mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-13 04:48:53 +00:00
feat: add retry support for copilot (#6947)
This commit is contained in:
@@ -82,14 +82,21 @@ export class CopilotController {
|
||||
|
||||
private async appendSessionMessage(
|
||||
sessionId: string,
|
||||
messageId: string
|
||||
messageId?: string
|
||||
): Promise<ChatSession> {
|
||||
const session = await this.chatSession.get(sessionId);
|
||||
if (!session) {
|
||||
throw new BadRequestException('Session not found');
|
||||
}
|
||||
|
||||
await session.pushByMessageId(messageId);
|
||||
if (messageId) {
|
||||
await session.pushByMessageId(messageId);
|
||||
} else {
|
||||
// revert the latest message generated by the assistant
|
||||
// if messageId is not provided, then we can retry the action
|
||||
await this.chatSession.revertLatestMessage(sessionId);
|
||||
session.revertLatestMessage();
|
||||
}
|
||||
|
||||
return session;
|
||||
}
|
||||
@@ -129,7 +136,6 @@ export class CopilotController {
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Req() req: Request,
|
||||
@Param('sessionId') sessionId: string,
|
||||
@Query('messageId') messageId: string,
|
||||
@Query() params: Record<string, string | string[]>
|
||||
): Promise<string> {
|
||||
const { model } = await this.checkRequest(user.id, sessionId);
|
||||
@@ -141,6 +147,9 @@ export class CopilotController {
|
||||
throw new InternalServerErrorException('No provider available');
|
||||
}
|
||||
|
||||
const messageId = Array.isArray(params.messageId)
|
||||
? params.messageId[0]
|
||||
: params.messageId;
|
||||
const session = await this.appendSessionMessage(sessionId, messageId);
|
||||
|
||||
try {
|
||||
@@ -174,7 +183,6 @@ export class CopilotController {
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Req() req: Request,
|
||||
@Param('sessionId') sessionId: string,
|
||||
@Query('messageId') messageId: string,
|
||||
@Query() params: Record<string, string>
|
||||
): Promise<Observable<ChatEvent>> {
|
||||
try {
|
||||
@@ -187,6 +195,9 @@ export class CopilotController {
|
||||
throw new InternalServerErrorException('No provider available');
|
||||
}
|
||||
|
||||
const messageId = Array.isArray(params.messageId)
|
||||
? params.messageId[0]
|
||||
: params.messageId;
|
||||
const session = await this.appendSessionMessage(sessionId, messageId);
|
||||
delete params.messageId;
|
||||
|
||||
@@ -237,10 +248,12 @@ export class CopilotController {
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Req() req: Request,
|
||||
@Param('sessionId') sessionId: string,
|
||||
@Query('messageId') messageId: string,
|
||||
@Query() params: Record<string, string>
|
||||
): Promise<Observable<ChatEvent>> {
|
||||
try {
|
||||
const messageId = Array.isArray(params.messageId)
|
||||
? params.messageId[0]
|
||||
: params.messageId;
|
||||
const { model, hasAttachment } = await this.checkRequest(
|
||||
user.id,
|
||||
sessionId,
|
||||
|
||||
@@ -64,6 +64,13 @@ export class ChatSession implements AsyncDisposable {
|
||||
this.stashMessageCount += 1;
|
||||
}
|
||||
|
||||
revertLatestMessage() {
|
||||
const messages = this.state.messages;
|
||||
messages.splice(
|
||||
messages.findLastIndex(({ role }) => role === AiPromptRole.user) + 1
|
||||
);
|
||||
}
|
||||
|
||||
async getMessageById(messageId: string) {
|
||||
const message = await this.messageCache.get(messageId);
|
||||
if (!message || message.sessionId !== this.state.sessionId) {
|
||||
@@ -287,6 +294,29 @@ export class ChatSessionService {
|
||||
});
|
||||
}
|
||||
|
||||
// revert the latest messages not generate by user
|
||||
// after revert, we can retry the action
|
||||
async revertLatestMessage(sessionId: string) {
|
||||
await this.db.$transaction(async tx => {
|
||||
const ids = await tx.aiSessionMessage
|
||||
.findMany({
|
||||
where: { sessionId },
|
||||
select: { id: true, role: true },
|
||||
orderBy: { createdAt: 'asc' },
|
||||
})
|
||||
.then(roles =>
|
||||
roles
|
||||
.slice(
|
||||
roles.findLastIndex(({ role }) => role === AiPromptRole.user) + 1
|
||||
)
|
||||
.map(({ id }) => id)
|
||||
);
|
||||
if (ids.length) {
|
||||
await tx.aiSessionMessage.deleteMany({ where: { id: { in: ids } } });
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private calculateTokenSize(
|
||||
messages: PromptMessage[],
|
||||
model: AvailableModel
|
||||
|
||||
Reference in New Issue
Block a user