diff --git a/packages/backend/server/src/__tests__/__snapshots__/copilot.spec.ts.md b/packages/backend/server/src/__tests__/__snapshots__/copilot.spec.ts.md index ee137de87d..39a6e2be9a 100644 --- a/packages/backend/server/src/__tests__/__snapshots__/copilot.spec.ts.md +++ b/packages/backend/server/src/__tests__/__snapshots__/copilot.spec.ts.md @@ -116,6 +116,38 @@ Generated by [AVA](https://avajs.dev). }, ] +> should generate the final message + + [ + { + content: 'hello world', + params: { + word: 'world', + }, + role: 'system', + }, + { + content: 'hello', + params: {}, + role: 'user', + }, + { + content: 'world', + params: {}, + role: 'assistant', + }, + { + content: 'aaa', + params: {}, + role: 'user', + }, + { + content: 'bbb', + params: {}, + role: 'assistant', + }, + ] + ## should revert message correctly > should have three messages before revert diff --git a/packages/backend/server/src/__tests__/__snapshots__/copilot.spec.ts.snap b/packages/backend/server/src/__tests__/__snapshots__/copilot.spec.ts.snap index e2b0012467..943a031e2d 100644 Binary files a/packages/backend/server/src/__tests__/__snapshots__/copilot.spec.ts.snap and b/packages/backend/server/src/__tests__/__snapshots__/copilot.spec.ts.snap differ diff --git a/packages/backend/server/src/__tests__/copilot.e2e.ts b/packages/backend/server/src/__tests__/copilot.e2e.ts index 3ee870cb78..3e5537a0dd 100644 --- a/packages/backend/server/src/__tests__/copilot.e2e.ts +++ b/packages/backend/server/src/__tests__/copilot.e2e.ts @@ -281,7 +281,7 @@ test('should fork session correctly', async t => { const assertForkSession = async ( workspaceId: string, sessionId: string, - lastMessageId: string, + lastMessageId: string | undefined, error: string, asserter = async (x: any) => { const forkedSessionId = await x; @@ -330,6 +330,27 @@ test('should fork session correctly', async t => { ); } + // should be able to fork session without latestMessageId (copy all messages) + { + forkedSessionId = await assertForkSession( + id, + sessionId, + undefined, + 'should be able to fork session without latestMessageId' + ); + } + + // should not be able to fork session with wrong latestMessageId + { + await assertForkSession(id, sessionId, 'wrong-message-id', '', async x => { + await t.throwsAsync( + x, + { instanceOf: Error }, + 'should not able to fork session with wrong latestMessageId' + ); + }); + } + { const u2 = await app.signupV1('u2@affine.pro'); await assertForkSession(id, sessionId, randomUUID(), '', async x => { diff --git a/packages/backend/server/src/__tests__/copilot.spec.ts b/packages/backend/server/src/__tests__/copilot.spec.ts index 233bd1584e..547878928a 100644 --- a/packages/backend/server/src/__tests__/copilot.spec.ts +++ b/packages/backend/server/src/__tests__/copilot.spec.ts @@ -410,6 +410,27 @@ test('should be able to fork chat session', async t => { 'should fork new session with same params' ); + // fork session without latestMessageId + const forkedSessionId3 = await session.fork({ + userId, + sessionId, + ...commonParams, + }); + + // fork session with wrong latestMessageId + await t.throwsAsync( + session.fork({ + userId, + sessionId, + latestMessageId: 'wrong-message-id', + ...commonParams, + }), + { + instanceOf: Error, + }, + 'should not able to fork new session with wrong latestMessageId' + ); + const cleanObject = (obj: any[]) => JSON.parse( JSON.stringify(obj, (k, v) => @@ -436,11 +457,17 @@ test('should be able to fork chat session', async t => { t.snapshot(cleanObject(finalMessages), 'should generate the final message'); } + // check third times forked session + { + const s3 = (await session.get(forkedSessionId3))!; + const finalMessages = s3.finish(params); + t.snapshot(cleanObject(finalMessages), 'should generate the final message'); + } + // check original session messages { - const s3 = (await session.get(sessionId))!; - - const finalMessages = s3.finish(params); + const s4 = (await session.get(sessionId))!; + const finalMessages = s4.finish(params); t.snapshot(cleanObject(finalMessages), 'should generate the final message'); } diff --git a/packages/backend/server/src/__tests__/utils/copilot.ts b/packages/backend/server/src/__tests__/utils/copilot.ts index 59f1af7b69..8ae7bb6a67 100644 --- a/packages/backend/server/src/__tests__/utils/copilot.ts +++ b/packages/backend/server/src/__tests__/utils/copilot.ts @@ -57,7 +57,7 @@ export async function forkCopilotSession( workspaceId: string, docId: string, sessionId: string, - latestMessageId: string + latestMessageId?: string ): Promise { const res = await app.gql( ` diff --git a/packages/backend/server/src/plugins/copilot/resolver.ts b/packages/backend/server/src/plugins/copilot/resolver.ts index 53214f6559..b70634f8c4 100644 --- a/packages/backend/server/src/plugins/copilot/resolver.ts +++ b/packages/backend/server/src/plugins/copilot/resolver.ts @@ -91,8 +91,9 @@ class ForkChatSessionInput { @Field(() => String, { description: 'Identify a message in the array and keep it with all previous messages into a forked session.', + nullable: true, }) - latestMessageId!: string; + latestMessageId?: string; } @InputType() diff --git a/packages/backend/server/src/plugins/copilot/session.ts b/packages/backend/server/src/plugins/copilot/session.ts index e160098b27..c2dadeb31b 100644 --- a/packages/backend/server/src/plugins/copilot/session.ts +++ b/packages/backend/server/src/plugins/copilot/session.ts @@ -673,16 +673,19 @@ export class ChatSessionService { if (!state) { throw new CopilotSessionNotFound(); } - const lastMessageIdx = state.messages.findLastIndex( - ({ id, role }) => - role === AiPromptRole.assistant && id === options.latestMessageId - ); - if (lastMessageIdx < 0) { - throw new CopilotMessageNotFound({ messageId: options.latestMessageId }); + let messages = state.messages.map(m => ({ ...m, id: undefined })); + if (options.latestMessageId) { + const lastMessageIdx = state.messages.findLastIndex( + ({ id, role }) => + role === AiPromptRole.assistant && id === options.latestMessageId + ); + if (lastMessageIdx < 0) { + throw new CopilotMessageNotFound({ + messageId: options.latestMessageId, + }); + } + messages = messages.slice(0, lastMessageIdx + 1); } - const messages = state.messages - .slice(0, lastMessageIdx + 1) - .map(m => ({ ...m, id: undefined })); const forkedState = { ...state, diff --git a/packages/backend/server/src/plugins/copilot/types.ts b/packages/backend/server/src/plugins/copilot/types.ts index ae403bf43d..9aefd70f55 100644 --- a/packages/backend/server/src/plugins/copilot/types.ts +++ b/packages/backend/server/src/plugins/copilot/types.ts @@ -119,7 +119,7 @@ export interface ChatSessionPromptUpdateOptions export interface ChatSessionForkOptions extends Omit { sessionId: string; - latestMessageId: string; + latestMessageId?: string; } export interface ChatSessionState diff --git a/packages/backend/server/src/schema.gql b/packages/backend/server/src/schema.gql index ff30e3190e..f9b8da63a4 100644 --- a/packages/backend/server/src/schema.gql +++ b/packages/backend/server/src/schema.gql @@ -703,7 +703,7 @@ input ForkChatSessionInput { """ Identify a message in the array and keep it with all previous messages into a forked session. """ - latestMessageId: String! + latestMessageId: String sessionId: String! workspaceId: String! } diff --git a/packages/common/graphql/src/schema.ts b/packages/common/graphql/src/schema.ts index 1725f57530..ccebb1b72d 100644 --- a/packages/common/graphql/src/schema.ts +++ b/packages/common/graphql/src/schema.ts @@ -879,7 +879,7 @@ export enum FeatureType { export interface ForkChatSessionInput { docId: Scalars['String']['input']; /** Identify a message in the array and keep it with all previous messages into a forked session. */ - latestMessageId: Scalars['String']['input']; + latestMessageId?: InputMaybe; sessionId: Scalars['String']['input']; workspaceId: Scalars['String']['input']; } diff --git a/packages/frontend/core/src/blocksuite/ai/actions/types.ts b/packages/frontend/core/src/blocksuite/ai/actions/types.ts index 40e701c427..f690489d9d 100644 --- a/packages/frontend/core/src/blocksuite/ai/actions/types.ts +++ b/packages/frontend/core/src/blocksuite/ai/actions/types.ts @@ -93,7 +93,7 @@ declare global { docId: string; workspaceId: string; sessionId: string; - latestMessageId: string; + latestMessageId?: string; } interface AIImageActionOptions extends AITextActionOptions {