feat: allow retry with new message (#10307)

fix AF-1630
This commit is contained in:
darkskygit
2025-02-20 06:07:53 +00:00
parent 50820482df
commit 13f1859cdf
9 changed files with 313 additions and 44 deletions

View File

@@ -4,6 +4,36 @@ The actual snapshot is saved in `copilot.e2e.ts.snap`.
Generated by [AVA](https://avajs.dev).
## should be able to retry with api
> should be able to list history after retry
[
{
messages: [
{
content: 'generate text to text',
role: 'assistant',
},
],
tokens: 0,
},
]
> should be able to list history after retry
[
{
messages: [
{
content: 'generate text to text',
role: 'assistant',
},
],
tokens: 0,
},
]
## should be able to manage context
> should list context files
@@ -13,14 +43,3 @@ Generated by [AVA](https://avajs.dev).
id: 'docId1',
},
]
> should list context docs
[
{
blobId: 'fileId1',
chunkSize: 3,
name: 'sample.pdf',
status: 'finished',
},
]

View File

@@ -0,0 +1,151 @@
# Snapshot report for `src/__tests__/copilot.spec.ts`
The actual snapshot is saved in `copilot.spec.ts.snap`.
Generated by [AVA](https://avajs.dev).
## should revert message correctly
> should have three messages before revert
[
{
content: 'hello world',
params: {
word: 'world',
},
role: 'system',
},
{
content: '1',
role: 'user',
},
{
content: '2',
role: 'assistant',
},
{
content: '3',
role: 'user',
},
{
content: '4',
role: 'assistant',
},
]
> should remove assistant message after revert
[
{
content: 'hello world',
params: {
word: 'world',
},
role: 'system',
},
{
content: '1',
role: 'user',
},
{
content: '2',
role: 'assistant',
},
{
content: '3',
role: 'user',
},
]
> should remove assistant message after revert
[
{
content: 'hello world',
params: {
word: 'world',
},
role: 'system',
},
{
content: '1',
role: 'user',
},
{
content: '2',
role: 'assistant',
},
]
> should have three messages before revert
[
{
content: 'hello world',
params: {
word: 'world',
},
role: 'system',
},
{
content: '1',
role: 'user',
},
{
content: '2',
role: 'assistant',
},
{
content: '3',
role: 'user',
},
{
content: '4',
role: 'assistant',
},
]
> should remove assistant message after revert
[
{
content: 'hello world',
params: {
word: 'world',
},
role: 'system',
},
{
content: '1',
role: 'user',
},
{
content: '2',
role: 'assistant',
},
{
content: '3',
role: 'user',
},
]
> should remove assistant message after revert
[
{
content: 'hello world',
params: {
word: 'world',
},
role: 'system',
},
{
content: '1',
role: 'user',
},
{
content: '2',
role: 'assistant',
},
]

View File

@@ -498,6 +498,15 @@ test('should be able to retry with api', async t => {
);
}
const cleanObject = (obj: any[]) =>
JSON.parse(
JSON.stringify(obj, (k, v) =>
['id', 'sessionId', 'createdAt'].includes(k) || v === null
? undefined
: v
)
);
// retry chat
{
const { id } = await createWorkspace(app);
@@ -514,10 +523,32 @@ test('should be able to retry with api', async t => {
// should only have 1 message
const histories = await getHistories(app, { workspaceId: id });
t.deepEqual(
histories.map(h => h.messages.map(m => m.content)),
[['generate text to text']],
'should be able to list history'
t.snapshot(
cleanObject(histories),
'should be able to list history after retry'
);
}
// retry chat with new message id
{
const { id } = await createWorkspace(app);
const sessionId = await createCopilotSession(
app,
id,
randomUUID(),
promptName
);
const messageId = await createCopilotMessage(app, sessionId);
await chatWithText(app, sessionId, messageId);
// retry with new message id
const newMessageId = await createCopilotMessage(app, sessionId);
await chatWithText(app, sessionId, newMessageId, '', true);
// should only have 1 message
const histories = await getHistories(app, { workspaceId: id });
t.snapshot(
cleanObject(histories),
'should be able to list history after retry'
);
}

View File

@@ -602,36 +602,81 @@ test('should revert message correctly', async t => {
const message = (await session.createMessage({
sessionId,
content: 'hello',
content: '1',
}))!;
await s.pushByMessageId(message);
await s.save();
}
const cleanObject = (obj: any[]) =>
JSON.parse(
JSON.stringify(obj, (k, v) =>
['id', 'createdAt'].includes(k) || v === null ? undefined : v
)
);
// check ChatSession behavior
{
const s = (await session.get(sessionId))!;
s.push({ role: 'assistant', content: 'hi', createdAt: new Date() });
s.push({ role: 'assistant', content: '2', createdAt: new Date() });
s.push({ role: 'user', content: '3', createdAt: new Date() });
s.push({ role: 'assistant', content: '4', createdAt: new Date() });
await s.save();
const beforeRevert = s.finish({ word: 'world' });
t.is(beforeRevert.length, 3, 'should have three messages before revert');
t.snapshot(
cleanObject(beforeRevert),
'should have three messages before revert'
);
s.revertLatestMessage();
const afterRevert = s.finish({ word: 'world' });
t.is(afterRevert.length, 2, 'should remove assistant message after revert');
{
s.revertLatestMessage(false);
const afterRevert = s.finish({ word: 'world' });
t.snapshot(
cleanObject(afterRevert),
'should remove assistant message after revert'
);
}
{
s.revertLatestMessage(true);
const afterRevert = s.finish({ word: 'world' });
t.snapshot(
cleanObject(afterRevert),
'should remove assistant message after revert'
);
}
}
// check database behavior
{
let s = (await session.get(sessionId))!;
const beforeRevert = s.finish({ word: 'world' });
t.is(beforeRevert.length, 3, 'should have three messages before revert');
await session.revertLatestMessage(sessionId);
s = (await session.get(sessionId))!;
const afterRevert = s.finish({ word: 'world' });
t.is(afterRevert.length, 2, 'should remove assistant message after revert');
const beforeRevert = s.finish({ word: 'world' });
t.snapshot(
cleanObject(beforeRevert),
'should have three messages before revert'
);
{
await session.revertLatestMessage(sessionId, false);
s = (await session.get(sessionId))!;
const afterRevert = s.finish({ word: 'world' });
t.snapshot(
cleanObject(afterRevert),
'should remove assistant message after revert'
);
}
{
await session.revertLatestMessage(sessionId, true);
s = (await session.get(sessionId))!;
const afterRevert = s.finish({ word: 'world' });
t.snapshot(
cleanObject(afterRevert),
'should remove assistant message after revert'
);
}
}
});

View File

@@ -444,9 +444,12 @@ export async function chatWithText(
app: TestingApp,
sessionId: string,
messageId?: string,
prefix = ''
prefix = '',
retry?: boolean
): Promise<string> {
const query = messageId ? `?messageId=${messageId}` : '';
const query = messageId
? `?messageId=${messageId}` + (retry ? '&retry=true' : '')
: '';
const res = await app
.GET(`/api/copilot/chat/${sessionId}${prefix}${query}`)
.expect(200);

View File

@@ -137,20 +137,23 @@ export class CopilotController implements BeforeApplicationShutdown {
private async appendSessionMessage(
sessionId: string,
messageId?: string
messageId?: string,
retry = false
): Promise<ChatSession> {
const session = await this.chatSession.get(sessionId);
if (!session) {
throw new CopilotSessionNotFound();
}
if (!messageId || retry) {
// revert the latest message generated by the assistant
// if messageId is provided, we will also revert latest user message
await this.chatSession.revertLatestMessage(sessionId, !messageId);
session.revertLatestMessage(!messageId);
}
if (messageId) {
await session.pushByMessageId(messageId);
} else {
// revert the latest message generated by the assistant
// if messageId is not provided, then we can retry the action
await this.chatSession.revertLatestMessage(sessionId);
session.revertLatestMessage();
}
return session;
@@ -160,8 +163,12 @@ export class CopilotController implements BeforeApplicationShutdown {
const messageId = Array.isArray(params.messageId)
? params.messageId[0]
: params.messageId;
const retry = Array.isArray(params.retry)
? Boolean(params.retry[0])
: Boolean(params.retry);
delete params.messageId;
return { messageId, params };
delete params.retry;
return { messageId, retry, params };
}
private getSignal(req: Request) {
@@ -202,7 +209,7 @@ export class CopilotController implements BeforeApplicationShutdown {
@Param('sessionId') sessionId: string,
@Query() params: Record<string, string | string[]>
): Promise<string> {
const { messageId } = this.prepareParams(params);
const { messageId, retry } = this.prepareParams(params);
const provider = await this.chooseTextProvider(
user.id,
@@ -210,7 +217,11 @@ export class CopilotController implements BeforeApplicationShutdown {
messageId
);
const session = await this.appendSessionMessage(sessionId, messageId);
const session = await this.appendSessionMessage(
sessionId,
messageId,
retry
);
try {
metrics.ai.counter('chat_calls').add(1, { model: session.model });
const content = await provider.generateText(
@@ -248,7 +259,7 @@ export class CopilotController implements BeforeApplicationShutdown {
const info: any = { sessionId, params, throwInStream: false };
try {
const { messageId } = this.prepareParams(params);
const { messageId, retry } = this.prepareParams(params);
const provider = await this.chooseTextProvider(
user.id,
@@ -256,7 +267,11 @@ export class CopilotController implements BeforeApplicationShutdown {
messageId
);
const session = await this.appendSessionMessage(sessionId, messageId);
const session = await this.appendSessionMessage(
sessionId,
messageId,
retry
);
info.model = session.model;
metrics.ai.counter('chat_stream_calls').add(1, { model: session.model });

View File

@@ -75,10 +75,11 @@ export class ChatSession implements AsyncDisposable {
this.stashMessageCount += 1;
}
revertLatestMessage() {
revertLatestMessage(removeLatestUserMessage: boolean) {
const messages = this.state.messages;
messages.splice(
messages.findLastIndex(({ role }) => role === AiPromptRole.user) + 1
messages.findLastIndex(({ role }) => role === AiPromptRole.user) +
(removeLatestUserMessage ? 0 : 1)
);
}
@@ -341,7 +342,10 @@ export class ChatSessionService {
// revert the latest messages not generate by user
// after revert, we can retry the action
async revertLatestMessage(sessionId: string) {
async revertLatestMessage(
sessionId: string,
removeLatestUserMessage: boolean
) {
await this.db.$transaction(async tx => {
const id = await tx.aiSession
.findUnique({
@@ -361,7 +365,8 @@ export class ChatSessionService {
.then(roles =>
roles
.slice(
roles.findLastIndex(({ role }) => role === AiPromptRole.user) + 1
roles.findLastIndex(({ role }) => role === AiPromptRole.user) +
(removeLatestUserMessage ? 0 : 1)
)
.map(({ id }) => id)
);