feat: update i2i model (#7041)

This commit is contained in:
darkskygit
2024-05-23 14:27:12 +00:00
parent 535254fdf6
commit 0c42849bc3
5 changed files with 99 additions and 8 deletions

View File

@@ -0,0 +1,13 @@
import { PrismaClient } from '@prisma/client';
import { refreshPrompts } from './utils/prompts';
export class UpdatePrompts1716451792364 {
// do the migration
static async up(db: PrismaClient) {
await refreshPrompts(db);
}
// revert the migration
static async down(_db: PrismaClient) {}
}

View File

@@ -86,7 +86,7 @@ export const prompts: Prompt[] = [
{
name: 'debug:action:fal-sdturbo-clay',
action: 'image',
model: 'fast-turbo-diffusion',
model: 'fast-sdxl/image-to-image',
messages: [
{
role: 'user',
@@ -102,7 +102,7 @@ export const prompts: Prompt[] = [
{
name: 'debug:action:fal-sdturbo-pixel',
action: 'image',
model: 'fast-turbo-diffusion',
model: 'fast-sdxl/image-to-image',
messages: [
{
role: 'user',
@@ -116,7 +116,7 @@ export const prompts: Prompt[] = [
{
name: 'debug:action:fal-sdturbo-sketch',
action: 'image',
model: 'fast-turbo-diffusion',
model: 'fast-sdxl/image-to-image',
messages: [
{
role: 'user',
@@ -132,7 +132,7 @@ export const prompts: Prompt[] = [
{
name: 'debug:action:fal-sdturbo-fantasy',
action: 'image',
model: 'fast-turbo-diffusion',
model: 'fast-sdxl/image-to-image',
messages: [
{
role: 'user',
@@ -145,6 +145,24 @@ export const prompts: Prompt[] = [
},
],
},
{
name: 'debug:action:fal-face-to-sticker',
action: 'image',
model: 'face-to-sticker',
messages: [],
},
{
name: 'debug:action:fal-summary-caption',
action: 'image',
model: 'llava-next',
messages: [
{
role: 'user',
content:
'Please understand this image and generate a short caption. {{content}}',
},
],
},
{
name: 'Summary',
action: 'Summary',

View File

@@ -2,6 +2,7 @@ import assert from 'node:assert';
import {
CopilotCapability,
CopilotChatOptions,
CopilotImageOptions,
CopilotImageToImageProvider,
CopilotProviderType,
@@ -21,8 +22,12 @@ export type FalImage = {
export type FalResponse = {
detail: Array<{ msg: string }> | string;
// normal sd/sdxl response
images?: Array<FalImage>;
// special i2i model response
image?: FalImage;
// image2text response
output: string;
};
type FalPrompt = {
@@ -38,6 +43,7 @@ export class FalProvider
static readonly capabilities = [
CopilotCapability.TextToImage,
CopilotCapability.ImageToImage,
CopilotCapability.ImageToText,
];
readonly availableModels = [
@@ -46,7 +52,11 @@ export class FalProvider
// image to image
'lcm-sd15-i2i',
'clarity-upscaler',
'face-to-sticker',
'imageutils/rembg',
'fast-sdxl/image-to-image',
// image to text
'llava-next',
];
constructor(private readonly config: FalConfig) {
@@ -96,11 +106,62 @@ export class FalProvider
).filter(v => typeof v === 'string' && v.length);
return {
image_url: attachments?.[0],
prompt: content,
prompt: content.trim(),
lora: lora.length ? lora : undefined,
};
}
async generateText(
messages: PromptMessage[],
model: string = 'llava-next',
options: CopilotChatOptions = {}
): Promise<string> {
if (!this.availableModels.includes(model)) {
throw new Error(`Invalid model: ${model}`);
}
// by default, image prompt assumes there is only one message
const prompt = this.extractPrompt(messages.pop());
const data = (await fetch(`https://fal.run/fal-ai/${model}`, {
method: 'POST',
headers: {
Authorization: `key ${this.config.apiKey}`,
'Content-Type': 'application/json',
},
body: JSON.stringify({
...prompt,
sync_mode: true,
enable_safety_checks: false,
}),
signal: options.signal,
}).then(res => res.json())) as FalResponse;
if (!data.output) {
const error = this.extractError(data);
throw new Error(
error ? `Failed to generate image: ${error}` : 'No images generated'
);
}
return data.output;
}
async *generateTextStream(
messages: PromptMessage[],
model: string = 'llava-next',
options: CopilotChatOptions = {}
): AsyncIterable<string> {
const result = await this.generateText(messages, model, options);
for await (const content of result) {
if (content) {
yield content;
if (options.signal?.aborted) {
break;
}
}
}
}
// ====== image to image ======
async generateImages(
messages: PromptMessage[],
@@ -113,7 +174,6 @@ export class FalProvider
// by default, image prompt assumes there is only one message
const prompt = this.extractPrompt(messages.pop());
const data = (await fetch(`https://fal.run/fal-ai/${model}`, {
method: 'POST',
headers: {

View File

@@ -470,7 +470,7 @@ test('should be able to get provider', async t => {
);
t.is(
p?.type.toString(),
'openai',
'fal',
'should get provider support image-to-text'
);
}

View File

@@ -31,7 +31,7 @@ export class MockCopilotTestProvider
{
override readonly availableModels = [
'test',
'fast-turbo-diffusion',
'fast-sdxl/image-to-image',
'lcm-sd15-i2i',
'clarity-upscaler',
'imageutils/rembg',