diff --git a/packages/backend/server/src/plugins/copilot/controller.ts b/packages/backend/server/src/plugins/copilot/controller.ts index 0caddcc529..6fcb935ea3 100644 --- a/packages/backend/server/src/plugins/copilot/controller.ts +++ b/packages/backend/server/src/plugins/copilot/controller.ts @@ -100,6 +100,17 @@ export class CopilotController { return controller.signal; } + private parseNumber(value: string | string[] | undefined) { + if (!value) { + return undefined; + } + const num = Number.parseInt(Array.isArray(value) ? value[0] : value, 10); + if (Number.isNaN(num)) { + return undefined; + } + return num; + } + private handleError(err: any) { if (err instanceof Error) { const ret = { @@ -256,6 +267,7 @@ export class CopilotController { return from( provider.generateImagesStream(session.finish(params), session.model, { + seed: this.parseNumber(params.seed), signal: this.getSignal(req), user: user.id, }) diff --git a/packages/backend/server/src/plugins/copilot/providers/fal.ts b/packages/backend/server/src/plugins/copilot/providers/fal.ts index b6b1731b7d..7752bb93c7 100644 --- a/packages/backend/server/src/plugins/copilot/providers/fal.ts +++ b/packages/backend/server/src/plugins/copilot/providers/fal.ts @@ -2,6 +2,7 @@ import assert from 'node:assert'; import { CopilotCapability, + CopilotImageOptions, CopilotImageToImageProvider, CopilotProviderType, CopilotTextToImageProvider, @@ -57,10 +58,7 @@ export class FalProvider async generateImages( messages: PromptMessage[], model: string = this.availableModels[0], - options: { - signal?: AbortSignal; - user?: string; - } = {} + options: CopilotImageOptions = {} ): Promise> { const { content, attachments } = messages.pop() || {}; if (!this.availableModels.includes(model)) { @@ -82,7 +80,7 @@ export class FalProvider image_url: attachments?.[0], prompt: content, sync_mode: true, - seed: 42, + seed: options.seed || 42, enable_safety_checks: false, }), signal: options.signal, @@ -100,10 +98,7 @@ export class FalProvider async *generateImagesStream( messages: PromptMessage[], model: string = this.availableModels[0], - options: { - signal?: AbortSignal; - user?: string; - } = {} + options: CopilotImageOptions = {} ): AsyncIterable { const ret = await this.generateImages(messages, model, options); for (const url of ret) { diff --git a/packages/backend/server/src/plugins/copilot/providers/openai.ts b/packages/backend/server/src/plugins/copilot/providers/openai.ts index 21ef0eea5e..b44f7d4ba8 100644 --- a/packages/backend/server/src/plugins/copilot/providers/openai.ts +++ b/packages/backend/server/src/plugins/copilot/providers/openai.ts @@ -5,6 +5,9 @@ import { ClientOptions, OpenAI } from 'openai'; import { ChatMessageRole, CopilotCapability, + CopilotChatOptions, + CopilotEmbeddingOptions, + CopilotImageOptions, CopilotImageToTextProvider, CopilotProviderType, CopilotTextToEmbeddingProvider, @@ -147,12 +150,7 @@ export class OpenAIProvider async generateText( messages: PromptMessage[], model: string = 'gpt-3.5-turbo', - options: { - temperature?: number; - maxTokens?: number; - signal?: AbortSignal; - user?: string; - } = {} + options: CopilotChatOptions = {} ): Promise { this.checkParams({ messages, model }); const result = await this.instance.chat.completions.create( @@ -175,12 +173,7 @@ export class OpenAIProvider async *generateTextStream( messages: PromptMessage[], model: string = 'gpt-3.5-turbo', - options: { - temperature?: number; - maxTokens?: number; - signal?: AbortSignal; - user?: string; - } = {} + options: CopilotChatOptions = {} ): AsyncIterable { this.checkParams({ messages, model }); const result = await this.instance.chat.completions.create( @@ -214,11 +207,7 @@ export class OpenAIProvider async generateEmbedding( messages: string | string[], model: string, - options: { - dimensions: number; - signal?: AbortSignal; - user?: string; - } = { dimensions: DEFAULT_DIMENSIONS } + options: CopilotEmbeddingOptions = { dimensions: DEFAULT_DIMENSIONS } ): Promise { messages = Array.isArray(messages) ? messages : [messages]; this.checkParams({ embeddings: messages, model }); @@ -236,10 +225,7 @@ export class OpenAIProvider async generateImages( messages: PromptMessage[], model: string = 'dall-e-3', - options: { - signal?: AbortSignal; - user?: string; - } = {} + options: CopilotImageOptions = {} ): Promise> { const { content: prompt } = messages.pop() || {}; if (!prompt) { @@ -261,10 +247,7 @@ export class OpenAIProvider async *generateImagesStream( messages: PromptMessage[], model: string = 'dall-e-3', - options: { - signal?: AbortSignal; - user?: string; - } = {} + options: CopilotImageOptions = {} ): AsyncIterable { const ret = await this.generateImages(messages, model, options); for (const url of ret) { diff --git a/packages/backend/server/src/plugins/copilot/types.ts b/packages/backend/server/src/plugins/copilot/types.ts index 2e707f96c7..64a770d635 100644 --- a/packages/backend/server/src/plugins/copilot/types.ts +++ b/packages/backend/server/src/plugins/copilot/types.ts @@ -143,6 +143,32 @@ export enum CopilotCapability { ImageToText = 'image-to-text', } +const CopilotProviderOptionsSchema = z.object({ + signal: z.instanceof(AbortSignal).optional(), + user: z.string().optional(), +}); + +const CopilotChatOptionsSchema = CopilotProviderOptionsSchema.extend({ + temperature: z.number().optional(), + maxTokens: z.number().optional(), +}).optional(); + +export type CopilotChatOptions = z.infer; + +const CopilotEmbeddingOptionsSchema = CopilotProviderOptionsSchema.extend({ + dimensions: z.number(), +}).optional(); + +export type CopilotEmbeddingOptions = z.infer< + typeof CopilotEmbeddingOptionsSchema +>; + +const CopilotImageOptionsSchema = CopilotProviderOptionsSchema.extend({ + seed: z.number().optional(), +}).optional(); + +export type CopilotImageOptions = z.infer; + export interface CopilotProvider { readonly type: CopilotProviderType; getCapabilities(): CopilotCapability[]; @@ -153,22 +179,12 @@ export interface CopilotTextToTextProvider extends CopilotProvider { generateText( messages: PromptMessage[], model?: string, - options?: { - temperature?: number; - maxTokens?: number; - signal?: AbortSignal; - user?: string; - } + options?: CopilotChatOptions ): Promise; generateTextStream( messages: PromptMessage[], model?: string, - options?: { - temperature?: number; - maxTokens?: number; - signal?: AbortSignal; - user?: string; - } + options?: CopilotChatOptions ): AsyncIterable; } @@ -176,11 +192,7 @@ export interface CopilotTextToEmbeddingProvider extends CopilotProvider { generateEmbedding( messages: string[] | string, model: string, - options: { - dimensions: number; - signal?: AbortSignal; - user?: string; - } + options?: CopilotEmbeddingOptions ): Promise; } @@ -188,18 +200,12 @@ export interface CopilotTextToImageProvider extends CopilotProvider { generateImages( messages: PromptMessage[], model: string, - options: { - signal?: AbortSignal; - user?: string; - } + options?: CopilotImageOptions ): Promise>; generateImagesStream( messages: PromptMessage[], model?: string, - options?: { - signal?: AbortSignal; - user?: string; - } + options?: CopilotImageOptions ): AsyncIterable; } @@ -207,22 +213,12 @@ export interface CopilotImageToTextProvider extends CopilotProvider { generateText( messages: PromptMessage[], model: string, - options: { - temperature?: number; - maxTokens?: number; - signal?: AbortSignal; - user?: string; - } + options?: CopilotChatOptions ): Promise; generateTextStream( messages: PromptMessage[], model: string, - options: { - temperature?: number; - maxTokens?: number; - signal?: AbortSignal; - user?: string; - } + options?: CopilotChatOptions ): AsyncIterable; } @@ -230,18 +226,12 @@ export interface CopilotImageToImageProvider extends CopilotProvider { generateImages( messages: PromptMessage[], model: string, - options: { - signal?: AbortSignal; - user?: string; - } + options?: CopilotImageOptions ): Promise>; generateImagesStream( messages: PromptMessage[], model?: string, - options?: { - signal?: AbortSignal; - user?: string; - } + options?: CopilotImageOptions ): AsyncIterable; }