chore(server): clean up throttler (#6326)

This commit is contained in:
liuyi
2024-04-17 16:32:26 +08:00
committed by GitHub
parent 5b315bfc81
commit e53d5e2e3d
20 changed files with 551 additions and 265 deletions

View File

@@ -1,13 +1,12 @@
import { join } from 'node:path'; import { join } from 'node:path';
import { Logger, Module } from '@nestjs/common'; import { Logger, Module } from '@nestjs/common';
import { APP_GUARD, APP_INTERCEPTOR } from '@nestjs/core';
import { ScheduleModule } from '@nestjs/schedule'; import { ScheduleModule } from '@nestjs/schedule';
import { ServeStaticModule } from '@nestjs/serve-static'; import { ServeStaticModule } from '@nestjs/serve-static';
import { get } from 'lodash-es'; import { get } from 'lodash-es';
import { AppController } from './app.controller'; import { AppController } from './app.controller';
import { AuthGuard, AuthModule } from './core/auth'; import { AuthModule } from './core/auth';
import { ADD_ENABLED_FEATURES, ServerConfigModule } from './core/config'; import { ADD_ENABLED_FEATURES, ServerConfigModule } from './core/config';
import { DocModule } from './core/doc'; import { DocModule } from './core/doc';
import { FeatureModule } from './core/features'; import { FeatureModule } from './core/features';
@@ -17,7 +16,7 @@ import { SyncModule } from './core/sync';
import { UserModule } from './core/user'; import { UserModule } from './core/user';
import { WorkspaceModule } from './core/workspaces'; import { WorkspaceModule } from './core/workspaces';
import { getOptionalModuleMetadata } from './fundamentals'; import { getOptionalModuleMetadata } from './fundamentals';
import { CacheInterceptor, CacheModule } from './fundamentals/cache'; import { CacheModule } from './fundamentals/cache';
import type { AvailablePlugins } from './fundamentals/config'; import type { AvailablePlugins } from './fundamentals/config';
import { Config, ConfigModule } from './fundamentals/config'; import { Config, ConfigModule } from './fundamentals/config';
import { EventModule } from './fundamentals/event'; import { EventModule } from './fundamentals/event';
@@ -103,16 +102,6 @@ export class AppModuleBuilder {
compile() { compile() {
@Module({ @Module({
providers: [
{
provide: APP_INTERCEPTOR,
useClass: CacheInterceptor,
},
{
provide: APP_GUARD,
useClass: AuthGuard,
},
],
imports: this.modules, imports: this.modules,
controllers: this.config.isSelfhosted ? [] : [AppController], controllers: this.config.isSelfhosted ? [] : [AppController],
}) })

View File

@@ -4,7 +4,12 @@ import type { NestExpressApplication } from '@nestjs/platform-express';
import cookieParser from 'cookie-parser'; import cookieParser from 'cookie-parser';
import graphqlUploadExpress from 'graphql-upload/graphqlUploadExpress.mjs'; import graphqlUploadExpress from 'graphql-upload/graphqlUploadExpress.mjs';
import { GlobalExceptionFilter } from './fundamentals'; import { AuthGuard } from './core/auth';
import {
CacheInterceptor,
CloudThrottlerGuard,
GlobalExceptionFilter,
} from './fundamentals';
import { SocketIoAdapter, SocketIoAdapterImpl } from './fundamentals/websocket'; import { SocketIoAdapter, SocketIoAdapterImpl } from './fundamentals/websocket';
import { serverTimingAndCache } from './middleware/timing'; import { serverTimingAndCache } from './middleware/timing';
@@ -28,6 +33,8 @@ export async function createApp() {
}) })
); );
app.useGlobalGuards(app.get(AuthGuard), app.get(CloudThrottlerGuard));
app.useGlobalInterceptors(app.get(CacheInterceptor));
app.useGlobalFilters(new GlobalExceptionFilter(app.getHttpAdapter())); app.useGlobalFilters(new GlobalExceptionFilter(app.getHttpAdapter()));
app.use(cookieParser()); app.use(cookieParser());

View File

