From e32c9a814a5fcc768a41dcbf17737c38224cfdd5 Mon Sep 17 00:00:00 2001 From: DarkSky <25152247+darkskygit@users.noreply.github.com> Date: Wed, 25 Jun 2025 20:02:21 +0800 Subject: [PATCH] feat(server): improve session modify (#12928) fix AI-248 --- .../__tests__/__snapshots__/copilot.e2e.ts.md | 4 +- .../__snapshots__/copilot.e2e.ts.snap | Bin 1110 -> 1110 bytes .../server/src/__tests__/copilot.spec.ts | 2 +- .../__tests__/models/copilot-context.spec.ts | 14 +- .../__tests__/models/copilot-session.spec.ts | 53 ++-- .../server/src/models/copilot-session.ts | 245 ++++++++++++++---- packages/backend/server/src/native.ts | 17 +- .../src/plugins/copilot/prompt/chat-prompt.ts | 5 +- .../server/src/plugins/copilot/resolver.ts | 11 +- .../server/src/plugins/copilot/session.ts | 164 +++--------- .../server/src/plugins/copilot/types.ts | 37 --- 11 files changed, 282 insertions(+), 270 deletions(-) diff --git a/packages/backend/server/src/__tests__/__snapshots__/copilot.e2e.ts.md b/packages/backend/server/src/__tests__/__snapshots__/copilot.e2e.ts.md index 3c0ca41331..9ce1e157b4 100644 --- a/packages/backend/server/src/__tests__/__snapshots__/copilot.e2e.ts.md +++ b/packages/backend/server/src/__tests__/__snapshots__/copilot.e2e.ts.md @@ -16,7 +16,7 @@ Generated by [AVA](https://avajs.dev). role: 'assistant', }, ], - tokens: 0, + tokens: 8, }, ] @@ -30,7 +30,7 @@ Generated by [AVA](https://avajs.dev). role: 'assistant', }, ], - tokens: 0, + tokens: 8, }, ] diff --git a/packages/backend/server/src/__tests__/__snapshots__/copilot.e2e.ts.snap b/packages/backend/server/src/__tests__/__snapshots__/copilot.e2e.ts.snap index 22de4d5a0c97bbb16f02a200f7a37c0ae7254477..b70f3bf91e0d70e880dfead62afd653117f66233 100644 GIT binary patch literal 1110 zcmV-c1gZN$RzVa68AVvsm@Sq@w zA9$$guGy)b?y9D$X0xL=T@*zOh+b6C1jLJYSPve=no z%(5F2C3Bgoo_f{q{a(F)Rn3)lE74KcZeJBnxCg(s<=3Tyegu%HE;BCmX2bGj>( z&AS2BbMJlt@MFb0e{b)88F?2H&>-MRs(2Ud<=vPMbCFg~C|CCHUKdIj?gVpU%MIub zzAHw7(TONK^W0i#9al~yyd{#YhbA+WNb6WfTBl8)@wO9Yn6x_yy4@3W^yCKB6t)uGx)_gWwq+tN#?8v8`KZ%V-RI=1Ludi2 zJQYtzJdatPW{FtN;&%0A)^XRfqk);xk+4>(^@9Ap1;!kAL5>YvvO{Pud~;-l@ZL;? zFy_7|N1ST0jwUfpjqjel;)``;%XM_$IGEK3GIyd#zMD#(S|Gi$g26p<6{A`1-Ix)GUBLfUf{t129j(xvG>*1{GGNghvD4 z?17}DN+L?QOhsIs+V;|voBn~O4)ZlJ&Ey2p%e1T2$HzG5OMaVxcm5gj0Zox*FX#Fx zaxoF$t%tlP}b$05O z|3YWNSMKP{P1Tw3T>{=G;1dGAAYf;j-UK<{??!s_d!aWCqWe9Rdi4CD>E~~F+XF6p zz}Fsd%>(|Pu0kt5aMlN2@`2ZU;K~#gs#Exf5B%i=rvu=L0EmN%No~U9qX76U0KN}^ z-vXc!R!rvZXx**Qx|u>~9u4W{O(xwu8q&?%U{@(lyBEsAMhx^OmhHAMLOEs!r)I5~ z>u16;u41;y6B!Q%Wk#o2et5RdaH*W#Jr=)FY*%mIvh6C9N{M0Cyu|{^)8kQ=vE~2A zjQ?~=o371xR2h`(@G$Noz?~zz27P$&YfzZqeReN7dB_jPI@%A$LIGoKW*9^3{}KtM cP0EvLB+6xRVP+Dsj&6wII;eZZVE_~W0RCYjVgLXD literal 1110 zcmV-c1gZN$RzVYJ6%<@3P%ko$iLQGX7G;HnSv)8R z;s+io-Ib&#-BnFhWp<)B9TY`e5WT3Ny9i#y!+7u@-n@7cWD&dy-USb$CpVq$jOoly zGR!Qy%H)!&PQB{)ey`rYs&cj6idE1v+t-ARY%I2z>8gGlvX)@H6$@q+)57WjyBygr z<7pHb$y3u+*1Uc#47ueml5nwr;fJc^tN?fvz-g+SBL9Ncd_YvRmY0`7OCSQ(h|U*A zZU;aD;H|>2@=R+(1eTm}vtF-*PJkTHNutjI+(jv|*SegN=e$vJm#=YAPG=Mk_(z>`#R&fCkm2^SUuC9ROQ=-=HDQfO`kvtr8*;SRq` zMxIu&C_0PWn8;Wzt%!L`#9I$dCn%1LW!=bFr3Z|+t_AFOe8xw`FnGW8Sb1M8@gmh&|diF$m-y|x$0oT zeJ2`os)Q<-#xyg&d-jSiRDmhh(S6hF$Z9T4priC(YiK*iapkZ$oB%*H8wKqj>4NS{ zxf|w#c@(DKN+Qd6Ka3{gP?Bw^R$d3LI&2UX17HB)D*)F4ED~_OEG5%H`DH2L!O%Bz zASo%62x1;3A}mjB`bom|;6PLR*_xPTa+2s}+Lh`P6P&XpzkNep-?=g51DYYrPR8|9 z%DC2uej;SBtK|C*P;-Ds9N<|8*mU++>MHqPI>3$t{OkaKI=~sXWa7_MCs)XE$yFPT zwMJv@?Bi2tQ^9*e7k#==#%L_pwlvLSjnb$|ZpG93x`<=7!o;S~10}`kd~P<^n&tVY z+mpwL6eb!c*ZK!Bik<4}lON6jlTZMeEniuH0(NHUjhFHLZmc)I=Xz5ky5B*eN6!zNZuW+^9pH)s zeC+_&9pLZTD%5a+b1v|b3%uq6S7)eDmHalwy=Mr=BBudNRr)w(Q@S z$)7GO-7_fT9>tvpxN~IJpbrmz4f3 { t.truthy(sessionId, 'should create session'); // Update the session - const updatedSessionId = await session.updateSession({ + const updatedSessionId = await session.update({ sessionId, promptName: 'Search With AFFiNE AI', userId, diff --git a/packages/backend/server/src/__tests__/models/copilot-context.spec.ts b/packages/backend/server/src/__tests__/models/copilot-context.spec.ts index f3648f32eb..f768b2a6c1 100644 --- a/packages/backend/server/src/__tests__/models/copilot-context.spec.ts +++ b/packages/backend/server/src/__tests__/models/copilot-context.spec.ts @@ -1,6 +1,6 @@ import { randomUUID } from 'node:crypto'; -import { AiSession, PrismaClient, User, Workspace } from '@prisma/client'; +import { PrismaClient, User, Workspace } from '@prisma/client'; import ava, { TestFn } from 'ava'; import Sinon from 'sinon'; @@ -43,7 +43,7 @@ test.before(async t => { let user: User; let workspace: Workspace; -let session: AiSession; +let sessionId: string; let docId = 'doc1'; test.beforeEach(async t => { @@ -53,7 +53,7 @@ test.beforeEach(async t => { email: 'test@affine.pro', }); workspace = await t.context.workspace.create(user.id); - session = await t.context.copilotSession.create({ + sessionId = await t.context.copilotSession.create({ sessionId: randomUUID(), workspaceId: workspace.id, docId, @@ -68,7 +68,7 @@ test.after(async t => { }); test('should create a copilot context', async t => { - const { id: contextId } = await t.context.copilotContext.create(session.id); + const { id: contextId } = await t.context.copilotContext.create(sessionId); t.truthy(contextId); const context = await t.context.copilotContext.get(contextId); @@ -77,7 +77,7 @@ test('should create a copilot context', async t => { const config = await t.context.copilotContext.getConfig(contextId); t.is(config?.workspaceId, workspace.id, 'should get context config'); - const context1 = await t.context.copilotContext.getBySessionId(session.id); + const context1 = await t.context.copilotContext.getBySessionId(sessionId); t.is(context1?.id, contextId, 'should get context by session id'); }); @@ -87,7 +87,7 @@ test('should get null for non-exist job', async t => { }); test('should update context', async t => { - const { id: contextId } = await t.context.copilotContext.create(session.id); + const { id: contextId } = await t.context.copilotContext.create(sessionId); const config = await t.context.copilotContext.getConfig(contextId); const doc = { @@ -102,7 +102,7 @@ test('should update context', async t => { }); test('should insert embedding by doc id', async t => { - const { id: contextId } = await t.context.copilotContext.create(session.id); + const { id: contextId } = await t.context.copilotContext.create(sessionId); { await t.context.copilotContext.insertFileEmbedding(contextId, 'file-id', [ diff --git a/packages/backend/server/src/__tests__/models/copilot-session.spec.ts b/packages/backend/server/src/__tests__/models/copilot-session.spec.ts index 0e90c2d5cc..ddf01ed315 100644 --- a/packages/backend/server/src/__tests__/models/copilot-session.spec.ts +++ b/packages/backend/server/src/__tests__/models/copilot-session.spec.ts @@ -6,7 +6,7 @@ import ava, { ExecutionContext, TestFn } from 'ava'; import { CopilotPromptInvalid, CopilotSessionInvalidInput } from '../../base'; import { CopilotSessionModel, - UpdateChatSessionData, + UpdateChatSessionOptions, UserModel, WorkspaceModel, } from '../../models'; @@ -174,7 +174,10 @@ test('should check session validation for prompts', async t => { sessionTypes.forEach(({ name, session }) => { t.notThrows( () => - copilotSession.checkSessionPrompt(session, 'test-prompt', undefined), + copilotSession.checkSessionPrompt(session, { + name: 'test-prompt', + action: undefined, + }), `${name} session should allow non-action prompts` ); }); @@ -195,14 +198,20 @@ test('should check session validation for prompts', async t => { if (shouldThrow) { t.throws( () => - copilotSession.checkSessionPrompt(session, 'action-prompt', 'edit'), + copilotSession.checkSessionPrompt(session, { + name: 'action-prompt', + action: 'edit', + }), { instanceOf: CopilotPromptInvalid }, `${name} session should reject action prompts` ); } else { t.notThrows( () => - copilotSession.checkSessionPrompt(session, 'action-prompt', 'edit'), + copilotSession.checkSessionPrompt(session, { + name: 'action-prompt', + action: 'edit', + }), `${name} session should allow action prompts` ); } @@ -323,14 +332,19 @@ test('should handle session updates and validations', async t => { }, }); + type UpdateData = Omit; const assertUpdateThrows = async ( t: ExecutionContext, sessionId: string, - updateData: UpdateChatSessionData, + updateData: UpdateData, message: string ) => { await t.throwsAsync( - t.context.copilotSession.update(user.id, sessionId, updateData), + t.context.copilotSession.update({ + ...updateData, + userId: user.id, + sessionId, + }), { instanceOf: CopilotSessionInvalidInput }, message ); @@ -339,11 +353,15 @@ test('should handle session updates and validations', async t => { const assertUpdate = async ( t: ExecutionContext, sessionId: string, - updateData: UpdateChatSessionData, + updateData: UpdateData, message: string ) => { await t.notThrowsAsync( - t.context.copilotSession.update(user.id, sessionId, updateData), + t.context.copilotSession.update({ + ...updateData, + userId: user.id, + sessionId, + }), message ); }; @@ -386,7 +404,6 @@ test('should handle session updates and validations', async t => { 'forked session should reject docId update' ); } - { // case 3: prompt update validation await assertUpdate( @@ -415,14 +432,13 @@ test('should handle session updates and validations', async t => { await createTestSession(t, { sessionId: existingPinnedId, pinned: true }); // should unpin existing when pinning new session - await copilotSession.update(user.id, sessionId, { pinned: true }); + await copilotSession.update({ userId: user.id, sessionId, pinned: true }); - const sessionStatesAfterPin = await Promise.all([ - getSessionState(db, sessionId), - getSessionState(db, existingPinnedId), - ]); t.snapshot( - sessionStatesAfterPin, + [ + await getSessionState(db, sessionId), + await getSessionState(db, existingPinnedId), + ], 'should unpin existing when pinning new session' ); } @@ -430,11 +446,8 @@ test('should handle session updates and validations', async t => { // test type conversions { const conversionSteps: any[] = []; - const convertSession = async ( - step: string, - data: UpdateChatSessionData - ) => { - await copilotSession.update(user.id, sessionId, data); + const convertSession = async (step: string, data: UpdateData) => { + await copilotSession.update({ ...data, userId: user.id, sessionId }); const session = await db.aiSession.findUnique({ where: { id: sessionId }, select: { docId: true, pinned: true }, diff --git a/packages/backend/server/src/models/copilot-session.ts b/packages/backend/server/src/models/copilot-session.ts index 5d64e3667d..521f865f3a 100644 --- a/packages/backend/server/src/models/copilot-session.ts +++ b/packages/backend/server/src/models/copilot-session.ts @@ -9,6 +9,7 @@ import { CopilotSessionInvalidInput, CopilotSessionNotFound, } from '../base'; +import { getTokenEncoder } from '../native'; import { BaseModel } from './base'; export enum SessionType { @@ -17,6 +18,12 @@ export enum SessionType { Doc = 'doc', // docId points to specific document } +type ChatPrompt = { + name: string; + action?: string | null; + model: string; +}; + type ChatAttachment = { attachment: string; mimeType: string } | string; type ChatStreamObject = { @@ -38,7 +45,7 @@ type ChatMessage = { createdAt: Date; }; -type ChatSession = { +type PureChatSession = { sessionId: string; workspaceId: string; docId?: string | null; @@ -46,22 +53,44 @@ type ChatSession = { messages?: ChatMessage[]; // connect ids userId: string; - promptName: string; - promptAction: string | null; parentSessionId?: string | null; }; -export type UpdateChatSessionData = Partial< - Pick ->; -export type UpdateChatSession = Pick & - UpdateChatSessionData; +type ChatSession = PureChatSession & { + // connect ids + promptName: string; + promptAction: string | null; +}; -export type ListSessionOptions = { +type ChatSessionWithPrompt = PureChatSession & { + prompt: ChatPrompt; +}; + +type ChatSessionBaseState = Pick; + +export type ForkSessionOptions = Omit< + ChatSession, + 'messages' | 'promptName' | 'promptAction' +> & { + prompt: { name: string; action: string | null | undefined; model: string }; + messages: ChatMessage[]; +}; + +type UpdateChatSessionMessage = ChatSessionBaseState & { + prompt: { model: string }; + messages: ChatMessage[]; +}; + +export type UpdateChatSessionOptions = ChatSessionBaseState & + Pick, 'docId' | 'pinned' | 'promptName'>; + +export type UpdateChatSession = ChatSessionBaseState & UpdateChatSessionOptions; + +export type ListSessionOptions = Pick< + Partial, + 'sessionId' | 'workspaceId' | 'docId' | 'pinned' +> & { userId: string; - sessionId?: string; - workspaceId?: string; - docId?: string; action?: boolean; fork?: boolean; limit?: number; @@ -74,6 +103,13 @@ export type ListSessionOptions = { withMessages?: boolean; }; +export type CleanupSessionOptions = Pick< + ChatSession, + 'userId' | 'workspaceId' | 'docId' +> & { + sessionIds: string[]; +}; + @Injectable() export class CopilotSessionModel extends BaseModel { getSessionType(session: Pick): SessionType { @@ -84,10 +120,10 @@ export class CopilotSessionModel extends BaseModel { checkSessionPrompt( session: Pick, - promptName: string, - promptAction: string | undefined + prompt: Partial ): boolean { const sessionType = this.getSessionType(session); + const { name: promptName, action: promptAction } = prompt; // workspace and pinned sessions cannot use action prompts if ( @@ -110,12 +146,18 @@ export class CopilotSessionModel extends BaseModel { } @Transactional() - async create(state: ChatSession) { + async create(state: ChatSession, reuseChat = false): Promise { + // find and return existing session if session is chat session + if (reuseChat && !state.promptAction) { + const sessionId = await this.find(state); + if (sessionId) return sessionId; + } + if (state.pinned) { await this.unpin(state.workspaceId, state.userId); } - const row = await this.db.aiSession.create({ + const session = await this.db.aiSession.create({ data: { id: state.sessionId, workspaceId: state.workspaceId, @@ -127,8 +169,46 @@ export class CopilotSessionModel extends BaseModel { promptAction: state.promptAction, parentSessionId: state.parentSessionId, }, + select: { id: true }, }); - return row; + return session.id; + } + + @Transactional() + async createWithPrompt( + state: ChatSessionWithPrompt, + reuseChat = false + ): Promise { + const { prompt, ...rest } = state; + return await this.models.copilotSession.create( + { ...rest, promptName: prompt.name, promptAction: prompt.action ?? null }, + reuseChat + ); + } + + @Transactional() + async fork(options: ForkSessionOptions): Promise { + if (!options.messages?.length) { + throw new CopilotSessionInvalidInput( + 'Cannot fork session without messages' + ); + } + if (options.pinned) { + await this.unpin(options.workspaceId, options.userId); + } + const { messages, ...forkedState } = options; + + // create session + const sessionId = await this.createWithPrompt({ + ...forkedState, + messages: [], + }); + // save message + await this.models.copilotSession.updateMessages({ + ...forkedState, + messages, + }); + return sessionId; } @Transactional() @@ -143,9 +223,7 @@ export class CopilotSessionModel extends BaseModel { } @Transactional() - async getChatSessionId( - state: Omit - ) { + async find(state: PureChatSession) { const extraCondition: Record = {}; if (state.parentSessionId) { // also check session id if provided session is forked session @@ -287,11 +365,8 @@ export class CopilotSessionModel extends BaseModel { } @Transactional() - async update( - userId: string, - sessionId: string, - data: UpdateChatSessionData - ): Promise { + async update(options: UpdateChatSessionOptions): Promise { + const { userId, sessionId, docId, promptName, pinned } = options; const session = await this.getExists( sessionId, { @@ -313,33 +388,71 @@ export class CopilotSessionModel extends BaseModel { throw new CopilotSessionInvalidInput( `Cannot update action: ${session.id}` ); - } else if (data.docId && session.parentSessionId) { + } else if (docId && session.parentSessionId) { throw new CopilotSessionInvalidInput( `Cannot update docId for forked session: ${session.id}` ); } - if (data.promptName) { + if (promptName) { const prompt = await this.db.aiPrompt.findFirst({ - where: { name: data.promptName }, + where: { name: promptName }, }); // always not allow to update to action prompt if (!prompt || prompt.action) { throw new CopilotSessionInvalidInput( - `Prompt ${data.promptName} not found or not available for session ${sessionId}` + `Prompt ${promptName} not found or not available for session ${sessionId}` ); } } - if (data.pinned && data.pinned !== session.pinned) { + if (pinned && pinned !== session.pinned) { // if pin the session, unpin exists session in the workspace await this.unpin(session.workspaceId, userId); } - await this.db.aiSession.update({ where: { id: sessionId }, data }); + await this.db.aiSession.update({ + where: { id: sessionId }, + data: { docId, promptName, pinned }, + }); return sessionId; } + @Transactional() + async cleanup(options: CleanupSessionOptions): Promise { + const sessions = await this.db.aiSession.findMany({ + where: { + id: { in: options.sessionIds }, + userId: options.userId, + workspaceId: options.workspaceId, + docId: options.docId, + deletedAt: null, + }, + select: { id: true, prompt: true }, + }); + const sessionIds = sessions.map(({ id }) => id); + // cleanup all messages + await this.db.aiSessionMessage.deleteMany({ + where: { sessionId: { in: sessionIds } }, + }); + + // only mark action session as deleted + // chat session always can be reuse + const actionIds = sessions + .filter(({ prompt }) => !!prompt.action) + .map(({ id }) => id); + + // 标记 action session 为已删除 + if (actionIds.length > 0) { + await this.db.aiSession.updateMany({ + where: { id: { in: actionIds } }, + data: { pinned: false, deletedAt: new Date() }, + }); + } + + return sessionIds; + } + @Transactional() async getMessages( sessionId: string, @@ -353,31 +466,42 @@ export class CopilotSessionModel extends BaseModel { }); } - @Transactional() - async setMessages( - sessionId: string, - messages: ChatMessage[], - tokenCost: number - ) { - await this.db.aiSessionMessage.createMany({ - data: messages.map(m => ({ - ...m, - attachments: m.attachments || undefined, - params: omit(m.params, ['docs']) || undefined, - streamObjects: m.streamObjects || undefined, - sessionId, - })), - }); + private calculateTokenSize(messages: any[], model: string): number { + const encoder = getTokenEncoder(model); + const content = messages.map(m => m.content).join(''); + return encoder?.count(content) || 0; + } - // only count message generated by user - const userMessages = messages.filter(m => m.role === 'user'); - await this.db.aiSession.update({ - where: { id: sessionId }, - data: { - messageCost: { increment: userMessages.length }, - tokenCost: { increment: tokenCost }, - }, - }); + @Transactional() + async updateMessages(state: UpdateChatSessionMessage) { + const { sessionId, userId, messages } = state; + const haveSession = await this.has(sessionId, userId); + if (!haveSession) { + throw new CopilotSessionNotFound(); + } + + if (messages.length) { + const tokenCost = this.calculateTokenSize(messages, state.prompt.model); + await this.db.aiSessionMessage.createMany({ + data: messages.map(m => ({ + ...m, + attachments: m.attachments || undefined, + params: omit(m.params, ['docs']) || undefined, + streamObjects: m.streamObjects || undefined, + sessionId, + })), + }); + + // only count message generated by user + const userMessages = messages.filter(m => m.role === 'user'); + await this.db.aiSession.update({ + where: { id: sessionId }, + data: { + messageCost: { increment: userMessages.length }, + tokenCost: { increment: tokenCost }, + }, + }); + } } @Transactional() @@ -404,4 +528,15 @@ export class CopilotSessionModel extends BaseModel { await this.db.aiSessionMessage.deleteMany({ where: { id: { in: ids } } }); } } + + @Transactional() + async countUserMessages(userId: string): Promise { + const sessions = await this.db.aiSession.findMany({ + where: { userId }, + select: { messageCost: true, prompt: { select: { action: true } } }, + }); + return sessions + .map(({ messageCost, prompt: { action } }) => (action ? 1 : messageCost)) + .reduce((prev, cost) => prev + cost, 0); + } } diff --git a/packages/backend/server/src/native.ts b/packages/backend/server/src/native.ts index 7b3a451fdf..135d47424a 100644 --- a/packages/backend/server/src/native.ts +++ b/packages/backend/server/src/native.ts @@ -1,4 +1,4 @@ -import serverNativeModule from '@affine/server-native'; +import serverNativeModule, { type Tokenizer } from '@affine/server-native'; export const mergeUpdatesInApplyWay = serverNativeModule.mergeUpdatesInApplyWay; @@ -16,10 +16,21 @@ export const mintChallengeResponse = async (resource: string, bits: number) => { return serverNativeModule.mintChallengeResponse(resource, bits); }; +export function getTokenEncoder(model?: string | null): Tokenizer | null { + if (!model) return null; + if (model.startsWith('gpt')) { + return serverNativeModule.fromModelName(model); + } else if (model.startsWith('dall')) { + // dalle don't need to calc the token + return null; + } else { + // c100k based model + return serverNativeModule.fromModelName('gpt-4'); + } +} + export const getMime = serverNativeModule.getMime; export const parseDoc = serverNativeModule.parseDoc; -export const Tokenizer = serverNativeModule.Tokenizer; -export const fromModelName = serverNativeModule.fromModelName; export const htmlSanitize = serverNativeModule.htmlSanitize; export const AFFINE_PRO_PUBLIC_KEY = serverNativeModule.AFFINE_PRO_PUBLIC_KEY; export const AFFINE_PRO_LICENSE_AES_KEY = diff --git a/packages/backend/server/src/plugins/copilot/prompt/chat-prompt.ts b/packages/backend/server/src/plugins/copilot/prompt/chat-prompt.ts index 5a80225f1a..317302683d 100644 --- a/packages/backend/server/src/plugins/copilot/prompt/chat-prompt.ts +++ b/packages/backend/server/src/plugins/copilot/prompt/chat-prompt.ts @@ -3,8 +3,8 @@ import { Logger } from '@nestjs/common'; import { AiPrompt } from '@prisma/client'; import Mustache from 'mustache'; +import { getTokenEncoder } from '../../../native'; import { PromptConfig, PromptMessage, PromptParams } from '../providers'; -import { getTokenEncoder } from '../types'; // disable escaping Mustache.escape = (text: string) => text; @@ -56,8 +56,7 @@ export class ChatPrompt { private readonly messages: PromptMessage[] ) { this.encoder = getTokenEncoder(model); - this.promptTokenSize = - this.encoder?.count(messages.map(m => m.content).join('') || '') || 0; + this.promptTokenSize = this.encode(messages.map(m => m.content).join('')); this.templateParamKeys = extractMustacheParams( messages.map(m => m.content).join('') ); diff --git a/packages/backend/server/src/plugins/copilot/resolver.ts b/packages/backend/server/src/plugins/copilot/resolver.ts index 98f22ee8dd..51832d0b4d 100644 --- a/packages/backend/server/src/plugins/copilot/resolver.ts +++ b/packages/backend/server/src/plugins/copilot/resolver.ts @@ -39,15 +39,12 @@ import { PromptMessage, StreamObject } from './providers'; import { ChatSessionService } from './session'; import { CopilotStorage } from './storage'; import { - AvailableModels, type ChatHistory, type ChatMessage, type ChatSessionState, SubmittedMessage, } from './types'; -registerEnumType(AvailableModels, { name: 'CopilotModel' }); - export const COPILOT_LOCKER = 'copilot'; // ================== Input Types ================== @@ -301,8 +298,6 @@ class CopilotPromptMessageType { params!: Record | null; } -registerEnumType(AvailableModels, { name: 'CopilotModels' }); - @ObjectType() class CopilotPromptType { @Field(() => String) @@ -533,7 +528,7 @@ export class CopilotResolver { } await this.chatSession.checkQuota(user.id); - return await this.chatSession.updateSession({ + return await this.chatSession.update({ ...options, userId: user.id, }); @@ -682,8 +677,8 @@ class CreateCopilotPromptInput { @Field(() => String) name!: string; - @Field(() => AvailableModels) - model!: AvailableModels; + @Field(() => String) + model!: string; @Field(() => String, { nullable: true }) action!: string | null; diff --git a/packages/backend/server/src/plugins/copilot/session.ts b/packages/backend/server/src/plugins/copilot/session.ts index e4c01bc3cc..d072e74cac 100644 --- a/packages/backend/server/src/plugins/copilot/session.ts +++ b/packages/backend/server/src/plugins/copilot/session.ts @@ -2,7 +2,7 @@ import { randomUUID } from 'node:crypto'; import { Injectable, Logger } from '@nestjs/common'; import { Transactional } from '@nestjs-cls/transactional'; -import { AiPromptRole, PrismaClient } from '@prisma/client'; +import { AiPromptRole } from '@prisma/client'; import { CopilotActionTaken, @@ -14,10 +14,11 @@ import { } from '../../base'; import { QuotaService } from '../../core/quota'; import { + CleanupSessionOptions, ListSessionOptions, Models, type UpdateChatSession, - UpdateChatSessionData, + UpdateChatSessionOptions, } from '../../models'; import { ChatMessageCache } from './message'; import { PromptService } from './prompt'; @@ -29,7 +30,6 @@ import { type ChatSessionForkOptions, type ChatSessionOptions, type ChatSessionState, - getTokenEncoder, type SubmittedMessage, } from './types'; @@ -224,46 +224,12 @@ export class ChatSessionService { private readonly logger = new Logger(ChatSessionService.name); constructor( - private readonly db: PrismaClient, private readonly quota: QuotaService, private readonly messageCache: ChatMessageCache, private readonly prompt: PromptService, private readonly models: Models ) {} - @Transactional() - private async setSession(state: ChatSessionState): Promise { - const session = this.models.copilotSession; - let sessionId = state.sessionId; - - // find existing session if session is chat session - if (!state.prompt.action) { - const id = await session.getChatSessionId(state); - if (id) sessionId = id; - } - - const haveSession = await session.has(sessionId, state.userId); - if (haveSession) { - // message will only exists when setSession call by session.save - if (state.messages.length) { - await session.setMessages( - sessionId, - state.messages, - this.calculateTokenSize(state.messages, state.prompt.model) - ); - } - } else { - await session.create({ - ...state, - sessionId, - promptName: state.prompt.name, - promptAction: state.prompt.action ?? null, - }); - } - - return sessionId; - } - async getSession(sessionId: string): Promise { const session = await this.models.copilotSession.get(sessionId); if (!session) return; @@ -296,23 +262,6 @@ export class ChatSessionService { ); } - private calculateTokenSize(messages: PromptMessage[], model: string): number { - const encoder = getTokenEncoder(model); - return messages - .map(m => encoder?.count(m.content) ?? 0) - .reduce((total, length) => total + length, 0); - } - - private async countUserMessages(userId: string): Promise { - const sessions = await this.db.aiSession.findMany({ - where: { userId }, - select: { messageCost: true, prompt: { select: { action: true } } }, - }); - return sessions - .map(({ messageCost, prompt: { action } }) => (action ? 1 : messageCost)) - .reduce((prev, cost) => prev + cost, 0); - } - async listSessions( options: ListSessionOptions ): Promise[]> { @@ -431,7 +380,7 @@ export class ChatSessionService { limit = quota.copilotActionLimit; } - const used = await this.countUserMessages(userId); + const used = await this.models.copilotSession.countUserMessages(userId); return { limit, used }; } @@ -456,20 +405,19 @@ export class ChatSessionService { } // validate prompt compatibility with session type - this.models.copilotSession.checkSessionPrompt( - options, - prompt.name, - prompt.action - ); + this.models.copilotSession.checkSessionPrompt(options, prompt); - return await this.setSession({ - ...options, - sessionId, - prompt, - messages: [], - // when client create chat session, we always find root session - parentSessionId: null, - }); + return await this.models.copilotSession.createWithPrompt( + { + ...options, + sessionId, + prompt, + messages: [], + // when client create chat session, we always find root session + parentSessionId: null, + }, + true + ); } @Transactional() @@ -478,13 +426,16 @@ export class ChatSessionService { } @Transactional() - async updateSession(options: UpdateChatSession): Promise { + async update(options: UpdateChatSession): Promise { const session = await this.getSession(options.sessionId); if (!session) { throw new CopilotSessionNotFound(); } - const finalData: UpdateChatSessionData = {}; + const finalData: UpdateChatSessionOptions = { + userId: options.userId, + sessionId: options.sessionId, + }; if (options.promptName) { const prompt = await this.prompt.get(options.promptName); if (!prompt) { @@ -492,11 +443,7 @@ export class ChatSessionService { throw new CopilotPromptNotFound({ name: options.promptName }); } - this.models.copilotSession.checkSessionPrompt( - session, - prompt.name, - prompt.action - ); + this.models.copilotSession.checkSessionPrompt(session, prompt); finalData.promptName = prompt.name; } finalData.pinned = options.pinned; @@ -508,21 +455,15 @@ export class ChatSessionService { ); } - return await this.models.copilotSession.update( - options.userId, - options.sessionId, - finalData - ); + return await this.models.copilotSession.update(finalData); } + @Transactional() async fork(options: ChatSessionForkOptions): Promise { const state = await this.getSession(options.sessionId); if (!state) { throw new CopilotSessionNotFound(); } - if (state.pinned) { - await this.unpin(options.workspaceId, options.userId); - } let messages = state.messages.map(m => ({ ...m, id: undefined })); if (options.latestMessageId) { @@ -538,62 +479,17 @@ export class ChatSessionService { messages = messages.slice(0, lastMessageIdx + 1); } - const forkedState = { + return await this.models.copilotSession.fork({ ...state, userId: options.userId, sessionId: randomUUID(), - messages: [], parentSessionId: options.sessionId, - }; - // create session - await this.setSession(forkedState); - // save message - return await this.setSession({ ...forkedState, messages }); + messages, + }); } - async cleanup( - options: Omit & { - sessionIds: string[]; - } - ) { - return await this.db.$transaction(async tx => { - const sessions = await tx.aiSession.findMany({ - where: { - id: { in: options.sessionIds }, - userId: options.userId, - workspaceId: options.workspaceId, - docId: options.docId, - deletedAt: null, - }, - select: { id: true, promptName: true }, - }); - const sessionIds = sessions.map(({ id }) => id); - // cleanup all messages - await tx.aiSessionMessage.deleteMany({ - where: { sessionId: { in: sessionIds } }, - }); - - // only mark action session as deleted - // chat session always can be reuse - const actionIds = ( - await Promise.all( - sessions.map(({ id, promptName }) => - this.prompt - .get(promptName) - .then(prompt => ({ id, action: !!prompt?.action })) - ) - ) - ) - .filter(({ action }) => action) - .map(({ id }) => id); - - await tx.aiSession.updateMany({ - where: { id: { in: actionIds } }, - data: { pinned: false, deletedAt: new Date() }, - }); - - return [...sessionIds, ...actionIds]; - }); + async cleanup(options: CleanupSessionOptions) { + return await this.models.copilotSession.cleanup(options); } async createMessage(message: SubmittedMessage): Promise { @@ -617,7 +513,7 @@ export class ChatSessionService { const state = await this.getSession(sessionId); if (state) { return new ChatSession(this.messageCache, state, async state => { - await this.setSession(state); + await this.models.copilotSession.updateMessages(state); }); } return null; diff --git a/packages/backend/server/src/plugins/copilot/types.ts b/packages/backend/server/src/plugins/copilot/types.ts index 62ab7d6a3d..c5465b8602 100644 --- a/packages/backend/server/src/plugins/copilot/types.ts +++ b/packages/backend/server/src/plugins/copilot/types.ts @@ -1,8 +1,6 @@ -import { type Tokenizer } from '@affine/server-native'; import { z } from 'zod'; import { OneMB } from '../../base'; -import { fromModelName } from '../../native'; import type { ChatPrompt } from './prompt'; import { PromptMessageSchema, PureMessageSchema } from './providers'; @@ -38,41 +36,6 @@ export const ChatQuerySchema = z }) ); -export enum AvailableModels { - // text to text - Gpt4Omni = 'gpt-4o', - Gpt4Omni0806 = 'gpt-4o-2024-08-06', - Gpt4OmniMini = 'gpt-4o-mini', - Gpt4OmniMini0718 = 'gpt-4o-mini-2024-07-18', - Gpt41 = 'gpt-4.1', - Gpt410414 = 'gpt-4.1-2025-04-14', - Gpt41Mini = 'gpt-4.1-mini', - Gpt41Nano = 'gpt-4.1-nano', - // embeddings - TextEmbedding3Large = 'text-embedding-3-large', - TextEmbedding3Small = 'text-embedding-3-small', - TextEmbeddingAda002 = 'text-embedding-ada-002', - // text to image - DallE3 = 'dall-e-3', - GptImage = 'gpt-image-1', -} - -const availableModels = Object.values(AvailableModels); - -export function getTokenEncoder(model?: string | null): Tokenizer | null { - if (!model) return null; - if (!availableModels.includes(model as AvailableModels)) return null; - if (model.startsWith('gpt')) { - return fromModelName(model); - } else if (model.startsWith('dall')) { - // dalle don't need to calc the token - return null; - } else { - // c100k based model - return fromModelName('gpt-4'); - } -} - // ======== ChatMessage ======== export const ChatMessageSchema = PromptMessageSchema.extend({