mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-13 21:05:19 +00:00
feat: text to image impl (#6437)
fix CLOUD-18 fix CLOUD-28 fix CLOUD-29
This commit is contained in:
92
packages/backend/server/src/plugins/copilot/providers/fal.ts
Normal file
92
packages/backend/server/src/plugins/copilot/providers/fal.ts
Normal file
@@ -0,0 +1,92 @@
|
||||
import assert from 'node:assert';
|
||||
|
||||
import {
|
||||
CopilotCapability,
|
||||
CopilotImageToImageProvider,
|
||||
CopilotProviderType,
|
||||
PromptMessage,
|
||||
} from '../types';
|
||||
|
||||
export type FalConfig = {
|
||||
apiKey: string;
|
||||
};
|
||||
|
||||
export type FalResponse = {
|
||||
images: Array<{ url: string }>;
|
||||
};
|
||||
|
||||
export class FalProvider implements CopilotImageToImageProvider {
|
||||
static readonly type = CopilotProviderType.FAL;
|
||||
static readonly capabilities = [CopilotCapability.ImageToImage];
|
||||
|
||||
readonly availableModels = [
|
||||
// image to image
|
||||
// https://blog.fal.ai/building-applications-with-real-time-stable-diffusion-apis/
|
||||
'110602490-lcm-sd15-i2i',
|
||||
];
|
||||
|
||||
constructor(private readonly config: FalConfig) {
|
||||
assert(FalProvider.assetsConfig(config));
|
||||
}
|
||||
|
||||
static assetsConfig(config: FalConfig) {
|
||||
return !!config.apiKey;
|
||||
}
|
||||
|
||||
getCapabilities(): CopilotCapability[] {
|
||||
return FalProvider.capabilities;
|
||||
}
|
||||
|
||||
// ====== image to image ======
|
||||
async generateImages(
|
||||
messages: PromptMessage[],
|
||||
model: string = this.availableModels[0],
|
||||
options: {
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
} = {}
|
||||
): Promise<Array<string>> {
|
||||
const { content, attachments } = messages.pop() || {};
|
||||
if (!this.availableModels.includes(model)) {
|
||||
throw new Error(`Invalid model: ${model}`);
|
||||
}
|
||||
if (!content) {
|
||||
throw new Error('Prompt is required');
|
||||
}
|
||||
if (!Array.isArray(attachments) || !attachments.length) {
|
||||
throw new Error('Attachments is required');
|
||||
}
|
||||
|
||||
const data = (await fetch(`https://${model}.gateway.alpha.fal.ai/`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Authorization: `key ${this.config.apiKey}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
image_url: attachments[0],
|
||||
prompt: content,
|
||||
sync_mode: true,
|
||||
seed: 42,
|
||||
enable_safety_checks: false,
|
||||
}),
|
||||
signal: options.signal,
|
||||
}).then(res => res.json())) as FalResponse;
|
||||
|
||||
return data.images.map(image => image.url);
|
||||
}
|
||||
|
||||
async *generateImagesStream(
|
||||
messages: PromptMessage[],
|
||||
model: string = this.availableModels[0],
|
||||
options: {
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
} = {}
|
||||
): AsyncIterable<string> {
|
||||
const ret = await this.generateImages(messages, model, options);
|
||||
for (const url of ret) {
|
||||
yield url;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -134,4 +134,5 @@ export class CopilotProviderService {
|
||||
}
|
||||
}
|
||||
|
||||
export { FalProvider } from './fal';
|
||||
export { OpenAIProvider } from './openai';
|
||||
|
||||
@@ -5,22 +5,31 @@ import { ClientOptions, OpenAI } from 'openai';
|
||||
import {
|
||||
ChatMessageRole,
|
||||
CopilotCapability,
|
||||
CopilotImageToTextProvider,
|
||||
CopilotProviderType,
|
||||
CopilotTextToEmbeddingProvider,
|
||||
CopilotTextToImageProvider,
|
||||
CopilotTextToTextProvider,
|
||||
PromptMessage,
|
||||
} from '../types';
|
||||
|
||||
const DEFAULT_DIMENSIONS = 256;
|
||||
|
||||
const SIMPLE_IMAGE_URL_REGEX = /^(https?:\/\/|data:image\/)/;
|
||||
|
||||
export class OpenAIProvider
|
||||
implements CopilotTextToTextProvider, CopilotTextToEmbeddingProvider
|
||||
implements
|
||||
CopilotTextToTextProvider,
|
||||
CopilotTextToEmbeddingProvider,
|
||||
CopilotTextToImageProvider,
|
||||
CopilotImageToTextProvider
|
||||
{
|
||||
static readonly type = CopilotProviderType.OpenAI;
|
||||
static readonly capabilities = [
|
||||
CopilotCapability.TextToText,
|
||||
CopilotCapability.TextToEmbedding,
|
||||
CopilotCapability.TextToImage,
|
||||
CopilotCapability.ImageToText,
|
||||
];
|
||||
|
||||
readonly availableModels = [
|
||||
@@ -35,6 +44,8 @@ export class OpenAIProvider
|
||||
// moderation
|
||||
'text-moderation-latest',
|
||||
'text-moderation-stable',
|
||||
// text to image
|
||||
'dall-e-3',
|
||||
];
|
||||
|
||||
private readonly instance: OpenAI;
|
||||
@@ -52,12 +63,29 @@ export class OpenAIProvider
|
||||
return OpenAIProvider.capabilities;
|
||||
}
|
||||
|
||||
private chatToGPTMessage(messages: PromptMessage[]) {
|
||||
private chatToGPTMessage(
|
||||
messages: PromptMessage[]
|
||||
): OpenAI.Chat.Completions.ChatCompletionMessageParam[] {
|
||||
// filter redundant fields
|
||||
return messages.map(message => ({
|
||||
role: message.role,
|
||||
content: message.content,
|
||||
}));
|
||||
return messages.map(({ role, content, attachments }) => {
|
||||
if (Array.isArray(attachments)) {
|
||||
const contents = [
|
||||
{ type: 'text', text: content },
|
||||
...attachments
|
||||
.filter(url => SIMPLE_IMAGE_URL_REGEX.test(url))
|
||||
.map(url => ({
|
||||
type: 'image_url',
|
||||
image_url: { url, detail: 'low' },
|
||||
})),
|
||||
];
|
||||
return {
|
||||
role,
|
||||
content: contents,
|
||||
} as OpenAI.Chat.Completions.ChatCompletionMessageParam;
|
||||
} else {
|
||||
return { role, content };
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private checkParams({
|
||||
@@ -194,4 +222,44 @@ export class OpenAIProvider
|
||||
});
|
||||
return result.data.map(e => e.embedding);
|
||||
}
|
||||
|
||||
// ====== text to image ======
|
||||
async generateImages(
|
||||
messages: PromptMessage[],
|
||||
model: string = 'dall-e-3',
|
||||
options: {
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
} = {}
|
||||
): Promise<Array<string>> {
|
||||
const { content: prompt } = messages.pop() || {};
|
||||
if (!prompt) {
|
||||
throw new Error('Prompt is required');
|
||||
}
|
||||
const result = await this.instance.images.generate(
|
||||
{
|
||||
prompt,
|
||||
model,
|
||||
response_format: 'url',
|
||||
user: options.user,
|
||||
},
|
||||
{ signal: options.signal }
|
||||
);
|
||||
|
||||
return result.data.map(image => image.url).filter((v): v is string => !!v);
|
||||
}
|
||||
|
||||
async *generateImagesStream(
|
||||
messages: PromptMessage[],
|
||||
model: string = 'dall-e-3',
|
||||
options: {
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
} = {}
|
||||
): AsyncIterable<string> {
|
||||
const ret = await this.generateImages(messages, model, options);
|
||||
for (const url of ret) {
|
||||
yield url;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user