mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-24 18:02:47 +08:00
feat: drop outdated session (#14373)
#### PR Dependency Tree * **PR #14373** 👈 This tree was auto-generated by [Charcoal](https://github.com/danerwilliams/charcoal) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added client version tracking and validation to ensure application compatibility across authentication flows and sessions. * Enhanced OAuth authentication with improved version handling during sign-in and refresh operations. * **Bug Fixes** * Improved payment callback URL handling with safer defaults for redirect links. * **Tests** * Expanded test coverage for client version enforcement and session management. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
@@ -53,6 +53,34 @@ test('should be able to sign in with credential', async t => {
|
||||
t.is(session?.id, u1.id);
|
||||
});
|
||||
|
||||
test('should record sign in client version when header is provided', async t => {
|
||||
const { app, db } = t.context;
|
||||
|
||||
const u1 = await app.createUser('u1@affine.pro');
|
||||
|
||||
await app
|
||||
.POST('/api/auth/sign-in')
|
||||
.set('x-affine-version', '0.25.1')
|
||||
.send({ email: u1.email, password: u1.password })
|
||||
.expect(200);
|
||||
|
||||
const userSession1 = await db.userSession.findFirst({
|
||||
where: { userId: u1.id },
|
||||
});
|
||||
t.is(userSession1?.signInClientVersion, '0.25.1');
|
||||
|
||||
// should not overwrite existing value with null/undefined
|
||||
await app
|
||||
.POST('/api/auth/sign-in')
|
||||
.send({ email: u1.email, password: u1.password })
|
||||
.expect(200);
|
||||
|
||||
const userSession2 = await db.userSession.findFirst({
|
||||
where: { userId: u1.id },
|
||||
});
|
||||
t.is(userSession2?.signInClientVersion, '0.25.1');
|
||||
});
|
||||
|
||||
test('should be able to sign in with email', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import { Controller, Get, HttpStatus, INestApplication } from '@nestjs/common';
|
||||
import { Controller, Get, HttpStatus } from '@nestjs/common';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
import ava, { TestFn } from 'ava';
|
||||
import Sinon from 'sinon';
|
||||
import request from 'supertest';
|
||||
|
||||
import { ConfigFactory } from '../../base';
|
||||
import { AuthModule, CurrentUser, Public, Session } from '../../core/auth';
|
||||
import { AuthService } from '../../core/auth/service';
|
||||
import { Models } from '../../models';
|
||||
import { createTestingApp } from '../utils';
|
||||
import { createTestingApp, TestingApp } from '../utils';
|
||||
|
||||
@Controller('/')
|
||||
class TestController {
|
||||
@@ -29,31 +30,46 @@ class TestController {
|
||||
}
|
||||
|
||||
const test = ava as TestFn<{
|
||||
app: INestApplication;
|
||||
app: TestingApp;
|
||||
server: any;
|
||||
auth: AuthService;
|
||||
models: Models;
|
||||
db: PrismaClient;
|
||||
config: ConfigFactory;
|
||||
u1: CurrentUser;
|
||||
sessionId: string;
|
||||
}>;
|
||||
|
||||
let server!: any;
|
||||
let auth!: AuthService;
|
||||
let u1!: CurrentUser;
|
||||
|
||||
let sessionId = '';
|
||||
|
||||
test.before(async t => {
|
||||
const app = await createTestingApp({
|
||||
imports: [AuthModule],
|
||||
controllers: [TestController],
|
||||
});
|
||||
|
||||
auth = app.get(AuthService);
|
||||
u1 = await auth.signUp('u1@affine.pro', '1');
|
||||
|
||||
const models = app.get(Models);
|
||||
const session = await models.session.createSession();
|
||||
sessionId = session.id;
|
||||
await auth.createUserSession(u1.id, sessionId);
|
||||
|
||||
server = app.getHttpServer();
|
||||
t.context.app = app;
|
||||
t.context.server = app.getHttpServer();
|
||||
t.context.auth = app.get(AuthService);
|
||||
t.context.models = app.get(Models);
|
||||
t.context.db = app.get(PrismaClient);
|
||||
t.context.config = app.get(ConfigFactory);
|
||||
});
|
||||
|
||||
test.beforeEach(async t => {
|
||||
Sinon.restore();
|
||||
await t.context.app.initTestingDB();
|
||||
t.context.config.override({
|
||||
client: {
|
||||
versionControl: {
|
||||
enabled: false,
|
||||
requiredVersion: '>=0.25.0',
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
t.context.u1 = await t.context.auth.signUp('u1@affine.pro', '1');
|
||||
const session = await t.context.models.session.createSession();
|
||||
t.context.sessionId = session.id;
|
||||
await t.context.auth.createUserSession(t.context.u1.id, t.context.sessionId);
|
||||
});
|
||||
|
||||
test.after.always(async t => {
|
||||
@@ -61,92 +77,95 @@ test.after.always(async t => {
|
||||
});
|
||||
|
||||
test('should be able to visit public api if not signed in', async t => {
|
||||
const res = await request(server).get('/public').expect(200);
|
||||
const res = await request(t.context.server).get('/public').expect(200);
|
||||
|
||||
t.is(res.body.user, undefined);
|
||||
});
|
||||
|
||||
test('should be able to visit public api if signed in', async t => {
|
||||
const res = await request(server)
|
||||
const res = await request(t.context.server)
|
||||
.get('/public')
|
||||
.set('Cookie', `${AuthService.sessionCookieName}=${sessionId}`)
|
||||
.set('Cookie', `${AuthService.sessionCookieName}=${t.context.sessionId}`)
|
||||
.expect(HttpStatus.OK);
|
||||
|
||||
t.is(res.body.user.id, u1.id);
|
||||
t.is(res.body.user.id, t.context.u1.id);
|
||||
});
|
||||
|
||||
test('should not be able to visit private api if not signed in', async t => {
|
||||
await request(server).get('/private').expect(HttpStatus.UNAUTHORIZED).expect({
|
||||
status: 401,
|
||||
code: 'Unauthorized',
|
||||
type: 'AUTHENTICATION_REQUIRED',
|
||||
name: 'AUTHENTICATION_REQUIRED',
|
||||
message: 'You must sign in first to access this resource.',
|
||||
});
|
||||
await request(t.context.server)
|
||||
.get('/private')
|
||||
.expect(HttpStatus.UNAUTHORIZED)
|
||||
.expect({
|
||||
status: 401,
|
||||
code: 'Unauthorized',
|
||||
type: 'AUTHENTICATION_REQUIRED',
|
||||
name: 'AUTHENTICATION_REQUIRED',
|
||||
message: 'You must sign in first to access this resource.',
|
||||
});
|
||||
|
||||
t.assert(true);
|
||||
});
|
||||
|
||||
test('should be able to visit private api if signed in', async t => {
|
||||
const res = await request(server)
|
||||
const res = await request(t.context.server)
|
||||
.get('/private')
|
||||
.set('Cookie', `${AuthService.sessionCookieName}=${sessionId}`)
|
||||
.set('Cookie', `${AuthService.sessionCookieName}=${t.context.sessionId}`)
|
||||
.expect(HttpStatus.OK);
|
||||
|
||||
t.is(res.body.user.id, u1.id);
|
||||
t.is(res.body.user.id, t.context.u1.id);
|
||||
});
|
||||
|
||||
test('should be able to visit private api with access token', async t => {
|
||||
const models = t.context.app.get(Models);
|
||||
const token = await models.accessToken.create({
|
||||
userId: u1.id,
|
||||
userId: t.context.u1.id,
|
||||
name: 'test',
|
||||
});
|
||||
|
||||
const res = await request(server)
|
||||
const res = await request(t.context.server)
|
||||
.get('/private')
|
||||
.set('Authorization', `Bearer ${token.token}`)
|
||||
.expect(HttpStatus.OK);
|
||||
|
||||
t.is(res.body.user.id, u1.id);
|
||||
t.is(res.body.user.id, t.context.u1.id);
|
||||
});
|
||||
|
||||
test('should be able to parse session cookie', async t => {
|
||||
const spy = Sinon.spy(auth, 'getUserSession');
|
||||
await request(server)
|
||||
const spy = Sinon.spy(t.context.auth, 'getUserSession');
|
||||
await request(t.context.server)
|
||||
.get('/public')
|
||||
.set('cookie', `${AuthService.sessionCookieName}=${sessionId}`)
|
||||
.set('cookie', `${AuthService.sessionCookieName}=${t.context.sessionId}`)
|
||||
.expect(200);
|
||||
|
||||
t.deepEqual(spy.firstCall.args, [sessionId, undefined]);
|
||||
t.deepEqual(spy.firstCall.args, [t.context.sessionId, undefined]);
|
||||
spy.restore();
|
||||
});
|
||||
|
||||
test('should be able to parse bearer token', async t => {
|
||||
const spy = Sinon.spy(auth, 'getUserSession');
|
||||
const spy = Sinon.spy(t.context.auth, 'getUserSession');
|
||||
|
||||
await request(server)
|
||||
await request(t.context.server)
|
||||
.get('/public')
|
||||
.auth(sessionId, { type: 'bearer' })
|
||||
.auth(t.context.sessionId, { type: 'bearer' })
|
||||
.expect(200);
|
||||
|
||||
t.deepEqual(spy.firstCall.args, [sessionId, undefined]);
|
||||
t.deepEqual(spy.firstCall.args, [t.context.sessionId, undefined]);
|
||||
spy.restore();
|
||||
});
|
||||
|
||||
test('should be able to refresh session if needed', async t => {
|
||||
await t.context.app.get(PrismaClient).userSession.updateMany({
|
||||
where: {
|
||||
sessionId,
|
||||
sessionId: t.context.sessionId,
|
||||
},
|
||||
data: {
|
||||
expiresAt: new Date(Date.now() + 1000 * 60 * 60 /* expires in 1 hour */),
|
||||
},
|
||||
});
|
||||
|
||||
const res = await request(server)
|
||||
const res = await request(t.context.server)
|
||||
.get('/session')
|
||||
.set('cookie', `${AuthService.sessionCookieName}=${sessionId}`)
|
||||
.set('cookie', `${AuthService.sessionCookieName}=${t.context.sessionId}`)
|
||||
.expect(200);
|
||||
|
||||
const cookie = res
|
||||
@@ -155,3 +174,101 @@ test('should be able to refresh session if needed', async t => {
|
||||
|
||||
t.truthy(cookie);
|
||||
});
|
||||
|
||||
test('should record refresh client version when refreshed', async t => {
|
||||
await t.context.db.userSession.updateMany({
|
||||
where: { sessionId: t.context.sessionId },
|
||||
data: {
|
||||
expiresAt: new Date(Date.now() + 1000 * 60 * 60 /* expires in 1 hour */),
|
||||
},
|
||||
});
|
||||
|
||||
await request(t.context.server)
|
||||
.get('/session')
|
||||
.set('cookie', `${AuthService.sessionCookieName}=${t.context.sessionId}`)
|
||||
.set('x-affine-version', '0.25.2')
|
||||
.expect(200);
|
||||
|
||||
const userSession = await t.context.db.userSession.findFirst({
|
||||
where: { sessionId: t.context.sessionId, userId: t.context.u1.id },
|
||||
});
|
||||
t.is(userSession?.refreshClientVersion, '0.25.2');
|
||||
});
|
||||
|
||||
test('should allow auth when header is missing but stored version is valid', async t => {
|
||||
t.context.config.override({
|
||||
client: {
|
||||
versionControl: {
|
||||
enabled: true,
|
||||
requiredVersion: '>=0.25.0',
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
await t.context.db.userSession.updateMany({
|
||||
where: { sessionId: t.context.sessionId },
|
||||
data: { signInClientVersion: '0.25.0' },
|
||||
});
|
||||
|
||||
const res = await request(t.context.server)
|
||||
.get('/private')
|
||||
.set('Cookie', `${AuthService.sessionCookieName}=${t.context.sessionId}`)
|
||||
.expect(200);
|
||||
|
||||
t.is(res.body.user.id, t.context.u1.id);
|
||||
});
|
||||
|
||||
test('should kick out unsupported client version on non-public handler', async t => {
|
||||
t.context.config.override({
|
||||
client: {
|
||||
versionControl: {
|
||||
enabled: true,
|
||||
requiredVersion: '>=0.25.0',
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const res = await request(t.context.server)
|
||||
.get('/private')
|
||||
.set('Cookie', `${AuthService.sessionCookieName}=${t.context.sessionId}`)
|
||||
.set('x-affine-version', '0.24.0')
|
||||
.expect(403);
|
||||
|
||||
const setCookies = res.get('Set-Cookie') ?? [];
|
||||
t.true(
|
||||
setCookies.some(c => c.startsWith(`${AuthService.sessionCookieName}=`))
|
||||
);
|
||||
t.true(setCookies.some(c => c.startsWith(`${AuthService.userCookieName}=`)));
|
||||
t.true(setCookies.some(c => c.startsWith(`${AuthService.csrfCookieName}=`)));
|
||||
|
||||
const session = await t.context.db.session.findFirst({
|
||||
where: { id: t.context.sessionId },
|
||||
});
|
||||
t.is(session, null);
|
||||
});
|
||||
|
||||
test('should not block public handler when client version is unsupported', async t => {
|
||||
t.context.config.override({
|
||||
client: {
|
||||
versionControl: {
|
||||
enabled: true,
|
||||
requiredVersion: '>=0.25.0',
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const res = await request(t.context.server)
|
||||
.get('/public')
|
||||
.set('Cookie', `${AuthService.sessionCookieName}=${t.context.sessionId}`)
|
||||
.set('x-affine-version', '0.24.0')
|
||||
.expect(200);
|
||||
|
||||
t.is(res.body.user, undefined);
|
||||
|
||||
const setCookies = res.get('Set-Cookie') ?? [];
|
||||
t.true(
|
||||
setCookies.some(c => c.startsWith(`${AuthService.sessionCookieName}=`))
|
||||
);
|
||||
t.true(setCookies.some(c => c.startsWith(`${AuthService.userCookieName}=`)));
|
||||
t.true(setCookies.some(c => c.startsWith(`${AuthService.csrfCookieName}=`)));
|
||||
});
|
||||
|
||||
@@ -122,6 +122,64 @@ test('should refresh exists userSession', async t => {
|
||||
);
|
||||
});
|
||||
|
||||
test('should record sign-in client version on create and update', async t => {
|
||||
const user = await t.context.user.create({
|
||||
email: 'test@affine.pro',
|
||||
});
|
||||
const session = await t.context.session.createSession();
|
||||
|
||||
const userSession1 = await t.context.session.createOrRefreshUserSession(
|
||||
user.id,
|
||||
session.id,
|
||||
undefined,
|
||||
'0.25.0'
|
||||
);
|
||||
t.is(userSession1.signInClientVersion, '0.25.0');
|
||||
|
||||
const userSession2 = await t.context.session.createOrRefreshUserSession(
|
||||
user.id,
|
||||
session.id
|
||||
);
|
||||
t.is(userSession2.signInClientVersion, '0.25.0');
|
||||
|
||||
const userSession3 = await t.context.session.createOrRefreshUserSession(
|
||||
user.id,
|
||||
session.id,
|
||||
undefined,
|
||||
'0.26.0'
|
||||
);
|
||||
t.is(userSession3.signInClientVersion, '0.26.0');
|
||||
});
|
||||
|
||||
test('should record refresh client version only when refreshed', async t => {
|
||||
const user = await t.context.user.create({
|
||||
email: 'test@affine.pro',
|
||||
});
|
||||
const session = await t.context.session.createSession();
|
||||
const userSession = await t.context.session.createOrRefreshUserSession(
|
||||
user.id,
|
||||
session.id
|
||||
);
|
||||
|
||||
// force refresh
|
||||
userSession.expiresAt = new Date(
|
||||
userSession.expiresAt!.getTime() -
|
||||
t.context.config.auth.session.ttr * 2 * 1000
|
||||
);
|
||||
|
||||
const newExpiresAt = await t.context.session.refreshUserSessionIfNeeded(
|
||||
userSession,
|
||||
undefined,
|
||||
'0.25.0'
|
||||
);
|
||||
t.truthy(newExpiresAt);
|
||||
|
||||
const refreshed = await t.context.db.userSession.findFirst({
|
||||
where: { id: userSession.id },
|
||||
});
|
||||
t.is(refreshed?.refreshClientVersion, '0.25.0');
|
||||
});
|
||||
|
||||
test('should not refresh userSession when expires time not hit ttr', async t => {
|
||||
const user = await t.context.user.create({
|
||||
email: 'test@affine.pro',
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { Controller, Get, HttpStatus, UseGuards } from '@nestjs/common';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
import ava, { TestFn } from 'ava';
|
||||
import Sinon from 'sinon';
|
||||
import { type Response } from 'supertest';
|
||||
@@ -144,6 +145,72 @@ test('should be able to prevent requests if limit is reached', async t => {
|
||||
stub.restore();
|
||||
});
|
||||
|
||||
test('should use session id as tracker when available', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
const user = await app.signupV1('u1@affine.pro');
|
||||
const userSession = await app.get(PrismaClient).userSession.findFirst({
|
||||
where: { userId: user.id },
|
||||
});
|
||||
t.truthy(userSession);
|
||||
|
||||
const stub = Sinon.stub(app.get(ThrottlerStorage), 'increment').resolves({
|
||||
timeToExpire: 10,
|
||||
totalHits: 1,
|
||||
isBlocked: false,
|
||||
timeToBlockExpire: 0,
|
||||
});
|
||||
|
||||
await app.GET('/throttled/default').expect(200);
|
||||
|
||||
const key = stub.firstCall.args[0] as string;
|
||||
t.true(key.startsWith(`throttler:${userSession!.sessionId};default`));
|
||||
|
||||
stub.restore();
|
||||
});
|
||||
|
||||
test('should use CF-Connecting-IP as tracker when present', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
const stub = Sinon.stub(app.get(ThrottlerStorage), 'increment').resolves({
|
||||
timeToExpire: 10,
|
||||
totalHits: 1,
|
||||
isBlocked: false,
|
||||
timeToBlockExpire: 0,
|
||||
});
|
||||
|
||||
await app
|
||||
.GET('/nonthrottled/default')
|
||||
.set('CF-Connecting-IP', '1.2.3.4')
|
||||
.expect(200);
|
||||
|
||||
const key = stub.firstCall.args[0] as string;
|
||||
t.true(key.startsWith('throttler:1.2.3.4;default'));
|
||||
|
||||
stub.restore();
|
||||
});
|
||||
|
||||
test('should use X-Forwarded-For as tracker when present', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
const stub = Sinon.stub(app.get(ThrottlerStorage), 'increment').resolves({
|
||||
timeToExpire: 10,
|
||||
totalHits: 1,
|
||||
isBlocked: false,
|
||||
timeToBlockExpire: 0,
|
||||
});
|
||||
|
||||
await app
|
||||
.GET('/nonthrottled/default')
|
||||
.set('X-Forwarded-For', '5.6.7.8, 9.9.9.9')
|
||||
.expect(200);
|
||||
|
||||
const key = stub.firstCall.args[0] as string;
|
||||
t.true(key.startsWith('throttler:5.6.7.8;default'));
|
||||
|
||||
stub.restore();
|
||||
});
|
||||
|
||||
// ====== unauthenticated user visits ======
|
||||
test('should use default throttler for unauthenticated user when not specified', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
@@ -6,7 +6,7 @@ import ava, { TestFn } from 'ava';
|
||||
import Sinon from 'sinon';
|
||||
|
||||
import { AppModule } from '../../app.module';
|
||||
import { URLHelper } from '../../base';
|
||||
import { ConfigFactory, URLHelper } from '../../base';
|
||||
import { ConfigModule } from '../../base/config';
|
||||
import { CurrentUser } from '../../core/auth';
|
||||
import { AuthService } from '../../core/auth/service';
|
||||
@@ -56,6 +56,14 @@ test.before(async t => {
|
||||
test.beforeEach(async t => {
|
||||
Sinon.restore();
|
||||
await t.context.app.initTestingDB();
|
||||
t.context.app.get(ConfigFactory).override({
|
||||
client: {
|
||||
versionControl: {
|
||||
enabled: false,
|
||||
requiredVersion: '>=0.25.0',
|
||||
},
|
||||
},
|
||||
});
|
||||
t.context.u1 = await t.context.auth.signUp('u1@affine.pro', '1');
|
||||
});
|
||||
|
||||
@@ -156,6 +164,56 @@ test('should be able to redirect to oauth provider with client_nonce', async t =
|
||||
t.truthy(state.state);
|
||||
});
|
||||
|
||||
test('should record sign in client version from oauth preflight state', async t => {
|
||||
const { app, db } = t.context;
|
||||
|
||||
const config = app.get(ConfigFactory);
|
||||
config.override({
|
||||
client: {
|
||||
versionControl: {
|
||||
enabled: true,
|
||||
requiredVersion: '>=0.25.0',
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const preflightRes = await app
|
||||
.POST('/api/oauth/preflight')
|
||||
.set('x-affine-version', '0.25.3')
|
||||
.send({ provider: 'Google', client_nonce: 'test-nonce' })
|
||||
.expect(HttpStatus.OK);
|
||||
|
||||
const redirect = new URL(preflightRes.body.url as string);
|
||||
const stateParam = redirect.searchParams.get('state');
|
||||
t.truthy(stateParam);
|
||||
|
||||
// state should be a json string
|
||||
const rawState = JSON.parse(stateParam!);
|
||||
|
||||
const provider = app.get(GoogleOAuthProvider);
|
||||
Sinon.stub(provider, 'getToken').resolves({ accessToken: '1' });
|
||||
Sinon.stub(provider, 'getUser').resolves({
|
||||
id: '1',
|
||||
email: 'oauth-version@affine.pro',
|
||||
avatarUrl: 'avatar',
|
||||
});
|
||||
|
||||
const callbackRes = await app
|
||||
.POST('/api/oauth/callback')
|
||||
.send({ code: '1', state: stateParam, client_nonce: 'test-nonce' })
|
||||
.expect(HttpStatus.OK);
|
||||
|
||||
const userId = callbackRes.body.id as string;
|
||||
t.truthy(userId);
|
||||
|
||||
const userSession = await db.userSession.findFirst({
|
||||
where: { userId },
|
||||
});
|
||||
t.is(userSession?.signInClientVersion, '0.25.3');
|
||||
t.is(userSession?.refreshClientVersion, null);
|
||||
t.truthy(rawState.state);
|
||||
});
|
||||
|
||||
test('should forbid preflight with untrusted redirect_uri', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
|
||||
@@ -73,7 +73,7 @@ test('should passthrough if version check is not enabled', async t => {
|
||||
spy.restore();
|
||||
});
|
||||
|
||||
test('should passthrough is version range is invalid', async t => {
|
||||
test('should enforce hard required version when version range is invalid', async t => {
|
||||
config.override({
|
||||
client: {
|
||||
versionControl: {
|
||||
@@ -82,9 +82,17 @@ test('should passthrough is version range is invalid', async t => {
|
||||
},
|
||||
});
|
||||
|
||||
let res = await app.GET('/guarded/test').set('x-affine-version', 'invalid');
|
||||
let res = await app.GET('/guarded/test').set('x-affine-version', '0.25.0');
|
||||
|
||||
t.is(res.status, 200);
|
||||
|
||||
res = await app.GET('/guarded/test').set('x-affine-version', 'invalid');
|
||||
|
||||
t.is(res.status, 403);
|
||||
t.is(
|
||||
res.body.message,
|
||||
'Unsupported client with version [invalid], required version is [>=0.25.0].'
|
||||
);
|
||||
});
|
||||
|
||||
test('should pass if client version is allowed', async t => {
|
||||
|
||||
@@ -86,6 +86,29 @@ test('can create link', t => {
|
||||
);
|
||||
});
|
||||
|
||||
test('addSimpleQuery should not double encode', t => {
|
||||
t.is(
|
||||
t.context.url.addSimpleQuery(
|
||||
'https://app.affine.local/path',
|
||||
'redirect_uri',
|
||||
'/path'
|
||||
),
|
||||
'https://app.affine.local/path?redirect_uri=%2Fpath'
|
||||
);
|
||||
});
|
||||
|
||||
test('addSimpleQuery should allow unescaped value when escape=false', t => {
|
||||
t.is(
|
||||
t.context.url.addSimpleQuery(
|
||||
'https://app.affine.local/path',
|
||||
'session_id',
|
||||
'{CHECKOUT_SESSION_ID}',
|
||||
false
|
||||
),
|
||||
'https://app.affine.local/path?session_id={CHECKOUT_SESSION_ID}'
|
||||
);
|
||||
});
|
||||
|
||||
test('can validate callbackUrl allowlist', t => {
|
||||
t.true(t.context.url.isAllowedCallbackUrl('/magic-link'));
|
||||
t.true(
|
||||
|
||||
@@ -109,7 +109,7 @@ export class URLHelper {
|
||||
) {
|
||||
const urlObj = new URL(url);
|
||||
if (escape) {
|
||||
urlObj.searchParams.set(key, encodeURIComponent(value));
|
||||
urlObj.searchParams.set(key, String(value));
|
||||
return urlObj.toString();
|
||||
} else {
|
||||
const query =
|
||||
|
||||
@@ -16,6 +16,7 @@ import type { Request, Response } from 'express';
|
||||
|
||||
import { Config } from '../config';
|
||||
import { getRequestResponseFromContext } from '../utils/request';
|
||||
import { getRequestTrackerId } from '../utils/request-tracker';
|
||||
import type { ThrottlerType } from './config';
|
||||
import { THROTTLER_PROTECTED, Throttlers } from './decorators';
|
||||
|
||||
@@ -63,11 +64,8 @@ export class CloudThrottlerGuard extends ThrottlerGuard {
|
||||
}
|
||||
|
||||
override getTracker(req: Request): Promise<string> {
|
||||
return Promise.resolve(
|
||||
// ↓ prefer session id if available
|
||||
`throttler:${req.session?.sessionId ?? req.get('CF-Connecting-IP') ?? req.get('CF-ray') ?? req.ip}`
|
||||
// ^ throttler prefix make the key in store recognizable
|
||||
);
|
||||
// throttler prefix make the key in store recognizable
|
||||
return Promise.resolve(`throttler:${getRequestTrackerId(req)}`);
|
||||
}
|
||||
|
||||
override generateKey(
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
export * from './duration';
|
||||
export * from './promise';
|
||||
export * from './request';
|
||||
export * from './request-tracker';
|
||||
export * from './ssrf';
|
||||
export * from './stream';
|
||||
export * from './types';
|
||||
|
||||
44
packages/backend/server/src/base/utils/request-tracker.ts
Normal file
44
packages/backend/server/src/base/utils/request-tracker.ts
Normal file
@@ -0,0 +1,44 @@
|
||||
import type { Request } from 'express';
|
||||
|
||||
function firstForwardedForIp(value?: string) {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
|
||||
const [first] = value.split(',', 1);
|
||||
const ip = first?.trim();
|
||||
|
||||
return ip || undefined;
|
||||
}
|
||||
|
||||
function firstNonEmpty(...values: Array<string | undefined>) {
|
||||
for (const value of values) {
|
||||
const trimmed = value?.trim();
|
||||
if (trimmed) {
|
||||
return trimmed;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
export function getRequestClientIp(req: Request) {
|
||||
return firstNonEmpty(
|
||||
req.get('CF-Connecting-IP'),
|
||||
firstForwardedForIp(req.get('X-Forwarded-For')),
|
||||
req.get('X-Real-IP'),
|
||||
req.ip
|
||||
)!;
|
||||
}
|
||||
|
||||
export function getRequestTrackerId(req: Request) {
|
||||
return (
|
||||
req.session?.sessionId ??
|
||||
firstNonEmpty(
|
||||
req.get('CF-Connecting-IP'),
|
||||
firstForwardedForIp(req.get('X-Forwarded-For')),
|
||||
req.get('X-Real-IP'),
|
||||
req.get('CF-Ray'),
|
||||
req.ip
|
||||
)!
|
||||
);
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import type {
|
||||
import { Injectable, SetMetadata } from '@nestjs/common';
|
||||
import { ModuleRef, Reflector } from '@nestjs/core';
|
||||
import type { Request, Response } from 'express';
|
||||
import semver from 'semver';
|
||||
import { Socket } from 'socket.io';
|
||||
|
||||
import {
|
||||
@@ -15,8 +16,10 @@ import {
|
||||
Cache,
|
||||
Config,
|
||||
CryptoHelper,
|
||||
getClientVersionFromRequest,
|
||||
getRequestResponseFromContext,
|
||||
parseCookies,
|
||||
UnsupportedClientVersion,
|
||||
} from '../../base';
|
||||
import { WEBSOCKET_OPTIONS } from '../../base/websocket';
|
||||
import { AuthService } from './service';
|
||||
@@ -30,10 +33,13 @@ const INTERNAL_ACCESS_TOKEN_CLOCK_SKEW_MS = 30 * 1000;
|
||||
@Injectable()
|
||||
export class AuthGuard implements CanActivate, OnModuleInit {
|
||||
private auth!: AuthService;
|
||||
private readonly cachedVersionRange = new Map<string, semver.Range | null>();
|
||||
private static readonly HARD_REQUIRED_VERSION = '>=0.25.0';
|
||||
|
||||
constructor(
|
||||
private readonly crypto: CryptoHelper,
|
||||
private readonly cache: Cache,
|
||||
private readonly config: Config,
|
||||
private readonly ref: ModuleRef,
|
||||
private readonly reflector: Reflector
|
||||
) {}
|
||||
@@ -78,14 +84,14 @@ export class AuthGuard implements CanActivate, OnModuleInit {
|
||||
throw new AccessDenied('Invalid internal request');
|
||||
}
|
||||
|
||||
const authedUser = await this.signIn(req, res);
|
||||
|
||||
// api is public
|
||||
const isPublic = this.reflector.getAllAndOverride<boolean>(
|
||||
PUBLIC_ENTRYPOINT_SYMBOL,
|
||||
[clazz, handler]
|
||||
);
|
||||
|
||||
const authedUser = await this.signIn(req, res, isPublic);
|
||||
|
||||
if (isPublic) {
|
||||
return true;
|
||||
}
|
||||
@@ -99,9 +105,10 @@ export class AuthGuard implements CanActivate, OnModuleInit {
|
||||
|
||||
async signIn(
|
||||
req: Request,
|
||||
res?: Response
|
||||
res?: Response,
|
||||
isPublic = false
|
||||
): Promise<Session | TokenSession | null> {
|
||||
const userSession = await this.signInWithCookie(req, res);
|
||||
const userSession = await this.signInWithCookie(req, res, isPublic);
|
||||
if (userSession) {
|
||||
return userSession;
|
||||
}
|
||||
@@ -111,7 +118,8 @@ export class AuthGuard implements CanActivate, OnModuleInit {
|
||||
|
||||
async signInWithCookie(
|
||||
req: Request,
|
||||
res?: Response
|
||||
res?: Response,
|
||||
isPublic = false
|
||||
): Promise<Session | null> {
|
||||
if (req.session) {
|
||||
return req.session;
|
||||
@@ -121,8 +129,38 @@ export class AuthGuard implements CanActivate, OnModuleInit {
|
||||
const userSession = await this.auth.getUserSessionFromRequest(req, res);
|
||||
|
||||
if (userSession) {
|
||||
const headerClientVersion = getClientVersionFromRequest(req);
|
||||
if (this.config.client.versionControl.enabled) {
|
||||
const clientVersion =
|
||||
headerClientVersion ??
|
||||
userSession.session.refreshClientVersion ??
|
||||
userSession.session.signInClientVersion;
|
||||
|
||||
const versionCheckResult = this.checkClientVersion(clientVersion);
|
||||
if (!versionCheckResult.ok) {
|
||||
await this.auth.signOut(userSession.session.sessionId);
|
||||
if (res) {
|
||||
await this.auth.refreshCookies(res, userSession.session.sessionId);
|
||||
}
|
||||
|
||||
if (isPublic) {
|
||||
return null;
|
||||
}
|
||||
|
||||
throw new UnsupportedClientVersion({
|
||||
clientVersion: clientVersion ?? 'unset_or_invalid',
|
||||
requiredVersion: versionCheckResult.requiredVersion,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (res) {
|
||||
await this.auth.refreshUserSessionIfNeeded(res, userSession.session);
|
||||
await this.auth.refreshUserSessionIfNeeded(
|
||||
res,
|
||||
userSession.session,
|
||||
undefined,
|
||||
headerClientVersion
|
||||
);
|
||||
}
|
||||
|
||||
req.session = {
|
||||
@@ -154,6 +192,59 @@ export class AuthGuard implements CanActivate, OnModuleInit {
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
private getVersionRange(versionRange: string): semver.Range | null {
|
||||
if (this.cachedVersionRange.has(versionRange)) {
|
||||
// oxlint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||
return this.cachedVersionRange.get(versionRange)!;
|
||||
}
|
||||
|
||||
let range: semver.Range | null = null;
|
||||
try {
|
||||
range = new semver.Range(versionRange, { loose: false });
|
||||
if (!semver.validRange(range)) {
|
||||
range = null;
|
||||
}
|
||||
} catch {
|
||||
range = null;
|
||||
}
|
||||
|
||||
this.cachedVersionRange.set(versionRange, range);
|
||||
return range;
|
||||
}
|
||||
|
||||
private checkClientVersion(
|
||||
clientVersion?: string | null
|
||||
): { ok: true } | { ok: false; requiredVersion: string } {
|
||||
const requiredVersion = this.config.client.versionControl.requiredVersion;
|
||||
|
||||
const configRange = this.getVersionRange(requiredVersion);
|
||||
if (
|
||||
configRange &&
|
||||
(!clientVersion ||
|
||||
!semver.satisfies(clientVersion, configRange, {
|
||||
includePrerelease: true,
|
||||
}))
|
||||
) {
|
||||
return { ok: false, requiredVersion };
|
||||
}
|
||||
|
||||
const hardRange = this.getVersionRange(AuthGuard.HARD_REQUIRED_VERSION);
|
||||
if (!hardRange) {
|
||||
return { ok: true };
|
||||
}
|
||||
|
||||
if (
|
||||
!clientVersion ||
|
||||
!semver.satisfies(clientVersion, hardRange, {
|
||||
includePrerelease: true,
|
||||
})
|
||||
) {
|
||||
return { ok: false, requiredVersion: AuthGuard.HARD_REQUIRED_VERSION };
|
||||
}
|
||||
|
||||
return { ok: true };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -184,7 +275,13 @@ export const AuthWebsocketOptionsProvider: FactoryProvider = {
|
||||
...upgradeReq.cookies,
|
||||
};
|
||||
|
||||
const session = await guard.signIn(upgradeReq);
|
||||
const session = await (async () => {
|
||||
try {
|
||||
return await guard.signIn(upgradeReq);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
})();
|
||||
|
||||
return !!session;
|
||||
},
|
||||
|
||||
@@ -4,7 +4,11 @@ import { Injectable, OnApplicationBootstrap } from '@nestjs/common';
|
||||
import type { CookieOptions, Request, Response } from 'express';
|
||||
import { assign, pick } from 'lodash-es';
|
||||
|
||||
import { Config, SignUpForbidden } from '../../base';
|
||||
import {
|
||||
Config,
|
||||
getClientVersionFromRequest,
|
||||
SignUpForbidden,
|
||||
} from '../../base';
|
||||
import { Models, type User, type UserSession } from '../../models';
|
||||
import { Mailer } from '../mail/mailer';
|
||||
import { createDevUsers } from './dev';
|
||||
@@ -130,11 +134,17 @@ export class AuthService implements OnApplicationBootstrap {
|
||||
return await this.models.session.findUserSessionsBySessionId(sessionId);
|
||||
}
|
||||
|
||||
async createUserSession(userId: string, sessionId?: string, ttl?: number) {
|
||||
async createUserSession(
|
||||
userId: string,
|
||||
sessionId?: string,
|
||||
ttl?: number,
|
||||
signInClientVersion?: string
|
||||
) {
|
||||
return await this.models.session.createOrRefreshUserSession(
|
||||
userId,
|
||||
sessionId,
|
||||
ttl
|
||||
ttl,
|
||||
signInClientVersion
|
||||
);
|
||||
}
|
||||
|
||||
@@ -159,11 +169,13 @@ export class AuthService implements OnApplicationBootstrap {
|
||||
async refreshUserSessionIfNeeded(
|
||||
res: Response,
|
||||
userSession: UserSession,
|
||||
ttr?: number
|
||||
ttr?: number,
|
||||
refreshClientVersion?: string
|
||||
): Promise<boolean> {
|
||||
const newExpiresAt = await this.models.session.refreshUserSessionIfNeeded(
|
||||
userSession,
|
||||
ttr
|
||||
ttr,
|
||||
refreshClientVersion
|
||||
);
|
||||
if (!newExpiresAt) {
|
||||
// no need to refresh
|
||||
@@ -205,10 +217,22 @@ export class AuthService implements OnApplicationBootstrap {
|
||||
};
|
||||
}
|
||||
|
||||
async setCookies(req: Request, res: Response, userId: string) {
|
||||
async setCookies(
|
||||
req: Request,
|
||||
res: Response,
|
||||
userId: string,
|
||||
clientVersion?: string
|
||||
) {
|
||||
const { sessionId } = this.getSessionOptionsFromRequest(req);
|
||||
|
||||
const userSession = await this.createUserSession(userId, sessionId);
|
||||
const signInClientVersion =
|
||||
clientVersion ?? getClientVersionFromRequest(req);
|
||||
const userSession = await this.createUserSession(
|
||||
userId,
|
||||
sessionId,
|
||||
undefined,
|
||||
signInClientVersion
|
||||
);
|
||||
|
||||
res.cookie(AuthService.sessionCookieName, userSession.sessionId, {
|
||||
...this.cookieOptions,
|
||||
|
||||
@@ -7,6 +7,7 @@ import { Injectable } from '@nestjs/common';
|
||||
|
||||
import {
|
||||
Config,
|
||||
getClientVersionFromRequest,
|
||||
getRequestResponseFromContext,
|
||||
GuardProvider,
|
||||
} from '../../base';
|
||||
@@ -33,7 +34,7 @@ export class VersionGuardProvider
|
||||
|
||||
const { req } = getRequestResponseFromContext(context);
|
||||
|
||||
const version = req.headers['x-affine-version'] as string | undefined;
|
||||
const version = getClientVersionFromRequest(req);
|
||||
|
||||
return this.version.checkVersion(version);
|
||||
}
|
||||
|
||||
@@ -6,23 +6,24 @@ import { Config, UnsupportedClientVersion } from '../../base';
|
||||
@Injectable()
|
||||
export class VersionService {
|
||||
private readonly logger = new Logger(VersionService.name);
|
||||
private static readonly HARD_REQUIRED_VERSION = '>=0.25.0';
|
||||
|
||||
constructor(private readonly config: Config) {}
|
||||
|
||||
async checkVersion(clientVersion?: string) {
|
||||
const requiredVersion = this.config.client.versionControl.requiredVersion;
|
||||
|
||||
const range = await this.getVersionRange(requiredVersion);
|
||||
if (!range) {
|
||||
// ignore invalid allowed version config
|
||||
return true;
|
||||
}
|
||||
const hardRange = await this.getVersionRange(
|
||||
VersionService.HARD_REQUIRED_VERSION
|
||||
);
|
||||
const configRange = await this.getVersionRange(requiredVersion);
|
||||
|
||||
if (
|
||||
!clientVersion ||
|
||||
!semver.satisfies(clientVersion, range, {
|
||||
includePrerelease: true,
|
||||
})
|
||||
configRange &&
|
||||
(!clientVersion ||
|
||||
!semver.satisfies(clientVersion, configRange, {
|
||||
includePrerelease: true,
|
||||
}))
|
||||
) {
|
||||
throw new UnsupportedClientVersion({
|
||||
clientVersion: clientVersion ?? 'unset_or_invalid',
|
||||
@@ -30,6 +31,19 @@ export class VersionService {
|
||||
});
|
||||
}
|
||||
|
||||
if (
|
||||
hardRange &&
|
||||
(!clientVersion ||
|
||||
!semver.satisfies(clientVersion, hardRange, {
|
||||
includePrerelease: true,
|
||||
}))
|
||||
) {
|
||||
throw new UnsupportedClientVersion({
|
||||
clientVersion: clientVersion ?? 'unset_or_invalid',
|
||||
requiredVersion: VersionService.HARD_REQUIRED_VERSION,
|
||||
});
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@@ -46,7 +46,8 @@ export class SessionModel extends BaseModel {
|
||||
async createOrRefreshUserSession(
|
||||
userId: string,
|
||||
sessionId?: string,
|
||||
ttl = this.config.auth.session.ttl
|
||||
ttl = this.config.auth.session.ttl,
|
||||
signInClientVersion?: string
|
||||
) {
|
||||
// check whether given session is valid
|
||||
if (sessionId) {
|
||||
@@ -76,18 +77,21 @@ export class SessionModel extends BaseModel {
|
||||
},
|
||||
update: {
|
||||
expiresAt,
|
||||
...(signInClientVersion ? { signInClientVersion } : {}),
|
||||
},
|
||||
create: {
|
||||
sessionId,
|
||||
userId,
|
||||
expiresAt,
|
||||
...(signInClientVersion ? { signInClientVersion } : {}),
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
async refreshUserSessionIfNeeded(
|
||||
userSession: UserSession,
|
||||
ttr = this.config.auth.session.ttr
|
||||
ttr = this.config.auth.session.ttr,
|
||||
refreshClientVersion?: string
|
||||
): Promise<Date | undefined> {
|
||||
if (
|
||||
userSession.expiresAt &&
|
||||
@@ -106,6 +110,7 @@ export class SessionModel extends BaseModel {
|
||||
},
|
||||
data: {
|
||||
expiresAt: newExpiresAt,
|
||||
...(refreshClientVersion ? { refreshClientVersion } : {}),
|
||||
},
|
||||
});
|
||||
|
||||
|
||||
@@ -5,7 +5,12 @@ import type { Request } from 'express';
|
||||
import { nanoid } from 'nanoid';
|
||||
import { z } from 'zod';
|
||||
|
||||
import { CaptchaVerificationFailed, Config, OnEvent } from '../../base';
|
||||
import {
|
||||
CaptchaVerificationFailed,
|
||||
Config,
|
||||
getRequestClientIp,
|
||||
OnEvent,
|
||||
} from '../../base';
|
||||
import { ServerFeature, ServerService } from '../../core';
|
||||
import { Models, TokenType } from '../../models';
|
||||
import { verifyChallengeResponse } from '../../native';
|
||||
@@ -133,7 +138,7 @@ export class CaptchaService {
|
||||
} else {
|
||||
const isTokenVerified = await this.verifyCaptchaToken(
|
||||
credential.token,
|
||||
req.headers['CF-Connecting-IP'] as string
|
||||
getRequestClientIp(req)
|
||||
);
|
||||
|
||||
if (!isTokenVerified) {
|
||||
|
||||
@@ -15,6 +15,7 @@ import type { Request, Response } from 'express';
|
||||
import {
|
||||
ActionForbidden,
|
||||
Config,
|
||||
getClientVersionFromRequest,
|
||||
InvalidAuthState,
|
||||
InvalidOauthCallbackState,
|
||||
MissingOauthQueryParameter,
|
||||
@@ -50,6 +51,7 @@ export class OAuthController {
|
||||
@Post('/preflight')
|
||||
@HttpCode(HttpStatus.OK)
|
||||
async preflight(
|
||||
@Req() req: Request,
|
||||
@Body('provider') unknownProviderName?: keyof typeof OAuthProviderName,
|
||||
@Body('redirect_uri') redirectUri?: string,
|
||||
@Body('client') client?: string,
|
||||
@@ -75,11 +77,13 @@ export class OAuthController {
|
||||
throw new ActionForbidden();
|
||||
}
|
||||
|
||||
const clientVersion = getClientVersionFromRequest(req);
|
||||
const state = await this.oauth.saveOAuthState({
|
||||
provider: providerName,
|
||||
redirectUri,
|
||||
client,
|
||||
clientNonce,
|
||||
clientVersion,
|
||||
...(pkce
|
||||
? {
|
||||
pkce: {
|
||||
@@ -220,7 +224,7 @@ export class OAuthController {
|
||||
tokens
|
||||
);
|
||||
|
||||
await this.auth.setCookies(req, res, user.id);
|
||||
await this.auth.setCookies(req, res, user.id, state.clientVersion);
|
||||
|
||||
if (
|
||||
state.provider === OAuthProviderName.Apple &&
|
||||
|
||||
@@ -13,6 +13,7 @@ export interface OAuthState {
|
||||
redirectUri?: string;
|
||||
client?: string;
|
||||
clientNonce?: string;
|
||||
clientVersion?: string;
|
||||
provider: OAuthProviderName;
|
||||
pkce?: OAuthPkceState;
|
||||
token?: string;
|
||||
|
||||
@@ -87,7 +87,7 @@ export class SelfhostTeamSubscriptionManager extends SubscriptionManager {
|
||||
return { allow_promotion_codes: true };
|
||||
})();
|
||||
|
||||
let successUrl = this.url.link(params.successCallbackLink);
|
||||
let successUrl = this.url.safeLink(params.successCallbackLink || '/');
|
||||
// stripe only accept unescaped '{CHECKOUT_SESSION_ID}' as query
|
||||
successUrl = this.url.addSimpleQuery(
|
||||
successUrl,
|
||||
|
||||
@@ -204,7 +204,7 @@ export class UserSubscriptionManager extends SubscriptionManager {
|
||||
],
|
||||
...mode,
|
||||
...discounts,
|
||||
success_url: this.url.link(params.successCallbackLink),
|
||||
success_url: this.url.safeLink(params.successCallbackLink || '/'),
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -120,7 +120,7 @@ export class WorkspaceSubscriptionManager extends SubscriptionManager {
|
||||
},
|
||||
},
|
||||
...discounts,
|
||||
success_url: this.url.link(params.successCallbackLink),
|
||||
success_url: this.url.safeLink(params.successCallbackLink || '/'),
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user