mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-12 20:38:52 +00:00
feat: fork session support (#7367)
This commit is contained in:
@@ -440,7 +440,8 @@ export const USER_FRIENDLY_ERRORS = {
|
||||
},
|
||||
copilot_message_not_found: {
|
||||
type: 'resource_not_found',
|
||||
message: `Copilot message not found.`,
|
||||
args: { messageId: 'string' },
|
||||
message: ({ messageId }) => `Copilot message ${messageId} not found.`,
|
||||
},
|
||||
copilot_prompt_not_found: {
|
||||
type: 'resource_not_found',
|
||||
|
||||
@@ -391,10 +391,14 @@ export class CopilotActionTaken extends UserFriendlyError {
|
||||
super('action_forbidden', 'copilot_action_taken', message);
|
||||
}
|
||||
}
|
||||
@ObjectType()
|
||||
class CopilotMessageNotFoundDataType {
|
||||
@Field() messageId!: string
|
||||
}
|
||||
|
||||
export class CopilotMessageNotFound extends UserFriendlyError {
|
||||
constructor(message?: string) {
|
||||
super('resource_not_found', 'copilot_message_not_found', message);
|
||||
constructor(args: CopilotMessageNotFoundDataType, message?: string | ((args: CopilotMessageNotFoundDataType) => string)) {
|
||||
super('resource_not_found', 'copilot_message_not_found', message, args);
|
||||
}
|
||||
}
|
||||
@ObjectType()
|
||||
@@ -542,5 +546,5 @@ registerEnumType(ErrorNames, {
|
||||
export const ErrorDataUnionType = createUnionType({
|
||||
name: 'ErrorDataUnion',
|
||||
types: () =>
|
||||
[UnknownOauthProviderDataType, MissingOauthQueryParameterDataType, InvalidPasswordLengthDataType, WorkspaceNotFoundDataType, NotInWorkspaceDataType, WorkspaceAccessDeniedDataType, WorkspaceOwnerNotFoundDataType, DocNotFoundDataType, DocAccessDeniedDataType, VersionRejectedDataType, InvalidHistoryTimestampDataType, DocHistoryNotFoundDataType, BlobNotFoundDataType, SubscriptionAlreadyExistsDataType, SubscriptionNotExistsDataType, SameSubscriptionRecurringDataType, SubscriptionPlanNotFoundDataType, CopilotPromptNotFoundDataType, CopilotProviderSideErrorDataType, RuntimeConfigNotFoundDataType, InvalidRuntimeConfigTypeDataType] as const,
|
||||
[UnknownOauthProviderDataType, MissingOauthQueryParameterDataType, InvalidPasswordLengthDataType, WorkspaceNotFoundDataType, NotInWorkspaceDataType, WorkspaceAccessDeniedDataType, WorkspaceOwnerNotFoundDataType, DocNotFoundDataType, DocAccessDeniedDataType, VersionRejectedDataType, InvalidHistoryTimestampDataType, DocHistoryNotFoundDataType, BlobNotFoundDataType, SubscriptionAlreadyExistsDataType, SubscriptionNotExistsDataType, SameSubscriptionRecurringDataType, SubscriptionPlanNotFoundDataType, CopilotMessageNotFoundDataType, CopilotPromptNotFoundDataType, CopilotProviderSideErrorDataType, RuntimeConfigNotFoundDataType, InvalidRuntimeConfigTypeDataType] as const,
|
||||
});
|
||||
|
||||
@@ -27,7 +27,7 @@ import {
|
||||
FileUpload,
|
||||
MutexService,
|
||||
Throttle,
|
||||
TooManyRequestsException,
|
||||
TooManyRequest,
|
||||
} from '../../fundamentals';
|
||||
import { PromptService } from './prompt';
|
||||
import { ChatSessionService } from './session';
|
||||
@@ -60,6 +60,24 @@ class CreateChatSessionInput {
|
||||
promptName!: string;
|
||||
}
|
||||
|
||||
@InputType()
|
||||
class ForkChatSessionInput {
|
||||
@Field(() => String)
|
||||
workspaceId!: string;
|
||||
|
||||
@Field(() => String)
|
||||
docId!: string;
|
||||
|
||||
@Field(() => String)
|
||||
sessionId!: string;
|
||||
|
||||
@Field(() => String, {
|
||||
description:
|
||||
'Identify a message in the array and keep it with all previous messages into a forked session.',
|
||||
})
|
||||
latestMessageId!: string;
|
||||
}
|
||||
|
||||
@InputType()
|
||||
class DeleteSessionInput {
|
||||
@Field(() => String)
|
||||
@@ -109,6 +127,10 @@ class QueryChatHistoriesInput implements Partial<ListHistoriesOptions> {
|
||||
|
||||
@ObjectType('ChatMessage')
|
||||
class ChatMessageType implements Partial<ChatMessage> {
|
||||
// id will be null if message is a prompt message
|
||||
@Field(() => ID, { nullable: true })
|
||||
id!: string;
|
||||
|
||||
@Field(() => String)
|
||||
role!: 'system' | 'assistant' | 'user';
|
||||
|
||||
@@ -301,7 +323,7 @@ export class CopilotResolver {
|
||||
const lockFlag = `${COPILOT_LOCKER}:session:${user.id}:${options.workspaceId}`;
|
||||
await using lock = await this.mutex.lock(lockFlag);
|
||||
if (!lock) {
|
||||
return new TooManyRequestsException('Server is busy');
|
||||
return new TooManyRequest('Server is busy');
|
||||
}
|
||||
|
||||
await this.chatSession.checkQuota(user.id);
|
||||
@@ -313,6 +335,34 @@ export class CopilotResolver {
|
||||
return session;
|
||||
}
|
||||
|
||||
@Mutation(() => String, {
|
||||
description: 'Create a chat session',
|
||||
})
|
||||
async forkCopilotSession(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args({ name: 'options', type: () => ForkChatSessionInput })
|
||||
options: ForkChatSessionInput
|
||||
) {
|
||||
await this.permissions.checkCloudPagePermission(
|
||||
options.workspaceId,
|
||||
options.docId,
|
||||
user.id
|
||||
);
|
||||
const lockFlag = `${COPILOT_LOCKER}:session:${user.id}:${options.workspaceId}`;
|
||||
await using lock = await this.mutex.lock(lockFlag);
|
||||
if (!lock) {
|
||||
return new TooManyRequest('Server is busy');
|
||||
}
|
||||
|
||||
await this.chatSession.checkQuota(user.id);
|
||||
|
||||
const session = await this.chatSession.fork({
|
||||
...options,
|
||||
userId: user.id,
|
||||
});
|
||||
return session;
|
||||
}
|
||||
|
||||
@Mutation(() => [String], {
|
||||
description: 'Cleanup sessions',
|
||||
})
|
||||
@@ -332,7 +382,7 @@ export class CopilotResolver {
|
||||
const lockFlag = `${COPILOT_LOCKER}:session:${user.id}:${options.workspaceId}`;
|
||||
await using lock = await this.mutex.lock(lockFlag);
|
||||
if (!lock) {
|
||||
return new TooManyRequestsException('Server is busy');
|
||||
return new TooManyRequest('Server is busy');
|
||||
}
|
||||
|
||||
return await this.chatSession.cleanup({
|
||||
@@ -352,7 +402,7 @@ export class CopilotResolver {
|
||||
const lockFlag = `${COPILOT_LOCKER}:message:${user?.id}:${options.sessionId}`;
|
||||
await using lock = await this.mutex.lock(lockFlag);
|
||||
if (!lock) {
|
||||
return new TooManyRequestsException('Server is busy');
|
||||
return new TooManyRequest('Server is busy');
|
||||
}
|
||||
const session = await this.chatSession.get(options.sessionId);
|
||||
if (!session || session.config.userId !== user.id) {
|
||||
|
||||
@@ -20,6 +20,7 @@ import {
|
||||
ChatHistory,
|
||||
ChatMessage,
|
||||
ChatMessageSchema,
|
||||
ChatSessionForkOptions,
|
||||
ChatSessionOptions,
|
||||
ChatSessionState,
|
||||
getTokenEncoder,
|
||||
@@ -81,7 +82,7 @@ export class ChatSession implements AsyncDisposable {
|
||||
async getMessageById(messageId: string) {
|
||||
const message = await this.messageCache.get(messageId);
|
||||
if (!message || message.sessionId !== this.state.sessionId) {
|
||||
throw new CopilotMessageNotFound();
|
||||
throw new CopilotMessageNotFound({ messageId });
|
||||
}
|
||||
return message;
|
||||
}
|
||||
@@ -89,7 +90,7 @@ export class ChatSession implements AsyncDisposable {
|
||||
async pushByMessageId(messageId: string) {
|
||||
const message = await this.messageCache.get(messageId);
|
||||
if (!message || message.sessionId !== this.state.sessionId) {
|
||||
throw new CopilotMessageNotFound();
|
||||
throw new CopilotMessageNotFound({ messageId });
|
||||
}
|
||||
|
||||
this.push({
|
||||
@@ -200,6 +201,7 @@ export class ChatSessionService {
|
||||
workspaceId: state.workspaceId,
|
||||
docId: state.docId,
|
||||
prompt: { action: { equals: null } },
|
||||
parentSessionId: state.parentSessionId,
|
||||
},
|
||||
select: { id: true, deletedAt: true },
|
||||
})) || {};
|
||||
@@ -271,8 +273,9 @@ export class ChatSessionService {
|
||||
userId: true,
|
||||
workspaceId: true,
|
||||
docId: true,
|
||||
parentSessionId: true,
|
||||
messages: {
|
||||
select: { role: true, content: true, createdAt: true },
|
||||
select: { id: true, role: true, content: true, createdAt: true },
|
||||
orderBy: { createdAt: 'asc' },
|
||||
},
|
||||
promptName: true,
|
||||
@@ -291,6 +294,7 @@ export class ChatSessionService {
|
||||
userId: session.userId,
|
||||
workspaceId: session.workspaceId,
|
||||
docId: session.docId,
|
||||
parentSessionId: session.parentSessionId,
|
||||
prompt,
|
||||
messages: messages.success ? messages.data : [],
|
||||
};
|
||||
@@ -396,6 +400,7 @@ export class ChatSessionService {
|
||||
createdAt: true,
|
||||
messages: {
|
||||
select: {
|
||||
id: true,
|
||||
role: true,
|
||||
content: true,
|
||||
attachments: true,
|
||||
@@ -430,7 +435,8 @@ export class ChatSessionService {
|
||||
.filter(({ role }) => role !== 'system')
|
||||
: [];
|
||||
|
||||
// `createdAt` is required for history sorting in frontend, let's fake the creating time of prompt messages
|
||||
// `createdAt` is required for history sorting in frontend
|
||||
// let's fake the creating time of prompt messages
|
||||
(preload as ChatMessage[]).forEach((msg, i) => {
|
||||
msg.createdAt = new Date(
|
||||
createdAt.getTime() - preload.length - i - 1
|
||||
@@ -495,9 +501,39 @@ export class ChatSessionService {
|
||||
sessionId,
|
||||
prompt,
|
||||
messages: [],
|
||||
// when client create chat session, we always find root session
|
||||
parentSessionId: null,
|
||||
});
|
||||
}
|
||||
|
||||
async fork(options: ChatSessionForkOptions): Promise<string> {
|
||||
const state = await this.getSession(options.sessionId);
|
||||
if (!state) {
|
||||
throw new CopilotSessionNotFound();
|
||||
}
|
||||
const lastMessageIdx = state.messages.findLastIndex(
|
||||
({ id, role }) =>
|
||||
role === AiPromptRole.assistant && id === options.latestMessageId
|
||||
);
|
||||
if (lastMessageIdx < 0) {
|
||||
throw new CopilotMessageNotFound({ messageId: options.latestMessageId });
|
||||
}
|
||||
const messages = state.messages
|
||||
.slice(0, lastMessageIdx + 1)
|
||||
.map(m => ({ ...m, id: undefined }));
|
||||
|
||||
const forkedState = {
|
||||
...state,
|
||||
sessionId: randomUUID(),
|
||||
messages: [],
|
||||
parentSessionId: options.sessionId,
|
||||
};
|
||||
// create session
|
||||
await this.setSession(forkedState);
|
||||
// save message
|
||||
return await this.setSession({ ...forkedState, messages });
|
||||
}
|
||||
|
||||
async cleanup(
|
||||
options: Omit<ChatSessionOptions, 'promptName'> & { sessionIds: string[] }
|
||||
) {
|
||||
|
||||
@@ -64,6 +64,7 @@ export type PromptMessage = z.infer<typeof PromptMessageSchema>;
|
||||
export type PromptParams = NonNullable<PromptMessage['params']>;
|
||||
|
||||
export const ChatMessageSchema = PromptMessageSchema.extend({
|
||||
id: z.string().optional(),
|
||||
createdAt: z.date(),
|
||||
}).strict();
|
||||
|
||||
@@ -98,10 +99,17 @@ export interface ChatSessionOptions {
|
||||
promptName: string;
|
||||
}
|
||||
|
||||
export interface ChatSessionForkOptions
|
||||
extends Omit<ChatSessionOptions, 'promptName'> {
|
||||
sessionId: string;
|
||||
latestMessageId: string;
|
||||
}
|
||||
|
||||
export interface ChatSessionState
|
||||
extends Omit<ChatSessionOptions, 'promptName'> {
|
||||
// connect ids
|
||||
sessionId: string;
|
||||
parentSessionId: string | null;
|
||||
// states
|
||||
prompt: ChatPrompt;
|
||||
messages: ChatMessage[];
|
||||
|
||||
@@ -11,6 +11,7 @@ type ChatMessage {
|
||||
attachments: [String!]
|
||||
content: String!
|
||||
createdAt: DateTime!
|
||||
id: ID
|
||||
params: JSON
|
||||
role: String!
|
||||
}
|
||||
@@ -39,6 +40,10 @@ type CopilotHistories {
|
||||
tokens: Int!
|
||||
}
|
||||
|
||||
type CopilotMessageNotFoundDataType {
|
||||
messageId: String!
|
||||
}
|
||||
|
||||
enum CopilotModels {
|
||||
DallE3
|
||||
Gpt4Omni
|
||||
@@ -175,7 +180,7 @@ enum EarlyAccessType {
|
||||
App
|
||||
}
|
||||
|
||||
union ErrorDataUnion = BlobNotFoundDataType | CopilotPromptNotFoundDataType | CopilotProviderSideErrorDataType | DocAccessDeniedDataType | DocHistoryNotFoundDataType | DocNotFoundDataType | InvalidHistoryTimestampDataType | InvalidPasswordLengthDataType | InvalidRuntimeConfigTypeDataType | MissingOauthQueryParameterDataType | NotInWorkspaceDataType | RuntimeConfigNotFoundDataType | SameSubscriptionRecurringDataType | SubscriptionAlreadyExistsDataType | SubscriptionNotExistsDataType | SubscriptionPlanNotFoundDataType | UnknownOauthProviderDataType | VersionRejectedDataType | WorkspaceAccessDeniedDataType | WorkspaceNotFoundDataType | WorkspaceOwnerNotFoundDataType
|
||||
union ErrorDataUnion = BlobNotFoundDataType | CopilotMessageNotFoundDataType | CopilotPromptNotFoundDataType | CopilotProviderSideErrorDataType | DocAccessDeniedDataType | DocHistoryNotFoundDataType | DocNotFoundDataType | InvalidHistoryTimestampDataType | InvalidPasswordLengthDataType | InvalidRuntimeConfigTypeDataType | MissingOauthQueryParameterDataType | NotInWorkspaceDataType | RuntimeConfigNotFoundDataType | SameSubscriptionRecurringDataType | SubscriptionAlreadyExistsDataType | SubscriptionNotExistsDataType | SubscriptionPlanNotFoundDataType | UnknownOauthProviderDataType | VersionRejectedDataType | WorkspaceAccessDeniedDataType | WorkspaceNotFoundDataType | WorkspaceOwnerNotFoundDataType
|
||||
|
||||
enum ErrorNames {
|
||||
ACCESS_DENIED
|
||||
@@ -252,6 +257,17 @@ enum FeatureType {
|
||||
UnlimitedWorkspace
|
||||
}
|
||||
|
||||
input ForkChatSessionInput {
|
||||
docId: String!
|
||||
|
||||
"""
|
||||
Identify a message in the array and keep it with all previous messages into a forked session.
|
||||
"""
|
||||
latestMessageId: String!
|
||||
sessionId: String!
|
||||
workspaceId: String!
|
||||
}
|
||||
|
||||
type HumanReadableQuotaType {
|
||||
blobLimit: String!
|
||||
copilotActionLimit: String
|
||||
@@ -399,6 +415,9 @@ type Mutation {
|
||||
"""Delete a user account"""
|
||||
deleteUser(id: String!): DeleteAccount!
|
||||
deleteWorkspace(id: String!): Boolean!
|
||||
|
||||
"""Create a chat session"""
|
||||
forkCopilotSession(options: ForkChatSessionInput!): String!
|
||||
invite(email: String!, permission: Permission!, sendInviteMail: Boolean, workspaceId: String!): String!
|
||||
leaveWorkspace(sendLeaveMail: Boolean, workspaceId: String!, workspaceName: String!): Boolean!
|
||||
publishPage(mode: PublicPageMode = Page, pageId: String!, workspaceId: String!): WorkspacePage!
|
||||
|
||||
Reference in New Issue
Block a user