feat(server): authenticate user before ws connected (#7777)

This commit is contained in:
forehalo
2024-08-08 08:30:55 +00:00
parent 83244f0201
commit f2eafc374c
17 changed files with 232 additions and 87 deletions

View File

@@ -152,6 +152,8 @@ function buildAppModule() {
factor
// common fundamental modules
.use(...FunctionalityModules)
.useIf(config => config.flavor.sync, WebSocketModule)
// auth
.use(AuthModule)
@@ -159,7 +161,7 @@ function buildAppModule() {
.use(DocModule)
// sync server only
.useIf(config => config.flavor.sync, WebSocketModule, SyncModule)
.useIf(config => config.flavor.sync, SyncModule)
// graphql server only
.useIf(

View File

@@ -1,6 +1,6 @@
import type { ExecutionContext } from '@nestjs/common';
import { createParamDecorator } from '@nestjs/common';
import { User } from '@prisma/client';
import { User, UserSession } from '@prisma/client';
import { getRequestResponseFromContext } from '../../fundamentals';
@@ -53,3 +53,5 @@ export interface CurrentUser
hasPassword: boolean | null;
emailVerified: boolean;
}
export { type UserSession };

View File

@@ -1,15 +1,22 @@
import type {
CanActivate,
ExecutionContext,
FactoryProvider,
OnModuleInit,
} from '@nestjs/common';
import { Injectable, SetMetadata, UseGuards } from '@nestjs/common';
import { ModuleRef, Reflector } from '@nestjs/core';
import type { Request } from 'express';
import {
AuthenticationRequired,
Config,
getRequestResponseFromContext,
mapAnyError,
parseCookies,
} from '../../fundamentals';
import { WEBSOCKET_OPTIONS } from '../../fundamentals/websocket';
import { CurrentUser, UserSession } from './current-user';
import { AuthService, parseAuthUserSeqNum } from './service';
function extractTokenFromHeader(authorization: string) {
@@ -38,37 +45,9 @@ export class AuthGuard implements CanActivate, OnModuleInit {
async canActivate(context: ExecutionContext) {
const { req, res } = getRequestResponseFromContext(context);
// check cookie
let sessionToken: string | undefined =
req.cookies[AuthService.sessionCookieName];
if (!sessionToken && req.headers.authorization) {
sessionToken = extractTokenFromHeader(req.headers.authorization);
}
if (sessionToken) {
const userSeq = parseAuthUserSeqNum(
req.headers[AuthService.authUserSeqHeaderName]
);
const { user, expiresAt } = await this.auth.getUser(
sessionToken,
userSeq
);
if (res && user && expiresAt) {
await this.auth.refreshUserSessionIfNeeded(
req,
res,
sessionToken,
user.id,
expiresAt
);
}
if (user) {
req.sid = sessionToken;
req.user = user;
}
const userSession = await this.signIn(req);
if (res && userSession && userSession.session.expiresAt) {
await this.auth.refreshUserSessionIfNeeded(req, res, userSession.session);
}
// api is public
@@ -84,9 +63,44 @@ export class AuthGuard implements CanActivate, OnModuleInit {
if (!req.user) {
throw new AuthenticationRequired();
}
return true;
}
async signIn(
req: Request
): Promise<{ user: CurrentUser; session: UserSession } | null> {
if (req.user && req.session) {
return {
user: req.user,
session: req.session,
};
}
parseCookies(req);
let sessionToken: string | undefined =
req.cookies[AuthService.sessionCookieName];
if (!sessionToken && req.headers.authorization) {
sessionToken = extractTokenFromHeader(req.headers.authorization);
}
if (sessionToken) {
const userSeq = parseAuthUserSeqNum(
req.headers[AuthService.authUserSeqHeaderName]
);
const userSession = await this.auth.getUserSession(sessionToken, userSeq);
if (userSession) {
req.session = userSession.session;
req.user = userSession.user;
}
return userSession;
}
return null;
}
}
/**
@@ -111,3 +125,35 @@ export const Auth = () => {
// api is public accessible
export const Public = () => SetMetadata(PUBLIC_ENTRYPOINT_SYMBOL, true);
export const AuthWebsocketOptionsProvider: FactoryProvider = {
provide: WEBSOCKET_OPTIONS,
useFactory: (config: Config, guard: AuthGuard) => {
return {
...config.websocket,
allowRequest: async (
req: any,
pass: (err: string | null | undefined, success: boolean) => void
) => {
if (!config.websocket.requireAuthentication) {
return pass(null, true);
}
try {
const authentication = await guard.signIn(req);
if (authentication) {
return pass(null, true);
} else {
return pass('unauthenticated', false);
}
} catch (e) {
const error = mapAnyError(e);
error.log('Websocket');
return pass('unauthenticated', false);
}
},
};
},
inject: [Config, AuthGuard],
};

View File

@@ -6,15 +6,21 @@ import { FeatureModule } from '../features';
import { QuotaModule } from '../quota';
import { UserModule } from '../user';
import { AuthController } from './controller';
import { AuthGuard } from './guard';
import { AuthGuard, AuthWebsocketOptionsProvider } from './guard';
import { AuthResolver } from './resolver';
import { AuthService } from './service';
import { TokenService, TokenType } from './token';
@Module({
imports: [FeatureModule, UserModule, QuotaModule],
providers: [AuthService, AuthResolver, TokenService, AuthGuard],
exports: [AuthService, AuthGuard],
providers: [
AuthService,
AuthResolver,
TokenService,
AuthGuard,
AuthWebsocketOptionsProvider,
],
exports: [AuthService, AuthGuard, AuthWebsocketOptionsProvider],
controllers: [AuthController],
})
export class AuthModule {}

View File

@@ -1,6 +1,6 @@
import { Injectable, OnApplicationBootstrap } from '@nestjs/common';
import { Cron, CronExpression } from '@nestjs/schedule';
import type { User } from '@prisma/client';
import type { User, UserSession } from '@prisma/client';
import { PrismaClient } from '@prisma/client';
import type { CookieOptions, Request, Response } from 'express';
import { assign, omit } from 'lodash-es';
@@ -121,27 +121,27 @@ export class AuthService implements OnApplicationBootstrap {
return sessionUser(user);
}
async getUser(
async getUserSession(
token: string,
seq = 0
): Promise<{ user: CurrentUser | null; expiresAt: Date | null }> {
): Promise<{ user: CurrentUser; session: UserSession } | null> {
const session = await this.getSession(token);
// no such session
if (!session) {
return { user: null, expiresAt: null };
return null;
}
const userSession = session.userSessions.at(seq);
// no such user session
if (!userSession) {
return { user: null, expiresAt: null };
return null;
}
// user session expired
if (userSession.expiresAt && userSession.expiresAt <= new Date()) {
return { user: null, expiresAt: null };
return null;
}
const user = await this.db.user.findUnique({
@@ -149,10 +149,10 @@ export class AuthService implements OnApplicationBootstrap {
});
if (!user) {
return { user: null, expiresAt: null };
return null;
}
return { user: sessionUser(user), expiresAt: userSession.expiresAt };
return { user: sessionUser(user), session: userSession };
}
async getUserList(token: string) {
@@ -251,12 +251,13 @@ export class AuthService implements OnApplicationBootstrap {
async refreshUserSessionIfNeeded(
_req: Request,
res: Response,
sessionId: string,
userId: string,
expiresAt: Date,
session: UserSession,
ttr = this.config.auth.session.ttr
): Promise<boolean> {
if (expiresAt && expiresAt.getTime() - Date.now() > ttr * 1000) {
if (
session.expiresAt &&
session.expiresAt.getTime() - Date.now() > ttr * 1000
) {
// no need to refresh
return false;
}
@@ -267,17 +268,14 @@ export class AuthService implements OnApplicationBootstrap {
await this.db.userSession.update({
where: {
sessionId_userId: {
sessionId,
userId,
},
id: session.id,
},
data: {
expiresAt: newExpiresAt,
},
});
res.cookie(AuthService.sessionCookieName, sessionId, {
res.cookie(AuthService.sessionCookieName, session.sessionId, {
expires: newExpiresAt,
...this.cookieOptions,
});

View File

@@ -50,12 +50,7 @@ function Awareness(workspaceId: string): `${string}:awareness` {
return `${workspaceId}:awareness`;
}
@WebSocketGateway({
cors: !AFFiNE.node.prod,
transports: ['websocket'],
// see: https://socket.io/docs/v4/server-options/#maxhttpbuffersize
maxHttpBufferSize: 1e8, // 100 MB
})
@WebSocketGateway()
export class EventsGateway implements OnGatewayConnection, OnGatewayDisconnect {
protected logger = new Logger(EventsGateway.name);
private connectionCount = 0;

View File

@@ -36,5 +36,6 @@ export {
getRequestFromHost,
getRequestResponseFromContext,
getRequestResponseFromHost,
parseCookies,
} from './utils/request';
export type * from './utils/types';

View File

@@ -2,8 +2,10 @@ import { ArgumentsHost, Catch, Logger } from '@nestjs/common';
import { BaseExceptionFilter } from '@nestjs/core';
import { GqlContextType } from '@nestjs/graphql';
import { ThrottlerException } from '@nestjs/throttler';
import { BaseWsExceptionFilter } from '@nestjs/websockets';
import { Response } from 'express';
import { of } from 'rxjs';
import { Socket } from 'socket.io';
import {
InternalServerError,
@@ -44,6 +46,20 @@ export class GlobalExceptionFilter extends BaseExceptionFilter {
}
}
export class GlobalWsExceptionFilter extends BaseWsExceptionFilter {
// @ts-expect-error satisfies the override
override handleError(client: Socket, exception: any): void {
const error = mapAnyError(exception);
error.log('Websocket');
metrics.socketio
.counter('unhandled_error')
.add(1, { status: error.status });
client.emit('error', {
error: toWebsocketError(error),
});
}
}
/**
* Only exists for websocket error body backward compatibility
*

View File

@@ -57,7 +57,7 @@ export class CloudThrottlerGuard extends ThrottlerGuard {
override getTracker(req: Request): Promise<string> {
return Promise.resolve(
// ↓ prefer session id if available
`throttler:${req.sid ?? req.get('CF-Connecting-IP') ?? req.get('CF-ray') ?? req.ip}`
`throttler:${req.session?.sessionId ?? req.get('CF-Connecting-IP') ?? req.get('CF-ray') ?? req.ip}`
// ^ throttler prefix make the key in store recognizable
);
}

View File

@@ -66,3 +66,29 @@ export function getRequestFromHost(host: ArgumentsHost) {
export function getRequestResponseFromContext(ctx: ExecutionContext) {
return getRequestResponseFromHost(ctx);
}
/**
* simple patch for request not protected by `cookie-parser`
* only take effect if `req.cookies` is not defined
*/
export function parseCookies(req: Request) {
if (req.cookies) {
return;
}
const cookieStr = req?.headers?.cookie ?? '';
req.cookies = cookieStr.split(';').reduce(
(cookies, cookie) => {
const [key, val] = cookie.split('=');
if (key) {
cookies[decodeURIComponent(key.trim())] = val
? decodeURIComponent(val.trim())
: val;
}
return cookies;
},
{} as Record<string, string>
);
}

View File

@@ -0,0 +1,20 @@
import { GatewayMetadata } from '@nestjs/websockets';
import { defineStartupConfig, ModuleConfig } from '../config';
declare module '../config' {
interface AppConfig {
websocket: ModuleConfig<
GatewayMetadata & {
requireAuthentication?: boolean;
}
>;
}
}
defineStartupConfig('websocket', {
// see: https://socket.io/docs/v4/server-options/#maxhttpbuffersize
transports: ['websocket'],
maxHttpBufferSize: 1e8, // 100 MB
requireAuthentication: true,
});

View File

@@ -1,17 +1,46 @@
import { Module, Provider } from '@nestjs/common';
import './config';
import {
FactoryProvider,
INestApplicationContext,
Module,
Provider,
} from '@nestjs/common';
import { IoAdapter } from '@nestjs/platform-socket.io';
import { Server } from 'socket.io';
import { Config } from '../config';
export const SocketIoAdapterImpl = Symbol('SocketIoAdapterImpl');
export class SocketIoAdapter extends IoAdapter {}
export class SocketIoAdapter extends IoAdapter {
constructor(protected readonly app: INestApplicationContext) {
super(app);
}
override createIOServer(port: number, options?: any): Server {
const config = this.app.get(WEBSOCKET_OPTIONS);
return super.createIOServer(port, { ...config, ...options });
}
}
const SocketIoAdapterImplProvider: Provider = {
provide: SocketIoAdapterImpl,
useValue: SocketIoAdapter,
};
export const WEBSOCKET_OPTIONS = Symbol('WEBSOCKET_OPTIONS');
export const websocketOptionsProvider: FactoryProvider = {
provide: WEBSOCKET_OPTIONS,
useFactory: (config: Config) => {
return config.websocket;
},
inject: [Config],
};
@Module({
providers: [SocketIoAdapterImplProvider],
exports: [SocketIoAdapterImplProvider],
providers: [SocketIoAdapterImplProvider, websocketOptionsProvider],
exports: [SocketIoAdapterImplProvider, websocketOptionsProvider],
})
export class WebSocketModule {}

View File

@@ -1,7 +1,7 @@
declare namespace Express {
interface Request {
user?: import('./core/auth/current-user').CurrentUser;
sid?: string;
session?: import('./core/auth/current-user').UserSession;
}
}

View File

@@ -18,7 +18,7 @@ export function createSockerIoAdapterImpl(
console.error(err);
});
const server = super.createIOServer(port, options) as Server;
const server = super.createIOServer(port, options);
server.adapter(createAdapter(pubClient, subClient));
return server;
}