From 1b0864eb60388ae46b9db6935b5e548387349335 Mon Sep 17 00:00:00 2001 From: darkskygit Date: Tue, 16 Apr 2024 09:41:48 +0000 Subject: [PATCH] feat: check quota correctly (#6561) --- .../server/src/core/features/feature.ts | 13 ++++ .../server/src/core/features/management.ts | 9 ++- .../server/src/core/features/service.ts | 37 ++++++++++- .../server/src/core/features/types/common.ts | 5 +- .../server/src/core/features/types/index.ts | 8 +++ .../core/features/types/unlimited-copilot.ts | 8 +++ .../backend/server/src/core/quota/index.ts | 2 +- .../backend/server/src/core/quota/schema.ts | 51 +++++++++++++++- .../backend/server/src/core/quota/service.ts | 61 ++++++++++++++----- .../backend/server/src/core/user/resolver.ts | 2 +- .../server/src/core/workspaces/management.ts | 7 +-- .../migrations/1705395933447-new-free-plan.ts | 2 +- .../1706513866287-business-blob-limit.ts | 2 +- .../1713164714634-copilot-feature.ts | 23 +++++++ .../src/data/migrations/utils/user-quotas.ts | 18 ++++-- .../server/src/plugins/copilot/controller.ts | 14 +++-- .../server/src/plugins/copilot/index.ts | 3 + .../server/src/plugins/copilot/resolver.ts | 34 ++--------- .../server/src/plugins/copilot/session.ts | 47 +++++++++++++- .../server/src/plugins/payment/service.ts | 5 +- .../server/src/plugins/payment/types.ts | 5 +- packages/backend/server/src/schema.gql | 3 +- packages/backend/server/tests/feature.spec.ts | 2 +- packages/backend/server/tests/quota.spec.ts | 32 ++++++---- .../account-setting/ai-usage-panel.tsx | 4 +- packages/frontend/graphql/src/schema.ts | 7 ++- 26 files changed, 309 insertions(+), 95 deletions(-) create mode 100644 packages/backend/server/src/core/features/types/unlimited-copilot.ts create mode 100644 packages/backend/server/src/data/migrations/1713164714634-copilot-feature.ts diff --git a/packages/backend/server/src/core/features/feature.ts b/packages/backend/server/src/core/features/feature.ts index 61a99aa1af..ee68d57139 100644 --- a/packages/backend/server/src/core/features/feature.ts +++ b/packages/backend/server/src/core/features/feature.ts @@ -54,10 +54,23 @@ export class UnlimitedWorkspaceFeatureConfig extends FeatureConfig { } } +export class UnlimitedCopilotFeatureConfig extends FeatureConfig { + override config!: Feature & { feature: FeatureType.UnlimitedCopilot }; + + constructor(data: any) { + super(data); + + if (this.config.feature !== FeatureType.UnlimitedCopilot) { + throw new Error('Invalid feature config: type is not UnlimitedWorkspace'); + } + } +} + const FeatureConfigMap = { [FeatureType.Copilot]: CopilotFeatureConfig, [FeatureType.EarlyAccess]: EarlyAccessFeatureConfig, [FeatureType.UnlimitedWorkspace]: UnlimitedWorkspaceFeatureConfig, + [FeatureType.UnlimitedCopilot]: UnlimitedCopilotFeatureConfig, }; export type FeatureConfigType = InstanceType< diff --git a/packages/backend/server/src/core/features/management.ts b/packages/backend/server/src/core/features/management.ts index c5df3713d1..5d0cc40745 100644 --- a/packages/backend/server/src/core/features/management.ts +++ b/packages/backend/server/src/core/features/management.ts @@ -35,7 +35,6 @@ export class FeatureManagementService { return this.feature.addUserFeature( userId, FeatureType.EarlyAccess, - 2, 'Early access user' ); } @@ -116,9 +115,9 @@ export class FeatureManagementService { return this.feature.listFeatureWorkspaces(feature); } - async getUserFeatures(userId: string): Promise { - return (await this.feature.getUserFeatures(userId)).map( - f => f.feature.name - ); + // ======== User Feature ======== + async getActivatedUserFeatures(userId: string): Promise { + const features = await this.feature.getActivatedUserFeatures(userId); + return features.map(f => f.feature.name); } } diff --git a/packages/backend/server/src/core/features/service.ts b/packages/backend/server/src/core/features/service.ts index d90581be74..0ac9b23f28 100644 --- a/packages/backend/server/src/core/features/service.ts +++ b/packages/backend/server/src/core/features/service.ts @@ -59,11 +59,17 @@ export class FeatureService { async addUserFeature( userId: string, feature: FeatureType, - version: number, reason: string, expiredAt?: Date | string ) { return this.prisma.$transaction(async tx => { + const latestVersion = await tx.features + .aggregate({ + where: { feature }, + _max: { version: true }, + }) + .then(r => r._max.version || 1); + const latestFlag = await tx.userFeatures.findFirst({ where: { userId, @@ -95,7 +101,7 @@ export class FeatureService { connect: { feature_version: { feature, - version, + version: latestVersion, }, type: FeatureKind.Feature, }, @@ -157,6 +163,33 @@ export class FeatureService { return configs.filter(feature => !!feature.feature); } + async getActivatedUserFeatures(userId: string) { + const features = await this.prisma.userFeatures.findMany({ + where: { + user: { id: userId }, + feature: { type: FeatureKind.Feature }, + activated: true, + OR: [{ expiredAt: null }, { expiredAt: { gt: new Date() } }], + }, + select: { + activated: true, + reason: true, + createdAt: true, + expiredAt: true, + featureId: true, + }, + }); + + const configs = await Promise.all( + features.map(async feature => ({ + ...feature, + feature: await getFeature(this.prisma, feature.featureId), + })) + ); + + return configs.filter(feature => !!feature.feature); + } + async listFeatureUsers(feature: FeatureType) { return this.prisma.userFeatures .findMany({ diff --git a/packages/backend/server/src/core/features/types/common.ts b/packages/backend/server/src/core/features/types/common.ts index 3095b49e0f..c0ef64fa30 100644 --- a/packages/backend/server/src/core/features/types/common.ts +++ b/packages/backend/server/src/core/features/types/common.ts @@ -1,8 +1,11 @@ import { registerEnumType } from '@nestjs/graphql'; export enum FeatureType { - Copilot = 'copilot', + // user feature EarlyAccess = 'early_access', + UnlimitedCopilot = 'unlimited_copilot', + // workspace feature + Copilot = 'copilot', UnlimitedWorkspace = 'unlimited_workspace', } diff --git a/packages/backend/server/src/core/features/types/index.ts b/packages/backend/server/src/core/features/types/index.ts index f732bce242..7009a63184 100644 --- a/packages/backend/server/src/core/features/types/index.ts +++ b/packages/backend/server/src/core/features/types/index.ts @@ -3,6 +3,7 @@ import { z } from 'zod'; import { FeatureType } from './common'; import { featureCopilot } from './copilot'; import { featureEarlyAccess } from './early-access'; +import { featureUnlimitedCopilot } from './unlimited-copilot'; import { featureUnlimitedWorkspace } from './unlimited-workspace'; /// ======== common schema ======== @@ -52,6 +53,12 @@ export const Features: Feature[] = [ version: 1, configs: {}, }, + { + feature: FeatureType.UnlimitedCopilot, + type: FeatureKind.Feature, + version: 1, + configs: {}, + }, ]; /// ======== schema infer ======== @@ -65,6 +72,7 @@ export const FeatureSchema = commonFeatureSchema featureCopilot, featureEarlyAccess, featureUnlimitedWorkspace, + featureUnlimitedCopilot, ]) ); diff --git a/packages/backend/server/src/core/features/types/unlimited-copilot.ts b/packages/backend/server/src/core/features/types/unlimited-copilot.ts new file mode 100644 index 0000000000..fd69e791a6 --- /dev/null +++ b/packages/backend/server/src/core/features/types/unlimited-copilot.ts @@ -0,0 +1,8 @@ +import { z } from 'zod'; + +import { FeatureType } from './common'; + +export const featureUnlimitedCopilot = z.object({ + feature: z.literal(FeatureType.UnlimitedCopilot), + configs: z.object({}), +}); diff --git a/packages/backend/server/src/core/quota/index.ts b/packages/backend/server/src/core/quota/index.ts index a84d09a367..efeaa9caed 100644 --- a/packages/backend/server/src/core/quota/index.ts +++ b/packages/backend/server/src/core/quota/index.ts @@ -20,5 +20,5 @@ import { QuotaManagementService } from './storage'; export class QuotaModule {} export { QuotaManagementService, QuotaService }; -export { Quota_FreePlanV1_1, Quota_ProPlanV1, Quotas } from './schema'; +export { Quota_FreePlanV1_1, Quota_ProPlanV1 } from './schema'; export { QuotaQueryType, QuotaType } from './types'; diff --git a/packages/backend/server/src/core/quota/schema.ts b/packages/backend/server/src/core/quota/schema.ts index 5776b98481..6dc45f0fbd 100644 --- a/packages/backend/server/src/core/quota/schema.ts +++ b/packages/backend/server/src/core/quota/schema.ts @@ -117,14 +117,61 @@ export const Quotas: Quota[] = [ copilotActionLimit: 10, }, }, + { + feature: QuotaType.ProPlanV1, + type: FeatureKind.Quota, + version: 2, + configs: { + // quota name + name: 'Pro', + // single blob limit 100MB + blobLimit: 100 * OneMB, + // total blob limit 100GB + storageQuota: 100 * OneGB, + // history period of validity 30 days + historyPeriod: 30 * OneDay, + // member limit 10 + memberLimit: 10, + // copilot action limit 10 + copilotActionLimit: 10, + }, + }, + { + feature: QuotaType.RestrictedPlanV1, + type: FeatureKind.Quota, + version: 2, + configs: { + // quota name + name: 'Restricted', + // single blob limit 10MB + blobLimit: OneMB, + // total blob limit 1GB + storageQuota: 10 * OneMB, + // history period of validity 30 days + historyPeriod: 30 * OneDay, + // member limit 10 + memberLimit: 10, + // copilot action limit 10 + copilotActionLimit: 10, + }, + }, ]; +export function getLatestQuota(type: QuotaType) { + const quota = Quotas.filter(f => f.feature === type); + quota.sort((a, b) => b.version - a.version); + return quota[0]; +} + +export const FreePlan = getLatestQuota(QuotaType.FreePlanV1); +export const ProPlan = getLatestQuota(QuotaType.ProPlanV1); + export const Quota_FreePlanV1_1 = { feature: Quotas[5].feature, version: Quotas[5].version, }; export const Quota_ProPlanV1 = { - feature: Quotas[1].feature, - version: Quotas[1].version, + feature: Quotas[6].feature, + version: Quotas[6].version, }; diff --git a/packages/backend/server/src/core/quota/service.ts b/packages/backend/server/src/core/quota/service.ts index d25aa1ae50..03b8022800 100644 --- a/packages/backend/server/src/core/quota/service.ts +++ b/packages/backend/server/src/core/quota/service.ts @@ -3,13 +3,17 @@ import { PrismaClient } from '@prisma/client'; import type { EventPayload } from '../../fundamentals'; import { OnEvent, PrismaTransaction } from '../../fundamentals'; -import { FeatureKind } from '../features'; +import { SubscriptionPlan } from '../../plugins/payment/types'; +import { FeatureKind, FeatureService, FeatureType } from '../features'; import { QuotaConfig } from './quota'; import { QuotaType } from './types'; @Injectable() export class QuotaService { - constructor(private readonly prisma: PrismaClient) {} + constructor( + private readonly prisma: PrismaClient, + private readonly feature: FeatureService + ) {} // get activated user quota async getUserQuota(userId: string) { @@ -159,22 +163,49 @@ export class QuotaService { @OnEvent('user.subscription.activated') async onSubscriptionUpdated({ userId, + plan, }: EventPayload<'user.subscription.activated'>) { - await this.switchUserQuota( - userId, - QuotaType.ProPlanV1, - 'subscription activated' - ); + switch (plan) { + case SubscriptionPlan.AI: + await this.feature.addUserFeature( + userId, + FeatureType.UnlimitedCopilot, + 'subscription activated' + ); + break; + case SubscriptionPlan.Pro: + await this.switchUserQuota( + userId, + QuotaType.ProPlanV1, + 'subscription activated' + ); + break; + default: + break; + } } @OnEvent('user.subscription.canceled') - async onSubscriptionCanceled( - userId: EventPayload<'user.subscription.canceled'> - ) { - await this.switchUserQuota( - userId, - QuotaType.FreePlanV1, - 'subscription canceled' - ); + async onSubscriptionCanceled({ + userId, + plan, + }: EventPayload<'user.subscription.canceled'>) { + switch (plan) { + case SubscriptionPlan.AI: + await this.feature.removeUserFeature( + userId, + FeatureType.UnlimitedCopilot + ); + break; + case SubscriptionPlan.Pro: + await this.switchUserQuota( + userId, + QuotaType.FreePlanV1, + 'subscription canceled' + ); + break; + default: + break; + } } } diff --git a/packages/backend/server/src/core/user/resolver.ts b/packages/backend/server/src/core/user/resolver.ts index 347e6ab366..aaa0fd46b5 100644 --- a/packages/backend/server/src/core/user/resolver.ts +++ b/packages/backend/server/src/core/user/resolver.ts @@ -115,7 +115,7 @@ export class UserResolver { description: 'Enabled features of a user', }) async userFeatures(@CurrentUser() user: CurrentUser) { - return this.feature.getUserFeatures(user.id); + return this.feature.getActivatedUserFeatures(user.id); } @Throttle({ diff --git a/packages/backend/server/src/core/workspaces/management.ts b/packages/backend/server/src/core/workspaces/management.ts index c8625c4d43..a4bd38fd34 100644 --- a/packages/backend/server/src/core/workspaces/management.ts +++ b/packages/backend/server/src/core/workspaces/management.ts @@ -117,12 +117,7 @@ export class WorkspaceManagementResolver { async availableFeatures( @CurrentUser() user: CurrentUser ): Promise { - const isEarlyAccessUser = await this.feature.isEarlyAccessUser(user.email); - if (isEarlyAccessUser) { - return [FeatureType.Copilot]; - } else { - return []; - } + return await this.feature.getActivatedUserFeatures(user.id); } @ResolveField(() => [FeatureType], { diff --git a/packages/backend/server/src/data/migrations/1705395933447-new-free-plan.ts b/packages/backend/server/src/data/migrations/1705395933447-new-free-plan.ts index dc6bf27966..51b869e9c7 100644 --- a/packages/backend/server/src/data/migrations/1705395933447-new-free-plan.ts +++ b/packages/backend/server/src/data/migrations/1705395933447-new-free-plan.ts @@ -1,6 +1,6 @@ import { PrismaClient } from '@prisma/client'; -import { Quotas } from '../../core/quota'; +import { Quotas } from '../../core/quota/schema'; import { upgradeQuotaVersion } from './utils/user-quotas'; export class NewFreePlan1705395933447 { diff --git a/packages/backend/server/src/data/migrations/1706513866287-business-blob-limit.ts b/packages/backend/server/src/data/migrations/1706513866287-business-blob-limit.ts index f19aec6fd2..4c61590057 100644 --- a/packages/backend/server/src/data/migrations/1706513866287-business-blob-limit.ts +++ b/packages/backend/server/src/data/migrations/1706513866287-business-blob-limit.ts @@ -1,6 +1,6 @@ import { PrismaClient } from '@prisma/client'; -import { Quotas } from '../../core/quota'; +import { Quotas } from '../../core/quota/schema'; import { upgradeQuotaVersion } from './utils/user-quotas'; export class BusinessBlobLimit1706513866287 { diff --git a/packages/backend/server/src/data/migrations/1713164714634-copilot-feature.ts b/packages/backend/server/src/data/migrations/1713164714634-copilot-feature.ts new file mode 100644 index 0000000000..9b6e2033b3 --- /dev/null +++ b/packages/backend/server/src/data/migrations/1713164714634-copilot-feature.ts @@ -0,0 +1,23 @@ +import { PrismaClient } from '@prisma/client'; + +import { QuotaType } from '../../core/quota/types'; +import { upgradeLatestQuotaVersion } from './utils/user-quotas'; + +export class CopilotFeature1713164714634 { + // do the migration + static async up(db: PrismaClient) { + await upgradeLatestQuotaVersion( + db, + QuotaType.ProPlanV1, + 'pro plan 1.1 migration' + ); + await upgradeLatestQuotaVersion( + db, + QuotaType.RestrictedPlanV1, + 'restricted plan 1.1 migration' + ); + } + + // revert the migration + static async down(_db: PrismaClient) {} +} diff --git a/packages/backend/server/src/data/migrations/utils/user-quotas.ts b/packages/backend/server/src/data/migrations/utils/user-quotas.ts index c245453282..7a8c5f9677 100644 --- a/packages/backend/server/src/data/migrations/utils/user-quotas.ts +++ b/packages/backend/server/src/data/migrations/utils/user-quotas.ts @@ -1,7 +1,7 @@ import { PrismaClient } from '@prisma/client'; import { FeatureKind } from '../../../core/features'; -import { Quotas } from '../../../core/quota/schema'; +import { getLatestQuota } from '../../../core/quota/schema'; import { Quota, QuotaType } from '../../../core/quota/types'; import { upsertFeature } from './user-features'; @@ -21,10 +21,10 @@ export async function upgradeQuotaVersion( }); // find all users that have old free plan - const userIds = await db.user.findMany({ + const userIds = await tx.user.findMany({ where: { features: { - every: { + some: { feature: { type: FeatureKind.Quota, feature: quota.feature, @@ -65,13 +65,19 @@ export async function upgradeQuotaVersion( }); } +export async function upsertLatestQuotaVersion( + db: PrismaClient, + type: QuotaType +) { + const latestQuota = getLatestQuota(type); + await upsertFeature(db, latestQuota); +} + export async function upgradeLatestQuotaVersion( db: PrismaClient, type: QuotaType, reason: string ) { - const quota = Quotas.filter(f => f.feature === type); - quota.sort((a, b) => b.version - a.version); - const latestQuota = quota[0]; + const latestQuota = getLatestQuota(type); await upgradeQuotaVersion(db, latestQuota, reason); } diff --git a/packages/backend/server/src/plugins/copilot/controller.ts b/packages/backend/server/src/plugins/copilot/controller.ts index c074010ff1..d1b9fb9822 100644 --- a/packages/backend/server/src/plugins/copilot/controller.ts +++ b/packages/backend/server/src/plugins/copilot/controller.ts @@ -20,7 +20,6 @@ import { toArray, } from 'rxjs'; -import { Public } from '../../core/auth'; import { CurrentUser } from '../../core/auth/current-user'; import { CopilotProviderService } from './providers'; import { ChatSession, ChatSessionService } from './session'; @@ -79,7 +78,6 @@ export class CopilotController { return session; } - @Public() @Get('/chat/:sessionId') async chat( @CurrentUser() user: CurrentUser, @@ -89,6 +87,8 @@ export class CopilotController { @Query('messageId') messageId: string | undefined, @Query() params: Record ): Promise { + await this.chatSession.checkQuota(user.id); + const model = await this.chatSession.get(sessionId).then(s => s?.model); const provider = this.provider.getProviderByCapability( CopilotCapability.TextToText, @@ -131,7 +131,6 @@ export class CopilotController { } } - @Public() @Sse('/chat/:sessionId/stream') async chatStream( @CurrentUser() user: CurrentUser, @@ -141,6 +140,8 @@ export class CopilotController { @Query('messageId') messageId: string | undefined, @Query() params: Record ): Promise> { + await this.chatSession.checkQuota(user.id); + const model = await this.chatSession.get(sessionId).then(s => s?.model); const provider = this.provider.getProviderByCapability( CopilotCapability.TextToText, @@ -188,16 +189,17 @@ export class CopilotController { ); } - @Public() @Sse('/chat/:sessionId/images') async chatImagesStream( - @CurrentUser() user: CurrentUser | undefined, + @CurrentUser() user: CurrentUser, @Req() req: Request, @Param('sessionId') sessionId: string, @Query('message') message: string | undefined, @Query('messageId') messageId: string | undefined, @Query() params: Record ): Promise> { + await this.chatSession.checkQuota(user.id); + const hasAttachment = await this.hasAttachment(sessionId, messageId); const model = await this.chatSession.get(sessionId).then(s => s?.model); const provider = this.provider.getProviderByCapability( @@ -221,7 +223,7 @@ export class CopilotController { return from( provider.generateImagesStream(session.finish(params), session.model, { signal: req.signal, - user: user?.id, + user: user.id, }) ).pipe( connect(shared$ => diff --git a/packages/backend/server/src/plugins/copilot/index.ts b/packages/backend/server/src/plugins/copilot/index.ts index 370e17cec5..6d65f5f19d 100644 --- a/packages/backend/server/src/plugins/copilot/index.ts +++ b/packages/backend/server/src/plugins/copilot/index.ts @@ -1,4 +1,5 @@ import { ServerFeature } from '../../core/config'; +import { FeatureManagementService, FeatureService } from '../../core/features'; import { QuotaService } from '../../core/quota'; import { PermissionService } from '../../core/workspaces/permission'; import { Plugin } from '../registry'; @@ -22,6 +23,8 @@ registerCopilotProvider(OpenAIProvider); name: 'copilot', providers: [ PermissionService, + FeatureService, + FeatureManagementService, QuotaService, ChatSessionService, CopilotResolver, diff --git a/packages/backend/server/src/plugins/copilot/resolver.ts b/packages/backend/server/src/plugins/copilot/resolver.ts index 44389a15cf..18a774d6c2 100644 --- a/packages/backend/server/src/plugins/copilot/resolver.ts +++ b/packages/backend/server/src/plugins/copilot/resolver.ts @@ -14,14 +14,9 @@ import { import { GraphQLJSON, SafeIntResolver } from 'graphql-scalars'; import { CurrentUser } from '../../core/auth'; -import { QuotaService } from '../../core/quota'; import { UserType } from '../../core/user'; import { PermissionService } from '../../core/workspaces/permission'; -import { - MutexService, - PaymentRequiredException, - TooManyRequestsException, -} from '../../fundamentals'; +import { MutexService, TooManyRequestsException } from '../../fundamentals'; import { ChatSessionService } from './session'; import { AvailableModels, @@ -123,8 +118,8 @@ class CopilotHistoriesType implements Partial { @ObjectType('CopilotQuota') class CopilotQuotaType { - @Field(() => SafeIntResolver) - limit!: number; + @Field(() => SafeIntResolver, { nullable: true }) + limit?: number; @Field(() => SafeIntResolver) used!: number; @@ -144,7 +139,6 @@ export class CopilotResolver { constructor( private readonly permissions: PermissionService, - private readonly quota: QuotaService, private readonly mutex: MutexService, private readonly chatSession: ChatSessionService ) {} @@ -155,20 +149,7 @@ export class CopilotResolver { complexity: 2, }) async getQuota(@CurrentUser() user: CurrentUser) { - const quota = await this.quota.getUserQuota(user.id); - const limit = quota.feature.copilotActionLimit; - - const actions = await this.chatSession.countUserActions(user.id); - const chats = await this.chatSession - .listHistories(user.id) - .then(histories => - histories.reduce( - (acc, h) => acc + h.messages.filter(m => m.role === 'user').length, - 0 - ) - ); - - return { limit, used: actions + chats }; + return await this.chatSession.getQuota(user.id); } @ResolveField(() => [String], { @@ -257,12 +238,7 @@ export class CopilotResolver { return new TooManyRequestsException('Server is busy'); } - const { limit, used } = await this.getQuota(user); - if (limit && Number.isFinite(limit) && used >= limit) { - return new PaymentRequiredException( - `You have reached the limit of actions in this workspace, please upgrade your plan.` - ); - } + await this.chatSession.checkQuota(user.id); const session = await this.chatSession.create({ ...options, diff --git a/packages/backend/server/src/plugins/copilot/session.ts b/packages/backend/server/src/plugins/copilot/session.ts index 50f6116c18..90014b3b75 100644 --- a/packages/backend/server/src/plugins/copilot/session.ts +++ b/packages/backend/server/src/plugins/copilot/session.ts @@ -1,8 +1,11 @@ import { randomUUID } from 'node:crypto'; import { Injectable, Logger } from '@nestjs/common'; -import { PrismaClient } from '@prisma/client'; +import { AiPromptRole, PrismaClient } from '@prisma/client'; +import { FeatureManagementService, FeatureType } from '../../core/features'; +import { QuotaService } from '../../core/quota'; +import { PaymentRequiredException } from '../../fundamentals'; import { ChatMessageCache } from './message'; import { ChatPrompt, PromptService } from './prompt'; import { @@ -120,6 +123,8 @@ export class ChatSessionService { constructor( private readonly db: PrismaClient, + private readonly feature: FeatureManagementService, + private readonly quota: QuotaService, private readonly messageCache: ChatMessageCache, private readonly prompt: PromptService ) {} @@ -242,12 +247,24 @@ export class ChatSessionService { .reduce((total, length) => total + length, 0); } - async countUserActions(userId: string): Promise { + private async countUserActions(userId: string): Promise { return await this.db.aiSession.count({ where: { userId, prompt: { action: { not: null } } }, }); } + private async countUserChats(userId: string): Promise { + const chats = await this.db.aiSession.findMany({ + where: { userId, prompt: { action: null } }, + select: { + _count: { + select: { messages: { where: { role: AiPromptRole.user } } }, + }, + }, + }); + return chats.reduce((prev, chat) => prev + chat._count.messages, 0); + } + async listSessions( userId: string, workspaceId: string, @@ -347,6 +364,32 @@ export class ChatSessionService { ); } + async getQuota(userId: string) { + const hasCopilotFeature = await this.feature + .getActivatedUserFeatures(userId) + .then(f => f.includes(FeatureType.UnlimitedCopilot)); + + let limit: number | undefined; + if (!hasCopilotFeature) { + const quota = await this.quota.getUserQuota(userId); + limit = quota.feature.copilotActionLimit; + } + + const actions = await this.countUserActions(userId); + const chats = await this.countUserChats(userId); + + return { limit, used: actions + chats }; + } + + async checkQuota(userId: string) { + const { limit, used } = await this.getQuota(userId); + if (limit && Number.isFinite(limit) && used >= limit) { + throw new PaymentRequiredException( + `You have reached the limit of actions in this workspace, please upgrade your plan.` + ); + } + } + async create(options: ChatSessionOptions): Promise { const sessionId = randomUUID(); const prompt = await this.prompt.get(options.promptName); diff --git a/packages/backend/server/src/plugins/payment/service.ts b/packages/backend/server/src/plugins/payment/service.ts index 1657cdf65f..bab3e2eac0 100644 --- a/packages/backend/server/src/plugins/payment/service.ts +++ b/packages/backend/server/src/plugins/payment/service.ts @@ -527,7 +527,10 @@ export class SubscriptionService { nextBillAt = new Date(subscription.current_period_end * 1000); } } else { - this.event.emit('user.subscription.canceled', user.id); + this.event.emit('user.subscription.canceled', { + userId: user.id, + plan, + }); } const commonData = { diff --git a/packages/backend/server/src/plugins/payment/types.ts b/packages/backend/server/src/plugins/payment/types.ts index 4b11a12ba9..dd51193613 100644 --- a/packages/backend/server/src/plugins/payment/types.ts +++ b/packages/backend/server/src/plugins/payment/types.ts @@ -53,7 +53,10 @@ declare module '../../fundamentals/event/def' { userId: User['id']; plan: SubscriptionPlan; }>; - canceled: Payload; + canceled: Payload<{ + userId: User['id']; + plan: SubscriptionPlan; + }>; }; } } diff --git a/packages/backend/server/src/schema.gql b/packages/backend/server/src/schema.gql index 76508052bb..3348daedf0 100644 --- a/packages/backend/server/src/schema.gql +++ b/packages/backend/server/src/schema.gql @@ -34,7 +34,7 @@ type CopilotHistories { } type CopilotQuota { - limit: SafeInt! + limit: SafeInt used: SafeInt! } @@ -84,6 +84,7 @@ type DocHistoryType { enum FeatureType { Copilot EarlyAccess + UnlimitedCopilot UnlimitedWorkspace } diff --git a/packages/backend/server/tests/feature.spec.ts b/packages/backend/server/tests/feature.spec.ts index 4710596dd9..52be2609fc 100644 --- a/packages/backend/server/tests/feature.spec.ts +++ b/packages/backend/server/tests/feature.spec.ts @@ -90,7 +90,7 @@ test('should be able to set user feature', async t => { const f1 = await feature.getUserFeatures(u1.id); t.is(f1.length, 0, 'should be empty'); - await feature.addUserFeature(u1.id, FeatureType.EarlyAccess, 2, 'test'); + await feature.addUserFeature(u1.id, FeatureType.EarlyAccess, 'test'); const f2 = await feature.getUserFeatures(u1.id); t.is(f2.length, 1, 'should have 1 feature'); diff --git a/packages/backend/server/tests/quota.spec.ts b/packages/backend/server/tests/quota.spec.ts index 58fa7b33db..89ebc42924 100644 --- a/packages/backend/server/tests/quota.spec.ts +++ b/packages/backend/server/tests/quota.spec.ts @@ -8,10 +8,10 @@ import { AuthService } from '../src/core/auth'; import { QuotaManagementService, QuotaModule, - Quotas, QuotaService, QuotaType, } from '../src/core/quota'; +import { FreePlan, ProPlan } from '../src/core/quota/schema'; import { StorageModule } from '../src/core/storage'; import { createTestingModule } from './utils'; @@ -63,33 +63,43 @@ test('should be able to set quota', async t => { test('should be able to check storage quota', async t => { const { auth, quota, quotaManager } = t.context; const u1 = await auth.signUp('DarkSky', 'darksky@example.org', '123456'); + const freePlan = FreePlan.configs; + const proPlan = ProPlan.configs; const q1 = await quotaManager.getUserQuota(u1.id); - t.is(q1?.blobLimit, Quotas[5].configs.blobLimit, 'should be free plan'); - t.is(q1?.storageQuota, Quotas[5].configs.storageQuota, 'should be free plan'); + t.is(q1?.blobLimit, freePlan.blobLimit, 'should be free plan'); + t.is(q1?.storageQuota, freePlan.storageQuota, 'should be free plan'); await quota.switchUserQuota(u1.id, QuotaType.ProPlanV1); const q2 = await quotaManager.getUserQuota(u1.id); - t.is(q2?.blobLimit, Quotas[1].configs.blobLimit, 'should be pro plan'); - t.is(q2?.storageQuota, Quotas[1].configs.storageQuota, 'should be pro plan'); + t.is(q2?.blobLimit, proPlan.blobLimit, 'should be pro plan'); + t.is(q2?.storageQuota, proPlan.storageQuota, 'should be pro plan'); }); test('should be able revert quota', async t => { const { auth, quota, quotaManager } = t.context; const u1 = await auth.signUp('DarkSky', 'darksky@example.org', '123456'); + const freePlan = FreePlan.configs; + const proPlan = ProPlan.configs; const q1 = await quotaManager.getUserQuota(u1.id); - t.is(q1?.blobLimit, Quotas[5].configs.blobLimit, 'should be free plan'); - t.is(q1?.storageQuota, Quotas[5].configs.storageQuota, 'should be free plan'); + + t.is(q1?.blobLimit, freePlan.blobLimit, 'should be free plan'); + t.is(q1?.storageQuota, freePlan.storageQuota, 'should be free plan'); await quota.switchUserQuota(u1.id, QuotaType.ProPlanV1); const q2 = await quotaManager.getUserQuota(u1.id); - t.is(q2?.blobLimit, Quotas[1].configs.blobLimit, 'should be pro plan'); - t.is(q2?.storageQuota, Quotas[1].configs.storageQuota, 'should be pro plan'); + t.is(q2?.blobLimit, proPlan.blobLimit, 'should be pro plan'); + t.is(q2?.storageQuota, proPlan.storageQuota, 'should be pro plan'); + t.is( + q2?.copilotActionLimit, + proPlan.copilotActionLimit!, + 'should be pro plan' + ); await quota.switchUserQuota(u1.id, QuotaType.FreePlanV1); const q3 = await quotaManager.getUserQuota(u1.id); - t.is(q3?.blobLimit, Quotas[5].configs.blobLimit, 'should be free plan'); + t.is(q3?.blobLimit, freePlan.blobLimit, 'should be free plan'); const quotas = await quota.getUserQuotas(u1.id); t.is(quotas.length, 3, 'should have 3 quotas'); @@ -104,9 +114,9 @@ test('should be able revert quota', async t => { test('should be able to check quota', async t => { const { auth, quotaManager } = t.context; const u1 = await auth.signUp('DarkSky', 'darksky@example.org', '123456'); + const freePlan = FreePlan.configs; const q1 = await quotaManager.getUserQuota(u1.id); - const freePlan = Quotas[5].configs; t.assert(q1, 'should have quota'); t.is(q1.blobLimit, freePlan.blobLimit, 'should be free plan'); t.is(q1.storageQuota, freePlan.storageQuota, 'should be free plan'); diff --git a/packages/frontend/core/src/components/affine/setting-modal/account-setting/ai-usage-panel.tsx b/packages/frontend/core/src/components/affine/setting-modal/account-setting/ai-usage-panel.tsx index 0ad6bda2fd..163263b53a 100644 --- a/packages/frontend/core/src/components/affine/setting-modal/account-setting/ai-usage-panel.tsx +++ b/packages/frontend/core/src/components/affine/setting-modal/account-setting/ai-usage-panel.tsx @@ -70,7 +70,9 @@ export const AIUsagePanelNotSubscripted = () => { const { data: quota } = useQuery({ query: getCopilotQuotaQuery, }); - const { limit = 10, used = 0 } = quota.currentUser?.copilot.quota || {}; + const { limit: nullableLimit, used = 0 } = + quota.currentUser?.copilot.quota || {}; + const limit = nullableLimit || 10; const percent = Math.min( 100, Math.max(0.5, Number(((used / limit) * 100).toFixed(4))) diff --git a/packages/frontend/graphql/src/schema.ts b/packages/frontend/graphql/src/schema.ts index 968c0ed424..990453cbfc 100644 --- a/packages/frontend/graphql/src/schema.ts +++ b/packages/frontend/graphql/src/schema.ts @@ -62,6 +62,7 @@ export interface CreateCheckoutSessionInput { export enum FeatureType { Copilot = 'Copilot', EarlyAccess = 'EarlyAccess', + UnlimitedCopilot = 'UnlimitedCopilot', UnlimitedWorkspace = 'UnlimitedWorkspace', } @@ -387,7 +388,11 @@ export type GetCopilotQuotaQuery = { __typename?: 'UserType'; copilot: { __typename?: 'Copilot'; - quota: { __typename?: 'CopilotQuota'; limit: number; used: number }; + quota: { + __typename?: 'CopilotQuota'; + limit: number | null; + used: number; + }; }; } | null; };