From 944fab36ac94b880455fe0cf8522e9567b6c7d23 Mon Sep 17 00:00:00 2001 From: DarkSky <25152247+darkskygit@users.noreply.github.com> Date: Thu, 5 Feb 2026 21:35:36 +0800 Subject: [PATCH] feat: drop outdated session (#14373) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit #### PR Dependency Tree * **PR #14373** 👈 This tree was auto-generated by [Charcoal](https://github.com/danerwilliams/charcoal) ## Summary by CodeRabbit * **New Features** * Added client version tracking and validation to ensure application compatibility across authentication flows and sessions. * Enhanced OAuth authentication with improved version handling during sign-in and refresh operations. * **Bug Fixes** * Improved payment callback URL handling with safer defaults for redirect links. * **Tests** * Expanded test coverage for client version enforcement and session management. --- .../migration.sql | 7 + packages/backend/server/schema.prisma | 12 +- .../src/__tests__/auth/controller.spec.ts | 28 +++ .../server/src/__tests__/auth/guard.spec.ts | 209 ++++++++++++++---- .../src/__tests__/models/session.spec.ts | 58 +++++ .../src/__tests__/nestjs/throttler.spec.ts | 67 ++++++ .../src/__tests__/oauth/controller.spec.ts | 60 ++++- .../server/src/__tests__/version.spec.ts | 12 +- .../src/base/helpers/__tests__/url.spec.ts | 23 ++ .../backend/server/src/base/helpers/url.ts | 2 +- .../server/src/base/throttler/index.ts | 8 +- .../backend/server/src/base/utils/index.ts | 1 + .../server/src/base/utils/request-tracker.ts | 44 ++++ .../backend/server/src/core/auth/guard.ts | 111 +++++++++- .../backend/server/src/core/auth/service.ts | 38 +++- .../backend/server/src/core/version/guard.ts | 3 +- .../server/src/core/version/service.ts | 32 ++- packages/backend/server/src/models/session.ts | 9 +- .../server/src/plugins/captcha/service.ts | 9 +- .../server/src/plugins/oauth/controller.ts | 6 +- .../backend/server/src/plugins/oauth/types.ts | 1 + .../src/plugins/payment/manager/selfhost.ts | 2 +- .../src/plugins/payment/manager/user.ts | 2 +- .../src/plugins/payment/manager/workspace.ts | 2 +- .../src/desktop/pages/auth/magic-link.tsx | 18 +- .../src/desktop/pages/auth/oauth-callback.tsx | 18 +- .../core/src/desktop/pages/open-app/index.tsx | 39 ++-- .../open-in-app/__tests__/deeplink.spec.ts | 71 ++++++ .../core/src/modules/open-in-app/utils.ts | 78 ++++++- 29 files changed, 845 insertions(+), 125 deletions(-) create mode 100644 packages/backend/server/migrations/20260204105925_user_sessions_client_version/migration.sql create mode 100644 packages/backend/server/src/base/utils/request-tracker.ts create mode 100644 packages/frontend/core/src/modules/open-in-app/__tests__/deeplink.spec.ts diff --git a/packages/backend/server/migrations/20260204105925_user_sessions_client_version/migration.sql b/packages/backend/server/migrations/20260204105925_user_sessions_client_version/migration.sql new file mode 100644 index 0000000000..0ed16fdc58 --- /dev/null +++ b/packages/backend/server/migrations/20260204105925_user_sessions_client_version/migration.sql @@ -0,0 +1,7 @@ +-- AlterTable +ALTER TABLE + "user_sessions" +ADD + COLUMN "sign_in_client_version" VARCHAR, +ADD + COLUMN "refresh_client_version" VARCHAR; \ No newline at end of file diff --git a/packages/backend/server/schema.prisma b/packages/backend/server/schema.prisma index 94a7b020dd..dfdd49bfb9 100644 --- a/packages/backend/server/schema.prisma +++ b/packages/backend/server/schema.prisma @@ -83,11 +83,13 @@ model Session { } model UserSession { - id String @id @default(uuid()) @db.VarChar - sessionId String @map("session_id") @db.VarChar - userId String @map("user_id") @db.VarChar - expiresAt DateTime? @map("expires_at") @db.Timestamptz(3) - createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3) + id String @id @default(uuid()) @db.VarChar + sessionId String @map("session_id") @db.VarChar + userId String @map("user_id") @db.VarChar + expiresAt DateTime? @map("expires_at") @db.Timestamptz(3) + signInClientVersion String? @map("sign_in_client_version") @db.VarChar + refreshClientVersion String? @map("refresh_client_version") @db.VarChar + createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3) session Session @relation(fields: [sessionId], references: [id], onDelete: Cascade) user User @relation(fields: [userId], references: [id], onDelete: Cascade) diff --git a/packages/backend/server/src/__tests__/auth/controller.spec.ts b/packages/backend/server/src/__tests__/auth/controller.spec.ts index 30bcebb5a8..cb43a16469 100644 --- a/packages/backend/server/src/__tests__/auth/controller.spec.ts +++ b/packages/backend/server/src/__tests__/auth/controller.spec.ts @@ -53,6 +53,34 @@ test('should be able to sign in with credential', async t => { t.is(session?.id, u1.id); }); +test('should record sign in client version when header is provided', async t => { + const { app, db } = t.context; + + const u1 = await app.createUser('u1@affine.pro'); + + await app + .POST('/api/auth/sign-in') + .set('x-affine-version', '0.25.1') + .send({ email: u1.email, password: u1.password }) + .expect(200); + + const userSession1 = await db.userSession.findFirst({ + where: { userId: u1.id }, + }); + t.is(userSession1?.signInClientVersion, '0.25.1'); + + // should not overwrite existing value with null/undefined + await app + .POST('/api/auth/sign-in') + .send({ email: u1.email, password: u1.password }) + .expect(200); + + const userSession2 = await db.userSession.findFirst({ + where: { userId: u1.id }, + }); + t.is(userSession2?.signInClientVersion, '0.25.1'); +}); + test('should be able to sign in with email', async t => { const { app } = t.context; diff --git a/packages/backend/server/src/__tests__/auth/guard.spec.ts b/packages/backend/server/src/__tests__/auth/guard.spec.ts index 934b83f09c..ecd3d21d31 100644 --- a/packages/backend/server/src/__tests__/auth/guard.spec.ts +++ b/packages/backend/server/src/__tests__/auth/guard.spec.ts @@ -1,13 +1,14 @@ -import { Controller, Get, HttpStatus, INestApplication } from '@nestjs/common'; +import { Controller, Get, HttpStatus } from '@nestjs/common'; import { PrismaClient } from '@prisma/client'; import ava, { TestFn } from 'ava'; import Sinon from 'sinon'; import request from 'supertest'; +import { ConfigFactory } from '../../base'; import { AuthModule, CurrentUser, Public, Session } from '../../core/auth'; import { AuthService } from '../../core/auth/service'; import { Models } from '../../models'; -import { createTestingApp } from '../utils'; +import { createTestingApp, TestingApp } from '../utils'; @Controller('/') class TestController { @@ -29,31 +30,46 @@ class TestController { } const test = ava as TestFn<{ - app: INestApplication; + app: TestingApp; + server: any; + auth: AuthService; + models: Models; + db: PrismaClient; + config: ConfigFactory; + u1: CurrentUser; + sessionId: string; }>; -let server!: any; -let auth!: AuthService; -let u1!: CurrentUser; - -let sessionId = ''; - test.before(async t => { const app = await createTestingApp({ imports: [AuthModule], controllers: [TestController], }); - auth = app.get(AuthService); - u1 = await auth.signUp('u1@affine.pro', '1'); - - const models = app.get(Models); - const session = await models.session.createSession(); - sessionId = session.id; - await auth.createUserSession(u1.id, sessionId); - - server = app.getHttpServer(); t.context.app = app; + t.context.server = app.getHttpServer(); + t.context.auth = app.get(AuthService); + t.context.models = app.get(Models); + t.context.db = app.get(PrismaClient); + t.context.config = app.get(ConfigFactory); +}); + +test.beforeEach(async t => { + Sinon.restore(); + await t.context.app.initTestingDB(); + t.context.config.override({ + client: { + versionControl: { + enabled: false, + requiredVersion: '>=0.25.0', + }, + }, + }); + + t.context.u1 = await t.context.auth.signUp('u1@affine.pro', '1'); + const session = await t.context.models.session.createSession(); + t.context.sessionId = session.id; + await t.context.auth.createUserSession(t.context.u1.id, t.context.sessionId); }); test.after.always(async t => { @@ -61,92 +77,95 @@ test.after.always(async t => { }); test('should be able to visit public api if not signed in', async t => { - const res = await request(server).get('/public').expect(200); + const res = await request(t.context.server).get('/public').expect(200); t.is(res.body.user, undefined); }); test('should be able to visit public api if signed in', async t => { - const res = await request(server) + const res = await request(t.context.server) .get('/public') - .set('Cookie', `${AuthService.sessionCookieName}=${sessionId}`) + .set('Cookie', `${AuthService.sessionCookieName}=${t.context.sessionId}`) .expect(HttpStatus.OK); - t.is(res.body.user.id, u1.id); + t.is(res.body.user.id, t.context.u1.id); }); test('should not be able to visit private api if not signed in', async t => { - await request(server).get('/private').expect(HttpStatus.UNAUTHORIZED).expect({ - status: 401, - code: 'Unauthorized', - type: 'AUTHENTICATION_REQUIRED', - name: 'AUTHENTICATION_REQUIRED', - message: 'You must sign in first to access this resource.', - }); + await request(t.context.server) + .get('/private') + .expect(HttpStatus.UNAUTHORIZED) + .expect({ + status: 401, + code: 'Unauthorized', + type: 'AUTHENTICATION_REQUIRED', + name: 'AUTHENTICATION_REQUIRED', + message: 'You must sign in first to access this resource.', + }); t.assert(true); }); test('should be able to visit private api if signed in', async t => { - const res = await request(server) + const res = await request(t.context.server) .get('/private') - .set('Cookie', `${AuthService.sessionCookieName}=${sessionId}`) + .set('Cookie', `${AuthService.sessionCookieName}=${t.context.sessionId}`) .expect(HttpStatus.OK); - t.is(res.body.user.id, u1.id); + t.is(res.body.user.id, t.context.u1.id); }); test('should be able to visit private api with access token', async t => { const models = t.context.app.get(Models); const token = await models.accessToken.create({ - userId: u1.id, + userId: t.context.u1.id, name: 'test', }); - const res = await request(server) + const res = await request(t.context.server) .get('/private') .set('Authorization', `Bearer ${token.token}`) .expect(HttpStatus.OK); - t.is(res.body.user.id, u1.id); + t.is(res.body.user.id, t.context.u1.id); }); test('should be able to parse session cookie', async t => { - const spy = Sinon.spy(auth, 'getUserSession'); - await request(server) + const spy = Sinon.spy(t.context.auth, 'getUserSession'); + await request(t.context.server) .get('/public') - .set('cookie', `${AuthService.sessionCookieName}=${sessionId}`) + .set('cookie', `${AuthService.sessionCookieName}=${t.context.sessionId}`) .expect(200); - t.deepEqual(spy.firstCall.args, [sessionId, undefined]); + t.deepEqual(spy.firstCall.args, [t.context.sessionId, undefined]); spy.restore(); }); test('should be able to parse bearer token', async t => { - const spy = Sinon.spy(auth, 'getUserSession'); + const spy = Sinon.spy(t.context.auth, 'getUserSession'); - await request(server) + await request(t.context.server) .get('/public') - .auth(sessionId, { type: 'bearer' }) + .auth(t.context.sessionId, { type: 'bearer' }) .expect(200); - t.deepEqual(spy.firstCall.args, [sessionId, undefined]); + t.deepEqual(spy.firstCall.args, [t.context.sessionId, undefined]); spy.restore(); }); test('should be able to refresh session if needed', async t => { await t.context.app.get(PrismaClient).userSession.updateMany({ where: { - sessionId, + sessionId: t.context.sessionId, }, data: { expiresAt: new Date(Date.now() + 1000 * 60 * 60 /* expires in 1 hour */), }, }); - const res = await request(server) + const res = await request(t.context.server) .get('/session') - .set('cookie', `${AuthService.sessionCookieName}=${sessionId}`) + .set('cookie', `${AuthService.sessionCookieName}=${t.context.sessionId}`) .expect(200); const cookie = res @@ -155,3 +174,101 @@ test('should be able to refresh session if needed', async t => { t.truthy(cookie); }); + +test('should record refresh client version when refreshed', async t => { + await t.context.db.userSession.updateMany({ + where: { sessionId: t.context.sessionId }, + data: { + expiresAt: new Date(Date.now() + 1000 * 60 * 60 /* expires in 1 hour */), + }, + }); + + await request(t.context.server) + .get('/session') + .set('cookie', `${AuthService.sessionCookieName}=${t.context.sessionId}`) + .set('x-affine-version', '0.25.2') + .expect(200); + + const userSession = await t.context.db.userSession.findFirst({ + where: { sessionId: t.context.sessionId, userId: t.context.u1.id }, + }); + t.is(userSession?.refreshClientVersion, '0.25.2'); +}); + +test('should allow auth when header is missing but stored version is valid', async t => { + t.context.config.override({ + client: { + versionControl: { + enabled: true, + requiredVersion: '>=0.25.0', + }, + }, + }); + + await t.context.db.userSession.updateMany({ + where: { sessionId: t.context.sessionId }, + data: { signInClientVersion: '0.25.0' }, + }); + + const res = await request(t.context.server) + .get('/private') + .set('Cookie', `${AuthService.sessionCookieName}=${t.context.sessionId}`) + .expect(200); + + t.is(res.body.user.id, t.context.u1.id); +}); + +test('should kick out unsupported client version on non-public handler', async t => { + t.context.config.override({ + client: { + versionControl: { + enabled: true, + requiredVersion: '>=0.25.0', + }, + }, + }); + + const res = await request(t.context.server) + .get('/private') + .set('Cookie', `${AuthService.sessionCookieName}=${t.context.sessionId}`) + .set('x-affine-version', '0.24.0') + .expect(403); + + const setCookies = res.get('Set-Cookie') ?? []; + t.true( + setCookies.some(c => c.startsWith(`${AuthService.sessionCookieName}=`)) + ); + t.true(setCookies.some(c => c.startsWith(`${AuthService.userCookieName}=`))); + t.true(setCookies.some(c => c.startsWith(`${AuthService.csrfCookieName}=`))); + + const session = await t.context.db.session.findFirst({ + where: { id: t.context.sessionId }, + }); + t.is(session, null); +}); + +test('should not block public handler when client version is unsupported', async t => { + t.context.config.override({ + client: { + versionControl: { + enabled: true, + requiredVersion: '>=0.25.0', + }, + }, + }); + + const res = await request(t.context.server) + .get('/public') + .set('Cookie', `${AuthService.sessionCookieName}=${t.context.sessionId}`) + .set('x-affine-version', '0.24.0') + .expect(200); + + t.is(res.body.user, undefined); + + const setCookies = res.get('Set-Cookie') ?? []; + t.true( + setCookies.some(c => c.startsWith(`${AuthService.sessionCookieName}=`)) + ); + t.true(setCookies.some(c => c.startsWith(`${AuthService.userCookieName}=`))); + t.true(setCookies.some(c => c.startsWith(`${AuthService.csrfCookieName}=`))); +}); diff --git a/packages/backend/server/src/__tests__/models/session.spec.ts b/packages/backend/server/src/__tests__/models/session.spec.ts index bf11545036..50481ddd94 100644 --- a/packages/backend/server/src/__tests__/models/session.spec.ts +++ b/packages/backend/server/src/__tests__/models/session.spec.ts @@ -122,6 +122,64 @@ test('should refresh exists userSession', async t => { ); }); +test('should record sign-in client version on create and update', async t => { + const user = await t.context.user.create({ + email: 'test@affine.pro', + }); + const session = await t.context.session.createSession(); + + const userSession1 = await t.context.session.createOrRefreshUserSession( + user.id, + session.id, + undefined, + '0.25.0' + ); + t.is(userSession1.signInClientVersion, '0.25.0'); + + const userSession2 = await t.context.session.createOrRefreshUserSession( + user.id, + session.id + ); + t.is(userSession2.signInClientVersion, '0.25.0'); + + const userSession3 = await t.context.session.createOrRefreshUserSession( + user.id, + session.id, + undefined, + '0.26.0' + ); + t.is(userSession3.signInClientVersion, '0.26.0'); +}); + +test('should record refresh client version only when refreshed', async t => { + const user = await t.context.user.create({ + email: 'test@affine.pro', + }); + const session = await t.context.session.createSession(); + const userSession = await t.context.session.createOrRefreshUserSession( + user.id, + session.id + ); + + // force refresh + userSession.expiresAt = new Date( + userSession.expiresAt!.getTime() - + t.context.config.auth.session.ttr * 2 * 1000 + ); + + const newExpiresAt = await t.context.session.refreshUserSessionIfNeeded( + userSession, + undefined, + '0.25.0' + ); + t.truthy(newExpiresAt); + + const refreshed = await t.context.db.userSession.findFirst({ + where: { id: userSession.id }, + }); + t.is(refreshed?.refreshClientVersion, '0.25.0'); +}); + test('should not refresh userSession when expires time not hit ttr', async t => { const user = await t.context.user.create({ email: 'test@affine.pro', diff --git a/packages/backend/server/src/__tests__/nestjs/throttler.spec.ts b/packages/backend/server/src/__tests__/nestjs/throttler.spec.ts index e535ac2747..d9bfda647c 100644 --- a/packages/backend/server/src/__tests__/nestjs/throttler.spec.ts +++ b/packages/backend/server/src/__tests__/nestjs/throttler.spec.ts @@ -1,4 +1,5 @@ import { Controller, Get, HttpStatus, UseGuards } from '@nestjs/common'; +import { PrismaClient } from '@prisma/client'; import ava, { TestFn } from 'ava'; import Sinon from 'sinon'; import { type Response } from 'supertest'; @@ -144,6 +145,72 @@ test('should be able to prevent requests if limit is reached', async t => { stub.restore(); }); +test('should use session id as tracker when available', async t => { + const { app } = t.context; + + const user = await app.signupV1('u1@affine.pro'); + const userSession = await app.get(PrismaClient).userSession.findFirst({ + where: { userId: user.id }, + }); + t.truthy(userSession); + + const stub = Sinon.stub(app.get(ThrottlerStorage), 'increment').resolves({ + timeToExpire: 10, + totalHits: 1, + isBlocked: false, + timeToBlockExpire: 0, + }); + + await app.GET('/throttled/default').expect(200); + + const key = stub.firstCall.args[0] as string; + t.true(key.startsWith(`throttler:${userSession!.sessionId};default`)); + + stub.restore(); +}); + +test('should use CF-Connecting-IP as tracker when present', async t => { + const { app } = t.context; + + const stub = Sinon.stub(app.get(ThrottlerStorage), 'increment').resolves({ + timeToExpire: 10, + totalHits: 1, + isBlocked: false, + timeToBlockExpire: 0, + }); + + await app + .GET('/nonthrottled/default') + .set('CF-Connecting-IP', '1.2.3.4') + .expect(200); + + const key = stub.firstCall.args[0] as string; + t.true(key.startsWith('throttler:1.2.3.4;default')); + + stub.restore(); +}); + +test('should use X-Forwarded-For as tracker when present', async t => { + const { app } = t.context; + + const stub = Sinon.stub(app.get(ThrottlerStorage), 'increment').resolves({ + timeToExpire: 10, + totalHits: 1, + isBlocked: false, + timeToBlockExpire: 0, + }); + + await app + .GET('/nonthrottled/default') + .set('X-Forwarded-For', '5.6.7.8, 9.9.9.9') + .expect(200); + + const key = stub.firstCall.args[0] as string; + t.true(key.startsWith('throttler:5.6.7.8;default')); + + stub.restore(); +}); + // ====== unauthenticated user visits ====== test('should use default throttler for unauthenticated user when not specified', async t => { const { app } = t.context; diff --git a/packages/backend/server/src/__tests__/oauth/controller.spec.ts b/packages/backend/server/src/__tests__/oauth/controller.spec.ts index 3f2723dd92..d2b27845f1 100644 --- a/packages/backend/server/src/__tests__/oauth/controller.spec.ts +++ b/packages/backend/server/src/__tests__/oauth/controller.spec.ts @@ -6,7 +6,7 @@ import ava, { TestFn } from 'ava'; import Sinon from 'sinon'; import { AppModule } from '../../app.module'; -import { URLHelper } from '../../base'; +import { ConfigFactory, URLHelper } from '../../base'; import { ConfigModule } from '../../base/config'; import { CurrentUser } from '../../core/auth'; import { AuthService } from '../../core/auth/service'; @@ -56,6 +56,14 @@ test.before(async t => { test.beforeEach(async t => { Sinon.restore(); await t.context.app.initTestingDB(); + t.context.app.get(ConfigFactory).override({ + client: { + versionControl: { + enabled: false, + requiredVersion: '>=0.25.0', + }, + }, + }); t.context.u1 = await t.context.auth.signUp('u1@affine.pro', '1'); }); @@ -156,6 +164,56 @@ test('should be able to redirect to oauth provider with client_nonce', async t = t.truthy(state.state); }); +test('should record sign in client version from oauth preflight state', async t => { + const { app, db } = t.context; + + const config = app.get(ConfigFactory); + config.override({ + client: { + versionControl: { + enabled: true, + requiredVersion: '>=0.25.0', + }, + }, + }); + + const preflightRes = await app + .POST('/api/oauth/preflight') + .set('x-affine-version', '0.25.3') + .send({ provider: 'Google', client_nonce: 'test-nonce' }) + .expect(HttpStatus.OK); + + const redirect = new URL(preflightRes.body.url as string); + const stateParam = redirect.searchParams.get('state'); + t.truthy(stateParam); + + // state should be a json string + const rawState = JSON.parse(stateParam!); + + const provider = app.get(GoogleOAuthProvider); + Sinon.stub(provider, 'getToken').resolves({ accessToken: '1' }); + Sinon.stub(provider, 'getUser').resolves({ + id: '1', + email: 'oauth-version@affine.pro', + avatarUrl: 'avatar', + }); + + const callbackRes = await app + .POST('/api/oauth/callback') + .send({ code: '1', state: stateParam, client_nonce: 'test-nonce' }) + .expect(HttpStatus.OK); + + const userId = callbackRes.body.id as string; + t.truthy(userId); + + const userSession = await db.userSession.findFirst({ + where: { userId }, + }); + t.is(userSession?.signInClientVersion, '0.25.3'); + t.is(userSession?.refreshClientVersion, null); + t.truthy(rawState.state); +}); + test('should forbid preflight with untrusted redirect_uri', async t => { const { app } = t.context; diff --git a/packages/backend/server/src/__tests__/version.spec.ts b/packages/backend/server/src/__tests__/version.spec.ts index 814ee85894..6e156273b7 100644 --- a/packages/backend/server/src/__tests__/version.spec.ts +++ b/packages/backend/server/src/__tests__/version.spec.ts @@ -73,7 +73,7 @@ test('should passthrough if version check is not enabled', async t => { spy.restore(); }); -test('should passthrough is version range is invalid', async t => { +test('should enforce hard required version when version range is invalid', async t => { config.override({ client: { versionControl: { @@ -82,9 +82,17 @@ test('should passthrough is version range is invalid', async t => { }, }); - let res = await app.GET('/guarded/test').set('x-affine-version', 'invalid'); + let res = await app.GET('/guarded/test').set('x-affine-version', '0.25.0'); t.is(res.status, 200); + + res = await app.GET('/guarded/test').set('x-affine-version', 'invalid'); + + t.is(res.status, 403); + t.is( + res.body.message, + 'Unsupported client with version [invalid], required version is [>=0.25.0].' + ); }); test('should pass if client version is allowed', async t => { diff --git a/packages/backend/server/src/base/helpers/__tests__/url.spec.ts b/packages/backend/server/src/base/helpers/__tests__/url.spec.ts index 1cd73e5ba8..dfa19e9486 100644 --- a/packages/backend/server/src/base/helpers/__tests__/url.spec.ts +++ b/packages/backend/server/src/base/helpers/__tests__/url.spec.ts @@ -86,6 +86,29 @@ test('can create link', t => { ); }); +test('addSimpleQuery should not double encode', t => { + t.is( + t.context.url.addSimpleQuery( + 'https://app.affine.local/path', + 'redirect_uri', + '/path' + ), + 'https://app.affine.local/path?redirect_uri=%2Fpath' + ); +}); + +test('addSimpleQuery should allow unescaped value when escape=false', t => { + t.is( + t.context.url.addSimpleQuery( + 'https://app.affine.local/path', + 'session_id', + '{CHECKOUT_SESSION_ID}', + false + ), + 'https://app.affine.local/path?session_id={CHECKOUT_SESSION_ID}' + ); +}); + test('can validate callbackUrl allowlist', t => { t.true(t.context.url.isAllowedCallbackUrl('/magic-link')); t.true( diff --git a/packages/backend/server/src/base/helpers/url.ts b/packages/backend/server/src/base/helpers/url.ts index 83caaabbc8..bf61d54eca 100644 --- a/packages/backend/server/src/base/helpers/url.ts +++ b/packages/backend/server/src/base/helpers/url.ts @@ -109,7 +109,7 @@ export class URLHelper { ) { const urlObj = new URL(url); if (escape) { - urlObj.searchParams.set(key, encodeURIComponent(value)); + urlObj.searchParams.set(key, String(value)); return urlObj.toString(); } else { const query = diff --git a/packages/backend/server/src/base/throttler/index.ts b/packages/backend/server/src/base/throttler/index.ts index a20174aa8e..67e51e3a56 100644 --- a/packages/backend/server/src/base/throttler/index.ts +++ b/packages/backend/server/src/base/throttler/index.ts @@ -16,6 +16,7 @@ import type { Request, Response } from 'express'; import { Config } from '../config'; import { getRequestResponseFromContext } from '../utils/request'; +import { getRequestTrackerId } from '../utils/request-tracker'; import type { ThrottlerType } from './config'; import { THROTTLER_PROTECTED, Throttlers } from './decorators'; @@ -63,11 +64,8 @@ export class CloudThrottlerGuard extends ThrottlerGuard { } override getTracker(req: Request): Promise { - return Promise.resolve( - // ↓ prefer session id if available - `throttler:${req.session?.sessionId ?? req.get('CF-Connecting-IP') ?? req.get('CF-ray') ?? req.ip}` - // ^ throttler prefix make the key in store recognizable - ); + // throttler prefix make the key in store recognizable + return Promise.resolve(`throttler:${getRequestTrackerId(req)}`); } override generateKey( diff --git a/packages/backend/server/src/base/utils/index.ts b/packages/backend/server/src/base/utils/index.ts index a385f25a52..d1f5694df1 100644 --- a/packages/backend/server/src/base/utils/index.ts +++ b/packages/backend/server/src/base/utils/index.ts @@ -1,6 +1,7 @@ export * from './duration'; export * from './promise'; export * from './request'; +export * from './request-tracker'; export * from './ssrf'; export * from './stream'; export * from './types'; diff --git a/packages/backend/server/src/base/utils/request-tracker.ts b/packages/backend/server/src/base/utils/request-tracker.ts new file mode 100644 index 0000000000..8fae68f420 --- /dev/null +++ b/packages/backend/server/src/base/utils/request-tracker.ts @@ -0,0 +1,44 @@ +import type { Request } from 'express'; + +function firstForwardedForIp(value?: string) { + if (!value) { + return; + } + + const [first] = value.split(',', 1); + const ip = first?.trim(); + + return ip || undefined; +} + +function firstNonEmpty(...values: Array) { + for (const value of values) { + const trimmed = value?.trim(); + if (trimmed) { + return trimmed; + } + } + return; +} + +export function getRequestClientIp(req: Request) { + return firstNonEmpty( + req.get('CF-Connecting-IP'), + firstForwardedForIp(req.get('X-Forwarded-For')), + req.get('X-Real-IP'), + req.ip + )!; +} + +export function getRequestTrackerId(req: Request) { + return ( + req.session?.sessionId ?? + firstNonEmpty( + req.get('CF-Connecting-IP'), + firstForwardedForIp(req.get('X-Forwarded-For')), + req.get('X-Real-IP'), + req.get('CF-Ray'), + req.ip + )! + ); +} diff --git a/packages/backend/server/src/core/auth/guard.ts b/packages/backend/server/src/core/auth/guard.ts index fcee58dab9..5c7f874fcc 100644 --- a/packages/backend/server/src/core/auth/guard.ts +++ b/packages/backend/server/src/core/auth/guard.ts @@ -7,6 +7,7 @@ import type { import { Injectable, SetMetadata } from '@nestjs/common'; import { ModuleRef, Reflector } from '@nestjs/core'; import type { Request, Response } from 'express'; +import semver from 'semver'; import { Socket } from 'socket.io'; import { @@ -15,8 +16,10 @@ import { Cache, Config, CryptoHelper, + getClientVersionFromRequest, getRequestResponseFromContext, parseCookies, + UnsupportedClientVersion, } from '../../base'; import { WEBSOCKET_OPTIONS } from '../../base/websocket'; import { AuthService } from './service'; @@ -30,10 +33,13 @@ const INTERNAL_ACCESS_TOKEN_CLOCK_SKEW_MS = 30 * 1000; @Injectable() export class AuthGuard implements CanActivate, OnModuleInit { private auth!: AuthService; + private readonly cachedVersionRange = new Map(); + private static readonly HARD_REQUIRED_VERSION = '>=0.25.0'; constructor( private readonly crypto: CryptoHelper, private readonly cache: Cache, + private readonly config: Config, private readonly ref: ModuleRef, private readonly reflector: Reflector ) {} @@ -78,14 +84,14 @@ export class AuthGuard implements CanActivate, OnModuleInit { throw new AccessDenied('Invalid internal request'); } - const authedUser = await this.signIn(req, res); - // api is public const isPublic = this.reflector.getAllAndOverride( PUBLIC_ENTRYPOINT_SYMBOL, [clazz, handler] ); + const authedUser = await this.signIn(req, res, isPublic); + if (isPublic) { return true; } @@ -99,9 +105,10 @@ export class AuthGuard implements CanActivate, OnModuleInit { async signIn( req: Request, - res?: Response + res?: Response, + isPublic = false ): Promise { - const userSession = await this.signInWithCookie(req, res); + const userSession = await this.signInWithCookie(req, res, isPublic); if (userSession) { return userSession; } @@ -111,7 +118,8 @@ export class AuthGuard implements CanActivate, OnModuleInit { async signInWithCookie( req: Request, - res?: Response + res?: Response, + isPublic = false ): Promise { if (req.session) { return req.session; @@ -121,8 +129,38 @@ export class AuthGuard implements CanActivate, OnModuleInit { const userSession = await this.auth.getUserSessionFromRequest(req, res); if (userSession) { + const headerClientVersion = getClientVersionFromRequest(req); + if (this.config.client.versionControl.enabled) { + const clientVersion = + headerClientVersion ?? + userSession.session.refreshClientVersion ?? + userSession.session.signInClientVersion; + + const versionCheckResult = this.checkClientVersion(clientVersion); + if (!versionCheckResult.ok) { + await this.auth.signOut(userSession.session.sessionId); + if (res) { + await this.auth.refreshCookies(res, userSession.session.sessionId); + } + + if (isPublic) { + return null; + } + + throw new UnsupportedClientVersion({ + clientVersion: clientVersion ?? 'unset_or_invalid', + requiredVersion: versionCheckResult.requiredVersion, + }); + } + } + if (res) { - await this.auth.refreshUserSessionIfNeeded(res, userSession.session); + await this.auth.refreshUserSessionIfNeeded( + res, + userSession.session, + undefined, + headerClientVersion + ); } req.session = { @@ -154,6 +192,59 @@ export class AuthGuard implements CanActivate, OnModuleInit { return null; } + + private getVersionRange(versionRange: string): semver.Range | null { + if (this.cachedVersionRange.has(versionRange)) { + // oxlint-disable-next-line @typescript-eslint/no-non-null-assertion + return this.cachedVersionRange.get(versionRange)!; + } + + let range: semver.Range | null = null; + try { + range = new semver.Range(versionRange, { loose: false }); + if (!semver.validRange(range)) { + range = null; + } + } catch { + range = null; + } + + this.cachedVersionRange.set(versionRange, range); + return range; + } + + private checkClientVersion( + clientVersion?: string | null + ): { ok: true } | { ok: false; requiredVersion: string } { + const requiredVersion = this.config.client.versionControl.requiredVersion; + + const configRange = this.getVersionRange(requiredVersion); + if ( + configRange && + (!clientVersion || + !semver.satisfies(clientVersion, configRange, { + includePrerelease: true, + })) + ) { + return { ok: false, requiredVersion }; + } + + const hardRange = this.getVersionRange(AuthGuard.HARD_REQUIRED_VERSION); + if (!hardRange) { + return { ok: true }; + } + + if ( + !clientVersion || + !semver.satisfies(clientVersion, hardRange, { + includePrerelease: true, + }) + ) { + return { ok: false, requiredVersion: AuthGuard.HARD_REQUIRED_VERSION }; + } + + return { ok: true }; + } } /** @@ -184,7 +275,13 @@ export const AuthWebsocketOptionsProvider: FactoryProvider = { ...upgradeReq.cookies, }; - const session = await guard.signIn(upgradeReq); + const session = await (async () => { + try { + return await guard.signIn(upgradeReq); + } catch { + return null; + } + })(); return !!session; }, diff --git a/packages/backend/server/src/core/auth/service.ts b/packages/backend/server/src/core/auth/service.ts index 7e6a71f0ce..87d1223a0a 100644 --- a/packages/backend/server/src/core/auth/service.ts +++ b/packages/backend/server/src/core/auth/service.ts @@ -4,7 +4,11 @@ import { Injectable, OnApplicationBootstrap } from '@nestjs/common'; import type { CookieOptions, Request, Response } from 'express'; import { assign, pick } from 'lodash-es'; -import { Config, SignUpForbidden } from '../../base'; +import { + Config, + getClientVersionFromRequest, + SignUpForbidden, +} from '../../base'; import { Models, type User, type UserSession } from '../../models'; import { Mailer } from '../mail/mailer'; import { createDevUsers } from './dev'; @@ -130,11 +134,17 @@ export class AuthService implements OnApplicationBootstrap { return await this.models.session.findUserSessionsBySessionId(sessionId); } - async createUserSession(userId: string, sessionId?: string, ttl?: number) { + async createUserSession( + userId: string, + sessionId?: string, + ttl?: number, + signInClientVersion?: string + ) { return await this.models.session.createOrRefreshUserSession( userId, sessionId, - ttl + ttl, + signInClientVersion ); } @@ -159,11 +169,13 @@ export class AuthService implements OnApplicationBootstrap { async refreshUserSessionIfNeeded( res: Response, userSession: UserSession, - ttr?: number + ttr?: number, + refreshClientVersion?: string ): Promise { const newExpiresAt = await this.models.session.refreshUserSessionIfNeeded( userSession, - ttr + ttr, + refreshClientVersion ); if (!newExpiresAt) { // no need to refresh @@ -205,10 +217,22 @@ export class AuthService implements OnApplicationBootstrap { }; } - async setCookies(req: Request, res: Response, userId: string) { + async setCookies( + req: Request, + res: Response, + userId: string, + clientVersion?: string + ) { const { sessionId } = this.getSessionOptionsFromRequest(req); - const userSession = await this.createUserSession(userId, sessionId); + const signInClientVersion = + clientVersion ?? getClientVersionFromRequest(req); + const userSession = await this.createUserSession( + userId, + sessionId, + undefined, + signInClientVersion + ); res.cookie(AuthService.sessionCookieName, userSession.sessionId, { ...this.cookieOptions, diff --git a/packages/backend/server/src/core/version/guard.ts b/packages/backend/server/src/core/version/guard.ts index ee3ddccb16..7a6ce5f07e 100644 --- a/packages/backend/server/src/core/version/guard.ts +++ b/packages/backend/server/src/core/version/guard.ts @@ -7,6 +7,7 @@ import { Injectable } from '@nestjs/common'; import { Config, + getClientVersionFromRequest, getRequestResponseFromContext, GuardProvider, } from '../../base'; @@ -33,7 +34,7 @@ export class VersionGuardProvider const { req } = getRequestResponseFromContext(context); - const version = req.headers['x-affine-version'] as string | undefined; + const version = getClientVersionFromRequest(req); return this.version.checkVersion(version); } diff --git a/packages/backend/server/src/core/version/service.ts b/packages/backend/server/src/core/version/service.ts index 722094d7ea..bc7705103e 100644 --- a/packages/backend/server/src/core/version/service.ts +++ b/packages/backend/server/src/core/version/service.ts @@ -6,23 +6,24 @@ import { Config, UnsupportedClientVersion } from '../../base'; @Injectable() export class VersionService { private readonly logger = new Logger(VersionService.name); + private static readonly HARD_REQUIRED_VERSION = '>=0.25.0'; constructor(private readonly config: Config) {} async checkVersion(clientVersion?: string) { const requiredVersion = this.config.client.versionControl.requiredVersion; - const range = await this.getVersionRange(requiredVersion); - if (!range) { - // ignore invalid allowed version config - return true; - } + const hardRange = await this.getVersionRange( + VersionService.HARD_REQUIRED_VERSION + ); + const configRange = await this.getVersionRange(requiredVersion); if ( - !clientVersion || - !semver.satisfies(clientVersion, range, { - includePrerelease: true, - }) + configRange && + (!clientVersion || + !semver.satisfies(clientVersion, configRange, { + includePrerelease: true, + })) ) { throw new UnsupportedClientVersion({ clientVersion: clientVersion ?? 'unset_or_invalid', @@ -30,6 +31,19 @@ export class VersionService { }); } + if ( + hardRange && + (!clientVersion || + !semver.satisfies(clientVersion, hardRange, { + includePrerelease: true, + })) + ) { + throw new UnsupportedClientVersion({ + clientVersion: clientVersion ?? 'unset_or_invalid', + requiredVersion: VersionService.HARD_REQUIRED_VERSION, + }); + } + return true; } diff --git a/packages/backend/server/src/models/session.ts b/packages/backend/server/src/models/session.ts index 44940b27cb..e37cec0e77 100644 --- a/packages/backend/server/src/models/session.ts +++ b/packages/backend/server/src/models/session.ts @@ -46,7 +46,8 @@ export class SessionModel extends BaseModel { async createOrRefreshUserSession( userId: string, sessionId?: string, - ttl = this.config.auth.session.ttl + ttl = this.config.auth.session.ttl, + signInClientVersion?: string ) { // check whether given session is valid if (sessionId) { @@ -76,18 +77,21 @@ export class SessionModel extends BaseModel { }, update: { expiresAt, + ...(signInClientVersion ? { signInClientVersion } : {}), }, create: { sessionId, userId, expiresAt, + ...(signInClientVersion ? { signInClientVersion } : {}), }, }); } async refreshUserSessionIfNeeded( userSession: UserSession, - ttr = this.config.auth.session.ttr + ttr = this.config.auth.session.ttr, + refreshClientVersion?: string ): Promise { if ( userSession.expiresAt && @@ -106,6 +110,7 @@ export class SessionModel extends BaseModel { }, data: { expiresAt: newExpiresAt, + ...(refreshClientVersion ? { refreshClientVersion } : {}), }, }); diff --git a/packages/backend/server/src/plugins/captcha/service.ts b/packages/backend/server/src/plugins/captcha/service.ts index 01ebd9fc29..679cbc0ba4 100644 --- a/packages/backend/server/src/plugins/captcha/service.ts +++ b/packages/backend/server/src/plugins/captcha/service.ts @@ -5,7 +5,12 @@ import type { Request } from 'express'; import { nanoid } from 'nanoid'; import { z } from 'zod'; -import { CaptchaVerificationFailed, Config, OnEvent } from '../../base'; +import { + CaptchaVerificationFailed, + Config, + getRequestClientIp, + OnEvent, +} from '../../base'; import { ServerFeature, ServerService } from '../../core'; import { Models, TokenType } from '../../models'; import { verifyChallengeResponse } from '../../native'; @@ -133,7 +138,7 @@ export class CaptchaService { } else { const isTokenVerified = await this.verifyCaptchaToken( credential.token, - req.headers['CF-Connecting-IP'] as string + getRequestClientIp(req) ); if (!isTokenVerified) { diff --git a/packages/backend/server/src/plugins/oauth/controller.ts b/packages/backend/server/src/plugins/oauth/controller.ts index 2f94137b66..97965a1090 100644 --- a/packages/backend/server/src/plugins/oauth/controller.ts +++ b/packages/backend/server/src/plugins/oauth/controller.ts @@ -15,6 +15,7 @@ import type { Request, Response } from 'express'; import { ActionForbidden, Config, + getClientVersionFromRequest, InvalidAuthState, InvalidOauthCallbackState, MissingOauthQueryParameter, @@ -50,6 +51,7 @@ export class OAuthController { @Post('/preflight') @HttpCode(HttpStatus.OK) async preflight( + @Req() req: Request, @Body('provider') unknownProviderName?: keyof typeof OAuthProviderName, @Body('redirect_uri') redirectUri?: string, @Body('client') client?: string, @@ -75,11 +77,13 @@ export class OAuthController { throw new ActionForbidden(); } + const clientVersion = getClientVersionFromRequest(req); const state = await this.oauth.saveOAuthState({ provider: providerName, redirectUri, client, clientNonce, + clientVersion, ...(pkce ? { pkce: { @@ -220,7 +224,7 @@ export class OAuthController { tokens ); - await this.auth.setCookies(req, res, user.id); + await this.auth.setCookies(req, res, user.id, state.clientVersion); if ( state.provider === OAuthProviderName.Apple && diff --git a/packages/backend/server/src/plugins/oauth/types.ts b/packages/backend/server/src/plugins/oauth/types.ts index ba75989257..39fcff7081 100644 --- a/packages/backend/server/src/plugins/oauth/types.ts +++ b/packages/backend/server/src/plugins/oauth/types.ts @@ -13,6 +13,7 @@ export interface OAuthState { redirectUri?: string; client?: string; clientNonce?: string; + clientVersion?: string; provider: OAuthProviderName; pkce?: OAuthPkceState; token?: string; diff --git a/packages/backend/server/src/plugins/payment/manager/selfhost.ts b/packages/backend/server/src/plugins/payment/manager/selfhost.ts index 53fad460a2..bd4b7324d4 100644 --- a/packages/backend/server/src/plugins/payment/manager/selfhost.ts +++ b/packages/backend/server/src/plugins/payment/manager/selfhost.ts @@ -87,7 +87,7 @@ export class SelfhostTeamSubscriptionManager extends SubscriptionManager { return { allow_promotion_codes: true }; })(); - let successUrl = this.url.link(params.successCallbackLink); + let successUrl = this.url.safeLink(params.successCallbackLink || '/'); // stripe only accept unescaped '{CHECKOUT_SESSION_ID}' as query successUrl = this.url.addSimpleQuery( successUrl, diff --git a/packages/backend/server/src/plugins/payment/manager/user.ts b/packages/backend/server/src/plugins/payment/manager/user.ts index 9743beadd7..e0647644b3 100644 --- a/packages/backend/server/src/plugins/payment/manager/user.ts +++ b/packages/backend/server/src/plugins/payment/manager/user.ts @@ -204,7 +204,7 @@ export class UserSubscriptionManager extends SubscriptionManager { ], ...mode, ...discounts, - success_url: this.url.link(params.successCallbackLink), + success_url: this.url.safeLink(params.successCallbackLink || '/'), }); } diff --git a/packages/backend/server/src/plugins/payment/manager/workspace.ts b/packages/backend/server/src/plugins/payment/manager/workspace.ts index 8e2b3a4d71..6501ed740b 100644 --- a/packages/backend/server/src/plugins/payment/manager/workspace.ts +++ b/packages/backend/server/src/plugins/payment/manager/workspace.ts @@ -120,7 +120,7 @@ export class WorkspaceSubscriptionManager extends SubscriptionManager { }, }, ...discounts, - success_url: this.url.link(params.successCallbackLink), + success_url: this.url.safeLink(params.successCallbackLink || '/'), }); } diff --git a/packages/frontend/core/src/desktop/pages/auth/magic-link.tsx b/packages/frontend/core/src/desktop/pages/auth/magic-link.tsx index ce8ba88e1c..f8db1a4ed9 100644 --- a/packages/frontend/core/src/desktop/pages/auth/magic-link.tsx +++ b/packages/frontend/core/src/desktop/pages/auth/magic-link.tsx @@ -4,11 +4,14 @@ import { type LoaderFunction, redirect, useLoaderData, - // eslint-disable-next-line @typescript-eslint/no-restricted-imports useNavigate, } from 'react-router-dom'; import { AuthService } from '../../../modules/cloud'; +import { + buildAuthenticationDeepLink, + buildOpenAppUrlRoute, +} from '../../../modules/open-in-app'; import { supportedClient } from './common'; interface LoaderData { @@ -44,13 +47,14 @@ export const loader: LoaderFunction = ({ request }) => { return redirect('/sign-in?error=Invalid callback parameters'); } - const authParams = new URLSearchParams(); - authParams.set('method', 'magic-link'); - authParams.set('payload', JSON.stringify(payload)); + const urlToOpen = buildAuthenticationDeepLink({ + scheme: clientCheckResult.data, + method: 'magic-link', + payload, + server: location.origin, + }); - return redirect( - `/open-app/url?url=${encodeURIComponent(`${client}://authentication?${authParams.toString()}`)}` - ); + return redirect(buildOpenAppUrlRoute(urlToOpen)); }; export const Component = () => { diff --git a/packages/frontend/core/src/desktop/pages/auth/oauth-callback.tsx b/packages/frontend/core/src/desktop/pages/auth/oauth-callback.tsx index ff158563d2..5b5a14924d 100644 --- a/packages/frontend/core/src/desktop/pages/auth/oauth-callback.tsx +++ b/packages/frontend/core/src/desktop/pages/auth/oauth-callback.tsx @@ -8,6 +8,10 @@ import { } from 'react-router-dom'; import { AuthService } from '../../../modules/cloud'; +import { + buildAuthenticationDeepLink, + buildOpenAppUrlRoute, +} from '../../../modules/open-in-app'; import { supportedClient } from './common'; interface LoaderData { @@ -45,14 +49,14 @@ export const loader: LoaderFunction = async ({ request }) => { return redirect('/sign-in?error=Invalid oauth callback parameters'); } - const authParams = new URLSearchParams(); - authParams.set('method', 'oauth'); - authParams.set('payload', JSON.stringify(payload)); - authParams.set('server', location.origin); + const urlToOpen = buildAuthenticationDeepLink({ + scheme: clientCheckResult.data, + method: 'oauth', + payload, + server: location.origin, + }); - return redirect( - `/open-app/url?url=${encodeURIComponent(`${client}://authentication?${authParams.toString()}`)}` - ); + return redirect(buildOpenAppUrlRoute(urlToOpen)); } catch { return redirect('/sign-in?error=Invalid oauth callback parameters'); } diff --git a/packages/frontend/core/src/desktop/pages/open-app/index.tsx b/packages/frontend/core/src/desktop/pages/open-app/index.tsx index 2e7be8a495..05c69004f6 100644 --- a/packages/frontend/core/src/desktop/pages/open-app/index.tsx +++ b/packages/frontend/core/src/desktop/pages/open-app/index.tsx @@ -1,5 +1,10 @@ import { useNavigateHelper } from '@affine/core/components/hooks/use-navigate-helper'; import { AuthService } from '@affine/core/modules/cloud'; +import { + buildAuthenticationDeepLink, + buildOpenAppUrlRoute, + normalizeOpenAppSignInNextParam, +} from '@affine/core/modules/open-in-app'; import { OpenInAppPage } from '@affine/core/modules/open-in-app/views/open-in-app-page'; import { appSchemaUrl, @@ -7,8 +12,8 @@ import { channelToScheme, } from '@affine/core/utils/channel'; import { useService } from '@toeverything/infra'; -import { useCallback, useEffect, useRef, useState } from 'react'; -import { useParams, useSearchParams } from 'react-router-dom'; +import { useCallback, useEffect, useRef } from 'react'; +import { useNavigate, useParams, useSearchParams } from 'react-router-dom'; import { AppContainer } from '../../components/app-container'; @@ -51,13 +56,16 @@ const OpenAppSignInRedirect = () => { const authService = useService(AuthService); const [params] = useSearchParams(); const triggeredRef = useRef(false); - const [urlToOpen, setUrlToOpen] = useState(null); + const navigate = useNavigate(); const maybeScheme = appSchemes.safeParse(params.get('scheme')); const scheme = maybeScheme.success ? maybeScheme.data : channelToScheme[BUILD_CONFIG.appBuildType]; - const next = params.get('next') || undefined; + const next = normalizeOpenAppSignInNextParam( + params.get('next'), + location.origin + ); useEffect(() => { if (triggeredRef.current) { @@ -68,23 +76,18 @@ const OpenAppSignInRedirect = () => { authService .createOpenAppSignInCode() .then(code => { - const authParams = new URLSearchParams(); - authParams.set('method', 'open-app-signin'); - authParams.set( - 'payload', - JSON.stringify(next ? { code, next } : { code }) - ); - authParams.set('server', location.origin); - setUrlToOpen(`${scheme}://authentication?${authParams.toString()}`); + const urlToOpen = buildAuthenticationDeepLink({ + scheme, + method: 'open-app-signin', + payload: next ? { code, next } : { code }, + server: location.origin, + }); + navigate(buildOpenAppUrlRoute(urlToOpen), { replace: true }); }) .catch(console.error); - }, [authService, next, scheme]); + }, [authService, navigate, next, scheme]); - if (!urlToOpen) { - return ; - } - - return ; + return ; }; export const Component = () => { diff --git a/packages/frontend/core/src/modules/open-in-app/__tests__/deeplink.spec.ts b/packages/frontend/core/src/modules/open-in-app/__tests__/deeplink.spec.ts new file mode 100644 index 0000000000..686fae4e65 --- /dev/null +++ b/packages/frontend/core/src/modules/open-in-app/__tests__/deeplink.spec.ts @@ -0,0 +1,71 @@ +import { expect, test } from 'vitest'; + +import { + buildAuthenticationDeepLink, + buildOpenAppUrlRoute, + normalizeOpenAppSignInNextParam, +} from '../utils'; + +test('buildAuthenticationDeepLink', () => { + const payload = { code: '1', next: '/workspace/123' }; + const url = buildAuthenticationDeepLink({ + scheme: 'affine', + method: 'open-app-signin', + payload, + server: 'https://app.affine.local', + }); + + const parsed = new URL(url); + + expect(parsed.protocol).toBe('affine:'); + expect(parsed.hostname).toBe('authentication'); + expect(parsed.searchParams.get('method')).toBe('open-app-signin'); + expect(parsed.searchParams.get('payload')).toBe(JSON.stringify(payload)); + expect(parsed.searchParams.get('server')).toBe('https://app.affine.local'); +}); + +test('buildOpenAppUrlRoute', () => { + const urlToOpen = 'affine://authentication?method=oauth&payload=%7B%7D'; + const route = buildOpenAppUrlRoute(urlToOpen); + + const parsed = new URL(route, 'https://app.affine.local'); + expect(parsed.pathname).toBe('/open-app/url'); + expect(parsed.searchParams.get('url')).toBe(urlToOpen); +}); + +test('normalizeOpenAppSignInNextParam', () => { + expect( + normalizeOpenAppSignInNextParam( + '/workspace/123', + 'https://app.affine.local' + ) + ).toBe('/workspace/123'); + + expect( + normalizeOpenAppSignInNextParam( + 'https://app.affine.local/workspace/123?foo=1#bar', + 'https://app.affine.local' + ) + ).toBe('/workspace/123?foo=1#bar'); + + expect( + normalizeOpenAppSignInNextParam( + 'https://evil.example/workspace/123', + 'https://app.affine.local' + ) + ).toBeUndefined(); + + expect( + normalizeOpenAppSignInNextParam( + '//evil.example/workspace/123', + 'https://app.affine.local' + ) + ).toBeUndefined(); + + expect( + normalizeOpenAppSignInNextParam( + '/redirect-proxy?redirect_uri=https://evil.example', + 'https://app.affine.local' + ) + ).toBeUndefined(); +}); diff --git a/packages/frontend/core/src/modules/open-in-app/utils.ts b/packages/frontend/core/src/modules/open-in-app/utils.ts index e0c0a2a4f4..3df4773de8 100644 --- a/packages/frontend/core/src/modules/open-in-app/utils.ts +++ b/packages/frontend/core/src/modules/open-in-app/utils.ts @@ -1,8 +1,84 @@ -import { channelToScheme } from '@affine/core/utils'; +import { channelToScheme } from '@affine/core/utils/channel'; import { DebugLogger } from '@affine/debug'; const logger = new DebugLogger('open-in-app'); +export type AuthenticationMethod = 'magic-link' | 'oauth' | 'open-app-signin'; + +export function buildAuthenticationDeepLink(options: { + scheme: string; + method: AuthenticationMethod; + payload: unknown; + server?: string; +}) { + const params = new URLSearchParams(); + params.set('method', options.method); + params.set('payload', JSON.stringify(options.payload)); + if (options.server) { + params.set('server', options.server); + } + + return `${options.scheme}://authentication?${params.toString()}`; +} + +export function buildOpenAppUrlRoute(urlToOpen: string) { + const params = new URLSearchParams(); + params.set('url', urlToOpen); + return `/open-app/url?${params.toString()}`; +} + +function isAllowedOpenAppSignInNext(next: string) { + if (next === '/') { + return true; + } + + if (next.startsWith('/workspace')) { + const boundary = next.charAt('/workspace'.length); + return ( + boundary === '' || + boundary === '/' || + boundary === '?' || + boundary === '#' + ); + } + + return next.startsWith('/share/'); +} + +export function normalizeOpenAppSignInNextParam( + next: string | null, + currentOrigin: string +) { + if (!next) { + return; + } + + // Disallow protocol-relative urls like `//evil.example`. + if (next.startsWith('//')) { + return; + } + + let parsed: URL; + try { + parsed = new URL(next, currentOrigin); + } catch { + return; + } + + // Only allow navigation within current origin. + if (parsed.origin !== currentOrigin) { + return; + } + + const normalized = `${parsed.pathname}${parsed.search}${parsed.hash}`; + + if (!isAllowedOpenAppSignInNext(normalized)) { + return; + } + + return normalized; +} + // return an AFFiNE app's url to be opened in desktop app export const getOpenUrlInDesktopAppLink = ( url: string,