feat: improve prompt management (#7853)

This commit is contained in:
darkskygit
2024-08-14 08:38:36 +00:00
parent cd3924b8fc
commit 339c39c1ec
13 changed files with 161 additions and 105 deletions

View File

@@ -0,0 +1,3 @@
-- AlterTable
ALTER TABLE "ai_prompts_metadata" ADD COLUMN "modified" BOOLEAN NOT NULL DEFAULT false,
ADD COLUMN "updated_at" TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP;

View File

@@ -367,6 +367,9 @@ model AiPrompt {
model String @db.VarChar
config Json? @db.Json
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3)
updatedAt DateTime @default(now()) @map("updated_at") @db.Timestamptz(3)
// whether the prompt is modified by the admin panel
modified Boolean @default(false)
messages AiPromptMessage[]
sessions AiSession[]

View File

@@ -33,7 +33,10 @@ export class ChatPrompt {
private readonly templateParams: PromptParams = {};
static createFromPrompt(
options: Omit<AiPrompt, 'id' | 'createdAt' | 'config'> & {
options: Omit<
AiPrompt,
'id' | 'createdAt' | 'updatedAt' | 'modified' | 'config'
> & {
messages: PromptMessage[];
config: PromptConfig | undefined;
}

View File

@@ -1,8 +1,12 @@
import { Logger } from '@nestjs/common';
import { AiPrompt, PrismaClient } from '@prisma/client';
import { PromptConfig, PromptMessage } from '../types';
type Prompt = Omit<AiPrompt, 'id' | 'createdAt' | 'action' | 'config'> & {
type Prompt = Omit<
AiPrompt,
'id' | 'createdAt' | 'updatedAt' | 'modified' | 'action' | 'config'
> & {
action?: string;
messages: PromptMessage[];
config?: PromptConfig;
@@ -830,7 +834,7 @@ const chat: Prompt[] = [
],
},
{
name: 'chat:gpt4',
name: 'Chat With AFFiNE AI',
model: 'gpt-4o',
messages: [
{
@@ -845,7 +849,20 @@ const chat: Prompt[] = [
export const prompts: Prompt[] = [...actions, ...chat, ...workflows];
export async function refreshPrompts(db: PrismaClient) {
const needToSkip = await db.aiPrompt
.findMany({
where: { modified: true },
select: { name: true },
})
.then(p => p.map(p => p.name));
for (const prompt of prompts) {
// skip prompt update if already modified by admin panel
if (needToSkip.includes(prompt.name)) {
new Logger('CopilotPrompt').warn(`Skip modified prompt: ${prompt.name}`);
return;
}
await db.aiPrompt.upsert({
create: {
name: prompt.name,
@@ -865,6 +882,7 @@ export async function refreshPrompts(db: PrismaClient) {
update: {
action: prompt.action,
model: prompt.model,
updatedAt: new Date(),
messages: {
deleteMany: {},
create: prompt.messages.map((message, idx) => ({

View File

@@ -38,16 +38,11 @@ export class PromptService implements OnModuleInit {
model: true,
config: true,
messages: {
select: {
role: true,
content: true,
params: true,
},
orderBy: {
idx: 'asc',
},
select: { role: true, content: true, params: true },
orderBy: { idx: 'asc' },
},
},
orderBy: { action: { sort: 'asc', nulls: 'first' } },
});
}
@@ -121,11 +116,18 @@ export class PromptService implements OnModuleInit {
.then(ret => ret.id);
}
async update(name: string, messages: PromptMessage[], config?: PromptConfig) {
async update(
name: string,
messages: PromptMessage[],
modifyByApi: boolean = false,
config?: PromptConfig
) {
const { id } = await this.db.aiPrompt.update({
where: { name },
data: {
config: config || undefined,
updatedAt: new Date(),
modified: modifyByApi,
messages: {
// cleanup old messages
deleteMany: {},

View File

@@ -517,7 +517,16 @@ export class PromptsManagementResolver {
description: 'List all copilot prompts',
})
async listCopilotPrompts() {
return this.promptService.list();
const prompts = await this.promptService.list();
return prompts.filter(
p =>
p.messages.length > 0 &&
// ignore internal prompts
!p.name.startsWith('workflow:') &&
!p.name.startsWith('debug:') &&
!p.name.startsWith('chat:') &&
!p.name.startsWith('action:')
);
}
@Mutation(() => CopilotPromptType, {
@@ -544,7 +553,7 @@ export class PromptsManagementResolver {
@Args('messages', { type: () => [CopilotPromptMessageType] })
messages: CopilotPromptMessageType[]
) {
await this.promptService.update(name, messages);
await this.promptService.update(name, messages, true);
return this.promptService.get(name);
}
}