From e53d5e2e3d19ff5e44857259cf0e8ffadffb69d4 Mon Sep 17 00:00:00 2001 From: liuyi Date: Wed, 17 Apr 2024 16:32:26 +0800 Subject: [PATCH] chore(server): clean up throttler (#6326) --- packages/backend/server/src/app.module.ts | 15 +- packages/backend/server/src/app.ts | 9 +- .../server/src/core/auth/controller.ts | 9 +- .../backend/server/src/core/auth/guard.ts | 1 + .../backend/server/src/core/auth/resolver.ts | 82 +---- .../server/src/core/user/management.ts | 30 +- .../backend/server/src/core/user/resolver.ts | 42 +-- .../server/src/core/workspaces/management.ts | 22 +- .../src/core/workspaces/resolvers/blob.ts | 8 +- .../src/core/workspaces/resolvers/history.ts | 3 - .../src/core/workspaces/resolvers/page.ts | 4 +- .../core/workspaces/resolvers/workspace.ts | 19 +- .../server/src/fundamentals/cache/index.ts | 5 +- .../backend/server/src/fundamentals/index.ts | 2 +- .../src/fundamentals/throttler/decorators.ts | 38 ++ .../src/fundamentals/throttler/index.ts | 187 +++++++--- packages/backend/server/src/global.d.ts | 1 + packages/backend/server/src/schema.gql | 2 +- .../server/tests/nestjs/throttler.spec.ts | 331 ++++++++++++++++++ packages/backend/server/tests/utils/user.ts | 6 +- 20 files changed, 551 insertions(+), 265 deletions(-) create mode 100644 packages/backend/server/src/fundamentals/throttler/decorators.ts create mode 100644 packages/backend/server/tests/nestjs/throttler.spec.ts diff --git a/packages/backend/server/src/app.module.ts b/packages/backend/server/src/app.module.ts index 2af316e0ca..22a2ecdede 100644 --- a/packages/backend/server/src/app.module.ts +++ b/packages/backend/server/src/app.module.ts @@ -1,13 +1,12 @@ import { join } from 'node:path'; import { Logger, Module } from '@nestjs/common'; -import { APP_GUARD, APP_INTERCEPTOR } from '@nestjs/core'; import { ScheduleModule } from '@nestjs/schedule'; import { ServeStaticModule } from '@nestjs/serve-static'; import { get } from 'lodash-es'; import { AppController } from './app.controller'; -import { AuthGuard, AuthModule } from './core/auth'; +import { AuthModule } from './core/auth'; import { ADD_ENABLED_FEATURES, ServerConfigModule } from './core/config'; import { DocModule } from './core/doc'; import { FeatureModule } from './core/features'; @@ -17,7 +16,7 @@ import { SyncModule } from './core/sync'; import { UserModule } from './core/user'; import { WorkspaceModule } from './core/workspaces'; import { getOptionalModuleMetadata } from './fundamentals'; -import { CacheInterceptor, CacheModule } from './fundamentals/cache'; +import { CacheModule } from './fundamentals/cache'; import type { AvailablePlugins } from './fundamentals/config'; import { Config, ConfigModule } from './fundamentals/config'; import { EventModule } from './fundamentals/event'; @@ -103,16 +102,6 @@ export class AppModuleBuilder { compile() { @Module({ - providers: [ - { - provide: APP_INTERCEPTOR, - useClass: CacheInterceptor, - }, - { - provide: APP_GUARD, - useClass: AuthGuard, - }, - ], imports: this.modules, controllers: this.config.isSelfhosted ? [] : [AppController], }) diff --git a/packages/backend/server/src/app.ts b/packages/backend/server/src/app.ts index bb825398ce..c5cfca00f9 100644 --- a/packages/backend/server/src/app.ts +++ b/packages/backend/server/src/app.ts @@ -4,7 +4,12 @@ import type { NestExpressApplication } from '@nestjs/platform-express'; import cookieParser from 'cookie-parser'; import graphqlUploadExpress from 'graphql-upload/graphqlUploadExpress.mjs'; -import { GlobalExceptionFilter } from './fundamentals'; +import { AuthGuard } from './core/auth'; +import { + CacheInterceptor, + CloudThrottlerGuard, + GlobalExceptionFilter, +} from './fundamentals'; import { SocketIoAdapter, SocketIoAdapterImpl } from './fundamentals/websocket'; import { serverTimingAndCache } from './middleware/timing'; @@ -28,6 +33,8 @@ export async function createApp() { }) ); + app.useGlobalGuards(app.get(AuthGuard), app.get(CloudThrottlerGuard)); + app.useGlobalInterceptors(app.get(CacheInterceptor)); app.useGlobalFilters(new GlobalExceptionFilter(app.getHttpAdapter())); app.use(cookieParser()); diff --git a/packages/backend/server/src/core/auth/controller.ts b/packages/backend/server/src/core/auth/controller.ts index 96980e769e..80f4a9b660 100644 --- a/packages/backend/server/src/core/auth/controller.ts +++ b/packages/backend/server/src/core/auth/controller.ts @@ -14,7 +14,11 @@ import { } from '@nestjs/common'; import type { Request, Response } from 'express'; -import { PaymentRequiredException, URLHelper } from '../../fundamentals'; +import { + PaymentRequiredException, + Throttle, + URLHelper, +} from '../../fundamentals'; import { UserService } from '../user'; import { validators } from '../utils/validators'; import { CurrentUser } from './current-user'; @@ -27,6 +31,7 @@ class SignInCredential { password?: string; } +@Throttle('strict') @Controller('/api/auth') export class AuthController { constructor( @@ -158,6 +163,7 @@ export class AuthController { return this.url.safeRedirect(res, redirectUri); } + @Throttle('default', { limit: 1200 }) @Public() @Get('/session') async currentSessionUser(@CurrentUser() user?: CurrentUser) { @@ -166,6 +172,7 @@ export class AuthController { }; } + @Throttle('default', { limit: 1200 }) @Public() @Get('/sessions') async currentSessionUsers(@Req() req: Request) { diff --git a/packages/backend/server/src/core/auth/guard.ts b/packages/backend/server/src/core/auth/guard.ts index a60822001e..99479ab02a 100644 --- a/packages/backend/server/src/core/auth/guard.ts +++ b/packages/backend/server/src/core/auth/guard.ts @@ -54,6 +54,7 @@ export class AuthGuard implements CanActivate, OnModuleInit { const user = await this.auth.getUser(sessionToken, userSeq); if (user) { + req.sid = sessionToken; req.user = user; } } diff --git a/packages/backend/server/src/core/auth/resolver.ts b/packages/backend/server/src/core/auth/resolver.ts index 1869ce22d8..b347ca1f28 100644 --- a/packages/backend/server/src/core/auth/resolver.ts +++ b/packages/backend/server/src/core/auth/resolver.ts @@ -1,8 +1,4 @@ -import { - BadRequestException, - ForbiddenException, - UseGuards, -} from '@nestjs/common'; +import { BadRequestException, ForbiddenException } from '@nestjs/common'; import { Args, Context, @@ -16,7 +12,7 @@ import { } from '@nestjs/graphql'; import type { Request, Response } from 'express'; -import { CloudThrottlerGuard, Config, Throttle } from '../../fundamentals'; +import { Config, Throttle } from '../../fundamentals'; import { UserService } from '../user'; import { UserType } from '../user/types'; import { validators } from '../utils/validators'; @@ -43,7 +39,7 @@ export class ClientTokenType { * Sign up/in rate limit: 10 req/m * Other rate limit: 5 req/m */ -@UseGuards(CloudThrottlerGuard) +@Throttle('strict') @Resolver(() => UserType) export class AuthResolver { constructor( @@ -53,12 +49,6 @@ export class AuthResolver { private readonly token: TokenService ) {} - @Throttle({ - default: { - limit: 10, - ttl: 60, - }, - }) @Public() @Query(() => UserType, { name: 'currentUser', @@ -69,12 +59,6 @@ export class AuthResolver { return user; } - @Throttle({ - default: { - limit: 20, - ttl: 60, - }, - }) @ResolveField(() => ClientTokenType, { name: 'token', deprecationReason: 'use [/api/auth/authorize]', @@ -101,12 +85,6 @@ export class AuthResolver { } @Public() - @Throttle({ - default: { - limit: 10, - ttl: 60, - }, - }) @Mutation(() => UserType) async signUp( @Context() ctx: { req: Request; res: Response }, @@ -122,12 +100,6 @@ export class AuthResolver { } @Public() - @Throttle({ - default: { - limit: 10, - ttl: 60, - }, - }) @Mutation(() => UserType) async signIn( @Context() ctx: { req: Request; res: Response }, @@ -141,12 +113,6 @@ export class AuthResolver { return user; } - @Throttle({ - default: { - limit: 5, - ttl: 60, - }, - }) @Mutation(() => UserType) async changePassword( @CurrentUser() user: CurrentUser, @@ -172,12 +138,6 @@ export class AuthResolver { return user; } - @Throttle({ - default: { - limit: 5, - ttl: 60, - }, - }) @Mutation(() => UserType) async changeEmail( @CurrentUser() user: CurrentUser, @@ -202,12 +162,6 @@ export class AuthResolver { return user; } - @Throttle({ - default: { - limit: 5, - ttl: 60, - }, - }) @Mutation(() => Boolean) async sendChangePasswordEmail( @CurrentUser() user: CurrentUser, @@ -235,12 +189,6 @@ export class AuthResolver { return !res.rejected.length; } - @Throttle({ - default: { - limit: 5, - ttl: 60, - }, - }) @Mutation(() => Boolean) async sendSetPasswordEmail( @CurrentUser() user: CurrentUser, @@ -273,12 +221,6 @@ export class AuthResolver { // 4. user open confirm email page from new email // 5. user click confirm button // 6. send notification email - @Throttle({ - default: { - limit: 5, - ttl: 60, - }, - }) @Mutation(() => Boolean) async sendChangeEmail( @CurrentUser() user: CurrentUser, @@ -299,12 +241,6 @@ export class AuthResolver { return !res.rejected.length; } - @Throttle({ - default: { - limit: 5, - ttl: 60, - }, - }) @Mutation(() => Boolean) async sendVerifyChangeEmail( @CurrentUser() user: CurrentUser, @@ -347,12 +283,6 @@ export class AuthResolver { return !res.rejected.length; } - @Throttle({ - default: { - limit: 5, - ttl: 60, - }, - }) @Mutation(() => Boolean) async sendVerifyEmail( @CurrentUser() user: CurrentUser, @@ -367,12 +297,6 @@ export class AuthResolver { return !res.rejected.length; } - @Throttle({ - default: { - limit: 5, - ttl: 60, - }, - }) @Mutation(() => Boolean) async verifyEmail( @CurrentUser() user: CurrentUser, diff --git a/packages/backend/server/src/core/user/management.ts b/packages/backend/server/src/core/user/management.ts index 224946918f..786acbfba2 100644 --- a/packages/backend/server/src/core/user/management.ts +++ b/packages/backend/server/src/core/user/management.ts @@ -1,8 +1,4 @@ -import { - BadRequestException, - ForbiddenException, - UseGuards, -} from '@nestjs/common'; +import { BadRequestException, ForbiddenException } from '@nestjs/common'; import { Args, Context, @@ -13,7 +9,6 @@ import { Resolver, } from '@nestjs/graphql'; -import { CloudThrottlerGuard, Throttle } from '../../fundamentals'; import { CurrentUser } from '../auth/current-user'; import { sessionUser } from '../auth/service'; import { EarlyAccessType, FeatureManagementService } from '../features'; @@ -24,11 +19,6 @@ registerEnumType(EarlyAccessType, { name: 'EarlyAccessType', }); -/** - * User resolver - * All op rate limit: 10 req/m - */ -@UseGuards(CloudThrottlerGuard) @Resolver(() => UserType) export class UserManagementResolver { constructor( @@ -36,12 +26,6 @@ export class UserManagementResolver { private readonly feature: FeatureManagementService ) {} - @Throttle({ - default: { - limit: 10, - ttl: 60, - }, - }) @Mutation(() => Int) async addToEarlyAccess( @CurrentUser() currentUser: CurrentUser, @@ -62,12 +46,6 @@ export class UserManagementResolver { } } - @Throttle({ - default: { - limit: 10, - ttl: 60, - }, - }) @Mutation(() => Int) async removeEarlyAccess( @CurrentUser() currentUser: CurrentUser, @@ -83,12 +61,6 @@ export class UserManagementResolver { return this.feature.removeEarlyAccess(user.id); } - @Throttle({ - default: { - limit: 10, - ttl: 60, - }, - }) @Query(() => [UserType]) async earlyAccessUsers( @Context() ctx: { isAdminQuery: boolean }, diff --git a/packages/backend/server/src/core/user/resolver.ts b/packages/backend/server/src/core/user/resolver.ts index aaa0fd46b5..ec157ec61c 100644 --- a/packages/backend/server/src/core/user/resolver.ts +++ b/packages/backend/server/src/core/user/resolver.ts @@ -1,4 +1,4 @@ -import { BadRequestException, UseGuards } from '@nestjs/common'; +import { BadRequestException } from '@nestjs/common'; import { Args, Int, @@ -14,7 +14,6 @@ import { isNil, omitBy } from 'lodash-es'; import type { FileUpload } from '../../fundamentals'; import { - CloudThrottlerGuard, EventEmitter, PaymentRequiredException, Throttle, @@ -35,11 +34,6 @@ import { UserType, } from './types'; -/** - * User resolver - * All op rate limit: 10 req/m - */ -@UseGuards(CloudThrottlerGuard) @Resolver(() => UserType) export class UserResolver { constructor( @@ -51,12 +45,7 @@ export class UserResolver { private readonly event: EventEmitter ) {} - @Throttle({ - default: { - limit: 10, - ttl: 60, - }, - }) + @Throttle('strict') @Query(() => UserOrLimitedUser, { name: 'user', description: 'Get user by email', @@ -90,7 +79,6 @@ export class UserResolver { }; } - @Throttle({ default: { limit: 10, ttl: 60 } }) @ResolveField(() => UserQuotaType, { name: 'quota', nullable: true }) async getQuota(@CurrentUser() me: User) { const quota = await this.quota.getUserQuota(me.id); @@ -98,7 +86,6 @@ export class UserResolver { return quota.feature; } - @Throttle({ default: { limit: 10, ttl: 60 } }) @ResolveField(() => Int, { name: 'invoiceCount', description: 'Get user invoice count', @@ -109,7 +96,6 @@ export class UserResolver { }); } - @Throttle({ default: { limit: 10, ttl: 60 } }) @ResolveField(() => [FeatureType], { name: 'features', description: 'Enabled features of a user', @@ -118,12 +104,6 @@ export class UserResolver { return this.feature.getActivatedUserFeatures(user.id); } - @Throttle({ - default: { - limit: 10, - ttl: 60, - }, - }) @Mutation(() => UserType, { name: 'uploadAvatar', description: 'Upload user avatar', @@ -153,12 +133,6 @@ export class UserResolver { }); } - @Throttle({ - default: { - limit: 10, - ttl: 60, - }, - }) @Mutation(() => UserType, { name: 'updateProfile', }) @@ -180,12 +154,6 @@ export class UserResolver { ); } - @Throttle({ - default: { - limit: 10, - ttl: 60, - }, - }) @Mutation(() => RemoveAvatar, { name: 'removeAvatar', description: 'Remove user avatar', @@ -201,12 +169,6 @@ export class UserResolver { return { success: true }; } - @Throttle({ - default: { - limit: 10, - ttl: 60, - }, - }) @Mutation(() => DeleteAccount) async deleteAccount( @CurrentUser() user: CurrentUser diff --git a/packages/backend/server/src/core/workspaces/management.ts b/packages/backend/server/src/core/workspaces/management.ts index a4bd38fd34..e28932f970 100644 --- a/packages/backend/server/src/core/workspaces/management.ts +++ b/packages/backend/server/src/core/workspaces/management.ts @@ -1,4 +1,4 @@ -import { ForbiddenException, UseGuards } from '@nestjs/common'; +import { ForbiddenException } from '@nestjs/common'; import { Args, Int, @@ -9,13 +9,11 @@ import { Resolver, } from '@nestjs/graphql'; -import { CloudThrottlerGuard, Throttle } from '../../fundamentals'; import { CurrentUser } from '../auth'; import { FeatureManagementService, FeatureType } from '../features'; import { PermissionService } from './permission'; import { WorkspaceType } from './types'; -@UseGuards(CloudThrottlerGuard) @Resolver(() => WorkspaceType) export class WorkspaceManagementResolver { constructor( @@ -23,12 +21,6 @@ export class WorkspaceManagementResolver { private readonly permission: PermissionService ) {} - @Throttle({ - default: { - limit: 10, - ttl: 60, - }, - }) @Mutation(() => Int) async addWorkspaceFeature( @CurrentUser() currentUser: CurrentUser, @@ -42,12 +34,6 @@ export class WorkspaceManagementResolver { return this.feature.addWorkspaceFeatures(workspaceId, feature); } - @Throttle({ - default: { - limit: 10, - ttl: 60, - }, - }) @Mutation(() => Int) async removeWorkspaceFeature( @CurrentUser() currentUser: CurrentUser, @@ -61,12 +47,6 @@ export class WorkspaceManagementResolver { return this.feature.removeWorkspaceFeature(workspaceId, feature); } - @Throttle({ - default: { - limit: 10, - ttl: 60, - }, - }) @Query(() => [WorkspaceType]) async listWorkspaceFeatures( @CurrentUser() user: CurrentUser, diff --git a/packages/backend/server/src/core/workspaces/resolvers/blob.ts b/packages/backend/server/src/core/workspaces/resolvers/blob.ts index a7e16347f0..335fea36bb 100644 --- a/packages/backend/server/src/core/workspaces/resolvers/blob.ts +++ b/packages/backend/server/src/core/workspaces/resolvers/blob.ts @@ -2,7 +2,6 @@ import { ForbiddenException, Logger, PayloadTooLargeException, - UseGuards, } from '@nestjs/common'; import { Args, @@ -17,11 +16,7 @@ import { SafeIntResolver } from 'graphql-scalars'; import GraphQLUpload from 'graphql-upload/GraphQLUpload.mjs'; import type { FileUpload } from '../../../fundamentals'; -import { - CloudThrottlerGuard, - MakeCache, - PreventCache, -} from '../../../fundamentals'; +import { MakeCache, PreventCache } from '../../../fundamentals'; import { CurrentUser } from '../../auth'; import { FeatureManagementService, FeatureType } from '../../features'; import { QuotaManagementService } from '../../quota'; @@ -29,7 +24,6 @@ import { WorkspaceBlobStorage } from '../../storage'; import { PermissionService } from '../permission'; import { Permission, WorkspaceBlobSizes, WorkspaceType } from '../types'; -@UseGuards(CloudThrottlerGuard) @Resolver(() => WorkspaceType) export class WorkspaceBlobResolver { logger = new Logger(WorkspaceBlobResolver.name); diff --git a/packages/backend/server/src/core/workspaces/resolvers/history.ts b/packages/backend/server/src/core/workspaces/resolvers/history.ts index deef0851a4..9b3741c6fc 100644 --- a/packages/backend/server/src/core/workspaces/resolvers/history.ts +++ b/packages/backend/server/src/core/workspaces/resolvers/history.ts @@ -1,4 +1,3 @@ -import { UseGuards } from '@nestjs/common'; import { Args, Field, @@ -12,7 +11,6 @@ import { } from '@nestjs/graphql'; import type { SnapshotHistory } from '@prisma/client'; -import { CloudThrottlerGuard } from '../../../fundamentals'; import { CurrentUser } from '../../auth'; import { DocHistoryManager } from '../../doc'; import { DocID } from '../../utils/doc'; @@ -31,7 +29,6 @@ class DocHistoryType implements Partial { timestamp!: Date; } -@UseGuards(CloudThrottlerGuard) @Resolver(() => WorkspaceType) export class DocHistoryResolver { constructor( diff --git a/packages/backend/server/src/core/workspaces/resolvers/page.ts b/packages/backend/server/src/core/workspaces/resolvers/page.ts index efd3c3f27c..4dcb69b077 100644 --- a/packages/backend/server/src/core/workspaces/resolvers/page.ts +++ b/packages/backend/server/src/core/workspaces/resolvers/page.ts @@ -1,4 +1,4 @@ -import { BadRequestException, UseGuards } from '@nestjs/common'; +import { BadRequestException } from '@nestjs/common'; import { Args, Field, @@ -12,7 +12,6 @@ import { import type { WorkspacePage as PrismaWorkspacePage } from '@prisma/client'; import { PrismaClient } from '@prisma/client'; -import { CloudThrottlerGuard } from '../../../fundamentals'; import { CurrentUser } from '../../auth'; import { DocID } from '../../utils/doc'; import { PermissionService, PublicPageMode } from '../permission'; @@ -38,7 +37,6 @@ class WorkspacePage implements Partial { public!: boolean; } -@UseGuards(CloudThrottlerGuard) @Resolver(() => WorkspaceType) export class PagePermissionResolver { constructor( diff --git a/packages/backend/server/src/core/workspaces/resolvers/workspace.ts b/packages/backend/server/src/core/workspaces/resolvers/workspace.ts index 121812b8a9..019bb7b32d 100644 --- a/packages/backend/server/src/core/workspaces/resolvers/workspace.ts +++ b/packages/backend/server/src/core/workspaces/resolvers/workspace.ts @@ -4,7 +4,6 @@ import { Logger, NotFoundException, PayloadTooLargeException, - UseGuards, } from '@nestjs/common'; import { Args, @@ -22,7 +21,6 @@ import { applyUpdate, Doc } from 'yjs'; import type { FileUpload } from '../../../fundamentals'; import { - CloudThrottlerGuard, EventEmitter, MailService, MutexService, @@ -48,7 +46,6 @@ import { defaultWorkspaceAvatar } from '../utils'; * Public apis rate limit: 10 req/m * Other rate limit: 120 req/m */ -@UseGuards(CloudThrottlerGuard) @Resolver(() => WorkspaceType) export class WorkspaceResolver { private readonly logger = new Logger(WorkspaceResolver.name); @@ -191,12 +188,7 @@ export class WorkspaceResolver { }); } - @Throttle({ - default: { - limit: 10, - ttl: 30, - }, - }) + @Throttle('strict') @Public() @Query(() => WorkspaceType, { description: 'Get public workspace by id', @@ -422,15 +414,10 @@ export class WorkspaceResolver { } } - @Throttle({ - default: { - limit: 10, - ttl: 30, - }, - }) + @Throttle('strict') @Public() @Query(() => InvitationType, { - description: 'Update workspace', + description: 'send workspace invitation', }) async getInviteInfo(@Args('inviteId') inviteId: string) { const workspaceId = await this.prisma.workspaceUserPermission diff --git a/packages/backend/server/src/fundamentals/cache/index.ts b/packages/backend/server/src/fundamentals/cache/index.ts index 7c325d64ad..86f92c4fc8 100644 --- a/packages/backend/server/src/fundamentals/cache/index.ts +++ b/packages/backend/server/src/fundamentals/cache/index.ts @@ -1,11 +1,12 @@ import { Global, Module } from '@nestjs/common'; import { Cache, SessionCache } from './instances'; +import { CacheInterceptor } from './interceptor'; @Global() @Module({ - providers: [Cache, SessionCache], - exports: [Cache, SessionCache], + providers: [Cache, SessionCache, CacheInterceptor], + exports: [Cache, SessionCache, CacheInterceptor], }) export class CacheModule {} export { Cache, SessionCache }; diff --git a/packages/backend/server/src/fundamentals/index.ts b/packages/backend/server/src/fundamentals/index.ts index 729ea3f9ee..5abf98febb 100644 --- a/packages/backend/server/src/fundamentals/index.ts +++ b/packages/backend/server/src/fundamentals/index.ts @@ -27,7 +27,7 @@ export { export type { PrismaTransaction } from './prisma'; export * from './storage'; export { type StorageProvider, StorageProviderFactory } from './storage'; -export { AuthThrottlerGuard, CloudThrottlerGuard, Throttle } from './throttler'; +export { CloudThrottlerGuard, Throttle } from './throttler'; export { getRequestFromHost, getRequestResponseFromContext, diff --git a/packages/backend/server/src/fundamentals/throttler/decorators.ts b/packages/backend/server/src/fundamentals/throttler/decorators.ts new file mode 100644 index 0000000000..1baa7d9dd0 --- /dev/null +++ b/packages/backend/server/src/fundamentals/throttler/decorators.ts @@ -0,0 +1,38 @@ +import { applyDecorators, SetMetadata } from '@nestjs/common'; +import { SkipThrottle, Throttle as RawThrottle } from '@nestjs/throttler'; + +export type Throttlers = 'default' | 'strict'; +export const THROTTLER_PROTECTED = 'affine_throttler:protected'; + +/** + * Choose what throttler to use + * + * If a Controller or Query do not protected behind a Throttler, + * it will never be rate limited. + * + * - Ease: 120 calls within 60 seconds + * - Strict: 10 calls within 60 seconds + * + * @example + * + * \@Throttle() + * \@Throttle('strict') + * + * // the config call be override by the second parameter, + * // and the call count will be calculated separately + * \@Throttle('default', { limit: 10, ttl: 10 }) + * + */ +export function Throttle( + type: Throttlers = 'default', + override: { limit?: number; ttl?: number } = {} +): MethodDecorator & ClassDecorator { + return applyDecorators( + SetMetadata(THROTTLER_PROTECTED, type), + RawThrottle({ + [type]: override, + }) + ); +} + +export { SkipThrottle }; diff --git a/packages/backend/server/src/fundamentals/throttler/index.ts b/packages/backend/server/src/fundamentals/throttler/index.ts index f43e588229..a75490f093 100644 --- a/packages/backend/server/src/fundamentals/throttler/index.ts +++ b/packages/backend/server/src/fundamentals/throttler/index.ts @@ -1,15 +1,20 @@ import { ExecutionContext, Global, Injectable, Module } from '@nestjs/common'; +import { Reflector } from '@nestjs/core'; import { - Throttle, + InjectThrottlerOptions, + InjectThrottlerStorage, ThrottlerGuard, ThrottlerModule, - ThrottlerModuleOptions, + type ThrottlerModuleOptions, + ThrottlerOptions, ThrottlerOptionsFactory, ThrottlerStorageService, } from '@nestjs/throttler'; +import type { Request } from 'express'; import { Config } from '../config'; import { getRequestResponseFromContext } from '../utils/request'; +import { THROTTLER_PROTECTED, Throttlers } from './decorators'; @Injectable() export class ThrottlerStorage extends ThrottlerStorageService {} @@ -25,13 +30,16 @@ class CustomOptionsFactory implements ThrottlerOptionsFactory { const options: ThrottlerModuleOptions = { throttlers: [ { + name: 'default', ttl: this.config.rateLimiter.ttl * 1000, limit: this.config.rateLimiter.limit, }, + { + name: 'strict', + ttl: this.config.rateLimiter.ttl * 1000, + limit: 20, + }, ], - skipIf: () => { - return !this.config.node.prod || this.config.affine.canary; - }, storage: this.storage, }; @@ -39,6 +47,132 @@ class CustomOptionsFactory implements ThrottlerOptionsFactory { } } +@Injectable() +export class CloudThrottlerGuard extends ThrottlerGuard { + constructor( + @InjectThrottlerOptions() options: ThrottlerModuleOptions, + @InjectThrottlerStorage() storageService: ThrottlerStorage, + reflector: Reflector, + private readonly config: Config + ) { + super(options, storageService, reflector); + } + + override getRequestResponse(context: ExecutionContext) { + return getRequestResponseFromContext(context) as any; + } + + override getTracker(req: Request): Promise { + return Promise.resolve( + // ↓ prefer session id if available + `throttler:${req.sid ?? req.get('CF-Connecting-IP') ?? req.get('CF-ray') ?? req.ip}` + // ^ throttler prefix make the key in store recognizable + ); + } + + override generateKey( + context: ExecutionContext, + tracker: string, + throttler: string + ) { + if (tracker.endsWith(';custom')) { + return `${tracker};${throttler}:${context.getClass().name}.${context.getHandler().name}`; + } + + return `${tracker};${throttler}`; + } + + override async handleRequest( + context: ExecutionContext, + limit: number, + ttl: number, + throttlerOptions: ThrottlerOptions + ) { + // give it 'default' if no throttler is specified, + // so the unauthenticated users visits will always hit default throttler + // authenticated users will directly bypass unprotected APIs in [CloudThrottlerGuard.canActivate] + const throttler = this.getSpecifiedThrottler(context) ?? 'default'; + + // by pass unmatched throttlers + if (throttlerOptions.name !== throttler) { + return true; + } + + const { req, res } = this.getRequestResponse(context); + const ignoreUserAgents = + throttlerOptions.ignoreUserAgents ?? this.commonOptions.ignoreUserAgents; + if (Array.isArray(ignoreUserAgents)) { + for (const pattern of ignoreUserAgents) { + const ua = req.headers['user-agent']; + if (ua && pattern.test(ua)) { + return true; + } + } + } + + let tracker = await this.getTracker(req); + + if (this.config.node.dev) { + limit = Number.MAX_SAFE_INTEGER; + } else { + // custom limit or ttl APIs will be treated standalone + if (limit !== throttlerOptions.limit || ttl !== throttlerOptions.ttl) { + tracker += ';custom'; + } + } + + const key = this.generateKey( + context, + tracker, + throttlerOptions.name ?? 'default' + ); + const { timeToExpire, totalHits } = await this.storageService.increment( + key, + ttl + ); + + if (totalHits > limit) { + res.header('Retry-After', timeToExpire.toString()); + await this.throwThrottlingException(context, { + limit, + ttl, + key, + tracker, + totalHits, + timeToExpire, + }); + } + + res.header(`${this.headerPrefix}-Limit`, limit.toString()); + res.header( + `${this.headerPrefix}-Remaining`, + (limit - totalHits).toString() + ); + res.header(`${this.headerPrefix}-Reset`, timeToExpire.toString()); + return true; + } + + override async canActivate(context: ExecutionContext): Promise { + const { req } = this.getRequestResponse(context); + + const throttler = this.getSpecifiedThrottler(context); + + // if user is logged in, bypass non-protected handlers + if (!throttler && req.user) { + return true; + } + + return super.canActivate(context); + } + + getSpecifiedThrottler(context: ExecutionContext) { + return this.reflector.getAllAndOverride( + THROTTLER_PROTECTED, + [context.getHandler(), context.getClass()] + ); + } +} + @Global() @Module({ imports: [ @@ -46,46 +180,9 @@ class CustomOptionsFactory implements ThrottlerOptionsFactory { useClass: CustomOptionsFactory, }), ], - providers: [ThrottlerStorage], - exports: [ThrottlerStorage], + providers: [ThrottlerStorage, CloudThrottlerGuard], + exports: [ThrottlerStorage, CloudThrottlerGuard], }) export class RateLimiterModule {} -@Injectable() -export class CloudThrottlerGuard extends ThrottlerGuard { - override getRequestResponse(context: ExecutionContext) { - return getRequestResponseFromContext(context) as any; - } - - protected override getTracker(req: Record): Promise { - return Promise.resolve( - req?.get('CF-Connecting-IP') ?? req?.get('CF-ray') ?? req?.ip - ); - } -} - -@Injectable() -export class AuthThrottlerGuard extends CloudThrottlerGuard { - override async handleRequest( - context: ExecutionContext, - limit: number, - ttl: number - ): Promise { - const { req } = this.getRequestResponse(context); - - if (req?.url === '/api/auth/session') { - // relax throttle for session auto renew - return super.handleRequest(context, limit * 20, ttl, { - ttl: ttl * 20, - limit: limit * 20, - }); - } - - return super.handleRequest(context, limit, ttl, { - ttl, - limit, - }); - } -} - -export { Throttle }; +export * from './decorators'; diff --git a/packages/backend/server/src/global.d.ts b/packages/backend/server/src/global.d.ts index ebb0fae5f1..ce59a7a2d4 100644 --- a/packages/backend/server/src/global.d.ts +++ b/packages/backend/server/src/global.d.ts @@ -1,6 +1,7 @@ declare namespace Express { interface Request { user?: import('./core/auth/current-user').CurrentUser; + sid?: string; } } diff --git a/packages/backend/server/src/schema.gql b/packages/backend/server/src/schema.gql index cf112257c4..7915e8b1f8 100644 --- a/packages/backend/server/src/schema.gql +++ b/packages/backend/server/src/schema.gql @@ -265,7 +265,7 @@ type Query { currentUser: UserType earlyAccessUsers: [UserType!]! - """Update workspace""" + """send workspace invitation""" getInviteInfo(inviteId: String!): InvitationType! """Get is owner of workspace""" diff --git a/packages/backend/server/tests/nestjs/throttler.spec.ts b/packages/backend/server/tests/nestjs/throttler.spec.ts new file mode 100644 index 0000000000..a3fad649f1 --- /dev/null +++ b/packages/backend/server/tests/nestjs/throttler.spec.ts @@ -0,0 +1,331 @@ +import '../../src/plugins/config'; + +import { + Controller, + Get, + HttpStatus, + INestApplication, + UseGuards, +} from '@nestjs/common'; +import ava, { TestFn } from 'ava'; +import Sinon from 'sinon'; +import request, { type Response } from 'supertest'; + +import { AppModule } from '../../src/app.module'; +import { AuthService, Public } from '../../src/core/auth'; +import { ConfigModule } from '../../src/fundamentals/config'; +import { + CloudThrottlerGuard, + SkipThrottle, + Throttle, + ThrottlerStorage, +} from '../../src/fundamentals/throttler'; +import { createTestingApp, sessionCookie } from '../utils'; + +const test = ava as TestFn<{ + storage: ThrottlerStorage; + cookie: string; + app: INestApplication; +}>; + +@UseGuards(CloudThrottlerGuard) +@Throttle() +@Controller('/throttled') +class ThrottledController { + @Get('/default') + default() { + return 'default'; + } + + @Get('/default2') + default2() { + return 'default2'; + } + + @Get('/default3') + @Throttle('default', { limit: 10 }) + default3() { + return 'default3'; + } + + @Throttle('strict') + @Get('/strict') + strict() { + return 'strict'; + } + + @Public() + @SkipThrottle() + @Get('/skip') + skip() { + return 'skip'; + } +} + +@UseGuards(CloudThrottlerGuard) +@Controller('/nonthrottled') +class NonThrottledController { + @Public() + @SkipThrottle() + @Get('/skip') + skip() { + return 'skip'; + } + + @Public() + @Get('/default') + default() { + return 'default'; + } + + @Public() + @Throttle('strict') + @Get('/strict') + strict() { + return 'strict'; + } +} + +test.beforeEach(async t => { + const { app } = await createTestingApp({ + imports: [ + ConfigModule.forRoot({ + rateLimiter: { + ttl: 60, + limit: 120, + }, + }), + AppModule, + ], + controllers: [ThrottledController, NonThrottledController], + }); + + t.context.storage = app.get(ThrottlerStorage); + t.context.app = app; + + const auth = app.get(AuthService); + const u1 = await auth.signUp('u1', 'u1@affine.pro', 'test'); + + const res = await request(app.getHttpServer()) + .post('/api/auth/sign-in') + .send({ email: u1.email, password: 'test' }); + + t.context.cookie = sessionCookie(res.headers)!; +}); + +test.afterEach.always(async t => { + await t.context.app.close(); +}); + +function rateLimitHeaders(res: Response) { + return { + limit: res.header['x-ratelimit-limit'], + remaining: res.header['x-ratelimit-remaining'], + reset: res.header['x-ratelimit-reset'], + retryAfter: res.header['retry-after'], + }; +} + +test('should be able to prevent requests if limit is reached', async t => { + const { app } = t.context; + + const stub = Sinon.stub(app.get(ThrottlerStorage), 'increment').resolves({ + timeToExpire: 10, + totalHits: 21, + }); + const res = await request(app.getHttpServer()) + .get('/nonthrottled/strict') + .expect(HttpStatus.TOO_MANY_REQUESTS); + + const headers = rateLimitHeaders(res); + + t.is(headers.retryAfter, '10'); + + stub.restore(); +}); + +// ====== unauthenticated user visits ====== +test('should use default throttler for unauthenticated user when not specified', async t => { + const { app } = t.context; + + const res = await request(app.getHttpServer()) + .get('/nonthrottled/default') + .expect(200); + + const headers = rateLimitHeaders(res); + + t.is(headers.limit, '120'); + t.is(headers.remaining, '119'); + t.is(headers.reset, '60'); +}); + +test('should skip throttler for unauthenticated user when specified', async t => { + const { app } = t.context; + + let res = await request(app.getHttpServer()) + .get('/nonthrottled/skip') + .expect(200); + + let headers = rateLimitHeaders(res); + + t.is(headers.limit, undefined!); + t.is(headers.remaining, undefined!); + t.is(headers.reset, undefined!); + + res = await request(app.getHttpServer()).get('/throttled/skip').expect(200); + + headers = rateLimitHeaders(res); + + t.is(headers.limit, undefined!); + t.is(headers.remaining, undefined!); + t.is(headers.reset, undefined!); +}); + +test('should use specified throttler for unauthenticated user', async t => { + const { app } = t.context; + + const res = await request(app.getHttpServer()) + .get('/nonthrottled/strict') + .expect(200); + + const headers = rateLimitHeaders(res); + + t.is(headers.limit, '20'); + t.is(headers.remaining, '19'); + t.is(headers.reset, '60'); +}); + +// ==== authenticated user visits ==== +test('should not protect unspecified routes', async t => { + const { app, cookie } = t.context; + + const res = await request(app.getHttpServer()) + .get('/nonthrottled/default') + .set('Cookie', cookie) + .expect(200); + + const headers = rateLimitHeaders(res); + + t.is(headers.limit, undefined!); + t.is(headers.remaining, undefined!); + t.is(headers.reset, undefined!); +}); + +test('should use default throttler for authenticated user when not specified', async t => { + const { app, cookie } = t.context; + + const res = await request(app.getHttpServer()) + .get('/throttled/default') + .set('Cookie', cookie) + .expect(200); + + const headers = rateLimitHeaders(res); + + t.is(headers.limit, '120'); + t.is(headers.remaining, '119'); + t.is(headers.reset, '60'); +}); + +test('should use same throttler for multiple routes', async t => { + const { app, cookie } = t.context; + + let res = await request(app.getHttpServer()) + .get('/throttled/default') + .set('Cookie', cookie) + .expect(200); + + let headers = rateLimitHeaders(res); + + t.is(headers.limit, '120'); + t.is(headers.remaining, '119'); + t.is(headers.reset, '60'); + + res = await request(app.getHttpServer()) + .get('/throttled/default2') + .set('Cookie', cookie) + .expect(200); + + headers = rateLimitHeaders(res); + + t.is(headers.limit, '120'); + t.is(headers.remaining, '118'); +}); + +test('should use different throttler if specified', async t => { + const { app, cookie } = t.context; + + let res = await request(app.getHttpServer()) + .get('/throttled/default') + .set('Cookie', cookie) + .expect(200); + + let headers = rateLimitHeaders(res); + + t.is(headers.limit, '120'); + t.is(headers.remaining, '119'); + t.is(headers.reset, '60'); + + res = await request(app.getHttpServer()) + .get('/throttled/default3') + .set('Cookie', cookie) + .expect(200); + + headers = rateLimitHeaders(res); + + t.is(headers.limit, '10'); + t.is(headers.remaining, '9'); + t.is(headers.reset, '60'); +}); + +test('should skip throttler for authenticated user when specified', async t => { + const { app, cookie } = t.context; + + const res = await request(app.getHttpServer()) + .get('/throttled/skip') + .set('Cookie', cookie) + .expect(200); + + const headers = rateLimitHeaders(res); + + t.is(headers.limit, undefined!); + t.is(headers.remaining, undefined!); + t.is(headers.reset, undefined!); +}); + +test('should use specified throttler for authenticated user', async t => { + const { app, cookie } = t.context; + + const res = await request(app.getHttpServer()) + .get('/throttled/strict') + .set('Cookie', cookie) + .expect(200); + + const headers = rateLimitHeaders(res); + + t.is(headers.limit, '20'); + t.is(headers.remaining, '19'); + t.is(headers.reset, '60'); +}); + +test('should separate anonymous and authenticated user throttlers', async t => { + const { app, cookie } = t.context; + + const authenticatedUserRes = await request(app.getHttpServer()) + .get('/throttled/default') + .set('Cookie', cookie) + .expect(200); + const unauthenticatedUserRes = await request(app.getHttpServer()) + .get('/nonthrottled/default') + .expect(200); + + const authenticatedResHeaders = rateLimitHeaders(authenticatedUserRes); + const unauthenticatedResHeaders = rateLimitHeaders(unauthenticatedUserRes); + + t.is(authenticatedResHeaders.limit, '120'); + t.is(authenticatedResHeaders.remaining, '119'); + t.is(authenticatedResHeaders.reset, '60'); + + t.is(unauthenticatedResHeaders.limit, '120'); + t.is(unauthenticatedResHeaders.remaining, '119'); + t.is(unauthenticatedResHeaders.reset, '60'); +}); diff --git a/packages/backend/server/tests/utils/user.ts b/packages/backend/server/tests/utils/user.ts index 8a4849d970..1404115110 100644 --- a/packages/backend/server/tests/utils/user.ts +++ b/packages/backend/server/tests/utils/user.ts @@ -10,13 +10,13 @@ import { import type { UserType } from '../../src/core/user'; import { gql } from './common'; -export function sessionCookie(headers: any) { +export function sessionCookie(headers: any): string { const cookie = headers['set-cookie']?.find((c: string) => c.startsWith(`${AuthService.sessionCookieName}=`) ); if (!cookie) { - return null; + return ''; } return cookie.split(';')[0]; @@ -29,7 +29,7 @@ export async function getSession( const cookie = sessionCookie(signInRes.headers); const res = await request(app.getHttpServer()) .get('/api/auth/session') - .set('cookie', cookie) + .set('cookie', cookie!) .expect(200); return res.body;