feat: detailed copilot histories (#6523)

This commit is contained in:
darkskygit
2024-04-12 08:39:32 +00:00
parent 9e7a2fcf0e
commit e77475aca5
5 changed files with 112 additions and 77 deletions

View File

@@ -107,11 +107,12 @@ export class ChatPrompt {
* @param params record of params, e.g. { name: 'Alice' } * @param params record of params, e.g. { name: 'Alice' }
* @returns e.g. [{ role: 'system', content: 'Hello, {{name}}' }] => [{ role: 'system', content: 'Hello, Alice' }] * @returns e.g. [{ role: 'system', content: 'Hello, {{name}}' }] => [{ role: 'system', content: 'Hello, Alice' }]
*/ */
finish(params: PromptParams) { finish(params: PromptParams): PromptMessage[] {
this.checkParams(params); this.checkParams(params);
return this.messages.map(m => ({ return this.messages.map(({ content, params: _, ...rest }) => ({
...m, ...rest,
content: Mustache.render(m.content, params), params,
content: Mustache.render(content, params),
})); }));
} }
@@ -122,6 +123,8 @@ export class ChatPrompt {
@Injectable() @Injectable()
export class PromptService { export class PromptService {
private readonly cache = new Map<string, ChatPrompt>();
constructor(private readonly db: PrismaClient) {} constructor(private readonly db: PrismaClient) {}
/** /**
@@ -140,8 +143,10 @@ export class PromptService {
* @returns prompt messages * @returns prompt messages
*/ */
async get(name: string): Promise<ChatPrompt | null> { async get(name: string): Promise<ChatPrompt | null> {
return this.db.aiPrompt const cached = this.cache.get(name);
.findUnique({ if (cached) return cached;
const prompt = await this.db.aiPrompt.findUnique({
where: { where: {
name, name,
}, },
@@ -160,14 +165,18 @@ export class PromptService {
}, },
}, },
}, },
}) });
.then(p => {
const messages = PromptMessageSchema.array().safeParse(p?.messages); const messages = PromptMessageSchema.array().safeParse(prompt?.messages);
if (p && messages.success) { if (prompt && messages.success) {
return ChatPrompt.createFromPrompt({ ...p, messages: messages.data }); const chatPrompt = ChatPrompt.createFromPrompt({
...prompt,
messages: messages.data,
});
this.cache.set(name, chatPrompt);
return chatPrompt;
} }
return null; return null;
});
} }
async set(name: string, messages: PromptMessage[]) { async set(name: string, messages: PromptMessage[]) {
@@ -188,8 +197,7 @@ export class PromptService {
} }
async update(name: string, messages: PromptMessage[]) { async update(name: string, messages: PromptMessage[]) {
return this.db.aiPrompt const { id } = await this.db.aiPrompt.update({
.update({
where: { name }, where: { name },
data: { data: {
messages: { messages: {
@@ -202,11 +210,15 @@ export class PromptService {
})), })),
}, },
}, },
}) });
.then(ret => ret.id);
this.cache.delete(name);
return id;
} }
async delete(name: string) { async delete(name: string) {
return this.db.aiPrompt.delete({ where: { name } }).then(ret => ret.id); const { id } = await this.db.aiPrompt.delete({ where: { name } });
this.cache.delete(name);
return id;
} }
} }

View File

