From 0c42849bc36e9869c2aee9ed3741cdafed22acfe Mon Sep 17 00:00:00 2001 From: darkskygit Date: Thu, 23 May 2024 14:27:12 +0000 Subject: [PATCH] feat: update i2i model (#7041) --- .../1716451792364-update-prompts.ts | 13 ++++ .../src/data/migrations/utils/prompts.ts | 26 ++++++-- .../src/plugins/copilot/providers/fal.ts | 64 ++++++++++++++++++- packages/backend/server/tests/copilot.spec.ts | 2 +- .../backend/server/tests/utils/copilot.ts | 2 +- 5 files changed, 99 insertions(+), 8 deletions(-) create mode 100644 packages/backend/server/src/data/migrations/1716451792364-update-prompts.ts diff --git a/packages/backend/server/src/data/migrations/1716451792364-update-prompts.ts b/packages/backend/server/src/data/migrations/1716451792364-update-prompts.ts new file mode 100644 index 0000000000..1a07b7e02d --- /dev/null +++ b/packages/backend/server/src/data/migrations/1716451792364-update-prompts.ts @@ -0,0 +1,13 @@ +import { PrismaClient } from '@prisma/client'; + +import { refreshPrompts } from './utils/prompts'; + +export class UpdatePrompts1716451792364 { + // do the migration + static async up(db: PrismaClient) { + await refreshPrompts(db); + } + + // revert the migration + static async down(_db: PrismaClient) {} +} diff --git a/packages/backend/server/src/data/migrations/utils/prompts.ts b/packages/backend/server/src/data/migrations/utils/prompts.ts index afd631d0b2..0f7558e2a8 100644 --- a/packages/backend/server/src/data/migrations/utils/prompts.ts +++ b/packages/backend/server/src/data/migrations/utils/prompts.ts @@ -86,7 +86,7 @@ export const prompts: Prompt[] = [ { name: 'debug:action:fal-sdturbo-clay', action: 'image', - model: 'fast-turbo-diffusion', + model: 'fast-sdxl/image-to-image', messages: [ { role: 'user', @@ -102,7 +102,7 @@ export const prompts: Prompt[] = [ { name: 'debug:action:fal-sdturbo-pixel', action: 'image', - model: 'fast-turbo-diffusion', + model: 'fast-sdxl/image-to-image', messages: [ { role: 'user', @@ -116,7 +116,7 @@ export const prompts: Prompt[] = [ { name: 'debug:action:fal-sdturbo-sketch', action: 'image', - model: 'fast-turbo-diffusion', + model: 'fast-sdxl/image-to-image', messages: [ { role: 'user', @@ -132,7 +132,7 @@ export const prompts: Prompt[] = [ { name: 'debug:action:fal-sdturbo-fantasy', action: 'image', - model: 'fast-turbo-diffusion', + model: 'fast-sdxl/image-to-image', messages: [ { role: 'user', @@ -145,6 +145,24 @@ export const prompts: Prompt[] = [ }, ], }, + { + name: 'debug:action:fal-face-to-sticker', + action: 'image', + model: 'face-to-sticker', + messages: [], + }, + { + name: 'debug:action:fal-summary-caption', + action: 'image', + model: 'llava-next', + messages: [ + { + role: 'user', + content: + 'Please understand this image and generate a short caption. {{content}}', + }, + ], + }, { name: 'Summary', action: 'Summary', diff --git a/packages/backend/server/src/plugins/copilot/providers/fal.ts b/packages/backend/server/src/plugins/copilot/providers/fal.ts index 151f2233ce..0c99e3bebe 100644 --- a/packages/backend/server/src/plugins/copilot/providers/fal.ts +++ b/packages/backend/server/src/plugins/copilot/providers/fal.ts @@ -2,6 +2,7 @@ import assert from 'node:assert'; import { CopilotCapability, + CopilotChatOptions, CopilotImageOptions, CopilotImageToImageProvider, CopilotProviderType, @@ -21,8 +22,12 @@ export type FalImage = { export type FalResponse = { detail: Array<{ msg: string }> | string; + // normal sd/sdxl response images?: Array; + // special i2i model response image?: FalImage; + // image2text response + output: string; }; type FalPrompt = { @@ -38,6 +43,7 @@ export class FalProvider static readonly capabilities = [ CopilotCapability.TextToImage, CopilotCapability.ImageToImage, + CopilotCapability.ImageToText, ]; readonly availableModels = [ @@ -46,7 +52,11 @@ export class FalProvider // image to image 'lcm-sd15-i2i', 'clarity-upscaler', + 'face-to-sticker', 'imageutils/rembg', + 'fast-sdxl/image-to-image', + // image to text + 'llava-next', ]; constructor(private readonly config: FalConfig) { @@ -96,11 +106,62 @@ export class FalProvider ).filter(v => typeof v === 'string' && v.length); return { image_url: attachments?.[0], - prompt: content, + prompt: content.trim(), lora: lora.length ? lora : undefined, }; } + async generateText( + messages: PromptMessage[], + model: string = 'llava-next', + options: CopilotChatOptions = {} + ): Promise { + if (!this.availableModels.includes(model)) { + throw new Error(`Invalid model: ${model}`); + } + + // by default, image prompt assumes there is only one message + const prompt = this.extractPrompt(messages.pop()); + const data = (await fetch(`https://fal.run/fal-ai/${model}`, { + method: 'POST', + headers: { + Authorization: `key ${this.config.apiKey}`, + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + ...prompt, + sync_mode: true, + enable_safety_checks: false, + }), + signal: options.signal, + }).then(res => res.json())) as FalResponse; + + if (!data.output) { + const error = this.extractError(data); + throw new Error( + error ? `Failed to generate image: ${error}` : 'No images generated' + ); + } + return data.output; + } + + async *generateTextStream( + messages: PromptMessage[], + model: string = 'llava-next', + options: CopilotChatOptions = {} + ): AsyncIterable { + const result = await this.generateText(messages, model, options); + + for await (const content of result) { + if (content) { + yield content; + if (options.signal?.aborted) { + break; + } + } + } + } + // ====== image to image ====== async generateImages( messages: PromptMessage[], @@ -113,7 +174,6 @@ export class FalProvider // by default, image prompt assumes there is only one message const prompt = this.extractPrompt(messages.pop()); - const data = (await fetch(`https://fal.run/fal-ai/${model}`, { method: 'POST', headers: { diff --git a/packages/backend/server/tests/copilot.spec.ts b/packages/backend/server/tests/copilot.spec.ts index 6d0cf31eae..dc8d4590ed 100644 --- a/packages/backend/server/tests/copilot.spec.ts +++ b/packages/backend/server/tests/copilot.spec.ts @@ -470,7 +470,7 @@ test('should be able to get provider', async t => { ); t.is( p?.type.toString(), - 'openai', + 'fal', 'should get provider support image-to-text' ); } diff --git a/packages/backend/server/tests/utils/copilot.ts b/packages/backend/server/tests/utils/copilot.ts index d349a7671e..92a367e6f1 100644 --- a/packages/backend/server/tests/utils/copilot.ts +++ b/packages/backend/server/tests/utils/copilot.ts @@ -31,7 +31,7 @@ export class MockCopilotTestProvider { override readonly availableModels = [ 'test', - 'fast-turbo-diffusion', + 'fast-sdxl/image-to-image', 'lcm-sd15-i2i', 'clarity-upscaler', 'imageutils/rembg',