Files
AFFiNE-Mirror/packages/backend/server/src/plugins/oauth/controller.ts
2026-01-17 22:39:20 +08:00

326 lines
8.6 KiB
TypeScript

import {
Body,
Controller,
HttpCode,
HttpStatus,
Logger,
Post,
type RawBodyRequest,
Req,
Res,
} from '@nestjs/common';
import { ConnectedAccount } from '@prisma/client';
import type { Request, Response } from 'express';
import {
ActionForbidden,
Config,
InvalidAuthState,
InvalidOauthCallbackState,
MissingOauthQueryParameter,
OauthAccountAlreadyConnected,
OauthStateExpired,
SignUpForbidden,
UnknownOauthProvider,
URLHelper,
UseNamedGuard,
} from '../../base';
import { AuthService, Public } from '../../core/auth';
import { Models } from '../../models';
import { OAuthProviderName } from './config';
import { OAuthProviderFactory } from './factory';
import { OAuthAccount, Tokens } from './providers/def';
import { OAuthService } from './service';
@Controller('/api/oauth')
export class OAuthController {
private readonly logger = new Logger(OAuthController.name);
constructor(
private readonly auth: AuthService,
private readonly oauth: OAuthService,
private readonly models: Models,
private readonly providerFactory: OAuthProviderFactory,
private readonly url: URLHelper,
private readonly config: Config
) {}
@Public()
@UseNamedGuard('version')
@Post('/preflight')
@HttpCode(HttpStatus.OK)
async preflight(
@Body('provider') unknownProviderName?: keyof typeof OAuthProviderName,
@Body('redirect_uri') redirectUri?: string,
@Body('client') client?: string,
@Body('client_nonce') clientNonce?: string
) {
if (!unknownProviderName) {
throw new MissingOauthQueryParameter({ name: 'provider' });
}
if (!clientNonce) {
throw new MissingOauthQueryParameter({ name: 'client_nonce' });
}
const providerName = OAuthProviderName[unknownProviderName];
const provider = this.providerFactory.get(providerName);
if (!provider) {
throw new UnknownOauthProvider({ name: unknownProviderName });
}
const pkce = provider.requiresPkce ? this.oauth.createPkcePair() : null;
if (redirectUri && !this.url.isAllowedRedirectUri(redirectUri)) {
throw new ActionForbidden();
}
const state = await this.oauth.saveOAuthState({
provider: providerName,
redirectUri,
client,
clientNonce,
...(pkce
? {
pkce: {
codeVerifier: pkce.codeVerifier,
codeChallengeMethod: pkce.codeChallengeMethod,
},
}
: {}),
});
const statePayload: Record<string, unknown> = {
state,
client,
provider: unknownProviderName,
};
if (pkce) {
statePayload.pkce = {
codeChallenge: pkce.codeChallenge,
codeChallengeMethod: pkce.codeChallengeMethod,
};
}
const stateStr = JSON.stringify(statePayload);
return {
url: provider.getAuthUrl(stateStr, clientNonce),
};
}
// the prerequest `/oauth/prelight` request already checked client version,
// let's simply ignore it for callback which will block apple oauth post_form mode
// @UseNamedGuard('version')
@Public()
@Post('/callback')
@HttpCode(HttpStatus.OK)
async callback(
@Req() req: RawBodyRequest<Request>,
@Res() res: Response,
@Body('code') code?: string,
@Body('state') stateStr?: string,
@Body('client_nonce') clientNonce?: string
) {
// TODO(@forehalo): refactor and remove deprecated code in 0.23
if (!code) {
throw new MissingOauthQueryParameter({ name: 'code' });
}
if (!stateStr) {
throw new MissingOauthQueryParameter({ name: 'state' });
}
// NOTE(@forehalo): Apple sign in will directly post /callback, with `state` set at #L73
let rawState = null;
if (typeof stateStr === 'string' && stateStr.length > 36) {
try {
rawState = JSON.parse(stateStr);
stateStr = rawState.state;
} catch {
/* noop */
}
}
if (typeof stateStr !== 'string' || !this.oauth.isValidState(stateStr)) {
throw new InvalidOauthCallbackState();
}
const state = await this.oauth.getOAuthState(stateStr);
if (!state) {
throw new OauthStateExpired();
}
if (!state.token) {
state.token = stateStr;
}
if (
state.provider === OAuthProviderName.Apple &&
rawState &&
state.client &&
state.client !== 'web'
) {
const clientUrl = new URL(`${state.client}://authentication`);
clientUrl.searchParams.set('method', 'oauth');
clientUrl.searchParams.set(
'payload',
JSON.stringify({
state: stateStr,
code,
provider: rawState.provider,
})
);
clientUrl.searchParams.set('server', this.url.requestOrigin);
return res.redirect(
this.url.link('/open-app/url?', {
url: clientUrl.toString(),
})
);
}
if (!state.provider) {
throw new MissingOauthQueryParameter({ name: 'provider' });
}
const provider = this.providerFactory.get(state.provider);
if (!provider) {
throw new UnknownOauthProvider({ name: state.provider ?? 'unknown' });
}
if (
state.provider !== OAuthProviderName.Apple &&
(!clientNonce || !state.clientNonce || state.clientNonce !== clientNonce)
) {
throw new InvalidAuthState();
}
let tokens: Tokens;
try {
tokens = await provider.getToken(code, state);
} catch (err) {
let rayBodyString = '';
if (req.rawBody) {
// only log the first 4096 bytes of the raw body
rayBodyString = req.rawBody.subarray(0, 4096).toString('utf-8');
}
this.logger.warn(
`Error getting oauth token for ${state.provider}, callback code: ${code}, stateStr: ${stateStr}, rawBody: ${rayBodyString}, error: ${err}`
);
throw err;
}
const externAccount = await provider.getUser(tokens, state);
const user = await this.getOrCreateUserFromOauth(
state.provider,
externAccount,
tokens
);
await this.auth.setCookies(req, res, user.id);
if (
state.provider === OAuthProviderName.Apple &&
(!state.client || state.client === 'web')
) {
return this.url.safeRedirect(res, state.redirectUri ?? '/');
}
res.send({
id: user.id,
redirectUri: state.redirectUri,
});
}
private async getOrCreateUserFromOauth(
provider: OAuthProviderName,
externalAccount: OAuthAccount,
tokens: Tokens
) {
const connectedAccount = await this.models.user.getConnectedAccount(
provider,
externalAccount.id
);
if (connectedAccount) {
// already connected
await this.updateConnectedAccount(connectedAccount, tokens);
if (
!connectedAccount.user.emailVerifiedAt &&
// external email may change, check if it matches exists email
externalAccount.email.toLowerCase() ===
connectedAccount.user.email.toLowerCase()
) {
await this.auth.setEmailVerified(connectedAccount.userId);
}
return connectedAccount.user;
}
if (!this.config.auth.allowSignupForOauth) {
throw new SignUpForbidden();
}
const user = await this.models.user.fulfill(externalAccount.email, {
name: externalAccount.name,
avatarUrl: externalAccount.avatarUrl,
});
await this.models.user.createConnectedAccount({
userId: user.id,
provider,
providerAccountId: externalAccount.id,
accessToken: tokens.accessToken,
refreshToken: tokens.refreshToken,
expiresAt: tokens.expiresAt,
});
return user;
}
private async updateConnectedAccount(
connectedAccount: ConnectedAccount,
tokens: Tokens
) {
return await this.models.user.updateConnectedAccount(connectedAccount.id, {
accessToken: tokens.accessToken,
refreshToken: tokens.refreshToken,
expiresAt: tokens.expiresAt,
});
}
/**
* we currently don't support connect oauth account to existing user
* keep it incase we need it in the future
*/
// @ts-expect-error allow unused
private async _connectAccount(
user: { id: string },
provider: OAuthProviderName,
externalAccount: OAuthAccount,
tokens: Tokens
) {
const connectedAccount = await this.models.user.getConnectedAccount(
provider,
externalAccount.id
);
if (connectedAccount) {
if (connectedAccount.userId !== user.id) {
throw new OauthAccountAlreadyConnected();
}
} else {
await this.models.user.createConnectedAccount({
userId: user.id,
provider,
providerAccountId: externalAccount.id,
accessToken: tokens.accessToken,
refreshToken: tokens.refreshToken,
expiresAt: tokens.expiresAt,
});
}
}
}