diff --git a/packages/backend/server/src/__tests__/oauth/controller.spec.ts b/packages/backend/server/src/__tests__/oauth/controller.spec.ts index d2b27845f1..1e89e0fc3e 100644 --- a/packages/backend/server/src/__tests__/oauth/controller.spec.ts +++ b/packages/backend/server/src/__tests__/oauth/controller.spec.ts @@ -6,13 +6,16 @@ import ava, { TestFn } from 'ava'; import Sinon from 'sinon'; import { AppModule } from '../../app.module'; -import { ConfigFactory, URLHelper } from '../../base'; +import { ConfigFactory, InvalidOauthResponse, URLHelper } from '../../base'; import { ConfigModule } from '../../base/config'; import { CurrentUser } from '../../core/auth'; import { AuthService } from '../../core/auth/service'; +import { ServerFeature } from '../../core/config/types'; import { Models } from '../../models'; import { OAuthProviderName } from '../../plugins/oauth/config'; +import { OAuthProviderFactory } from '../../plugins/oauth/factory'; import { GoogleOAuthProvider } from '../../plugins/oauth/providers/google'; +import { OIDCProvider } from '../../plugins/oauth/providers/oidc'; import { OAuthService } from '../../plugins/oauth/service'; import { createTestingApp, currentUser, TestingApp } from '../utils'; @@ -35,6 +38,12 @@ test.before(async t => { clientId: 'google-client-id', clientSecret: 'google-client-secret', }, + oidc: { + clientId: '', + clientSecret: '', + issuer: '', + args: {}, + }, }, }, server: { @@ -432,6 +441,87 @@ function mockOAuthProvider( return clientNonce; } +function mockOidcProvider( + provider: OIDCProvider, + { + args = {}, + idTokenClaims, + userinfo, + }: { + args?: Record; + idTokenClaims: Record; + userinfo: Record; + } +) { + Sinon.stub(provider, 'config').get(() => ({ + clientId: '', + clientSecret: '', + issuer: '', + args, + })); + Sinon.stub( + provider as unknown as { endpoints: { userinfo_endpoint: string } }, + 'endpoints' + ).get(() => ({ + userinfo_endpoint: 'https://oidc.affine.dev/userinfo', + })); + Sinon.stub( + provider as unknown as { verifyIdToken: () => unknown }, + 'verifyIdToken' + ).resolves(idTokenClaims); + Sinon.stub( + provider as unknown as { fetchJson: () => unknown }, + 'fetchJson' + ).resolves(userinfo); +} + +function createOidcRegistrationHarness(config?: { + clientId?: string; + clientSecret?: string; + issuer?: string; +}) { + const server = { + enableFeature: Sinon.spy(), + disableFeature: Sinon.spy(), + }; + const factory = new OAuthProviderFactory(server as any); + const affineConfig = { + server: { + externalUrl: 'https://affine.example', + host: 'localhost', + path: '', + https: true, + hosts: [], + }, + oauth: { + providers: { + oidc: { + clientId: config?.clientId ?? 'oidc-client-id', + clientSecret: config?.clientSecret ?? 'oidc-client-secret', + issuer: config?.issuer ?? 'https://issuer.affine.dev', + args: {}, + }, + }, + }, + }; + const provider = new OIDCProvider(new URLHelper(affineConfig as any)); + + (provider as any).factory = factory; + (provider as any).AFFiNEConfig = affineConfig; + + return { + provider, + factory, + server, + }; +} + +async function flushAsyncWork(iterations = 5) { + for (let i = 0; i < iterations; i++) { + await new Promise(resolve => setImmediate(resolve)); + } +} + test('should be able to sign up with oauth', async t => { const { app, db } = t.context; @@ -554,3 +644,209 @@ test('should be able to fullfil user with oauth sign in', async t => { t.truthy(account); t.is(account!.user.id, u3.id); }); + +test('oidc should accept email from id token when userinfo email is missing', async t => { + const { app } = t.context; + + const provider = app.get(OIDCProvider); + mockOidcProvider(provider, { + idTokenClaims: { + sub: 'oidc-user', + email: 'oidc-id-token@affine.pro', + name: 'OIDC User', + }, + userinfo: { + sub: 'oidc-user', + name: 'OIDC User', + }, + }); + + const user = await provider.getUser( + { accessToken: 'token', idToken: 'id-token' }, + { token: 'nonce', provider: OAuthProviderName.OIDC } + ); + + t.is(user.id, 'oidc-user'); + t.is(user.email, 'oidc-id-token@affine.pro'); + t.is(user.name, 'OIDC User'); +}); + +test('oidc should resolve custom email claim from userinfo', async t => { + const { app } = t.context; + + const provider = app.get(OIDCProvider); + mockOidcProvider(provider, { + args: { claim_email: 'mail', claim_name: 'display_name' }, + idTokenClaims: { + sub: 'oidc-user', + }, + userinfo: { + sub: 'oidc-user', + mail: 'oidc-userinfo@affine.pro', + display_name: 'OIDC Custom', + }, + }); + + const user = await provider.getUser( + { accessToken: 'token', idToken: 'id-token' }, + { token: 'nonce', provider: OAuthProviderName.OIDC } + ); + + t.is(user.id, 'oidc-user'); + t.is(user.email, 'oidc-userinfo@affine.pro'); + t.is(user.name, 'OIDC Custom'); +}); + +test('oidc should resolve custom email claim from id token', async t => { + const { app } = t.context; + + const provider = app.get(OIDCProvider); + mockOidcProvider(provider, { + args: { claim_email: 'mail', claim_email_verified: 'mail_verified' }, + idTokenClaims: { + sub: 'oidc-user', + mail: 'oidc-custom-id-token@affine.pro', + mail_verified: 'true', + }, + userinfo: { + sub: 'oidc-user', + }, + }); + + const user = await provider.getUser( + { accessToken: 'token', idToken: 'id-token' }, + { token: 'nonce', provider: OAuthProviderName.OIDC } + ); + + t.is(user.id, 'oidc-user'); + t.is(user.email, 'oidc-custom-id-token@affine.pro'); +}); + +test('oidc should reject responses without a usable email claim', async t => { + const { app } = t.context; + + const provider = app.get(OIDCProvider); + mockOidcProvider(provider, { + args: { claim_email: 'mail' }, + idTokenClaims: { + sub: 'oidc-user', + mail: 'not-an-email', + }, + userinfo: { + sub: 'oidc-user', + mail: 'still-not-an-email', + }, + }); + + const error = await t.throwsAsync( + provider.getUser( + { accessToken: 'token', idToken: 'id-token' }, + { token: 'nonce', provider: OAuthProviderName.OIDC } + ) + ); + + t.true(error instanceof InvalidOauthResponse); + t.true( + error.message.includes( + 'Missing valid email claim in OIDC response. Tried userinfo and ID token claims: "mail"' + ) + ); +}); + +test('oidc should not fall back to default email claim when custom claim is configured', async t => { + const { app } = t.context; + + const provider = app.get(OIDCProvider); + mockOidcProvider(provider, { + args: { claim_email: 'mail' }, + idTokenClaims: { + sub: 'oidc-user', + email: 'fallback@affine.pro', + }, + userinfo: { + sub: 'oidc-user', + email: 'userinfo-fallback@affine.pro', + }, + }); + + const error = await t.throwsAsync( + provider.getUser( + { accessToken: 'token', idToken: 'id-token' }, + { token: 'nonce', provider: OAuthProviderName.OIDC } + ) + ); + + t.true(error instanceof InvalidOauthResponse); + t.true( + error.message.includes( + 'Missing valid email claim in OIDC response. Tried userinfo and ID token claims: "mail"' + ) + ); +}); + +test('oidc discovery should remove oauth feature on failure and restore it after backoff retry succeeds', async t => { + const { provider, factory, server } = createOidcRegistrationHarness(); + const fetchStub = Sinon.stub(globalThis, 'fetch'); + const scheduledRetries: Array<() => void> = []; + const retryDelays: number[] = []; + const setTimeoutStub = Sinon.stub(globalThis, 'setTimeout').callsFake((( + callback: Parameters[0], + delay?: number + ) => { + retryDelays.push(Number(delay)); + scheduledRetries.push(callback as () => void); + return Symbol('timeout') as unknown as ReturnType; + }) as typeof setTimeout); + t.teardown(() => { + provider.onModuleDestroy(); + fetchStub.restore(); + setTimeoutStub.restore(); + }); + + fetchStub + .onFirstCall() + .rejects(new Error('temporary discovery failure')) + .onSecondCall() + .rejects(new Error('temporary discovery failure')) + .onThirdCall() + .resolves( + new Response( + JSON.stringify({ + authorization_endpoint: 'https://issuer.affine.dev/auth', + token_endpoint: 'https://issuer.affine.dev/token', + userinfo_endpoint: 'https://issuer.affine.dev/userinfo', + issuer: 'https://issuer.affine.dev', + jwks_uri: 'https://issuer.affine.dev/jwks', + }), + { + status: 200, + headers: { 'Content-Type': 'application/json' }, + } + ) + ); + + (provider as any).setup(); + + await flushAsyncWork(); + t.deepEqual(factory.providers, []); + t.true(server.disableFeature.calledWith(ServerFeature.OAuth)); + t.is(fetchStub.callCount, 1); + t.deepEqual(retryDelays, [1000]); + + const firstRetry = scheduledRetries.shift(); + t.truthy(firstRetry); + firstRetry!(); + await flushAsyncWork(); + t.is(fetchStub.callCount, 2); + t.deepEqual(factory.providers, []); + t.deepEqual(retryDelays, [1000, 2000]); + + const secondRetry = scheduledRetries.shift(); + t.truthy(secondRetry); + secondRetry!(); + await flushAsyncWork(); + t.is(fetchStub.callCount, 3); + t.deepEqual(factory.providers, [OAuthProviderName.OIDC]); + t.true(server.enableFeature.calledWith(ServerFeature.OAuth)); + t.is(scheduledRetries.length, 0); +}); diff --git a/packages/backend/server/src/base/utils/__tests__/promise.spec.ts b/packages/backend/server/src/base/utils/__tests__/promise.spec.ts new file mode 100644 index 0000000000..6ea35a3545 --- /dev/null +++ b/packages/backend/server/src/base/utils/__tests__/promise.spec.ts @@ -0,0 +1,75 @@ +import test from 'ava'; +import Sinon from 'sinon'; + +import { + exponentialBackoffDelay, + ExponentialBackoffScheduler, +} from '../promise'; + +test('exponentialBackoffDelay should cap exponential growth at maxDelayMs', t => { + t.is(exponentialBackoffDelay(0, { baseDelayMs: 100, maxDelayMs: 500 }), 100); + t.is(exponentialBackoffDelay(1, { baseDelayMs: 100, maxDelayMs: 500 }), 200); + t.is(exponentialBackoffDelay(3, { baseDelayMs: 100, maxDelayMs: 500 }), 500); +}); + +test('ExponentialBackoffScheduler should track pending callback and increase delay per attempt', async t => { + const clock = Sinon.useFakeTimers(); + t.teardown(() => { + clock.restore(); + }); + + const calls: number[] = []; + const scheduler = new ExponentialBackoffScheduler({ + baseDelayMs: 100, + maxDelayMs: 500, + }); + + t.is( + scheduler.schedule(() => { + calls.push(1); + }), + 100 + ); + t.true(scheduler.pending); + t.is( + scheduler.schedule(() => { + calls.push(2); + }), + null + ); + + await clock.tickAsync(100); + t.deepEqual(calls, [1]); + t.false(scheduler.pending); + + t.is( + scheduler.schedule(() => { + calls.push(3); + }), + 200 + ); + await clock.tickAsync(200); + t.deepEqual(calls, [1, 3]); +}); + +test('ExponentialBackoffScheduler reset should clear pending work and restart from the base delay', t => { + const scheduler = new ExponentialBackoffScheduler({ + baseDelayMs: 100, + maxDelayMs: 500, + }); + + t.is( + scheduler.schedule(() => {}), + 100 + ); + t.true(scheduler.pending); + + scheduler.reset(); + t.false(scheduler.pending); + t.is( + scheduler.schedule(() => {}), + 100 + ); + + scheduler.clear(); +}); diff --git a/packages/backend/server/src/base/utils/promise.ts b/packages/backend/server/src/base/utils/promise.ts index 27023dc341..1c64e00c4e 100644 --- a/packages/backend/server/src/base/utils/promise.ts +++ b/packages/backend/server/src/base/utils/promise.ts @@ -1,4 +1,4 @@ -import { setTimeout } from 'node:timers/promises'; +import { setTimeout as delay } from 'node:timers/promises'; import { defer as rxjsDefer, retry } from 'rxjs'; @@ -52,5 +52,61 @@ export function defer(dispose: () => Promise) { } export function sleep(ms: number): Promise { - return setTimeout(ms); + return delay(ms); +} + +export function exponentialBackoffDelay( + attempt: number, + { + baseDelayMs, + maxDelayMs, + factor = 2, + }: { baseDelayMs: number; maxDelayMs: number; factor?: number } +): number { + return Math.min( + baseDelayMs * Math.pow(factor, Math.max(0, attempt)), + maxDelayMs + ); +} + +export class ExponentialBackoffScheduler { + #attempt = 0; + #timer: ReturnType | null = null; + + constructor( + private readonly options: { + baseDelayMs: number; + maxDelayMs: number; + factor?: number; + } + ) {} + + get pending() { + return this.#timer !== null; + } + + clear() { + if (this.#timer) { + clearTimeout(this.#timer); + this.#timer = null; + } + } + + reset() { + this.#attempt = 0; + this.clear(); + } + + schedule(callback: () => void) { + if (this.#timer) return null; + + const timeout = exponentialBackoffDelay(this.#attempt, this.options); + this.#timer = globalThis.setTimeout(() => { + this.#timer = null; + callback(); + }, timeout); + this.#attempt += 1; + + return timeout; + } } diff --git a/packages/backend/server/src/plugins/oauth/providers/oidc.ts b/packages/backend/server/src/plugins/oauth/providers/oidc.ts index c63add0200..f4859afa5e 100644 --- a/packages/backend/server/src/plugins/oauth/providers/oidc.ts +++ b/packages/backend/server/src/plugins/oauth/providers/oidc.ts @@ -1,9 +1,10 @@ -import { Injectable } from '@nestjs/common'; +import { Injectable, OnModuleDestroy } from '@nestjs/common'; import { createRemoteJWKSet, type JWTPayload, jwtVerify } from 'jose'; import { omit } from 'lodash-es'; import { z } from 'zod'; import { + ExponentialBackoffScheduler, InvalidAuthState, InvalidOauthResponse, URLHelper, @@ -35,7 +36,7 @@ const OIDCUserInfoSchema = z .object({ sub: z.string(), preferred_username: z.string().optional(), - email: z.string().email(), + email: z.string().optional(), name: z.string().optional(), email_verified: z .union([z.boolean(), z.enum(['true', 'false', '1', '0', 'yes', 'no'])]) @@ -44,6 +45,8 @@ const OIDCUserInfoSchema = z }) .passthrough(); +const OIDCEmailSchema = z.string().email(); + const OIDCConfigurationSchema = z.object({ authorization_endpoint: z.string().url(), token_endpoint: z.string().url(), @@ -54,16 +57,28 @@ const OIDCConfigurationSchema = z.object({ type OIDCConfiguration = z.infer; +const OIDC_DISCOVERY_INITIAL_RETRY_DELAY = 1000; +const OIDC_DISCOVERY_MAX_RETRY_DELAY = 60_000; + @Injectable() -export class OIDCProvider extends OAuthProvider { +export class OIDCProvider extends OAuthProvider implements OnModuleDestroy { override provider = OAuthProviderName.OIDC; #endpoints: OIDCConfiguration | null = null; #jwks: ReturnType | null = null; + readonly #retryScheduler = new ExponentialBackoffScheduler({ + baseDelayMs: OIDC_DISCOVERY_INITIAL_RETRY_DELAY, + maxDelayMs: OIDC_DISCOVERY_MAX_RETRY_DELAY, + }); + #validationGeneration = 0; constructor(private readonly url: URLHelper) { super(); } + onModuleDestroy() { + this.#retryScheduler.clear(); + } + override get requiresPkce() { return true; } @@ -87,58 +102,109 @@ export class OIDCProvider extends OAuthProvider { } protected override setup() { - const validate = async () => { - this.#endpoints = null; - this.#jwks = null; + const generation = ++this.#validationGeneration; + this.#retryScheduler.clear(); - if (super.configured) { - const config = this.config as OAuthOIDCProviderConfig; - if (!config.issuer) { - this.logger.error('Missing OIDC issuer configuration'); - super.setup(); - return; - } - - try { - const res = await fetch( - `${config.issuer}/.well-known/openid-configuration`, - { - method: 'GET', - headers: { Accept: 'application/json' }, - } - ); - - if (res.ok) { - const configuration = OIDCConfigurationSchema.parse( - await res.json() - ); - if ( - this.normalizeIssuer(config.issuer) !== - this.normalizeIssuer(configuration.issuer) - ) { - this.logger.error( - `OIDC issuer mismatch, expected ${config.issuer}, got ${configuration.issuer}` - ); - } else { - this.#endpoints = configuration; - this.#jwks = createRemoteJWKSet(new URL(configuration.jwks_uri)); - } - } else { - this.logger.error(`Invalid OIDC issuer ${config.issuer}`); - } - } catch (e) { - this.logger.error('Failed to validate OIDC configuration', e); - } - } - - super.setup(); - }; - - validate().catch(() => { + this.validateAndSync(generation).catch(() => { /* noop */ }); } + private async validateAndSync(generation: number) { + if (generation !== this.#validationGeneration) { + return; + } + + if (!super.configured) { + this.resetState(); + this.#retryScheduler.reset(); + super.setup(); + return; + } + + const config = this.config as OAuthOIDCProviderConfig; + if (!config.issuer) { + this.logger.error('Missing OIDC issuer configuration'); + this.resetState(); + this.#retryScheduler.reset(); + super.setup(); + return; + } + + try { + const res = await fetch( + `${config.issuer}/.well-known/openid-configuration`, + { + method: 'GET', + headers: { Accept: 'application/json' }, + } + ); + + if (generation !== this.#validationGeneration) { + return; + } + + if (!res.ok) { + this.logger.error(`Invalid OIDC issuer ${config.issuer}`); + this.onValidationFailure(generation); + return; + } + + const configuration = OIDCConfigurationSchema.parse(await res.json()); + if ( + this.normalizeIssuer(config.issuer) !== + this.normalizeIssuer(configuration.issuer) + ) { + this.logger.error( + `OIDC issuer mismatch, expected ${config.issuer}, got ${configuration.issuer}` + ); + this.onValidationFailure(generation); + return; + } + + this.#endpoints = configuration; + this.#jwks = createRemoteJWKSet(new URL(configuration.jwks_uri)); + this.#retryScheduler.reset(); + super.setup(); + } catch (e) { + if (generation !== this.#validationGeneration) { + return; + } + this.logger.error('Failed to validate OIDC configuration', e); + this.onValidationFailure(generation); + } + } + + private onValidationFailure(generation: number) { + this.resetState(); + super.setup(); + this.scheduleRetry(generation); + } + + private scheduleRetry(generation: number) { + if (generation !== this.#validationGeneration) { + return; + } + + const delay = this.#retryScheduler.schedule(() => { + this.validateAndSync(generation).catch(() => { + /* noop */ + }); + }); + if (delay === null) { + return; + } + + this.logger.warn( + `OIDC discovery validation failed, retrying in ${delay}ms` + ); + } + + private resetState() { + this.#endpoints = null; + this.#jwks = null; + } + getAuthUrl(state: string): string { const parsedState = this.parseStatePayload(state); const nonce = parsedState?.state ?? state; @@ -291,6 +357,68 @@ export class OIDCProvider extends OAuthProvider { return undefined; } + private claimCandidates( + configuredClaim: string | undefined, + defaultClaim: string + ) { + if (typeof configuredClaim === 'string' && configuredClaim.length > 0) { + return [configuredClaim]; + } + return [defaultClaim]; + } + + private formatClaimCandidates(claims: string[]) { + return claims.map(claim => `"${claim}"`).join(', '); + } + + private resolveStringClaim( + claims: string[], + ...sources: Array> + ) { + for (const claim of claims) { + for (const source of sources) { + const value = this.extractString(source[claim]); + if (value) { + return value; + } + } + } + + return undefined; + } + + private resolveBooleanClaim( + claims: string[], + ...sources: Array> + ) { + for (const claim of claims) { + for (const source of sources) { + const value = this.extractBoolean(source[claim]); + if (value !== undefined) { + return value; + } + } + } + + return undefined; + } + + private resolveEmailClaim( + claims: string[], + ...sources: Array> + ) { + for (const claim of claims) { + for (const source of sources) { + const value = this.extractString(source[claim]); + if (value && OIDCEmailSchema.safeParse(value).success) { + return value; + } + } + } + + return undefined; + } + async getUser(tokens: Tokens, state: OAuthState): Promise { if (!tokens.idToken) { throw new InvalidOauthResponse({ @@ -315,6 +443,8 @@ export class OIDCProvider extends OAuthProvider { { treatServerErrorAsInvalid: true } ); const user = OIDCUserInfoSchema.parse(rawUser); + const userClaims = user as Record; + const idTokenClaimsRecord = idTokenClaims as Record; if (!user.sub || !idTokenClaims.sub) { throw new InvalidOauthResponse({ @@ -327,22 +457,29 @@ export class OIDCProvider extends OAuthProvider { } const args = this.config.args ?? {}; + const idClaims = this.claimCandidates(args.claim_id, 'sub'); + const emailClaims = this.claimCandidates(args.claim_email, 'email'); + const nameClaims = this.claimCandidates(args.claim_name, 'name'); + const emailVerifiedClaims = this.claimCandidates( + args.claim_email_verified, + 'email_verified' + ); - const claimsMap = { - id: args.claim_id || 'sub', - email: args.claim_email || 'email', - name: args.claim_name || 'name', - emailVerified: args.claim_email_verified || 'email_verified', - }; - - const accountId = - this.extractString(user[claimsMap.id]) ?? idTokenClaims.sub; - const email = - this.extractString(user[claimsMap.email]) || - this.extractString(idTokenClaims.email); - const emailVerified = - this.extractBoolean(user[claimsMap.emailVerified]) ?? - this.extractBoolean(idTokenClaims.email_verified); + const accountId = this.resolveStringClaim( + idClaims, + userClaims, + idTokenClaimsRecord + ); + const email = this.resolveEmailClaim( + emailClaims, + userClaims, + idTokenClaimsRecord + ); + const emailVerified = this.resolveBooleanClaim( + emailVerifiedClaims, + userClaims, + idTokenClaimsRecord + ); if (!accountId) { throw new InvalidOauthResponse({ @@ -352,7 +489,7 @@ export class OIDCProvider extends OAuthProvider { if (!email) { throw new InvalidOauthResponse({ - reason: 'Missing required claim for email', + reason: `Missing valid email claim in OIDC response. Tried userinfo and ID token claims: ${this.formatClaimCandidates(emailClaims)}`, }); } @@ -367,9 +504,11 @@ export class OIDCProvider extends OAuthProvider { email, }; - const name = - this.extractString(user[claimsMap.name]) || - this.extractString(idTokenClaims.name); + const name = this.resolveStringClaim( + nameClaims, + userClaims, + idTokenClaimsRecord + ); if (name) { account.name = name; } diff --git a/packages/frontend/core/src/__tests__/oauth-flow.spec.ts b/packages/frontend/core/src/__tests__/oauth-flow.spec.ts new file mode 100644 index 0000000000..52b378acbc --- /dev/null +++ b/packages/frontend/core/src/__tests__/oauth-flow.spec.ts @@ -0,0 +1,63 @@ +import { + attachOAuthFlowToAuthUrl, + parseOAuthCallbackState, + resolveOAuthFlowMode, + resolveOAuthRedirect, +} from '@affine/core/desktop/pages/auth/oauth-flow'; +import { describe, expect, test } from 'vitest'; + +describe('oauth flow mode', () => { + test('defaults to redirect for missing or unknown values', () => { + expect(resolveOAuthFlowMode()).toBe('redirect'); + expect(resolveOAuthFlowMode(null)).toBe('redirect'); + expect(resolveOAuthFlowMode('unknown')).toBe('redirect'); + }); + + test('persists flow in oauth state instead of web storage', () => { + const url = attachOAuthFlowToAuthUrl( + 'https://example.com/auth?state=%7B%22state%22%3A%22nonce%22%2C%22provider%22%3A%22Google%22%2C%22client%22%3A%22web%22%7D', + 'redirect' + ); + + expect( + parseOAuthCallbackState(new URL(url).searchParams.get('state')!) + ).toEqual({ + client: 'web', + flow: 'redirect', + provider: 'Google', + state: 'nonce', + }); + }); + + test('falls back to popup when callback state has no flow', () => { + expect( + parseOAuthCallbackState( + JSON.stringify({ client: 'web', provider: 'Google', state: 'nonce' }) + ).flow + ).toBe('popup'); + }); + + test('keeps same-origin redirects direct', () => { + expect(resolveOAuthRedirect('/workspace', 'https://app.affine.pro')).toBe( + '/workspace' + ); + + expect( + resolveOAuthRedirect( + 'https://app.affine.pro/workspace?from=oauth', + 'https://app.affine.pro' + ) + ).toBe('https://app.affine.pro/workspace?from=oauth'); + }); + + test('wraps external redirects with redirect-proxy', () => { + expect( + resolveOAuthRedirect( + 'https://github.com/toeverything/AFFiNE', + 'https://app.affine.pro' + ) + ).toBe( + 'https://app.affine.pro/redirect-proxy?redirect_uri=https%3A%2F%2Fgithub.com%2Ftoeverything%2FAFFiNE' + ); + }); +}); diff --git a/packages/frontend/core/src/components/affine/auth/oauth.tsx b/packages/frontend/core/src/components/affine/auth/oauth.tsx index 6fda9696e3..0a16792c3b 100644 --- a/packages/frontend/core/src/components/affine/auth/oauth.tsx +++ b/packages/frontend/core/src/components/affine/auth/oauth.tsx @@ -73,11 +73,13 @@ export function OAuth({ redirectUrl }: { redirectUrl?: string }) { params.set('redirect_uri', redirectUrl); } + params.set('flow', 'redirect'); + const oauthUrl = serverService.server.baseUrl + `/oauth/login?${params.toString()}`; - urlService.openPopupWindow(oauthUrl); + urlService.openExternal(oauthUrl); }; const ret = open(); diff --git a/packages/frontend/core/src/desktop/pages/auth/oauth-callback.tsx b/packages/frontend/core/src/desktop/pages/auth/oauth-callback.tsx index 5b5a14924d..e8d1e49487 100644 --- a/packages/frontend/core/src/desktop/pages/auth/oauth-callback.tsx +++ b/packages/frontend/core/src/desktop/pages/auth/oauth-callback.tsx @@ -13,10 +13,16 @@ import { buildOpenAppUrlRoute, } from '../../../modules/open-in-app'; import { supportedClient } from './common'; +import { + type OAuthFlowMode, + parseOAuthCallbackState, + resolveOAuthRedirect, +} from './oauth-flow'; interface LoaderData { state: string; code: string; + flow: OAuthFlowMode; provider: string; } @@ -31,12 +37,18 @@ export const loader: LoaderFunction = async ({ request }) => { } try { - const { state, client, provider } = JSON.parse(stateStr); + const { state, client, flow, provider } = parseOAuthCallbackState(stateStr); + + if (!state || !provider) { + return redirect('/sign-in?error=Invalid oauth callback parameters'); + } + stateStr = state; const payload: LoaderData = { state, code, + flow, provider, }; @@ -79,8 +91,13 @@ export const Component = () => { triggeredRef.current = true; auth .signInOauth(data.code, data.state, data.provider) - .then(() => { - window.close(); + .then(({ redirectUri }) => { + if (data.flow === 'popup') { + window.close(); + return; + } + + location.replace(resolveOAuthRedirect(redirectUri, location.origin)); }) .catch(e => { nav(`/sign-in?error=${encodeURIComponent(e.message)}`); diff --git a/packages/frontend/core/src/desktop/pages/auth/oauth-flow.ts b/packages/frontend/core/src/desktop/pages/auth/oauth-flow.ts new file mode 100644 index 0000000000..60707a16d5 --- /dev/null +++ b/packages/frontend/core/src/desktop/pages/auth/oauth-flow.ts @@ -0,0 +1,73 @@ +export const oauthFlowModes = ['popup', 'redirect'] as const; + +export type OAuthFlowMode = (typeof oauthFlowModes)[number]; + +export function resolveOAuthFlowMode( + mode?: string | null, + fallback: OAuthFlowMode = 'redirect' +): OAuthFlowMode { + return mode === 'popup' || mode === 'redirect' ? mode : fallback; +} + +export function attachOAuthFlowToAuthUrl(url: string, flow: OAuthFlowMode) { + const authUrl = new URL(url); + const state = authUrl.searchParams.get('state'); + if (!state) return url; + + try { + const payload = JSON.parse(state) as Record; + authUrl.searchParams.set('state', JSON.stringify({ ...payload, flow })); + return authUrl.toString(); + } catch { + return url; + } +} + +export function readOAuthFlowModeFromCallbackState(state: string | null) { + if (!state) return 'popup'; + + try { + const payload = JSON.parse(state) as { flow?: string }; + return resolveOAuthFlowMode(payload.flow, 'popup'); + } catch { + return 'popup'; + } +} + +export function parseOAuthCallbackState(state: string) { + const parsed = JSON.parse(state) as { + client?: string; + provider?: string; + state?: string; + }; + + return { + client: parsed.client, + flow: readOAuthFlowModeFromCallbackState(state), + provider: parsed.provider, + state: parsed.state, + }; +} + +export function resolveOAuthRedirect( + redirectUri: string | null | undefined, + currentOrigin: string +) { + if (!redirectUri) return '/'; + if (redirectUri.startsWith('/') && !redirectUri.startsWith('//')) { + return redirectUri; + } + + let target: URL; + try { + target = new URL(redirectUri); + } catch { + return '/'; + } + + if (target.origin === currentOrigin) return target.toString(); + + const redirectProxy = new URL('/redirect-proxy', currentOrigin); + redirectProxy.searchParams.set('redirect_uri', target.toString()); + return redirectProxy.toString(); +} diff --git a/packages/frontend/core/src/desktop/pages/auth/oauth-login.tsx b/packages/frontend/core/src/desktop/pages/auth/oauth-login.tsx index 4cb3a270b5..c2c03c2bb4 100644 --- a/packages/frontend/core/src/desktop/pages/auth/oauth-login.tsx +++ b/packages/frontend/core/src/desktop/pages/auth/oauth-login.tsx @@ -12,6 +12,7 @@ import { import { z } from 'zod'; import { supportedClient } from './common'; +import { attachOAuthFlowToAuthUrl, resolveOAuthFlowMode } from './oauth-flow'; const supportedProvider = z.nativeEnum(OAuthProviderType); const CSRF_COOKIE_NAME = 'affine_csrf_token'; @@ -36,12 +37,14 @@ const oauthParameters = z.object({ provider: supportedProvider, client: supportedClient, redirectUri: z.string().optional().nullable(), + flow: z.string().optional().nullable(), }); interface LoaderData { provider: OAuthProviderType; client: string; redirectUri?: string; + flow: string; } export const loader: LoaderFunction = async ({ request }) => { @@ -50,6 +53,7 @@ export const loader: LoaderFunction = async ({ request }) => { const provider = searchParams.get('provider'); const client = searchParams.get('client') ?? 'web'; const redirectUri = searchParams.get('redirect_uri'); + const flow = searchParams.get('flow'); // sign out first, web only if (client === 'web') { @@ -64,6 +68,7 @@ export const loader: LoaderFunction = async ({ request }) => { provider, client, redirectUri, + flow, }); if (paramsParseResult.success) { @@ -71,6 +76,7 @@ export const loader: LoaderFunction = async ({ request }) => { provider, client, redirectUri, + flow: resolveOAuthFlowMode(flow), }; } @@ -90,7 +96,10 @@ export const Component = () => { .oauthPreflight(data.provider, data.client, data.redirectUri) .then(({ url }) => { // this is the url of oauth provider auth page, can't navigate with react-router - location.href = url; + location.href = attachOAuthFlowToAuthUrl( + url, + resolveOAuthFlowMode(data.flow) + ); }) .catch(e => { nav(`/sign-in?error=${encodeURIComponent(e.message)}`);