feat: check quota correctly (#6561)

This commit is contained in:
darkskygit
2024-04-16 09:41:48 +00:00
parent 0ca8a23dd8
commit 1b0864eb60
26 changed files with 309 additions and 95 deletions

View File

@@ -20,7 +20,6 @@ import {
toArray,
} from 'rxjs';
import { Public } from '../../core/auth';
import { CurrentUser } from '../../core/auth/current-user';
import { CopilotProviderService } from './providers';
import { ChatSession, ChatSessionService } from './session';
@@ -79,7 +78,6 @@ export class CopilotController {
return session;
}
@Public()
@Get('/chat/:sessionId')
async chat(
@CurrentUser() user: CurrentUser,
@@ -89,6 +87,8 @@ export class CopilotController {
@Query('messageId') messageId: string | undefined,
@Query() params: Record<string, string | string[]>
): Promise<string> {
await this.chatSession.checkQuota(user.id);
const model = await this.chatSession.get(sessionId).then(s => s?.model);
const provider = this.provider.getProviderByCapability(
CopilotCapability.TextToText,
@@ -131,7 +131,6 @@ export class CopilotController {
}
}
@Public()
@Sse('/chat/:sessionId/stream')
async chatStream(
@CurrentUser() user: CurrentUser,
@@ -141,6 +140,8 @@ export class CopilotController {
@Query('messageId') messageId: string | undefined,
@Query() params: Record<string, string>
): Promise<Observable<ChatEvent>> {
await this.chatSession.checkQuota(user.id);
const model = await this.chatSession.get(sessionId).then(s => s?.model);
const provider = this.provider.getProviderByCapability(
CopilotCapability.TextToText,
@@ -188,16 +189,17 @@ export class CopilotController {
);
}
@Public()
@Sse('/chat/:sessionId/images')
async chatImagesStream(
@CurrentUser() user: CurrentUser | undefined,
@CurrentUser() user: CurrentUser,
@Req() req: Request,
@Param('sessionId') sessionId: string,
@Query('message') message: string | undefined,
@Query('messageId') messageId: string | undefined,
@Query() params: Record<string, string>
): Promise<Observable<ChatEvent>> {
await this.chatSession.checkQuota(user.id);
const hasAttachment = await this.hasAttachment(sessionId, messageId);
const model = await this.chatSession.get(sessionId).then(s => s?.model);
const provider = this.provider.getProviderByCapability(
@@ -221,7 +223,7 @@ export class CopilotController {
return from(
provider.generateImagesStream(session.finish(params), session.model, {
signal: req.signal,
user: user?.id,
user: user.id,
})
).pipe(
connect(shared$ =>

View File

@@ -1,4 +1,5 @@
import { ServerFeature } from '../../core/config';
import { FeatureManagementService, FeatureService } from '../../core/features';
import { QuotaService } from '../../core/quota';
import { PermissionService } from '../../core/workspaces/permission';
import { Plugin } from '../registry';
@@ -22,6 +23,8 @@ registerCopilotProvider(OpenAIProvider);
name: 'copilot',
providers: [
PermissionService,
FeatureService,
FeatureManagementService,
QuotaService,
ChatSessionService,
CopilotResolver,

View File

@@ -14,14 +14,9 @@ import {
import { GraphQLJSON, SafeIntResolver } from 'graphql-scalars';
import { CurrentUser } from '../../core/auth';
import { QuotaService } from '../../core/quota';
import { UserType } from '../../core/user';
import { PermissionService } from '../../core/workspaces/permission';
import {
MutexService,
PaymentRequiredException,
TooManyRequestsException,
} from '../../fundamentals';
import { MutexService, TooManyRequestsException } from '../../fundamentals';
import { ChatSessionService } from './session';
import {
AvailableModels,
@@ -123,8 +118,8 @@ class CopilotHistoriesType implements Partial<ChatHistory> {
@ObjectType('CopilotQuota')
class CopilotQuotaType {
@Field(() => SafeIntResolver)
limit!: number;
@Field(() => SafeIntResolver, { nullable: true })
limit?: number;
@Field(() => SafeIntResolver)
used!: number;
@@ -144,7 +139,6 @@ export class CopilotResolver {
constructor(
private readonly permissions: PermissionService,
private readonly quota: QuotaService,
private readonly mutex: MutexService,
private readonly chatSession: ChatSessionService
) {}
@@ -155,20 +149,7 @@ export class CopilotResolver {
complexity: 2,
})
async getQuota(@CurrentUser() user: CurrentUser) {
const quota = await this.quota.getUserQuota(user.id);
const limit = quota.feature.copilotActionLimit;
const actions = await this.chatSession.countUserActions(user.id);
const chats = await this.chatSession
.listHistories(user.id)
.then(histories =>
histories.reduce(
(acc, h) => acc + h.messages.filter(m => m.role === 'user').length,
0
)
);
return { limit, used: actions + chats };
return await this.chatSession.getQuota(user.id);
}
@ResolveField(() => [String], {
@@ -257,12 +238,7 @@ export class CopilotResolver {
return new TooManyRequestsException('Server is busy');
}
const { limit, used } = await this.getQuota(user);
if (limit && Number.isFinite(limit) && used >= limit) {
return new PaymentRequiredException(
`You have reached the limit of actions in this workspace, please upgrade your plan.`
);
}
await this.chatSession.checkQuota(user.id);
const session = await this.chatSession.create({
...options,

View File

@@ -1,8 +1,11 @@
import { randomUUID } from 'node:crypto';
import { Injectable, Logger } from '@nestjs/common';
import { PrismaClient } from '@prisma/client';
import { AiPromptRole, PrismaClient } from '@prisma/client';
import { FeatureManagementService, FeatureType } from '../../core/features';
import { QuotaService } from '../../core/quota';
import { PaymentRequiredException } from '../../fundamentals';
import { ChatMessageCache } from './message';
import { ChatPrompt, PromptService } from './prompt';
import {
@@ -120,6 +123,8 @@ export class ChatSessionService {
constructor(
private readonly db: PrismaClient,
private readonly feature: FeatureManagementService,
private readonly quota: QuotaService,
private readonly messageCache: ChatMessageCache,
private readonly prompt: PromptService
) {}
@@ -242,12 +247,24 @@ export class ChatSessionService {
.reduce((total, length) => total + length, 0);
}
async countUserActions(userId: string): Promise<number> {
private async countUserActions(userId: string): Promise<number> {
return await this.db.aiSession.count({
where: { userId, prompt: { action: { not: null } } },
});
}
private async countUserChats(userId: string): Promise<number> {
const chats = await this.db.aiSession.findMany({
where: { userId, prompt: { action: null } },
select: {
_count: {
select: { messages: { where: { role: AiPromptRole.user } } },
},
},
});
return chats.reduce((prev, chat) => prev + chat._count.messages, 0);
}
async listSessions(
userId: string,
workspaceId: string,
@@ -347,6 +364,32 @@ export class ChatSessionService {
);
}
async getQuota(userId: string) {
const hasCopilotFeature = await this.feature
.getActivatedUserFeatures(userId)
.then(f => f.includes(FeatureType.UnlimitedCopilot));
let limit: number | undefined;
if (!hasCopilotFeature) {
const quota = await this.quota.getUserQuota(userId);
limit = quota.feature.copilotActionLimit;
}
const actions = await this.countUserActions(userId);
const chats = await this.countUserChats(userId);
return { limit, used: actions + chats };
}
async checkQuota(userId: string) {
const { limit, used } = await this.getQuota(userId);
if (limit && Number.isFinite(limit) && used >= limit) {
throw new PaymentRequiredException(
`You have reached the limit of actions in this workspace, please upgrade your plan.`
);
}
}
async create(options: ChatSessionOptions): Promise<string> {
const sessionId = randomUUID();
const prompt = await this.prompt.get(options.promptName);

View File

@@ -527,7 +527,10 @@ export class SubscriptionService {
nextBillAt = new Date(subscription.current_period_end * 1000);
}
} else {
this.event.emit('user.subscription.canceled', user.id);
this.event.emit('user.subscription.canceled', {
userId: user.id,
plan,
});
}
const commonData = {

View File

@@ -53,7 +53,10 @@ declare module '../../fundamentals/event/def' {
userId: User['id'];
plan: SubscriptionPlan;
}>;
canceled: Payload<User['id']>;
canceled: Payload<{
userId: User['id'];
plan: SubscriptionPlan;
}>;
};
}
}