feat: text to image impl (#6437)

fix CLOUD-18
fix CLOUD-28
fix CLOUD-29
This commit is contained in:
darkskygit
2024-04-10 12:13:39 +00:00
parent 7c38a54f81
commit 9f349a2300
19 changed files with 601 additions and 99 deletions

View File

@@ -0,0 +1,92 @@
import assert from 'node:assert';
import {
CopilotCapability,
CopilotImageToImageProvider,
CopilotProviderType,
PromptMessage,
} from '../types';
export type FalConfig = {
apiKey: string;
};
export type FalResponse = {
images: Array<{ url: string }>;
};
export class FalProvider implements CopilotImageToImageProvider {
static readonly type = CopilotProviderType.FAL;
static readonly capabilities = [CopilotCapability.ImageToImage];
readonly availableModels = [
// image to image
// https://blog.fal.ai/building-applications-with-real-time-stable-diffusion-apis/
'110602490-lcm-sd15-i2i',
];
constructor(private readonly config: FalConfig) {
assert(FalProvider.assetsConfig(config));
}
static assetsConfig(config: FalConfig) {
return !!config.apiKey;
}
getCapabilities(): CopilotCapability[] {
return FalProvider.capabilities;
}
// ====== image to image ======
async generateImages(
messages: PromptMessage[],
model: string = this.availableModels[0],
options: {
signal?: AbortSignal;
user?: string;
} = {}
): Promise<Array<string>> {
const { content, attachments } = messages.pop() || {};
if (!this.availableModels.includes(model)) {
throw new Error(`Invalid model: ${model}`);
}
if (!content) {
throw new Error('Prompt is required');
}
if (!Array.isArray(attachments) || !attachments.length) {
throw new Error('Attachments is required');
}
const data = (await fetch(`https://${model}.gateway.alpha.fal.ai/`, {
method: 'POST',
headers: {
Authorization: `key ${this.config.apiKey}`,
'Content-Type': 'application/json',
},
body: JSON.stringify({
image_url: attachments[0],
prompt: content,
sync_mode: true,
seed: 42,
enable_safety_checks: false,
}),
signal: options.signal,
}).then(res => res.json())) as FalResponse;
return data.images.map(image => image.url);
}
async *generateImagesStream(
messages: PromptMessage[],
model: string = this.availableModels[0],
options: {
signal?: AbortSignal;
user?: string;
} = {}
): AsyncIterable<string> {
const ret = await this.generateImages(messages, model, options);
for (const url of ret) {
yield url;
}
}
}

View File

@@ -134,4 +134,5 @@ export class CopilotProviderService {
}
}
export { FalProvider } from './fal';
export { OpenAIProvider } from './openai';

View File

@@ -5,22 +5,31 @@ import { ClientOptions, OpenAI } from 'openai';
import {
ChatMessageRole,
CopilotCapability,
CopilotImageToTextProvider,
CopilotProviderType,
CopilotTextToEmbeddingProvider,
CopilotTextToImageProvider,
CopilotTextToTextProvider,
PromptMessage,
} from '../types';
const DEFAULT_DIMENSIONS = 256;
const SIMPLE_IMAGE_URL_REGEX = /^(https?:\/\/|data:image\/)/;
export class OpenAIProvider
implements CopilotTextToTextProvider, CopilotTextToEmbeddingProvider
implements
CopilotTextToTextProvider,
CopilotTextToEmbeddingProvider,
CopilotTextToImageProvider,
CopilotImageToTextProvider
{
static readonly type = CopilotProviderType.OpenAI;
static readonly capabilities = [
CopilotCapability.TextToText,
CopilotCapability.TextToEmbedding,
CopilotCapability.TextToImage,
CopilotCapability.ImageToText,
];
readonly availableModels = [
@@ -35,6 +44,8 @@ export class OpenAIProvider
// moderation
'text-moderation-latest',
'text-moderation-stable',
// text to image
'dall-e-3',
];
private readonly instance: OpenAI;
@@ -52,12 +63,29 @@ export class OpenAIProvider
return OpenAIProvider.capabilities;
}
private chatToGPTMessage(messages: PromptMessage[]) {
private chatToGPTMessage(
messages: PromptMessage[]
): OpenAI.Chat.Completions.ChatCompletionMessageParam[] {
// filter redundant fields
return messages.map(message => ({
role: message.role,
content: message.content,
}));
return messages.map(({ role, content, attachments }) => {
if (Array.isArray(attachments)) {
const contents = [
{ type: 'text', text: content },
...attachments
.filter(url => SIMPLE_IMAGE_URL_REGEX.test(url))
.map(url => ({
type: 'image_url',
image_url: { url, detail: 'low' },
})),
];
return {
role,
content: contents,
} as OpenAI.Chat.Completions.ChatCompletionMessageParam;
} else {
return { role, content };
}
});
}
private checkParams({
@@ -194,4 +222,44 @@ export class OpenAIProvider
});
return result.data.map(e => e.embedding);
}
// ====== text to image ======
async generateImages(
messages: PromptMessage[],
model: string = 'dall-e-3',
options: {
signal?: AbortSignal;
user?: string;
} = {}
): Promise<Array<string>> {
const { content: prompt } = messages.pop() || {};
if (!prompt) {
throw new Error('Prompt is required');
}
const result = await this.instance.images.generate(
{
prompt,
model,
response_format: 'url',
user: options.user,
},
{ signal: options.signal }
);
return result.data.map(image => image.url).filter((v): v is string => !!v);
}
async *generateImagesStream(
messages: PromptMessage[],
model: string = 'dall-e-3',
options: {
signal?: AbortSignal;
user?: string;
} = {}
): AsyncIterable<string> {
const ret = await this.generateImages(messages, model, options);
for (const url of ret) {
yield url;
}
}
}