mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-11 20:08:37 +00:00
feat: allow custom seed (#6709)
This commit is contained in:
@@ -100,6 +100,17 @@ export class CopilotController {
|
||||
return controller.signal;
|
||||
}
|
||||
|
||||
private parseNumber(value: string | string[] | undefined) {
|
||||
if (!value) {
|
||||
return undefined;
|
||||
}
|
||||
const num = Number.parseInt(Array.isArray(value) ? value[0] : value, 10);
|
||||
if (Number.isNaN(num)) {
|
||||
return undefined;
|
||||
}
|
||||
return num;
|
||||
}
|
||||
|
||||
private handleError(err: any) {
|
||||
if (err instanceof Error) {
|
||||
const ret = {
|
||||
@@ -256,6 +267,7 @@ export class CopilotController {
|
||||
|
||||
return from(
|
||||
provider.generateImagesStream(session.finish(params), session.model, {
|
||||
seed: this.parseNumber(params.seed),
|
||||
signal: this.getSignal(req),
|
||||
user: user.id,
|
||||
})
|
||||
|
||||
@@ -2,6 +2,7 @@ import assert from 'node:assert';
|
||||
|
||||
import {
|
||||
CopilotCapability,
|
||||
CopilotImageOptions,
|
||||
CopilotImageToImageProvider,
|
||||
CopilotProviderType,
|
||||
CopilotTextToImageProvider,
|
||||
@@ -57,10 +58,7 @@ export class FalProvider
|
||||
async generateImages(
|
||||
messages: PromptMessage[],
|
||||
model: string = this.availableModels[0],
|
||||
options: {
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
} = {}
|
||||
options: CopilotImageOptions = {}
|
||||
): Promise<Array<string>> {
|
||||
const { content, attachments } = messages.pop() || {};
|
||||
if (!this.availableModels.includes(model)) {
|
||||
@@ -82,7 +80,7 @@ export class FalProvider
|
||||
image_url: attachments?.[0],
|
||||
prompt: content,
|
||||
sync_mode: true,
|
||||
seed: 42,
|
||||
seed: options.seed || 42,
|
||||
enable_safety_checks: false,
|
||||
}),
|
||||
signal: options.signal,
|
||||
@@ -100,10 +98,7 @@ export class FalProvider
|
||||
async *generateImagesStream(
|
||||
messages: PromptMessage[],
|
||||
model: string = this.availableModels[0],
|
||||
options: {
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
} = {}
|
||||
options: CopilotImageOptions = {}
|
||||
): AsyncIterable<string> {
|
||||
const ret = await this.generateImages(messages, model, options);
|
||||
for (const url of ret) {
|
||||
|
||||
@@ -5,6 +5,9 @@ import { ClientOptions, OpenAI } from 'openai';
|
||||
import {
|
||||
ChatMessageRole,
|
||||
CopilotCapability,
|
||||
CopilotChatOptions,
|
||||
CopilotEmbeddingOptions,
|
||||
CopilotImageOptions,
|
||||
CopilotImageToTextProvider,
|
||||
CopilotProviderType,
|
||||
CopilotTextToEmbeddingProvider,
|
||||
@@ -147,12 +150,7 @@ export class OpenAIProvider
|
||||
async generateText(
|
||||
messages: PromptMessage[],
|
||||
model: string = 'gpt-3.5-turbo',
|
||||
options: {
|
||||
temperature?: number;
|
||||
maxTokens?: number;
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
} = {}
|
||||
options: CopilotChatOptions = {}
|
||||
): Promise<string> {
|
||||
this.checkParams({ messages, model });
|
||||
const result = await this.instance.chat.completions.create(
|
||||
@@ -175,12 +173,7 @@ export class OpenAIProvider
|
||||
async *generateTextStream(
|
||||
messages: PromptMessage[],
|
||||
model: string = 'gpt-3.5-turbo',
|
||||
options: {
|
||||
temperature?: number;
|
||||
maxTokens?: number;
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
} = {}
|
||||
options: CopilotChatOptions = {}
|
||||
): AsyncIterable<string> {
|
||||
this.checkParams({ messages, model });
|
||||
const result = await this.instance.chat.completions.create(
|
||||
@@ -214,11 +207,7 @@ export class OpenAIProvider
|
||||
async generateEmbedding(
|
||||
messages: string | string[],
|
||||
model: string,
|
||||
options: {
|
||||
dimensions: number;
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
} = { dimensions: DEFAULT_DIMENSIONS }
|
||||
options: CopilotEmbeddingOptions = { dimensions: DEFAULT_DIMENSIONS }
|
||||
): Promise<number[][]> {
|
||||
messages = Array.isArray(messages) ? messages : [messages];
|
||||
this.checkParams({ embeddings: messages, model });
|
||||
@@ -236,10 +225,7 @@ export class OpenAIProvider
|
||||
async generateImages(
|
||||
messages: PromptMessage[],
|
||||
model: string = 'dall-e-3',
|
||||
options: {
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
} = {}
|
||||
options: CopilotImageOptions = {}
|
||||
): Promise<Array<string>> {
|
||||
const { content: prompt } = messages.pop() || {};
|
||||
if (!prompt) {
|
||||
@@ -261,10 +247,7 @@ export class OpenAIProvider
|
||||
async *generateImagesStream(
|
||||
messages: PromptMessage[],
|
||||
model: string = 'dall-e-3',
|
||||
options: {
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
} = {}
|
||||
options: CopilotImageOptions = {}
|
||||
): AsyncIterable<string> {
|
||||
const ret = await this.generateImages(messages, model, options);
|
||||
for (const url of ret) {
|
||||
|
||||
@@ -143,6 +143,32 @@ export enum CopilotCapability {
|
||||
ImageToText = 'image-to-text',
|
||||
}
|
||||
|
||||
const CopilotProviderOptionsSchema = z.object({
|
||||
signal: z.instanceof(AbortSignal).optional(),
|
||||
user: z.string().optional(),
|
||||
});
|
||||
|
||||
const CopilotChatOptionsSchema = CopilotProviderOptionsSchema.extend({
|
||||
temperature: z.number().optional(),
|
||||
maxTokens: z.number().optional(),
|
||||
}).optional();
|
||||
|
||||
export type CopilotChatOptions = z.infer<typeof CopilotChatOptionsSchema>;
|
||||
|
||||
const CopilotEmbeddingOptionsSchema = CopilotProviderOptionsSchema.extend({
|
||||
dimensions: z.number(),
|
||||
}).optional();
|
||||
|
||||
export type CopilotEmbeddingOptions = z.infer<
|
||||
typeof CopilotEmbeddingOptionsSchema
|
||||
>;
|
||||
|
||||
const CopilotImageOptionsSchema = CopilotProviderOptionsSchema.extend({
|
||||
seed: z.number().optional(),
|
||||
}).optional();
|
||||
|
||||
export type CopilotImageOptions = z.infer<typeof CopilotImageOptionsSchema>;
|
||||
|
||||
export interface CopilotProvider {
|
||||
readonly type: CopilotProviderType;
|
||||
getCapabilities(): CopilotCapability[];
|
||||
@@ -153,22 +179,12 @@ export interface CopilotTextToTextProvider extends CopilotProvider {
|
||||
generateText(
|
||||
messages: PromptMessage[],
|
||||
model?: string,
|
||||
options?: {
|
||||
temperature?: number;
|
||||
maxTokens?: number;
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
}
|
||||
options?: CopilotChatOptions
|
||||
): Promise<string>;
|
||||
generateTextStream(
|
||||
messages: PromptMessage[],
|
||||
model?: string,
|
||||
options?: {
|
||||
temperature?: number;
|
||||
maxTokens?: number;
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
}
|
||||
options?: CopilotChatOptions
|
||||
): AsyncIterable<string>;
|
||||
}
|
||||
|
||||
@@ -176,11 +192,7 @@ export interface CopilotTextToEmbeddingProvider extends CopilotProvider {
|
||||
generateEmbedding(
|
||||
messages: string[] | string,
|
||||
model: string,
|
||||
options: {
|
||||
dimensions: number;
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
}
|
||||
options?: CopilotEmbeddingOptions
|
||||
): Promise<number[][]>;
|
||||
}
|
||||
|
||||
@@ -188,18 +200,12 @@ export interface CopilotTextToImageProvider extends CopilotProvider {
|
||||
generateImages(
|
||||
messages: PromptMessage[],
|
||||
model: string,
|
||||
options: {
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
}
|
||||
options?: CopilotImageOptions
|
||||
): Promise<Array<string>>;
|
||||
generateImagesStream(
|
||||
messages: PromptMessage[],
|
||||
model?: string,
|
||||
options?: {
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
}
|
||||
options?: CopilotImageOptions
|
||||
): AsyncIterable<string>;
|
||||
}
|
||||
|
||||
@@ -207,22 +213,12 @@ export interface CopilotImageToTextProvider extends CopilotProvider {
|
||||
generateText(
|
||||
messages: PromptMessage[],
|
||||
model: string,
|
||||
options: {
|
||||
temperature?: number;
|
||||
maxTokens?: number;
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
}
|
||||
options?: CopilotChatOptions
|
||||
): Promise<string>;
|
||||
generateTextStream(
|
||||
messages: PromptMessage[],
|
||||
model: string,
|
||||
options: {
|
||||
temperature?: number;
|
||||
maxTokens?: number;
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
}
|
||||
options?: CopilotChatOptions
|
||||
): AsyncIterable<string>;
|
||||
}
|
||||
|
||||
@@ -230,18 +226,12 @@ export interface CopilotImageToImageProvider extends CopilotProvider {
|
||||
generateImages(
|
||||
messages: PromptMessage[],
|
||||
model: string,
|
||||
options: {
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
}
|
||||
options?: CopilotImageOptions
|
||||
): Promise<Array<string>>;
|
||||
generateImagesStream(
|
||||
messages: PromptMessage[],
|
||||
model?: string,
|
||||
options?: {
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
}
|
||||
options?: CopilotImageOptions
|
||||
): AsyncIterable<string>;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user