mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-12 12:28:42 +00:00
feat: use default params if not provided (#6701)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
import { AiPrompt, PrismaClient } from '@prisma/client';
|
||||
import Mustache from 'mustache';
|
||||
import { Tiktoken } from 'tiktoken';
|
||||
@@ -26,6 +26,7 @@ function extractMustacheParams(template: string) {
|
||||
}
|
||||
|
||||
export class ChatPrompt {
|
||||
private readonly logger = new Logger(ChatPrompt.name);
|
||||
public readonly encoder?: Tiktoken;
|
||||
private readonly promptTokenSize: number;
|
||||
private readonly templateParamKeys: string[] = [];
|
||||
@@ -88,7 +89,7 @@ export class ChatPrompt {
|
||||
return this.encoder?.encode_ordinary(message).length || 0;
|
||||
}
|
||||
|
||||
private checkParams(params: PromptParams) {
|
||||
private checkParams(params: PromptParams, sessionId?: string) {
|
||||
const selfParams = this.templateParams;
|
||||
for (const key of Object.keys(selfParams)) {
|
||||
const options = selfParams[key];
|
||||
@@ -97,7 +98,20 @@ export class ChatPrompt {
|
||||
typeof income !== 'string' ||
|
||||
(Array.isArray(options) && !options.includes(income))
|
||||
) {
|
||||
throw new Error(`Invalid param: ${key}`);
|
||||
if (sessionId) {
|
||||
const prefix = income
|
||||
? `Invalid param value: ${key}=${income}`
|
||||
: `Missing param value: ${key}`;
|
||||
this.logger.warn(
|
||||
`${prefix} in session ${sessionId}, use default options: ${options[0]}`
|
||||
);
|
||||
}
|
||||
if (Array.isArray(options)) {
|
||||
// use the first option if income is not in options
|
||||
params[key] = options[0];
|
||||
} else {
|
||||
params[key] = options;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -107,8 +121,8 @@ export class ChatPrompt {
|
||||
* @param params record of params, e.g. { name: 'Alice' }
|
||||
* @returns e.g. [{ role: 'system', content: 'Hello, {{name}}' }] => [{ role: 'system', content: 'Hello, Alice' }]
|
||||
*/
|
||||
finish(params: PromptParams): PromptMessage[] {
|
||||
this.checkParams(params);
|
||||
finish(params: PromptParams, sessionId?: string): PromptMessage[] {
|
||||
this.checkParams(params, sessionId);
|
||||
return this.messages.map(({ content, params: _, ...rest }) => ({
|
||||
...rest,
|
||||
params,
|
||||
|
||||
@@ -112,7 +112,8 @@ export class ChatSession implements AsyncDisposable {
|
||||
const messages = this.takeMessages();
|
||||
return [
|
||||
...this.state.prompt.finish(
|
||||
Object.keys(params).length ? params : messages[0]?.params || {}
|
||||
Object.keys(params).length ? params : messages[0]?.params || {},
|
||||
this.config.sessionId
|
||||
),
|
||||
...messages.filter(m => m.content || m.attachments?.length),
|
||||
];
|
||||
@@ -354,7 +355,7 @@ export class ChatSessionService {
|
||||
// render system prompt
|
||||
const preload = withPrompt
|
||||
? prompt
|
||||
.finish(ret.data[0]?.params || {})
|
||||
.finish(ret.data[0]?.params || {}, id)
|
||||
.filter(({ role }) => role !== 'system')
|
||||
: [];
|
||||
|
||||
|
||||
Reference in New Issue
Block a user