From e77475aca53fcd744559c1f5a9aa41f6e4ce81f4 Mon Sep 17 00:00:00 2001 From: darkskygit Date: Fri, 12 Apr 2024 08:39:32 +0000 Subject: [PATCH] feat: detailed copilot histories (#6523) --- .../server/src/plugins/copilot/prompt.ts | 102 ++++++++++-------- .../server/src/plugins/copilot/resolver.ts | 38 ++++--- .../server/src/plugins/copilot/session.ts | 36 +++++-- .../server/src/plugins/copilot/types.ts | 1 + packages/backend/server/src/schema.gql | 12 ++- 5 files changed, 112 insertions(+), 77 deletions(-) diff --git a/packages/backend/server/src/plugins/copilot/prompt.ts b/packages/backend/server/src/plugins/copilot/prompt.ts index 7d7f4929d8..f7ab278e77 100644 --- a/packages/backend/server/src/plugins/copilot/prompt.ts +++ b/packages/backend/server/src/plugins/copilot/prompt.ts @@ -107,11 +107,12 @@ export class ChatPrompt { * @param params record of params, e.g. { name: 'Alice' } * @returns e.g. [{ role: 'system', content: 'Hello, {{name}}' }] => [{ role: 'system', content: 'Hello, Alice' }] */ - finish(params: PromptParams) { + finish(params: PromptParams): PromptMessage[] { this.checkParams(params); - return this.messages.map(m => ({ - ...m, - content: Mustache.render(m.content, params), + return this.messages.map(({ content, params: _, ...rest }) => ({ + ...rest, + params, + content: Mustache.render(content, params), })); } @@ -122,6 +123,8 @@ export class ChatPrompt { @Injectable() export class PromptService { + private readonly cache = new Map(); + constructor(private readonly db: PrismaClient) {} /** @@ -140,34 +143,40 @@ export class PromptService { * @returns prompt messages */ async get(name: string): Promise { - return this.db.aiPrompt - .findUnique({ - where: { - name, - }, - select: { - name: true, - action: true, - model: true, - messages: { - select: { - role: true, - content: true, - params: true, - }, - orderBy: { - idx: 'asc', - }, + const cached = this.cache.get(name); + if (cached) return cached; + + const prompt = await this.db.aiPrompt.findUnique({ + where: { + name, + }, + select: { + name: true, + action: true, + model: true, + messages: { + select: { + role: true, + content: true, + params: true, + }, + orderBy: { + idx: 'asc', }, }, - }) - .then(p => { - const messages = PromptMessageSchema.array().safeParse(p?.messages); - if (p && messages.success) { - return ChatPrompt.createFromPrompt({ ...p, messages: messages.data }); - } - return null; + }, + }); + + const messages = PromptMessageSchema.array().safeParse(prompt?.messages); + if (prompt && messages.success) { + const chatPrompt = ChatPrompt.createFromPrompt({ + ...prompt, + messages: messages.data, }); + this.cache.set(name, chatPrompt); + return chatPrompt; + } + return null; } async set(name: string, messages: PromptMessage[]) { @@ -188,25 +197,28 @@ export class PromptService { } async update(name: string, messages: PromptMessage[]) { - return this.db.aiPrompt - .update({ - where: { name }, - data: { - messages: { - // cleanup old messages - deleteMany: {}, - create: messages.map((m, idx) => ({ - idx, - ...m, - params: m.params || undefined, - })), - }, + const { id } = await this.db.aiPrompt.update({ + where: { name }, + data: { + messages: { + // cleanup old messages + deleteMany: {}, + create: messages.map((m, idx) => ({ + idx, + ...m, + params: m.params || undefined, + })), }, - }) - .then(ret => ret.id); + }, + }); + + this.cache.delete(name); + return id; } async delete(name: string) { - return this.db.aiPrompt.delete({ where: { name } }).then(ret => ret.id); + const { id } = await this.db.aiPrompt.delete({ where: { name } }); + this.cache.delete(name); + return id; } } diff --git a/packages/backend/server/src/plugins/copilot/resolver.ts b/packages/backend/server/src/plugins/copilot/resolver.ts index 7d767a2591..44389a15cf 100644 --- a/packages/backend/server/src/plugins/copilot/resolver.ts +++ b/packages/backend/server/src/plugins/copilot/resolver.ts @@ -11,7 +11,7 @@ import { ResolveField, Resolver, } from '@nestjs/graphql'; -import { SafeIntResolver } from 'graphql-scalars'; +import { GraphQLJSON, SafeIntResolver } from 'graphql-scalars'; import { CurrentUser } from '../../core/auth'; import { QuotaService } from '../../core/quota'; @@ -45,12 +45,6 @@ class CreateChatSessionInput { @Field(() => String) docId!: string; - @Field(() => String, { - description: 'An mark identifying which view to use to display the session', - nullable: true, - }) - action!: string | undefined; - @Field(() => String, { description: 'The prompt name to use for the session', }) @@ -58,18 +52,18 @@ class CreateChatSessionInput { } @InputType() -class CreateChatMessageInput implements Omit { +class CreateChatMessageInput implements Omit { @Field(() => String) sessionId!: string; - @Field(() => String) - content!: string; + @Field(() => String, { nullable: true }) + content!: string | undefined; @Field(() => [String], { nullable: true }) attachments!: string[] | undefined; - @Field(() => String, { nullable: true }) - params!: string | undefined; + @Field(() => GraphQLJSON, { nullable: true }) + params!: Record | undefined; } @InputType() @@ -100,6 +94,9 @@ class ChatMessageType implements Partial { @Field(() => [String], { nullable: true }) attachments!: string[]; + @Field(() => GraphQLJSON, { nullable: true }) + params!: Record | undefined; + @Field(() => Date, { nullable: true }) createdAt!: Date | undefined; } @@ -227,12 +224,18 @@ export class CopilotResolver { await this.permissions.checkCloudWorkspace(workspaceId, user.id); } - return await this.chatSession.listHistories( + const histories = await this.chatSession.listHistories( user.id, workspaceId, docId, - options + options, + true ); + return histories.map(h => ({ + ...h, + // filter out empty messages + messages: h.messages.filter(m => m.content || m.attachments?.length), + })); } @Mutation(() => String, { @@ -282,12 +285,7 @@ export class CopilotResolver { return new TooManyRequestsException('Server is busy'); } try { - const { params, ...rest } = options; - const record: SubmittedMessage['params'] = {}; - new URLSearchParams(params).forEach((value, key) => { - record[key] = value; - }); - return await this.chatSession.createMessage({ ...rest, params: record }); + return await this.chatSession.createMessage(options); } catch (e: any) { this.logger.error(`Failed to create chat message: ${e.message}`); throw new Error('Failed to create chat message'); diff --git a/packages/backend/server/src/plugins/copilot/session.ts b/packages/backend/server/src/plugins/copilot/session.ts index 9de9b83980..f29076d11d 100644 --- a/packages/backend/server/src/plugins/copilot/session.ts +++ b/packages/backend/server/src/plugins/copilot/session.ts @@ -59,7 +59,7 @@ export class ChatSession implements AsyncDisposable { this.push({ role: 'user', - content: message.content, + content: message.content || '', attachments: message.attachments, params: message.params, createdAt: new Date(), @@ -96,7 +96,12 @@ export class ChatSession implements AsyncDisposable { finish(params: PromptParams): PromptMessage[] { const messages = this.takeMessages(); - return [...this.state.prompt.finish(params), ...messages]; + return [ + ...this.state.prompt.finish( + Object.keys(params).length ? params : messages[0]?.params || {} + ), + ...messages.filter(m => m.content || m.attachments?.length), + ]; } async save() { @@ -257,7 +262,8 @@ export class ChatSessionService { userId: string, workspaceId?: string, docId?: string, - options?: ListHistoriesOptions + options?: ListHistoriesOptions, + withPrompt = false ): Promise { return await this.db.aiSession .findMany({ @@ -272,11 +278,12 @@ export class ChatSessionService { }, select: { id: true, - prompt: true, + promptName: true, messages: { select: { role: true, content: true, + params: true, }, orderBy: { createdAt: 'asc', @@ -288,20 +295,30 @@ export class ChatSessionService { orderBy: { createdAt: 'desc' }, }) .then(sessions => - sessions - .map(({ id, prompt, messages }) => { + Promise.all( + sessions.map(async ({ id, promptName, messages }) => { try { const ret = PromptMessageSchema.array().safeParse(messages); if (ret.success) { + const prompt = await this.prompt.get(promptName); + if (!prompt) { + throw new Error(`Prompt not found: ${promptName}`); + } const tokens = this.calculateTokenSize( ret.data, prompt.model as AvailableModel ); + + // render system prompt + const preload = withPrompt + ? prompt.finish(ret.data[0]?.params || {}) + : []; + return { sessionId: id, action: prompt.action || undefined, tokens, - messages: ret.data, + messages: preload.concat(ret.data), }; } else { this.logger.error( @@ -313,7 +330,10 @@ export class ChatSessionService { } return undefined; }) - .filter((v): v is NonNullable => !!v) + ) + ) + .then(histories => + histories.filter((v): v is NonNullable => !!v) ); } diff --git a/packages/backend/server/src/plugins/copilot/types.ts b/packages/backend/server/src/plugins/copilot/types.ts index 59870d0888..9ae30a954a 100644 --- a/packages/backend/server/src/plugins/copilot/types.ts +++ b/packages/backend/server/src/plugins/copilot/types.ts @@ -82,6 +82,7 @@ export type ChatMessage = z.infer; export const SubmittedMessageSchema = PureMessageSchema.extend({ sessionId: z.string(), + content: z.string().optional(), }).strict(); export type SubmittedMessage = z.infer; diff --git a/packages/backend/server/src/schema.gql b/packages/backend/server/src/schema.gql index bea9322f39..76508052bb 100644 --- a/packages/backend/server/src/schema.gql +++ b/packages/backend/server/src/schema.gql @@ -6,6 +6,7 @@ type ChatMessage { attachments: [String!] content: String! createdAt: DateTime + params: JSON role: String! } @@ -39,14 +40,12 @@ type CopilotQuota { input CreateChatMessageInput { attachments: [String!] - content: String! - params: String + content: String + params: JSON sessionId: String! } input CreateChatSessionInput { - """An mark identifying which view to use to display the session""" - action: String docId: String! """The prompt name to use for the session""" @@ -155,6 +154,11 @@ enum InvoiceStatus { Void } +""" +The `JSON` scalar type represents JSON values as specified by [ECMA-404](http://www.ecma-international.org/publications/files/ECMA-ST/ECMA-404.pdf). +""" +scalar JSON @specifiedBy(url: "http://www.ecma-international.org/publications/files/ECMA-ST/ECMA-404.pdf") + type LimitedUserType { """User email""" email: String!