feat(server): improve oidc compatibility (#14686)

fix #13938 
fix #14683 
fix #14532

#### PR Dependency Tree


* **PR #14686** 👈

This tree was auto-generated by
[Charcoal](https://github.com/danerwilliams/charcoal)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Flexible OIDC claim mapping for email/name, automatic OIDC discovery
retry with exponential backoff, and explicit OAuth flow modes (popup vs
redirect) propagated through the auth flow.

* **Bug Fixes**
* Stricter OIDC email validation, clearer error messages listing
attempted claim candidates, and improved callback redirect handling for
various flow scenarios.

* **Tests**
* Added unit tests covering OIDC behaviors, backoff scheduler/promise
utilities, and frontend OAuth flow parsing/redirect logic.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
DarkSky
2026-03-20 04:02:37 +08:00
committed by GitHub
parent 1ffb8c922c
commit 16a8f17717
9 changed files with 807 additions and 77 deletions

View File

@@ -6,13 +6,16 @@ import ava, { TestFn } from 'ava';
import Sinon from 'sinon';
import { AppModule } from '../../app.module';
import { ConfigFactory, URLHelper } from '../../base';
import { ConfigFactory, InvalidOauthResponse, URLHelper } from '../../base';
import { ConfigModule } from '../../base/config';
import { CurrentUser } from '../../core/auth';
import { AuthService } from '../../core/auth/service';
import { ServerFeature } from '../../core/config/types';
import { Models } from '../../models';
import { OAuthProviderName } from '../../plugins/oauth/config';
import { OAuthProviderFactory } from '../../plugins/oauth/factory';
import { GoogleOAuthProvider } from '../../plugins/oauth/providers/google';
import { OIDCProvider } from '../../plugins/oauth/providers/oidc';
import { OAuthService } from '../../plugins/oauth/service';
import { createTestingApp, currentUser, TestingApp } from '../utils';
@@ -35,6 +38,12 @@ test.before(async t => {
clientId: 'google-client-id',
clientSecret: 'google-client-secret',
},
oidc: {
clientId: '',
clientSecret: '',
issuer: '',
args: {},
},
},
},
server: {
@@ -432,6 +441,87 @@ function mockOAuthProvider(
return clientNonce;
}
function mockOidcProvider(
provider: OIDCProvider,
{
args = {},
idTokenClaims,
userinfo,
}: {
args?: Record<string, string>;
idTokenClaims: Record<string, unknown>;
userinfo: Record<string, unknown>;
}
) {
Sinon.stub(provider, 'config').get(() => ({
clientId: '',
clientSecret: '',
issuer: '',
args,
}));
Sinon.stub(
provider as unknown as { endpoints: { userinfo_endpoint: string } },
'endpoints'
).get(() => ({
userinfo_endpoint: 'https://oidc.affine.dev/userinfo',
}));
Sinon.stub(
provider as unknown as { verifyIdToken: () => unknown },
'verifyIdToken'
).resolves(idTokenClaims);
Sinon.stub(
provider as unknown as { fetchJson: () => unknown },
'fetchJson'
).resolves(userinfo);
}
function createOidcRegistrationHarness(config?: {
clientId?: string;
clientSecret?: string;
issuer?: string;
}) {
const server = {
enableFeature: Sinon.spy(),
disableFeature: Sinon.spy(),
};
const factory = new OAuthProviderFactory(server as any);
const affineConfig = {
server: {
externalUrl: 'https://affine.example',
host: 'localhost',
path: '',
https: true,
hosts: [],
},
oauth: {
providers: {
oidc: {
clientId: config?.clientId ?? 'oidc-client-id',
clientSecret: config?.clientSecret ?? 'oidc-client-secret',
issuer: config?.issuer ?? 'https://issuer.affine.dev',
args: {},
},
},
},
};
const provider = new OIDCProvider(new URLHelper(affineConfig as any));
(provider as any).factory = factory;
(provider as any).AFFiNEConfig = affineConfig;
return {
provider,
factory,
server,
};
}
async function flushAsyncWork(iterations = 5) {
for (let i = 0; i < iterations; i++) {
await new Promise(resolve => setImmediate(resolve));
}
}
test('should be able to sign up with oauth', async t => {
const { app, db } = t.context;
@@ -554,3 +644,209 @@ test('should be able to fullfil user with oauth sign in', async t => {
t.truthy(account);
t.is(account!.user.id, u3.id);
});
test('oidc should accept email from id token when userinfo email is missing', async t => {
const { app } = t.context;
const provider = app.get(OIDCProvider);
mockOidcProvider(provider, {
idTokenClaims: {
sub: 'oidc-user',
email: 'oidc-id-token@affine.pro',
name: 'OIDC User',
},
userinfo: {
sub: 'oidc-user',
name: 'OIDC User',
},
});
const user = await provider.getUser(
{ accessToken: 'token', idToken: 'id-token' },
{ token: 'nonce', provider: OAuthProviderName.OIDC }
);
t.is(user.id, 'oidc-user');
t.is(user.email, 'oidc-id-token@affine.pro');
t.is(user.name, 'OIDC User');
});
test('oidc should resolve custom email claim from userinfo', async t => {
const { app } = t.context;
const provider = app.get(OIDCProvider);
mockOidcProvider(provider, {
args: { claim_email: 'mail', claim_name: 'display_name' },
idTokenClaims: {
sub: 'oidc-user',
},
userinfo: {
sub: 'oidc-user',
mail: 'oidc-userinfo@affine.pro',
display_name: 'OIDC Custom',
},
});
const user = await provider.getUser(
{ accessToken: 'token', idToken: 'id-token' },
{ token: 'nonce', provider: OAuthProviderName.OIDC }
);
t.is(user.id, 'oidc-user');
t.is(user.email, 'oidc-userinfo@affine.pro');
t.is(user.name, 'OIDC Custom');
});
test('oidc should resolve custom email claim from id token', async t => {
const { app } = t.context;
const provider = app.get(OIDCProvider);
mockOidcProvider(provider, {
args: { claim_email: 'mail', claim_email_verified: 'mail_verified' },
idTokenClaims: {
sub: 'oidc-user',
mail: 'oidc-custom-id-token@affine.pro',
mail_verified: 'true',
},
userinfo: {
sub: 'oidc-user',
},
});
const user = await provider.getUser(
{ accessToken: 'token', idToken: 'id-token' },
{ token: 'nonce', provider: OAuthProviderName.OIDC }
);
t.is(user.id, 'oidc-user');
t.is(user.email, 'oidc-custom-id-token@affine.pro');
});
test('oidc should reject responses without a usable email claim', async t => {
const { app } = t.context;
const provider = app.get(OIDCProvider);
mockOidcProvider(provider, {
args: { claim_email: 'mail' },
idTokenClaims: {
sub: 'oidc-user',
mail: 'not-an-email',
},
userinfo: {
sub: 'oidc-user',
mail: 'still-not-an-email',
},
});
const error = await t.throwsAsync(
provider.getUser(
{ accessToken: 'token', idToken: 'id-token' },
{ token: 'nonce', provider: OAuthProviderName.OIDC }
)
);
t.true(error instanceof InvalidOauthResponse);
t.true(
error.message.includes(
'Missing valid email claim in OIDC response. Tried userinfo and ID token claims: "mail"'
)
);
});
test('oidc should not fall back to default email claim when custom claim is configured', async t => {
const { app } = t.context;
const provider = app.get(OIDCProvider);
mockOidcProvider(provider, {
args: { claim_email: 'mail' },
idTokenClaims: {
sub: 'oidc-user',
email: 'fallback@affine.pro',
},
userinfo: {
sub: 'oidc-user',
email: 'userinfo-fallback@affine.pro',
},
});
const error = await t.throwsAsync(
provider.getUser(
{ accessToken: 'token', idToken: 'id-token' },
{ token: 'nonce', provider: OAuthProviderName.OIDC }
)
);
t.true(error instanceof InvalidOauthResponse);
t.true(
error.message.includes(
'Missing valid email claim in OIDC response. Tried userinfo and ID token claims: "mail"'
)
);
});
test('oidc discovery should remove oauth feature on failure and restore it after backoff retry succeeds', async t => {
const { provider, factory, server } = createOidcRegistrationHarness();
const fetchStub = Sinon.stub(globalThis, 'fetch');
const scheduledRetries: Array<() => void> = [];
const retryDelays: number[] = [];
const setTimeoutStub = Sinon.stub(globalThis, 'setTimeout').callsFake(((
callback: Parameters<typeof setTimeout>[0],
delay?: number
) => {
retryDelays.push(Number(delay));
scheduledRetries.push(callback as () => void);
return Symbol('timeout') as unknown as ReturnType<typeof setTimeout>;
}) as typeof setTimeout);
t.teardown(() => {
provider.onModuleDestroy();
fetchStub.restore();
setTimeoutStub.restore();
});
fetchStub
.onFirstCall()
.rejects(new Error('temporary discovery failure'))
.onSecondCall()
.rejects(new Error('temporary discovery failure'))
.onThirdCall()
.resolves(
new Response(
JSON.stringify({
authorization_endpoint: 'https://issuer.affine.dev/auth',
token_endpoint: 'https://issuer.affine.dev/token',
userinfo_endpoint: 'https://issuer.affine.dev/userinfo',
issuer: 'https://issuer.affine.dev',
jwks_uri: 'https://issuer.affine.dev/jwks',
}),
{
status: 200,
headers: { 'Content-Type': 'application/json' },
}
)
);
(provider as any).setup();
await flushAsyncWork();
t.deepEqual(factory.providers, []);
t.true(server.disableFeature.calledWith(ServerFeature.OAuth));
t.is(fetchStub.callCount, 1);
t.deepEqual(retryDelays, [1000]);
const firstRetry = scheduledRetries.shift();
t.truthy(firstRetry);
firstRetry!();
await flushAsyncWork();
t.is(fetchStub.callCount, 2);
t.deepEqual(factory.providers, []);
t.deepEqual(retryDelays, [1000, 2000]);
const secondRetry = scheduledRetries.shift();
t.truthy(secondRetry);
secondRetry!();
await flushAsyncWork();
t.is(fetchStub.callCount, 3);
t.deepEqual(factory.providers, [OAuthProviderName.OIDC]);
t.true(server.enableFeature.calledWith(ServerFeature.OAuth));
t.is(scheduledRetries.length, 0);
});

View File

@@ -0,0 +1,75 @@
import test from 'ava';
import Sinon from 'sinon';
import {
exponentialBackoffDelay,
ExponentialBackoffScheduler,
} from '../promise';
test('exponentialBackoffDelay should cap exponential growth at maxDelayMs', t => {
t.is(exponentialBackoffDelay(0, { baseDelayMs: 100, maxDelayMs: 500 }), 100);
t.is(exponentialBackoffDelay(1, { baseDelayMs: 100, maxDelayMs: 500 }), 200);
t.is(exponentialBackoffDelay(3, { baseDelayMs: 100, maxDelayMs: 500 }), 500);
});
test('ExponentialBackoffScheduler should track pending callback and increase delay per attempt', async t => {
const clock = Sinon.useFakeTimers();
t.teardown(() => {
clock.restore();
});
const calls: number[] = [];
const scheduler = new ExponentialBackoffScheduler({
baseDelayMs: 100,
maxDelayMs: 500,
});
t.is(
scheduler.schedule(() => {
calls.push(1);
}),
100
);
t.true(scheduler.pending);
t.is(
scheduler.schedule(() => {
calls.push(2);
}),
null
);
await clock.tickAsync(100);
t.deepEqual(calls, [1]);
t.false(scheduler.pending);
t.is(
scheduler.schedule(() => {
calls.push(3);
}),
200
);
await clock.tickAsync(200);
t.deepEqual(calls, [1, 3]);
});
test('ExponentialBackoffScheduler reset should clear pending work and restart from the base delay', t => {
const scheduler = new ExponentialBackoffScheduler({
baseDelayMs: 100,
maxDelayMs: 500,
});
t.is(
scheduler.schedule(() => {}),
100
);
t.true(scheduler.pending);
scheduler.reset();
t.false(scheduler.pending);
t.is(
scheduler.schedule(() => {}),
100
);
scheduler.clear();
});

View File

@@ -1,4 +1,4 @@
import { setTimeout } from 'node:timers/promises';
import { setTimeout as delay } from 'node:timers/promises';
import { defer as rxjsDefer, retry } from 'rxjs';
@@ -52,5 +52,61 @@ export function defer(dispose: () => Promise<void>) {
}
export function sleep(ms: number): Promise<void> {
return setTimeout(ms);
return delay(ms);
}
export function exponentialBackoffDelay(
attempt: number,
{
baseDelayMs,
maxDelayMs,
factor = 2,
}: { baseDelayMs: number; maxDelayMs: number; factor?: number }
): number {
return Math.min(
baseDelayMs * Math.pow(factor, Math.max(0, attempt)),
maxDelayMs
);
}
export class ExponentialBackoffScheduler {
#attempt = 0;
#timer: ReturnType<typeof globalThis.setTimeout> | null = null;
constructor(
private readonly options: {
baseDelayMs: number;
maxDelayMs: number;
factor?: number;
}
) {}
get pending() {
return this.#timer !== null;
}
clear() {
if (this.#timer) {
clearTimeout(this.#timer);
this.#timer = null;
}
}
reset() {
this.#attempt = 0;
this.clear();
}
schedule(callback: () => void) {
if (this.#timer) return null;
const timeout = exponentialBackoffDelay(this.#attempt, this.options);
this.#timer = globalThis.setTimeout(() => {
this.#timer = null;
callback();
}, timeout);
this.#attempt += 1;
return timeout;
}
}

View File

@@ -1,9 +1,10 @@
import { Injectable } from '@nestjs/common';
import { Injectable, OnModuleDestroy } from '@nestjs/common';
import { createRemoteJWKSet, type JWTPayload, jwtVerify } from 'jose';
import { omit } from 'lodash-es';
import { z } from 'zod';
import {
ExponentialBackoffScheduler,
InvalidAuthState,
InvalidOauthResponse,
URLHelper,
@@ -35,7 +36,7 @@ const OIDCUserInfoSchema = z
.object({
sub: z.string(),
preferred_username: z.string().optional(),
email: z.string().email(),
email: z.string().optional(),
name: z.string().optional(),
email_verified: z
.union([z.boolean(), z.enum(['true', 'false', '1', '0', 'yes', 'no'])])
@@ -44,6 +45,8 @@ const OIDCUserInfoSchema = z
})
.passthrough();
const OIDCEmailSchema = z.string().email();
const OIDCConfigurationSchema = z.object({
authorization_endpoint: z.string().url(),
token_endpoint: z.string().url(),
@@ -54,16 +57,28 @@ const OIDCConfigurationSchema = z.object({
type OIDCConfiguration = z.infer<typeof OIDCConfigurationSchema>;
const OIDC_DISCOVERY_INITIAL_RETRY_DELAY = 1000;
const OIDC_DISCOVERY_MAX_RETRY_DELAY = 60_000;
@Injectable()
export class OIDCProvider extends OAuthProvider {
export class OIDCProvider extends OAuthProvider implements OnModuleDestroy {
override provider = OAuthProviderName.OIDC;
#endpoints: OIDCConfiguration | null = null;
#jwks: ReturnType<typeof createRemoteJWKSet> | null = null;
readonly #retryScheduler = new ExponentialBackoffScheduler({
baseDelayMs: OIDC_DISCOVERY_INITIAL_RETRY_DELAY,
maxDelayMs: OIDC_DISCOVERY_MAX_RETRY_DELAY,
});
#validationGeneration = 0;
constructor(private readonly url: URLHelper) {
super();
}
onModuleDestroy() {
this.#retryScheduler.clear();
}
override get requiresPkce() {
return true;
}
@@ -87,58 +102,109 @@ export class OIDCProvider extends OAuthProvider {
}
protected override setup() {
const validate = async () => {
this.#endpoints = null;
this.#jwks = null;
const generation = ++this.#validationGeneration;
this.#retryScheduler.clear();
if (super.configured) {
const config = this.config as OAuthOIDCProviderConfig;
if (!config.issuer) {
this.logger.error('Missing OIDC issuer configuration');
super.setup();
return;
}
try {
const res = await fetch(
`${config.issuer}/.well-known/openid-configuration`,
{
method: 'GET',
headers: { Accept: 'application/json' },
}
);
if (res.ok) {
const configuration = OIDCConfigurationSchema.parse(
await res.json()
);
if (
this.normalizeIssuer(config.issuer) !==
this.normalizeIssuer(configuration.issuer)
) {
this.logger.error(
`OIDC issuer mismatch, expected ${config.issuer}, got ${configuration.issuer}`
);
} else {
this.#endpoints = configuration;
this.#jwks = createRemoteJWKSet(new URL(configuration.jwks_uri));
}
} else {
this.logger.error(`Invalid OIDC issuer ${config.issuer}`);
}
} catch (e) {
this.logger.error('Failed to validate OIDC configuration', e);
}
}
super.setup();
};
validate().catch(() => {
this.validateAndSync(generation).catch(() => {
/* noop */
});
}
private async validateAndSync(generation: number) {
if (generation !== this.#validationGeneration) {
return;
}
if (!super.configured) {
this.resetState();
this.#retryScheduler.reset();
super.setup();
return;
}
const config = this.config as OAuthOIDCProviderConfig;
if (!config.issuer) {
this.logger.error('Missing OIDC issuer configuration');
this.resetState();
this.#retryScheduler.reset();
super.setup();
return;
}
try {
const res = await fetch(
`${config.issuer}/.well-known/openid-configuration`,
{
method: 'GET',
headers: { Accept: 'application/json' },
}
);
if (generation !== this.#validationGeneration) {
return;
}
if (!res.ok) {
this.logger.error(`Invalid OIDC issuer ${config.issuer}`);
this.onValidationFailure(generation);
return;
}
const configuration = OIDCConfigurationSchema.parse(await res.json());
if (
this.normalizeIssuer(config.issuer) !==
this.normalizeIssuer(configuration.issuer)
) {
this.logger.error(
`OIDC issuer mismatch, expected ${config.issuer}, got ${configuration.issuer}`
);
this.onValidationFailure(generation);
return;
}
this.#endpoints = configuration;
this.#jwks = createRemoteJWKSet(new URL(configuration.jwks_uri));
this.#retryScheduler.reset();
super.setup();
} catch (e) {
if (generation !== this.#validationGeneration) {
return;
}
this.logger.error('Failed to validate OIDC configuration', e);
this.onValidationFailure(generation);
}
}
private onValidationFailure(generation: number) {
this.resetState();
super.setup();
this.scheduleRetry(generation);
}
private scheduleRetry(generation: number) {
if (generation !== this.#validationGeneration) {
return;
}
const delay = this.#retryScheduler.schedule(() => {
this.validateAndSync(generation).catch(() => {
/* noop */
});
});
if (delay === null) {
return;
}
this.logger.warn(
`OIDC discovery validation failed, retrying in ${delay}ms`
);
}
private resetState() {
this.#endpoints = null;
this.#jwks = null;
}
getAuthUrl(state: string): string {
const parsedState = this.parseStatePayload(state);
const nonce = parsedState?.state ?? state;
@@ -291,6 +357,68 @@ export class OIDCProvider extends OAuthProvider {
return undefined;
}
private claimCandidates(
configuredClaim: string | undefined,
defaultClaim: string
) {
if (typeof configuredClaim === 'string' && configuredClaim.length > 0) {
return [configuredClaim];
}
return [defaultClaim];
}
private formatClaimCandidates(claims: string[]) {
return claims.map(claim => `"${claim}"`).join(', ');
}
private resolveStringClaim(
claims: string[],
...sources: Array<Record<string, unknown>>
) {
for (const claim of claims) {
for (const source of sources) {
const value = this.extractString(source[claim]);
if (value) {
return value;
}
}
}
return undefined;
}
private resolveBooleanClaim(
claims: string[],
...sources: Array<Record<string, unknown>>
) {
for (const claim of claims) {
for (const source of sources) {
const value = this.extractBoolean(source[claim]);
if (value !== undefined) {
return value;
}
}
}
return undefined;
}
private resolveEmailClaim(
claims: string[],
...sources: Array<Record<string, unknown>>
) {
for (const claim of claims) {
for (const source of sources) {
const value = this.extractString(source[claim]);
if (value && OIDCEmailSchema.safeParse(value).success) {
return value;
}
}
}
return undefined;
}
async getUser(tokens: Tokens, state: OAuthState): Promise<OAuthAccount> {
if (!tokens.idToken) {
throw new InvalidOauthResponse({
@@ -315,6 +443,8 @@ export class OIDCProvider extends OAuthProvider {
{ treatServerErrorAsInvalid: true }
);
const user = OIDCUserInfoSchema.parse(rawUser);
const userClaims = user as Record<string, unknown>;
const idTokenClaimsRecord = idTokenClaims as Record<string, unknown>;
if (!user.sub || !idTokenClaims.sub) {
throw new InvalidOauthResponse({
@@ -327,22 +457,29 @@ export class OIDCProvider extends OAuthProvider {
}
const args = this.config.args ?? {};
const idClaims = this.claimCandidates(args.claim_id, 'sub');
const emailClaims = this.claimCandidates(args.claim_email, 'email');
const nameClaims = this.claimCandidates(args.claim_name, 'name');
const emailVerifiedClaims = this.claimCandidates(
args.claim_email_verified,
'email_verified'
);
const claimsMap = {
id: args.claim_id || 'sub',
email: args.claim_email || 'email',
name: args.claim_name || 'name',
emailVerified: args.claim_email_verified || 'email_verified',
};
const accountId =
this.extractString(user[claimsMap.id]) ?? idTokenClaims.sub;
const email =
this.extractString(user[claimsMap.email]) ||
this.extractString(idTokenClaims.email);
const emailVerified =
this.extractBoolean(user[claimsMap.emailVerified]) ??
this.extractBoolean(idTokenClaims.email_verified);
const accountId = this.resolveStringClaim(
idClaims,
userClaims,
idTokenClaimsRecord
);
const email = this.resolveEmailClaim(
emailClaims,
userClaims,
idTokenClaimsRecord
);
const emailVerified = this.resolveBooleanClaim(
emailVerifiedClaims,
userClaims,
idTokenClaimsRecord
);
if (!accountId) {
throw new InvalidOauthResponse({
@@ -352,7 +489,7 @@ export class OIDCProvider extends OAuthProvider {
if (!email) {
throw new InvalidOauthResponse({
reason: 'Missing required claim for email',
reason: `Missing valid email claim in OIDC response. Tried userinfo and ID token claims: ${this.formatClaimCandidates(emailClaims)}`,
});
}
@@ -367,9 +504,11 @@ export class OIDCProvider extends OAuthProvider {
email,
};
const name =
this.extractString(user[claimsMap.name]) ||
this.extractString(idTokenClaims.name);
const name = this.resolveStringClaim(
nameClaims,
userClaims,
idTokenClaimsRecord
);
if (name) {
account.name = name;
}

View File

@@ -0,0 +1,63 @@
import {
attachOAuthFlowToAuthUrl,
parseOAuthCallbackState,
resolveOAuthFlowMode,
resolveOAuthRedirect,
} from '@affine/core/desktop/pages/auth/oauth-flow';
import { describe, expect, test } from 'vitest';
describe('oauth flow mode', () => {
test('defaults to redirect for missing or unknown values', () => {
expect(resolveOAuthFlowMode()).toBe('redirect');
expect(resolveOAuthFlowMode(null)).toBe('redirect');
expect(resolveOAuthFlowMode('unknown')).toBe('redirect');
});
test('persists flow in oauth state instead of web storage', () => {
const url = attachOAuthFlowToAuthUrl(
'https://example.com/auth?state=%7B%22state%22%3A%22nonce%22%2C%22provider%22%3A%22Google%22%2C%22client%22%3A%22web%22%7D',
'redirect'
);
expect(
parseOAuthCallbackState(new URL(url).searchParams.get('state')!)
).toEqual({
client: 'web',
flow: 'redirect',
provider: 'Google',
state: 'nonce',
});
});
test('falls back to popup when callback state has no flow', () => {
expect(
parseOAuthCallbackState(
JSON.stringify({ client: 'web', provider: 'Google', state: 'nonce' })
).flow
).toBe('popup');
});
test('keeps same-origin redirects direct', () => {
expect(resolveOAuthRedirect('/workspace', 'https://app.affine.pro')).toBe(
'/workspace'
);
expect(
resolveOAuthRedirect(
'https://app.affine.pro/workspace?from=oauth',
'https://app.affine.pro'
)
).toBe('https://app.affine.pro/workspace?from=oauth');
});
test('wraps external redirects with redirect-proxy', () => {
expect(
resolveOAuthRedirect(
'https://github.com/toeverything/AFFiNE',
'https://app.affine.pro'
)
).toBe(
'https://app.affine.pro/redirect-proxy?redirect_uri=https%3A%2F%2Fgithub.com%2Ftoeverything%2FAFFiNE'
);
});
});

View File

@@ -73,11 +73,13 @@ export function OAuth({ redirectUrl }: { redirectUrl?: string }) {
params.set('redirect_uri', redirectUrl);
}
params.set('flow', 'redirect');
const oauthUrl =
serverService.server.baseUrl +
`/oauth/login?${params.toString()}`;
urlService.openPopupWindow(oauthUrl);
urlService.openExternal(oauthUrl);
};
const ret = open();

View File

@@ -13,10 +13,16 @@ import {
buildOpenAppUrlRoute,
} from '../../../modules/open-in-app';
import { supportedClient } from './common';
import {
type OAuthFlowMode,
parseOAuthCallbackState,
resolveOAuthRedirect,
} from './oauth-flow';
interface LoaderData {
state: string;
code: string;
flow: OAuthFlowMode;
provider: string;
}
@@ -31,12 +37,18 @@ export const loader: LoaderFunction = async ({ request }) => {
}
try {
const { state, client, provider } = JSON.parse(stateStr);
const { state, client, flow, provider } = parseOAuthCallbackState(stateStr);
if (!state || !provider) {
return redirect('/sign-in?error=Invalid oauth callback parameters');
}
stateStr = state;
const payload: LoaderData = {
state,
code,
flow,
provider,
};
@@ -79,8 +91,13 @@ export const Component = () => {
triggeredRef.current = true;
auth
.signInOauth(data.code, data.state, data.provider)
.then(() => {
window.close();
.then(({ redirectUri }) => {
if (data.flow === 'popup') {
window.close();
return;
}
location.replace(resolveOAuthRedirect(redirectUri, location.origin));
})
.catch(e => {
nav(`/sign-in?error=${encodeURIComponent(e.message)}`);

View File

@@ -0,0 +1,73 @@
export const oauthFlowModes = ['popup', 'redirect'] as const;
export type OAuthFlowMode = (typeof oauthFlowModes)[number];
export function resolveOAuthFlowMode(
mode?: string | null,
fallback: OAuthFlowMode = 'redirect'
): OAuthFlowMode {
return mode === 'popup' || mode === 'redirect' ? mode : fallback;
}
export function attachOAuthFlowToAuthUrl(url: string, flow: OAuthFlowMode) {
const authUrl = new URL(url);
const state = authUrl.searchParams.get('state');
if (!state) return url;
try {
const payload = JSON.parse(state) as Record<string, unknown>;
authUrl.searchParams.set('state', JSON.stringify({ ...payload, flow }));
return authUrl.toString();
} catch {
return url;
}
}
export function readOAuthFlowModeFromCallbackState(state: string | null) {
if (!state) return 'popup';
try {
const payload = JSON.parse(state) as { flow?: string };
return resolveOAuthFlowMode(payload.flow, 'popup');
} catch {
return 'popup';
}
}
export function parseOAuthCallbackState(state: string) {
const parsed = JSON.parse(state) as {
client?: string;
provider?: string;
state?: string;
};
return {
client: parsed.client,
flow: readOAuthFlowModeFromCallbackState(state),
provider: parsed.provider,
state: parsed.state,
};
}
export function resolveOAuthRedirect(
redirectUri: string | null | undefined,
currentOrigin: string
) {
if (!redirectUri) return '/';
if (redirectUri.startsWith('/') && !redirectUri.startsWith('//')) {
return redirectUri;
}
let target: URL;
try {
target = new URL(redirectUri);
} catch {
return '/';
}
if (target.origin === currentOrigin) return target.toString();
const redirectProxy = new URL('/redirect-proxy', currentOrigin);
redirectProxy.searchParams.set('redirect_uri', target.toString());
return redirectProxy.toString();
}

View File

@@ -12,6 +12,7 @@ import {
import { z } from 'zod';
import { supportedClient } from './common';
import { attachOAuthFlowToAuthUrl, resolveOAuthFlowMode } from './oauth-flow';
const supportedProvider = z.nativeEnum(OAuthProviderType);
const CSRF_COOKIE_NAME = 'affine_csrf_token';
@@ -36,12 +37,14 @@ const oauthParameters = z.object({
provider: supportedProvider,
client: supportedClient,
redirectUri: z.string().optional().nullable(),
flow: z.string().optional().nullable(),
});
interface LoaderData {
provider: OAuthProviderType;
client: string;
redirectUri?: string;
flow: string;
}
export const loader: LoaderFunction = async ({ request }) => {
@@ -50,6 +53,7 @@ export const loader: LoaderFunction = async ({ request }) => {
const provider = searchParams.get('provider');
const client = searchParams.get('client') ?? 'web';
const redirectUri = searchParams.get('redirect_uri');
const flow = searchParams.get('flow');
// sign out first, web only
if (client === 'web') {
@@ -64,6 +68,7 @@ export const loader: LoaderFunction = async ({ request }) => {
provider,
client,
redirectUri,
flow,
});
if (paramsParseResult.success) {
@@ -71,6 +76,7 @@ export const loader: LoaderFunction = async ({ request }) => {
provider,
client,
redirectUri,
flow: resolveOAuthFlowMode(flow),
};
}
@@ -90,7 +96,10 @@ export const Component = () => {
.oauthPreflight(data.provider, data.client, data.redirectUri)
.then(({ url }) => {
// this is the url of oauth provider auth page, can't navigate with react-router
location.href = url;
location.href = attachOAuthFlowToAuthUrl(
url,
resolveOAuthFlowMode(data.flow)
);
})
.catch(e => {
nav(`/sign-in?error=${encodeURIComponent(e.message)}`);