From bf6c9a5955b09bcdcdc285adb085554250434d5d Mon Sep 17 00:00:00 2001 From: darkskygit Date: Mon, 8 Jul 2024 08:11:22 +0000 Subject: [PATCH] feat: add prompt level config (#7445) --- .../migration.sql | 2 ++ packages/backend/server/schema.prisma | 1 + .../src/data/migrations/utils/prompts.ts | 11 +++++++ .../server/src/plugins/copilot/controller.ts | 19 ++++++----- .../server/src/plugins/copilot/prompt.ts | 23 ++++++++++--- .../src/plugins/copilot/providers/openai.ts | 20 ++---------- .../server/src/plugins/copilot/resolver.ts | 32 ++++++++++++++++++- .../server/src/plugins/copilot/session.ts | 4 +-- .../server/src/plugins/copilot/types.ts | 22 ++++++++++--- packages/backend/server/src/schema.gql | 18 +++++++++++ packages/backend/server/tests/copilot.spec.ts | 6 ++-- .../backend/server/tests/utils/copilot.ts | 8 ++++- 12 files changed, 125 insertions(+), 41 deletions(-) create mode 100644 packages/backend/server/migrations/20240708034904_prompt_level_config/migration.sql diff --git a/packages/backend/server/migrations/20240708034904_prompt_level_config/migration.sql b/packages/backend/server/migrations/20240708034904_prompt_level_config/migration.sql new file mode 100644 index 0000000000..5cac4f21a1 --- /dev/null +++ b/packages/backend/server/migrations/20240708034904_prompt_level_config/migration.sql @@ -0,0 +1,2 @@ +-- AlterTable +ALTER TABLE "ai_prompts_metadata" ADD COLUMN "config" JSON; diff --git a/packages/backend/server/schema.prisma b/packages/backend/server/schema.prisma index 2d9be52c88..e5162ee508 100644 --- a/packages/backend/server/schema.prisma +++ b/packages/backend/server/schema.prisma @@ -457,6 +457,7 @@ model AiPrompt { // it is only used in the frontend and does not affect the backend action String? @db.VarChar model String @db.VarChar + config Json? @db.Json createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6) messages AiPromptMessage[] diff --git a/packages/backend/server/src/data/migrations/utils/prompts.ts b/packages/backend/server/src/data/migrations/utils/prompts.ts index 9a0d4b5d7f..2f5334f651 100644 --- a/packages/backend/server/src/data/migrations/utils/prompts.ts +++ b/packages/backend/server/src/data/migrations/utils/prompts.ts @@ -6,10 +6,19 @@ type PromptMessage = { params?: Record; }; +type PromptConfig = { + jsonMode?: boolean; + frequencyPenalty?: number; + presencePenalty?: number; + temperature?: number; + maxTokens?: number; +}; + type Prompt = { name: string; action?: string; model: string; + config?: PromptConfig; messages: PromptMessage[]; }; @@ -465,6 +474,7 @@ content: {{content}}`, name: 'workflow:presentation:step1', action: 'workflow:presentation:step1', model: 'gpt-4o', + config: { temperature: 0.7 }, messages: [ { role: 'system', @@ -685,6 +695,7 @@ export async function refreshPrompts(db: PrismaClient) { create: { name: prompt.name, action: prompt.action, + config: prompt.config, model: prompt.model, messages: { create: prompt.messages.map((message, idx) => ({ diff --git a/packages/backend/server/src/plugins/copilot/controller.ts b/packages/backend/server/src/plugins/copilot/controller.ts index 5795993474..ec46455d13 100644 --- a/packages/backend/server/src/plugins/copilot/controller.ts +++ b/packages/backend/server/src/plugins/copilot/controller.ts @@ -138,9 +138,8 @@ export class CopilotController { const messageId = Array.isArray(params.messageId) ? params.messageId[0] : params.messageId; - const jsonMode = String(params.jsonMode).toLowerCase() === 'true'; delete params.messageId; - return { messageId, jsonMode, params }; + return { messageId, params }; } private getSignal(req: Request) { @@ -167,7 +166,7 @@ export class CopilotController { @Param('sessionId') sessionId: string, @Query() params: Record ): Promise { - const { messageId, jsonMode } = this.prepareParams(params); + const { messageId } = this.prepareParams(params); const provider = await this.chooseTextProvider( user.id, sessionId, @@ -180,7 +179,11 @@ export class CopilotController { const content = await provider.generateText( session.finish(params), session.model, - { jsonMode, signal: this.getSignal(req), user: user.id } + { + ...session.config.promptConfig, + signal: this.getSignal(req), + user: user.id, + } ); session.push({ @@ -204,7 +207,7 @@ export class CopilotController { @Query() params: Record ): Promise> { try { - const { messageId, jsonMode } = this.prepareParams(params); + const { messageId } = this.prepareParams(params); const provider = await this.chooseTextProvider( user.id, sessionId, @@ -215,7 +218,7 @@ export class CopilotController { return from( provider.generateTextStream(session.finish(params), session.model, { - jsonMode, + ...session.config.promptConfig, signal: this.getSignal(req), user: user.id, }) @@ -256,7 +259,7 @@ export class CopilotController { @Query() params: Record ): Promise> { try { - const { messageId, jsonMode } = this.prepareParams(params); + const { messageId } = this.prepareParams(params); const session = await this.appendSessionMessage(sessionId, messageId); const latestMessage = session.stashMessages.findLast( m => m.role === 'user' @@ -269,7 +272,7 @@ export class CopilotController { return from( this.workflow.runGraph(params, session.model, { - jsonMode, + ...session.config.promptConfig, signal: this.getSignal(req), user: user.id, }) diff --git a/packages/backend/server/src/plugins/copilot/prompt.ts b/packages/backend/server/src/plugins/copilot/prompt.ts index e07eb434bb..d9b0835d1e 100644 --- a/packages/backend/server/src/plugins/copilot/prompt.ts +++ b/packages/backend/server/src/plugins/copilot/prompt.ts @@ -5,6 +5,8 @@ import Mustache from 'mustache'; import { getTokenEncoder, + PromptConfig, + PromptConfigSchema, PromptMessage, PromptMessageSchema, PromptParams, @@ -35,14 +37,16 @@ export class ChatPrompt { private readonly templateParams: PromptParams = {}; static createFromPrompt( - options: Omit & { + options: Omit & { messages: PromptMessage[]; + config: PromptConfig | undefined; } ) { return new ChatPrompt( options.name, options.action || undefined, options.model, + options.config, options.messages ); } @@ -51,6 +55,7 @@ export class ChatPrompt { public readonly name: string, public readonly action: string | undefined, public readonly model: string, + public readonly config: PromptConfig | undefined, private readonly messages: PromptMessage[] ) { this.encoder = getTokenEncoder(model); @@ -185,6 +190,7 @@ export class PromptService { name: true, action: true, model: true, + config: true, messages: { select: { role: true, @@ -199,9 +205,11 @@ export class PromptService { }); const messages = PromptMessageSchema.array().safeParse(prompt?.messages); - if (prompt && messages.success) { + const config = PromptConfigSchema.safeParse(prompt?.config); + if (prompt && messages.success && config.success) { const chatPrompt = ChatPrompt.createFromPrompt({ ...prompt, + config: config.data, messages: messages.data, }); this.cache.set(name, chatPrompt); @@ -210,12 +218,18 @@ export class PromptService { return null; } - async set(name: string, model: string, messages: PromptMessage[]) { + async set( + name: string, + model: string, + messages: PromptMessage[], + config?: PromptConfig + ) { return await this.db.aiPrompt .create({ data: { name, model, + config: config || undefined, messages: { create: messages.map((m, idx) => ({ idx, @@ -229,10 +243,11 @@ export class PromptService { .then(ret => ret.id); } - async update(name: string, messages: PromptMessage[]) { + async update(name: string, messages: PromptMessage[], config?: PromptConfig) { const { id } = await this.db.aiPrompt.update({ where: { name }, data: { + config: config || undefined, messages: { // cleanup old messages deleteMany: {}, diff --git a/packages/backend/server/src/plugins/copilot/providers/openai.ts b/packages/backend/server/src/plugins/copilot/providers/openai.ts index 31d27b913f..8cc8927806 100644 --- a/packages/backend/server/src/plugins/copilot/providers/openai.ts +++ b/packages/backend/server/src/plugins/copilot/providers/openai.ts @@ -125,21 +125,6 @@ export class OpenAIProvider }); } - private extractOptionFromMessages( - messages: PromptMessage[], - options: CopilotChatOptions - ) { - const params: Record = {}; - for (const message of messages) { - if (message.params) { - Object.assign(params, message.params); - } - } - if (params.jsonMode && options) { - options.jsonMode = String(params.jsonMode).toLowerCase() === 'true'; - } - } - protected checkParams({ messages, embeddings, @@ -155,7 +140,6 @@ export class OpenAIProvider throw new CopilotPromptInvalid(`Invalid model: ${model}`); } if (Array.isArray(messages) && messages.length > 0) { - this.extractOptionFromMessages(messages, options); if ( messages.some( m => @@ -257,7 +241,9 @@ export class OpenAIProvider stream: true, messages: this.chatToGPTMessage(messages), model: model, - temperature: options.temperature || 0, + frequency_penalty: options.frequencyPenalty || 0, + presence_penalty: options.presencePenalty || 0, + temperature: options.temperature || 0.5, max_tokens: options.maxTokens || 4096, response_format: { type: options.jsonMode ? 'json_object' : 'text', diff --git a/packages/backend/server/src/plugins/copilot/resolver.ts b/packages/backend/server/src/plugins/copilot/resolver.ts index 067105df72..ef37de1559 100644 --- a/packages/backend/server/src/plugins/copilot/resolver.ts +++ b/packages/backend/server/src/plugins/copilot/resolver.ts @@ -183,6 +183,25 @@ registerEnumType(AiPromptRole, { name: 'CopilotPromptMessageRole', }); +@InputType('CopilotPromptConfigInput') +@ObjectType() +class CopilotPromptConfigType { + @Field(() => Boolean, { nullable: true }) + jsonMode!: boolean | null; + + @Field(() => Number, { nullable: true }) + frequencyPenalty!: number | null; + + @Field(() => Number, { nullable: true }) + presencePenalty!: number | null; + + @Field(() => Number, { nullable: true }) + temperature!: number | null; + + @Field(() => Number, { nullable: true }) + topP!: number | null; +} + @InputType('CopilotPromptMessageInput') @ObjectType() class CopilotPromptMessageType { @@ -209,6 +228,9 @@ class CopilotPromptType { @Field(() => String, { nullable: true }) action!: string | null; + @Field(() => CopilotPromptConfigType, { nullable: true }) + config!: CopilotPromptConfigType | null; + @Field(() => [CopilotPromptMessageType]) messages!: CopilotPromptMessageType[]; } @@ -462,6 +484,9 @@ class CreateCopilotPromptInput { @Field(() => String, { nullable: true }) action!: string | null; + @Field(() => CopilotPromptConfigType, { nullable: true }) + config!: CopilotPromptConfigType | null; + @Field(() => [CopilotPromptMessageType]) messages!: CopilotPromptMessageType[]; } @@ -485,7 +510,12 @@ export class PromptsManagementResolver { @Args({ type: () => CreateCopilotPromptInput, name: 'input' }) input: CreateCopilotPromptInput ) { - await this.promptService.set(input.name, input.model, input.messages); + await this.promptService.set( + input.name, + input.model, + input.messages, + input.config + ); return this.promptService.get(input.name); } diff --git a/packages/backend/server/src/plugins/copilot/session.ts b/packages/backend/server/src/plugins/copilot/session.ts index f4f0496ac9..6f6098a7e8 100644 --- a/packages/backend/server/src/plugins/copilot/session.ts +++ b/packages/backend/server/src/plugins/copilot/session.ts @@ -49,10 +49,10 @@ export class ChatSession implements AsyncDisposable { userId, workspaceId, docId, - prompt: { name: promptName }, + prompt: { name: promptName, config: promptConfig }, } = this.state; - return { sessionId, userId, workspaceId, docId, promptName }; + return { sessionId, userId, workspaceId, docId, promptName, promptConfig }; } get stashMessages() { diff --git a/packages/backend/server/src/plugins/copilot/types.ts b/packages/backend/server/src/plugins/copilot/types.ts index 12a649d795..4f5ce3c1d6 100644 --- a/packages/backend/server/src/plugins/copilot/types.ts +++ b/packages/backend/server/src/plugins/copilot/types.ts @@ -63,6 +63,20 @@ export type PromptMessage = z.infer; export type PromptParams = NonNullable; +export const PromptConfigStrictSchema = z.object({ + jsonMode: z.boolean().nullable().optional(), + frequencyPenalty: z.number().nullable().optional(), + presencePenalty: z.number().nullable().optional(), + temperature: z.number().nullable().optional(), + topP: z.number().nullable().optional(), + maxTokens: z.number().nullable().optional(), +}); + +export const PromptConfigSchema = + PromptConfigStrictSchema.nullable().optional(); + +export type PromptConfig = z.infer; + export const ChatMessageSchema = PromptMessageSchema.extend({ id: z.string().optional(), createdAt: z.date(), @@ -144,11 +158,9 @@ const CopilotProviderOptionsSchema = z.object({ user: z.string().optional(), }); -const CopilotChatOptionsSchema = CopilotProviderOptionsSchema.extend({ - jsonMode: z.boolean().optional(), - temperature: z.number().optional(), - maxTokens: z.number().optional(), -}).optional(); +const CopilotChatOptionsSchema = CopilotProviderOptionsSchema.merge( + PromptConfigStrictSchema +).optional(); export type CopilotChatOptions = z.infer; diff --git a/packages/backend/server/src/schema.gql b/packages/backend/server/src/schema.gql index e65530e4fd..f3c524eb78 100644 --- a/packages/backend/server/src/schema.gql +++ b/packages/backend/server/src/schema.gql @@ -57,6 +57,22 @@ enum CopilotModels { TextModerationStable } +input CopilotPromptConfigInput { + frequencyPenalty: Int + jsonMode: Boolean + presencePenalty: Int + temperature: Int + topP: Int +} + +type CopilotPromptConfigType { + frequencyPenalty: Int + jsonMode: Boolean + presencePenalty: Int + temperature: Int + topP: Int +} + input CopilotPromptMessageInput { content: String! params: JSON @@ -81,6 +97,7 @@ type CopilotPromptNotFoundDataType { type CopilotPromptType { action: String + config: CopilotPromptConfigType messages: [CopilotPromptMessageType!]! model: CopilotModels! name: String! @@ -123,6 +140,7 @@ input CreateCheckoutSessionInput { input CreateCopilotPromptInput { action: String + config: CopilotPromptConfigInput messages: [CopilotPromptMessageInput!]! model: CopilotModels! name: String! diff --git a/packages/backend/server/tests/copilot.spec.ts b/packages/backend/server/tests/copilot.spec.ts index 8ec3506a42..3b4a046f70 100644 --- a/packages/backend/server/tests/copilot.spec.ts +++ b/packages/backend/server/tests/copilot.spec.ts @@ -676,7 +676,7 @@ test.skip('should be able to preview workflow', async t => { registerCopilotProvider(OpenAIProvider); for (const p of prompts) { - await prompt.set(p.name, p.model, p.messages); + await prompt.set(p.name, p.model, p.messages, p.config); } let result = ''; @@ -726,7 +726,7 @@ test('should be able to run pre defined workflow', async t => { const { graph, prompts, callCount, input, params, result } = testCase; console.log('running workflow test:', graph.name); for (const p of prompts) { - await prompt.set(p.name, p.model, p.messages); + await prompt.set(p.name, p.model, p.messages, p.config); } for (const [idx, i] of input.entries()) { @@ -773,7 +773,7 @@ test('should be able to run workflow', async t => { const executor = Sinon.spy(executors.text, 'next'); for (const p of prompts) { - await prompt.set(p.name, p.model, p.messages); + await prompt.set(p.name, p.model, p.messages, p.config); } const graphName = 'presentation'; diff --git a/packages/backend/server/tests/utils/copilot.ts b/packages/backend/server/tests/utils/copilot.ts index d6e32a2472..83ec4d7dea 100644 --- a/packages/backend/server/tests/utils/copilot.ts +++ b/packages/backend/server/tests/utils/copilot.ts @@ -17,6 +17,7 @@ import { CopilotTextToEmbeddingProvider, CopilotTextToImageProvider, CopilotTextToTextProvider, + PromptConfig, PromptMessage, } from '../../src/plugins/copilot/types'; import { NodeExecutorType } from '../../src/plugins/copilot/workflow/executor'; @@ -383,7 +384,12 @@ export async function getHistories( return res.body.data.currentUser?.copilot?.histories || []; } -type Prompt = { name: string; model: string; messages: PromptMessage[] }; +type Prompt = { + name: string; + model: string; + messages: PromptMessage[]; + config?: PromptConfig; +}; type WorkflowTestCase = { graph: WorkflowGraph; prompts: Prompt[];