diff --git a/packages/backend/server/package.json b/packages/backend/server/package.json index 6ff2a9065d..31934bc139 100644 --- a/packages/backend/server/package.json +++ b/packages/backend/server/package.json @@ -20,6 +20,7 @@ "dependencies": { "@apollo/server": "^4.10.2", "@aws-sdk/client-s3": "^3.552.0", + "@fal-ai/serverless-client": "^0.10.2", "@google-cloud/opentelemetry-cloud-monitoring-exporter": "^0.18.0", "@google-cloud/opentelemetry-cloud-trace-exporter": "^2.2.0", "@google-cloud/opentelemetry-resource-util": "^2.2.0", diff --git a/packages/backend/server/src/data/migrations/1717490700326-update-prompts.ts b/packages/backend/server/src/data/migrations/1717490700326-update-prompts.ts new file mode 100644 index 0000000000..d2889f074b --- /dev/null +++ b/packages/backend/server/src/data/migrations/1717490700326-update-prompts.ts @@ -0,0 +1,13 @@ +import { PrismaClient } from '@prisma/client'; + +import { refreshPrompts } from './utils/prompts'; + +export class UpdatePrompts1717490700326 { + // do the migration + static async up(db: PrismaClient) { + await refreshPrompts(db); + } + + // revert the migration + static async down(_db: PrismaClient) {} +} diff --git a/packages/backend/server/src/data/migrations/utils/prompts.ts b/packages/backend/server/src/data/migrations/utils/prompts.ts index b4a2f28261..aeb446eae6 100644 --- a/packages/backend/server/src/data/migrations/utils/prompts.ts +++ b/packages/backend/server/src/data/migrations/utils/prompts.ts @@ -86,64 +86,26 @@ export const prompts: Prompt[] = [ { name: 'debug:action:fal-sdturbo-clay', action: 'AI image filter clay style', - model: 'fast-sdxl/image-to-image', - messages: [ - { - role: 'user', - content: 'claymation, clay, {{content}}', - params: { - lora: [ - 'https://models.affine.pro/fal/Clay_AFFiNEAI_SDXL1_CLAYMATION.safetensors', - ], - }, - }, - ], + model: 'workflows/darkskygit/clay', + messages: [], }, { name: 'debug:action:fal-sdturbo-pixel', action: 'AI image filter pixel style', - model: 'fast-sdxl/image-to-image', - messages: [ - { - role: 'user', - content: 'pixel art, very high detail, masterpiece, {{content}}', - params: { - lora: ['https://models.affine.pro/fal/pixel-art-xl-v1.1.safetensors'], - }, - }, - ], + model: 'workflows/darkskygit/pixel-art', + messages: [], }, { name: 'debug:action:fal-sdturbo-sketch', action: 'AI image filter sketch style', - model: 'fast-sdxl/image-to-image', - messages: [ - { - role: 'user', - content: 'sketch for art examination, {{content}}', - params: { - lora: [ - 'https://models.affine.pro/fal/sketch_for_art_examination.safetensors', - ], - }, - }, - ], + model: 'workflows/darkskygit/sketch', + messages: [], }, { name: 'debug:action:fal-sdturbo-fantasy', action: 'AI image filter anime style', - model: 'fast-sdxl/image-to-image', - messages: [ - { - role: 'user', - content: 'fansty world, {{content}}', - params: { - lora: [ - 'https://models.affine.pro/fal/fansty%20world-000020.safetensors', - ], - }, - }, - ], + model: 'workflows/darkskygit/animie', + messages: [], }, { name: 'debug:action:fal-face-to-sticker', diff --git a/packages/backend/server/src/plugins/copilot/providers/fal.ts b/packages/backend/server/src/plugins/copilot/providers/fal.ts index 0c99e3bebe..9fc1e56630 100644 --- a/packages/backend/server/src/plugins/copilot/providers/fal.ts +++ b/packages/backend/server/src/plugins/copilot/providers/fal.ts @@ -1,5 +1,12 @@ import assert from 'node:assert'; +import { + config as falConfig, + stream as falStream, +} from '@fal-ai/serverless-client'; +import { Logger } from '@nestjs/common'; +import { z } from 'zod'; + import { CopilotCapability, CopilotChatOptions, @@ -14,21 +21,35 @@ export type FalConfig = { apiKey: string; }; -export type FalImage = { - url: string; - seed: number; - file_name: string; -}; +const FalImageSchema = z + .object({ + url: z.string(), + seed: z.number().optional(), + content_type: z.string(), + file_name: z.string(), + file_size: z.number(), + width: z.number(), + height: z.number(), + }) + .optional(); -export type FalResponse = { - detail: Array<{ msg: string }> | string; - // normal sd/sdxl response - images?: Array; - // special i2i model response - image?: FalImage; - // image2text response - output: string; -}; +type FalImage = z.infer; + +const FalResponseSchema = z.object({ + detail: z + .union([z.array(z.object({ msg: z.string() })), z.string()]) + .optional(), + images: z.array(FalImageSchema).optional(), + image: FalImageSchema.optional(), + output: z.string().optional(), +}); + +type FalResponse = z.infer; + +const FalStreamOutputSchema = z.object({ + type: z.literal('output'), + output: FalResponseSchema, +}); type FalPrompt = { image_url?: string; @@ -55,12 +76,19 @@ export class FalProvider 'face-to-sticker', 'imageutils/rembg', 'fast-sdxl/image-to-image', + 'workflows/darkskygit/animie', + 'workflows/darkskygit/clay', + 'workflows/darkskygit/pixel-art', + 'workflows/darkskygit/sketch', // image to text 'llava-next', ]; + private readonly logger = new Logger(FalProvider.name); + constructor(private readonly config: FalConfig) { assert(FalProvider.assetsConfig(config)); + falConfig({ credentials: this.config.apiKey }); } static assetsConfig(config: FalConfig) { @@ -162,6 +190,37 @@ export class FalProvider } } + private async buildResponse( + messages: PromptMessage[], + model: string = this.availableModels[0], + options: CopilotImageOptions = {} + ) { + // by default, image prompt assumes there is only one message + const prompt = this.extractPrompt(messages.pop()); + if (model.startsWith('workflows/')) { + const stream = await falStream(model, { input: prompt }); + + const result = FalStreamOutputSchema.parse(await stream.done()); + return result.output; + } else { + const response = await fetch(`https://fal.run/fal-ai/${model}`, { + method: 'POST', + headers: { + Authorization: `key ${this.config.apiKey}`, + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + ...prompt, + sync_mode: true, + seed: options.seed || 42, + enable_safety_checks: false, + }), + signal: options.signal, + }); + return FalResponseSchema.parse(await response.json()); + } + } + // ====== image to image ====== async generateImages( messages: PromptMessage[], @@ -172,35 +231,32 @@ export class FalProvider throw new Error(`Invalid model: ${model}`); } - // by default, image prompt assumes there is only one message - const prompt = this.extractPrompt(messages.pop()); - const data = (await fetch(`https://fal.run/fal-ai/${model}`, { - method: 'POST', - headers: { - Authorization: `key ${this.config.apiKey}`, - 'Content-Type': 'application/json', - }, - body: JSON.stringify({ - ...prompt, - sync_mode: true, - seed: options.seed || 42, - enable_safety_checks: false, - }), - signal: options.signal, - }).then(res => res.json())) as FalResponse; + try { + const data = await this.buildResponse(messages, model, options); - if (!data.images?.length && !data.image?.url) { - const error = this.extractError(data); - throw new Error( - error ? `Failed to generate image: ${error}` : 'No images generated' + if (!data.images?.length && !data.image?.url) { + const error = this.extractError(data); + const finalError = error + ? `Failed to generate image: ${error}` + : 'No images generated'; + this.logger.error(finalError); + throw new Error(finalError); + } + + if (data.image?.url) { + return [data.image.url]; + } + + return ( + data.images + ?.filter((image): image is NonNullable => !!image) + .map(image => image.url) || [] ); + } catch (e: any) { + const error = `Failed to generate image: ${e.message}`; + this.logger.error(error, e.stack); + throw new Error(error); } - - if (data.image?.url) { - return [data.image.url]; - } - - return data.images?.map(image => image.url) || []; } async *generateImagesStream( diff --git a/yarn.lock b/yarn.lock index 399842c5ea..f2c591aeb8 100644 --- a/yarn.lock +++ b/yarn.lock @@ -658,6 +658,7 @@ __metadata: "@affine/server-native": "workspace:*" "@apollo/server": "npm:^4.10.2" "@aws-sdk/client-s3": "npm:^3.552.0" + "@fal-ai/serverless-client": "npm:^0.10.2" "@google-cloud/opentelemetry-cloud-monitoring-exporter": "npm:^0.18.0" "@google-cloud/opentelemetry-cloud-trace-exporter": "npm:^2.2.0" "@google-cloud/opentelemetry-resource-util": "npm:^2.2.0" @@ -5427,15 +5428,15 @@ __metadata: languageName: node linkType: hard -"@fal-ai/serverless-client@npm:^0.10.0": - version: 0.10.0 - resolution: "@fal-ai/serverless-client@npm:0.10.0" +"@fal-ai/serverless-client@npm:^0.10.0, @fal-ai/serverless-client@npm:^0.10.2": + version: 0.10.2 + resolution: "@fal-ai/serverless-client@npm:0.10.2" dependencies: "@msgpack/msgpack": "npm:^3.0.0-beta2" eventsource-parser: "npm:^1.1.2" robot3: "npm:^0.4.1" uuid-random: "npm:^1.3.2" - checksum: 10/46bf17fa08523ad6847c063535458b2f132e2baa0e40c70f09b881112d8aa3fa8d3be085e4f915cfe5106f8ad6abe31e7a8236e05acf7a884f17a78ae24a705b + checksum: 10/d96951b606179ed06d5d14cc31db7c1e55372bfbef34c1bc894c76e338d5e3dde3686848d866e273e033b0190aa730f48fcbcac72449f7047c50319f552d2423 languageName: node linkType: hard