feat: detailed copilot histories (#6523)

This commit is contained in:
darkskygit
2024-04-12 08:39:32 +00:00
parent 9e7a2fcf0e
commit e77475aca5
5 changed files with 112 additions and 77 deletions

View File

@@ -107,11 +107,12 @@ export class ChatPrompt {
* @param params record of params, e.g. { name: 'Alice' }
* @returns e.g. [{ role: 'system', content: 'Hello, {{name}}' }] => [{ role: 'system', content: 'Hello, Alice' }]
*/
finish(params: PromptParams) {
finish(params: PromptParams): PromptMessage[] {
this.checkParams(params);
return this.messages.map(m => ({
...m,
content: Mustache.render(m.content, params),
return this.messages.map(({ content, params: _, ...rest }) => ({
...rest,
params,
content: Mustache.render(content, params),
}));
}
@@ -122,6 +123,8 @@ export class ChatPrompt {
@Injectable()
export class PromptService {
private readonly cache = new Map<string, ChatPrompt>();
constructor(private readonly db: PrismaClient) {}
/**
@@ -140,34 +143,40 @@ export class PromptService {
* @returns prompt messages
*/
async get(name: string): Promise<ChatPrompt | null> {
return this.db.aiPrompt
.findUnique({
where: {
name,
},
select: {
name: true,
action: true,
model: true,
messages: {
select: {
role: true,
content: true,
params: true,
},
orderBy: {
idx: 'asc',
},
const cached = this.cache.get(name);
if (cached) return cached;
const prompt = await this.db.aiPrompt.findUnique({
where: {
name,
},
select: {
name: true,
action: true,
model: true,
messages: {
select: {
role: true,
content: true,
params: true,
},
orderBy: {
idx: 'asc',
},
},
})
.then(p => {
const messages = PromptMessageSchema.array().safeParse(p?.messages);
if (p && messages.success) {
return ChatPrompt.createFromPrompt({ ...p, messages: messages.data });
}
return null;
},
});
const messages = PromptMessageSchema.array().safeParse(prompt?.messages);
if (prompt && messages.success) {
const chatPrompt = ChatPrompt.createFromPrompt({
...prompt,
messages: messages.data,
});
this.cache.set(name, chatPrompt);
return chatPrompt;
}
return null;
}
async set(name: string, messages: PromptMessage[]) {
@@ -188,25 +197,28 @@ export class PromptService {
}
async update(name: string, messages: PromptMessage[]) {
return this.db.aiPrompt
.update({
where: { name },
data: {
messages: {
// cleanup old messages
deleteMany: {},
create: messages.map((m, idx) => ({
idx,
...m,
params: m.params || undefined,
})),
},
const { id } = await this.db.aiPrompt.update({
where: { name },
data: {
messages: {
// cleanup old messages
deleteMany: {},
create: messages.map((m, idx) => ({
idx,
...m,
params: m.params || undefined,
})),
},
})
.then(ret => ret.id);
},
});
this.cache.delete(name);
return id;
}
async delete(name: string) {
return this.db.aiPrompt.delete({ where: { name } }).then(ret => ret.id);
const { id } = await this.db.aiPrompt.delete({ where: { name } });
this.cache.delete(name);
return id;
}
}

View File

@@ -11,7 +11,7 @@ import {
ResolveField,
Resolver,
} from '@nestjs/graphql';
import { SafeIntResolver } from 'graphql-scalars';
import { GraphQLJSON, SafeIntResolver } from 'graphql-scalars';
import { CurrentUser } from '../../core/auth';
import { QuotaService } from '../../core/quota';
@@ -45,12 +45,6 @@ class CreateChatSessionInput {
@Field(() => String)
docId!: string;
@Field(() => String, {
description: 'An mark identifying which view to use to display the session',
nullable: true,
})
action!: string | undefined;
@Field(() => String, {
description: 'The prompt name to use for the session',
})
@@ -58,18 +52,18 @@ class CreateChatSessionInput {
}
@InputType()
class CreateChatMessageInput implements Omit<SubmittedMessage, 'params'> {
class CreateChatMessageInput implements Omit<SubmittedMessage, 'content'> {
@Field(() => String)
sessionId!: string;
@Field(() => String)
content!: string;
@Field(() => String, { nullable: true })
content!: string | undefined;
@Field(() => [String], { nullable: true })
attachments!: string[] | undefined;
@Field(() => String, { nullable: true })
params!: string | undefined;
@Field(() => GraphQLJSON, { nullable: true })
params!: Record<string, string> | undefined;
}
@InputType()
@@ -100,6 +94,9 @@ class ChatMessageType implements Partial<ChatMessage> {
@Field(() => [String], { nullable: true })
attachments!: string[];
@Field(() => GraphQLJSON, { nullable: true })
params!: Record<string, string> | undefined;
@Field(() => Date, { nullable: true })
createdAt!: Date | undefined;
}
@@ -227,12 +224,18 @@ export class CopilotResolver {
await this.permissions.checkCloudWorkspace(workspaceId, user.id);
}
return await this.chatSession.listHistories(
const histories = await this.chatSession.listHistories(
user.id,
workspaceId,
docId,
options
options,
true
);
return histories.map(h => ({
...h,
// filter out empty messages
messages: h.messages.filter(m => m.content || m.attachments?.length),
}));
}
@Mutation(() => String, {
@@ -282,12 +285,7 @@ export class CopilotResolver {
return new TooManyRequestsException('Server is busy');
}
try {
const { params, ...rest } = options;
const record: SubmittedMessage['params'] = {};
new URLSearchParams(params).forEach((value, key) => {
record[key] = value;
});
return await this.chatSession.createMessage({ ...rest, params: record });
return await this.chatSession.createMessage(options);
} catch (e: any) {
this.logger.error(`Failed to create chat message: ${e.message}`);
throw new Error('Failed to create chat message');

View File

@@ -59,7 +59,7 @@ export class ChatSession implements AsyncDisposable {
this.push({
role: 'user',
content: message.content,
content: message.content || '',
attachments: message.attachments,
params: message.params,
createdAt: new Date(),
@@ -96,7 +96,12 @@ export class ChatSession implements AsyncDisposable {
finish(params: PromptParams): PromptMessage[] {
const messages = this.takeMessages();
return [...this.state.prompt.finish(params), ...messages];
return [
...this.state.prompt.finish(
Object.keys(params).length ? params : messages[0]?.params || {}
),
...messages.filter(m => m.content || m.attachments?.length),
];
}
async save() {
@@ -257,7 +262,8 @@ export class ChatSessionService {
userId: string,
workspaceId?: string,
docId?: string,
options?: ListHistoriesOptions
options?: ListHistoriesOptions,
withPrompt = false
): Promise<ChatHistory[]> {
return await this.db.aiSession
.findMany({
@@ -272,11 +278,12 @@ export class ChatSessionService {
},
select: {
id: true,
prompt: true,
promptName: true,
messages: {
select: {
role: true,
content: true,
params: true,
},
orderBy: {
createdAt: 'asc',
@@ -288,20 +295,30 @@ export class ChatSessionService {
orderBy: { createdAt: 'desc' },
})
.then(sessions =>
sessions
.map(({ id, prompt, messages }) => {
Promise.all(
sessions.map(async ({ id, promptName, messages }) => {
try {
const ret = PromptMessageSchema.array().safeParse(messages);
if (ret.success) {
const prompt = await this.prompt.get(promptName);
if (!prompt) {
throw new Error(`Prompt not found: ${promptName}`);
}
const tokens = this.calculateTokenSize(
ret.data,
prompt.model as AvailableModel
);
// render system prompt
const preload = withPrompt
? prompt.finish(ret.data[0]?.params || {})
: [];
return {
sessionId: id,
action: prompt.action || undefined,
tokens,
messages: ret.data,
messages: preload.concat(ret.data),
};
} else {
this.logger.error(
@@ -313,7 +330,10 @@ export class ChatSessionService {
}
return undefined;
})
.filter((v): v is NonNullable<typeof v> => !!v)
)
)
.then(histories =>
histories.filter((v): v is NonNullable<typeof v> => !!v)
);
}

View File

@@ -82,6 +82,7 @@ export type ChatMessage = z.infer<typeof ChatMessageSchema>;
export const SubmittedMessageSchema = PureMessageSchema.extend({
sessionId: z.string(),
content: z.string().optional(),
}).strict();
export type SubmittedMessage = z.infer<typeof SubmittedMessageSchema>;