mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-05 09:04:56 +00:00
@@ -0,0 +1,4 @@
|
||||
-- AlterTable
|
||||
ALTER TABLE "ai_sessions_metadata" ADD COLUMN "deleted_at" TIMESTAMPTZ(6),
|
||||
ADD COLUMN "messageCost" INTEGER NOT NULL DEFAULT 0,
|
||||
ADD COLUMN "tokenCost" INTEGER NOT NULL DEFAULT 0;
|
||||
@@ -1,3 +1,3 @@
|
||||
# Please do not edit this file manually
|
||||
# It should be added in your version-control system (i.e. Git)
|
||||
provider = "postgresql"
|
||||
provider = "postgresql"
|
||||
@@ -480,12 +480,15 @@ model AiSessionMessage {
|
||||
}
|
||||
|
||||
model AiSession {
|
||||
id String @id @default(uuid()) @db.VarChar(36)
|
||||
userId String @map("user_id") @db.VarChar(36)
|
||||
workspaceId String @map("workspace_id") @db.VarChar(36)
|
||||
docId String @map("doc_id") @db.VarChar(36)
|
||||
promptName String @map("prompt_name") @db.VarChar(32)
|
||||
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6)
|
||||
id String @id @default(uuid()) @db.VarChar(36)
|
||||
userId String @map("user_id") @db.VarChar(36)
|
||||
workspaceId String @map("workspace_id") @db.VarChar(36)
|
||||
docId String @map("doc_id") @db.VarChar(36)
|
||||
promptName String @map("prompt_name") @db.VarChar(32)
|
||||
messageCost Int @default(0)
|
||||
tokenCost Int @default(0)
|
||||
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6)
|
||||
deletedAt DateTime? @map("deleted_at") @db.Timestamptz(6)
|
||||
|
||||
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
prompt AiPrompt @relation(fields: [promptName], references: [name], onDelete: Cascade)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { createHash } from 'node:crypto';
|
||||
|
||||
import { BadRequestException, Logger } from '@nestjs/common';
|
||||
import { BadRequestException, Logger, NotFoundException } from '@nestjs/common';
|
||||
import {
|
||||
Args,
|
||||
Field,
|
||||
@@ -55,6 +55,18 @@ class CreateChatSessionInput {
|
||||
promptName!: string;
|
||||
}
|
||||
|
||||
@InputType()
|
||||
class DeleteSessionInput {
|
||||
@Field(() => String)
|
||||
workspaceId!: string;
|
||||
|
||||
@Field(() => String)
|
||||
docId!: string;
|
||||
|
||||
@Field(() => [String])
|
||||
sessionIds!: string[];
|
||||
}
|
||||
|
||||
@InputType()
|
||||
class CreateChatMessageInput implements Omit<SubmittedMessage, 'content'> {
|
||||
@Field(() => String)
|
||||
@@ -264,6 +276,35 @@ export class CopilotResolver {
|
||||
return session;
|
||||
}
|
||||
|
||||
@Mutation(() => String, {
|
||||
description: 'Cleanup sessions',
|
||||
})
|
||||
async cleanupCopilotSession(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args({ name: 'options', type: () => DeleteSessionInput })
|
||||
options: DeleteSessionInput
|
||||
) {
|
||||
await this.permissions.checkCloudPagePermission(
|
||||
options.workspaceId,
|
||||
options.docId,
|
||||
user.id
|
||||
);
|
||||
if (!options.sessionIds.length) {
|
||||
return new NotFoundException('Session not found');
|
||||
}
|
||||
const lockFlag = `${COPILOT_LOCKER}:session:${user.id}:${options.workspaceId}`;
|
||||
await using lock = await this.mutex.lock(lockFlag);
|
||||
if (!lock) {
|
||||
return new TooManyRequestsException('Server is busy');
|
||||
}
|
||||
|
||||
const ret = await this.chatSession.cleanup({
|
||||
...options,
|
||||
userId: user.id,
|
||||
});
|
||||
return ret;
|
||||
}
|
||||
|
||||
@Mutation(() => String, {
|
||||
description: 'Create a chat message',
|
||||
})
|
||||
|
||||
@@ -186,7 +186,7 @@ export class ChatSessionService {
|
||||
|
||||
// find existing session if session is chat session
|
||||
if (!state.prompt.action) {
|
||||
const { id } =
|
||||
const { id, deletedAt } =
|
||||
(await tx.aiSession.findFirst({
|
||||
where: {
|
||||
userId: state.userId,
|
||||
@@ -194,8 +194,9 @@ export class ChatSessionService {
|
||||
docId: state.docId,
|
||||
prompt: { action: { equals: null } },
|
||||
},
|
||||
select: { id: true },
|
||||
select: { id: true, deletedAt: true },
|
||||
})) || {};
|
||||
if (deletedAt) throw new Error(`Session is deleted: ${id}`);
|
||||
if (id) sessionId = id;
|
||||
}
|
||||
|
||||
@@ -219,6 +220,21 @@ export class ChatSessionService {
|
||||
sessionId,
|
||||
})),
|
||||
});
|
||||
|
||||
// only count message generated by user
|
||||
const userMessages = state.messages.filter(m => m.role === 'user');
|
||||
await tx.aiSession.update({
|
||||
where: { id: sessionId },
|
||||
data: {
|
||||
messageCost: { increment: userMessages.length },
|
||||
tokenCost: {
|
||||
increment: this.calculateTokenSize(
|
||||
userMessages,
|
||||
state.prompt.model as AvailableModel
|
||||
),
|
||||
},
|
||||
},
|
||||
});
|
||||
}
|
||||
} else {
|
||||
await tx.aiSession.create({
|
||||
@@ -242,21 +258,15 @@ export class ChatSessionService {
|
||||
): Promise<ChatSessionState | undefined> {
|
||||
return await this.db.aiSession
|
||||
.findUnique({
|
||||
where: { id: sessionId },
|
||||
where: { id: sessionId, deletedAt: null },
|
||||
select: {
|
||||
id: true,
|
||||
userId: true,
|
||||
workspaceId: true,
|
||||
docId: true,
|
||||
messages: {
|
||||
select: {
|
||||
role: true,
|
||||
content: true,
|
||||
createdAt: true,
|
||||
},
|
||||
orderBy: {
|
||||
createdAt: 'asc',
|
||||
},
|
||||
select: { role: true, content: true, createdAt: true },
|
||||
orderBy: { createdAt: 'asc' },
|
||||
},
|
||||
promptName: true,
|
||||
},
|
||||
@@ -283,9 +293,18 @@ export class ChatSessionService {
|
||||
// after revert, we can retry the action
|
||||
async revertLatestMessage(sessionId: string) {
|
||||
await this.db.$transaction(async tx => {
|
||||
const id = await tx.aiSession
|
||||
.findUnique({
|
||||
where: { id: sessionId, deletedAt: null },
|
||||
select: { id: true },
|
||||
})
|
||||
.then(session => session?.id);
|
||||
if (!id) {
|
||||
throw new Error(`Session not found: ${sessionId}`);
|
||||
}
|
||||
const ids = await tx.aiSessionMessage
|
||||
.findMany({
|
||||
where: { sessionId },
|
||||
where: { sessionId: id },
|
||||
select: { id: true, role: true },
|
||||
orderBy: { createdAt: 'asc' },
|
||||
})
|
||||
@@ -312,22 +331,14 @@ export class ChatSessionService {
|
||||
.reduce((total, length) => total + length, 0);
|
||||
}
|
||||
|
||||
private async countUserActions(userId: string): Promise<number> {
|
||||
return await this.db.aiSession.count({
|
||||
where: { userId, prompt: { action: { not: null } } },
|
||||
private async countUserMessages(userId: string): Promise<number> {
|
||||
const sessions = await this.db.aiSession.findMany({
|
||||
where: { userId },
|
||||
select: { messageCost: true, prompt: { select: { action: true } } },
|
||||
});
|
||||
}
|
||||
|
||||
private async countUserChats(userId: string): Promise<number> {
|
||||
const chats = await this.db.aiSession.findMany({
|
||||
where: { userId, prompt: { action: null } },
|
||||
select: {
|
||||
_count: {
|
||||
select: { messages: { where: { role: AiPromptRole.user } } },
|
||||
},
|
||||
},
|
||||
});
|
||||
return chats.reduce((prev, chat) => prev + chat._count.messages, 0);
|
||||
return sessions
|
||||
.map(({ messageCost, prompt: { action } }) => (action ? 1 : messageCost))
|
||||
.reduce((prev, cost) => prev + cost, 0);
|
||||
}
|
||||
|
||||
async listSessions(
|
||||
@@ -344,6 +355,7 @@ export class ChatSessionService {
|
||||
prompt: {
|
||||
action: options?.action ? { not: null } : null,
|
||||
},
|
||||
deletedAt: null,
|
||||
},
|
||||
select: { id: true },
|
||||
})
|
||||
@@ -367,10 +379,12 @@ export class ChatSessionService {
|
||||
action: options?.action ? { not: null } : null,
|
||||
},
|
||||
id: options?.sessionId ? { equals: options.sessionId } : undefined,
|
||||
deletedAt: null,
|
||||
},
|
||||
select: {
|
||||
id: true,
|
||||
promptName: true,
|
||||
tokenCost: true,
|
||||
createdAt: true,
|
||||
messages: {
|
||||
select: {
|
||||
@@ -391,50 +405,48 @@ export class ChatSessionService {
|
||||
})
|
||||
.then(sessions =>
|
||||
Promise.all(
|
||||
sessions.map(async ({ id, promptName, messages, createdAt }) => {
|
||||
try {
|
||||
const ret = ChatMessageSchema.array().safeParse(messages);
|
||||
if (ret.success) {
|
||||
const prompt = await this.prompt.get(promptName);
|
||||
if (!prompt) {
|
||||
throw new Error(`Prompt not found: ${promptName}`);
|
||||
}
|
||||
const tokens = this.calculateTokenSize(
|
||||
ret.data,
|
||||
prompt.model as AvailableModel
|
||||
);
|
||||
sessions.map(
|
||||
async ({ id, promptName, tokenCost, messages, createdAt }) => {
|
||||
try {
|
||||
const ret = ChatMessageSchema.array().safeParse(messages);
|
||||
if (ret.success) {
|
||||
const prompt = await this.prompt.get(promptName);
|
||||
if (!prompt) {
|
||||
throw new Error(`Prompt not found: ${promptName}`);
|
||||
}
|
||||
|
||||
// render system prompt
|
||||
const preload = withPrompt
|
||||
? prompt
|
||||
.finish(ret.data[0]?.params || {}, id)
|
||||
.filter(({ role }) => role !== 'system')
|
||||
: [];
|
||||
// render system prompt
|
||||
const preload = withPrompt
|
||||
? prompt
|
||||
.finish(ret.data[0]?.params || {}, id)
|
||||
.filter(({ role }) => role !== 'system')
|
||||
: [];
|
||||
|
||||
// `createdAt` is required for history sorting in frontend, let's fake the creating time of prompt messages
|
||||
(preload as ChatMessage[]).forEach((msg, i) => {
|
||||
msg.createdAt = new Date(
|
||||
createdAt.getTime() - preload.length - i - 1
|
||||
// `createdAt` is required for history sorting in frontend, let's fake the creating time of prompt messages
|
||||
(preload as ChatMessage[]).forEach((msg, i) => {
|
||||
msg.createdAt = new Date(
|
||||
createdAt.getTime() - preload.length - i - 1
|
||||
);
|
||||
});
|
||||
|
||||
return {
|
||||
sessionId: id,
|
||||
action: prompt.action || undefined,
|
||||
tokens: tokenCost,
|
||||
createdAt,
|
||||
messages: preload.concat(ret.data),
|
||||
};
|
||||
} else {
|
||||
this.logger.error(
|
||||
`Unexpected message schema: ${JSON.stringify(ret.error)}`
|
||||
);
|
||||
});
|
||||
|
||||
return {
|
||||
sessionId: id,
|
||||
action: prompt.action || undefined,
|
||||
tokens,
|
||||
createdAt,
|
||||
messages: preload.concat(ret.data),
|
||||
};
|
||||
} else {
|
||||
this.logger.error(
|
||||
`Unexpected message schema: ${JSON.stringify(ret.error)}`
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
this.logger.error('Unexpected error in listHistories', e);
|
||||
}
|
||||
} catch (e) {
|
||||
this.logger.error('Unexpected error in listHistories', e);
|
||||
return undefined;
|
||||
}
|
||||
return undefined;
|
||||
})
|
||||
)
|
||||
)
|
||||
)
|
||||
.then(histories =>
|
||||
@@ -451,10 +463,9 @@ export class ChatSessionService {
|
||||
limit = quota.feature.copilotActionLimit;
|
||||
}
|
||||
|
||||
const actions = await this.countUserActions(userId);
|
||||
const chats = await this.countUserChats(userId);
|
||||
const used = await this.countUserMessages(userId);
|
||||
|
||||
return { limit, used: actions + chats };
|
||||
return { limit, used };
|
||||
}
|
||||
|
||||
async checkQuota(userId: string) {
|
||||
@@ -481,6 +492,49 @@ export class ChatSessionService {
|
||||
});
|
||||
}
|
||||
|
||||
async cleanup(
|
||||
options: Omit<ChatSessionOptions, 'promptName'> & { sessionIds: string[] }
|
||||
) {
|
||||
return await this.db.$transaction(async tx => {
|
||||
const sessions = await tx.aiSession.findMany({
|
||||
where: {
|
||||
id: { in: options.sessionIds },
|
||||
userId: options.userId,
|
||||
workspaceId: options.workspaceId,
|
||||
docId: options.docId,
|
||||
deletedAt: null,
|
||||
},
|
||||
select: { id: true, promptName: true },
|
||||
});
|
||||
const sessionIds = sessions.map(({ id }) => id);
|
||||
// cleanup all messages
|
||||
await tx.aiSessionMessage.deleteMany({
|
||||
where: { sessionId: { in: sessionIds } },
|
||||
});
|
||||
|
||||
// only mark action session as deleted
|
||||
// chat session always can be reuse
|
||||
{
|
||||
const actionIds = (
|
||||
await Promise.all(
|
||||
sessions.map(({ id, promptName }) =>
|
||||
this.prompt
|
||||
.get(promptName)
|
||||
.then(prompt => ({ id, action: !!prompt?.action }))
|
||||
)
|
||||
)
|
||||
)
|
||||
.filter(({ action }) => action)
|
||||
.map(({ id }) => id);
|
||||
|
||||
await tx.aiSession.updateMany({
|
||||
where: { id: { in: actionIds } },
|
||||
data: { deletedAt: new Date() },
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async createMessage(message: SubmittedMessage): Promise<string | undefined> {
|
||||
return await this.messageCache.set(message);
|
||||
}
|
||||
|
||||
@@ -76,6 +76,12 @@ type DeleteAccount {
|
||||
success: Boolean!
|
||||
}
|
||||
|
||||
input DeleteSessionInput {
|
||||
docId: String!
|
||||
sessionIds: [String!]!
|
||||
workspaceId: String!
|
||||
}
|
||||
|
||||
type DocHistoryType {
|
||||
id: String!
|
||||
timestamp: DateTime!
|
||||
@@ -184,6 +190,9 @@ type Mutation {
|
||||
changeEmail(email: String!, token: String!): UserType!
|
||||
changePassword(newPassword: String!, token: String!): UserType!
|
||||
|
||||
"""Cleanup sessions"""
|
||||
cleanupCopilotSession(options: DeleteSessionInput!): String!
|
||||
|
||||
"""Create a subscription checkout link of stripe"""
|
||||
createCheckoutSession(input: CreateCheckoutSessionInput!): String!
|
||||
|
||||
|
||||
Reference in New Issue
Block a user