feat: add prompt level config (#7445)

This commit is contained in:
darkskygit
2024-07-08 08:11:22 +00:00
parent 9ef8829ef1
commit bf6c9a5955
12 changed files with 125 additions and 41 deletions

View File

@@ -6,10 +6,19 @@ type PromptMessage = {
params?: Record<string, string | string[]>;
};
type PromptConfig = {
jsonMode?: boolean;
frequencyPenalty?: number;
presencePenalty?: number;
temperature?: number;
maxTokens?: number;
};
type Prompt = {
name: string;
action?: string;
model: string;
config?: PromptConfig;
messages: PromptMessage[];
};
@@ -465,6 +474,7 @@ content: {{content}}`,
name: 'workflow:presentation:step1',
action: 'workflow:presentation:step1',
model: 'gpt-4o',
config: { temperature: 0.7 },
messages: [
{
role: 'system',
@@ -685,6 +695,7 @@ export async function refreshPrompts(db: PrismaClient) {
create: {
name: prompt.name,
action: prompt.action,
config: prompt.config,
model: prompt.model,
messages: {
create: prompt.messages.map((message, idx) => ({

View File

@@ -138,9 +138,8 @@ export class CopilotController {
const messageId = Array.isArray(params.messageId)
? params.messageId[0]
: params.messageId;
const jsonMode = String(params.jsonMode).toLowerCase() === 'true';
delete params.messageId;
return { messageId, jsonMode, params };
return { messageId, params };
}
private getSignal(req: Request) {
@@ -167,7 +166,7 @@ export class CopilotController {
@Param('sessionId') sessionId: string,
@Query() params: Record<string, string | string[]>
): Promise<string> {
const { messageId, jsonMode } = this.prepareParams(params);
const { messageId } = this.prepareParams(params);
const provider = await this.chooseTextProvider(
user.id,
sessionId,
@@ -180,7 +179,11 @@ export class CopilotController {
const content = await provider.generateText(
session.finish(params),
session.model,
{ jsonMode, signal: this.getSignal(req), user: user.id }
{
...session.config.promptConfig,
signal: this.getSignal(req),
user: user.id,
}
);
session.push({
@@ -204,7 +207,7 @@ export class CopilotController {
@Query() params: Record<string, string>
): Promise<Observable<ChatEvent>> {
try {
const { messageId, jsonMode } = this.prepareParams(params);
const { messageId } = this.prepareParams(params);
const provider = await this.chooseTextProvider(
user.id,
sessionId,
@@ -215,7 +218,7 @@ export class CopilotController {
return from(
provider.generateTextStream(session.finish(params), session.model, {
jsonMode,
...session.config.promptConfig,
signal: this.getSignal(req),
user: user.id,
})
@@ -256,7 +259,7 @@ export class CopilotController {
@Query() params: Record<string, string>
): Promise<Observable<ChatEvent>> {
try {
const { messageId, jsonMode } = this.prepareParams(params);
const { messageId } = this.prepareParams(params);
const session = await this.appendSessionMessage(sessionId, messageId);
const latestMessage = session.stashMessages.findLast(
m => m.role === 'user'
@@ -269,7 +272,7 @@ export class CopilotController {
return from(
this.workflow.runGraph(params, session.model, {
jsonMode,
...session.config.promptConfig,
signal: this.getSignal(req),
user: user.id,
})

View File

@@ -5,6 +5,8 @@ import Mustache from 'mustache';
import {
getTokenEncoder,
PromptConfig,
PromptConfigSchema,
PromptMessage,
PromptMessageSchema,
PromptParams,
@@ -35,14 +37,16 @@ export class ChatPrompt {
private readonly templateParams: PromptParams = {};
static createFromPrompt(
options: Omit<AiPrompt, 'id' | 'createdAt'> & {
options: Omit<AiPrompt, 'id' | 'createdAt' | 'config'> & {
messages: PromptMessage[];
config: PromptConfig | undefined;
}
) {
return new ChatPrompt(
options.name,
options.action || undefined,
options.model,
options.config,
options.messages
);
}
@@ -51,6 +55,7 @@ export class ChatPrompt {
public readonly name: string,
public readonly action: string | undefined,
public readonly model: string,
public readonly config: PromptConfig | undefined,
private readonly messages: PromptMessage[]
) {
this.encoder = getTokenEncoder(model);
@@ -185,6 +190,7 @@ export class PromptService {
name: true,
action: true,
model: true,
config: true,
messages: {
select: {
role: true,
@@ -199,9 +205,11 @@ export class PromptService {
});
const messages = PromptMessageSchema.array().safeParse(prompt?.messages);
if (prompt && messages.success) {
const config = PromptConfigSchema.safeParse(prompt?.config);
if (prompt && messages.success && config.success) {
const chatPrompt = ChatPrompt.createFromPrompt({
...prompt,
config: config.data,
messages: messages.data,
});
this.cache.set(name, chatPrompt);
@@ -210,12 +218,18 @@ export class PromptService {
return null;
}
async set(name: string, model: string, messages: PromptMessage[]) {
async set(
name: string,
model: string,
messages: PromptMessage[],
config?: PromptConfig
) {
return await this.db.aiPrompt
.create({
data: {
name,
model,
config: config || undefined,
messages: {
create: messages.map((m, idx) => ({
idx,
@@ -229,10 +243,11 @@ export class PromptService {
.then(ret => ret.id);
}
async update(name: string, messages: PromptMessage[]) {
async update(name: string, messages: PromptMessage[], config?: PromptConfig) {
const { id } = await this.db.aiPrompt.update({
where: { name },
data: {
config: config || undefined,
messages: {
// cleanup old messages
deleteMany: {},

View File

@@ -125,21 +125,6 @@ export class OpenAIProvider
});
}
private extractOptionFromMessages(
messages: PromptMessage[],
options: CopilotChatOptions
) {
const params: Record<string, string | string[]> = {};
for (const message of messages) {
if (message.params) {
Object.assign(params, message.params);
}
}
if (params.jsonMode && options) {
options.jsonMode = String(params.jsonMode).toLowerCase() === 'true';
}
}
protected checkParams({
messages,
embeddings,
@@ -155,7 +140,6 @@ export class OpenAIProvider
throw new CopilotPromptInvalid(`Invalid model: ${model}`);
}
if (Array.isArray(messages) && messages.length > 0) {
this.extractOptionFromMessages(messages, options);
if (
messages.some(
m =>
@@ -257,7 +241,9 @@ export class OpenAIProvider
stream: true,
messages: this.chatToGPTMessage(messages),
model: model,
temperature: options.temperature || 0,
frequency_penalty: options.frequencyPenalty || 0,
presence_penalty: options.presencePenalty || 0,
temperature: options.temperature || 0.5,
max_tokens: options.maxTokens || 4096,
response_format: {
type: options.jsonMode ? 'json_object' : 'text',

View File

@@ -183,6 +183,25 @@ registerEnumType(AiPromptRole, {
name: 'CopilotPromptMessageRole',
});
@InputType('CopilotPromptConfigInput')
@ObjectType()
class CopilotPromptConfigType {
@Field(() => Boolean, { nullable: true })
jsonMode!: boolean | null;
@Field(() => Number, { nullable: true })
frequencyPenalty!: number | null;
@Field(() => Number, { nullable: true })
presencePenalty!: number | null;
@Field(() => Number, { nullable: true })
temperature!: number | null;
@Field(() => Number, { nullable: true })
topP!: number | null;
}
@InputType('CopilotPromptMessageInput')
@ObjectType()
class CopilotPromptMessageType {
@@ -209,6 +228,9 @@ class CopilotPromptType {
@Field(() => String, { nullable: true })
action!: string | null;
@Field(() => CopilotPromptConfigType, { nullable: true })
config!: CopilotPromptConfigType | null;
@Field(() => [CopilotPromptMessageType])
messages!: CopilotPromptMessageType[];
}
@@ -462,6 +484,9 @@ class CreateCopilotPromptInput {
@Field(() => String, { nullable: true })
action!: string | null;
@Field(() => CopilotPromptConfigType, { nullable: true })
config!: CopilotPromptConfigType | null;
@Field(() => [CopilotPromptMessageType])
messages!: CopilotPromptMessageType[];
}
@@ -485,7 +510,12 @@ export class PromptsManagementResolver {
@Args({ type: () => CreateCopilotPromptInput, name: 'input' })
input: CreateCopilotPromptInput
) {
await this.promptService.set(input.name, input.model, input.messages);
await this.promptService.set(
input.name,
input.model,
input.messages,
input.config
);
return this.promptService.get(input.name);
}

View File

@@ -49,10 +49,10 @@ export class ChatSession implements AsyncDisposable {
userId,
workspaceId,
docId,
prompt: { name: promptName },
prompt: { name: promptName, config: promptConfig },
} = this.state;
return { sessionId, userId, workspaceId, docId, promptName };
return { sessionId, userId, workspaceId, docId, promptName, promptConfig };
}
get stashMessages() {

View File

@@ -63,6 +63,20 @@ export type PromptMessage = z.infer<typeof PromptMessageSchema>;
export type PromptParams = NonNullable<PromptMessage['params']>;
export const PromptConfigStrictSchema = z.object({
jsonMode: z.boolean().nullable().optional(),
frequencyPenalty: z.number().nullable().optional(),
presencePenalty: z.number().nullable().optional(),
temperature: z.number().nullable().optional(),
topP: z.number().nullable().optional(),
maxTokens: z.number().nullable().optional(),
});
export const PromptConfigSchema =
PromptConfigStrictSchema.nullable().optional();
export type PromptConfig = z.infer<typeof PromptConfigSchema>;
export const ChatMessageSchema = PromptMessageSchema.extend({
id: z.string().optional(),
createdAt: z.date(),
@@ -144,11 +158,9 @@ const CopilotProviderOptionsSchema = z.object({
user: z.string().optional(),
});
const CopilotChatOptionsSchema = CopilotProviderOptionsSchema.extend({
jsonMode: z.boolean().optional(),
temperature: z.number().optional(),
maxTokens: z.number().optional(),
}).optional();
const CopilotChatOptionsSchema = CopilotProviderOptionsSchema.merge(
PromptConfigStrictSchema
).optional();
export type CopilotChatOptions = z.infer<typeof CopilotChatOptionsSchema>;

View File

@@ -57,6 +57,22 @@ enum CopilotModels {
TextModerationStable
}
input CopilotPromptConfigInput {
frequencyPenalty: Int
jsonMode: Boolean
presencePenalty: Int
temperature: Int
topP: Int
}
type CopilotPromptConfigType {
frequencyPenalty: Int
jsonMode: Boolean
presencePenalty: Int
temperature: Int
topP: Int
}
input CopilotPromptMessageInput {
content: String!
params: JSON
@@ -81,6 +97,7 @@ type CopilotPromptNotFoundDataType {
type CopilotPromptType {
action: String
config: CopilotPromptConfigType
messages: [CopilotPromptMessageType!]!
model: CopilotModels!
name: String!
@@ -123,6 +140,7 @@ input CreateCheckoutSessionInput {
input CreateCopilotPromptInput {
action: String
config: CopilotPromptConfigInput
messages: [CopilotPromptMessageInput!]!
model: CopilotModels!
name: String!