feat(core): support fork session without latestMessageId (#12587)

Close [AI-86](https://linear.app/affine-design/issue/AI-86)

<!-- This is an auto-generated comment: release notes by coderabbit.ai -->

## Summary by CodeRabbit

- **New Features**
  - Improved chat session forking to allow creating a fork without specifying the latest message, enabling more flexible session management.

- **Bug Fixes**
  - Forking a chat session with an invalid latest message ID now correctly returns an error.

- **Tests**
  - Added and updated test cases to cover session forking with missing or invalid latest message IDs, ensuring robust behavior in these scenarios.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
akumatus
2025-05-28 07:34:23 +00:00
parent a045786c6a
commit eb49ffaedb
11 changed files with 103 additions and 19 deletions

View File

@@ -116,6 +116,38 @@ Generated by [AVA](https://avajs.dev).
}, },
] ]
> should generate the final message
[
{
content: 'hello world',
params: {
word: 'world',
},
role: 'system',
},
{
content: 'hello',
params: {},
role: 'user',
},
{
content: 'world',
params: {},
role: 'assistant',
},
{
content: 'aaa',
params: {},
role: 'user',
},
{
content: 'bbb',
params: {},
role: 'assistant',
},
]
## should revert message correctly ## should revert message correctly
> should have three messages before revert > should have three messages before revert

View File

