diff --git a/packages/backend/server/src/plugins/copilot/session.ts b/packages/backend/server/src/plugins/copilot/session.ts index 9b6c9d1a20..34a41e34d9 100644 --- a/packages/backend/server/src/plugins/copilot/session.ts +++ b/packages/backend/server/src/plugins/copilot/session.ts @@ -194,6 +194,12 @@ export class ChatSessionService { // find existing session if session is chat session if (!state.prompt.action) { + const extraCondition: Record = {}; + if (state.parentSessionId) { + // also check session id if provided session is forked session + extraCondition.id = state.sessionId; + extraCondition.parentSessionId = state.parentSessionId; + } const { id, deletedAt } = (await tx.aiSession.findFirst({ where: { @@ -201,7 +207,8 @@ export class ChatSessionService { workspaceId: state.workspaceId, docId: state.docId, prompt: { action: { equals: null } }, - parentSessionId: state.parentSessionId, + parentSessionId: null, + ...extraCondition, }, select: { id: true, deletedAt: true }, })) || {}; diff --git a/packages/backend/server/tests/copilot.spec.ts b/packages/backend/server/tests/copilot.spec.ts index 3b4a046f70..437dd4a0c8 100644 --- a/packages/backend/server/tests/copilot.spec.ts +++ b/packages/backend/server/tests/copilot.spec.ts @@ -290,17 +290,46 @@ test('should be able to fork chat session', async t => { const s1 = (await session.get(sessionId))!; // @ts-expect-error const latestMessageId = s1.finish({}).find(m => m.role === 'assistant')!.id; - const forkedSessionId = await session.fork({ + const forkedSessionId1 = await session.fork({ userId, sessionId, latestMessageId, ...commonParams, }); - t.not(sessionId, forkedSessionId, 'should fork a new session'); + t.not(sessionId, forkedSessionId1, 'should fork a new session'); + const forkedSessionId2 = await session.fork({ + userId, + sessionId, + latestMessageId, + ...commonParams, + }); + t.not( + forkedSessionId1, + forkedSessionId2, + 'should fork new session with same params' + ); // check forked session messages { - const s2 = (await session.get(forkedSessionId))!; + const s2 = (await session.get(forkedSessionId1))!; + + const finalMessages = s2 + .finish(params) // @ts-expect-error + .map(({ id: _, createdAt: __, ...m }) => m); + t.deepEqual( + finalMessages, + [ + { role: 'system', content: 'hello world', params }, + { role: 'user', content: 'hello' }, + { role: 'assistant', content: 'world' }, + ], + 'should generate the final message' + ); + } + + // check second times forked session messages + { + const s2 = (await session.get(forkedSessionId2))!; const finalMessages = s2 .finish(params) // @ts-expect-error