feat: use default params if not provided (#6701)

This commit is contained in:
darkskygit
2024-04-25 10:59:46 +00:00
parent 3297486e31
commit a0c219e036
3 changed files with 30 additions and 10 deletions

View File

@@ -1,4 +1,4 @@
import { Injectable } from '@nestjs/common';
import { Injectable, Logger } from '@nestjs/common';
import { AiPrompt, PrismaClient } from '@prisma/client';
import Mustache from 'mustache';
import { Tiktoken } from 'tiktoken';
@@ -26,6 +26,7 @@ function extractMustacheParams(template: string) {
}
export class ChatPrompt {
private readonly logger = new Logger(ChatPrompt.name);
public readonly encoder?: Tiktoken;
private readonly promptTokenSize: number;
private readonly templateParamKeys: string[] = [];
@@ -88,7 +89,7 @@ export class ChatPrompt {
return this.encoder?.encode_ordinary(message).length || 0;
}
private checkParams(params: PromptParams) {
private checkParams(params: PromptParams, sessionId?: string) {
const selfParams = this.templateParams;
for (const key of Object.keys(selfParams)) {
const options = selfParams[key];
@@ -97,7 +98,20 @@ export class ChatPrompt {
typeof income !== 'string' ||
(Array.isArray(options) && !options.includes(income))
) {
throw new Error(`Invalid param: ${key}`);
if (sessionId) {
const prefix = income
? `Invalid param value: ${key}=${income}`
: `Missing param value: ${key}`;
this.logger.warn(
`${prefix} in session ${sessionId}, use default options: ${options[0]}`
);
}
if (Array.isArray(options)) {
// use the first option if income is not in options
params[key] = options[0];
} else {
params[key] = options;
}
}
}
}
@@ -107,8 +121,8 @@ export class ChatPrompt {
* @param params record of params, e.g. { name: 'Alice' }
* @returns e.g. [{ role: 'system', content: 'Hello, {{name}}' }] => [{ role: 'system', content: 'Hello, Alice' }]
*/
finish(params: PromptParams): PromptMessage[] {
this.checkParams(params);
finish(params: PromptParams, sessionId?: string): PromptMessage[] {
this.checkParams(params, sessionId);
return this.messages.map(({ content, params: _, ...rest }) => ({
...rest,
params,

View File

@@ -112,7 +112,8 @@ export class ChatSession implements AsyncDisposable {
const messages = this.takeMessages();
return [
...this.state.prompt.finish(
Object.keys(params).length ? params : messages[0]?.params || {}
Object.keys(params).length ? params : messages[0]?.params || {},
this.config.sessionId
),
...messages.filter(m => m.content || m.attachments?.length),
];
@@ -354,7 +355,7 @@ export class ChatSessionService {
// render system prompt
const preload = withPrompt
? prompt
.finish(ret.data[0]?.params || {})
.finish(ret.data[0]?.params || {}, id)
.filter(({ role }) => role !== 'system')
: [];