@@ -14,7 +14,11 @@ import {
} from '@nestjs/common'; } from '@nestjs/common';
import type { Request, Response } from 'express'; import type { Request, Response } from 'express';
import { PaymentRequiredException, URLHelper } from '../../fundamentals'; import {
PaymentRequiredException,
Throttle,
URLHelper,
} from '../../fundamentals';
import { UserService } from '../user'; import { UserService } from '../user';
import { validators } from '../utils/validators'; import { validators } from '../utils/validators';
import { CurrentUser } from './current-user'; import { CurrentUser } from './current-user';
@@ -27,6 +31,7 @@ class SignInCredential {
password?: string; password?: string;
} }
@Throttle('strict')
@Controller('/api/auth') @Controller('/api/auth')
export class AuthController { export class AuthController {
constructor( constructor(
@@ -158,6 +163,7 @@ export class AuthController {
return this.url.safeRedirect(res, redirectUri); return this.url.safeRedirect(res, redirectUri);
} }
@Throttle('default', { limit: 1200 })
@Public() @Public()
@Get('/session') @Get('/session')
async currentSessionUser(@CurrentUser() user?: CurrentUser) { async currentSessionUser(@CurrentUser() user?: CurrentUser) {
@@ -166,6 +172,7 @@ export class AuthController {
}; };
} }
@Throttle('default', { limit: 1200 })
@Public() @Public()
@Get('/sessions') @Get('/sessions')
async currentSessionUsers(@Req() req: Request) { async currentSessionUsers(@Req() req: Request) {

View File

@@ -54,6 +54,7 @@ export class AuthGuard implements CanActivate, OnModuleInit {
const user = await this.auth.getUser(sessionToken, userSeq); const user = await this.auth.getUser(sessionToken, userSeq);
if (user) { if (user) {
req.sid = sessionToken;
req.user = user; req.user = user;
} }
} }

View File

@@ -1,8 +1,4 @@
import { import { BadRequestException, ForbiddenException } from '@nestjs/common';
BadRequestException,
ForbiddenException,
UseGuards,
} from '@nestjs/common';
import { import {
Args, Args,
Context, Context,
@@ -16,7 +12,7 @@ import {
} from '@nestjs/graphql'; } from '@nestjs/graphql';
import type { Request, Response } from 'express'; import type { Request, Response } from 'express';
import { CloudThrottlerGuard, Config, Throttle } from '../../fundamentals'; import { Config, Throttle } from '../../fundamentals';
import { UserService } from '../user'; import { UserService } from '../user';
import { UserType } from '../user/types'; import { UserType } from '../user/types';
import { validators } from '../utils/validators'; import { validators } from '../utils/validators';
@@ -43,7 +39,7 @@ export class ClientTokenType {
* Sign up/in rate limit: 10 req/m * Sign up/in rate limit: 10 req/m
* Other rate limit: 5 req/m * Other rate limit: 5 req/m
*/ */
@UseGuards(CloudThrottlerGuard) @Throttle('strict')
@Resolver(() => UserType) @Resolver(() => UserType)
export class AuthResolver { export class AuthResolver {
constructor( constructor(
@@ -53,12 +49,6 @@ export class AuthResolver {
private readonly token: TokenService private readonly token: TokenService
) {} ) {}
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Public() @Public()
@Query(() => UserType, { @Query(() => UserType, {
name: 'currentUser', name: 'currentUser',
@@ -69,12 +59,6 @@ export class AuthResolver {
return user; return user;
} }
@Throttle({
default: {
limit: 20,
ttl: 60,
},
})
@ResolveField(() => ClientTokenType, { @ResolveField(() => ClientTokenType, {
name: 'token', name: 'token',
deprecationReason: 'use [/api/auth/authorize]', deprecationReason: 'use [/api/auth/authorize]',
@@ -101,12 +85,6 @@ export class AuthResolver {
} }
@Public() @Public()
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Mutation(() => UserType) @Mutation(() => UserType)
async signUp( async signUp(
@Context() ctx: { req: Request; res: Response }, @Context() ctx: { req: Request; res: Response },
@@ -122,12 +100,6 @@ export class AuthResolver {
} }
@Public() @Public()
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Mutation(() => UserType) @Mutation(() => UserType)
async signIn( async signIn(
@Context() ctx: { req: Request; res: Response }, @Context() ctx: { req: Request; res: Response },
@@ -141,12 +113,6 @@ export class AuthResolver {
return user; return user;
} }
@Throttle({
default: {
limit: 5,
ttl: 60,
},
})
@Mutation(() => UserType) @Mutation(() => UserType)
async changePassword( async changePassword(
@CurrentUser() user: CurrentUser, @CurrentUser() user: CurrentUser,
@@ -172,12 +138,6 @@ export class AuthResolver {
return user; return user;
} }
@Throttle({
default: {
limit: 5,
ttl: 60,
},
})
@Mutation(() => UserType) @Mutation(() => UserType)
async changeEmail( async changeEmail(
@CurrentUser() user: CurrentUser, @CurrentUser() user: CurrentUser,
@@ -202,12 +162,6 @@ export class AuthResolver {
return user; return user;
} }
@Throttle({
default: {
limit: 5,
ttl: 60,
},
})
@Mutation(() => Boolean) @Mutation(() => Boolean)
async sendChangePasswordEmail( async sendChangePasswordEmail(
@CurrentUser() user: CurrentUser, @CurrentUser() user: CurrentUser,
@@ -235,12 +189,6 @@ export class AuthResolver {
return !res.rejected.length; return !res.rejected.length;
} }
@Throttle({
default: {
limit: 5,
ttl: 60,
},
})
@Mutation(() => Boolean) @Mutation(() => Boolean)
async sendSetPasswordEmail( async sendSetPasswordEmail(
@CurrentUser() user: CurrentUser, @CurrentUser() user: CurrentUser,
@@ -273,12 +221,6 @@ export class AuthResolver {
// 4. user open confirm email page from new email // 4. user open confirm email page from new email
// 5. user click confirm button // 5. user click confirm button
// 6. send notification email // 6. send notification email
@Throttle({
default: {
limit: 5,
ttl: 60,
},
})
@Mutation(() => Boolean) @Mutation(() => Boolean)
async sendChangeEmail( async sendChangeEmail(
@CurrentUser() user: CurrentUser, @CurrentUser() user: CurrentUser,
@@ -299,12 +241,6 @@ export class AuthResolver {
return !res.rejected.length; return !res.rejected.length;
} }
@Throttle({
default: {
limit: 5,
ttl: 60,
},
})
@Mutation(() => Boolean) @Mutation(() => Boolean)
async sendVerifyChangeEmail( async sendVerifyChangeEmail(
@CurrentUser() user: CurrentUser, @CurrentUser() user: CurrentUser,
@@ -347,12 +283,6 @@ export class AuthResolver {
return !res.rejected.length; return !res.rejected.length;
} }
@Throttle({
default: {
limit: 5,
ttl: 60,
},
})
@Mutation(() => Boolean) @Mutation(() => Boolean)
async sendVerifyEmail( async sendVerifyEmail(
@CurrentUser() user: CurrentUser, @CurrentUser() user: CurrentUser,
@@ -367,12 +297,6 @@ export class AuthResolver {
return !res.rejected.length; return !res.rejected.length;
} }
@Throttle({
default: {
limit: 5,
ttl: 60,
},
})
@Mutation(() => Boolean) @Mutation(() => Boolean)
async verifyEmail( async verifyEmail(
@CurrentUser() user: CurrentUser, @CurrentUser() user: CurrentUser,

View File

@@ -1,8 +1,4 @@
import { import { BadRequestException, ForbiddenException } from '@nestjs/common';
BadRequestException,
ForbiddenException,
UseGuards,
} from '@nestjs/common';
import { import {
Args, Args,
Context, Context,
@@ -13,7 +9,6 @@ import {
Resolver, Resolver,
} from '@nestjs/graphql'; } from '@nestjs/graphql';
import { CloudThrottlerGuard, Throttle } from '../../fundamentals';
import { CurrentUser } from '../auth/current-user'; import { CurrentUser } from '../auth/current-user';
import { sessionUser } from '../auth/service'; import { sessionUser } from '../auth/service';
import { EarlyAccessType, FeatureManagementService } from '../features'; import { EarlyAccessType, FeatureManagementService } from '../features';
@@ -24,11 +19,6 @@ registerEnumType(EarlyAccessType, {
name: 'EarlyAccessType', name: 'EarlyAccessType',
}); });
/**
* User resolver
* All op rate limit: 10 req/m
*/
@UseGuards(CloudThrottlerGuard)
@Resolver(() => UserType) @Resolver(() => UserType)
export class UserManagementResolver { export class UserManagementResolver {
constructor( constructor(
@@ -36,12 +26,6 @@ export class UserManagementResolver {
private readonly feature: FeatureManagementService private readonly feature: FeatureManagementService
) {} ) {}
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Mutation(() => Int) @Mutation(() => Int)
async addToEarlyAccess( async addToEarlyAccess(
@CurrentUser() currentUser: CurrentUser, @CurrentUser() currentUser: CurrentUser,
@@ -62,12 +46,6 @@ export class UserManagementResolver {
} }
} }
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Mutation(() => Int) @Mutation(() => Int)
async removeEarlyAccess( async removeEarlyAccess(
@CurrentUser() currentUser: CurrentUser, @CurrentUser() currentUser: CurrentUser,
@@ -83,12 +61,6 @@ export class UserManagementResolver {
return this.feature.removeEarlyAccess(user.id); return this.feature.removeEarlyAccess(user.id);
} }
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Query(() => [UserType]) @Query(() => [UserType])
async earlyAccessUsers( async earlyAccessUsers(
@Context() ctx: { isAdminQuery: boolean }, @Context() ctx: { isAdminQuery: boolean },

View File

@@ -1,4 +1,4 @@
import { BadRequestException, UseGuards } from '@nestjs/common'; import { BadRequestException } from '@nestjs/common';
import { import {
Args, Args,
Int, Int,
@@ -14,7 +14,6 @@ import { isNil, omitBy } from 'lodash-es';
import type { FileUpload } from '../../fundamentals'; import type { FileUpload } from '../../fundamentals';
import { import {
CloudThrottlerGuard,
EventEmitter, EventEmitter,
PaymentRequiredException, PaymentRequiredException,
Throttle, Throttle,
@@ -35,11 +34,6 @@ import {
UserType, UserType,
} from './types'; } from './types';
/**
* User resolver
* All op rate limit: 10 req/m
*/
@UseGuards(CloudThrottlerGuard)
@Resolver(() => UserType) @Resolver(() => UserType)
export class UserResolver { export class UserResolver {
constructor( constructor(
@@ -51,12 +45,7 @@ export class UserResolver {
private readonly event: EventEmitter private readonly event: EventEmitter
) {} ) {}
@Throttle({ @Throttle('strict')
default: {
limit: 10,
ttl: 60,
},
})
@Query(() => UserOrLimitedUser, { @Query(() => UserOrLimitedUser, {
name: 'user', name: 'user',
description: 'Get user by email', description: 'Get user by email',
@@ -90,7 +79,6 @@ export class UserResolver {
}; };
} }
@Throttle({ default: { limit: 10, ttl: 60 } })
@ResolveField(() => UserQuotaType, { name: 'quota', nullable: true }) @ResolveField(() => UserQuotaType, { name: 'quota', nullable: true })
async getQuota(@CurrentUser() me: User) { async getQuota(@CurrentUser() me: User) {
const quota = await this.quota.getUserQuota(me.id); const quota = await this.quota.getUserQuota(me.id);
@@ -98,7 +86,6 @@ export class UserResolver {
return quota.feature; return quota.feature;
} }
@Throttle({ default: { limit: 10, ttl: 60 } })
@ResolveField(() => Int, { @ResolveField(() => Int, {
name: 'invoiceCount', name: 'invoiceCount',
description: 'Get user invoice count', description: 'Get user invoice count',
@@ -109,7 +96,6 @@ export class UserResolver {
}); });
} }
@Throttle({ default: { limit: 10, ttl: 60 } })
@ResolveField(() => [FeatureType], { @ResolveField(() => [FeatureType], {
name: 'features', name: 'features',
description: 'Enabled features of a user', description: 'Enabled features of a user',
@@ -118,12 +104,6 @@ export class UserResolver {
return this.feature.getActivatedUserFeatures(user.id); return this.feature.getActivatedUserFeatures(user.id);
} }
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Mutation(() => UserType, { @Mutation(() => UserType, {
name: 'uploadAvatar', name: 'uploadAvatar',
description: 'Upload user avatar', description: 'Upload user avatar',
@@ -153,12 +133,6 @@ export class UserResolver {
}); });
} }
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Mutation(() => UserType, { @Mutation(() => UserType, {
name: 'updateProfile', name: 'updateProfile',
}) })
@@ -180,12 +154,6 @@ export class UserResolver {
); );
} }
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Mutation(() => RemoveAvatar, { @Mutation(() => RemoveAvatar, {
name: 'removeAvatar', name: 'removeAvatar',
description: 'Remove user avatar', description: 'Remove user avatar',
@@ -201,12 +169,6 @@ export class UserResolver {
return { success: true }; return { success: true };
} }
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Mutation(() => DeleteAccount) @Mutation(() => DeleteAccount)
async deleteAccount( async deleteAccount(
@CurrentUser() user: CurrentUser @CurrentUser() user: CurrentUser

View File

@@ -1,4 +1,4 @@
import { ForbiddenException, UseGuards } from '@nestjs/common'; import { ForbiddenException } from '@nestjs/common';
import { import {
Args, Args,
Int, Int,
@@ -9,13 +9,11 @@ import {
Resolver, Resolver,
} from '@nestjs/graphql'; } from '@nestjs/graphql';
import { CloudThrottlerGuard, Throttle } from '../../fundamentals';
import { CurrentUser } from '../auth'; import { CurrentUser } from '../auth';
import { FeatureManagementService, FeatureType } from '../features'; import { FeatureManagementService, FeatureType } from '../features';
import { PermissionService } from './permission'; import { PermissionService } from './permission';
import { WorkspaceType } from './types'; import { WorkspaceType } from './types';
@UseGuards(CloudThrottlerGuard)
@Resolver(() => WorkspaceType) @Resolver(() => WorkspaceType)
export class WorkspaceManagementResolver { export class WorkspaceManagementResolver {
constructor( constructor(
@@ -23,12 +21,6 @@ export class WorkspaceManagementResolver {
private readonly permission: PermissionService private readonly permission: PermissionService
) {} ) {}
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Mutation(() => Int) @Mutation(() => Int)
async addWorkspaceFeature( async addWorkspaceFeature(
@CurrentUser() currentUser: CurrentUser, @CurrentUser() currentUser: CurrentUser,
@@ -42,12 +34,6 @@ export class WorkspaceManagementResolver {
return this.feature.addWorkspaceFeatures(workspaceId, feature); return this.feature.addWorkspaceFeatures(workspaceId, feature);
} }
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Mutation(() => Int) @Mutation(() => Int)
async removeWorkspaceFeature( async removeWorkspaceFeature(
@CurrentUser() currentUser: CurrentUser, @CurrentUser() currentUser: CurrentUser,
@@ -61,12 +47,6 @@ export class WorkspaceManagementResolver {
return this.feature.removeWorkspaceFeature(workspaceId, feature); return this.feature.removeWorkspaceFeature(workspaceId, feature);
} }
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Query(() => [WorkspaceType]) @Query(() => [WorkspaceType])
async listWorkspaceFeatures( async listWorkspaceFeatures(
@CurrentUser() user: CurrentUser, @CurrentUser() user: CurrentUser,

View File

@@ -2,7 +2,6 @@ import {
ForbiddenException, ForbiddenException,
Logger, Logger,
PayloadTooLargeException, PayloadTooLargeException,
UseGuards,
} from '@nestjs/common'; } from '@nestjs/common';
import { import {
Args, Args,
@@ -17,11 +16,7 @@ import { SafeIntResolver } from 'graphql-scalars';
import GraphQLUpload from 'graphql-upload/GraphQLUpload.mjs'; import GraphQLUpload from 'graphql-upload/GraphQLUpload.mjs';
import type { FileUpload } from '../../../fundamentals'; import type { FileUpload } from '../../../fundamentals';
import { import { MakeCache, PreventCache } from '../../../fundamentals';
CloudThrottlerGuard,
MakeCache,
PreventCache,
} from '../../../fundamentals';
import { CurrentUser } from '../../auth'; import { CurrentUser } from '../../auth';
import { FeatureManagementService, FeatureType } from '../../features'; import { FeatureManagementService, FeatureType } from '../../features';
import { QuotaManagementService } from '../../quota'; import { QuotaManagementService } from '../../quota';
@@ -29,7 +24,6 @@ import { WorkspaceBlobStorage } from '../../storage';
import { PermissionService } from '../permission'; import { PermissionService } from '../permission';
import { Permission, WorkspaceBlobSizes, WorkspaceType } from '../types'; import { Permission, WorkspaceBlobSizes, WorkspaceType } from '../types';
@UseGuards(CloudThrottlerGuard)
@Resolver(() => WorkspaceType) @Resolver(() => WorkspaceType)
export class WorkspaceBlobResolver { export class WorkspaceBlobResolver {
logger = new Logger(WorkspaceBlobResolver.name); logger = new Logger(WorkspaceBlobResolver.name);

View File

@@ -1,4 +1,3 @@
import { UseGuards } from '@nestjs/common';
import { import {
Args, Args,
Field, Field,
@@ -12,7 +11,6 @@ import {
} from '@nestjs/graphql'; } from '@nestjs/graphql';
import type { SnapshotHistory } from '@prisma/client'; import type { SnapshotHistory } from '@prisma/client';
import { CloudThrottlerGuard } from '../../../fundamentals';
import { CurrentUser } from '../../auth'; import { CurrentUser } from '../../auth';
import { DocHistoryManager } from '../../doc'; import { DocHistoryManager } from '../../doc';
import { DocID } from '../../utils/doc'; import { DocID } from '../../utils/doc';
@@ -31,7 +29,6 @@ class DocHistoryType implements Partial<SnapshotHistory> {
timestamp!: Date; timestamp!: Date;
} }
@UseGuards(CloudThrottlerGuard)
@Resolver(() => WorkspaceType) @Resolver(() => WorkspaceType)
export class DocHistoryResolver { export class DocHistoryResolver {
constructor( constructor(

View File

@@ -1,4 +1,4 @@
import { BadRequestException, UseGuards } from '@nestjs/common'; import { BadRequestException } from '@nestjs/common';
import { import {
Args, Args,
Field, Field,
@@ -12,7 +12,6 @@ import {
import type { WorkspacePage as PrismaWorkspacePage } from '@prisma/client'; import type { WorkspacePage as PrismaWorkspacePage } from '@prisma/client';
import { PrismaClient } from '@prisma/client'; import { PrismaClient } from '@prisma/client';
import { CloudThrottlerGuard } from '../../../fundamentals';
import { CurrentUser } from '../../auth'; import { CurrentUser } from '../../auth';
import { DocID } from '../../utils/doc'; import { DocID } from '../../utils/doc';
import { PermissionService, PublicPageMode } from '../permission'; import { PermissionService, PublicPageMode } from '../permission';
@@ -38,7 +37,6 @@ class WorkspacePage implements Partial<PrismaWorkspacePage> {
public!: boolean; public!: boolean;
} }
@UseGuards(CloudThrottlerGuard)
@Resolver(() => WorkspaceType) @Resolver(() => WorkspaceType)
export class PagePermissionResolver { export class PagePermissionResolver {
constructor( constructor(

View File

@@ -4,7 +4,6 @@ import {
Logger, Logger,
NotFoundException, NotFoundException,
PayloadTooLargeException, PayloadTooLargeException,
UseGuards,
} from '@nestjs/common'; } from '@nestjs/common';
import { import {
Args, Args,
@@ -22,7 +21,6 @@ import { applyUpdate, Doc } from 'yjs';
import type { FileUpload } from '../../../fundamentals'; import type { FileUpload } from '../../../fundamentals';
import { import {
CloudThrottlerGuard,
EventEmitter, EventEmitter,
MailService, MailService,
MutexService, MutexService,
@@ -48,7 +46,6 @@ import { defaultWorkspaceAvatar } from '../utils';
* Public apis rate limit: 10 req/m * Public apis rate limit: 10 req/m
* Other rate limit: 120 req/m * Other rate limit: 120 req/m
*/ */
@UseGuards(CloudThrottlerGuard)
@Resolver(() => WorkspaceType) @Resolver(() => WorkspaceType)
export class WorkspaceResolver { export class WorkspaceResolver {
private readonly logger = new Logger(WorkspaceResolver.name); private readonly logger = new Logger(WorkspaceResolver.name);
@@ -191,12 +188,7 @@ export class WorkspaceResolver {
}); });
} }
@Throttle({ @Throttle('strict')
default: {
limit: 10,
ttl: 30,
},
})
@Public() @Public()
@Query(() => WorkspaceType, { @Query(() => WorkspaceType, {
description: 'Get public workspace by id', description: 'Get public workspace by id',
@@ -422,15 +414,10 @@ export class WorkspaceResolver {
} }
} }
@Throttle({ @Throttle('strict')
default: {
limit: 10,
ttl: 30,
},
})
@Public() @Public()
@Query(() => InvitationType, { @Query(() => InvitationType, {
description: 'Update workspace', description: 'send workspace invitation',
}) })
async getInviteInfo(@Args('inviteId') inviteId: string) { async getInviteInfo(@Args('inviteId') inviteId: string) {
const workspaceId = await this.prisma.workspaceUserPermission const workspaceId = await this.prisma.workspaceUserPermission

View File

@@ -1,11 +1,12 @@
import { Global, Module } from '@nestjs/common'; import { Global, Module } from '@nestjs/common';
import { Cache, SessionCache } from './instances'; import { Cache, SessionCache } from './instances';
import { CacheInterceptor } from './interceptor';
@Global() @Global()
@Module({ @Module({
providers: [Cache, SessionCache], providers: [Cache, SessionCache, CacheInterceptor],
exports: [Cache, SessionCache], exports: [Cache, SessionCache, CacheInterceptor],
}) })
export class CacheModule {} export class CacheModule {}
export { Cache, SessionCache }; export { Cache, SessionCache };

View File

@@ -27,7 +27,7 @@ export {
export type { PrismaTransaction } from './prisma'; export type { PrismaTransaction } from './prisma';
export * from './storage'; export * from './storage';
export { type StorageProvider, StorageProviderFactory } from './storage'; export { type StorageProvider, StorageProviderFactory } from './storage';
export { AuthThrottlerGuard, CloudThrottlerGuard, Throttle } from './throttler'; export { CloudThrottlerGuard, Throttle } from './throttler';
export { export {
getRequestFromHost, getRequestFromHost,
getRequestResponseFromContext, getRequestResponseFromContext,

View File

@@ -0,0 +1,38 @@
import { applyDecorators, SetMetadata } from '@nestjs/common';
import { SkipThrottle, Throttle as RawThrottle } from '@nestjs/throttler';
export type Throttlers = 'default' | 'strict';
export const THROTTLER_PROTECTED = 'affine_throttler:protected';
/**
* Choose what throttler to use
*
* If a Controller or Query do not protected behind a Throttler,
* it will never be rate limited.
*
* - Ease: 120 calls within 60 seconds
* - Strict: 10 calls within 60 seconds
*
* @example
*
* \@Throttle()
* \@Throttle('strict')
*
* // the config call be override by the second parameter,
* // and the call count will be calculated separately
* \@Throttle('default', { limit: 10, ttl: 10 })
*
*/
export function Throttle(
type: Throttlers = 'default',
override: { limit?: number; ttl?: number } = {}
): MethodDecorator & ClassDecorator {
return applyDecorators(
SetMetadata(THROTTLER_PROTECTED, type),
RawThrottle({
[type]: override,
})
);
}
export { SkipThrottle };

View File

@@ -1,15 +1,20 @@
import { ExecutionContext, Global, Injectable, Module } from '@nestjs/common'; import { ExecutionContext, Global, Injectable, Module } from '@nestjs/common';
import { Reflector } from '@nestjs/core';
import { import {
Throttle, InjectThrottlerOptions,
InjectThrottlerStorage,
ThrottlerGuard, ThrottlerGuard,
ThrottlerModule, ThrottlerModule,
ThrottlerModuleOptions, type ThrottlerModuleOptions,
ThrottlerOptions,
ThrottlerOptionsFactory, ThrottlerOptionsFactory,
ThrottlerStorageService, ThrottlerStorageService,
} from '@nestjs/throttler'; } from '@nestjs/throttler';
import type { Request } from 'express';
import { Config } from '../config'; import { Config } from '../config';
import { getRequestResponseFromContext } from '../utils/request'; import { getRequestResponseFromContext } from '../utils/request';
import { THROTTLER_PROTECTED, Throttlers } from './decorators';
@Injectable() @Injectable()
export class ThrottlerStorage extends ThrottlerStorageService {} export class ThrottlerStorage extends ThrottlerStorageService {}
@@ -25,13 +30,16 @@ class CustomOptionsFactory implements ThrottlerOptionsFactory {
const options: ThrottlerModuleOptions = { const options: ThrottlerModuleOptions = {
throttlers: [ throttlers: [
{ {
name: 'default',
ttl: this.config.rateLimiter.ttl * 1000, ttl: this.config.rateLimiter.ttl * 1000,
limit: this.config.rateLimiter.limit, limit: this.config.rateLimiter.limit,
}, },
{
name: 'strict',
ttl: this.config.rateLimiter.ttl * 1000,
limit: 20,
},
], ],
skipIf: () => {
return !this.config.node.prod || this.config.affine.canary;
},
storage: this.storage, storage: this.storage,
}; };
@@ -39,6 +47,132 @@ class CustomOptionsFactory implements ThrottlerOptionsFactory {
} }
} }
@Injectable()
export class CloudThrottlerGuard extends ThrottlerGuard {
constructor(
@InjectThrottlerOptions() options: ThrottlerModuleOptions,
@InjectThrottlerStorage() storageService: ThrottlerStorage,
reflector: Reflector,
private readonly config: Config
) {
super(options, storageService, reflector);
}
override getRequestResponse(context: ExecutionContext) {
return getRequestResponseFromContext(context) as any;
}
override getTracker(req: Request): Promise<string> {
return Promise.resolve(
// ↓ prefer session id if available
`throttler:${req.sid ?? req.get('CF-Connecting-IP') ?? req.get('CF-ray') ?? req.ip}`
// ^ throttler prefix make the key in store recognizable
);
}
override generateKey(
context: ExecutionContext,
tracker: string,
throttler: string
) {
if (tracker.endsWith(';custom')) {
return `${tracker};${throttler}:${context.getClass().name}.${context.getHandler().name}`;
}
return `${tracker};${throttler}`;
}
override async handleRequest(
context: ExecutionContext,
limit: number,
ttl: number,
throttlerOptions: ThrottlerOptions
) {
// give it 'default' if no throttler is specified,
// so the unauthenticated users visits will always hit default throttler
// authenticated users will directly bypass unprotected APIs in [CloudThrottlerGuard.canActivate]
const throttler = this.getSpecifiedThrottler(context) ?? 'default';
// by pass unmatched throttlers
if (throttlerOptions.name !== throttler) {
return true;
}
const { req, res } = this.getRequestResponse(context);
const ignoreUserAgents =
throttlerOptions.ignoreUserAgents ?? this.commonOptions.ignoreUserAgents;
if (Array.isArray(ignoreUserAgents)) {
for (const pattern of ignoreUserAgents) {
const ua = req.headers['user-agent'];
if (ua && pattern.test(ua)) {
return true;
}
}
}
let tracker = await this.getTracker(req);
if (this.config.node.dev) {
limit = Number.MAX_SAFE_INTEGER;
} else {
// custom limit or ttl APIs will be treated standalone
if (limit !== throttlerOptions.limit || ttl !== throttlerOptions.ttl) {
tracker += ';custom';
}
}
const key = this.generateKey(
context,
tracker,
throttlerOptions.name ?? 'default'
);
const { timeToExpire, totalHits } = await this.storageService.increment(
key,
ttl
);
if (totalHits > limit) {
res.header('Retry-After', timeToExpire.toString());
await this.throwThrottlingException(context, {
limit,
ttl,
key,
tracker,
totalHits,
timeToExpire,
});
}
res.header(`${this.headerPrefix}-Limit`, limit.toString());
res.header(
`${this.headerPrefix}-Remaining`,
(limit - totalHits).toString()
);
res.header(`${this.headerPrefix}-Reset`, timeToExpire.toString());
return true;
}
override async canActivate(context: ExecutionContext): Promise<boolean> {
const { req } = this.getRequestResponse(context);
const throttler = this.getSpecifiedThrottler(context);
// if user is logged in, bypass non-protected handlers
if (!throttler && req.user) {
return true;
}
return super.canActivate(context);
}
getSpecifiedThrottler(context: ExecutionContext) {
return this.reflector.getAllAndOverride<Throttlers | undefined>(
THROTTLER_PROTECTED,
[context.getHandler(), context.getClass()]
);
}
}
@Global() @Global()
@Module({ @Module({
imports: [ imports: [
@@ -46,46 +180,9 @@ class CustomOptionsFactory implements ThrottlerOptionsFactory {
useClass: CustomOptionsFactory, useClass: CustomOptionsFactory,
}), }),
], ],
providers: [ThrottlerStorage], providers: [ThrottlerStorage, CloudThrottlerGuard],
exports: [ThrottlerStorage], exports: [ThrottlerStorage, CloudThrottlerGuard],
}) })
export class RateLimiterModule {} export class RateLimiterModule {}
@Injectable() export * from './decorators';
export class CloudThrottlerGuard extends ThrottlerGuard {
override getRequestResponse(context: ExecutionContext) {
return getRequestResponseFromContext(context) as any;
}
protected override getTracker(req: Record<string, any>): Promise<string> {
return Promise.resolve(
req?.get('CF-Connecting-IP') ?? req?.get('CF-ray') ?? req?.ip
);
}
}
@Injectable()
export class AuthThrottlerGuard extends CloudThrottlerGuard {
override async handleRequest(
context: ExecutionContext,
limit: number,
ttl: number
): Promise<boolean> {
const { req } = this.getRequestResponse(context);
if (req?.url === '/api/auth/session') {
// relax throttle for session auto renew
return super.handleRequest(context, limit * 20, ttl, {
ttl: ttl * 20,
limit: limit * 20,
});
}
return super.handleRequest(context, limit, ttl, {
ttl,
limit,
});
}
}
export { Throttle };

View File

@@ -1,6 +1,7 @@
declare namespace Express { declare namespace Express {
interface Request { interface Request {
user?: import('./core/auth/current-user').CurrentUser; user?: import('./core/auth/current-user').CurrentUser;
sid?: string;
} }
} }

View File

@@ -265,7 +265,7 @@ type Query {
currentUser: UserType currentUser: UserType
earlyAccessUsers: [UserType!]! earlyAccessUsers: [UserType!]!
"""Update workspace""" """send workspace invitation"""
getInviteInfo(inviteId: String!): InvitationType! getInviteInfo(inviteId: String!): InvitationType!
"""Get is owner of workspace""" """Get is owner of workspace"""

View File

@@ -0,0 +1,331 @@
import '../../src/plugins/config';
import {
Controller,
Get,
HttpStatus,
INestApplication,
UseGuards,
} from '@nestjs/common';
import ava, { TestFn } from 'ava';
import Sinon from 'sinon';
import request, { type Response } from 'supertest';
import { AppModule } from '../../src/app.module';
import { AuthService, Public } from '../../src/core/auth';
import { ConfigModule } from '../../src/fundamentals/config';
import {
CloudThrottlerGuard,
SkipThrottle,
Throttle,
ThrottlerStorage,
} from '../../src/fundamentals/throttler';
import { createTestingApp, sessionCookie } from '../utils';
const test = ava as TestFn<{
storage: ThrottlerStorage;
cookie: string;
app: INestApplication;
}>;
@UseGuards(CloudThrottlerGuard)
@Throttle()
@Controller('/throttled')
class ThrottledController {
@Get('/default')
default() {
return 'default';
}
@Get('/default2')
default2() {
return 'default2';
}
@Get('/default3')
@Throttle('default', { limit: 10 })
default3() {
return 'default3';
}
@Throttle('strict')
@Get('/strict')
strict() {
return 'strict';
}
@Public()
@SkipThrottle()
@Get('/skip')
skip() {
return 'skip';
}
}
@UseGuards(CloudThrottlerGuard)
@Controller('/nonthrottled')
class NonThrottledController {
@Public()
@SkipThrottle()
@Get('/skip')
skip() {
return 'skip';
}
@Public()
@Get('/default')
default() {
return 'default';
}
@Public()
@Throttle('strict')
@Get('/strict')
strict() {
return 'strict';
}
}
test.beforeEach(async t => {
const { app } = await createTestingApp({
imports: [
ConfigModule.forRoot({
rateLimiter: {
ttl: 60,
limit: 120,
},
}),
AppModule,
],
controllers: [ThrottledController, NonThrottledController],
});
t.context.storage = app.get(ThrottlerStorage);
t.context.app = app;
const auth = app.get(AuthService);
const u1 = await auth.signUp('u1', 'u1@affine.pro', 'test');
const res = await request(app.getHttpServer())
.post('/api/auth/sign-in')
.send({ email: u1.email, password: 'test' });
t.context.cookie = sessionCookie(res.headers)!;
});
test.afterEach.always(async t => {
await t.context.app.close();
});
function rateLimitHeaders(res: Response) {
return {
limit: res.header['x-ratelimit-limit'],
remaining: res.header['x-ratelimit-remaining'],
reset: res.header['x-ratelimit-reset'],
retryAfter: res.header['retry-after'],
};
}
test('should be able to prevent requests if limit is reached', async t => {
const { app } = t.context;
const stub = Sinon.stub(app.get(ThrottlerStorage), 'increment').resolves({
timeToExpire: 10,
totalHits: 21,
});
const res = await request(app.getHttpServer())
.get('/nonthrottled/strict')
.expect(HttpStatus.TOO_MANY_REQUESTS);
const headers = rateLimitHeaders(res);
t.is(headers.retryAfter, '10');
stub.restore();
});
// ====== unauthenticated user visits ======
test('should use default throttler for unauthenticated user when not specified', async t => {
const { app } = t.context;
const res = await request(app.getHttpServer())
.get('/nonthrottled/default')
.expect(200);
const headers = rateLimitHeaders(res);
t.is(headers.limit, '120');
t.is(headers.remaining, '119');
t.is(headers.reset, '60');
});
test('should skip throttler for unauthenticated user when specified', async t => {
const { app } = t.context;
let res = await request(app.getHttpServer())
.get('/nonthrottled/skip')
.expect(200);
let headers = rateLimitHeaders(res);
t.is(headers.limit, undefined!);
t.is(headers.remaining, undefined!);
t.is(headers.reset, undefined!);
res = await request(app.getHttpServer()).get('/throttled/skip').expect(200);
headers = rateLimitHeaders(res);
t.is(headers.limit, undefined!);
t.is(headers.remaining, undefined!);
t.is(headers.reset, undefined!);
});
test('should use specified throttler for unauthenticated user', async t => {
const { app } = t.context;
const res = await request(app.getHttpServer())
.get('/nonthrottled/strict')
.expect(200);
const headers = rateLimitHeaders(res);
t.is(headers.limit, '20');
t.is(headers.remaining, '19');
t.is(headers.reset, '60');
});
// ==== authenticated user visits ====
test('should not protect unspecified routes', async t => {
const { app, cookie } = t.context;
const res = await request(app.getHttpServer())
.get('/nonthrottled/default')
.set('Cookie', cookie)
.expect(200);
const headers = rateLimitHeaders(res);
t.is(headers.limit, undefined!);
t.is(headers.remaining, undefined!);
t.is(headers.reset, undefined!);
});
test('should use default throttler for authenticated user when not specified', async t => {
const { app, cookie } = t.context;
const res = await request(app.getHttpServer())
.get('/throttled/default')
.set('Cookie', cookie)
.expect(200);
const headers = rateLimitHeaders(res);
t.is(headers.limit, '120');
t.is(headers.remaining, '119');
t.is(headers.reset, '60');
});
test('should use same throttler for multiple routes', async t => {
const { app, cookie } = t.context;
let res = await request(app.getHttpServer())
.get('/throttled/default')
.set('Cookie', cookie)
.expect(200);
let headers = rateLimitHeaders(res);
t.is(headers.limit, '120');
t.is(headers.remaining, '119');
t.is(headers.reset, '60');
res = await request(app.getHttpServer())
.get('/throttled/default2')
.set('Cookie', cookie)
.expect(200);
headers = rateLimitHeaders(res);
t.is(headers.limit, '120');
t.is(headers.remaining, '118');
});
test('should use different throttler if specified', async t => {
const { app, cookie } = t.context;
let res = await request(app.getHttpServer())
.get('/throttled/default')
.set('Cookie', cookie)
.expect(200);
let headers = rateLimitHeaders(res);
t.is(headers.limit, '120');
t.is(headers.remaining, '119');
t.is(headers.reset, '60');
res = await request(app.getHttpServer())
.get('/throttled/default3')
.set('Cookie', cookie)
.expect(200);
headers = rateLimitHeaders(res);
t.is(headers.limit, '10');
t.is(headers.remaining, '9');
t.is(headers.reset, '60');
});
test('should skip throttler for authenticated user when specified', async t => {
const { app, cookie } = t.context;
const res = await request(app.getHttpServer())
.get('/throttled/skip')
.set('Cookie', cookie)
.expect(200);
const headers = rateLimitHeaders(res);
t.is(headers.limit, undefined!);
t.is(headers.remaining, undefined!);
t.is(headers.reset, undefined!);
});
test('should use specified throttler for authenticated user', async t => {
const { app, cookie } = t.context;
const res = await request(app.getHttpServer())
.get('/throttled/strict')
.set('Cookie', cookie)
.expect(200);
const headers = rateLimitHeaders(res);
t.is(headers.limit, '20');
t.is(headers.remaining, '19');
t.is(headers.reset, '60');
});
test('should separate anonymous and authenticated user throttlers', async t => {
const { app, cookie } = t.context;
const authenticatedUserRes = await request(app.getHttpServer())
.get('/throttled/default')
.set('Cookie', cookie)
.expect(200);
const unauthenticatedUserRes = await request(app.getHttpServer())
.get('/nonthrottled/default')
.expect(200);
const authenticatedResHeaders = rateLimitHeaders(authenticatedUserRes);
const unauthenticatedResHeaders = rateLimitHeaders(unauthenticatedUserRes);
t.is(authenticatedResHeaders.limit, '120');
t.is(authenticatedResHeaders.remaining, '119');
t.is(authenticatedResHeaders.reset, '60');
t.is(unauthenticatedResHeaders.limit, '120');
t.is(unauthenticatedResHeaders.remaining, '119');
t.is(unauthenticatedResHeaders.reset, '60');
});

View File

@@ -10,13 +10,13 @@ import {
import type { UserType } from '../../src/core/user'; import type { UserType } from '../../src/core/user';
import { gql } from './common'; import { gql } from './common';
export function sessionCookie(headers: any) { export function sessionCookie(headers: any): string {
const cookie = headers['set-cookie']?.find((c: string) => const cookie = headers['set-cookie']?.find((c: string) =>
c.startsWith(`${AuthService.sessionCookieName}=`) c.startsWith(`${AuthService.sessionCookieName}=`)
); );
if (!cookie) { if (!cookie) {
return null; return '';
} }
return cookie.split(';')[0]; return cookie.split(';')[0];
@@ -29,7 +29,7 @@ export async function getSession(
const cookie = sessionCookie(signInRes.headers); const cookie = sessionCookie(signInRes.headers);
const res = await request(app.getHttpServer()) const res = await request(app.getHttpServer())
.get('/api/auth/session') .get('/api/auth/session')
.set('cookie', cookie) .set('cookie', cookie!)
.expect(200); .expect(200);
return res.body; return res.body;