feat: check quota correctly (#6561)

This commit is contained in:
darkskygit
2024-04-16 09:41:48 +00:00
parent 0ca8a23dd8
commit 1b0864eb60
26 changed files with 309 additions and 95 deletions

View File

@@ -1,8 +1,11 @@
import { randomUUID } from 'node:crypto';
import { Injectable, Logger } from '@nestjs/common';
import { PrismaClient } from '@prisma/client';
import { AiPromptRole, PrismaClient } from '@prisma/client';
import { FeatureManagementService, FeatureType } from '../../core/features';
import { QuotaService } from '../../core/quota';
import { PaymentRequiredException } from '../../fundamentals';
import { ChatMessageCache } from './message';
import { ChatPrompt, PromptService } from './prompt';
import {
@@ -120,6 +123,8 @@ export class ChatSessionService {
constructor(
private readonly db: PrismaClient,
private readonly feature: FeatureManagementService,
private readonly quota: QuotaService,
private readonly messageCache: ChatMessageCache,
private readonly prompt: PromptService
) {}
@@ -242,12 +247,24 @@ export class ChatSessionService {
.reduce((total, length) => total + length, 0);
}
async countUserActions(userId: string): Promise<number> {
private async countUserActions(userId: string): Promise<number> {
return await this.db.aiSession.count({
where: { userId, prompt: { action: { not: null } } },
});
}
private async countUserChats(userId: string): Promise<number> {
const chats = await this.db.aiSession.findMany({
where: { userId, prompt: { action: null } },
select: {
_count: {
select: { messages: { where: { role: AiPromptRole.user } } },
},
},
});
return chats.reduce((prev, chat) => prev + chat._count.messages, 0);
}
async listSessions(
userId: string,
workspaceId: string,
@@ -347,6 +364,32 @@ export class ChatSessionService {
);
}
async getQuota(userId: string) {
const hasCopilotFeature = await this.feature
.getActivatedUserFeatures(userId)
.then(f => f.includes(FeatureType.UnlimitedCopilot));
let limit: number | undefined;
if (!hasCopilotFeature) {
const quota = await this.quota.getUserQuota(userId);
limit = quota.feature.copilotActionLimit;
}
const actions = await this.countUserActions(userId);
const chats = await this.countUserChats(userId);
return { limit, used: actions + chats };
}
async checkQuota(userId: string) {
const { limit, used } = await this.getQuota(userId);
if (limit && Number.isFinite(limit) && used >= limit) {
throw new PaymentRequiredException(
`You have reached the limit of actions in this workspace, please upgrade your plan.`
);
}
}
async create(options: ChatSessionOptions): Promise<string> {
const sessionId = randomUUID();
const prompt = await this.prompt.get(options.promptName);