mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-13 21:05:19 +00:00
feat: check quota correctly (#6561)
This commit is contained in:
@@ -20,7 +20,6 @@ import {
|
||||
toArray,
|
||||
} from 'rxjs';
|
||||
|
||||
import { Public } from '../../core/auth';
|
||||
import { CurrentUser } from '../../core/auth/current-user';
|
||||
import { CopilotProviderService } from './providers';
|
||||
import { ChatSession, ChatSessionService } from './session';
|
||||
@@ -79,7 +78,6 @@ export class CopilotController {
|
||||
return session;
|
||||
}
|
||||
|
||||
@Public()
|
||||
@Get('/chat/:sessionId')
|
||||
async chat(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@@ -89,6 +87,8 @@ export class CopilotController {
|
||||
@Query('messageId') messageId: string | undefined,
|
||||
@Query() params: Record<string, string | string[]>
|
||||
): Promise<string> {
|
||||
await this.chatSession.checkQuota(user.id);
|
||||
|
||||
const model = await this.chatSession.get(sessionId).then(s => s?.model);
|
||||
const provider = this.provider.getProviderByCapability(
|
||||
CopilotCapability.TextToText,
|
||||
@@ -131,7 +131,6 @@ export class CopilotController {
|
||||
}
|
||||
}
|
||||
|
||||
@Public()
|
||||
@Sse('/chat/:sessionId/stream')
|
||||
async chatStream(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@@ -141,6 +140,8 @@ export class CopilotController {
|
||||
@Query('messageId') messageId: string | undefined,
|
||||
@Query() params: Record<string, string>
|
||||
): Promise<Observable<ChatEvent>> {
|
||||
await this.chatSession.checkQuota(user.id);
|
||||
|
||||
const model = await this.chatSession.get(sessionId).then(s => s?.model);
|
||||
const provider = this.provider.getProviderByCapability(
|
||||
CopilotCapability.TextToText,
|
||||
@@ -188,16 +189,17 @@ export class CopilotController {
|
||||
);
|
||||
}
|
||||
|
||||
@Public()
|
||||
@Sse('/chat/:sessionId/images')
|
||||
async chatImagesStream(
|
||||
@CurrentUser() user: CurrentUser | undefined,
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Req() req: Request,
|
||||
@Param('sessionId') sessionId: string,
|
||||
@Query('message') message: string | undefined,
|
||||
@Query('messageId') messageId: string | undefined,
|
||||
@Query() params: Record<string, string>
|
||||
): Promise<Observable<ChatEvent>> {
|
||||
await this.chatSession.checkQuota(user.id);
|
||||
|
||||
const hasAttachment = await this.hasAttachment(sessionId, messageId);
|
||||
const model = await this.chatSession.get(sessionId).then(s => s?.model);
|
||||
const provider = this.provider.getProviderByCapability(
|
||||
@@ -221,7 +223,7 @@ export class CopilotController {
|
||||
return from(
|
||||
provider.generateImagesStream(session.finish(params), session.model, {
|
||||
signal: req.signal,
|
||||
user: user?.id,
|
||||
user: user.id,
|
||||
})
|
||||
).pipe(
|
||||
connect(shared$ =>
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { ServerFeature } from '../../core/config';
|
||||
import { FeatureManagementService, FeatureService } from '../../core/features';
|
||||
import { QuotaService } from '../../core/quota';
|
||||
import { PermissionService } from '../../core/workspaces/permission';
|
||||
import { Plugin } from '../registry';
|
||||
@@ -22,6 +23,8 @@ registerCopilotProvider(OpenAIProvider);
|
||||
name: 'copilot',
|
||||
providers: [
|
||||
PermissionService,
|
||||
FeatureService,
|
||||
FeatureManagementService,
|
||||
QuotaService,
|
||||
ChatSessionService,
|
||||
CopilotResolver,
|
||||
|
||||
@@ -14,14 +14,9 @@ import {
|
||||
import { GraphQLJSON, SafeIntResolver } from 'graphql-scalars';
|
||||
|
||||
import { CurrentUser } from '../../core/auth';
|
||||
import { QuotaService } from '../../core/quota';
|
||||
import { UserType } from '../../core/user';
|
||||
import { PermissionService } from '../../core/workspaces/permission';
|
||||
import {
|
||||
MutexService,
|
||||
PaymentRequiredException,
|
||||
TooManyRequestsException,
|
||||
} from '../../fundamentals';
|
||||
import { MutexService, TooManyRequestsException } from '../../fundamentals';
|
||||
import { ChatSessionService } from './session';
|
||||
import {
|
||||
AvailableModels,
|
||||
@@ -123,8 +118,8 @@ class CopilotHistoriesType implements Partial<ChatHistory> {
|
||||
|
||||
@ObjectType('CopilotQuota')
|
||||
class CopilotQuotaType {
|
||||
@Field(() => SafeIntResolver)
|
||||
limit!: number;
|
||||
@Field(() => SafeIntResolver, { nullable: true })
|
||||
limit?: number;
|
||||
|
||||
@Field(() => SafeIntResolver)
|
||||
used!: number;
|
||||
@@ -144,7 +139,6 @@ export class CopilotResolver {
|
||||
|
||||
constructor(
|
||||
private readonly permissions: PermissionService,
|
||||
private readonly quota: QuotaService,
|
||||
private readonly mutex: MutexService,
|
||||
private readonly chatSession: ChatSessionService
|
||||
) {}
|
||||
@@ -155,20 +149,7 @@ export class CopilotResolver {
|
||||
complexity: 2,
|
||||
})
|
||||
async getQuota(@CurrentUser() user: CurrentUser) {
|
||||
const quota = await this.quota.getUserQuota(user.id);
|
||||
const limit = quota.feature.copilotActionLimit;
|
||||
|
||||
const actions = await this.chatSession.countUserActions(user.id);
|
||||
const chats = await this.chatSession
|
||||
.listHistories(user.id)
|
||||
.then(histories =>
|
||||
histories.reduce(
|
||||
(acc, h) => acc + h.messages.filter(m => m.role === 'user').length,
|
||||
0
|
||||
)
|
||||
);
|
||||
|
||||
return { limit, used: actions + chats };
|
||||
return await this.chatSession.getQuota(user.id);
|
||||
}
|
||||
|
||||
@ResolveField(() => [String], {
|
||||
@@ -257,12 +238,7 @@ export class CopilotResolver {
|
||||
return new TooManyRequestsException('Server is busy');
|
||||
}
|
||||
|
||||
const { limit, used } = await this.getQuota(user);
|
||||
if (limit && Number.isFinite(limit) && used >= limit) {
|
||||
return new PaymentRequiredException(
|
||||
`You have reached the limit of actions in this workspace, please upgrade your plan.`
|
||||
);
|
||||
}
|
||||
await this.chatSession.checkQuota(user.id);
|
||||
|
||||
const session = await this.chatSession.create({
|
||||
...options,
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import { randomUUID } from 'node:crypto';
|
||||
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
import { AiPromptRole, PrismaClient } from '@prisma/client';
|
||||
|
||||
import { FeatureManagementService, FeatureType } from '../../core/features';
|
||||
import { QuotaService } from '../../core/quota';
|
||||
import { PaymentRequiredException } from '../../fundamentals';
|
||||
import { ChatMessageCache } from './message';
|
||||
import { ChatPrompt, PromptService } from './prompt';
|
||||
import {
|
||||
@@ -120,6 +123,8 @@ export class ChatSessionService {
|
||||
|
||||
constructor(
|
||||
private readonly db: PrismaClient,
|
||||
private readonly feature: FeatureManagementService,
|
||||
private readonly quota: QuotaService,
|
||||
private readonly messageCache: ChatMessageCache,
|
||||
private readonly prompt: PromptService
|
||||
) {}
|
||||
@@ -242,12 +247,24 @@ export class ChatSessionService {
|
||||
.reduce((total, length) => total + length, 0);
|
||||
}
|
||||
|
||||
async countUserActions(userId: string): Promise<number> {
|
||||
private async countUserActions(userId: string): Promise<number> {
|
||||
return await this.db.aiSession.count({
|
||||
where: { userId, prompt: { action: { not: null } } },
|
||||
});
|
||||
}
|
||||
|
||||
private async countUserChats(userId: string): Promise<number> {
|
||||
const chats = await this.db.aiSession.findMany({
|
||||
where: { userId, prompt: { action: null } },
|
||||
select: {
|
||||
_count: {
|
||||
select: { messages: { where: { role: AiPromptRole.user } } },
|
||||
},
|
||||
},
|
||||
});
|
||||
return chats.reduce((prev, chat) => prev + chat._count.messages, 0);
|
||||
}
|
||||
|
||||
async listSessions(
|
||||
userId: string,
|
||||
workspaceId: string,
|
||||
@@ -347,6 +364,32 @@ export class ChatSessionService {
|
||||
);
|
||||
}
|
||||
|
||||
async getQuota(userId: string) {
|
||||
const hasCopilotFeature = await this.feature
|
||||
.getActivatedUserFeatures(userId)
|
||||
.then(f => f.includes(FeatureType.UnlimitedCopilot));
|
||||
|
||||
let limit: number | undefined;
|
||||
if (!hasCopilotFeature) {
|
||||
const quota = await this.quota.getUserQuota(userId);
|
||||
limit = quota.feature.copilotActionLimit;
|
||||
}
|
||||
|
||||
const actions = await this.countUserActions(userId);
|
||||
const chats = await this.countUserChats(userId);
|
||||
|
||||
return { limit, used: actions + chats };
|
||||
}
|
||||
|
||||
async checkQuota(userId: string) {
|
||||
const { limit, used } = await this.getQuota(userId);
|
||||
if (limit && Number.isFinite(limit) && used >= limit) {
|
||||
throw new PaymentRequiredException(
|
||||
`You have reached the limit of actions in this workspace, please upgrade your plan.`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
async create(options: ChatSessionOptions): Promise<string> {
|
||||
const sessionId = randomUUID();
|
||||
const prompt = await this.prompt.get(options.promptName);
|
||||
|
||||
@@ -527,7 +527,10 @@ export class SubscriptionService {
|
||||
nextBillAt = new Date(subscription.current_period_end * 1000);
|
||||
}
|
||||
} else {
|
||||
this.event.emit('user.subscription.canceled', user.id);
|
||||
this.event.emit('user.subscription.canceled', {
|
||||
userId: user.id,
|
||||
plan,
|
||||
});
|
||||
}
|
||||
|
||||
const commonData = {
|
||||
|
||||
@@ -53,7 +53,10 @@ declare module '../../fundamentals/event/def' {
|
||||
userId: User['id'];
|
||||
plan: SubscriptionPlan;
|
||||
}>;
|
||||
canceled: Payload<User['id']>;
|
||||
canceled: Payload<{
|
||||
userId: User['id'];
|
||||
plan: SubscriptionPlan;
|
||||
}>;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user