From 91adc533a8ad8e47fda0d45032c241d044df07d0 Mon Sep 17 00:00:00 2001 From: darkskygit Date: Thu, 6 Mar 2025 07:32:26 +0000 Subject: [PATCH] fix(server): reuse params in retry (#10653) fix BS-2484 --- .../__snapshots__/copilot.spec.ts.md | 112 ++++++++++++++++++ .../__snapshots__/copilot.spec.ts.snap | Bin 686 -> 1110 bytes .../server/src/__tests__/copilot.spec.ts | 95 ++++++--------- .../server/src/plugins/copilot/controller.ts | 49 ++++++-- .../server/src/plugins/copilot/session.ts | 5 + 5 files changed, 192 insertions(+), 69 deletions(-) diff --git a/packages/backend/server/src/__tests__/__snapshots__/copilot.spec.ts.md b/packages/backend/server/src/__tests__/__snapshots__/copilot.spec.ts.md index 4f239cbb40..c53b94d56c 100644 --- a/packages/backend/server/src/__tests__/__snapshots__/copilot.spec.ts.md +++ b/packages/backend/server/src/__tests__/__snapshots__/copilot.spec.ts.md @@ -4,6 +4,118 @@ The actual snapshot is saved in `copilot.spec.ts.snap`. Generated by [AVA](https://avajs.dev). +## should be able to manage chat session + +> should generate the final message + + [ + { + content: 'hello world', + params: { + word: 'world', + }, + role: 'system', + }, + { + content: 'hello', + role: 'user', + }, + ] + +> should generate different message with another params + + [ + { + content: 'hello world', + params: { + word: 'world', + }, + role: 'system', + }, + { + content: 'hello', + role: 'user', + }, + ] + +## should be able to fork chat session + +> should generate the final message + + [ + { + content: 'hello world', + params: { + word: 'world', + }, + role: 'system', + }, + { + content: 'hello', + params: {}, + role: 'user', + }, + { + content: 'world', + params: {}, + role: 'assistant', + }, + ] + +> should generate the final message + + [ + { + content: 'hello world', + params: { + word: 'world', + }, + role: 'system', + }, + { + content: 'hello', + params: {}, + role: 'user', + }, + { + content: 'world', + params: {}, + role: 'assistant', + }, + ] + +> should generate the final message + + [ + { + content: 'hello world', + params: { + word: 'world', + }, + role: 'system', + }, + { + content: 'hello', + params: {}, + role: 'user', + }, + { + content: 'world', + params: {}, + role: 'assistant', + }, + { + content: 'aaa', + params: {}, + role: 'user', + }, + { + content: 'bbb', + params: {}, + role: 'assistant', + }, + ] + ## should revert message correctly > should have three messages before revert diff --git a/packages/backend/server/src/__tests__/__snapshots__/copilot.spec.ts.snap b/packages/backend/server/src/__tests__/__snapshots__/copilot.spec.ts.snap index 516503db9115986299cb84475d25fb58c8fc149a..887ca9407d82310d3593d7a7d6b8c3868ca518bd 100644 GIT binary patch literal 1110 zcmV-c1gZN$RzVsWT;ku|jNvvB!7faKo^!_f)Zg%+rgQj=fty_+^Y{%7B=vi&M zW6x@<*0e(_&_U4hI<25%9}AjZ7`(JxuiK%0NKPA!5xt}$#xCGK;9+39l&7S$&-d*` zk#saMF@cOD3aN_hzx2kc$n()(iPkVb08Rntfs5%nw4C_=eu=eo5-^Q>16qs?c41jvxMn; zuI`oui$SQ{aW9qAs}%QIYc=zDp#5lz>Sck~vfDHIwkO%u<~_gO?HBJ#ySAg8Eqd%E{#Pq*G`G_~)TvdVm@~iwz@xx{JXnkvk;7eBP9(5YMLvs$Dqivz;5XnV zfnC{etB)J=pui!47X{wQVW%SIrvhIH{3P&4w!FM?)GReLYI(g<)V7H{HdvPg!S(d$ z?L-gOV~ZUGtst~JVa)oHf)i?O)+?9Kme!md+YA>h;4a|)-@A>*3rnYZ98uB)u^^&IF1u{xG?-iR-&G1!nip)hsexsS#!2JT6am}3aoH$6l8ldd66_Y5h$T%%ZiGS`rdF9psDT*$qLjWGpV z72KJ1mh2ge@w9@&3SP+dEF<--GuhJq+hbq@`gQo5J?{xg?p~iVbJw|N03*p=Z1=zP zp=Jg!g-HW$$le|sdH$PDB-ga(B4PSkr?LwyM>jlQCu&UN?Z>MGSK0$S2s{oP1dim1 z&5RY)>Wb~%p4iUx#C9|QQyZ%$j?Pu;y{>_({l4DEM~;$iNS-oxAv zz|X*>z;r&|`vsm7@bmG0OW-|$&vWn|hS?uS_hZ=RGg}drUba`2wIa$CQRZ6qfdBvi literal 686 zcmV;f0#W@zRzVhwvA3I-LNQVD6kPkK|qjDq_~X`xTLs$fmQdQw{aTRIrOYnfqDN83?;NQ%|473o&g zT@TcC5QOexXo4DYMC6Bp?!&^IXS}9Oq;O0a=$^zb+ekO#S}LWMoY^juGjFxoS)Q}o ziguRl+<$6nY^P0?L{)NdESK24T;JwNGUZ}uYTZRX-ZLF|z7Cc~JIK|& z1?@(4qjeWM>-u^&G`d4$2RQX+3vl^y;39AZm;!F6!REvSHF~f;--K;_6Skin*h=CY z7Z?$^D&VIR_DJA~!0U9vJ_>vm_?49Q6Lvy@t6(C9u-#zeXV~;NRT8zL;A!%E|BKU4 z_70b3xS-u$*S%xLn~k(dxc&4Nw>x9NdEhcIndbD3vtp*_^p#DgzwJ1^B+eJ$8&DA# zO}F { t.is(s.config.promptName, 'prompt', 'should have prompt name'); t.is(s.model, 'model', 'should have model'); + const cleanObject = (obj: any[]) => + JSON.parse( + JSON.stringify(obj, (k, v) => + ['id', 'attachments', 'createdAt'].includes(k) || + v === null || + (typeof v === 'object' && !Object.keys(v).length) + ? undefined + : v + ) + ); + s.push({ role: 'user', content: 'hello', createdAt: new Date() }); - // @ts-expect-error - const finalMessages = s.finish(params).map(({ createdAt: _, ...m }) => m); - t.deepEqual( - finalMessages, - [ - { content: 'hello world', params, role: 'system' }, - { content: 'hello', role: 'user' }, - ], - 'should generate the final message' - ); + + const finalMessages = cleanObject(s.finish(params)); + t.snapshot(finalMessages, 'should generate the final message'); await s.save(); const s1 = (await session.get(sessionId))!; t.deepEqual( - s1 - .finish(params) - // @ts-expect-error - .map(({ id: _, attachments: __, createdAt: ___, ...m }) => m), + cleanObject(s1.finish(params)), finalMessages, 'should same as before message' ); - t.deepEqual( - // @ts-expect-error - s1.finish({}).map(({ id: _, attachments: __, createdAt: ___, ...m }) => m), - [ - { content: 'hello ', params: {}, role: 'system' }, - { content: 'hello', role: 'user' }, - ], + t.snapshot( + cleanObject(s1.finish(params)), 'should generate different message with another params' ); @@ -366,22 +362,19 @@ test('should be able to fork chat session', async t => { 'should fork new session with same params' ); + const cleanObject = (obj: any[]) => + JSON.parse( + JSON.stringify(obj, (k, v) => + ['id', 'createdAt'].includes(k) || v === null ? undefined : v + ) + ); + // check forked session messages { const s2 = (await session.get(forkedSessionId1))!; - const finalMessages = s2 - .finish(params) // @ts-expect-error - .map(({ id: _, attachments: __, createdAt: ___, ...m }) => m); - t.deepEqual( - finalMessages, - [ - { role: 'system', content: 'hello world', params }, - { role: 'user', content: 'hello' }, - { role: 'assistant', content: 'world' }, - ], - 'should generate the final message' - ); + const finalMessages = s2.finish(params); + t.snapshot(cleanObject(finalMessages), 'should generate the final message'); } // check second times forked session @@ -391,38 +384,16 @@ test('should be able to fork chat session', async t => { // should overwrite user id t.is(s2.config.userId, newUser.id, 'should have same user id'); - const finalMessages = s2 - .finish(params) // @ts-expect-error - .map(({ id: _, attachments: __, createdAt: ___, ...m }) => m); - t.deepEqual( - finalMessages, - [ - { role: 'system', content: 'hello world', params }, - { role: 'user', content: 'hello' }, - { role: 'assistant', content: 'world' }, - ], - 'should generate the final message' - ); + const finalMessages = s2.finish(params); + t.snapshot(cleanObject(finalMessages), 'should generate the final message'); } // check original session messages { const s3 = (await session.get(sessionId))!; - const finalMessages = s3 - .finish(params) // @ts-expect-error - .map(({ id: _, attachments: __, createdAt: ___, ...m }) => m); - t.deepEqual( - finalMessages, - [ - { role: 'system', content: 'hello world', params }, - { role: 'user', content: 'hello' }, - { role: 'assistant', content: 'world' }, - { role: 'user', content: 'aaa' }, - { role: 'assistant', content: 'bbb' }, - ], - 'should generate the final message' - ); + const finalMessages = s3.finish(params); + t.snapshot(cleanObject(finalMessages), 'should generate the final message'); } // should get main session after fork if re-create a chat session for same docId and workspaceId @@ -612,7 +583,11 @@ test('should revert message correctly', async t => { const cleanObject = (obj: any[]) => JSON.parse( JSON.stringify(obj, (k, v) => - ['id', 'createdAt'].includes(k) || v === null ? undefined : v + ['id', 'createdAt'].includes(k) || + v === null || + (typeof v === 'object' && !Object.keys(v).length) + ? undefined + : v ) ); diff --git a/packages/backend/server/src/plugins/copilot/controller.ts b/packages/backend/server/src/plugins/copilot/controller.ts index 658a3d2b2c..2d2c1a480c 100644 --- a/packages/backend/server/src/plugins/copilot/controller.ts +++ b/packages/backend/server/src/plugins/copilot/controller.ts @@ -48,7 +48,7 @@ import { CurrentUser, Public } from '../../core/auth'; import { CopilotProviderService } from './providers'; import { ChatSession, ChatSessionService } from './session'; import { CopilotStorage } from './storage'; -import { CopilotCapability, CopilotTextProvider } from './types'; +import { ChatMessage, CopilotCapability, CopilotTextProvider } from './types'; import { CopilotWorkflowService, GraphExecutorState } from './workflow'; export interface ChatEvent { @@ -141,24 +141,28 @@ export class CopilotController implements BeforeApplicationShutdown { sessionId: string, messageId?: string, retry = false - ): Promise { + ): Promise<[ChatMessage | undefined, ChatSession]> { const session = await this.chatSession.get(sessionId); if (!session) { throw new CopilotSessionNotFound(); } + let latestMessage = undefined; if (!messageId || retry) { // revert the latest message generated by the assistant // if messageId is provided, we will also revert latest user message await this.chatSession.revertLatestMessage(sessionId, !!messageId); session.revertLatestMessage(!!messageId); + if (!messageId) { + latestMessage = session.latestUserMessage; + } } if (messageId) { await session.pushByMessageId(messageId); } - return session; + return [latestMessage, session]; } private prepareParams(params: Record) { @@ -226,7 +230,7 @@ export class CopilotController implements BeforeApplicationShutdown { messageId ); - const session = await this.appendSessionMessage( + const [latestMessage, session] = await this.appendSessionMessage( sessionId, messageId, retry @@ -234,6 +238,14 @@ export class CopilotController implements BeforeApplicationShutdown { info.model = session.model; metrics.ai.counter('chat_calls').add(1, { model: session.model }); + + if (latestMessage) { + params = Object.assign({}, params, latestMessage.params, { + content: latestMessage.content, + attachments: latestMessage.attachments, + }); + } + const finalMessage = session.finish(params); info.finalMessage = finalMessage; @@ -281,14 +293,22 @@ export class CopilotController implements BeforeApplicationShutdown { messageId ); - const session = await this.appendSessionMessage( + const [latestMessage, session] = await this.appendSessionMessage( sessionId, messageId, retry ); - info.model = session.model; + info.model = session.model; metrics.ai.counter('chat_stream_calls').add(1, { model: session.model }); + + if (latestMessage) { + params = Object.assign({}, params, latestMessage.params, { + content: latestMessage.content, + attachments: latestMessage.attachments, + }); + } + this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1); const finalMessage = session.finish(params); info.finalMessage = finalMessage; @@ -349,10 +369,11 @@ export class CopilotController implements BeforeApplicationShutdown { try { const { messageId } = this.prepareParams(params); - const session = await this.appendSessionMessage(sessionId, messageId); + const [, session] = await this.appendSessionMessage(sessionId, messageId); info.model = session.model; metrics.ai.counter('workflow_calls').add(1, { model: session.model }); + const latestMessage = session.stashMessages.findLast( m => m.role === 'user' ); @@ -463,12 +484,22 @@ export class CopilotController implements BeforeApplicationShutdown { throw new NoCopilotProviderAvailable(); } - const session = await this.appendSessionMessage(sessionId, messageId); + const [latestMessage, session] = await this.appendSessionMessage( + sessionId, + messageId + ); info.model = session.model; - metrics.ai .counter('images_stream_calls') .add(1, { model: session.model }); + + if (latestMessage) { + params = Object.assign({}, params, latestMessage.params, { + content: latestMessage.content, + attachments: latestMessage.attachments, + }); + } + const handleRemoteLink = this.storage.handleRemoteLink.bind( this.storage, user.id, diff --git a/packages/backend/server/src/plugins/copilot/session.ts b/packages/backend/server/src/plugins/copilot/session.ts index 502c59ab2f..94ec7bea38 100644 --- a/packages/backend/server/src/plugins/copilot/session.ts +++ b/packages/backend/server/src/plugins/copilot/session.ts @@ -63,6 +63,10 @@ export class ChatSession implements AsyncDisposable { return this.state.messages.slice(-this.stashMessageCount); } + get latestUserMessage() { + return this.state.messages.findLast(m => m.role === 'user'); + } + push(message: ChatMessage) { if ( this.state.prompt.action && @@ -313,6 +317,7 @@ export class ChatSessionService { role: true, content: true, attachments: true, + params: true, createdAt: true, }, orderBy: { createdAt: 'asc' },