From a0c219e0363a3b3ab14fdf36e92312df8d86fba8 Mon Sep 17 00:00:00 2001 From: darkskygit Date: Thu, 25 Apr 2024 10:59:46 +0000 Subject: [PATCH] feat: use default params if not provided (#6701) --- .../server/src/plugins/copilot/prompt.ts | 24 +++++++++++++++---- .../server/src/plugins/copilot/session.ts | 5 ++-- packages/backend/server/tests/copilot.spec.ts | 11 ++++++--- 3 files changed, 30 insertions(+), 10 deletions(-) diff --git a/packages/backend/server/src/plugins/copilot/prompt.ts b/packages/backend/server/src/plugins/copilot/prompt.ts index c71c4fc2fd..74c51127e8 100644 --- a/packages/backend/server/src/plugins/copilot/prompt.ts +++ b/packages/backend/server/src/plugins/copilot/prompt.ts @@ -1,4 +1,4 @@ -import { Injectable } from '@nestjs/common'; +import { Injectable, Logger } from '@nestjs/common'; import { AiPrompt, PrismaClient } from '@prisma/client'; import Mustache from 'mustache'; import { Tiktoken } from 'tiktoken'; @@ -26,6 +26,7 @@ function extractMustacheParams(template: string) { } export class ChatPrompt { + private readonly logger = new Logger(ChatPrompt.name); public readonly encoder?: Tiktoken; private readonly promptTokenSize: number; private readonly templateParamKeys: string[] = []; @@ -88,7 +89,7 @@ export class ChatPrompt { return this.encoder?.encode_ordinary(message).length || 0; } - private checkParams(params: PromptParams) { + private checkParams(params: PromptParams, sessionId?: string) { const selfParams = this.templateParams; for (const key of Object.keys(selfParams)) { const options = selfParams[key]; @@ -97,7 +98,20 @@ export class ChatPrompt { typeof income !== 'string' || (Array.isArray(options) && !options.includes(income)) ) { - throw new Error(`Invalid param: ${key}`); + if (sessionId) { + const prefix = income + ? `Invalid param value: ${key}=${income}` + : `Missing param value: ${key}`; + this.logger.warn( + `${prefix} in session ${sessionId}, use default options: ${options[0]}` + ); + } + if (Array.isArray(options)) { + // use the first option if income is not in options + params[key] = options[0]; + } else { + params[key] = options; + } } } } @@ -107,8 +121,8 @@ export class ChatPrompt { * @param params record of params, e.g. { name: 'Alice' } * @returns e.g. [{ role: 'system', content: 'Hello, {{name}}' }] => [{ role: 'system', content: 'Hello, Alice' }] */ - finish(params: PromptParams): PromptMessage[] { - this.checkParams(params); + finish(params: PromptParams, sessionId?: string): PromptMessage[] { + this.checkParams(params, sessionId); return this.messages.map(({ content, params: _, ...rest }) => ({ ...rest, params, diff --git a/packages/backend/server/src/plugins/copilot/session.ts b/packages/backend/server/src/plugins/copilot/session.ts index aa5b033a3c..c7e28c22c1 100644 --- a/packages/backend/server/src/plugins/copilot/session.ts +++ b/packages/backend/server/src/plugins/copilot/session.ts @@ -112,7 +112,8 @@ export class ChatSession implements AsyncDisposable { const messages = this.takeMessages(); return [ ...this.state.prompt.finish( - Object.keys(params).length ? params : messages[0]?.params || {} + Object.keys(params).length ? params : messages[0]?.params || {}, + this.config.sessionId ), ...messages.filter(m => m.content || m.attachments?.length), ]; @@ -354,7 +355,7 @@ export class ChatSessionService { // render system prompt const preload = withPrompt ? prompt - .finish(ret.data[0]?.params || {}) + .finish(ret.data[0]?.params || {}, id) .filter(({ role }) => role !== 'system') : []; diff --git a/packages/backend/server/tests/copilot.spec.ts b/packages/backend/server/tests/copilot.spec.ts index 141e6c477b..75b023ec3d 100644 --- a/packages/backend/server/tests/copilot.spec.ts +++ b/packages/backend/server/tests/copilot.spec.ts @@ -105,9 +105,14 @@ test('should be able to render prompt', async t => { 'should have param keys' ); t.deepEqual(testPrompt?.params, msg.params, 'should have params'); - t.throws(() => testPrompt?.finish({ src_language: 'abc' }), { - instanceOf: Error, - }); + // will use first option if a params not provided + t.deepEqual(testPrompt?.finish({ src_language: 'abc' }), [ + { + content: 'translate eng to chs: ', + params: { dest_language: 'chs', src_language: 'eng' }, + role: 'system', + }, + ]); }); test('should be able to render listed prompt', async t => {