diff --git a/packages/backend/server/src/__tests__/__snapshots__/copilot.spec.ts.md b/packages/backend/server/src/__tests__/__snapshots__/copilot.spec.ts.md index 4f239cbb40..c53b94d56c 100644 --- a/packages/backend/server/src/__tests__/__snapshots__/copilot.spec.ts.md +++ b/packages/backend/server/src/__tests__/__snapshots__/copilot.spec.ts.md @@ -4,6 +4,118 @@ The actual snapshot is saved in `copilot.spec.ts.snap`. Generated by [AVA](https://avajs.dev). +## should be able to manage chat session + +> should generate the final message + + [ + { + content: 'hello world', + params: { + word: 'world', + }, + role: 'system', + }, + { + content: 'hello', + role: 'user', + }, + ] + +> should generate different message with another params + + [ + { + content: 'hello world', + params: { + word: 'world', + }, + role: 'system', + }, + { + content: 'hello', + role: 'user', + }, + ] + +## should be able to fork chat session + +> should generate the final message + + [ + { + content: 'hello world', + params: { + word: 'world', + }, + role: 'system', + }, + { + content: 'hello', + params: {}, + role: 'user', + }, + { + content: 'world', + params: {}, + role: 'assistant', + }, + ] + +> should generate the final message + + [ + { + content: 'hello world', + params: { + word: 'world', + }, + role: 'system', + }, + { + content: 'hello', + params: {}, + role: 'user', + }, + { + content: 'world', + params: {}, + role: 'assistant', + }, + ] + +> should generate the final message + + [ + { + content: 'hello world', + params: { + word: 'world', + }, + role: 'system', + }, + { + content: 'hello', + params: {}, + role: 'user', + }, + { + content: 'world', + params: {}, + role: 'assistant', + }, + { + content: 'aaa', + params: {}, + role: 'user', + }, + { + content: 'bbb', + params: {}, + role: 'assistant', + }, + ] + ## should revert message correctly > should have three messages before revert diff --git a/packages/backend/server/src/__tests__/__snapshots__/copilot.spec.ts.snap b/packages/backend/server/src/__tests__/__snapshots__/copilot.spec.ts.snap index 516503db91..887ca9407d 100644 Binary files a/packages/backend/server/src/__tests__/__snapshots__/copilot.spec.ts.snap and b/packages/backend/server/src/__tests__/__snapshots__/copilot.spec.ts.snap differ diff --git a/packages/backend/server/src/__tests__/copilot.spec.ts b/packages/backend/server/src/__tests__/copilot.spec.ts index 85e435b63e..a93a4b1da0 100644 --- a/packages/backend/server/src/__tests__/copilot.spec.ts +++ b/packages/backend/server/src/__tests__/copilot.spec.ts @@ -241,35 +241,31 @@ test('should be able to manage chat session', async t => { t.is(s.config.promptName, 'prompt', 'should have prompt name'); t.is(s.model, 'model', 'should have model'); + const cleanObject = (obj: any[]) => + JSON.parse( + JSON.stringify(obj, (k, v) => + ['id', 'attachments', 'createdAt'].includes(k) || + v === null || + (typeof v === 'object' && !Object.keys(v).length) + ? undefined + : v + ) + ); + s.push({ role: 'user', content: 'hello', createdAt: new Date() }); - // @ts-expect-error - const finalMessages = s.finish(params).map(({ createdAt: _, ...m }) => m); - t.deepEqual( - finalMessages, - [ - { content: 'hello world', params, role: 'system' }, - { content: 'hello', role: 'user' }, - ], - 'should generate the final message' - ); + + const finalMessages = cleanObject(s.finish(params)); + t.snapshot(finalMessages, 'should generate the final message'); await s.save(); const s1 = (await session.get(sessionId))!; t.deepEqual( - s1 - .finish(params) - // @ts-expect-error - .map(({ id: _, attachments: __, createdAt: ___, ...m }) => m), + cleanObject(s1.finish(params)), finalMessages, 'should same as before message' ); - t.deepEqual( - // @ts-expect-error - s1.finish({}).map(({ id: _, attachments: __, createdAt: ___, ...m }) => m), - [ - { content: 'hello ', params: {}, role: 'system' }, - { content: 'hello', role: 'user' }, - ], + t.snapshot( + cleanObject(s1.finish(params)), 'should generate different message with another params' ); @@ -366,22 +362,19 @@ test('should be able to fork chat session', async t => { 'should fork new session with same params' ); + const cleanObject = (obj: any[]) => + JSON.parse( + JSON.stringify(obj, (k, v) => + ['id', 'createdAt'].includes(k) || v === null ? undefined : v + ) + ); + // check forked session messages { const s2 = (await session.get(forkedSessionId1))!; - const finalMessages = s2 - .finish(params) // @ts-expect-error - .map(({ id: _, attachments: __, createdAt: ___, ...m }) => m); - t.deepEqual( - finalMessages, - [ - { role: 'system', content: 'hello world', params }, - { role: 'user', content: 'hello' }, - { role: 'assistant', content: 'world' }, - ], - 'should generate the final message' - ); + const finalMessages = s2.finish(params); + t.snapshot(cleanObject(finalMessages), 'should generate the final message'); } // check second times forked session @@ -391,38 +384,16 @@ test('should be able to fork chat session', async t => { // should overwrite user id t.is(s2.config.userId, newUser.id, 'should have same user id'); - const finalMessages = s2 - .finish(params) // @ts-expect-error - .map(({ id: _, attachments: __, createdAt: ___, ...m }) => m); - t.deepEqual( - finalMessages, - [ - { role: 'system', content: 'hello world', params }, - { role: 'user', content: 'hello' }, - { role: 'assistant', content: 'world' }, - ], - 'should generate the final message' - ); + const finalMessages = s2.finish(params); + t.snapshot(cleanObject(finalMessages), 'should generate the final message'); } // check original session messages { const s3 = (await session.get(sessionId))!; - const finalMessages = s3 - .finish(params) // @ts-expect-error - .map(({ id: _, attachments: __, createdAt: ___, ...m }) => m); - t.deepEqual( - finalMessages, - [ - { role: 'system', content: 'hello world', params }, - { role: 'user', content: 'hello' }, - { role: 'assistant', content: 'world' }, - { role: 'user', content: 'aaa' }, - { role: 'assistant', content: 'bbb' }, - ], - 'should generate the final message' - ); + const finalMessages = s3.finish(params); + t.snapshot(cleanObject(finalMessages), 'should generate the final message'); } // should get main session after fork if re-create a chat session for same docId and workspaceId @@ -612,7 +583,11 @@ test('should revert message correctly', async t => { const cleanObject = (obj: any[]) => JSON.parse( JSON.stringify(obj, (k, v) => - ['id', 'createdAt'].includes(k) || v === null ? undefined : v + ['id', 'createdAt'].includes(k) || + v === null || + (typeof v === 'object' && !Object.keys(v).length) + ? undefined + : v ) ); diff --git a/packages/backend/server/src/plugins/copilot/controller.ts b/packages/backend/server/src/plugins/copilot/controller.ts index 658a3d2b2c..2d2c1a480c 100644 --- a/packages/backend/server/src/plugins/copilot/controller.ts +++ b/packages/backend/server/src/plugins/copilot/controller.ts @@ -48,7 +48,7 @@ import { CurrentUser, Public } from '../../core/auth'; import { CopilotProviderService } from './providers'; import { ChatSession, ChatSessionService } from './session'; import { CopilotStorage } from './storage'; -import { CopilotCapability, CopilotTextProvider } from './types'; +import { ChatMessage, CopilotCapability, CopilotTextProvider } from './types'; import { CopilotWorkflowService, GraphExecutorState } from './workflow'; export interface ChatEvent { @@ -141,24 +141,28 @@ export class CopilotController implements BeforeApplicationShutdown { sessionId: string, messageId?: string, retry = false - ): Promise { + ): Promise<[ChatMessage | undefined, ChatSession]> { const session = await this.chatSession.get(sessionId); if (!session) { throw new CopilotSessionNotFound(); } + let latestMessage = undefined; if (!messageId || retry) { // revert the latest message generated by the assistant // if messageId is provided, we will also revert latest user message await this.chatSession.revertLatestMessage(sessionId, !!messageId); session.revertLatestMessage(!!messageId); + if (!messageId) { + latestMessage = session.latestUserMessage; + } } if (messageId) { await session.pushByMessageId(messageId); } - return session; + return [latestMessage, session]; } private prepareParams(params: Record) { @@ -226,7 +230,7 @@ export class CopilotController implements BeforeApplicationShutdown { messageId ); - const session = await this.appendSessionMessage( + const [latestMessage, session] = await this.appendSessionMessage( sessionId, messageId, retry @@ -234,6 +238,14 @@ export class CopilotController implements BeforeApplicationShutdown { info.model = session.model; metrics.ai.counter('chat_calls').add(1, { model: session.model }); + + if (latestMessage) { + params = Object.assign({}, params, latestMessage.params, { + content: latestMessage.content, + attachments: latestMessage.attachments, + }); + } + const finalMessage = session.finish(params); info.finalMessage = finalMessage; @@ -281,14 +293,22 @@ export class CopilotController implements BeforeApplicationShutdown { messageId ); - const session = await this.appendSessionMessage( + const [latestMessage, session] = await this.appendSessionMessage( sessionId, messageId, retry ); - info.model = session.model; + info.model = session.model; metrics.ai.counter('chat_stream_calls').add(1, { model: session.model }); + + if (latestMessage) { + params = Object.assign({}, params, latestMessage.params, { + content: latestMessage.content, + attachments: latestMessage.attachments, + }); + } + this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1); const finalMessage = session.finish(params); info.finalMessage = finalMessage; @@ -349,10 +369,11 @@ export class CopilotController implements BeforeApplicationShutdown { try { const { messageId } = this.prepareParams(params); - const session = await this.appendSessionMessage(sessionId, messageId); + const [, session] = await this.appendSessionMessage(sessionId, messageId); info.model = session.model; metrics.ai.counter('workflow_calls').add(1, { model: session.model }); + const latestMessage = session.stashMessages.findLast( m => m.role === 'user' ); @@ -463,12 +484,22 @@ export class CopilotController implements BeforeApplicationShutdown { throw new NoCopilotProviderAvailable(); } - const session = await this.appendSessionMessage(sessionId, messageId); + const [latestMessage, session] = await this.appendSessionMessage( + sessionId, + messageId + ); info.model = session.model; - metrics.ai .counter('images_stream_calls') .add(1, { model: session.model }); + + if (latestMessage) { + params = Object.assign({}, params, latestMessage.params, { + content: latestMessage.content, + attachments: latestMessage.attachments, + }); + } + const handleRemoteLink = this.storage.handleRemoteLink.bind( this.storage, user.id, diff --git a/packages/backend/server/src/plugins/copilot/session.ts b/packages/backend/server/src/plugins/copilot/session.ts index 502c59ab2f..94ec7bea38 100644 --- a/packages/backend/server/src/plugins/copilot/session.ts +++ b/packages/backend/server/src/plugins/copilot/session.ts @@ -63,6 +63,10 @@ export class ChatSession implements AsyncDisposable { return this.state.messages.slice(-this.stashMessageCount); } + get latestUserMessage() { + return this.state.messages.findLast(m => m.role === 'user'); + } + push(message: ChatMessage) { if ( this.state.prompt.action && @@ -313,6 +317,7 @@ export class ChatSessionService { role: true, content: true, attachments: true, + params: true, createdAt: true, }, orderBy: { createdAt: 'asc' },