diff --git a/packages/backend/server/src/__tests__/oauth/controller.spec.ts b/packages/backend/server/src/__tests__/oauth/controller.spec.ts index b71bf4a112..4205b8a67d 100644 --- a/packages/backend/server/src/__tests__/oauth/controller.spec.ts +++ b/packages/backend/server/src/__tests__/oauth/controller.spec.ts @@ -1,5 +1,7 @@ import '../../plugins/config'; +import { randomUUID } from 'node:crypto'; + import { HttpStatus } from '@nestjs/common'; import { PrismaClient } from '@prisma/client'; import ava, { TestFn } from 'ava'; @@ -83,6 +85,43 @@ test("should be able to redirect to oauth provider's login page", async t => { t.is(redirect.searchParams.get('response_type'), 'code'); t.is(redirect.searchParams.get('prompt'), 'select_account'); t.truthy(redirect.searchParams.get('state')); + // state should be a json string + const state = JSON.parse(redirect.searchParams.get('state')!); + t.is(state.provider, 'Google'); + t.regex( + state.state, + /^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/ + ); +}); + +test('should be able to redirect to oauth provider with client_nonce', async t => { + const { app } = t.context; + + const res = await app + .POST('/api/oauth/preflight') + .send({ provider: 'Google', client: 'affine', client_nonce: '1234567890' }) + .expect(HttpStatus.OK); + + const { url } = res.body; + + const redirect = new URL(url); + t.is(redirect.origin, 'https://accounts.google.com'); + + t.is(redirect.pathname, '/o/oauth2/v2/auth'); + t.is(redirect.searchParams.get('client_id'), 'google-client-id'); + t.is( + redirect.searchParams.get('redirect_uri'), + app.get(URLHelper).link('/oauth/callback') + ); + t.is(redirect.searchParams.get('response_type'), 'code'); + t.is(redirect.searchParams.get('prompt'), 'select_account'); + t.truthy(redirect.searchParams.get('state')); + // state should be a json string + const state = JSON.parse(redirect.searchParams.get('state')!); + t.is(state.provider, 'Google'); + t.is(state.client, 'affine'); + t.falsy(state.clientNonce); + t.truthy(state.state); }); test('should throw if provider is invalid', async t => { @@ -246,13 +285,18 @@ test('should throw if provider is invalid in callback uri', async t => { t.pass(); }); -function mockOAuthProvider(app: TestingApp, email: string) { +function mockOAuthProvider( + app: TestingApp, + email: string, + clientNonce?: string +) { const provider = app.get(GoogleOAuthProvider); const oauth = app.get(OAuthService); Sinon.stub(oauth, 'isValidState').resolves(true); Sinon.stub(oauth, 'getOAuthState').resolves({ provider: OAuthProviderName.Google, + clientNonce, }); // @ts-expect-error mock @@ -294,6 +338,61 @@ test('should be able to sign up with oauth', async t => { t.is(user!.connectedAccounts[0].providerAccountId, '1'); }); +test('should be able to sign up with oauth and client_nonce', async t => { + const { app, db } = t.context; + + const clientNonce = randomUUID(); + const userEmail = `${clientNonce}@affine.pro`; + mockOAuthProvider(app, userEmail, clientNonce); + + await app + .POST('/api/oauth/callback') + .send({ code: '1', state: '1', client_nonce: clientNonce }) + .expect(HttpStatus.OK); + + const sessionUser = await currentUser(app); + + t.truthy(sessionUser); + t.is(sessionUser!.email, userEmail); + + const user = await db.user.findFirst({ + select: { + email: true, + connectedAccounts: true, + }, + where: { + email: userEmail, + }, + }); + + t.truthy(user); + t.is(user!.email, userEmail); + t.is(user!.connectedAccounts[0].providerAccountId, '1'); +}); + +test('should throw if client_nonce is invalid', async t => { + const { app } = t.context; + + const clientNonce = randomUUID(); + const userEmail = `${clientNonce}@affine.pro`; + mockOAuthProvider(app, userEmail, clientNonce); + + await app + .POST('/api/oauth/callback') + .send({ code: '1', state: '1', client_nonce: 'invalid' }) + .expect(HttpStatus.BAD_REQUEST) + .expect({ + status: 400, + code: 'Bad Request', + type: 'BAD_REQUEST', + name: 'INVALID_AUTH_STATE', + message: + 'Invalid auth state. You might start the auth progress from another device.', + }); + + t.pass(); +}); + test('should not throw if account registered', async t => { const { app, u1 } = t.context; diff --git a/packages/backend/server/src/base/error/def.ts b/packages/backend/server/src/base/error/def.ts index ed7778fcc8..3cea817f56 100644 --- a/packages/backend/server/src/base/error/def.ts +++ b/packages/backend/server/src/base/error/def.ts @@ -55,6 +55,7 @@ const IncludedEvents = new Set([ 'missing_oauth_query_parameter', 'unknown_oauth_provider', 'invalid_oauth_callback_state', + 'invalid_oauth_state', 'oauth_state_expired', 'oauth_account_already_connected', ]); @@ -319,6 +320,11 @@ export const USER_FRIENDLY_ERRORS = { message: ({ status, body }) => `Invalid callback code parameter, provider response status: ${status} and body: ${body}.`, }, + invalid_auth_state: { + type: 'bad_request', + message: + 'Invalid auth state. You might start the auth progress from another device.', + }, missing_oauth_query_parameter: { type: 'bad_request', args: { name: 'string' }, diff --git a/packages/backend/server/src/base/error/errors.gen.ts b/packages/backend/server/src/base/error/errors.gen.ts index b0f3595fa1..6e7eed9f2e 100644 --- a/packages/backend/server/src/base/error/errors.gen.ts +++ b/packages/backend/server/src/base/error/errors.gen.ts @@ -131,6 +131,12 @@ export class InvalidOauthCallbackCode extends UserFriendlyError { super('bad_request', 'invalid_oauth_callback_code', message, args); } } + +export class InvalidAuthState extends UserFriendlyError { + constructor(message?: string) { + super('bad_request', 'invalid_auth_state', message); + } +} @ObjectType() class MissingOauthQueryParameterDataType { @Field() name!: string @@ -895,6 +901,7 @@ export enum ErrorNames { OAUTH_STATE_EXPIRED, INVALID_OAUTH_CALLBACK_STATE, INVALID_OAUTH_CALLBACK_CODE, + INVALID_AUTH_STATE, MISSING_OAUTH_QUERY_PARAMETER, OAUTH_ACCOUNT_ALREADY_CONNECTED, INVALID_EMAIL, diff --git a/packages/backend/server/src/plugins/oauth/controller.ts b/packages/backend/server/src/plugins/oauth/controller.ts index af2329e0b6..5c9a0ec217 100644 --- a/packages/backend/server/src/plugins/oauth/controller.ts +++ b/packages/backend/server/src/plugins/oauth/controller.ts @@ -13,6 +13,7 @@ import { ConnectedAccount } from '@prisma/client'; import type { Request, Response } from 'express'; import { + InvalidAuthState, InvalidOauthCallbackState, MissingOauthQueryParameter, OauthAccountAlreadyConnected, @@ -43,14 +44,15 @@ export class OAuthController { @Post('/preflight') @HttpCode(HttpStatus.OK) async preflight( - @Body('provider') unknownProviderName?: string, - @Body('redirect_uri') redirectUri?: string + @Body('provider') unknownProviderName?: keyof typeof OAuthProviderName, + @Body('redirect_uri') redirectUri?: string, + @Body('client') client?: string, + @Body('client_nonce') clientNonce?: string ) { if (!unknownProviderName) { throw new MissingOauthQueryParameter({ name: 'provider' }); } - // @ts-expect-error safe const providerName = OAuthProviderName[unknownProviderName]; const provider = this.providerFactory.get(providerName); @@ -61,10 +63,14 @@ export class OAuthController { const state = await this.oauth.saveOAuthState({ provider: providerName, redirectUri, + client, + clientNonce, }); return { - url: provider.getAuthUrl(state), + url: provider.getAuthUrl( + JSON.stringify({ state, client, provider: unknownProviderName }) + ), }; } @@ -76,7 +82,8 @@ export class OAuthController { @Req() req: RawBodyRequest, @Res() res: Response, @Body('code') code?: string, - @Body('state') stateStr?: string + @Body('state') stateStr?: string, + @Body('client_nonce') clientNonce?: string ) { if (!code) { throw new MissingOauthQueryParameter({ name: 'code' }); @@ -96,6 +103,11 @@ export class OAuthController { throw new OauthStateExpired(); } + // TODO(@fengmk2): clientNonce should be required after the client version >= 0.21.0 + if (state.clientNonce && state.clientNonce !== clientNonce) { + throw new InvalidAuthState(); + } + if (!state.provider) { throw new MissingOauthQueryParameter({ name: 'provider' }); } diff --git a/packages/backend/server/src/plugins/oauth/service.ts b/packages/backend/server/src/plugins/oauth/service.ts index da0ccee9ce..251425f663 100644 --- a/packages/backend/server/src/plugins/oauth/service.ts +++ b/packages/backend/server/src/plugins/oauth/service.ts @@ -10,6 +10,8 @@ const OAUTH_STATE_KEY = 'OAUTH_STATE'; interface OAuthState { redirectUri?: string; + client?: string; + clientNonce?: string; provider: OAuthProviderName; } diff --git a/packages/backend/server/src/schema.gql b/packages/backend/server/src/schema.gql index 42f31f467e..e402bd7f2b 100644 --- a/packages/backend/server/src/schema.gql +++ b/packages/backend/server/src/schema.gql @@ -383,6 +383,7 @@ enum ErrorNames { FAILED_TO_UPSERT_SNAPSHOT GRAPHQL_BAD_REQUEST INTERNAL_SERVER_ERROR + INVALID_AUTH_STATE INVALID_CHECKOUT_PARAMETERS INVALID_EMAIL INVALID_EMAIL_TOKEN diff --git a/packages/frontend/apps/ios/App/App/Plugins/Auth/AuthPlugin.swift b/packages/frontend/apps/ios/App/App/Plugins/Auth/AuthPlugin.swift index 01970e2f72..af9fb945f3 100644 --- a/packages/frontend/apps/ios/App/App/Plugins/Auth/AuthPlugin.swift +++ b/packages/frontend/apps/ios/App/App/Plugins/Auth/AuthPlugin.swift @@ -46,8 +46,9 @@ public class AuthPlugin: CAPPlugin, CAPBridgedPlugin { let endpoint = try call.getStringEnsure("endpoint") let code = try call.getStringEnsure("code") let state = try call.getStringEnsure("state") + let clientNonce = call.getString("clientNonce") - let (data, response) = try await self.fetch(endpoint, method: "POST", action: "/api/oauth/callback", headers: [:], body: ["code": code, "state": state]) + let (data, response) = try await self.fetch(endpoint, method: "POST", action: "/api/oauth/callback", headers: [:], body: ["code": code, "state": state, "client_nonce": clientNonce]) if response.statusCode >= 400 { if let textBody = String(data: data, encoding: .utf8) { diff --git a/packages/frontend/apps/ios/src/app.tsx b/packages/frontend/apps/ios/src/app.tsx index a839bb4dcc..b34dd6d5f5 100644 --- a/packages/frontend/apps/ios/src/app.tsx +++ b/packages/frontend/apps/ios/src/app.tsx @@ -177,11 +177,12 @@ framework.scope(ServerScope).override(AuthProvider, resolver => { }); await writeEndpointToken(endpoint, token); }, - async signInOauth(code, state, _provider) { + async signInOauth(code, state, _provider, clientNonce) { const { token } = await Auth.signInOauth({ endpoint, code, state, + clientNonce, }); await writeEndpointToken(endpoint, token); return {}; diff --git a/packages/frontend/apps/ios/src/plugins/auth/definitions.ts b/packages/frontend/apps/ios/src/plugins/auth/definitions.ts index b59bf30c07..160b18b20a 100644 --- a/packages/frontend/apps/ios/src/plugins/auth/definitions.ts +++ b/packages/frontend/apps/ios/src/plugins/auth/definitions.ts @@ -8,6 +8,7 @@ export interface AuthPlugin { endpoint: string; code: string; state: string; + clientNonce?: string; }): Promise<{ token: string }>; signInPassword(options: { endpoint: string; diff --git a/packages/frontend/core/src/components/affine/auth/oauth.tsx b/packages/frontend/core/src/components/affine/auth/oauth.tsx index 7be403761a..0cc6cd1a9f 100644 --- a/packages/frontend/core/src/components/affine/auth/oauth.tsx +++ b/packages/frontend/core/src/components/affine/auth/oauth.tsx @@ -1,11 +1,15 @@ import { Button } from '@affine/component/ui/button'; -import { ServerService } from '@affine/core/modules/cloud'; +import { notify } from '@affine/component/ui/notification'; +import { useAsyncCallback } from '@affine/core/components/hooks/affine-async-hooks'; +import { AuthService, ServerService } from '@affine/core/modules/cloud'; import { UrlService } from '@affine/core/modules/url'; +import { type UserFriendlyError } from '@affine/error'; import { OAuthProviderType } from '@affine/graphql'; +import { useI18n } from '@affine/i18n'; import track from '@affine/track'; import { GithubIcon, GoogleIcon, LockIcon } from '@blocksuite/icons/rc'; import { useLiveData, useService } from '@toeverything/infra'; -import { type ReactElement, type SVGAttributes, useCallback } from 'react'; +import { type ReactElement, type SVGAttributes } from 'react'; const OAuthProviderMap: Record< OAuthProviderType, @@ -64,9 +68,27 @@ function OAuthProvider({ popupWindow: (url: string) => void; }) { const serverService = useService(ServerService); + const auth = useService(AuthService); const { icon } = OAuthProviderMap[provider]; + const t = useI18n(); + + const onClick = useAsyncCallback(async () => { + if (scheme && BUILD_CONFIG.isNative) { + let oauthUrl = ''; + try { + oauthUrl = await auth.oauthPreflight(provider, scheme); + } catch (e) { + console.error(e); + const err = e as UserFriendlyError; + notify.error({ + title: t[`error.${err.name}`](err.data), + }); + return; + } + popupWindow(oauthUrl); + return; + } - const onClick = useCallback(() => { const params = new URLSearchParams(); params.set('provider', provider); @@ -88,7 +110,7 @@ function OAuthProvider({ track.$.$.auth.signIn({ method: 'oauth', provider }); popupWindow(oauthUrl); - }, [popupWindow, provider, redirectUrl, scheme, serverService]); + }, [popupWindow, provider, redirectUrl, scheme, serverService, auth, t]); return (