mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-13 12:55:00 +00:00
feat: add prompt level config (#7445)
This commit is contained in:
@@ -6,10 +6,19 @@ type PromptMessage = {
|
||||
params?: Record<string, string | string[]>;
|
||||
};
|
||||
|
||||
type PromptConfig = {
|
||||
jsonMode?: boolean;
|
||||
frequencyPenalty?: number;
|
||||
presencePenalty?: number;
|
||||
temperature?: number;
|
||||
maxTokens?: number;
|
||||
};
|
||||
|
||||
type Prompt = {
|
||||
name: string;
|
||||
action?: string;
|
||||
model: string;
|
||||
config?: PromptConfig;
|
||||
messages: PromptMessage[];
|
||||
};
|
||||
|
||||
@@ -465,6 +474,7 @@ content: {{content}}`,
|
||||
name: 'workflow:presentation:step1',
|
||||
action: 'workflow:presentation:step1',
|
||||
model: 'gpt-4o',
|
||||
config: { temperature: 0.7 },
|
||||
messages: [
|
||||
{
|
||||
role: 'system',
|
||||
@@ -685,6 +695,7 @@ export async function refreshPrompts(db: PrismaClient) {
|
||||
create: {
|
||||
name: prompt.name,
|
||||
action: prompt.action,
|
||||
config: prompt.config,
|
||||
model: prompt.model,
|
||||
messages: {
|
||||
create: prompt.messages.map((message, idx) => ({
|
||||
|
||||
@@ -138,9 +138,8 @@ export class CopilotController {
|
||||
const messageId = Array.isArray(params.messageId)
|
||||
? params.messageId[0]
|
||||
: params.messageId;
|
||||
const jsonMode = String(params.jsonMode).toLowerCase() === 'true';
|
||||
delete params.messageId;
|
||||
return { messageId, jsonMode, params };
|
||||
return { messageId, params };
|
||||
}
|
||||
|
||||
private getSignal(req: Request) {
|
||||
@@ -167,7 +166,7 @@ export class CopilotController {
|
||||
@Param('sessionId') sessionId: string,
|
||||
@Query() params: Record<string, string | string[]>
|
||||
): Promise<string> {
|
||||
const { messageId, jsonMode } = this.prepareParams(params);
|
||||
const { messageId } = this.prepareParams(params);
|
||||
const provider = await this.chooseTextProvider(
|
||||
user.id,
|
||||
sessionId,
|
||||
@@ -180,7 +179,11 @@ export class CopilotController {
|
||||
const content = await provider.generateText(
|
||||
session.finish(params),
|
||||
session.model,
|
||||
{ jsonMode, signal: this.getSignal(req), user: user.id }
|
||||
{
|
||||
...session.config.promptConfig,
|
||||
signal: this.getSignal(req),
|
||||
user: user.id,
|
||||
}
|
||||
);
|
||||
|
||||
session.push({
|
||||
@@ -204,7 +207,7 @@ export class CopilotController {
|
||||
@Query() params: Record<string, string>
|
||||
): Promise<Observable<ChatEvent>> {
|
||||
try {
|
||||
const { messageId, jsonMode } = this.prepareParams(params);
|
||||
const { messageId } = this.prepareParams(params);
|
||||
const provider = await this.chooseTextProvider(
|
||||
user.id,
|
||||
sessionId,
|
||||
@@ -215,7 +218,7 @@ export class CopilotController {
|
||||
|
||||
return from(
|
||||
provider.generateTextStream(session.finish(params), session.model, {
|
||||
jsonMode,
|
||||
...session.config.promptConfig,
|
||||
signal: this.getSignal(req),
|
||||
user: user.id,
|
||||
})
|
||||
@@ -256,7 +259,7 @@ export class CopilotController {
|
||||
@Query() params: Record<string, string>
|
||||
): Promise<Observable<ChatEvent>> {
|
||||
try {
|
||||
const { messageId, jsonMode } = this.prepareParams(params);
|
||||
const { messageId } = this.prepareParams(params);
|
||||
const session = await this.appendSessionMessage(sessionId, messageId);
|
||||
const latestMessage = session.stashMessages.findLast(
|
||||
m => m.role === 'user'
|
||||
@@ -269,7 +272,7 @@ export class CopilotController {
|
||||
|
||||
return from(
|
||||
this.workflow.runGraph(params, session.model, {
|
||||
jsonMode,
|
||||
...session.config.promptConfig,
|
||||
signal: this.getSignal(req),
|
||||
user: user.id,
|
||||
})
|
||||
|
||||
@@ -5,6 +5,8 @@ import Mustache from 'mustache';
|
||||
|
||||
import {
|
||||
getTokenEncoder,
|
||||
PromptConfig,
|
||||
PromptConfigSchema,
|
||||
PromptMessage,
|
||||
PromptMessageSchema,
|
||||
PromptParams,
|
||||
@@ -35,14 +37,16 @@ export class ChatPrompt {
|
||||
private readonly templateParams: PromptParams = {};
|
||||
|
||||
static createFromPrompt(
|
||||
options: Omit<AiPrompt, 'id' | 'createdAt'> & {
|
||||
options: Omit<AiPrompt, 'id' | 'createdAt' | 'config'> & {
|
||||
messages: PromptMessage[];
|
||||
config: PromptConfig | undefined;
|
||||
}
|
||||
) {
|
||||
return new ChatPrompt(
|
||||
options.name,
|
||||
options.action || undefined,
|
||||
options.model,
|
||||
options.config,
|
||||
options.messages
|
||||
);
|
||||
}
|
||||
@@ -51,6 +55,7 @@ export class ChatPrompt {
|
||||
public readonly name: string,
|
||||
public readonly action: string | undefined,
|
||||
public readonly model: string,
|
||||
public readonly config: PromptConfig | undefined,
|
||||
private readonly messages: PromptMessage[]
|
||||
) {
|
||||
this.encoder = getTokenEncoder(model);
|
||||
@@ -185,6 +190,7 @@ export class PromptService {
|
||||
name: true,
|
||||
action: true,
|
||||
model: true,
|
||||
config: true,
|
||||
messages: {
|
||||
select: {
|
||||
role: true,
|
||||
@@ -199,9 +205,11 @@ export class PromptService {
|
||||
});
|
||||
|
||||
const messages = PromptMessageSchema.array().safeParse(prompt?.messages);
|
||||
if (prompt && messages.success) {
|
||||
const config = PromptConfigSchema.safeParse(prompt?.config);
|
||||
if (prompt && messages.success && config.success) {
|
||||
const chatPrompt = ChatPrompt.createFromPrompt({
|
||||
...prompt,
|
||||
config: config.data,
|
||||
messages: messages.data,
|
||||
});
|
||||
this.cache.set(name, chatPrompt);
|
||||
@@ -210,12 +218,18 @@ export class PromptService {
|
||||
return null;
|
||||
}
|
||||
|
||||
async set(name: string, model: string, messages: PromptMessage[]) {
|
||||
async set(
|
||||
name: string,
|
||||
model: string,
|
||||
messages: PromptMessage[],
|
||||
config?: PromptConfig
|
||||
) {
|
||||
return await this.db.aiPrompt
|
||||
.create({
|
||||
data: {
|
||||
name,
|
||||
model,
|
||||
config: config || undefined,
|
||||
messages: {
|
||||
create: messages.map((m, idx) => ({
|
||||
idx,
|
||||
@@ -229,10 +243,11 @@ export class PromptService {
|
||||
.then(ret => ret.id);
|
||||
}
|
||||
|
||||
async update(name: string, messages: PromptMessage[]) {
|
||||
async update(name: string, messages: PromptMessage[], config?: PromptConfig) {
|
||||
const { id } = await this.db.aiPrompt.update({
|
||||
where: { name },
|
||||
data: {
|
||||
config: config || undefined,
|
||||
messages: {
|
||||
// cleanup old messages
|
||||
deleteMany: {},
|
||||
|
||||
@@ -125,21 +125,6 @@ export class OpenAIProvider
|
||||
});
|
||||
}
|
||||
|
||||
private extractOptionFromMessages(
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions
|
||||
) {
|
||||
const params: Record<string, string | string[]> = {};
|
||||
for (const message of messages) {
|
||||
if (message.params) {
|
||||
Object.assign(params, message.params);
|
||||
}
|
||||
}
|
||||
if (params.jsonMode && options) {
|
||||
options.jsonMode = String(params.jsonMode).toLowerCase() === 'true';
|
||||
}
|
||||
}
|
||||
|
||||
protected checkParams({
|
||||
messages,
|
||||
embeddings,
|
||||
@@ -155,7 +140,6 @@ export class OpenAIProvider
|
||||
throw new CopilotPromptInvalid(`Invalid model: ${model}`);
|
||||
}
|
||||
if (Array.isArray(messages) && messages.length > 0) {
|
||||
this.extractOptionFromMessages(messages, options);
|
||||
if (
|
||||
messages.some(
|
||||
m =>
|
||||
@@ -257,7 +241,9 @@ export class OpenAIProvider
|
||||
stream: true,
|
||||
messages: this.chatToGPTMessage(messages),
|
||||
model: model,
|
||||
temperature: options.temperature || 0,
|
||||
frequency_penalty: options.frequencyPenalty || 0,
|
||||
presence_penalty: options.presencePenalty || 0,
|
||||
temperature: options.temperature || 0.5,
|
||||
max_tokens: options.maxTokens || 4096,
|
||||
response_format: {
|
||||
type: options.jsonMode ? 'json_object' : 'text',
|
||||
|
||||
@@ -183,6 +183,25 @@ registerEnumType(AiPromptRole, {
|
||||
name: 'CopilotPromptMessageRole',
|
||||
});
|
||||
|
||||
@InputType('CopilotPromptConfigInput')
|
||||
@ObjectType()
|
||||
class CopilotPromptConfigType {
|
||||
@Field(() => Boolean, { nullable: true })
|
||||
jsonMode!: boolean | null;
|
||||
|
||||
@Field(() => Number, { nullable: true })
|
||||
frequencyPenalty!: number | null;
|
||||
|
||||
@Field(() => Number, { nullable: true })
|
||||
presencePenalty!: number | null;
|
||||
|
||||
@Field(() => Number, { nullable: true })
|
||||
temperature!: number | null;
|
||||
|
||||
@Field(() => Number, { nullable: true })
|
||||
topP!: number | null;
|
||||
}
|
||||
|
||||
@InputType('CopilotPromptMessageInput')
|
||||
@ObjectType()
|
||||
class CopilotPromptMessageType {
|
||||
@@ -209,6 +228,9 @@ class CopilotPromptType {
|
||||
@Field(() => String, { nullable: true })
|
||||
action!: string | null;
|
||||
|
||||
@Field(() => CopilotPromptConfigType, { nullable: true })
|
||||
config!: CopilotPromptConfigType | null;
|
||||
|
||||
@Field(() => [CopilotPromptMessageType])
|
||||
messages!: CopilotPromptMessageType[];
|
||||
}
|
||||
@@ -462,6 +484,9 @@ class CreateCopilotPromptInput {
|
||||
@Field(() => String, { nullable: true })
|
||||
action!: string | null;
|
||||
|
||||
@Field(() => CopilotPromptConfigType, { nullable: true })
|
||||
config!: CopilotPromptConfigType | null;
|
||||
|
||||
@Field(() => [CopilotPromptMessageType])
|
||||
messages!: CopilotPromptMessageType[];
|
||||
}
|
||||
@@ -485,7 +510,12 @@ export class PromptsManagementResolver {
|
||||
@Args({ type: () => CreateCopilotPromptInput, name: 'input' })
|
||||
input: CreateCopilotPromptInput
|
||||
) {
|
||||
await this.promptService.set(input.name, input.model, input.messages);
|
||||
await this.promptService.set(
|
||||
input.name,
|
||||
input.model,
|
||||
input.messages,
|
||||
input.config
|
||||
);
|
||||
return this.promptService.get(input.name);
|
||||
}
|
||||
|
||||
|
||||
@@ -49,10 +49,10 @@ export class ChatSession implements AsyncDisposable {
|
||||
userId,
|
||||
workspaceId,
|
||||
docId,
|
||||
prompt: { name: promptName },
|
||||
prompt: { name: promptName, config: promptConfig },
|
||||
} = this.state;
|
||||
|
||||
return { sessionId, userId, workspaceId, docId, promptName };
|
||||
return { sessionId, userId, workspaceId, docId, promptName, promptConfig };
|
||||
}
|
||||
|
||||
get stashMessages() {
|
||||
|
||||
@@ -63,6 +63,20 @@ export type PromptMessage = z.infer<typeof PromptMessageSchema>;
|
||||
|
||||
export type PromptParams = NonNullable<PromptMessage['params']>;
|
||||
|
||||
export const PromptConfigStrictSchema = z.object({
|
||||
jsonMode: z.boolean().nullable().optional(),
|
||||
frequencyPenalty: z.number().nullable().optional(),
|
||||
presencePenalty: z.number().nullable().optional(),
|
||||
temperature: z.number().nullable().optional(),
|
||||
topP: z.number().nullable().optional(),
|
||||
maxTokens: z.number().nullable().optional(),
|
||||
});
|
||||
|
||||
export const PromptConfigSchema =
|
||||
PromptConfigStrictSchema.nullable().optional();
|
||||
|
||||
export type PromptConfig = z.infer<typeof PromptConfigSchema>;
|
||||
|
||||
export const ChatMessageSchema = PromptMessageSchema.extend({
|
||||
id: z.string().optional(),
|
||||
createdAt: z.date(),
|
||||
@@ -144,11 +158,9 @@ const CopilotProviderOptionsSchema = z.object({
|
||||
user: z.string().optional(),
|
||||
});
|
||||
|
||||
const CopilotChatOptionsSchema = CopilotProviderOptionsSchema.extend({
|
||||
jsonMode: z.boolean().optional(),
|
||||
temperature: z.number().optional(),
|
||||
maxTokens: z.number().optional(),
|
||||
}).optional();
|
||||
const CopilotChatOptionsSchema = CopilotProviderOptionsSchema.merge(
|
||||
PromptConfigStrictSchema
|
||||
).optional();
|
||||
|
||||
export type CopilotChatOptions = z.infer<typeof CopilotChatOptionsSchema>;
|
||||
|
||||
|
||||
@@ -57,6 +57,22 @@ enum CopilotModels {
|
||||
TextModerationStable
|
||||
}
|
||||
|
||||
input CopilotPromptConfigInput {
|
||||
frequencyPenalty: Int
|
||||
jsonMode: Boolean
|
||||
presencePenalty: Int
|
||||
temperature: Int
|
||||
topP: Int
|
||||
}
|
||||
|
||||
type CopilotPromptConfigType {
|
||||
frequencyPenalty: Int
|
||||
jsonMode: Boolean
|
||||
presencePenalty: Int
|
||||
temperature: Int
|
||||
topP: Int
|
||||
}
|
||||
|
||||
input CopilotPromptMessageInput {
|
||||
content: String!
|
||||
params: JSON
|
||||
@@ -81,6 +97,7 @@ type CopilotPromptNotFoundDataType {
|
||||
|
||||
type CopilotPromptType {
|
||||
action: String
|
||||
config: CopilotPromptConfigType
|
||||
messages: [CopilotPromptMessageType!]!
|
||||
model: CopilotModels!
|
||||
name: String!
|
||||
@@ -123,6 +140,7 @@ input CreateCheckoutSessionInput {
|
||||
|
||||
input CreateCopilotPromptInput {
|
||||
action: String
|
||||
config: CopilotPromptConfigInput
|
||||
messages: [CopilotPromptMessageInput!]!
|
||||
model: CopilotModels!
|
||||
name: String!
|
||||
|
||||
Reference in New Issue
Block a user