feat(server): authenticate user before ws connected (#7777)

This commit is contained in:
forehalo
2024-08-08 08:30:55 +00:00
parent 83244f0201
commit f2eafc374c
17 changed files with 232 additions and 87 deletions

View File

@@ -36,5 +36,6 @@ export {
getRequestFromHost,
getRequestResponseFromContext,
getRequestResponseFromHost,
parseCookies,
} from './utils/request';
export type * from './utils/types';

View File

@@ -2,8 +2,10 @@ import { ArgumentsHost, Catch, Logger } from '@nestjs/common';
import { BaseExceptionFilter } from '@nestjs/core';
import { GqlContextType } from '@nestjs/graphql';
import { ThrottlerException } from '@nestjs/throttler';
import { BaseWsExceptionFilter } from '@nestjs/websockets';
import { Response } from 'express';
import { of } from 'rxjs';
import { Socket } from 'socket.io';
import {
InternalServerError,
@@ -44,6 +46,20 @@ export class GlobalExceptionFilter extends BaseExceptionFilter {
}
}
export class GlobalWsExceptionFilter extends BaseWsExceptionFilter {
// @ts-expect-error satisfies the override
override handleError(client: Socket, exception: any): void {
const error = mapAnyError(exception);
error.log('Websocket');
metrics.socketio
.counter('unhandled_error')
.add(1, { status: error.status });
client.emit('error', {
error: toWebsocketError(error),
});
}
}
/**
* Only exists for websocket error body backward compatibility
*

View File

@@ -57,7 +57,7 @@ export class CloudThrottlerGuard extends ThrottlerGuard {
override getTracker(req: Request): Promise<string> {
return Promise.resolve(
// ↓ prefer session id if available
`throttler:${req.sid ?? req.get('CF-Connecting-IP') ?? req.get('CF-ray') ?? req.ip}`
`throttler:${req.session?.sessionId ?? req.get('CF-Connecting-IP') ?? req.get('CF-ray') ?? req.ip}`
// ^ throttler prefix make the key in store recognizable
);
}

View File

@@ -66,3 +66,29 @@ export function getRequestFromHost(host: ArgumentsHost) {
export function getRequestResponseFromContext(ctx: ExecutionContext) {
return getRequestResponseFromHost(ctx);
}
/**
* simple patch for request not protected by `cookie-parser`
* only take effect if `req.cookies` is not defined
*/
export function parseCookies(req: Request) {
if (req.cookies) {
return;
}
const cookieStr = req?.headers?.cookie ?? '';
req.cookies = cookieStr.split(';').reduce(
(cookies, cookie) => {
const [key, val] = cookie.split('=');
if (key) {
cookies[decodeURIComponent(key.trim())] = val
? decodeURIComponent(val.trim())
: val;
}
return cookies;
},
{} as Record<string, string>
);
}

View File

@@ -0,0 +1,20 @@
import { GatewayMetadata } from '@nestjs/websockets';
import { defineStartupConfig, ModuleConfig } from '../config';
declare module '../config' {
interface AppConfig {
websocket: ModuleConfig<
GatewayMetadata & {
requireAuthentication?: boolean;
}
>;
}
}
defineStartupConfig('websocket', {
// see: https://socket.io/docs/v4/server-options/#maxhttpbuffersize
transports: ['websocket'],
maxHttpBufferSize: 1e8, // 100 MB
requireAuthentication: true,
});

View File

@@ -1,17 +1,46 @@
import { Module, Provider } from '@nestjs/common';
import './config';
import {
FactoryProvider,
INestApplicationContext,
Module,
Provider,
} from '@nestjs/common';
import { IoAdapter } from '@nestjs/platform-socket.io';
import { Server } from 'socket.io';
import { Config } from '../config';
export const SocketIoAdapterImpl = Symbol('SocketIoAdapterImpl');
export class SocketIoAdapter extends IoAdapter {}
export class SocketIoAdapter extends IoAdapter {
constructor(protected readonly app: INestApplicationContext) {
super(app);
}
override createIOServer(port: number, options?: any): Server {
const config = this.app.get(WEBSOCKET_OPTIONS);
return super.createIOServer(port, { ...config, ...options });
}
}
const SocketIoAdapterImplProvider: Provider = {
provide: SocketIoAdapterImpl,
useValue: SocketIoAdapter,
};
export const WEBSOCKET_OPTIONS = Symbol('WEBSOCKET_OPTIONS');
export const websocketOptionsProvider: FactoryProvider = {
provide: WEBSOCKET_OPTIONS,
useFactory: (config: Config) => {
return config.websocket;
},
inject: [Config],
};
@Module({
providers: [SocketIoAdapterImplProvider],
exports: [SocketIoAdapterImplProvider],
providers: [SocketIoAdapterImplProvider, websocketOptionsProvider],
exports: [SocketIoAdapterImplProvider, websocketOptionsProvider],
})
export class WebSocketModule {}