feat: drop outdated session (#14373)

#### PR Dependency Tree


* **PR #14373** 👈

This tree was auto-generated by
[Charcoal](https://github.com/danerwilliams/charcoal)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **New Features**
* Added client version tracking and validation to ensure application
compatibility across authentication flows and sessions.
* Enhanced OAuth authentication with improved version handling during
sign-in and refresh operations.

* **Bug Fixes**
* Improved payment callback URL handling with safer defaults for
redirect links.

* **Tests**
* Expanded test coverage for client version enforcement and session
management.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
DarkSky
2026-02-05 21:35:36 +08:00
committed by GitHub
parent 161eb302fd
commit 944fab36ac
29 changed files with 845 additions and 125 deletions

View File

@@ -53,6 +53,34 @@ test('should be able to sign in with credential', async t => {
t.is(session?.id, u1.id);
});
test('should record sign in client version when header is provided', async t => {
const { app, db } = t.context;
const u1 = await app.createUser('u1@affine.pro');
await app
.POST('/api/auth/sign-in')
.set('x-affine-version', '0.25.1')
.send({ email: u1.email, password: u1.password })
.expect(200);
const userSession1 = await db.userSession.findFirst({
where: { userId: u1.id },
});
t.is(userSession1?.signInClientVersion, '0.25.1');
// should not overwrite existing value with null/undefined
await app
.POST('/api/auth/sign-in')
.send({ email: u1.email, password: u1.password })
.expect(200);
const userSession2 = await db.userSession.findFirst({
where: { userId: u1.id },
});
t.is(userSession2?.signInClientVersion, '0.25.1');
});
test('should be able to sign in with email', async t => {
const { app } = t.context;

View File

@@ -1,13 +1,14 @@
import { Controller, Get, HttpStatus, INestApplication } from '@nestjs/common';
import { Controller, Get, HttpStatus } from '@nestjs/common';
import { PrismaClient } from '@prisma/client';
import ava, { TestFn } from 'ava';
import Sinon from 'sinon';
import request from 'supertest';
import { ConfigFactory } from '../../base';
import { AuthModule, CurrentUser, Public, Session } from '../../core/auth';
import { AuthService } from '../../core/auth/service';
import { Models } from '../../models';
import { createTestingApp } from '../utils';
import { createTestingApp, TestingApp } from '../utils';
@Controller('/')
class TestController {
@@ -29,31 +30,46 @@ class TestController {
}
const test = ava as TestFn<{
app: INestApplication;
app: TestingApp;
server: any;
auth: AuthService;
models: Models;
db: PrismaClient;
config: ConfigFactory;
u1: CurrentUser;
sessionId: string;
}>;
let server!: any;
let auth!: AuthService;
let u1!: CurrentUser;
let sessionId = '';
test.before(async t => {
const app = await createTestingApp({
imports: [AuthModule],
controllers: [TestController],
});
auth = app.get(AuthService);
u1 = await auth.signUp('u1@affine.pro', '1');
const models = app.get(Models);
const session = await models.session.createSession();
sessionId = session.id;
await auth.createUserSession(u1.id, sessionId);
server = app.getHttpServer();
t.context.app = app;
t.context.server = app.getHttpServer();
t.context.auth = app.get(AuthService);
t.context.models = app.get(Models);
t.context.db = app.get(PrismaClient);
t.context.config = app.get(ConfigFactory);
});
test.beforeEach(async t => {
Sinon.restore();
await t.context.app.initTestingDB();
t.context.config.override({
client: {
versionControl: {
enabled: false,
requiredVersion: '>=0.25.0',
},
},
});
t.context.u1 = await t.context.auth.signUp('u1@affine.pro', '1');
const session = await t.context.models.session.createSession();
t.context.sessionId = session.id;
await t.context.auth.createUserSession(t.context.u1.id, t.context.sessionId);
});
test.after.always(async t => {
@@ -61,92 +77,95 @@ test.after.always(async t => {
});
test('should be able to visit public api if not signed in', async t => {
const res = await request(server).get('/public').expect(200);
const res = await request(t.context.server).get('/public').expect(200);
t.is(res.body.user, undefined);
});
test('should be able to visit public api if signed in', async t => {
const res = await request(server)
const res = await request(t.context.server)
.get('/public')
.set('Cookie', `${AuthService.sessionCookieName}=${sessionId}`)
.set('Cookie', `${AuthService.sessionCookieName}=${t.context.sessionId}`)
.expect(HttpStatus.OK);
t.is(res.body.user.id, u1.id);
t.is(res.body.user.id, t.context.u1.id);
});
test('should not be able to visit private api if not signed in', async t => {
await request(server).get('/private').expect(HttpStatus.UNAUTHORIZED).expect({
status: 401,
code: 'Unauthorized',
type: 'AUTHENTICATION_REQUIRED',
name: 'AUTHENTICATION_REQUIRED',
message: 'You must sign in first to access this resource.',
});
await request(t.context.server)
.get('/private')
.expect(HttpStatus.UNAUTHORIZED)
.expect({
status: 401,
code: 'Unauthorized',
type: 'AUTHENTICATION_REQUIRED',
name: 'AUTHENTICATION_REQUIRED',
message: 'You must sign in first to access this resource.',
});
t.assert(true);
});
test('should be able to visit private api if signed in', async t => {
const res = await request(server)
const res = await request(t.context.server)
.get('/private')
.set('Cookie', `${AuthService.sessionCookieName}=${sessionId}`)
.set('Cookie', `${AuthService.sessionCookieName}=${t.context.sessionId}`)
.expect(HttpStatus.OK);
t.is(res.body.user.id, u1.id);
t.is(res.body.user.id, t.context.u1.id);
});
test('should be able to visit private api with access token', async t => {
const models = t.context.app.get(Models);
const token = await models.accessToken.create({
userId: u1.id,
userId: t.context.u1.id,
name: 'test',
});
const res = await request(server)
const res = await request(t.context.server)
.get('/private')
.set('Authorization', `Bearer ${token.token}`)
.expect(HttpStatus.OK);
t.is(res.body.user.id, u1.id);
t.is(res.body.user.id, t.context.u1.id);
});
test('should be able to parse session cookie', async t => {
const spy = Sinon.spy(auth, 'getUserSession');
await request(server)
const spy = Sinon.spy(t.context.auth, 'getUserSession');
await request(t.context.server)
.get('/public')
.set('cookie', `${AuthService.sessionCookieName}=${sessionId}`)
.set('cookie', `${AuthService.sessionCookieName}=${t.context.sessionId}`)
.expect(200);
t.deepEqual(spy.firstCall.args, [sessionId, undefined]);
t.deepEqual(spy.firstCall.args, [t.context.sessionId, undefined]);
spy.restore();
});
test('should be able to parse bearer token', async t => {
const spy = Sinon.spy(auth, 'getUserSession');
const spy = Sinon.spy(t.context.auth, 'getUserSession');
await request(server)
await request(t.context.server)
.get('/public')
.auth(sessionId, { type: 'bearer' })
.auth(t.context.sessionId, { type: 'bearer' })
.expect(200);
t.deepEqual(spy.firstCall.args, [sessionId, undefined]);
t.deepEqual(spy.firstCall.args, [t.context.sessionId, undefined]);
spy.restore();
});
test('should be able to refresh session if needed', async t => {
await t.context.app.get(PrismaClient).userSession.updateMany({
where: {
sessionId,
sessionId: t.context.sessionId,
},
data: {
expiresAt: new Date(Date.now() + 1000 * 60 * 60 /* expires in 1 hour */),
},
});
const res = await request(server)
const res = await request(t.context.server)
.get('/session')
.set('cookie', `${AuthService.sessionCookieName}=${sessionId}`)
.set('cookie', `${AuthService.sessionCookieName}=${t.context.sessionId}`)
.expect(200);
const cookie = res
@@ -155,3 +174,101 @@ test('should be able to refresh session if needed', async t => {
t.truthy(cookie);
});
test('should record refresh client version when refreshed', async t => {
await t.context.db.userSession.updateMany({
where: { sessionId: t.context.sessionId },
data: {
expiresAt: new Date(Date.now() + 1000 * 60 * 60 /* expires in 1 hour */),
},
});
await request(t.context.server)
.get('/session')
.set('cookie', `${AuthService.sessionCookieName}=${t.context.sessionId}`)
.set('x-affine-version', '0.25.2')
.expect(200);
const userSession = await t.context.db.userSession.findFirst({
where: { sessionId: t.context.sessionId, userId: t.context.u1.id },
});
t.is(userSession?.refreshClientVersion, '0.25.2');
});
test('should allow auth when header is missing but stored version is valid', async t => {
t.context.config.override({
client: {
versionControl: {
enabled: true,
requiredVersion: '>=0.25.0',
},
},
});
await t.context.db.userSession.updateMany({
where: { sessionId: t.context.sessionId },
data: { signInClientVersion: '0.25.0' },
});
const res = await request(t.context.server)
.get('/private')
.set('Cookie', `${AuthService.sessionCookieName}=${t.context.sessionId}`)
.expect(200);
t.is(res.body.user.id, t.context.u1.id);
});
test('should kick out unsupported client version on non-public handler', async t => {
t.context.config.override({
client: {
versionControl: {
enabled: true,
requiredVersion: '>=0.25.0',
},
},
});
const res = await request(t.context.server)
.get('/private')
.set('Cookie', `${AuthService.sessionCookieName}=${t.context.sessionId}`)
.set('x-affine-version', '0.24.0')
.expect(403);
const setCookies = res.get('Set-Cookie') ?? [];
t.true(
setCookies.some(c => c.startsWith(`${AuthService.sessionCookieName}=`))
);
t.true(setCookies.some(c => c.startsWith(`${AuthService.userCookieName}=`)));
t.true(setCookies.some(c => c.startsWith(`${AuthService.csrfCookieName}=`)));
const session = await t.context.db.session.findFirst({
where: { id: t.context.sessionId },
});
t.is(session, null);
});
test('should not block public handler when client version is unsupported', async t => {
t.context.config.override({
client: {
versionControl: {
enabled: true,
requiredVersion: '>=0.25.0',
},
},
});
const res = await request(t.context.server)
.get('/public')
.set('Cookie', `${AuthService.sessionCookieName}=${t.context.sessionId}`)
.set('x-affine-version', '0.24.0')
.expect(200);
t.is(res.body.user, undefined);
const setCookies = res.get('Set-Cookie') ?? [];
t.true(
setCookies.some(c => c.startsWith(`${AuthService.sessionCookieName}=`))
);
t.true(setCookies.some(c => c.startsWith(`${AuthService.userCookieName}=`)));
t.true(setCookies.some(c => c.startsWith(`${AuthService.csrfCookieName}=`)));
});

View File

@@ -122,6 +122,64 @@ test('should refresh exists userSession', async t => {
);
});
test('should record sign-in client version on create and update', async t => {
const user = await t.context.user.create({
email: 'test@affine.pro',
});
const session = await t.context.session.createSession();
const userSession1 = await t.context.session.createOrRefreshUserSession(
user.id,
session.id,
undefined,
'0.25.0'
);
t.is(userSession1.signInClientVersion, '0.25.0');
const userSession2 = await t.context.session.createOrRefreshUserSession(
user.id,
session.id
);
t.is(userSession2.signInClientVersion, '0.25.0');
const userSession3 = await t.context.session.createOrRefreshUserSession(
user.id,
session.id,
undefined,
'0.26.0'
);
t.is(userSession3.signInClientVersion, '0.26.0');
});
test('should record refresh client version only when refreshed', async t => {
const user = await t.context.user.create({
email: 'test@affine.pro',
});
const session = await t.context.session.createSession();
const userSession = await t.context.session.createOrRefreshUserSession(
user.id,
session.id
);
// force refresh
userSession.expiresAt = new Date(
userSession.expiresAt!.getTime() -
t.context.config.auth.session.ttr * 2 * 1000
);
const newExpiresAt = await t.context.session.refreshUserSessionIfNeeded(
userSession,
undefined,
'0.25.0'
);
t.truthy(newExpiresAt);
const refreshed = await t.context.db.userSession.findFirst({
where: { id: userSession.id },
});
t.is(refreshed?.refreshClientVersion, '0.25.0');
});
test('should not refresh userSession when expires time not hit ttr', async t => {
const user = await t.context.user.create({
email: 'test@affine.pro',

View File

@@ -1,4 +1,5 @@
import { Controller, Get, HttpStatus, UseGuards } from '@nestjs/common';
import { PrismaClient } from '@prisma/client';
import ava, { TestFn } from 'ava';
import Sinon from 'sinon';
import { type Response } from 'supertest';
@@ -144,6 +145,72 @@ test('should be able to prevent requests if limit is reached', async t => {
stub.restore();
});
test('should use session id as tracker when available', async t => {
const { app } = t.context;
const user = await app.signupV1('u1@affine.pro');
const userSession = await app.get(PrismaClient).userSession.findFirst({
where: { userId: user.id },
});
t.truthy(userSession);
const stub = Sinon.stub(app.get(ThrottlerStorage), 'increment').resolves({
timeToExpire: 10,
totalHits: 1,
isBlocked: false,
timeToBlockExpire: 0,
});
await app.GET('/throttled/default').expect(200);
const key = stub.firstCall.args[0] as string;
t.true(key.startsWith(`throttler:${userSession!.sessionId};default`));
stub.restore();
});
test('should use CF-Connecting-IP as tracker when present', async t => {
const { app } = t.context;
const stub = Sinon.stub(app.get(ThrottlerStorage), 'increment').resolves({
timeToExpire: 10,
totalHits: 1,
isBlocked: false,
timeToBlockExpire: 0,
});
await app
.GET('/nonthrottled/default')
.set('CF-Connecting-IP', '1.2.3.4')
.expect(200);
const key = stub.firstCall.args[0] as string;
t.true(key.startsWith('throttler:1.2.3.4;default'));
stub.restore();
});
test('should use X-Forwarded-For as tracker when present', async t => {
const { app } = t.context;
const stub = Sinon.stub(app.get(ThrottlerStorage), 'increment').resolves({
timeToExpire: 10,
totalHits: 1,
isBlocked: false,
timeToBlockExpire: 0,
});
await app
.GET('/nonthrottled/default')
.set('X-Forwarded-For', '5.6.7.8, 9.9.9.9')
.expect(200);
const key = stub.firstCall.args[0] as string;
t.true(key.startsWith('throttler:5.6.7.8;default'));
stub.restore();
});
// ====== unauthenticated user visits ======
test('should use default throttler for unauthenticated user when not specified', async t => {
const { app } = t.context;

View File

@@ -6,7 +6,7 @@ import ava, { TestFn } from 'ava';
import Sinon from 'sinon';
import { AppModule } from '../../app.module';
import { URLHelper } from '../../base';
import { ConfigFactory, URLHelper } from '../../base';
import { ConfigModule } from '../../base/config';
import { CurrentUser } from '../../core/auth';
import { AuthService } from '../../core/auth/service';
@@ -56,6 +56,14 @@ test.before(async t => {
test.beforeEach(async t => {
Sinon.restore();
await t.context.app.initTestingDB();
t.context.app.get(ConfigFactory).override({
client: {
versionControl: {
enabled: false,
requiredVersion: '>=0.25.0',
},
},
});
t.context.u1 = await t.context.auth.signUp('u1@affine.pro', '1');
});
@@ -156,6 +164,56 @@ test('should be able to redirect to oauth provider with client_nonce', async t =
t.truthy(state.state);
});
test('should record sign in client version from oauth preflight state', async t => {
const { app, db } = t.context;
const config = app.get(ConfigFactory);
config.override({
client: {
versionControl: {
enabled: true,
requiredVersion: '>=0.25.0',
},
},
});
const preflightRes = await app
.POST('/api/oauth/preflight')
.set('x-affine-version', '0.25.3')
.send({ provider: 'Google', client_nonce: 'test-nonce' })
.expect(HttpStatus.OK);
const redirect = new URL(preflightRes.body.url as string);
const stateParam = redirect.searchParams.get('state');
t.truthy(stateParam);
// state should be a json string
const rawState = JSON.parse(stateParam!);
const provider = app.get(GoogleOAuthProvider);
Sinon.stub(provider, 'getToken').resolves({ accessToken: '1' });
Sinon.stub(provider, 'getUser').resolves({
id: '1',
email: 'oauth-version@affine.pro',
avatarUrl: 'avatar',
});
const callbackRes = await app
.POST('/api/oauth/callback')
.send({ code: '1', state: stateParam, client_nonce: 'test-nonce' })
.expect(HttpStatus.OK);
const userId = callbackRes.body.id as string;
t.truthy(userId);
const userSession = await db.userSession.findFirst({
where: { userId },
});
t.is(userSession?.signInClientVersion, '0.25.3');
t.is(userSession?.refreshClientVersion, null);
t.truthy(rawState.state);
});
test('should forbid preflight with untrusted redirect_uri', async t => {
const { app } = t.context;

View File

@@ -73,7 +73,7 @@ test('should passthrough if version check is not enabled', async t => {
spy.restore();
});
test('should passthrough is version range is invalid', async t => {
test('should enforce hard required version when version range is invalid', async t => {
config.override({
client: {
versionControl: {
@@ -82,9 +82,17 @@ test('should passthrough is version range is invalid', async t => {
},
});
let res = await app.GET('/guarded/test').set('x-affine-version', 'invalid');
let res = await app.GET('/guarded/test').set('x-affine-version', '0.25.0');
t.is(res.status, 200);
res = await app.GET('/guarded/test').set('x-affine-version', 'invalid');
t.is(res.status, 403);
t.is(
res.body.message,
'Unsupported client with version [invalid], required version is [>=0.25.0].'
);
});
test('should pass if client version is allowed', async t => {

View File

@@ -86,6 +86,29 @@ test('can create link', t => {
);
});
test('addSimpleQuery should not double encode', t => {
t.is(
t.context.url.addSimpleQuery(
'https://app.affine.local/path',
'redirect_uri',
'/path'
),
'https://app.affine.local/path?redirect_uri=%2Fpath'
);
});
test('addSimpleQuery should allow unescaped value when escape=false', t => {
t.is(
t.context.url.addSimpleQuery(
'https://app.affine.local/path',
'session_id',
'{CHECKOUT_SESSION_ID}',
false
),
'https://app.affine.local/path?session_id={CHECKOUT_SESSION_ID}'
);
});
test('can validate callbackUrl allowlist', t => {
t.true(t.context.url.isAllowedCallbackUrl('/magic-link'));
t.true(

View File

@@ -109,7 +109,7 @@ export class URLHelper {
) {
const urlObj = new URL(url);
if (escape) {
urlObj.searchParams.set(key, encodeURIComponent(value));
urlObj.searchParams.set(key, String(value));
return urlObj.toString();
} else {
const query =

View File

@@ -16,6 +16,7 @@ import type { Request, Response } from 'express';
import { Config } from '../config';
import { getRequestResponseFromContext } from '../utils/request';
import { getRequestTrackerId } from '../utils/request-tracker';
import type { ThrottlerType } from './config';
import { THROTTLER_PROTECTED, Throttlers } from './decorators';
@@ -63,11 +64,8 @@ export class CloudThrottlerGuard extends ThrottlerGuard {
}
override getTracker(req: Request): Promise<string> {
return Promise.resolve(
// ↓ prefer session id if available
`throttler:${req.session?.sessionId ?? req.get('CF-Connecting-IP') ?? req.get('CF-ray') ?? req.ip}`
// ^ throttler prefix make the key in store recognizable
);
// throttler prefix make the key in store recognizable
return Promise.resolve(`throttler:${getRequestTrackerId(req)}`);
}
override generateKey(

View File

@@ -1,6 +1,7 @@
export * from './duration';
export * from './promise';
export * from './request';
export * from './request-tracker';
export * from './ssrf';
export * from './stream';
export * from './types';

View File

@@ -0,0 +1,44 @@
import type { Request } from 'express';
function firstForwardedForIp(value?: string) {
if (!value) {
return;
}
const [first] = value.split(',', 1);
const ip = first?.trim();
return ip || undefined;
}
function firstNonEmpty(...values: Array<string | undefined>) {
for (const value of values) {
const trimmed = value?.trim();
if (trimmed) {
return trimmed;
}
}
return;
}
export function getRequestClientIp(req: Request) {
return firstNonEmpty(
req.get('CF-Connecting-IP'),
firstForwardedForIp(req.get('X-Forwarded-For')),
req.get('X-Real-IP'),
req.ip
)!;
}
export function getRequestTrackerId(req: Request) {
return (
req.session?.sessionId ??
firstNonEmpty(
req.get('CF-Connecting-IP'),
firstForwardedForIp(req.get('X-Forwarded-For')),
req.get('X-Real-IP'),
req.get('CF-Ray'),
req.ip
)!
);
}

View File

@@ -7,6 +7,7 @@ import type {
import { Injectable, SetMetadata } from '@nestjs/common';
import { ModuleRef, Reflector } from '@nestjs/core';
import type { Request, Response } from 'express';
import semver from 'semver';
import { Socket } from 'socket.io';
import {
@@ -15,8 +16,10 @@ import {
Cache,
Config,
CryptoHelper,
getClientVersionFromRequest,
getRequestResponseFromContext,
parseCookies,
UnsupportedClientVersion,
} from '../../base';
import { WEBSOCKET_OPTIONS } from '../../base/websocket';
import { AuthService } from './service';
@@ -30,10 +33,13 @@ const INTERNAL_ACCESS_TOKEN_CLOCK_SKEW_MS = 30 * 1000;
@Injectable()
export class AuthGuard implements CanActivate, OnModuleInit {
private auth!: AuthService;
private readonly cachedVersionRange = new Map<string, semver.Range | null>();
private static readonly HARD_REQUIRED_VERSION = '>=0.25.0';
constructor(
private readonly crypto: CryptoHelper,
private readonly cache: Cache,
private readonly config: Config,
private readonly ref: ModuleRef,
private readonly reflector: Reflector
) {}
@@ -78,14 +84,14 @@ export class AuthGuard implements CanActivate, OnModuleInit {
throw new AccessDenied('Invalid internal request');
}
const authedUser = await this.signIn(req, res);
// api is public
const isPublic = this.reflector.getAllAndOverride<boolean>(
PUBLIC_ENTRYPOINT_SYMBOL,
[clazz, handler]
);
const authedUser = await this.signIn(req, res, isPublic);
if (isPublic) {
return true;
}
@@ -99,9 +105,10 @@ export class AuthGuard implements CanActivate, OnModuleInit {
async signIn(
req: Request,
res?: Response
res?: Response,
isPublic = false
): Promise<Session | TokenSession | null> {
const userSession = await this.signInWithCookie(req, res);
const userSession = await this.signInWithCookie(req, res, isPublic);
if (userSession) {
return userSession;
}
@@ -111,7 +118,8 @@ export class AuthGuard implements CanActivate, OnModuleInit {
async signInWithCookie(
req: Request,
res?: Response
res?: Response,
isPublic = false
): Promise<Session | null> {
if (req.session) {
return req.session;
@@ -121,8 +129,38 @@ export class AuthGuard implements CanActivate, OnModuleInit {
const userSession = await this.auth.getUserSessionFromRequest(req, res);
if (userSession) {
const headerClientVersion = getClientVersionFromRequest(req);
if (this.config.client.versionControl.enabled) {
const clientVersion =
headerClientVersion ??
userSession.session.refreshClientVersion ??
userSession.session.signInClientVersion;
const versionCheckResult = this.checkClientVersion(clientVersion);
if (!versionCheckResult.ok) {
await this.auth.signOut(userSession.session.sessionId);
if (res) {
await this.auth.refreshCookies(res, userSession.session.sessionId);
}
if (isPublic) {
return null;
}
throw new UnsupportedClientVersion({
clientVersion: clientVersion ?? 'unset_or_invalid',
requiredVersion: versionCheckResult.requiredVersion,
});
}
}
if (res) {
await this.auth.refreshUserSessionIfNeeded(res, userSession.session);
await this.auth.refreshUserSessionIfNeeded(
res,
userSession.session,
undefined,
headerClientVersion
);
}
req.session = {
@@ -154,6 +192,59 @@ export class AuthGuard implements CanActivate, OnModuleInit {
return null;
}
private getVersionRange(versionRange: string): semver.Range | null {
if (this.cachedVersionRange.has(versionRange)) {
// oxlint-disable-next-line @typescript-eslint/no-non-null-assertion
return this.cachedVersionRange.get(versionRange)!;
}
let range: semver.Range | null = null;
try {
range = new semver.Range(versionRange, { loose: false });
if (!semver.validRange(range)) {
range = null;
}
} catch {
range = null;
}
this.cachedVersionRange.set(versionRange, range);
return range;
}
private checkClientVersion(
clientVersion?: string | null
): { ok: true } | { ok: false; requiredVersion: string } {
const requiredVersion = this.config.client.versionControl.requiredVersion;
const configRange = this.getVersionRange(requiredVersion);
if (
configRange &&
(!clientVersion ||
!semver.satisfies(clientVersion, configRange, {
includePrerelease: true,
}))
) {
return { ok: false, requiredVersion };
}
const hardRange = this.getVersionRange(AuthGuard.HARD_REQUIRED_VERSION);
if (!hardRange) {
return { ok: true };
}
if (
!clientVersion ||
!semver.satisfies(clientVersion, hardRange, {
includePrerelease: true,
})
) {
return { ok: false, requiredVersion: AuthGuard.HARD_REQUIRED_VERSION };
}
return { ok: true };
}
}
/**
@@ -184,7 +275,13 @@ export const AuthWebsocketOptionsProvider: FactoryProvider = {
...upgradeReq.cookies,
};
const session = await guard.signIn(upgradeReq);
const session = await (async () => {
try {
return await guard.signIn(upgradeReq);
} catch {
return null;
}
})();
return !!session;
},

View File

@@ -4,7 +4,11 @@ import { Injectable, OnApplicationBootstrap } from '@nestjs/common';
import type { CookieOptions, Request, Response } from 'express';
import { assign, pick } from 'lodash-es';
import { Config, SignUpForbidden } from '../../base';
import {
Config,
getClientVersionFromRequest,
SignUpForbidden,
} from '../../base';
import { Models, type User, type UserSession } from '../../models';
import { Mailer } from '../mail/mailer';
import { createDevUsers } from './dev';
@@ -130,11 +134,17 @@ export class AuthService implements OnApplicationBootstrap {
return await this.models.session.findUserSessionsBySessionId(sessionId);
}
async createUserSession(userId: string, sessionId?: string, ttl?: number) {
async createUserSession(
userId: string,
sessionId?: string,
ttl?: number,
signInClientVersion?: string
) {
return await this.models.session.createOrRefreshUserSession(
userId,
sessionId,
ttl
ttl,
signInClientVersion
);
}
@@ -159,11 +169,13 @@ export class AuthService implements OnApplicationBootstrap {
async refreshUserSessionIfNeeded(
res: Response,
userSession: UserSession,
ttr?: number
ttr?: number,
refreshClientVersion?: string
): Promise<boolean> {
const newExpiresAt = await this.models.session.refreshUserSessionIfNeeded(
userSession,
ttr
ttr,
refreshClientVersion
);
if (!newExpiresAt) {
// no need to refresh
@@ -205,10 +217,22 @@ export class AuthService implements OnApplicationBootstrap {
};
}
async setCookies(req: Request, res: Response, userId: string) {
async setCookies(
req: Request,
res: Response,
userId: string,
clientVersion?: string
) {
const { sessionId } = this.getSessionOptionsFromRequest(req);
const userSession = await this.createUserSession(userId, sessionId);
const signInClientVersion =
clientVersion ?? getClientVersionFromRequest(req);
const userSession = await this.createUserSession(
userId,
sessionId,
undefined,
signInClientVersion
);
res.cookie(AuthService.sessionCookieName, userSession.sessionId, {
...this.cookieOptions,

View File

@@ -7,6 +7,7 @@ import { Injectable } from '@nestjs/common';
import {
Config,
getClientVersionFromRequest,
getRequestResponseFromContext,
GuardProvider,
} from '../../base';
@@ -33,7 +34,7 @@ export class VersionGuardProvider
const { req } = getRequestResponseFromContext(context);
const version = req.headers['x-affine-version'] as string | undefined;
const version = getClientVersionFromRequest(req);
return this.version.checkVersion(version);
}

View File

@@ -6,23 +6,24 @@ import { Config, UnsupportedClientVersion } from '../../base';
@Injectable()
export class VersionService {
private readonly logger = new Logger(VersionService.name);
private static readonly HARD_REQUIRED_VERSION = '>=0.25.0';
constructor(private readonly config: Config) {}
async checkVersion(clientVersion?: string) {
const requiredVersion = this.config.client.versionControl.requiredVersion;
const range = await this.getVersionRange(requiredVersion);
if (!range) {
// ignore invalid allowed version config
return true;
}
const hardRange = await this.getVersionRange(
VersionService.HARD_REQUIRED_VERSION
);
const configRange = await this.getVersionRange(requiredVersion);
if (
!clientVersion ||
!semver.satisfies(clientVersion, range, {
includePrerelease: true,
})
configRange &&
(!clientVersion ||
!semver.satisfies(clientVersion, configRange, {
includePrerelease: true,
}))
) {
throw new UnsupportedClientVersion({
clientVersion: clientVersion ?? 'unset_or_invalid',
@@ -30,6 +31,19 @@ export class VersionService {
});
}
if (
hardRange &&
(!clientVersion ||
!semver.satisfies(clientVersion, hardRange, {
includePrerelease: true,
}))
) {
throw new UnsupportedClientVersion({
clientVersion: clientVersion ?? 'unset_or_invalid',
requiredVersion: VersionService.HARD_REQUIRED_VERSION,
});
}
return true;
}

View File

@@ -46,7 +46,8 @@ export class SessionModel extends BaseModel {
async createOrRefreshUserSession(
userId: string,
sessionId?: string,
ttl = this.config.auth.session.ttl
ttl = this.config.auth.session.ttl,
signInClientVersion?: string
) {
// check whether given session is valid
if (sessionId) {
@@ -76,18 +77,21 @@ export class SessionModel extends BaseModel {
},
update: {
expiresAt,
...(signInClientVersion ? { signInClientVersion } : {}),
},
create: {
sessionId,
userId,
expiresAt,
...(signInClientVersion ? { signInClientVersion } : {}),
},
});
}
async refreshUserSessionIfNeeded(
userSession: UserSession,
ttr = this.config.auth.session.ttr
ttr = this.config.auth.session.ttr,
refreshClientVersion?: string
): Promise<Date | undefined> {
if (
userSession.expiresAt &&
@@ -106,6 +110,7 @@ export class SessionModel extends BaseModel {
},
data: {
expiresAt: newExpiresAt,
...(refreshClientVersion ? { refreshClientVersion } : {}),
},
});

View File

@@ -5,7 +5,12 @@ import type { Request } from 'express';
import { nanoid } from 'nanoid';
import { z } from 'zod';
import { CaptchaVerificationFailed, Config, OnEvent } from '../../base';
import {
CaptchaVerificationFailed,
Config,
getRequestClientIp,
OnEvent,
} from '../../base';
import { ServerFeature, ServerService } from '../../core';
import { Models, TokenType } from '../../models';
import { verifyChallengeResponse } from '../../native';
@@ -133,7 +138,7 @@ export class CaptchaService {
} else {
const isTokenVerified = await this.verifyCaptchaToken(
credential.token,
req.headers['CF-Connecting-IP'] as string
getRequestClientIp(req)
);
if (!isTokenVerified) {

View File

@@ -15,6 +15,7 @@ import type { Request, Response } from 'express';
import {
ActionForbidden,
Config,
getClientVersionFromRequest,
InvalidAuthState,
InvalidOauthCallbackState,
MissingOauthQueryParameter,
@@ -50,6 +51,7 @@ export class OAuthController {
@Post('/preflight')
@HttpCode(HttpStatus.OK)
async preflight(
@Req() req: Request,
@Body('provider') unknownProviderName?: keyof typeof OAuthProviderName,
@Body('redirect_uri') redirectUri?: string,
@Body('client') client?: string,
@@ -75,11 +77,13 @@ export class OAuthController {
throw new ActionForbidden();
}
const clientVersion = getClientVersionFromRequest(req);
const state = await this.oauth.saveOAuthState({
provider: providerName,
redirectUri,
client,
clientNonce,
clientVersion,
...(pkce
? {
pkce: {
@@ -220,7 +224,7 @@ export class OAuthController {
tokens
);
await this.auth.setCookies(req, res, user.id);
await this.auth.setCookies(req, res, user.id, state.clientVersion);
if (
state.provider === OAuthProviderName.Apple &&

View File

@@ -13,6 +13,7 @@ export interface OAuthState {
redirectUri?: string;
client?: string;
clientNonce?: string;
clientVersion?: string;
provider: OAuthProviderName;
pkce?: OAuthPkceState;
token?: string;

View File

@@ -87,7 +87,7 @@ export class SelfhostTeamSubscriptionManager extends SubscriptionManager {
return { allow_promotion_codes: true };
})();
let successUrl = this.url.link(params.successCallbackLink);
let successUrl = this.url.safeLink(params.successCallbackLink || '/');
// stripe only accept unescaped '{CHECKOUT_SESSION_ID}' as query
successUrl = this.url.addSimpleQuery(
successUrl,

View File

@@ -204,7 +204,7 @@ export class UserSubscriptionManager extends SubscriptionManager {
],
...mode,
...discounts,
success_url: this.url.link(params.successCallbackLink),
success_url: this.url.safeLink(params.successCallbackLink || '/'),
});
}

View File

@@ -120,7 +120,7 @@ export class WorkspaceSubscriptionManager extends SubscriptionManager {
},
},
...discounts,
success_url: this.url.link(params.successCallbackLink),
success_url: this.url.safeLink(params.successCallbackLink || '/'),
});
}