mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-12 04:18:54 +00:00
feat: update i2i model (#7041)
This commit is contained in:
@@ -0,0 +1,13 @@
|
|||||||
|
import { PrismaClient } from '@prisma/client';
|
||||||
|
|
||||||
|
import { refreshPrompts } from './utils/prompts';
|
||||||
|
|
||||||
|
export class UpdatePrompts1716451792364 {
|
||||||
|
// do the migration
|
||||||
|
static async up(db: PrismaClient) {
|
||||||
|
await refreshPrompts(db);
|
||||||
|
}
|
||||||
|
|
||||||
|
// revert the migration
|
||||||
|
static async down(_db: PrismaClient) {}
|
||||||
|
}
|
||||||
@@ -86,7 +86,7 @@ export const prompts: Prompt[] = [
|
|||||||
{
|
{
|
||||||
name: 'debug:action:fal-sdturbo-clay',
|
name: 'debug:action:fal-sdturbo-clay',
|
||||||
action: 'image',
|
action: 'image',
|
||||||
model: 'fast-turbo-diffusion',
|
model: 'fast-sdxl/image-to-image',
|
||||||
messages: [
|
messages: [
|
||||||
{
|
{
|
||||||
role: 'user',
|
role: 'user',
|
||||||
@@ -102,7 +102,7 @@ export const prompts: Prompt[] = [
|
|||||||
{
|
{
|
||||||
name: 'debug:action:fal-sdturbo-pixel',
|
name: 'debug:action:fal-sdturbo-pixel',
|
||||||
action: 'image',
|
action: 'image',
|
||||||
model: 'fast-turbo-diffusion',
|
model: 'fast-sdxl/image-to-image',
|
||||||
messages: [
|
messages: [
|
||||||
{
|
{
|
||||||
role: 'user',
|
role: 'user',
|
||||||
@@ -116,7 +116,7 @@ export const prompts: Prompt[] = [
|
|||||||
{
|
{
|
||||||
name: 'debug:action:fal-sdturbo-sketch',
|
name: 'debug:action:fal-sdturbo-sketch',
|
||||||
action: 'image',
|
action: 'image',
|
||||||
model: 'fast-turbo-diffusion',
|
model: 'fast-sdxl/image-to-image',
|
||||||
messages: [
|
messages: [
|
||||||
{
|
{
|
||||||
role: 'user',
|
role: 'user',
|
||||||
@@ -132,7 +132,7 @@ export const prompts: Prompt[] = [
|
|||||||
{
|
{
|
||||||
name: 'debug:action:fal-sdturbo-fantasy',
|
name: 'debug:action:fal-sdturbo-fantasy',
|
||||||
action: 'image',
|
action: 'image',
|
||||||
model: 'fast-turbo-diffusion',
|
model: 'fast-sdxl/image-to-image',
|
||||||
messages: [
|
messages: [
|
||||||
{
|
{
|
||||||
role: 'user',
|
role: 'user',
|
||||||
@@ -145,6 +145,24 @@ export const prompts: Prompt[] = [
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: 'debug:action:fal-face-to-sticker',
|
||||||
|
action: 'image',
|
||||||
|
model: 'face-to-sticker',
|
||||||
|
messages: [],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'debug:action:fal-summary-caption',
|
||||||
|
action: 'image',
|
||||||
|
model: 'llava-next',
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
content:
|
||||||
|
'Please understand this image and generate a short caption. {{content}}',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: 'Summary',
|
name: 'Summary',
|
||||||
action: 'Summary',
|
action: 'Summary',
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import assert from 'node:assert';
|
|||||||
|
|
||||||
import {
|
import {
|
||||||
CopilotCapability,
|
CopilotCapability,
|
||||||
|
CopilotChatOptions,
|
||||||
CopilotImageOptions,
|
CopilotImageOptions,
|
||||||
CopilotImageToImageProvider,
|
CopilotImageToImageProvider,
|
||||||
CopilotProviderType,
|
CopilotProviderType,
|
||||||
@@ -21,8 +22,12 @@ export type FalImage = {
|
|||||||
|
|
||||||
export type FalResponse = {
|
export type FalResponse = {
|
||||||
detail: Array<{ msg: string }> | string;
|
detail: Array<{ msg: string }> | string;
|
||||||
|
// normal sd/sdxl response
|
||||||
images?: Array<FalImage>;
|
images?: Array<FalImage>;
|
||||||
|
// special i2i model response
|
||||||
image?: FalImage;
|
image?: FalImage;
|
||||||
|
// image2text response
|
||||||
|
output: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
type FalPrompt = {
|
type FalPrompt = {
|
||||||
@@ -38,6 +43,7 @@ export class FalProvider
|
|||||||
static readonly capabilities = [
|
static readonly capabilities = [
|
||||||
CopilotCapability.TextToImage,
|
CopilotCapability.TextToImage,
|
||||||
CopilotCapability.ImageToImage,
|
CopilotCapability.ImageToImage,
|
||||||
|
CopilotCapability.ImageToText,
|
||||||
];
|
];
|
||||||
|
|
||||||
readonly availableModels = [
|
readonly availableModels = [
|
||||||
@@ -46,7 +52,11 @@ export class FalProvider
|
|||||||
// image to image
|
// image to image
|
||||||
'lcm-sd15-i2i',
|
'lcm-sd15-i2i',
|
||||||
'clarity-upscaler',
|
'clarity-upscaler',
|
||||||
|
'face-to-sticker',
|
||||||
'imageutils/rembg',
|
'imageutils/rembg',
|
||||||
|
'fast-sdxl/image-to-image',
|
||||||
|
// image to text
|
||||||
|
'llava-next',
|
||||||
];
|
];
|
||||||
|
|
||||||
constructor(private readonly config: FalConfig) {
|
constructor(private readonly config: FalConfig) {
|
||||||
@@ -96,11 +106,62 @@ export class FalProvider
|
|||||||
).filter(v => typeof v === 'string' && v.length);
|
).filter(v => typeof v === 'string' && v.length);
|
||||||
return {
|
return {
|
||||||
image_url: attachments?.[0],
|
image_url: attachments?.[0],
|
||||||
prompt: content,
|
prompt: content.trim(),
|
||||||
lora: lora.length ? lora : undefined,
|
lora: lora.length ? lora : undefined,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async generateText(
|
||||||
|
messages: PromptMessage[],
|
||||||
|
model: string = 'llava-next',
|
||||||
|
options: CopilotChatOptions = {}
|
||||||
|
): Promise<string> {
|
||||||
|
if (!this.availableModels.includes(model)) {
|
||||||
|
throw new Error(`Invalid model: ${model}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
// by default, image prompt assumes there is only one message
|
||||||
|
const prompt = this.extractPrompt(messages.pop());
|
||||||
|
const data = (await fetch(`https://fal.run/fal-ai/${model}`, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
Authorization: `key ${this.config.apiKey}`,
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
...prompt,
|
||||||
|
sync_mode: true,
|
||||||
|
enable_safety_checks: false,
|
||||||
|
}),
|
||||||
|
signal: options.signal,
|
||||||
|
}).then(res => res.json())) as FalResponse;
|
||||||
|
|
||||||
|
if (!data.output) {
|
||||||
|
const error = this.extractError(data);
|
||||||
|
throw new Error(
|
||||||
|
error ? `Failed to generate image: ${error}` : 'No images generated'
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return data.output;
|
||||||
|
}
|
||||||
|
|
||||||
|
async *generateTextStream(
|
||||||
|
messages: PromptMessage[],
|
||||||
|
model: string = 'llava-next',
|
||||||
|
options: CopilotChatOptions = {}
|
||||||
|
): AsyncIterable<string> {
|
||||||
|
const result = await this.generateText(messages, model, options);
|
||||||
|
|
||||||
|
for await (const content of result) {
|
||||||
|
if (content) {
|
||||||
|
yield content;
|
||||||
|
if (options.signal?.aborted) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ====== image to image ======
|
// ====== image to image ======
|
||||||
async generateImages(
|
async generateImages(
|
||||||
messages: PromptMessage[],
|
messages: PromptMessage[],
|
||||||
@@ -113,7 +174,6 @@ export class FalProvider
|
|||||||
|
|
||||||
// by default, image prompt assumes there is only one message
|
// by default, image prompt assumes there is only one message
|
||||||
const prompt = this.extractPrompt(messages.pop());
|
const prompt = this.extractPrompt(messages.pop());
|
||||||
|
|
||||||
const data = (await fetch(`https://fal.run/fal-ai/${model}`, {
|
const data = (await fetch(`https://fal.run/fal-ai/${model}`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
|
|||||||
@@ -470,7 +470,7 @@ test('should be able to get provider', async t => {
|
|||||||
);
|
);
|
||||||
t.is(
|
t.is(
|
||||||
p?.type.toString(),
|
p?.type.toString(),
|
||||||
'openai',
|
'fal',
|
||||||
'should get provider support image-to-text'
|
'should get provider support image-to-text'
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ export class MockCopilotTestProvider
|
|||||||
{
|
{
|
||||||
override readonly availableModels = [
|
override readonly availableModels = [
|
||||||
'test',
|
'test',
|
||||||
'fast-turbo-diffusion',
|
'fast-sdxl/image-to-image',
|
||||||
'lcm-sd15-i2i',
|
'lcm-sd15-i2i',
|
||||||
'clarity-upscaler',
|
'clarity-upscaler',
|
||||||
'imageutils/rembg',
|
'imageutils/rembg',
|
||||||
|
|||||||
Reference in New Issue
Block a user