@@ -281,7 +281,7 @@ test('should fork session correctly', async t => {
const assertForkSession = async ( const assertForkSession = async (
workspaceId: string, workspaceId: string,
sessionId: string, sessionId: string,
lastMessageId: string, lastMessageId: string | undefined,
error: string, error: string,
asserter = async (x: any) => { asserter = async (x: any) => {
const forkedSessionId = await x; const forkedSessionId = await x;
@@ -330,6 +330,27 @@ test('should fork session correctly', async t => {
); );
} }
// should be able to fork session without latestMessageId (copy all messages)
{
forkedSessionId = await assertForkSession(
id,
sessionId,
undefined,
'should be able to fork session without latestMessageId'
);
}
// should not be able to fork session with wrong latestMessageId
{
await assertForkSession(id, sessionId, 'wrong-message-id', '', async x => {
await t.throwsAsync(
x,
{ instanceOf: Error },
'should not able to fork session with wrong latestMessageId'
);
});
}
{ {
const u2 = await app.signupV1('u2@affine.pro'); const u2 = await app.signupV1('u2@affine.pro');
await assertForkSession(id, sessionId, randomUUID(), '', async x => { await assertForkSession(id, sessionId, randomUUID(), '', async x => {

View File

@@ -410,6 +410,27 @@ test('should be able to fork chat session', async t => {
'should fork new session with same params' 'should fork new session with same params'
); );
// fork session without latestMessageId
const forkedSessionId3 = await session.fork({
userId,
sessionId,
...commonParams,
});
// fork session with wrong latestMessageId
await t.throwsAsync(
session.fork({
userId,
sessionId,
latestMessageId: 'wrong-message-id',
...commonParams,
}),
{
instanceOf: Error,
},
'should not able to fork new session with wrong latestMessageId'
);
const cleanObject = (obj: any[]) => const cleanObject = (obj: any[]) =>
JSON.parse( JSON.parse(
JSON.stringify(obj, (k, v) => JSON.stringify(obj, (k, v) =>
@@ -436,11 +457,17 @@ test('should be able to fork chat session', async t => {
t.snapshot(cleanObject(finalMessages), 'should generate the final message'); t.snapshot(cleanObject(finalMessages), 'should generate the final message');
} }
// check third times forked session
{
const s3 = (await session.get(forkedSessionId3))!;
const finalMessages = s3.finish(params);
t.snapshot(cleanObject(finalMessages), 'should generate the final message');
}
// check original session messages // check original session messages
{ {
const s3 = (await session.get(sessionId))!; const s4 = (await session.get(sessionId))!;
const finalMessages = s4.finish(params);
const finalMessages = s3.finish(params);
t.snapshot(cleanObject(finalMessages), 'should generate the final message'); t.snapshot(cleanObject(finalMessages), 'should generate the final message');
} }

View File

@@ -57,7 +57,7 @@ export async function forkCopilotSession(
workspaceId: string, workspaceId: string,
docId: string, docId: string,
sessionId: string, sessionId: string,
latestMessageId: string latestMessageId?: string
): Promise<string> { ): Promise<string> {
const res = await app.gql( const res = await app.gql(
` `

View File

@@ -91,8 +91,9 @@ class ForkChatSessionInput {
@Field(() => String, { @Field(() => String, {
description: description:
'Identify a message in the array and keep it with all previous messages into a forked session.', 'Identify a message in the array and keep it with all previous messages into a forked session.',
nullable: true,
}) })
latestMessageId!: string; latestMessageId?: string;
} }
@InputType() @InputType()

View File

@@ -673,16 +673,19 @@ export class ChatSessionService {
if (!state) { if (!state) {
throw new CopilotSessionNotFound(); throw new CopilotSessionNotFound();
} }
const lastMessageIdx = state.messages.findLastIndex( let messages = state.messages.map(m => ({ ...m, id: undefined }));
({ id, role }) => if (options.latestMessageId) {
role === AiPromptRole.assistant && id === options.latestMessageId const lastMessageIdx = state.messages.findLastIndex(
); ({ id, role }) =>
if (lastMessageIdx < 0) { role === AiPromptRole.assistant && id === options.latestMessageId
throw new CopilotMessageNotFound({ messageId: options.latestMessageId }); );
if (lastMessageIdx < 0) {
throw new CopilotMessageNotFound({
messageId: options.latestMessageId,
});
}
messages = messages.slice(0, lastMessageIdx + 1);
} }
const messages = state.messages
.slice(0, lastMessageIdx + 1)
.map(m => ({ ...m, id: undefined }));
const forkedState = { const forkedState = {
...state, ...state,

View File

@@ -119,7 +119,7 @@ export interface ChatSessionPromptUpdateOptions
export interface ChatSessionForkOptions export interface ChatSessionForkOptions
extends Omit<ChatSessionOptions, 'promptName'> { extends Omit<ChatSessionOptions, 'promptName'> {
sessionId: string; sessionId: string;
latestMessageId: string; latestMessageId?: string;
} }
export interface ChatSessionState export interface ChatSessionState

View File

@@ -703,7 +703,7 @@ input ForkChatSessionInput {
""" """
Identify a message in the array and keep it with all previous messages into a forked session. Identify a message in the array and keep it with all previous messages into a forked session.
""" """
latestMessageId: String! latestMessageId: String
sessionId: String! sessionId: String!
workspaceId: String! workspaceId: String!
} }

View File

@@ -879,7 +879,7 @@ export enum FeatureType {
export interface ForkChatSessionInput { export interface ForkChatSessionInput {
docId: Scalars['String']['input']; docId: Scalars['String']['input'];
/** Identify a message in the array and keep it with all previous messages into a forked session. */ /** Identify a message in the array and keep it with all previous messages into a forked session. */
latestMessageId: Scalars['String']['input']; latestMessageId?: InputMaybe<Scalars['String']['input']>;
sessionId: Scalars['String']['input']; sessionId: Scalars['String']['input'];
workspaceId: Scalars['String']['input']; workspaceId: Scalars['String']['input'];
} }

View File

@@ -93,7 +93,7 @@ declare global {
docId: string; docId: string;
workspaceId: string; workspaceId: string;
sessionId: string; sessionId: string;
latestMessageId: string; latestMessageId?: string;
} }
interface AIImageActionOptions extends AITextActionOptions { interface AIImageActionOptions extends AITextActionOptions {