mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-11 20:08:37 +00:00
feat: detailed copilot histories (#6523)
This commit is contained in:
@@ -107,11 +107,12 @@ export class ChatPrompt {
|
||||
* @param params record of params, e.g. { name: 'Alice' }
|
||||
* @returns e.g. [{ role: 'system', content: 'Hello, {{name}}' }] => [{ role: 'system', content: 'Hello, Alice' }]
|
||||
*/
|
||||
finish(params: PromptParams) {
|
||||
finish(params: PromptParams): PromptMessage[] {
|
||||
this.checkParams(params);
|
||||
return this.messages.map(m => ({
|
||||
...m,
|
||||
content: Mustache.render(m.content, params),
|
||||
return this.messages.map(({ content, params: _, ...rest }) => ({
|
||||
...rest,
|
||||
params,
|
||||
content: Mustache.render(content, params),
|
||||
}));
|
||||
}
|
||||
|
||||
@@ -122,6 +123,8 @@ export class ChatPrompt {
|
||||
|
||||
@Injectable()
|
||||
export class PromptService {
|
||||
private readonly cache = new Map<string, ChatPrompt>();
|
||||
|
||||
constructor(private readonly db: PrismaClient) {}
|
||||
|
||||
/**
|
||||
@@ -140,34 +143,40 @@ export class PromptService {
|
||||
* @returns prompt messages
|
||||
*/
|
||||
async get(name: string): Promise<ChatPrompt | null> {
|
||||
return this.db.aiPrompt
|
||||
.findUnique({
|
||||
where: {
|
||||
name,
|
||||
},
|
||||
select: {
|
||||
name: true,
|
||||
action: true,
|
||||
model: true,
|
||||
messages: {
|
||||
select: {
|
||||
role: true,
|
||||
content: true,
|
||||
params: true,
|
||||
},
|
||||
orderBy: {
|
||||
idx: 'asc',
|
||||
},
|
||||
const cached = this.cache.get(name);
|
||||
if (cached) return cached;
|
||||
|
||||
const prompt = await this.db.aiPrompt.findUnique({
|
||||
where: {
|
||||
name,
|
||||
},
|
||||
select: {
|
||||
name: true,
|
||||
action: true,
|
||||
model: true,
|
||||
messages: {
|
||||
select: {
|
||||
role: true,
|
||||
content: true,
|
||||
params: true,
|
||||
},
|
||||
orderBy: {
|
||||
idx: 'asc',
|
||||
},
|
||||
},
|
||||
})
|
||||
.then(p => {
|
||||
const messages = PromptMessageSchema.array().safeParse(p?.messages);
|
||||
if (p && messages.success) {
|
||||
return ChatPrompt.createFromPrompt({ ...p, messages: messages.data });
|
||||
}
|
||||
return null;
|
||||
},
|
||||
});
|
||||
|
||||
const messages = PromptMessageSchema.array().safeParse(prompt?.messages);
|
||||
if (prompt && messages.success) {
|
||||
const chatPrompt = ChatPrompt.createFromPrompt({
|
||||
...prompt,
|
||||
messages: messages.data,
|
||||
});
|
||||
this.cache.set(name, chatPrompt);
|
||||
return chatPrompt;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
async set(name: string, messages: PromptMessage[]) {
|
||||
@@ -188,25 +197,28 @@ export class PromptService {
|
||||
}
|
||||
|
||||
async update(name: string, messages: PromptMessage[]) {
|
||||
return this.db.aiPrompt
|
||||
.update({
|
||||
where: { name },
|
||||
data: {
|
||||
messages: {
|
||||
// cleanup old messages
|
||||
deleteMany: {},
|
||||
create: messages.map((m, idx) => ({
|
||||
idx,
|
||||
...m,
|
||||
params: m.params || undefined,
|
||||
})),
|
||||
},
|
||||
const { id } = await this.db.aiPrompt.update({
|
||||
where: { name },
|
||||
data: {
|
||||
messages: {
|
||||
// cleanup old messages
|
||||
deleteMany: {},
|
||||
create: messages.map((m, idx) => ({
|
||||
idx,
|
||||
...m,
|
||||
params: m.params || undefined,
|
||||
})),
|
||||
},
|
||||
})
|
||||
.then(ret => ret.id);
|
||||
},
|
||||
});
|
||||
|
||||
this.cache.delete(name);
|
||||
return id;
|
||||
}
|
||||
|
||||
async delete(name: string) {
|
||||
return this.db.aiPrompt.delete({ where: { name } }).then(ret => ret.id);
|
||||
const { id } = await this.db.aiPrompt.delete({ where: { name } });
|
||||
this.cache.delete(name);
|
||||
return id;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ import {
|
||||
ResolveField,
|
||||
Resolver,
|
||||
} from '@nestjs/graphql';
|
||||
import { SafeIntResolver } from 'graphql-scalars';
|
||||
import { GraphQLJSON, SafeIntResolver } from 'graphql-scalars';
|
||||
|
||||
import { CurrentUser } from '../../core/auth';
|
||||
import { QuotaService } from '../../core/quota';
|
||||
@@ -45,12 +45,6 @@ class CreateChatSessionInput {
|
||||
@Field(() => String)
|
||||
docId!: string;
|
||||
|
||||
@Field(() => String, {
|
||||
description: 'An mark identifying which view to use to display the session',
|
||||
nullable: true,
|
||||
})
|
||||
action!: string | undefined;
|
||||
|
||||
@Field(() => String, {
|
||||
description: 'The prompt name to use for the session',
|
||||
})
|
||||
@@ -58,18 +52,18 @@ class CreateChatSessionInput {
|
||||
}
|
||||
|
||||
@InputType()
|
||||
class CreateChatMessageInput implements Omit<SubmittedMessage, 'params'> {
|
||||
class CreateChatMessageInput implements Omit<SubmittedMessage, 'content'> {
|
||||
@Field(() => String)
|
||||
sessionId!: string;
|
||||
|
||||
@Field(() => String)
|
||||
content!: string;
|
||||
@Field(() => String, { nullable: true })
|
||||
content!: string | undefined;
|
||||
|
||||
@Field(() => [String], { nullable: true })
|
||||
attachments!: string[] | undefined;
|
||||
|
||||
@Field(() => String, { nullable: true })
|
||||
params!: string | undefined;
|
||||
@Field(() => GraphQLJSON, { nullable: true })
|
||||
params!: Record<string, string> | undefined;
|
||||
}
|
||||
|
||||
@InputType()
|
||||
@@ -100,6 +94,9 @@ class ChatMessageType implements Partial<ChatMessage> {
|
||||
@Field(() => [String], { nullable: true })
|
||||
attachments!: string[];
|
||||
|
||||
@Field(() => GraphQLJSON, { nullable: true })
|
||||
params!: Record<string, string> | undefined;
|
||||
|
||||
@Field(() => Date, { nullable: true })
|
||||
createdAt!: Date | undefined;
|
||||
}
|
||||
@@ -227,12 +224,18 @@ export class CopilotResolver {
|
||||
await this.permissions.checkCloudWorkspace(workspaceId, user.id);
|
||||
}
|
||||
|
||||
return await this.chatSession.listHistories(
|
||||
const histories = await this.chatSession.listHistories(
|
||||
user.id,
|
||||
workspaceId,
|
||||
docId,
|
||||
options
|
||||
options,
|
||||
true
|
||||
);
|
||||
return histories.map(h => ({
|
||||
...h,
|
||||
// filter out empty messages
|
||||
messages: h.messages.filter(m => m.content || m.attachments?.length),
|
||||
}));
|
||||
}
|
||||
|
||||
@Mutation(() => String, {
|
||||
@@ -282,12 +285,7 @@ export class CopilotResolver {
|
||||
return new TooManyRequestsException('Server is busy');
|
||||
}
|
||||
try {
|
||||
const { params, ...rest } = options;
|
||||
const record: SubmittedMessage['params'] = {};
|
||||
new URLSearchParams(params).forEach((value, key) => {
|
||||
record[key] = value;
|
||||
});
|
||||
return await this.chatSession.createMessage({ ...rest, params: record });
|
||||
return await this.chatSession.createMessage(options);
|
||||
} catch (e: any) {
|
||||
this.logger.error(`Failed to create chat message: ${e.message}`);
|
||||
throw new Error('Failed to create chat message');
|
||||
|
||||
@@ -59,7 +59,7 @@ export class ChatSession implements AsyncDisposable {
|
||||
|
||||
this.push({
|
||||
role: 'user',
|
||||
content: message.content,
|
||||
content: message.content || '',
|
||||
attachments: message.attachments,
|
||||
params: message.params,
|
||||
createdAt: new Date(),
|
||||
@@ -96,7 +96,12 @@ export class ChatSession implements AsyncDisposable {
|
||||
|
||||
finish(params: PromptParams): PromptMessage[] {
|
||||
const messages = this.takeMessages();
|
||||
return [...this.state.prompt.finish(params), ...messages];
|
||||
return [
|
||||
...this.state.prompt.finish(
|
||||
Object.keys(params).length ? params : messages[0]?.params || {}
|
||||
),
|
||||
...messages.filter(m => m.content || m.attachments?.length),
|
||||
];
|
||||
}
|
||||
|
||||
async save() {
|
||||
@@ -257,7 +262,8 @@ export class ChatSessionService {
|
||||
userId: string,
|
||||
workspaceId?: string,
|
||||
docId?: string,
|
||||
options?: ListHistoriesOptions
|
||||
options?: ListHistoriesOptions,
|
||||
withPrompt = false
|
||||
): Promise<ChatHistory[]> {
|
||||
return await this.db.aiSession
|
||||
.findMany({
|
||||
@@ -272,11 +278,12 @@ export class ChatSessionService {
|
||||
},
|
||||
select: {
|
||||
id: true,
|
||||
prompt: true,
|
||||
promptName: true,
|
||||
messages: {
|
||||
select: {
|
||||
role: true,
|
||||
content: true,
|
||||
params: true,
|
||||
},
|
||||
orderBy: {
|
||||
createdAt: 'asc',
|
||||
@@ -288,20 +295,30 @@ export class ChatSessionService {
|
||||
orderBy: { createdAt: 'desc' },
|
||||
})
|
||||
.then(sessions =>
|
||||
sessions
|
||||
.map(({ id, prompt, messages }) => {
|
||||
Promise.all(
|
||||
sessions.map(async ({ id, promptName, messages }) => {
|
||||
try {
|
||||
const ret = PromptMessageSchema.array().safeParse(messages);
|
||||
if (ret.success) {
|
||||
const prompt = await this.prompt.get(promptName);
|
||||
if (!prompt) {
|
||||
throw new Error(`Prompt not found: ${promptName}`);
|
||||
}
|
||||
const tokens = this.calculateTokenSize(
|
||||
ret.data,
|
||||
prompt.model as AvailableModel
|
||||
);
|
||||
|
||||
// render system prompt
|
||||
const preload = withPrompt
|
||||
? prompt.finish(ret.data[0]?.params || {})
|
||||
: [];
|
||||
|
||||
return {
|
||||
sessionId: id,
|
||||
action: prompt.action || undefined,
|
||||
tokens,
|
||||
messages: ret.data,
|
||||
messages: preload.concat(ret.data),
|
||||
};
|
||||
} else {
|
||||
this.logger.error(
|
||||
@@ -313,7 +330,10 @@ export class ChatSessionService {
|
||||
}
|
||||
return undefined;
|
||||
})
|
||||
.filter((v): v is NonNullable<typeof v> => !!v)
|
||||
)
|
||||
)
|
||||
.then(histories =>
|
||||
histories.filter((v): v is NonNullable<typeof v> => !!v)
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -82,6 +82,7 @@ export type ChatMessage = z.infer<typeof ChatMessageSchema>;
|
||||
|
||||
export const SubmittedMessageSchema = PureMessageSchema.extend({
|
||||
sessionId: z.string(),
|
||||
content: z.string().optional(),
|
||||
}).strict();
|
||||
|
||||
export type SubmittedMessage = z.infer<typeof SubmittedMessageSchema>;
|
||||
|
||||
@@ -6,6 +6,7 @@ type ChatMessage {
|
||||
attachments: [String!]
|
||||
content: String!
|
||||
createdAt: DateTime
|
||||
params: JSON
|
||||
role: String!
|
||||
}
|
||||
|
||||
@@ -39,14 +40,12 @@ type CopilotQuota {
|
||||
|
||||
input CreateChatMessageInput {
|
||||
attachments: [String!]
|
||||
content: String!
|
||||
params: String
|
||||
content: String
|
||||
params: JSON
|
||||
sessionId: String!
|
||||
}
|
||||
|
||||
input CreateChatSessionInput {
|
||||
"""An mark identifying which view to use to display the session"""
|
||||
action: String
|
||||
docId: String!
|
||||
|
||||
"""The prompt name to use for the session"""
|
||||
@@ -155,6 +154,11 @@ enum InvoiceStatus {
|
||||
Void
|
||||
}
|
||||
|
||||
"""
|
||||
The `JSON` scalar type represents JSON values as specified by [ECMA-404](http://www.ecma-international.org/publications/files/ECMA-ST/ECMA-404.pdf).
|
||||
"""
|
||||
scalar JSON @specifiedBy(url: "http://www.ecma-international.org/publications/files/ECMA-ST/ECMA-404.pdf")
|
||||
|
||||
type LimitedUserType {
|
||||
"""User email"""
|
||||
email: String!
|
||||
|
||||
Reference in New Issue
Block a user