@@ -11,7 +11,7 @@ import {
ResolveField, ResolveField,
Resolver, Resolver,
} from '@nestjs/graphql'; } from '@nestjs/graphql';
import { SafeIntResolver } from 'graphql-scalars'; import { GraphQLJSON, SafeIntResolver } from 'graphql-scalars';
import { CurrentUser } from '../../core/auth'; import { CurrentUser } from '../../core/auth';
import { QuotaService } from '../../core/quota'; import { QuotaService } from '../../core/quota';
@@ -45,12 +45,6 @@ class CreateChatSessionInput {
@Field(() => String) @Field(() => String)
docId!: string; docId!: string;
@Field(() => String, {
description: 'An mark identifying which view to use to display the session',
nullable: true,
})
action!: string | undefined;
@Field(() => String, { @Field(() => String, {
description: 'The prompt name to use for the session', description: 'The prompt name to use for the session',
}) })
@@ -58,18 +52,18 @@ class CreateChatSessionInput {
} }
@InputType() @InputType()
class CreateChatMessageInput implements Omit<SubmittedMessage, 'params'> { class CreateChatMessageInput implements Omit<SubmittedMessage, 'content'> {
@Field(() => String) @Field(() => String)
sessionId!: string; sessionId!: string;
@Field(() => String) @Field(() => String, { nullable: true })
content!: string; content!: string | undefined;
@Field(() => [String], { nullable: true }) @Field(() => [String], { nullable: true })
attachments!: string[] | undefined; attachments!: string[] | undefined;
@Field(() => String, { nullable: true }) @Field(() => GraphQLJSON, { nullable: true })
params!: string | undefined; params!: Record<string, string> | undefined;
} }
@InputType() @InputType()
@@ -100,6 +94,9 @@ class ChatMessageType implements Partial<ChatMessage> {
@Field(() => [String], { nullable: true }) @Field(() => [String], { nullable: true })
attachments!: string[]; attachments!: string[];
@Field(() => GraphQLJSON, { nullable: true })
params!: Record<string, string> | undefined;
@Field(() => Date, { nullable: true }) @Field(() => Date, { nullable: true })
createdAt!: Date | undefined; createdAt!: Date | undefined;
} }
@@ -227,12 +224,18 @@ export class CopilotResolver {
await this.permissions.checkCloudWorkspace(workspaceId, user.id); await this.permissions.checkCloudWorkspace(workspaceId, user.id);
} }
return await this.chatSession.listHistories( const histories = await this.chatSession.listHistories(
user.id, user.id,
workspaceId, workspaceId,
docId, docId,
options options,
true
); );
return histories.map(h => ({
...h,
// filter out empty messages
messages: h.messages.filter(m => m.content || m.attachments?.length),
}));
} }
@Mutation(() => String, { @Mutation(() => String, {
@@ -282,12 +285,7 @@ export class CopilotResolver {
return new TooManyRequestsException('Server is busy'); return new TooManyRequestsException('Server is busy');
} }
try { try {
const { params, ...rest } = options; return await this.chatSession.createMessage(options);
const record: SubmittedMessage['params'] = {};
new URLSearchParams(params).forEach((value, key) => {
record[key] = value;
});
return await this.chatSession.createMessage({ ...rest, params: record });
} catch (e: any) { } catch (e: any) {
this.logger.error(`Failed to create chat message: ${e.message}`); this.logger.error(`Failed to create chat message: ${e.message}`);
throw new Error('Failed to create chat message'); throw new Error('Failed to create chat message');

View File

@@ -59,7 +59,7 @@ export class ChatSession implements AsyncDisposable {
this.push({ this.push({
role: 'user', role: 'user',
content: message.content, content: message.content || '',
attachments: message.attachments, attachments: message.attachments,
params: message.params, params: message.params,
createdAt: new Date(), createdAt: new Date(),
@@ -96,7 +96,12 @@ export class ChatSession implements AsyncDisposable {
finish(params: PromptParams): PromptMessage[] { finish(params: PromptParams): PromptMessage[] {
const messages = this.takeMessages(); const messages = this.takeMessages();
return [...this.state.prompt.finish(params), ...messages]; return [
...this.state.prompt.finish(
Object.keys(params).length ? params : messages[0]?.params || {}
),
...messages.filter(m => m.content || m.attachments?.length),
];
} }
async save() { async save() {
@@ -257,7 +262,8 @@ export class ChatSessionService {
userId: string, userId: string,
workspaceId?: string, workspaceId?: string,
docId?: string, docId?: string,
options?: ListHistoriesOptions options?: ListHistoriesOptions,
withPrompt = false
): Promise<ChatHistory[]> { ): Promise<ChatHistory[]> {
return await this.db.aiSession return await this.db.aiSession
.findMany({ .findMany({
@@ -272,11 +278,12 @@ export class ChatSessionService {
}, },
select: { select: {
id: true, id: true,
prompt: true, promptName: true,
messages: { messages: {
select: { select: {
role: true, role: true,
content: true, content: true,
params: true,
}, },
orderBy: { orderBy: {
createdAt: 'asc', createdAt: 'asc',
@@ -288,20 +295,30 @@ export class ChatSessionService {
orderBy: { createdAt: 'desc' }, orderBy: { createdAt: 'desc' },
}) })
.then(sessions => .then(sessions =>
sessions Promise.all(
.map(({ id, prompt, messages }) => { sessions.map(async ({ id, promptName, messages }) => {
try { try {
const ret = PromptMessageSchema.array().safeParse(messages); const ret = PromptMessageSchema.array().safeParse(messages);
if (ret.success) { if (ret.success) {
const prompt = await this.prompt.get(promptName);
if (!prompt) {
throw new Error(`Prompt not found: ${promptName}`);
}
const tokens = this.calculateTokenSize( const tokens = this.calculateTokenSize(
ret.data, ret.data,
prompt.model as AvailableModel prompt.model as AvailableModel
); );
// render system prompt
const preload = withPrompt
? prompt.finish(ret.data[0]?.params || {})
: [];
return { return {
sessionId: id, sessionId: id,
action: prompt.action || undefined, action: prompt.action || undefined,
tokens, tokens,
messages: ret.data, messages: preload.concat(ret.data),
}; };
} else { } else {
this.logger.error( this.logger.error(
@@ -313,7 +330,10 @@ export class ChatSessionService {
} }
return undefined; return undefined;
}) })
.filter((v): v is NonNullable<typeof v> => !!v) )
)
.then(histories =>
histories.filter((v): v is NonNullable<typeof v> => !!v)
); );
} }

View File

@@ -82,6 +82,7 @@ export type ChatMessage = z.infer<typeof ChatMessageSchema>;
export const SubmittedMessageSchema = PureMessageSchema.extend({ export const SubmittedMessageSchema = PureMessageSchema.extend({
sessionId: z.string(), sessionId: z.string(),
content: z.string().optional(),
}).strict(); }).strict();
export type SubmittedMessage = z.infer<typeof SubmittedMessageSchema>; export type SubmittedMessage = z.infer<typeof SubmittedMessageSchema>;

View File

@@ -6,6 +6,7 @@ type ChatMessage {
attachments: [String!] attachments: [String!]
content: String! content: String!
createdAt: DateTime createdAt: DateTime
params: JSON
role: String! role: String!
} }
@@ -39,14 +40,12 @@ type CopilotQuota {
input CreateChatMessageInput { input CreateChatMessageInput {
attachments: [String!] attachments: [String!]
content: String! content: String
params: String params: JSON
sessionId: String! sessionId: String!
} }
input CreateChatSessionInput { input CreateChatSessionInput {
"""An mark identifying which view to use to display the session"""
action: String
docId: String! docId: String!
"""The prompt name to use for the session""" """The prompt name to use for the session"""
@@ -155,6 +154,11 @@ enum InvoiceStatus {
Void Void
} }
"""
The `JSON` scalar type represents JSON values as specified by [ECMA-404](http://www.ecma-international.org/publications/files/ECMA-ST/ECMA-404.pdf).
"""
scalar JSON @specifiedBy(url: "http://www.ecma-international.org/publications/files/ECMA-ST/ECMA-404.pdf")
type LimitedUserType { type LimitedUserType {
"""User email""" """User email"""
email: String! email: String!