From 1ddae40fb2296fd29d4d0522667a3c18ce1d50b1 Mon Sep 17 00:00:00 2001 From: DarkSky <25152247+darkskygit@users.noreply.github.com> Date: Thu, 21 Sep 2023 21:05:26 +0800 Subject: [PATCH] feat: add auth support for websocket (#4445) --- .../src/modules/sync/events/events.gateway.ts | 69 ++++++++++++++++--- .../src/modules/sync/events/events.module.ts | 3 +- apps/server/src/utils/nestjs.ts | 23 +++++++ 3 files changed, 86 insertions(+), 9 deletions(-) diff --git a/apps/server/src/modules/sync/events/events.gateway.ts b/apps/server/src/modules/sync/events/events.gateway.ts index 81daedd2f3..30f76972f4 100644 --- a/apps/server/src/modules/sync/events/events.gateway.ts +++ b/apps/server/src/modules/sync/events/events.gateway.ts @@ -1,3 +1,4 @@ +import { Logger } from '@nestjs/common'; import { ConnectedSocket, MessageBody, @@ -12,18 +13,24 @@ import { encodeStateAsUpdate, encodeStateVector } from 'yjs'; import { Metrics } from '../../../metrics/metrics'; import { trimGuid } from '../../../utils/doc'; +import { Auth, CurrentUser } from '../../auth'; import { DocManager } from '../../doc'; +import { UserType } from '../../users'; +import { PermissionService } from '../../workspaces/permission'; +import { Permission } from '../../workspaces/types'; @WebSocketGateway({ cors: process.env.NODE_ENV !== 'production', transports: ['websocket'], }) export class EventsGateway implements OnGatewayConnection, OnGatewayDisconnect { + protected logger = new Logger(EventsGateway.name); private connectionCount = 0; constructor( private readonly docManager: DocManager, - private readonly metric: Metrics + private readonly metric: Metrics, + private readonly permissions: PermissionService ) {} @WebSocketServer() @@ -39,8 +46,10 @@ export class EventsGateway implements OnGatewayConnection, OnGatewayDisconnect { this.metric.socketIOConnectionGauge(this.connectionCount, {}); } + @Auth() @SubscribeMessage('client-handshake') async handleClientHandShake( + @CurrentUser() user: UserType, @MessageBody() workspaceId: string, @ConnectedSocket() client: Socket ) { @@ -48,8 +57,16 @@ export class EventsGateway implements OnGatewayConnection, OnGatewayDisconnect { const endTimer = this.metric.socketIOEventTimer({ event: 'client-handshake', }); - await client.join(workspaceId); + + const canWrite = await this.permissions.tryCheck( + workspaceId, + user.id, + Permission.Write + ); + if (canWrite) await client.join(workspaceId); + endTimer(); + return canWrite; } @SubscribeMessage('client-leave') @@ -77,26 +94,49 @@ export class EventsGateway implements OnGatewayConnection, OnGatewayDisconnect { ) { this.metric.socketIOEventCounter(1, { event: 'client-update' }); const endTimer = this.metric.socketIOEventTimer({ event: 'client-update' }); + if (!client.rooms.has(message.workspaceId)) { + this.logger.verbose( + `Client ${client.id} tried to push update to workspace ${message.workspaceId} without joining it first` + ); + endTimer(); + return; + } + const update = Buffer.from(message.update, 'base64'); client.to(message.workspaceId).emit('server-update', message); - const guid = trimGuid(message.workspaceId, message.guid); + const guid = trimGuid(message.workspaceId, message.guid); await this.docManager.push(message.workspaceId, guid, update); + endTimer(); } + @Auth() @SubscribeMessage('doc-load') async loadDoc( + @CurrentUser() user: UserType, @MessageBody() message: { workspaceId: string; guid: string; stateVector?: string; targetClientId?: number; - } + }, + @ConnectedSocket() client: Socket ): Promise<{ missing: string; state?: string } | false> { this.metric.socketIOEventCounter(1, { event: 'doc-load' }); const endTimer = this.metric.socketIOEventTimer({ event: 'doc-load' }); + if (!client.rooms.has(message.workspaceId)) { + const canRead = await this.permissions.tryCheck( + message.workspaceId, + user.id + ); + if (!canRead) { + endTimer(); + return false; + } + } + const guid = trimGuid(message.workspaceId, message.guid); const doc = await this.docManager.getLatest(message.workspaceId, guid); @@ -131,7 +171,13 @@ export class EventsGateway implements OnGatewayConnection, OnGatewayDisconnect { const endTimer = this.metric.socketIOEventTimer({ event: 'init-awareness', }); - client.to(workspaceId).emit('new-client-awareness-init'); + if (client.rooms.has(workspaceId)) { + client.to(workspaceId).emit('new-client-awareness-init'); + } else { + this.logger.verbose( + `Client ${client.id} tried to init awareness for workspace ${workspaceId} without joining it first` + ); + } endTimer(); } @@ -144,9 +190,16 @@ export class EventsGateway implements OnGatewayConnection, OnGatewayDisconnect { const endTimer = this.metric.socketIOEventTimer({ event: 'awareness-update', }); - client.to(message.workspaceId).emit('server-awareness-broadcast', { - ...message, - }); + + if (client.rooms.has(message.workspaceId)) { + client.to(message.workspaceId).emit('server-awareness-broadcast', { + ...message, + }); + } else { + this.logger.verbose( + `Client ${client.id} tried to update awareness for workspace ${message.workspaceId} without joining it first` + ); + } endTimer(); return 'ack'; diff --git a/apps/server/src/modules/sync/events/events.module.ts b/apps/server/src/modules/sync/events/events.module.ts index 2d61c910c3..f9a1c1bef0 100644 --- a/apps/server/src/modules/sync/events/events.module.ts +++ b/apps/server/src/modules/sync/events/events.module.ts @@ -1,11 +1,12 @@ import { Module } from '@nestjs/common'; import { DocModule } from '../../doc'; +import { PermissionService } from '../../workspaces/permission'; import { EventsGateway } from './events.gateway'; import { WorkspaceService } from './workspace'; @Module({ imports: [DocModule.forFeature()], - providers: [EventsGateway, WorkspaceService], + providers: [EventsGateway, PermissionService, WorkspaceService], }) export class EventsModule {} diff --git a/apps/server/src/utils/nestjs.ts b/apps/server/src/utils/nestjs.ts index 7da39a15d2..7ee27745ce 100644 --- a/apps/server/src/utils/nestjs.ts +++ b/apps/server/src/utils/nestjs.ts @@ -21,6 +21,29 @@ export function getRequestResponseFromContext(context: ExecutionContext) { res: http.getResponse(), }; } + case 'ws': { + const ws = context.switchToWs(); + const req = ws.getClient().handshake; + + const cookies = req?.headers?.cookie; + // patch cookies to match auth guard logic + if (typeof cookies === 'string') { + req.cookies = cookies + .split(';') + .map(v => v.split('=')) + .reduce( + (acc, v) => { + acc[decodeURIComponent(v[0].trim())] = decodeURIComponent( + v[1].trim() + ); + return acc; + }, + {} as Record + ); + } + + return { req }; + } default: throw new Error('Unknown context type for getting request and response'); }