refactor(server): improve oauth login flow (#10648)

close CLOUD-145
This commit is contained in:
fengmk2
2025-03-12 06:53:29 +00:00
parent d823792f85
commit 867ae7933f
16 changed files with 211 additions and 31 deletions

View File

@@ -1,5 +1,7 @@
import '../../plugins/config'; import '../../plugins/config';
import { randomUUID } from 'node:crypto';
import { HttpStatus } from '@nestjs/common'; import { HttpStatus } from '@nestjs/common';
import { PrismaClient } from '@prisma/client'; import { PrismaClient } from '@prisma/client';
import ava, { TestFn } from 'ava'; import ava, { TestFn } from 'ava';
@@ -83,6 +85,43 @@ test("should be able to redirect to oauth provider's login page", async t => {
t.is(redirect.searchParams.get('response_type'), 'code'); t.is(redirect.searchParams.get('response_type'), 'code');
t.is(redirect.searchParams.get('prompt'), 'select_account'); t.is(redirect.searchParams.get('prompt'), 'select_account');
t.truthy(redirect.searchParams.get('state')); t.truthy(redirect.searchParams.get('state'));
// state should be a json string
const state = JSON.parse(redirect.searchParams.get('state')!);
t.is(state.provider, 'Google');
t.regex(
state.state,
/^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/
);
});
test('should be able to redirect to oauth provider with client_nonce', async t => {
const { app } = t.context;
const res = await app
.POST('/api/oauth/preflight')
.send({ provider: 'Google', client: 'affine', client_nonce: '1234567890' })
.expect(HttpStatus.OK);
const { url } = res.body;
const redirect = new URL(url);
t.is(redirect.origin, 'https://accounts.google.com');
t.is(redirect.pathname, '/o/oauth2/v2/auth');
t.is(redirect.searchParams.get('client_id'), 'google-client-id');
t.is(
redirect.searchParams.get('redirect_uri'),
app.get(URLHelper).link('/oauth/callback')
);
t.is(redirect.searchParams.get('response_type'), 'code');
t.is(redirect.searchParams.get('prompt'), 'select_account');
t.truthy(redirect.searchParams.get('state'));
// state should be a json string
const state = JSON.parse(redirect.searchParams.get('state')!);
t.is(state.provider, 'Google');
t.is(state.client, 'affine');
t.falsy(state.clientNonce);
t.truthy(state.state);
}); });
test('should throw if provider is invalid', async t => { test('should throw if provider is invalid', async t => {
@@ -246,13 +285,18 @@ test('should throw if provider is invalid in callback uri', async t => {
t.pass(); t.pass();
}); });
function mockOAuthProvider(app: TestingApp, email: string) { function mockOAuthProvider(
app: TestingApp,
email: string,
clientNonce?: string
) {
const provider = app.get(GoogleOAuthProvider); const provider = app.get(GoogleOAuthProvider);
const oauth = app.get(OAuthService); const oauth = app.get(OAuthService);
Sinon.stub(oauth, 'isValidState').resolves(true); Sinon.stub(oauth, 'isValidState').resolves(true);
Sinon.stub(oauth, 'getOAuthState').resolves({ Sinon.stub(oauth, 'getOAuthState').resolves({
provider: OAuthProviderName.Google, provider: OAuthProviderName.Google,
clientNonce,
}); });
// @ts-expect-error mock // @ts-expect-error mock
@@ -294,6 +338,61 @@ test('should be able to sign up with oauth', async t => {
t.is(user!.connectedAccounts[0].providerAccountId, '1'); t.is(user!.connectedAccounts[0].providerAccountId, '1');
}); });
test('should be able to sign up with oauth and client_nonce', async t => {
const { app, db } = t.context;
const clientNonce = randomUUID();
const userEmail = `${clientNonce}@affine.pro`;
mockOAuthProvider(app, userEmail, clientNonce);
await app
.POST('/api/oauth/callback')
.send({ code: '1', state: '1', client_nonce: clientNonce })
.expect(HttpStatus.OK);
const sessionUser = await currentUser(app);
t.truthy(sessionUser);
t.is(sessionUser!.email, userEmail);
const user = await db.user.findFirst({
select: {
email: true,
connectedAccounts: true,
},
where: {
email: userEmail,
},
});
t.truthy(user);
t.is(user!.email, userEmail);
t.is(user!.connectedAccounts[0].providerAccountId, '1');
});
test('should throw if client_nonce is invalid', async t => {
const { app } = t.context;
const clientNonce = randomUUID();
const userEmail = `${clientNonce}@affine.pro`;
mockOAuthProvider(app, userEmail, clientNonce);
await app
.POST('/api/oauth/callback')
.send({ code: '1', state: '1', client_nonce: 'invalid' })
.expect(HttpStatus.BAD_REQUEST)
.expect({
status: 400,
code: 'Bad Request',
type: 'BAD_REQUEST',
name: 'INVALID_AUTH_STATE',
message:
'Invalid auth state. You might start the auth progress from another device.',
});
t.pass();
});
test('should not throw if account registered', async t => { test('should not throw if account registered', async t => {
const { app, u1 } = t.context; const { app, u1 } = t.context;

View File

@@ -55,6 +55,7 @@ const IncludedEvents = new Set([
'missing_oauth_query_parameter', 'missing_oauth_query_parameter',
'unknown_oauth_provider', 'unknown_oauth_provider',
'invalid_oauth_callback_state', 'invalid_oauth_callback_state',
'invalid_oauth_state',
'oauth_state_expired', 'oauth_state_expired',
'oauth_account_already_connected', 'oauth_account_already_connected',
]); ]);
@@ -319,6 +320,11 @@ export const USER_FRIENDLY_ERRORS = {
message: ({ status, body }) => message: ({ status, body }) =>
`Invalid callback code parameter, provider response status: ${status} and body: ${body}.`, `Invalid callback code parameter, provider response status: ${status} and body: ${body}.`,
}, },
invalid_auth_state: {
type: 'bad_request',
message:
'Invalid auth state. You might start the auth progress from another device.',
},
missing_oauth_query_parameter: { missing_oauth_query_parameter: {
type: 'bad_request', type: 'bad_request',
args: { name: 'string' }, args: { name: 'string' },

View File

@@ -131,6 +131,12 @@ export class InvalidOauthCallbackCode extends UserFriendlyError {
super('bad_request', 'invalid_oauth_callback_code', message, args); super('bad_request', 'invalid_oauth_callback_code', message, args);
} }
} }
export class InvalidAuthState extends UserFriendlyError {
constructor(message?: string) {
super('bad_request', 'invalid_auth_state', message);
}
}
@ObjectType() @ObjectType()
class MissingOauthQueryParameterDataType { class MissingOauthQueryParameterDataType {
@Field() name!: string @Field() name!: string
@@ -895,6 +901,7 @@ export enum ErrorNames {
OAUTH_STATE_EXPIRED, OAUTH_STATE_EXPIRED,
INVALID_OAUTH_CALLBACK_STATE, INVALID_OAUTH_CALLBACK_STATE,
INVALID_OAUTH_CALLBACK_CODE, INVALID_OAUTH_CALLBACK_CODE,
INVALID_AUTH_STATE,
MISSING_OAUTH_QUERY_PARAMETER, MISSING_OAUTH_QUERY_PARAMETER,
OAUTH_ACCOUNT_ALREADY_CONNECTED, OAUTH_ACCOUNT_ALREADY_CONNECTED,
INVALID_EMAIL, INVALID_EMAIL,

View File

@@ -13,6 +13,7 @@ import { ConnectedAccount } from '@prisma/client';
import type { Request, Response } from 'express'; import type { Request, Response } from 'express';
import { import {
InvalidAuthState,
InvalidOauthCallbackState, InvalidOauthCallbackState,
MissingOauthQueryParameter, MissingOauthQueryParameter,
OauthAccountAlreadyConnected, OauthAccountAlreadyConnected,
@@ -43,14 +44,15 @@ export class OAuthController {
@Post('/preflight') @Post('/preflight')
@HttpCode(HttpStatus.OK) @HttpCode(HttpStatus.OK)
async preflight( async preflight(
@Body('provider') unknownProviderName?: string, @Body('provider') unknownProviderName?: keyof typeof OAuthProviderName,
@Body('redirect_uri') redirectUri?: string @Body('redirect_uri') redirectUri?: string,
@Body('client') client?: string,
@Body('client_nonce') clientNonce?: string
) { ) {
if (!unknownProviderName) { if (!unknownProviderName) {
throw new MissingOauthQueryParameter({ name: 'provider' }); throw new MissingOauthQueryParameter({ name: 'provider' });
} }
// @ts-expect-error safe
const providerName = OAuthProviderName[unknownProviderName]; const providerName = OAuthProviderName[unknownProviderName];
const provider = this.providerFactory.get(providerName); const provider = this.providerFactory.get(providerName);
@@ -61,10 +63,14 @@ export class OAuthController {
const state = await this.oauth.saveOAuthState({ const state = await this.oauth.saveOAuthState({
provider: providerName, provider: providerName,
redirectUri, redirectUri,
client,
clientNonce,
}); });
return { return {
url: provider.getAuthUrl(state), url: provider.getAuthUrl(
JSON.stringify({ state, client, provider: unknownProviderName })
),
}; };
} }
@@ -76,7 +82,8 @@ export class OAuthController {
@Req() req: RawBodyRequest<Request>, @Req() req: RawBodyRequest<Request>,
@Res() res: Response, @Res() res: Response,
@Body('code') code?: string, @Body('code') code?: string,
@Body('state') stateStr?: string @Body('state') stateStr?: string,
@Body('client_nonce') clientNonce?: string
) { ) {
if (!code) { if (!code) {
throw new MissingOauthQueryParameter({ name: 'code' }); throw new MissingOauthQueryParameter({ name: 'code' });
@@ -96,6 +103,11 @@ export class OAuthController {
throw new OauthStateExpired(); throw new OauthStateExpired();
} }
// TODO(@fengmk2): clientNonce should be required after the client version >= 0.21.0
if (state.clientNonce && state.clientNonce !== clientNonce) {
throw new InvalidAuthState();
}
if (!state.provider) { if (!state.provider) {
throw new MissingOauthQueryParameter({ name: 'provider' }); throw new MissingOauthQueryParameter({ name: 'provider' });
} }

View File

@@ -10,6 +10,8 @@ const OAUTH_STATE_KEY = 'OAUTH_STATE';
interface OAuthState { interface OAuthState {
redirectUri?: string; redirectUri?: string;
client?: string;
clientNonce?: string;
provider: OAuthProviderName; provider: OAuthProviderName;
} }

View File

@@ -383,6 +383,7 @@ enum ErrorNames {
FAILED_TO_UPSERT_SNAPSHOT FAILED_TO_UPSERT_SNAPSHOT
GRAPHQL_BAD_REQUEST GRAPHQL_BAD_REQUEST
INTERNAL_SERVER_ERROR INTERNAL_SERVER_ERROR
INVALID_AUTH_STATE
INVALID_CHECKOUT_PARAMETERS INVALID_CHECKOUT_PARAMETERS
INVALID_EMAIL INVALID_EMAIL
INVALID_EMAIL_TOKEN INVALID_EMAIL_TOKEN

View File

@@ -46,8 +46,9 @@ public class AuthPlugin: CAPPlugin, CAPBridgedPlugin {
let endpoint = try call.getStringEnsure("endpoint") let endpoint = try call.getStringEnsure("endpoint")
let code = try call.getStringEnsure("code") let code = try call.getStringEnsure("code")
let state = try call.getStringEnsure("state") let state = try call.getStringEnsure("state")
let clientNonce = call.getString("clientNonce")
let (data, response) = try await self.fetch(endpoint, method: "POST", action: "/api/oauth/callback", headers: [:], body: ["code": code, "state": state]) let (data, response) = try await self.fetch(endpoint, method: "POST", action: "/api/oauth/callback", headers: [:], body: ["code": code, "state": state, "client_nonce": clientNonce])
if response.statusCode >= 400 { if response.statusCode >= 400 {
if let textBody = String(data: data, encoding: .utf8) { if let textBody = String(data: data, encoding: .utf8) {

View File

@@ -177,11 +177,12 @@ framework.scope(ServerScope).override(AuthProvider, resolver => {
}); });
await writeEndpointToken(endpoint, token); await writeEndpointToken(endpoint, token);
}, },
async signInOauth(code, state, _provider) { async signInOauth(code, state, _provider, clientNonce) {
const { token } = await Auth.signInOauth({ const { token } = await Auth.signInOauth({
endpoint, endpoint,
code, code,
state, state,
clientNonce,
}); });
await writeEndpointToken(endpoint, token); await writeEndpointToken(endpoint, token);
return {}; return {};

View File

@@ -8,6 +8,7 @@ export interface AuthPlugin {
endpoint: string; endpoint: string;
code: string; code: string;
state: string; state: string;
clientNonce?: string;
}): Promise<{ token: string }>; }): Promise<{ token: string }>;
signInPassword(options: { signInPassword(options: {
endpoint: string; endpoint: string;

View File

@@ -1,11 +1,15 @@
import { Button } from '@affine/component/ui/button'; import { Button } from '@affine/component/ui/button';
import { ServerService } from '@affine/core/modules/cloud'; import { notify } from '@affine/component/ui/notification';
import { useAsyncCallback } from '@affine/core/components/hooks/affine-async-hooks';
import { AuthService, ServerService } from '@affine/core/modules/cloud';
import { UrlService } from '@affine/core/modules/url'; import { UrlService } from '@affine/core/modules/url';
import { type UserFriendlyError } from '@affine/error';
import { OAuthProviderType } from '@affine/graphql'; import { OAuthProviderType } from '@affine/graphql';
import { useI18n } from '@affine/i18n';
import track from '@affine/track'; import track from '@affine/track';
import { GithubIcon, GoogleIcon, LockIcon } from '@blocksuite/icons/rc'; import { GithubIcon, GoogleIcon, LockIcon } from '@blocksuite/icons/rc';
import { useLiveData, useService } from '@toeverything/infra'; import { useLiveData, useService } from '@toeverything/infra';
import { type ReactElement, type SVGAttributes, useCallback } from 'react'; import { type ReactElement, type SVGAttributes } from 'react';
const OAuthProviderMap: Record< const OAuthProviderMap: Record<
OAuthProviderType, OAuthProviderType,
@@ -64,9 +68,27 @@ function OAuthProvider({
popupWindow: (url: string) => void; popupWindow: (url: string) => void;
}) { }) {
const serverService = useService(ServerService); const serverService = useService(ServerService);
const auth = useService(AuthService);
const { icon } = OAuthProviderMap[provider]; const { icon } = OAuthProviderMap[provider];
const t = useI18n();
const onClick = useAsyncCallback(async () => {
if (scheme && BUILD_CONFIG.isNative) {
let oauthUrl = '';
try {
oauthUrl = await auth.oauthPreflight(provider, scheme);
} catch (e) {
console.error(e);
const err = e as UserFriendlyError;
notify.error({
title: t[`error.${err.name}`](err.data),
});
return;
}
popupWindow(oauthUrl);
return;
}
const onClick = useCallback(() => {
const params = new URLSearchParams(); const params = new URLSearchParams();
params.set('provider', provider); params.set('provider', provider);
@@ -88,7 +110,7 @@ function OAuthProvider({
track.$.$.auth.signIn({ method: 'oauth', provider }); track.$.$.auth.signIn({ method: 'oauth', provider });
popupWindow(oauthUrl); popupWindow(oauthUrl);
}, [popupWindow, provider, redirectUrl, scheme, serverService]); }, [popupWindow, provider, redirectUrl, scheme, serverService, auth, t]);
return ( return (
<Button <Button

View File

@@ -18,10 +18,15 @@ export function configureDefaultAuthProvider(framework: Framework) {
}); });
}, },
async signInOauth(code: string, state: string, _provider: string) { async signInOauth(
code: string,
state: string,
_provider: string,
clientNonce?: string
) {
const res = await fetchService.fetch('/api/oauth/callback', { const res = await fetchService.fetch('/api/oauth/callback', {
method: 'POST', method: 'POST',
body: JSON.stringify({ code, state }), body: JSON.stringify({ code, state, client_nonce: clientNonce }),
headers: { headers: {
'content-type': 'application/json', 'content-type': 'application/json',
}, },

View File

@@ -6,7 +6,8 @@ export interface AuthProvider {
signInOauth( signInOauth(
code: string, code: string,
state: string, state: string,
provider: string provider: string,
clientNonce?: string
): Promise<{ redirectUri?: string }>; ): Promise<{ redirectUri?: string }>;
signInPassword(credential: { signInPassword(credential: {

View File

@@ -3,6 +3,7 @@ import { UserFriendlyError } from '@affine/error';
import type { OAuthProviderType } from '@affine/graphql'; import type { OAuthProviderType } from '@affine/graphql';
import { track } from '@affine/track'; import { track } from '@affine/track';
import { OnEvent, Service } from '@toeverything/infra'; import { OnEvent, Service } from '@toeverything/infra';
import { nanoid } from 'nanoid';
import { distinctUntilChanged, map, skip } from 'rxjs'; import { distinctUntilChanged, map, skip } from 'rxjs';
import { ApplicationFocused } from '../../lifecycle'; import { ApplicationFocused } from '../../lifecycle';
@@ -130,10 +131,16 @@ export class AuthService extends Service {
client: string, client: string,
/** @deprecated*/ redirectUrl?: string /** @deprecated*/ redirectUrl?: string
) { ) {
this.setClientNonce();
try { try {
const res = await this.fetchService.fetch('/api/oauth/preflight', { const res = await this.fetchService.fetch('/api/oauth/preflight', {
method: 'POST', method: 'POST',
body: JSON.stringify({ provider, redirect_uri: redirectUrl }), body: JSON.stringify({
provider,
client,
redirect_uri: redirectUrl,
client_nonce: this.store.getClientNonce(),
}),
headers: { headers: {
'content-type': 'application/json', 'content-type': 'application/json',
}, },
@@ -141,19 +148,6 @@ export class AuthService extends Service {
let { url } = await res.json(); let { url } = await res.json();
// change `state=xxx` to `state={state:xxx,native:true}`
// so we could know the callback should be redirect to native app
const oauthUrl = new URL(url);
oauthUrl.searchParams.set(
'state',
JSON.stringify({
state: oauthUrl.searchParams.get('state'),
client,
provider,
})
);
url = oauthUrl.toString();
return url as string; return url as string;
} catch (e) { } catch (e) {
track.$.$.auth.signInFail({ track.$.$.auth.signInFail({
@@ -228,4 +222,11 @@ export class AuthService extends Service {
return headers; return headers;
} }
private setClientNonce() {
if (BUILD_CONFIG.isNative) {
// send random client nonce on native app
this.store.setClientNonce(nanoid());
}
}
} }

View File

@@ -48,6 +48,14 @@ export class AuthStore extends Store {
this.globalState.set(`${this.serverService.server.id}-auth`, session); this.globalState.set(`${this.serverService.server.id}-auth`, session);
} }
getClientNonce() {
return this.globalState.get<string>('auth-client-nonce');
}
setClientNonce(nonce: string) {
this.globalState.set('auth-client-nonce', nonce);
}
async fetchSession() { async fetchSession() {
const url = `/api/auth/session`; const url = `/api/auth/session`;
const options: RequestInit = { const options: RequestInit = {
@@ -70,7 +78,12 @@ export class AuthStore extends Store {
} }
async signInOauth(code: string, state: string, provider: string) { async signInOauth(code: string, state: string, provider: string) {
return await this.authProvider.signInOauth(code, state, provider); return await this.authProvider.signInOauth(
code,
state,
provider,
this.getClientNonce()
);
} }
async signInPassword(credential: { async signInPassword(credential: {

View File

@@ -12,6 +12,10 @@ declare interface BUILD_CONFIG_TYPE {
isElectron: boolean; isElectron: boolean;
isWeb: boolean; isWeb: boolean;
/**
* 'desktop' | 'ios' | 'android'
*/
isNative: boolean;
isMobileWeb: boolean; isMobileWeb: boolean;
isIOS: boolean; isIOS: boolean;
isAndroid: boolean; isAndroid: boolean;
@@ -30,4 +34,4 @@ declare interface BUILD_CONFIG_TYPE {
linkPreviewUrl: string; linkPreviewUrl: string;
} }
declare var BUILD_CONFIG: BUILD_CONFIG_TYPE; declare var BUILD_CONFIG: BUILD_CONFIG_TYPE;

View File

@@ -33,6 +33,10 @@ export function getBuildConfig(
isMobileWeb: distribution === 'mobile', isMobileWeb: distribution === 'mobile',
isIOS: distribution === 'ios', isIOS: distribution === 'ios',
isAndroid: distribution === 'android', isAndroid: distribution === 'android',
isNative:
distribution === 'desktop' ||
distribution === 'ios' ||
distribution === 'android',
isAdmin: distribution === 'admin', isAdmin: distribution === 'admin',
appBuildType: 'stable' as const, appBuildType: 'stable' as const,