feat: fork session support (#7367)

This commit is contained in:
darkskygit
2024-07-03 03:10:09 +00:00
parent 10df1fb4b7
commit 61870c04d0
12 changed files with 292 additions and 28 deletions

View File

@@ -440,7 +440,8 @@ export const USER_FRIENDLY_ERRORS = {
},
copilot_message_not_found: {
type: 'resource_not_found',
message: `Copilot message not found.`,
args: { messageId: 'string' },
message: ({ messageId }) => `Copilot message ${messageId} not found.`,
},
copilot_prompt_not_found: {
type: 'resource_not_found',

View File

@@ -391,10 +391,14 @@ export class CopilotActionTaken extends UserFriendlyError {
super('action_forbidden', 'copilot_action_taken', message);
}
}
@ObjectType()
class CopilotMessageNotFoundDataType {
@Field() messageId!: string
}
export class CopilotMessageNotFound extends UserFriendlyError {
constructor(message?: string) {
super('resource_not_found', 'copilot_message_not_found', message);
constructor(args: CopilotMessageNotFoundDataType, message?: string | ((args: CopilotMessageNotFoundDataType) => string)) {
super('resource_not_found', 'copilot_message_not_found', message, args);
}
}
@ObjectType()
@@ -542,5 +546,5 @@ registerEnumType(ErrorNames, {
export const ErrorDataUnionType = createUnionType({
name: 'ErrorDataUnion',
types: () =>
[UnknownOauthProviderDataType, MissingOauthQueryParameterDataType, InvalidPasswordLengthDataType, WorkspaceNotFoundDataType, NotInWorkspaceDataType, WorkspaceAccessDeniedDataType, WorkspaceOwnerNotFoundDataType, DocNotFoundDataType, DocAccessDeniedDataType, VersionRejectedDataType, InvalidHistoryTimestampDataType, DocHistoryNotFoundDataType, BlobNotFoundDataType, SubscriptionAlreadyExistsDataType, SubscriptionNotExistsDataType, SameSubscriptionRecurringDataType, SubscriptionPlanNotFoundDataType, CopilotPromptNotFoundDataType, CopilotProviderSideErrorDataType, RuntimeConfigNotFoundDataType, InvalidRuntimeConfigTypeDataType] as const,
[UnknownOauthProviderDataType, MissingOauthQueryParameterDataType, InvalidPasswordLengthDataType, WorkspaceNotFoundDataType, NotInWorkspaceDataType, WorkspaceAccessDeniedDataType, WorkspaceOwnerNotFoundDataType, DocNotFoundDataType, DocAccessDeniedDataType, VersionRejectedDataType, InvalidHistoryTimestampDataType, DocHistoryNotFoundDataType, BlobNotFoundDataType, SubscriptionAlreadyExistsDataType, SubscriptionNotExistsDataType, SameSubscriptionRecurringDataType, SubscriptionPlanNotFoundDataType, CopilotMessageNotFoundDataType, CopilotPromptNotFoundDataType, CopilotProviderSideErrorDataType, RuntimeConfigNotFoundDataType, InvalidRuntimeConfigTypeDataType] as const,
});

View File

@@ -27,7 +27,7 @@ import {
FileUpload,
MutexService,
Throttle,
TooManyRequestsException,
TooManyRequest,
} from '../../fundamentals';
import { PromptService } from './prompt';
import { ChatSessionService } from './session';
@@ -60,6 +60,24 @@ class CreateChatSessionInput {
promptName!: string;
}
@InputType()
class ForkChatSessionInput {
@Field(() => String)
workspaceId!: string;
@Field(() => String)
docId!: string;
@Field(() => String)
sessionId!: string;
@Field(() => String, {
description:
'Identify a message in the array and keep it with all previous messages into a forked session.',
})
latestMessageId!: string;
}
@InputType()
class DeleteSessionInput {
@Field(() => String)
@@ -109,6 +127,10 @@ class QueryChatHistoriesInput implements Partial<ListHistoriesOptions> {
@ObjectType('ChatMessage')
class ChatMessageType implements Partial<ChatMessage> {
// id will be null if message is a prompt message
@Field(() => ID, { nullable: true })
id!: string;
@Field(() => String)
role!: 'system' | 'assistant' | 'user';
@@ -301,7 +323,7 @@ export class CopilotResolver {
const lockFlag = `${COPILOT_LOCKER}:session:${user.id}:${options.workspaceId}`;
await using lock = await this.mutex.lock(lockFlag);
if (!lock) {
return new TooManyRequestsException('Server is busy');
return new TooManyRequest('Server is busy');
}
await this.chatSession.checkQuota(user.id);
@@ -313,6 +335,34 @@ export class CopilotResolver {
return session;
}
@Mutation(() => String, {
description: 'Create a chat session',
})
async forkCopilotSession(
@CurrentUser() user: CurrentUser,
@Args({ name: 'options', type: () => ForkChatSessionInput })
options: ForkChatSessionInput
) {
await this.permissions.checkCloudPagePermission(
options.workspaceId,
options.docId,
user.id
);
const lockFlag = `${COPILOT_LOCKER}:session:${user.id}:${options.workspaceId}`;
await using lock = await this.mutex.lock(lockFlag);
if (!lock) {
return new TooManyRequest('Server is busy');
}
await this.chatSession.checkQuota(user.id);
const session = await this.chatSession.fork({
...options,
userId: user.id,
});
return session;
}
@Mutation(() => [String], {
description: 'Cleanup sessions',
})
@@ -332,7 +382,7 @@ export class CopilotResolver {
const lockFlag = `${COPILOT_LOCKER}:session:${user.id}:${options.workspaceId}`;
await using lock = await this.mutex.lock(lockFlag);
if (!lock) {
return new TooManyRequestsException('Server is busy');
return new TooManyRequest('Server is busy');
}
return await this.chatSession.cleanup({
@@ -352,7 +402,7 @@ export class CopilotResolver {
const lockFlag = `${COPILOT_LOCKER}:message:${user?.id}:${options.sessionId}`;
await using lock = await this.mutex.lock(lockFlag);
if (!lock) {
return new TooManyRequestsException('Server is busy');
return new TooManyRequest('Server is busy');
}
const session = await this.chatSession.get(options.sessionId);
if (!session || session.config.userId !== user.id) {

View File

@@ -20,6 +20,7 @@ import {
ChatHistory,
ChatMessage,
ChatMessageSchema,
ChatSessionForkOptions,
ChatSessionOptions,
ChatSessionState,
getTokenEncoder,
@@ -81,7 +82,7 @@ export class ChatSession implements AsyncDisposable {
async getMessageById(messageId: string) {
const message = await this.messageCache.get(messageId);
if (!message || message.sessionId !== this.state.sessionId) {
throw new CopilotMessageNotFound();
throw new CopilotMessageNotFound({ messageId });
}
return message;
}
@@ -89,7 +90,7 @@ export class ChatSession implements AsyncDisposable {
async pushByMessageId(messageId: string) {
const message = await this.messageCache.get(messageId);
if (!message || message.sessionId !== this.state.sessionId) {
throw new CopilotMessageNotFound();
throw new CopilotMessageNotFound({ messageId });
}
this.push({
@@ -200,6 +201,7 @@ export class ChatSessionService {
workspaceId: state.workspaceId,
docId: state.docId,
prompt: { action: { equals: null } },
parentSessionId: state.parentSessionId,
},
select: { id: true, deletedAt: true },
})) || {};
@@ -271,8 +273,9 @@ export class ChatSessionService {
userId: true,
workspaceId: true,
docId: true,
parentSessionId: true,
messages: {
select: { role: true, content: true, createdAt: true },
select: { id: true, role: true, content: true, createdAt: true },
orderBy: { createdAt: 'asc' },
},
promptName: true,
@@ -291,6 +294,7 @@ export class ChatSessionService {
userId: session.userId,
workspaceId: session.workspaceId,
docId: session.docId,
parentSessionId: session.parentSessionId,
prompt,
messages: messages.success ? messages.data : [],
};
@@ -396,6 +400,7 @@ export class ChatSessionService {
createdAt: true,
messages: {
select: {
id: true,
role: true,
content: true,
attachments: true,
@@ -430,7 +435,8 @@ export class ChatSessionService {
.filter(({ role }) => role !== 'system')
: [];
// `createdAt` is required for history sorting in frontend, let's fake the creating time of prompt messages
// `createdAt` is required for history sorting in frontend
// let's fake the creating time of prompt messages
(preload as ChatMessage[]).forEach((msg, i) => {
msg.createdAt = new Date(
createdAt.getTime() - preload.length - i - 1
@@ -495,9 +501,39 @@ export class ChatSessionService {
sessionId,
prompt,
messages: [],
// when client create chat session, we always find root session
parentSessionId: null,
});
}
async fork(options: ChatSessionForkOptions): Promise<string> {
const state = await this.getSession(options.sessionId);
if (!state) {
throw new CopilotSessionNotFound();
}
const lastMessageIdx = state.messages.findLastIndex(
({ id, role }) =>
role === AiPromptRole.assistant && id === options.latestMessageId
);
if (lastMessageIdx < 0) {
throw new CopilotMessageNotFound({ messageId: options.latestMessageId });
}
const messages = state.messages
.slice(0, lastMessageIdx + 1)
.map(m => ({ ...m, id: undefined }));
const forkedState = {
...state,
sessionId: randomUUID(),
messages: [],
parentSessionId: options.sessionId,
};
// create session
await this.setSession(forkedState);
// save message
return await this.setSession({ ...forkedState, messages });
}
async cleanup(
options: Omit<ChatSessionOptions, 'promptName'> & { sessionIds: string[] }
) {

View File

@@ -64,6 +64,7 @@ export type PromptMessage = z.infer<typeof PromptMessageSchema>;
export type PromptParams = NonNullable<PromptMessage['params']>;
export const ChatMessageSchema = PromptMessageSchema.extend({
id: z.string().optional(),
createdAt: z.date(),
}).strict();
@@ -98,10 +99,17 @@ export interface ChatSessionOptions {
promptName: string;
}
export interface ChatSessionForkOptions
extends Omit<ChatSessionOptions, 'promptName'> {
sessionId: string;
latestMessageId: string;
}
export interface ChatSessionState
extends Omit<ChatSessionOptions, 'promptName'> {
// connect ids
sessionId: string;
parentSessionId: string | null;
// states
prompt: ChatPrompt;
messages: ChatMessage[];

View File

@@ -11,6 +11,7 @@ type ChatMessage {
attachments: [String!]
content: String!
createdAt: DateTime!
id: ID
params: JSON
role: String!
}
@@ -39,6 +40,10 @@ type CopilotHistories {
tokens: Int!
}
type CopilotMessageNotFoundDataType {
messageId: String!
}
enum CopilotModels {
DallE3
Gpt4Omni
@@ -175,7 +180,7 @@ enum EarlyAccessType {
App
}
union ErrorDataUnion = BlobNotFoundDataType | CopilotPromptNotFoundDataType | CopilotProviderSideErrorDataType | DocAccessDeniedDataType | DocHistoryNotFoundDataType | DocNotFoundDataType | InvalidHistoryTimestampDataType | InvalidPasswordLengthDataType | InvalidRuntimeConfigTypeDataType | MissingOauthQueryParameterDataType | NotInWorkspaceDataType | RuntimeConfigNotFoundDataType | SameSubscriptionRecurringDataType | SubscriptionAlreadyExistsDataType | SubscriptionNotExistsDataType | SubscriptionPlanNotFoundDataType | UnknownOauthProviderDataType | VersionRejectedDataType | WorkspaceAccessDeniedDataType | WorkspaceNotFoundDataType | WorkspaceOwnerNotFoundDataType
union ErrorDataUnion = BlobNotFoundDataType | CopilotMessageNotFoundDataType | CopilotPromptNotFoundDataType | CopilotProviderSideErrorDataType | DocAccessDeniedDataType | DocHistoryNotFoundDataType | DocNotFoundDataType | InvalidHistoryTimestampDataType | InvalidPasswordLengthDataType | InvalidRuntimeConfigTypeDataType | MissingOauthQueryParameterDataType | NotInWorkspaceDataType | RuntimeConfigNotFoundDataType | SameSubscriptionRecurringDataType | SubscriptionAlreadyExistsDataType | SubscriptionNotExistsDataType | SubscriptionPlanNotFoundDataType | UnknownOauthProviderDataType | VersionRejectedDataType | WorkspaceAccessDeniedDataType | WorkspaceNotFoundDataType | WorkspaceOwnerNotFoundDataType
enum ErrorNames {
ACCESS_DENIED
@@ -252,6 +257,17 @@ enum FeatureType {
UnlimitedWorkspace
}
input ForkChatSessionInput {
docId: String!
"""
Identify a message in the array and keep it with all previous messages into a forked session.
"""
latestMessageId: String!
sessionId: String!
workspaceId: String!
}
type HumanReadableQuotaType {
blobLimit: String!
copilotActionLimit: String
@@ -399,6 +415,9 @@ type Mutation {
"""Delete a user account"""
deleteUser(id: String!): DeleteAccount!
deleteWorkspace(id: String!): Boolean!
"""Create a chat session"""
forkCopilotSession(options: ForkChatSessionInput!): String!
invite(email: String!, permission: Permission!, sendInviteMail: Boolean, workspaceId: String!): String!
leaveWorkspace(sendLeaveMail: Boolean, workspaceId: String!, workspaceName: String!): Boolean!
publishPage(mode: PublicPageMode = Page, pageId: String!, workspaceId: String!): WorkspacePage!