mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-12 04:18:54 +00:00
feat: use default params if not provided (#6701)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
import { AiPrompt, PrismaClient } from '@prisma/client';
|
||||
import Mustache from 'mustache';
|
||||
import { Tiktoken } from 'tiktoken';
|
||||
@@ -26,6 +26,7 @@ function extractMustacheParams(template: string) {
|
||||
}
|
||||
|
||||
export class ChatPrompt {
|
||||
private readonly logger = new Logger(ChatPrompt.name);
|
||||
public readonly encoder?: Tiktoken;
|
||||
private readonly promptTokenSize: number;
|
||||
private readonly templateParamKeys: string[] = [];
|
||||
@@ -88,7 +89,7 @@ export class ChatPrompt {
|
||||
return this.encoder?.encode_ordinary(message).length || 0;
|
||||
}
|
||||
|
||||
private checkParams(params: PromptParams) {
|
||||
private checkParams(params: PromptParams, sessionId?: string) {
|
||||
const selfParams = this.templateParams;
|
||||
for (const key of Object.keys(selfParams)) {
|
||||
const options = selfParams[key];
|
||||
@@ -97,7 +98,20 @@ export class ChatPrompt {
|
||||
typeof income !== 'string' ||
|
||||
(Array.isArray(options) && !options.includes(income))
|
||||
) {
|
||||
throw new Error(`Invalid param: ${key}`);
|
||||
if (sessionId) {
|
||||
const prefix = income
|
||||
? `Invalid param value: ${key}=${income}`
|
||||
: `Missing param value: ${key}`;
|
||||
this.logger.warn(
|
||||
`${prefix} in session ${sessionId}, use default options: ${options[0]}`
|
||||
);
|
||||
}
|
||||
if (Array.isArray(options)) {
|
||||
// use the first option if income is not in options
|
||||
params[key] = options[0];
|
||||
} else {
|
||||
params[key] = options;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -107,8 +121,8 @@ export class ChatPrompt {
|
||||
* @param params record of params, e.g. { name: 'Alice' }
|
||||
* @returns e.g. [{ role: 'system', content: 'Hello, {{name}}' }] => [{ role: 'system', content: 'Hello, Alice' }]
|
||||
*/
|
||||
finish(params: PromptParams): PromptMessage[] {
|
||||
this.checkParams(params);
|
||||
finish(params: PromptParams, sessionId?: string): PromptMessage[] {
|
||||
this.checkParams(params, sessionId);
|
||||
return this.messages.map(({ content, params: _, ...rest }) => ({
|
||||
...rest,
|
||||
params,
|
||||
|
||||
@@ -112,7 +112,8 @@ export class ChatSession implements AsyncDisposable {
|
||||
const messages = this.takeMessages();
|
||||
return [
|
||||
...this.state.prompt.finish(
|
||||
Object.keys(params).length ? params : messages[0]?.params || {}
|
||||
Object.keys(params).length ? params : messages[0]?.params || {},
|
||||
this.config.sessionId
|
||||
),
|
||||
...messages.filter(m => m.content || m.attachments?.length),
|
||||
];
|
||||
@@ -354,7 +355,7 @@ export class ChatSessionService {
|
||||
// render system prompt
|
||||
const preload = withPrompt
|
||||
? prompt
|
||||
.finish(ret.data[0]?.params || {})
|
||||
.finish(ret.data[0]?.params || {}, id)
|
||||
.filter(({ role }) => role !== 'system')
|
||||
: [];
|
||||
|
||||
|
||||
@@ -105,9 +105,14 @@ test('should be able to render prompt', async t => {
|
||||
'should have param keys'
|
||||
);
|
||||
t.deepEqual(testPrompt?.params, msg.params, 'should have params');
|
||||
t.throws(() => testPrompt?.finish({ src_language: 'abc' }), {
|
||||
instanceOf: Error,
|
||||
});
|
||||
// will use first option if a params not provided
|
||||
t.deepEqual(testPrompt?.finish({ src_language: 'abc' }), [
|
||||
{
|
||||
content: 'translate eng to chs: ',
|
||||
params: { dest_language: 'chs', src_language: 'eng' },
|
||||
role: 'system',
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
test('should be able to render listed prompt', async t => {
|
||||
|
||||
Reference in New Issue
Block a user