feat(server): authenticate user before ws connected (#7777)

This commit is contained in:
forehalo
2024-08-08 08:30:55 +00:00
parent 83244f0201
commit f2eafc374c
17 changed files with 232 additions and 87 deletions

View File

@@ -1,6 +1,6 @@
import type { ExecutionContext } from '@nestjs/common';
import { createParamDecorator } from '@nestjs/common';
import { User } from '@prisma/client';
import { User, UserSession } from '@prisma/client';
import { getRequestResponseFromContext } from '../../fundamentals';
@@ -53,3 +53,5 @@ export interface CurrentUser
hasPassword: boolean | null;
emailVerified: boolean;
}
export { type UserSession };

View File

@@ -1,15 +1,22 @@
import type {
CanActivate,
ExecutionContext,
FactoryProvider,
OnModuleInit,
} from '@nestjs/common';
import { Injectable, SetMetadata, UseGuards } from '@nestjs/common';
import { ModuleRef, Reflector } from '@nestjs/core';
import type { Request } from 'express';
import {
AuthenticationRequired,
Config,
getRequestResponseFromContext,
mapAnyError,
parseCookies,
} from '../../fundamentals';
import { WEBSOCKET_OPTIONS } from '../../fundamentals/websocket';
import { CurrentUser, UserSession } from './current-user';
import { AuthService, parseAuthUserSeqNum } from './service';
function extractTokenFromHeader(authorization: string) {
@@ -38,37 +45,9 @@ export class AuthGuard implements CanActivate, OnModuleInit {
async canActivate(context: ExecutionContext) {
const { req, res } = getRequestResponseFromContext(context);
// check cookie
let sessionToken: string | undefined =
req.cookies[AuthService.sessionCookieName];
if (!sessionToken && req.headers.authorization) {
sessionToken = extractTokenFromHeader(req.headers.authorization);
}
if (sessionToken) {
const userSeq = parseAuthUserSeqNum(
req.headers[AuthService.authUserSeqHeaderName]
);
const { user, expiresAt } = await this.auth.getUser(
sessionToken,
userSeq
);
if (res && user && expiresAt) {
await this.auth.refreshUserSessionIfNeeded(
req,
res,
sessionToken,
user.id,
expiresAt
);
}
if (user) {
req.sid = sessionToken;
req.user = user;
}
const userSession = await this.signIn(req);
if (res && userSession && userSession.session.expiresAt) {
await this.auth.refreshUserSessionIfNeeded(req, res, userSession.session);
}
// api is public
@@ -84,9 +63,44 @@ export class AuthGuard implements CanActivate, OnModuleInit {
if (!req.user) {
throw new AuthenticationRequired();
}
return true;
}
async signIn(
req: Request
): Promise<{ user: CurrentUser; session: UserSession } | null> {
if (req.user && req.session) {
return {
user: req.user,
session: req.session,
};
}
parseCookies(req);
let sessionToken: string | undefined =
req.cookies[AuthService.sessionCookieName];
if (!sessionToken && req.headers.authorization) {
sessionToken = extractTokenFromHeader(req.headers.authorization);
}
if (sessionToken) {
const userSeq = parseAuthUserSeqNum(
req.headers[AuthService.authUserSeqHeaderName]
);
const userSession = await this.auth.getUserSession(sessionToken, userSeq);
if (userSession) {
req.session = userSession.session;
req.user = userSession.user;
}
return userSession;
}
return null;
}
}
/**
@@ -111,3 +125,35 @@ export const Auth = () => {
// api is public accessible
export const Public = () => SetMetadata(PUBLIC_ENTRYPOINT_SYMBOL, true);
export const AuthWebsocketOptionsProvider: FactoryProvider = {
provide: WEBSOCKET_OPTIONS,
useFactory: (config: Config, guard: AuthGuard) => {
return {
...config.websocket,
allowRequest: async (
req: any,
pass: (err: string | null | undefined, success: boolean) => void
) => {
if (!config.websocket.requireAuthentication) {
return pass(null, true);
}
try {
const authentication = await guard.signIn(req);
if (authentication) {
return pass(null, true);
} else {
return pass('unauthenticated', false);
}
} catch (e) {
const error = mapAnyError(e);
error.log('Websocket');
return pass('unauthenticated', false);
}
},
};
},
inject: [Config, AuthGuard],
};

View File

@@ -6,15 +6,21 @@ import { FeatureModule } from '../features';
import { QuotaModule } from '../quota';
import { UserModule } from '../user';
import { AuthController } from './controller';
import { AuthGuard } from './guard';
import { AuthGuard, AuthWebsocketOptionsProvider } from './guard';
import { AuthResolver } from './resolver';
import { AuthService } from './service';
import { TokenService, TokenType } from './token';
@Module({
imports: [FeatureModule, UserModule, QuotaModule],
providers: [AuthService, AuthResolver, TokenService, AuthGuard],
exports: [AuthService, AuthGuard],
providers: [
AuthService,
AuthResolver,
TokenService,
AuthGuard,
AuthWebsocketOptionsProvider,
],
exports: [AuthService, AuthGuard, AuthWebsocketOptionsProvider],
controllers: [AuthController],
})
export class AuthModule {}

View File

@@ -1,6 +1,6 @@
import { Injectable, OnApplicationBootstrap } from '@nestjs/common';
import { Cron, CronExpression } from '@nestjs/schedule';
import type { User } from '@prisma/client';
import type { User, UserSession } from '@prisma/client';
import { PrismaClient } from '@prisma/client';
import type { CookieOptions, Request, Response } from 'express';
import { assign, omit } from 'lodash-es';
@@ -121,27 +121,27 @@ export class AuthService implements OnApplicationBootstrap {
return sessionUser(user);
}
async getUser(
async getUserSession(
token: string,
seq = 0
): Promise<{ user: CurrentUser | null; expiresAt: Date | null }> {
): Promise<{ user: CurrentUser; session: UserSession } | null> {
const session = await this.getSession(token);
// no such session
if (!session) {
return { user: null, expiresAt: null };
return null;
}
const userSession = session.userSessions.at(seq);
// no such user session
if (!userSession) {
return { user: null, expiresAt: null };
return null;
}
// user session expired
if (userSession.expiresAt && userSession.expiresAt <= new Date()) {
return { user: null, expiresAt: null };
return null;
}
const user = await this.db.user.findUnique({
@@ -149,10 +149,10 @@ export class AuthService implements OnApplicationBootstrap {
});
if (!user) {
return { user: null, expiresAt: null };
return null;
}
return { user: sessionUser(user), expiresAt: userSession.expiresAt };
return { user: sessionUser(user), session: userSession };
}
async getUserList(token: string) {
@@ -251,12 +251,13 @@ export class AuthService implements OnApplicationBootstrap {
async refreshUserSessionIfNeeded(
_req: Request,
res: Response,
sessionId: string,
userId: string,
expiresAt: Date,
session: UserSession,
ttr = this.config.auth.session.ttr
): Promise<boolean> {
if (expiresAt && expiresAt.getTime() - Date.now() > ttr * 1000) {
if (
session.expiresAt &&
session.expiresAt.getTime() - Date.now() > ttr * 1000
) {
// no need to refresh
return false;
}
@@ -267,17 +268,14 @@ export class AuthService implements OnApplicationBootstrap {
await this.db.userSession.update({
where: {
sessionId_userId: {
sessionId,
userId,
},
id: session.id,
},
data: {
expiresAt: newExpiresAt,
},
});
res.cookie(AuthService.sessionCookieName, sessionId, {
res.cookie(AuthService.sessionCookieName, session.sessionId, {
expires: newExpiresAt,
...this.cookieOptions,
});

View File

@@ -50,12 +50,7 @@ function Awareness(workspaceId: string): `${string}:awareness` {
return `${workspaceId}:awareness`;
}
@WebSocketGateway({
cors: !AFFiNE.node.prod,
transports: ['websocket'],
// see: https://socket.io/docs/v4/server-options/#maxhttpbuffersize
maxHttpBufferSize: 1e8, // 100 MB
})
@WebSocketGateway()
export class EventsGateway implements OnGatewayConnection, OnGatewayDisconnect {
protected logger = new Logger(EventsGateway.name);
private connectionCount = 0;