feat: add auth support for websocket (#4445)

This commit is contained in:
DarkSky
2023-09-21 21:05:26 +08:00
committed by GitHub
parent 872dc3521b
commit 1ddae40fb2
3 changed files with 86 additions and 9 deletions

View File

@@ -1,3 +1,4 @@
import { Logger } from '@nestjs/common';
import {
ConnectedSocket,
MessageBody,
@@ -12,18 +13,24 @@ import { encodeStateAsUpdate, encodeStateVector } from 'yjs';
import { Metrics } from '../../../metrics/metrics';
import { trimGuid } from '../../../utils/doc';
import { Auth, CurrentUser } from '../../auth';
import { DocManager } from '../../doc';
import { UserType } from '../../users';
import { PermissionService } from '../../workspaces/permission';
import { Permission } from '../../workspaces/types';
@WebSocketGateway({
cors: process.env.NODE_ENV !== 'production',
transports: ['websocket'],
})
export class EventsGateway implements OnGatewayConnection, OnGatewayDisconnect {
protected logger = new Logger(EventsGateway.name);
private connectionCount = 0;
constructor(
private readonly docManager: DocManager,
private readonly metric: Metrics
private readonly metric: Metrics,
private readonly permissions: PermissionService
) {}
@WebSocketServer()
@@ -39,8 +46,10 @@ export class EventsGateway implements OnGatewayConnection, OnGatewayDisconnect {
this.metric.socketIOConnectionGauge(this.connectionCount, {});
}
@Auth()
@SubscribeMessage('client-handshake')
async handleClientHandShake(
@CurrentUser() user: UserType,
@MessageBody() workspaceId: string,
@ConnectedSocket() client: Socket
) {
@@ -48,8 +57,16 @@ export class EventsGateway implements OnGatewayConnection, OnGatewayDisconnect {
const endTimer = this.metric.socketIOEventTimer({
event: 'client-handshake',
});
await client.join(workspaceId);
const canWrite = await this.permissions.tryCheck(
workspaceId,
user.id,
Permission.Write
);
if (canWrite) await client.join(workspaceId);
endTimer();
return canWrite;
}
@SubscribeMessage('client-leave')
@@ -77,26 +94,49 @@ export class EventsGateway implements OnGatewayConnection, OnGatewayDisconnect {
) {
this.metric.socketIOEventCounter(1, { event: 'client-update' });
const endTimer = this.metric.socketIOEventTimer({ event: 'client-update' });
if (!client.rooms.has(message.workspaceId)) {
this.logger.verbose(
`Client ${client.id} tried to push update to workspace ${message.workspaceId} without joining it first`
);
endTimer();
return;
}
const update = Buffer.from(message.update, 'base64');
client.to(message.workspaceId).emit('server-update', message);
const guid = trimGuid(message.workspaceId, message.guid);
const guid = trimGuid(message.workspaceId, message.guid);
await this.docManager.push(message.workspaceId, guid, update);
endTimer();
}
@Auth()
@SubscribeMessage('doc-load')
async loadDoc(
@CurrentUser() user: UserType,
@MessageBody()
message: {
workspaceId: string;
guid: string;
stateVector?: string;
targetClientId?: number;
}
},
@ConnectedSocket() client: Socket
): Promise<{ missing: string; state?: string } | false> {
this.metric.socketIOEventCounter(1, { event: 'doc-load' });
const endTimer = this.metric.socketIOEventTimer({ event: 'doc-load' });
if (!client.rooms.has(message.workspaceId)) {
const canRead = await this.permissions.tryCheck(
message.workspaceId,
user.id
);
if (!canRead) {
endTimer();
return false;
}
}
const guid = trimGuid(message.workspaceId, message.guid);
const doc = await this.docManager.getLatest(message.workspaceId, guid);
@@ -131,7 +171,13 @@ export class EventsGateway implements OnGatewayConnection, OnGatewayDisconnect {
const endTimer = this.metric.socketIOEventTimer({
event: 'init-awareness',
});
client.to(workspaceId).emit('new-client-awareness-init');
if (client.rooms.has(workspaceId)) {
client.to(workspaceId).emit('new-client-awareness-init');
} else {
this.logger.verbose(
`Client ${client.id} tried to init awareness for workspace ${workspaceId} without joining it first`
);
}
endTimer();
}
@@ -144,9 +190,16 @@ export class EventsGateway implements OnGatewayConnection, OnGatewayDisconnect {
const endTimer = this.metric.socketIOEventTimer({
event: 'awareness-update',
});
client.to(message.workspaceId).emit('server-awareness-broadcast', {
...message,
});
if (client.rooms.has(message.workspaceId)) {
client.to(message.workspaceId).emit('server-awareness-broadcast', {
...message,
});
} else {
this.logger.verbose(
`Client ${client.id} tried to update awareness for workspace ${message.workspaceId} without joining it first`
);
}
endTimer();
return 'ack';

View File

@@ -1,11 +1,12 @@
import { Module } from '@nestjs/common';
import { DocModule } from '../../doc';
import { PermissionService } from '../../workspaces/permission';
import { EventsGateway } from './events.gateway';
import { WorkspaceService } from './workspace';
@Module({
imports: [DocModule.forFeature()],
providers: [EventsGateway, WorkspaceService],
providers: [EventsGateway, PermissionService, WorkspaceService],
})
export class EventsModule {}

View File

@@ -21,6 +21,29 @@ export function getRequestResponseFromContext(context: ExecutionContext) {
res: http.getResponse<Response>(),
};
}
case 'ws': {
const ws = context.switchToWs();
const req = ws.getClient().handshake;
const cookies = req?.headers?.cookie;
// patch cookies to match auth guard logic
if (typeof cookies === 'string') {
req.cookies = cookies
.split(';')
.map(v => v.split('='))
.reduce(
(acc, v) => {
acc[decodeURIComponent(v[0].trim())] = decodeURIComponent(
v[1].trim()
);
return acc;
},
{} as Record<string, string>
);
}
return { req };
}
default:
throw new Error('Unknown context type for getting request and response');
}