feat: fork session support (#7367)

This commit is contained in:
darkskygit
2024-07-03 03:10:09 +00:00
parent 10df1fb4b7
commit 61870c04d0
12 changed files with 292 additions and 28 deletions

View File

@@ -27,7 +27,7 @@ import {
FileUpload,
MutexService,
Throttle,
TooManyRequestsException,
TooManyRequest,
} from '../../fundamentals';
import { PromptService } from './prompt';
import { ChatSessionService } from './session';
@@ -60,6 +60,24 @@ class CreateChatSessionInput {
promptName!: string;
}
@InputType()
class ForkChatSessionInput {
@Field(() => String)
workspaceId!: string;
@Field(() => String)
docId!: string;
@Field(() => String)
sessionId!: string;
@Field(() => String, {
description:
'Identify a message in the array and keep it with all previous messages into a forked session.',
})
latestMessageId!: string;
}
@InputType()
class DeleteSessionInput {
@Field(() => String)
@@ -109,6 +127,10 @@ class QueryChatHistoriesInput implements Partial<ListHistoriesOptions> {
@ObjectType('ChatMessage')
class ChatMessageType implements Partial<ChatMessage> {
// id will be null if message is a prompt message
@Field(() => ID, { nullable: true })
id!: string;
@Field(() => String)
role!: 'system' | 'assistant' | 'user';
@@ -301,7 +323,7 @@ export class CopilotResolver {
const lockFlag = `${COPILOT_LOCKER}:session:${user.id}:${options.workspaceId}`;
await using lock = await this.mutex.lock(lockFlag);
if (!lock) {
return new TooManyRequestsException('Server is busy');
return new TooManyRequest('Server is busy');
}
await this.chatSession.checkQuota(user.id);
@@ -313,6 +335,34 @@ export class CopilotResolver {
return session;
}
@Mutation(() => String, {
description: 'Create a chat session',
})
async forkCopilotSession(
@CurrentUser() user: CurrentUser,
@Args({ name: 'options', type: () => ForkChatSessionInput })
options: ForkChatSessionInput
) {
await this.permissions.checkCloudPagePermission(
options.workspaceId,
options.docId,
user.id
);
const lockFlag = `${COPILOT_LOCKER}:session:${user.id}:${options.workspaceId}`;
await using lock = await this.mutex.lock(lockFlag);
if (!lock) {
return new TooManyRequest('Server is busy');
}
await this.chatSession.checkQuota(user.id);
const session = await this.chatSession.fork({
...options,
userId: user.id,
});
return session;
}
@Mutation(() => [String], {
description: 'Cleanup sessions',
})
@@ -332,7 +382,7 @@ export class CopilotResolver {
const lockFlag = `${COPILOT_LOCKER}:session:${user.id}:${options.workspaceId}`;
await using lock = await this.mutex.lock(lockFlag);
if (!lock) {
return new TooManyRequestsException('Server is busy');
return new TooManyRequest('Server is busy');
}
return await this.chatSession.cleanup({
@@ -352,7 +402,7 @@ export class CopilotResolver {
const lockFlag = `${COPILOT_LOCKER}:message:${user?.id}:${options.sessionId}`;
await using lock = await this.mutex.lock(lockFlag);
if (!lock) {
return new TooManyRequestsException('Server is busy');
return new TooManyRequest('Server is busy');
}
const session = await this.chatSession.get(options.sessionId);
if (!session || session.config.userId !== user.id) {

View File

@@ -20,6 +20,7 @@ import {
ChatHistory,
ChatMessage,
ChatMessageSchema,
ChatSessionForkOptions,
ChatSessionOptions,
ChatSessionState,
getTokenEncoder,
@@ -81,7 +82,7 @@ export class ChatSession implements AsyncDisposable {
async getMessageById(messageId: string) {
const message = await this.messageCache.get(messageId);
if (!message || message.sessionId !== this.state.sessionId) {
throw new CopilotMessageNotFound();
throw new CopilotMessageNotFound({ messageId });
}
return message;
}
@@ -89,7 +90,7 @@ export class ChatSession implements AsyncDisposable {
async pushByMessageId(messageId: string) {
const message = await this.messageCache.get(messageId);
if (!message || message.sessionId !== this.state.sessionId) {
throw new CopilotMessageNotFound();
throw new CopilotMessageNotFound({ messageId });
}
this.push({
@@ -200,6 +201,7 @@ export class ChatSessionService {
workspaceId: state.workspaceId,
docId: state.docId,
prompt: { action: { equals: null } },
parentSessionId: state.parentSessionId,
},
select: { id: true, deletedAt: true },
})) || {};
@@ -271,8 +273,9 @@ export class ChatSessionService {
userId: true,
workspaceId: true,
docId: true,
parentSessionId: true,
messages: {
select: { role: true, content: true, createdAt: true },
select: { id: true, role: true, content: true, createdAt: true },
orderBy: { createdAt: 'asc' },
},
promptName: true,
@@ -291,6 +294,7 @@ export class ChatSessionService {
userId: session.userId,
workspaceId: session.workspaceId,
docId: session.docId,
parentSessionId: session.parentSessionId,
prompt,
messages: messages.success ? messages.data : [],
};
@@ -396,6 +400,7 @@ export class ChatSessionService {
createdAt: true,
messages: {
select: {
id: true,
role: true,
content: true,
attachments: true,
@@ -430,7 +435,8 @@ export class ChatSessionService {
.filter(({ role }) => role !== 'system')
: [];
// `createdAt` is required for history sorting in frontend, let's fake the creating time of prompt messages
// `createdAt` is required for history sorting in frontend
// let's fake the creating time of prompt messages
(preload as ChatMessage[]).forEach((msg, i) => {
msg.createdAt = new Date(
createdAt.getTime() - preload.length - i - 1
@@ -495,9 +501,39 @@ export class ChatSessionService {
sessionId,
prompt,
messages: [],
// when client create chat session, we always find root session
parentSessionId: null,
});
}
async fork(options: ChatSessionForkOptions): Promise<string> {
const state = await this.getSession(options.sessionId);
if (!state) {
throw new CopilotSessionNotFound();
}
const lastMessageIdx = state.messages.findLastIndex(
({ id, role }) =>
role === AiPromptRole.assistant && id === options.latestMessageId
);
if (lastMessageIdx < 0) {
throw new CopilotMessageNotFound({ messageId: options.latestMessageId });
}
const messages = state.messages
.slice(0, lastMessageIdx + 1)
.map(m => ({ ...m, id: undefined }));
const forkedState = {
...state,
sessionId: randomUUID(),
messages: [],
parentSessionId: options.sessionId,
};
// create session
await this.setSession(forkedState);
// save message
return await this.setSession({ ...forkedState, messages });
}
async cleanup(
options: Omit<ChatSessionOptions, 'promptName'> & { sessionIds: string[] }
) {

View File

@@ -64,6 +64,7 @@ export type PromptMessage = z.infer<typeof PromptMessageSchema>;
export type PromptParams = NonNullable<PromptMessage['params']>;
export const ChatMessageSchema = PromptMessageSchema.extend({
id: z.string().optional(),
createdAt: z.date(),
}).strict();
@@ -98,10 +99,17 @@ export interface ChatSessionOptions {
promptName: string;
}
export interface ChatSessionForkOptions
extends Omit<ChatSessionOptions, 'promptName'> {
sessionId: string;
latestMessageId: string;
}
export interface ChatSessionState
extends Omit<ChatSessionOptions, 'promptName'> {
// connect ids
sessionId: string;
parentSessionId: string | null;
// states
prompt: ChatPrompt;
messages: ChatMessage[];