feat: allow custom seed (#6709)

This commit is contained in:
darkskygit
2024-04-26 11:40:07 +00:00
parent 5d114ea965
commit b639e52dca
4 changed files with 59 additions and 79 deletions

View File

@@ -100,6 +100,17 @@ export class CopilotController {
return controller.signal;
}
private parseNumber(value: string | string[] | undefined) {
if (!value) {
return undefined;
}
const num = Number.parseInt(Array.isArray(value) ? value[0] : value, 10);
if (Number.isNaN(num)) {
return undefined;
}
return num;
}
private handleError(err: any) {
if (err instanceof Error) {
const ret = {
@@ -256,6 +267,7 @@ export class CopilotController {
return from(
provider.generateImagesStream(session.finish(params), session.model, {
seed: this.parseNumber(params.seed),
signal: this.getSignal(req),
user: user.id,
})

View File

@@ -2,6 +2,7 @@ import assert from 'node:assert';
import {
CopilotCapability,
CopilotImageOptions,
CopilotImageToImageProvider,
CopilotProviderType,
CopilotTextToImageProvider,
@@ -57,10 +58,7 @@ export class FalProvider
async generateImages(
messages: PromptMessage[],
model: string = this.availableModels[0],
options: {
signal?: AbortSignal;
user?: string;
} = {}
options: CopilotImageOptions = {}
): Promise<Array<string>> {
const { content, attachments } = messages.pop() || {};
if (!this.availableModels.includes(model)) {
@@ -82,7 +80,7 @@ export class FalProvider
image_url: attachments?.[0],
prompt: content,
sync_mode: true,
seed: 42,
seed: options.seed || 42,
enable_safety_checks: false,
}),
signal: options.signal,
@@ -100,10 +98,7 @@ export class FalProvider
async *generateImagesStream(
messages: PromptMessage[],
model: string = this.availableModels[0],
options: {
signal?: AbortSignal;
user?: string;
} = {}
options: CopilotImageOptions = {}
): AsyncIterable<string> {
const ret = await this.generateImages(messages, model, options);
for (const url of ret) {

View File

@@ -5,6 +5,9 @@ import { ClientOptions, OpenAI } from 'openai';
import {
ChatMessageRole,
CopilotCapability,
CopilotChatOptions,
CopilotEmbeddingOptions,
CopilotImageOptions,
CopilotImageToTextProvider,
CopilotProviderType,
CopilotTextToEmbeddingProvider,
@@ -147,12 +150,7 @@ export class OpenAIProvider
async generateText(
messages: PromptMessage[],
model: string = 'gpt-3.5-turbo',
options: {
temperature?: number;
maxTokens?: number;
signal?: AbortSignal;
user?: string;
} = {}
options: CopilotChatOptions = {}
): Promise<string> {
this.checkParams({ messages, model });
const result = await this.instance.chat.completions.create(
@@ -175,12 +173,7 @@ export class OpenAIProvider
async *generateTextStream(
messages: PromptMessage[],
model: string = 'gpt-3.5-turbo',
options: {
temperature?: number;
maxTokens?: number;
signal?: AbortSignal;
user?: string;
} = {}
options: CopilotChatOptions = {}
): AsyncIterable<string> {
this.checkParams({ messages, model });
const result = await this.instance.chat.completions.create(
@@ -214,11 +207,7 @@ export class OpenAIProvider
async generateEmbedding(
messages: string | string[],
model: string,
options: {
dimensions: number;
signal?: AbortSignal;
user?: string;
} = { dimensions: DEFAULT_DIMENSIONS }
options: CopilotEmbeddingOptions = { dimensions: DEFAULT_DIMENSIONS }
): Promise<number[][]> {
messages = Array.isArray(messages) ? messages : [messages];
this.checkParams({ embeddings: messages, model });
@@ -236,10 +225,7 @@ export class OpenAIProvider
async generateImages(
messages: PromptMessage[],
model: string = 'dall-e-3',
options: {
signal?: AbortSignal;
user?: string;
} = {}
options: CopilotImageOptions = {}
): Promise<Array<string>> {
const { content: prompt } = messages.pop() || {};
if (!prompt) {
@@ -261,10 +247,7 @@ export class OpenAIProvider
async *generateImagesStream(
messages: PromptMessage[],
model: string = 'dall-e-3',
options: {
signal?: AbortSignal;
user?: string;
} = {}
options: CopilotImageOptions = {}
): AsyncIterable<string> {
const ret = await this.generateImages(messages, model, options);
for (const url of ret) {

View File

@@ -143,6 +143,32 @@ export enum CopilotCapability {
ImageToText = 'image-to-text',
}
const CopilotProviderOptionsSchema = z.object({
signal: z.instanceof(AbortSignal).optional(),
user: z.string().optional(),
});
const CopilotChatOptionsSchema = CopilotProviderOptionsSchema.extend({
temperature: z.number().optional(),
maxTokens: z.number().optional(),
}).optional();
export type CopilotChatOptions = z.infer<typeof CopilotChatOptionsSchema>;
const CopilotEmbeddingOptionsSchema = CopilotProviderOptionsSchema.extend({
dimensions: z.number(),
}).optional();
export type CopilotEmbeddingOptions = z.infer<
typeof CopilotEmbeddingOptionsSchema
>;
const CopilotImageOptionsSchema = CopilotProviderOptionsSchema.extend({
seed: z.number().optional(),
}).optional();
export type CopilotImageOptions = z.infer<typeof CopilotImageOptionsSchema>;
export interface CopilotProvider {
readonly type: CopilotProviderType;
getCapabilities(): CopilotCapability[];
@@ -153,22 +179,12 @@ export interface CopilotTextToTextProvider extends CopilotProvider {
generateText(
messages: PromptMessage[],
model?: string,
options?: {
temperature?: number;
maxTokens?: number;
signal?: AbortSignal;
user?: string;
}
options?: CopilotChatOptions
): Promise<string>;
generateTextStream(
messages: PromptMessage[],
model?: string,
options?: {
temperature?: number;
maxTokens?: number;
signal?: AbortSignal;
user?: string;
}
options?: CopilotChatOptions
): AsyncIterable<string>;
}
@@ -176,11 +192,7 @@ export interface CopilotTextToEmbeddingProvider extends CopilotProvider {
generateEmbedding(
messages: string[] | string,
model: string,
options: {
dimensions: number;
signal?: AbortSignal;
user?: string;
}
options?: CopilotEmbeddingOptions
): Promise<number[][]>;
}
@@ -188,18 +200,12 @@ export interface CopilotTextToImageProvider extends CopilotProvider {
generateImages(
messages: PromptMessage[],
model: string,
options: {
signal?: AbortSignal;
user?: string;
}
options?: CopilotImageOptions
): Promise<Array<string>>;
generateImagesStream(
messages: PromptMessage[],
model?: string,
options?: {
signal?: AbortSignal;
user?: string;
}
options?: CopilotImageOptions
): AsyncIterable<string>;
}
@@ -207,22 +213,12 @@ export interface CopilotImageToTextProvider extends CopilotProvider {
generateText(
messages: PromptMessage[],
model: string,
options: {
temperature?: number;
maxTokens?: number;
signal?: AbortSignal;
user?: string;
}
options?: CopilotChatOptions
): Promise<string>;
generateTextStream(
messages: PromptMessage[],
model: string,
options: {
temperature?: number;
maxTokens?: number;
signal?: AbortSignal;
user?: string;
}
options?: CopilotChatOptions
): AsyncIterable<string>;
}
@@ -230,18 +226,12 @@ export interface CopilotImageToImageProvider extends CopilotProvider {
generateImages(
messages: PromptMessage[],
model: string,
options: {
signal?: AbortSignal;
user?: string;
}
options?: CopilotImageOptions
): Promise<Array<string>>;
generateImagesStream(
messages: PromptMessage[],
model?: string,
options?: {
signal?: AbortSignal;
user?: string;
}
options?: CopilotImageOptions
): AsyncIterable<string>;
}