mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-12 12:28:42 +00:00
feat(server): authenticate user before ws connected (#7777)
This commit is contained in:
@@ -152,6 +152,8 @@ function buildAppModule() {
|
||||
factor
|
||||
// common fundamental modules
|
||||
.use(...FunctionalityModules)
|
||||
.useIf(config => config.flavor.sync, WebSocketModule)
|
||||
|
||||
// auth
|
||||
.use(AuthModule)
|
||||
|
||||
@@ -159,7 +161,7 @@ function buildAppModule() {
|
||||
.use(DocModule)
|
||||
|
||||
// sync server only
|
||||
.useIf(config => config.flavor.sync, WebSocketModule, SyncModule)
|
||||
.useIf(config => config.flavor.sync, SyncModule)
|
||||
|
||||
// graphql server only
|
||||
.useIf(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import type { ExecutionContext } from '@nestjs/common';
|
||||
import { createParamDecorator } from '@nestjs/common';
|
||||
import { User } from '@prisma/client';
|
||||
import { User, UserSession } from '@prisma/client';
|
||||
|
||||
import { getRequestResponseFromContext } from '../../fundamentals';
|
||||
|
||||
@@ -53,3 +53,5 @@ export interface CurrentUser
|
||||
hasPassword: boolean | null;
|
||||
emailVerified: boolean;
|
||||
}
|
||||
|
||||
export { type UserSession };
|
||||
|
||||
@@ -1,15 +1,22 @@
|
||||
import type {
|
||||
CanActivate,
|
||||
ExecutionContext,
|
||||
FactoryProvider,
|
||||
OnModuleInit,
|
||||
} from '@nestjs/common';
|
||||
import { Injectable, SetMetadata, UseGuards } from '@nestjs/common';
|
||||
import { ModuleRef, Reflector } from '@nestjs/core';
|
||||
import type { Request } from 'express';
|
||||
|
||||
import {
|
||||
AuthenticationRequired,
|
||||
Config,
|
||||
getRequestResponseFromContext,
|
||||
mapAnyError,
|
||||
parseCookies,
|
||||
} from '../../fundamentals';
|
||||
import { WEBSOCKET_OPTIONS } from '../../fundamentals/websocket';
|
||||
import { CurrentUser, UserSession } from './current-user';
|
||||
import { AuthService, parseAuthUserSeqNum } from './service';
|
||||
|
||||
function extractTokenFromHeader(authorization: string) {
|
||||
@@ -38,37 +45,9 @@ export class AuthGuard implements CanActivate, OnModuleInit {
|
||||
async canActivate(context: ExecutionContext) {
|
||||
const { req, res } = getRequestResponseFromContext(context);
|
||||
|
||||
// check cookie
|
||||
let sessionToken: string | undefined =
|
||||
req.cookies[AuthService.sessionCookieName];
|
||||
|
||||
if (!sessionToken && req.headers.authorization) {
|
||||
sessionToken = extractTokenFromHeader(req.headers.authorization);
|
||||
}
|
||||
|
||||
if (sessionToken) {
|
||||
const userSeq = parseAuthUserSeqNum(
|
||||
req.headers[AuthService.authUserSeqHeaderName]
|
||||
);
|
||||
|
||||
const { user, expiresAt } = await this.auth.getUser(
|
||||
sessionToken,
|
||||
userSeq
|
||||
);
|
||||
if (res && user && expiresAt) {
|
||||
await this.auth.refreshUserSessionIfNeeded(
|
||||
req,
|
||||
res,
|
||||
sessionToken,
|
||||
user.id,
|
||||
expiresAt
|
||||
);
|
||||
}
|
||||
|
||||
if (user) {
|
||||
req.sid = sessionToken;
|
||||
req.user = user;
|
||||
}
|
||||
const userSession = await this.signIn(req);
|
||||
if (res && userSession && userSession.session.expiresAt) {
|
||||
await this.auth.refreshUserSessionIfNeeded(req, res, userSession.session);
|
||||
}
|
||||
|
||||
// api is public
|
||||
@@ -84,9 +63,44 @@ export class AuthGuard implements CanActivate, OnModuleInit {
|
||||
if (!req.user) {
|
||||
throw new AuthenticationRequired();
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
async signIn(
|
||||
req: Request
|
||||
): Promise<{ user: CurrentUser; session: UserSession } | null> {
|
||||
if (req.user && req.session) {
|
||||
return {
|
||||
user: req.user,
|
||||
session: req.session,
|
||||
};
|
||||
}
|
||||
|
||||
parseCookies(req);
|
||||
let sessionToken: string | undefined =
|
||||
req.cookies[AuthService.sessionCookieName];
|
||||
|
||||
if (!sessionToken && req.headers.authorization) {
|
||||
sessionToken = extractTokenFromHeader(req.headers.authorization);
|
||||
}
|
||||
|
||||
if (sessionToken) {
|
||||
const userSeq = parseAuthUserSeqNum(
|
||||
req.headers[AuthService.authUserSeqHeaderName]
|
||||
);
|
||||
|
||||
const userSession = await this.auth.getUserSession(sessionToken, userSeq);
|
||||
|
||||
if (userSession) {
|
||||
req.session = userSession.session;
|
||||
req.user = userSession.user;
|
||||
}
|
||||
|
||||
return userSession;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -111,3 +125,35 @@ export const Auth = () => {
|
||||
|
||||
// api is public accessible
|
||||
export const Public = () => SetMetadata(PUBLIC_ENTRYPOINT_SYMBOL, true);
|
||||
|
||||
export const AuthWebsocketOptionsProvider: FactoryProvider = {
|
||||
provide: WEBSOCKET_OPTIONS,
|
||||
useFactory: (config: Config, guard: AuthGuard) => {
|
||||
return {
|
||||
...config.websocket,
|
||||
allowRequest: async (
|
||||
req: any,
|
||||
pass: (err: string | null | undefined, success: boolean) => void
|
||||
) => {
|
||||
if (!config.websocket.requireAuthentication) {
|
||||
return pass(null, true);
|
||||
}
|
||||
|
||||
try {
|
||||
const authentication = await guard.signIn(req);
|
||||
|
||||
if (authentication) {
|
||||
return pass(null, true);
|
||||
} else {
|
||||
return pass('unauthenticated', false);
|
||||
}
|
||||
} catch (e) {
|
||||
const error = mapAnyError(e);
|
||||
error.log('Websocket');
|
||||
return pass('unauthenticated', false);
|
||||
}
|
||||
},
|
||||
};
|
||||
},
|
||||
inject: [Config, AuthGuard],
|
||||
};
|
||||
|
||||
@@ -6,15 +6,21 @@ import { FeatureModule } from '../features';
|
||||
import { QuotaModule } from '../quota';
|
||||
import { UserModule } from '../user';
|
||||
import { AuthController } from './controller';
|
||||
import { AuthGuard } from './guard';
|
||||
import { AuthGuard, AuthWebsocketOptionsProvider } from './guard';
|
||||
import { AuthResolver } from './resolver';
|
||||
import { AuthService } from './service';
|
||||
import { TokenService, TokenType } from './token';
|
||||
|
||||
@Module({
|
||||
imports: [FeatureModule, UserModule, QuotaModule],
|
||||
providers: [AuthService, AuthResolver, TokenService, AuthGuard],
|
||||
exports: [AuthService, AuthGuard],
|
||||
providers: [
|
||||
AuthService,
|
||||
AuthResolver,
|
||||
TokenService,
|
||||
AuthGuard,
|
||||
AuthWebsocketOptionsProvider,
|
||||
],
|
||||
exports: [AuthService, AuthGuard, AuthWebsocketOptionsProvider],
|
||||
controllers: [AuthController],
|
||||
})
|
||||
export class AuthModule {}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { Injectable, OnApplicationBootstrap } from '@nestjs/common';
|
||||
import { Cron, CronExpression } from '@nestjs/schedule';
|
||||
import type { User } from '@prisma/client';
|
||||
import type { User, UserSession } from '@prisma/client';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
import type { CookieOptions, Request, Response } from 'express';
|
||||
import { assign, omit } from 'lodash-es';
|
||||
@@ -121,27 +121,27 @@ export class AuthService implements OnApplicationBootstrap {
|
||||
return sessionUser(user);
|
||||
}
|
||||
|
||||
async getUser(
|
||||
async getUserSession(
|
||||
token: string,
|
||||
seq = 0
|
||||
): Promise<{ user: CurrentUser | null; expiresAt: Date | null }> {
|
||||
): Promise<{ user: CurrentUser; session: UserSession } | null> {
|
||||
const session = await this.getSession(token);
|
||||
|
||||
// no such session
|
||||
if (!session) {
|
||||
return { user: null, expiresAt: null };
|
||||
return null;
|
||||
}
|
||||
|
||||
const userSession = session.userSessions.at(seq);
|
||||
|
||||
// no such user session
|
||||
if (!userSession) {
|
||||
return { user: null, expiresAt: null };
|
||||
return null;
|
||||
}
|
||||
|
||||
// user session expired
|
||||
if (userSession.expiresAt && userSession.expiresAt <= new Date()) {
|
||||
return { user: null, expiresAt: null };
|
||||
return null;
|
||||
}
|
||||
|
||||
const user = await this.db.user.findUnique({
|
||||
@@ -149,10 +149,10 @@ export class AuthService implements OnApplicationBootstrap {
|
||||
});
|
||||
|
||||
if (!user) {
|
||||
return { user: null, expiresAt: null };
|
||||
return null;
|
||||
}
|
||||
|
||||
return { user: sessionUser(user), expiresAt: userSession.expiresAt };
|
||||
return { user: sessionUser(user), session: userSession };
|
||||
}
|
||||
|
||||
async getUserList(token: string) {
|
||||
@@ -251,12 +251,13 @@ export class AuthService implements OnApplicationBootstrap {
|
||||
async refreshUserSessionIfNeeded(
|
||||
_req: Request,
|
||||
res: Response,
|
||||
sessionId: string,
|
||||
userId: string,
|
||||
expiresAt: Date,
|
||||
session: UserSession,
|
||||
ttr = this.config.auth.session.ttr
|
||||
): Promise<boolean> {
|
||||
if (expiresAt && expiresAt.getTime() - Date.now() > ttr * 1000) {
|
||||
if (
|
||||
session.expiresAt &&
|
||||
session.expiresAt.getTime() - Date.now() > ttr * 1000
|
||||
) {
|
||||
// no need to refresh
|
||||
return false;
|
||||
}
|
||||
@@ -267,17 +268,14 @@ export class AuthService implements OnApplicationBootstrap {
|
||||
|
||||
await this.db.userSession.update({
|
||||
where: {
|
||||
sessionId_userId: {
|
||||
sessionId,
|
||||
userId,
|
||||
},
|
||||
id: session.id,
|
||||
},
|
||||
data: {
|
||||
expiresAt: newExpiresAt,
|
||||
},
|
||||
});
|
||||
|
||||
res.cookie(AuthService.sessionCookieName, sessionId, {
|
||||
res.cookie(AuthService.sessionCookieName, session.sessionId, {
|
||||
expires: newExpiresAt,
|
||||
...this.cookieOptions,
|
||||
});
|
||||
|
||||
@@ -50,12 +50,7 @@ function Awareness(workspaceId: string): `${string}:awareness` {
|
||||
return `${workspaceId}:awareness`;
|
||||
}
|
||||
|
||||
@WebSocketGateway({
|
||||
cors: !AFFiNE.node.prod,
|
||||
transports: ['websocket'],
|
||||
// see: https://socket.io/docs/v4/server-options/#maxhttpbuffersize
|
||||
maxHttpBufferSize: 1e8, // 100 MB
|
||||
})
|
||||
@WebSocketGateway()
|
||||
export class EventsGateway implements OnGatewayConnection, OnGatewayDisconnect {
|
||||
protected logger = new Logger(EventsGateway.name);
|
||||
private connectionCount = 0;
|
||||
|
||||
@@ -36,5 +36,6 @@ export {
|
||||
getRequestFromHost,
|
||||
getRequestResponseFromContext,
|
||||
getRequestResponseFromHost,
|
||||
parseCookies,
|
||||
} from './utils/request';
|
||||
export type * from './utils/types';
|
||||
|
||||
@@ -2,8 +2,10 @@ import { ArgumentsHost, Catch, Logger } from '@nestjs/common';
|
||||
import { BaseExceptionFilter } from '@nestjs/core';
|
||||
import { GqlContextType } from '@nestjs/graphql';
|
||||
import { ThrottlerException } from '@nestjs/throttler';
|
||||
import { BaseWsExceptionFilter } from '@nestjs/websockets';
|
||||
import { Response } from 'express';
|
||||
import { of } from 'rxjs';
|
||||
import { Socket } from 'socket.io';
|
||||
|
||||
import {
|
||||
InternalServerError,
|
||||
@@ -44,6 +46,20 @@ export class GlobalExceptionFilter extends BaseExceptionFilter {
|
||||
}
|
||||
}
|
||||
|
||||
export class GlobalWsExceptionFilter extends BaseWsExceptionFilter {
|
||||
// @ts-expect-error satisfies the override
|
||||
override handleError(client: Socket, exception: any): void {
|
||||
const error = mapAnyError(exception);
|
||||
error.log('Websocket');
|
||||
metrics.socketio
|
||||
.counter('unhandled_error')
|
||||
.add(1, { status: error.status });
|
||||
client.emit('error', {
|
||||
error: toWebsocketError(error),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Only exists for websocket error body backward compatibility
|
||||
*
|
||||
|
||||
@@ -57,7 +57,7 @@ export class CloudThrottlerGuard extends ThrottlerGuard {
|
||||
override getTracker(req: Request): Promise<string> {
|
||||
return Promise.resolve(
|
||||
// ↓ prefer session id if available
|
||||
`throttler:${req.sid ?? req.get('CF-Connecting-IP') ?? req.get('CF-ray') ?? req.ip}`
|
||||
`throttler:${req.session?.sessionId ?? req.get('CF-Connecting-IP') ?? req.get('CF-ray') ?? req.ip}`
|
||||
// ^ throttler prefix make the key in store recognizable
|
||||
);
|
||||
}
|
||||
|
||||
@@ -66,3 +66,29 @@ export function getRequestFromHost(host: ArgumentsHost) {
|
||||
export function getRequestResponseFromContext(ctx: ExecutionContext) {
|
||||
return getRequestResponseFromHost(ctx);
|
||||
}
|
||||
|
||||
/**
|
||||
* simple patch for request not protected by `cookie-parser`
|
||||
* only take effect if `req.cookies` is not defined
|
||||
*/
|
||||
export function parseCookies(req: Request) {
|
||||
if (req.cookies) {
|
||||
return;
|
||||
}
|
||||
|
||||
const cookieStr = req?.headers?.cookie ?? '';
|
||||
req.cookies = cookieStr.split(';').reduce(
|
||||
(cookies, cookie) => {
|
||||
const [key, val] = cookie.split('=');
|
||||
|
||||
if (key) {
|
||||
cookies[decodeURIComponent(key.trim())] = val
|
||||
? decodeURIComponent(val.trim())
|
||||
: val;
|
||||
}
|
||||
|
||||
return cookies;
|
||||
},
|
||||
{} as Record<string, string>
|
||||
);
|
||||
}
|
||||
|
||||
20
packages/backend/server/src/fundamentals/websocket/config.ts
Normal file
20
packages/backend/server/src/fundamentals/websocket/config.ts
Normal file
@@ -0,0 +1,20 @@
|
||||
import { GatewayMetadata } from '@nestjs/websockets';
|
||||
|
||||
import { defineStartupConfig, ModuleConfig } from '../config';
|
||||
|
||||
declare module '../config' {
|
||||
interface AppConfig {
|
||||
websocket: ModuleConfig<
|
||||
GatewayMetadata & {
|
||||
requireAuthentication?: boolean;
|
||||
}
|
||||
>;
|
||||
}
|
||||
}
|
||||
|
||||
defineStartupConfig('websocket', {
|
||||
// see: https://socket.io/docs/v4/server-options/#maxhttpbuffersize
|
||||
transports: ['websocket'],
|
||||
maxHttpBufferSize: 1e8, // 100 MB
|
||||
requireAuthentication: true,
|
||||
});
|
||||
@@ -1,17 +1,46 @@
|
||||
import { Module, Provider } from '@nestjs/common';
|
||||
import './config';
|
||||
|
||||
import {
|
||||
FactoryProvider,
|
||||
INestApplicationContext,
|
||||
Module,
|
||||
Provider,
|
||||
} from '@nestjs/common';
|
||||
import { IoAdapter } from '@nestjs/platform-socket.io';
|
||||
import { Server } from 'socket.io';
|
||||
|
||||
import { Config } from '../config';
|
||||
|
||||
export const SocketIoAdapterImpl = Symbol('SocketIoAdapterImpl');
|
||||
|
||||
export class SocketIoAdapter extends IoAdapter {}
|
||||
export class SocketIoAdapter extends IoAdapter {
|
||||
constructor(protected readonly app: INestApplicationContext) {
|
||||
super(app);
|
||||
}
|
||||
|
||||
override createIOServer(port: number, options?: any): Server {
|
||||
const config = this.app.get(WEBSOCKET_OPTIONS);
|
||||
return super.createIOServer(port, { ...config, ...options });
|
||||
}
|
||||
}
|
||||
|
||||
const SocketIoAdapterImplProvider: Provider = {
|
||||
provide: SocketIoAdapterImpl,
|
||||
useValue: SocketIoAdapter,
|
||||
};
|
||||
|
||||
export const WEBSOCKET_OPTIONS = Symbol('WEBSOCKET_OPTIONS');
|
||||
|
||||
export const websocketOptionsProvider: FactoryProvider = {
|
||||
provide: WEBSOCKET_OPTIONS,
|
||||
useFactory: (config: Config) => {
|
||||
return config.websocket;
|
||||
},
|
||||
inject: [Config],
|
||||
};
|
||||
|
||||
@Module({
|
||||
providers: [SocketIoAdapterImplProvider],
|
||||
exports: [SocketIoAdapterImplProvider],
|
||||
providers: [SocketIoAdapterImplProvider, websocketOptionsProvider],
|
||||
exports: [SocketIoAdapterImplProvider, websocketOptionsProvider],
|
||||
})
|
||||
export class WebSocketModule {}
|
||||
|
||||
2
packages/backend/server/src/global.d.ts
vendored
2
packages/backend/server/src/global.d.ts
vendored
@@ -1,7 +1,7 @@
|
||||
declare namespace Express {
|
||||
interface Request {
|
||||
user?: import('./core/auth/current-user').CurrentUser;
|
||||
sid?: string;
|
||||
session?: import('./core/auth/current-user').UserSession;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ export function createSockerIoAdapterImpl(
|
||||
console.error(err);
|
||||
});
|
||||
|
||||
const server = super.createIOServer(port, options) as Server;
|
||||
const server = super.createIOServer(port, options);
|
||||
server.adapter(createAdapter(pubClient, subClient));
|
||||
return server;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user