diff --git a/packages/backend/server/src/app.module.ts b/packages/backend/server/src/app.module.ts index 1382ef0957..45f86a2d54 100644 --- a/packages/backend/server/src/app.module.ts +++ b/packages/backend/server/src/app.module.ts @@ -152,6 +152,8 @@ function buildAppModule() { factor // common fundamental modules .use(...FunctionalityModules) + .useIf(config => config.flavor.sync, WebSocketModule) + // auth .use(AuthModule) @@ -159,7 +161,7 @@ function buildAppModule() { .use(DocModule) // sync server only - .useIf(config => config.flavor.sync, WebSocketModule, SyncModule) + .useIf(config => config.flavor.sync, SyncModule) // graphql server only .useIf( diff --git a/packages/backend/server/src/core/auth/current-user.ts b/packages/backend/server/src/core/auth/current-user.ts index b6757314f1..ca736fb303 100644 --- a/packages/backend/server/src/core/auth/current-user.ts +++ b/packages/backend/server/src/core/auth/current-user.ts @@ -1,6 +1,6 @@ import type { ExecutionContext } from '@nestjs/common'; import { createParamDecorator } from '@nestjs/common'; -import { User } from '@prisma/client'; +import { User, UserSession } from '@prisma/client'; import { getRequestResponseFromContext } from '../../fundamentals'; @@ -53,3 +53,5 @@ export interface CurrentUser hasPassword: boolean | null; emailVerified: boolean; } + +export { type UserSession }; diff --git a/packages/backend/server/src/core/auth/guard.ts b/packages/backend/server/src/core/auth/guard.ts index 25679c40ec..bc3dd128ac 100644 --- a/packages/backend/server/src/core/auth/guard.ts +++ b/packages/backend/server/src/core/auth/guard.ts @@ -1,15 +1,22 @@ import type { CanActivate, ExecutionContext, + FactoryProvider, OnModuleInit, } from '@nestjs/common'; import { Injectable, SetMetadata, UseGuards } from '@nestjs/common'; import { ModuleRef, Reflector } from '@nestjs/core'; +import type { Request } from 'express'; import { AuthenticationRequired, + Config, getRequestResponseFromContext, + mapAnyError, + parseCookies, } from '../../fundamentals'; +import { WEBSOCKET_OPTIONS } from '../../fundamentals/websocket'; +import { CurrentUser, UserSession } from './current-user'; import { AuthService, parseAuthUserSeqNum } from './service'; function extractTokenFromHeader(authorization: string) { @@ -38,37 +45,9 @@ export class AuthGuard implements CanActivate, OnModuleInit { async canActivate(context: ExecutionContext) { const { req, res } = getRequestResponseFromContext(context); - // check cookie - let sessionToken: string | undefined = - req.cookies[AuthService.sessionCookieName]; - - if (!sessionToken && req.headers.authorization) { - sessionToken = extractTokenFromHeader(req.headers.authorization); - } - - if (sessionToken) { - const userSeq = parseAuthUserSeqNum( - req.headers[AuthService.authUserSeqHeaderName] - ); - - const { user, expiresAt } = await this.auth.getUser( - sessionToken, - userSeq - ); - if (res && user && expiresAt) { - await this.auth.refreshUserSessionIfNeeded( - req, - res, - sessionToken, - user.id, - expiresAt - ); - } - - if (user) { - req.sid = sessionToken; - req.user = user; - } + const userSession = await this.signIn(req); + if (res && userSession && userSession.session.expiresAt) { + await this.auth.refreshUserSessionIfNeeded(req, res, userSession.session); } // api is public @@ -84,9 +63,44 @@ export class AuthGuard implements CanActivate, OnModuleInit { if (!req.user) { throw new AuthenticationRequired(); } - return true; } + + async signIn( + req: Request + ): Promise<{ user: CurrentUser; session: UserSession } | null> { + if (req.user && req.session) { + return { + user: req.user, + session: req.session, + }; + } + + parseCookies(req); + let sessionToken: string | undefined = + req.cookies[AuthService.sessionCookieName]; + + if (!sessionToken && req.headers.authorization) { + sessionToken = extractTokenFromHeader(req.headers.authorization); + } + + if (sessionToken) { + const userSeq = parseAuthUserSeqNum( + req.headers[AuthService.authUserSeqHeaderName] + ); + + const userSession = await this.auth.getUserSession(sessionToken, userSeq); + + if (userSession) { + req.session = userSession.session; + req.user = userSession.user; + } + + return userSession; + } + + return null; + } } /** @@ -111,3 +125,35 @@ export const Auth = () => { // api is public accessible export const Public = () => SetMetadata(PUBLIC_ENTRYPOINT_SYMBOL, true); + +export const AuthWebsocketOptionsProvider: FactoryProvider = { + provide: WEBSOCKET_OPTIONS, + useFactory: (config: Config, guard: AuthGuard) => { + return { + ...config.websocket, + allowRequest: async ( + req: any, + pass: (err: string | null | undefined, success: boolean) => void + ) => { + if (!config.websocket.requireAuthentication) { + return pass(null, true); + } + + try { + const authentication = await guard.signIn(req); + + if (authentication) { + return pass(null, true); + } else { + return pass('unauthenticated', false); + } + } catch (e) { + const error = mapAnyError(e); + error.log('Websocket'); + return pass('unauthenticated', false); + } + }, + }; + }, + inject: [Config, AuthGuard], +}; diff --git a/packages/backend/server/src/core/auth/index.ts b/packages/backend/server/src/core/auth/index.ts index c1551d6752..7244027533 100644 --- a/packages/backend/server/src/core/auth/index.ts +++ b/packages/backend/server/src/core/auth/index.ts @@ -6,15 +6,21 @@ import { FeatureModule } from '../features'; import { QuotaModule } from '../quota'; import { UserModule } from '../user'; import { AuthController } from './controller'; -import { AuthGuard } from './guard'; +import { AuthGuard, AuthWebsocketOptionsProvider } from './guard'; import { AuthResolver } from './resolver'; import { AuthService } from './service'; import { TokenService, TokenType } from './token'; @Module({ imports: [FeatureModule, UserModule, QuotaModule], - providers: [AuthService, AuthResolver, TokenService, AuthGuard], - exports: [AuthService, AuthGuard], + providers: [ + AuthService, + AuthResolver, + TokenService, + AuthGuard, + AuthWebsocketOptionsProvider, + ], + exports: [AuthService, AuthGuard, AuthWebsocketOptionsProvider], controllers: [AuthController], }) export class AuthModule {} diff --git a/packages/backend/server/src/core/auth/service.ts b/packages/backend/server/src/core/auth/service.ts index 9e261188ec..c372bb80c1 100644 --- a/packages/backend/server/src/core/auth/service.ts +++ b/packages/backend/server/src/core/auth/service.ts @@ -1,6 +1,6 @@ import { Injectable, OnApplicationBootstrap } from '@nestjs/common'; import { Cron, CronExpression } from '@nestjs/schedule'; -import type { User } from '@prisma/client'; +import type { User, UserSession } from '@prisma/client'; import { PrismaClient } from '@prisma/client'; import type { CookieOptions, Request, Response } from 'express'; import { assign, omit } from 'lodash-es'; @@ -121,27 +121,27 @@ export class AuthService implements OnApplicationBootstrap { return sessionUser(user); } - async getUser( + async getUserSession( token: string, seq = 0 - ): Promise<{ user: CurrentUser | null; expiresAt: Date | null }> { + ): Promise<{ user: CurrentUser; session: UserSession } | null> { const session = await this.getSession(token); // no such session if (!session) { - return { user: null, expiresAt: null }; + return null; } const userSession = session.userSessions.at(seq); // no such user session if (!userSession) { - return { user: null, expiresAt: null }; + return null; } // user session expired if (userSession.expiresAt && userSession.expiresAt <= new Date()) { - return { user: null, expiresAt: null }; + return null; } const user = await this.db.user.findUnique({ @@ -149,10 +149,10 @@ export class AuthService implements OnApplicationBootstrap { }); if (!user) { - return { user: null, expiresAt: null }; + return null; } - return { user: sessionUser(user), expiresAt: userSession.expiresAt }; + return { user: sessionUser(user), session: userSession }; } async getUserList(token: string) { @@ -251,12 +251,13 @@ export class AuthService implements OnApplicationBootstrap { async refreshUserSessionIfNeeded( _req: Request, res: Response, - sessionId: string, - userId: string, - expiresAt: Date, + session: UserSession, ttr = this.config.auth.session.ttr ): Promise { - if (expiresAt && expiresAt.getTime() - Date.now() > ttr * 1000) { + if ( + session.expiresAt && + session.expiresAt.getTime() - Date.now() > ttr * 1000 + ) { // no need to refresh return false; } @@ -267,17 +268,14 @@ export class AuthService implements OnApplicationBootstrap { await this.db.userSession.update({ where: { - sessionId_userId: { - sessionId, - userId, - }, + id: session.id, }, data: { expiresAt: newExpiresAt, }, }); - res.cookie(AuthService.sessionCookieName, sessionId, { + res.cookie(AuthService.sessionCookieName, session.sessionId, { expires: newExpiresAt, ...this.cookieOptions, }); diff --git a/packages/backend/server/src/core/sync/events/events.gateway.ts b/packages/backend/server/src/core/sync/events/events.gateway.ts index 78e66cee35..73aa53e372 100644 --- a/packages/backend/server/src/core/sync/events/events.gateway.ts +++ b/packages/backend/server/src/core/sync/events/events.gateway.ts @@ -50,12 +50,7 @@ function Awareness(workspaceId: string): `${string}:awareness` { return `${workspaceId}:awareness`; } -@WebSocketGateway({ - cors: !AFFiNE.node.prod, - transports: ['websocket'], - // see: https://socket.io/docs/v4/server-options/#maxhttpbuffersize - maxHttpBufferSize: 1e8, // 100 MB -}) +@WebSocketGateway() export class EventsGateway implements OnGatewayConnection, OnGatewayDisconnect { protected logger = new Logger(EventsGateway.name); private connectionCount = 0; diff --git a/packages/backend/server/src/fundamentals/index.ts b/packages/backend/server/src/fundamentals/index.ts index 98059cefc8..b55bd0b117 100644 --- a/packages/backend/server/src/fundamentals/index.ts +++ b/packages/backend/server/src/fundamentals/index.ts @@ -36,5 +36,6 @@ export { getRequestFromHost, getRequestResponseFromContext, getRequestResponseFromHost, + parseCookies, } from './utils/request'; export type * from './utils/types'; diff --git a/packages/backend/server/src/fundamentals/nestjs/exception.ts b/packages/backend/server/src/fundamentals/nestjs/exception.ts index 2460cb32ea..e553d22da1 100644 --- a/packages/backend/server/src/fundamentals/nestjs/exception.ts +++ b/packages/backend/server/src/fundamentals/nestjs/exception.ts @@ -2,8 +2,10 @@ import { ArgumentsHost, Catch, Logger } from '@nestjs/common'; import { BaseExceptionFilter } from '@nestjs/core'; import { GqlContextType } from '@nestjs/graphql'; import { ThrottlerException } from '@nestjs/throttler'; +import { BaseWsExceptionFilter } from '@nestjs/websockets'; import { Response } from 'express'; import { of } from 'rxjs'; +import { Socket } from 'socket.io'; import { InternalServerError, @@ -44,6 +46,20 @@ export class GlobalExceptionFilter extends BaseExceptionFilter { } } +export class GlobalWsExceptionFilter extends BaseWsExceptionFilter { + // @ts-expect-error satisfies the override + override handleError(client: Socket, exception: any): void { + const error = mapAnyError(exception); + error.log('Websocket'); + metrics.socketio + .counter('unhandled_error') + .add(1, { status: error.status }); + client.emit('error', { + error: toWebsocketError(error), + }); + } +} + /** * Only exists for websocket error body backward compatibility * diff --git a/packages/backend/server/src/fundamentals/throttler/index.ts b/packages/backend/server/src/fundamentals/throttler/index.ts index fd68967a3f..f7a4993a08 100644 --- a/packages/backend/server/src/fundamentals/throttler/index.ts +++ b/packages/backend/server/src/fundamentals/throttler/index.ts @@ -57,7 +57,7 @@ export class CloudThrottlerGuard extends ThrottlerGuard { override getTracker(req: Request): Promise { return Promise.resolve( // ↓ prefer session id if available - `throttler:${req.sid ?? req.get('CF-Connecting-IP') ?? req.get('CF-ray') ?? req.ip}` + `throttler:${req.session?.sessionId ?? req.get('CF-Connecting-IP') ?? req.get('CF-ray') ?? req.ip}` // ^ throttler prefix make the key in store recognizable ); } diff --git a/packages/backend/server/src/fundamentals/utils/request.ts b/packages/backend/server/src/fundamentals/utils/request.ts index 37d965e3eb..7ab04be3f1 100644 --- a/packages/backend/server/src/fundamentals/utils/request.ts +++ b/packages/backend/server/src/fundamentals/utils/request.ts @@ -66,3 +66,29 @@ export function getRequestFromHost(host: ArgumentsHost) { export function getRequestResponseFromContext(ctx: ExecutionContext) { return getRequestResponseFromHost(ctx); } + +/** + * simple patch for request not protected by `cookie-parser` + * only take effect if `req.cookies` is not defined + */ +export function parseCookies(req: Request) { + if (req.cookies) { + return; + } + + const cookieStr = req?.headers?.cookie ?? ''; + req.cookies = cookieStr.split(';').reduce( + (cookies, cookie) => { + const [key, val] = cookie.split('='); + + if (key) { + cookies[decodeURIComponent(key.trim())] = val + ? decodeURIComponent(val.trim()) + : val; + } + + return cookies; + }, + {} as Record + ); +} diff --git a/packages/backend/server/src/fundamentals/websocket/config.ts b/packages/backend/server/src/fundamentals/websocket/config.ts new file mode 100644 index 0000000000..2aa49aa50f --- /dev/null +++ b/packages/backend/server/src/fundamentals/websocket/config.ts @@ -0,0 +1,20 @@ +import { GatewayMetadata } from '@nestjs/websockets'; + +import { defineStartupConfig, ModuleConfig } from '../config'; + +declare module '../config' { + interface AppConfig { + websocket: ModuleConfig< + GatewayMetadata & { + requireAuthentication?: boolean; + } + >; + } +} + +defineStartupConfig('websocket', { + // see: https://socket.io/docs/v4/server-options/#maxhttpbuffersize + transports: ['websocket'], + maxHttpBufferSize: 1e8, // 100 MB + requireAuthentication: true, +}); diff --git a/packages/backend/server/src/fundamentals/websocket/index.ts b/packages/backend/server/src/fundamentals/websocket/index.ts index f216a05e01..2446677636 100644 --- a/packages/backend/server/src/fundamentals/websocket/index.ts +++ b/packages/backend/server/src/fundamentals/websocket/index.ts @@ -1,17 +1,46 @@ -import { Module, Provider } from '@nestjs/common'; +import './config'; + +import { + FactoryProvider, + INestApplicationContext, + Module, + Provider, +} from '@nestjs/common'; import { IoAdapter } from '@nestjs/platform-socket.io'; +import { Server } from 'socket.io'; + +import { Config } from '../config'; export const SocketIoAdapterImpl = Symbol('SocketIoAdapterImpl'); -export class SocketIoAdapter extends IoAdapter {} +export class SocketIoAdapter extends IoAdapter { + constructor(protected readonly app: INestApplicationContext) { + super(app); + } + + override createIOServer(port: number, options?: any): Server { + const config = this.app.get(WEBSOCKET_OPTIONS); + return super.createIOServer(port, { ...config, ...options }); + } +} const SocketIoAdapterImplProvider: Provider = { provide: SocketIoAdapterImpl, useValue: SocketIoAdapter, }; +export const WEBSOCKET_OPTIONS = Symbol('WEBSOCKET_OPTIONS'); + +export const websocketOptionsProvider: FactoryProvider = { + provide: WEBSOCKET_OPTIONS, + useFactory: (config: Config) => { + return config.websocket; + }, + inject: [Config], +}; + @Module({ - providers: [SocketIoAdapterImplProvider], - exports: [SocketIoAdapterImplProvider], + providers: [SocketIoAdapterImplProvider, websocketOptionsProvider], + exports: [SocketIoAdapterImplProvider, websocketOptionsProvider], }) export class WebSocketModule {} diff --git a/packages/backend/server/src/global.d.ts b/packages/backend/server/src/global.d.ts index fecd5d8551..700bbc61f0 100644 --- a/packages/backend/server/src/global.d.ts +++ b/packages/backend/server/src/global.d.ts @@ -1,7 +1,7 @@ declare namespace Express { interface Request { user?: import('./core/auth/current-user').CurrentUser; - sid?: string; + session?: import('./core/auth/current-user').UserSession; } } diff --git a/packages/backend/server/src/plugins/redis/ws-adapter.ts b/packages/backend/server/src/plugins/redis/ws-adapter.ts index 4633fc235c..8f7c646c84 100644 --- a/packages/backend/server/src/plugins/redis/ws-adapter.ts +++ b/packages/backend/server/src/plugins/redis/ws-adapter.ts @@ -18,7 +18,7 @@ export function createSockerIoAdapterImpl( console.error(err); }); - const server = super.createIOServer(port, options) as Server; + const server = super.createIOServer(port, options); server.adapter(createAdapter(pubClient, subClient)); return server; } diff --git a/packages/backend/server/tests/auth/guard.spec.ts b/packages/backend/server/tests/auth/guard.spec.ts index 3acfac4649..23e276a559 100644 --- a/packages/backend/server/tests/auth/guard.spec.ts +++ b/packages/backend/server/tests/auth/guard.spec.ts @@ -69,7 +69,7 @@ test('should be able to visit public api if signed in', async t => { const { app, auth } = t.context; // @ts-expect-error mock - auth.getUser.resolves({ user: { id: '1' } }); + auth.getUserSession.resolves({ user: { id: '1' }, session: { id: '1' } }); const res = await request(app.getHttpServer()) .get('/public') @@ -100,7 +100,7 @@ test('should be able to visit private api if signed in', async t => { const { app, auth } = t.context; // @ts-expect-error mock - auth.getUser.resolves({ user: { id: '1' } }); + auth.getUserSession.resolves({ user: { id: '1' }, session: { id: '1' } }); const res = await request(app.getHttpServer()) .get('/private') @@ -114,26 +114,26 @@ test('should be able to parse session cookie', async t => { const { app, auth } = t.context; // @ts-expect-error mock - auth.getUser.resolves({ user: { id: '1' } }); + auth.getUserSession.resolves({ user: { id: '1' }, session: { id: '1' } }); await request(app.getHttpServer()) .get('/public') .set('cookie', `${AuthService.sessionCookieName}=1`) .expect(200); - t.deepEqual(auth.getUser.firstCall.args, ['1', 0]); + t.deepEqual(auth.getUserSession.firstCall.args, ['1', 0]); }); test('should be able to parse bearer token', async t => { const { app, auth } = t.context; // @ts-expect-error mock - auth.getUser.resolves({ user: { id: '1' } }); + auth.getUserSession.resolves({ user: { id: '1' }, session: { id: '1' } }); await request(app.getHttpServer()) .get('/public') .auth('1', { type: 'bearer' }) .expect(200); - t.deepEqual(auth.getUser.firstCall.args, ['1', 0]); + t.deepEqual(auth.getUserSession.firstCall.args, ['1', 0]); }); diff --git a/packages/backend/server/tests/auth/service.spec.ts b/packages/backend/server/tests/auth/service.spec.ts index de6ee350da..eecd783056 100644 --- a/packages/backend/server/tests/auth/service.spec.ts +++ b/packages/backend/server/tests/auth/service.spec.ts @@ -157,10 +157,10 @@ test('should be able to get user from session', async t => { const session = await auth.createUserSession(u1); - const { user } = await auth.getUser(session.sessionId); + const userSession = await auth.getUserSession(session.sessionId); - t.not(user, null); - t.is(user!.id, u1.id); + t.not(userSession, null); + t.is(userSession!.user.id, u1.id); }); test('should be able to sign out session', async t => { @@ -203,19 +203,19 @@ test('should be able to signout multi accounts session', async t => { t.not(signedOutSession, null); - const { user: signedU2 } = await auth.getUser(session.sessionId, 0); - const { user: noUser } = await auth.getUser(session.sessionId, 1); + const userSession1 = await auth.getUserSession(session.sessionId, 0); + const userSession2 = await auth.getUserSession(session.sessionId, 1); - t.is(noUser, null); - t.not(signedU2, null); + t.is(userSession2, null); + t.not(userSession1, null); - t.is(signedU2!.id, u2.id); + t.is(userSession1!.user.id, u2.id); // sign out user at seq(0) signedOutSession = await auth.signOut(session.sessionId); t.is(signedOutSession, null); - const { user: noUser2 } = await auth.getUser(session.sessionId, 0); - t.is(noUser2, null); + const userSession3 = await auth.getUserSession(session.sessionId, 0); + t.is(userSession3, null); }); diff --git a/packages/backend/server/tests/oauth/controller.spec.ts b/packages/backend/server/tests/oauth/controller.spec.ts index dff50c2c9c..8c9e1ff550 100644 --- a/packages/backend/server/tests/oauth/controller.spec.ts +++ b/packages/backend/server/tests/oauth/controller.spec.ts @@ -341,8 +341,10 @@ test('should throw if oauth account already connected', async t => { }, }); - // @ts-expect-error mock - Sinon.stub(auth, 'getUser').resolves({ user: { id: 'u2-id' } }); + Sinon.stub(auth, 'getUserSession').resolves({ + user: { id: 'u2-id' }, + session: {}, + } as any); mockOAuthProvider(app, 'u2@affine.pro'); @@ -363,8 +365,10 @@ test('should throw if oauth account already connected', async t => { test('should be able to connect oauth account', async t => { const { app, u1, auth, db } = t.context; - // @ts-expect-error mock - Sinon.stub(auth, 'getUser').resolves({ user: { id: u1.id } }); + Sinon.stub(auth, 'getUserSession').resolves({ + user: { id: u1.id }, + session: {}, + } as any); mockOAuthProvider(app, u1.email);