mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-12 04:18:54 +00:00
feat: update i2i model (#7041)
This commit is contained in:
@@ -0,0 +1,13 @@
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
|
||||
import { refreshPrompts } from './utils/prompts';
|
||||
|
||||
export class UpdatePrompts1716451792364 {
|
||||
// do the migration
|
||||
static async up(db: PrismaClient) {
|
||||
await refreshPrompts(db);
|
||||
}
|
||||
|
||||
// revert the migration
|
||||
static async down(_db: PrismaClient) {}
|
||||
}
|
||||
@@ -86,7 +86,7 @@ export const prompts: Prompt[] = [
|
||||
{
|
||||
name: 'debug:action:fal-sdturbo-clay',
|
||||
action: 'image',
|
||||
model: 'fast-turbo-diffusion',
|
||||
model: 'fast-sdxl/image-to-image',
|
||||
messages: [
|
||||
{
|
||||
role: 'user',
|
||||
@@ -102,7 +102,7 @@ export const prompts: Prompt[] = [
|
||||
{
|
||||
name: 'debug:action:fal-sdturbo-pixel',
|
||||
action: 'image',
|
||||
model: 'fast-turbo-diffusion',
|
||||
model: 'fast-sdxl/image-to-image',
|
||||
messages: [
|
||||
{
|
||||
role: 'user',
|
||||
@@ -116,7 +116,7 @@ export const prompts: Prompt[] = [
|
||||
{
|
||||
name: 'debug:action:fal-sdturbo-sketch',
|
||||
action: 'image',
|
||||
model: 'fast-turbo-diffusion',
|
||||
model: 'fast-sdxl/image-to-image',
|
||||
messages: [
|
||||
{
|
||||
role: 'user',
|
||||
@@ -132,7 +132,7 @@ export const prompts: Prompt[] = [
|
||||
{
|
||||
name: 'debug:action:fal-sdturbo-fantasy',
|
||||
action: 'image',
|
||||
model: 'fast-turbo-diffusion',
|
||||
model: 'fast-sdxl/image-to-image',
|
||||
messages: [
|
||||
{
|
||||
role: 'user',
|
||||
@@ -145,6 +145,24 @@ export const prompts: Prompt[] = [
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'debug:action:fal-face-to-sticker',
|
||||
action: 'image',
|
||||
model: 'face-to-sticker',
|
||||
messages: [],
|
||||
},
|
||||
{
|
||||
name: 'debug:action:fal-summary-caption',
|
||||
action: 'image',
|
||||
model: 'llava-next',
|
||||
messages: [
|
||||
{
|
||||
role: 'user',
|
||||
content:
|
||||
'Please understand this image and generate a short caption. {{content}}',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'Summary',
|
||||
action: 'Summary',
|
||||
|
||||
@@ -2,6 +2,7 @@ import assert from 'node:assert';
|
||||
|
||||
import {
|
||||
CopilotCapability,
|
||||
CopilotChatOptions,
|
||||
CopilotImageOptions,
|
||||
CopilotImageToImageProvider,
|
||||
CopilotProviderType,
|
||||
@@ -21,8 +22,12 @@ export type FalImage = {
|
||||
|
||||
export type FalResponse = {
|
||||
detail: Array<{ msg: string }> | string;
|
||||
// normal sd/sdxl response
|
||||
images?: Array<FalImage>;
|
||||
// special i2i model response
|
||||
image?: FalImage;
|
||||
// image2text response
|
||||
output: string;
|
||||
};
|
||||
|
||||
type FalPrompt = {
|
||||
@@ -38,6 +43,7 @@ export class FalProvider
|
||||
static readonly capabilities = [
|
||||
CopilotCapability.TextToImage,
|
||||
CopilotCapability.ImageToImage,
|
||||
CopilotCapability.ImageToText,
|
||||
];
|
||||
|
||||
readonly availableModels = [
|
||||
@@ -46,7 +52,11 @@ export class FalProvider
|
||||
// image to image
|
||||
'lcm-sd15-i2i',
|
||||
'clarity-upscaler',
|
||||
'face-to-sticker',
|
||||
'imageutils/rembg',
|
||||
'fast-sdxl/image-to-image',
|
||||
// image to text
|
||||
'llava-next',
|
||||
];
|
||||
|
||||
constructor(private readonly config: FalConfig) {
|
||||
@@ -96,11 +106,62 @@ export class FalProvider
|
||||
).filter(v => typeof v === 'string' && v.length);
|
||||
return {
|
||||
image_url: attachments?.[0],
|
||||
prompt: content,
|
||||
prompt: content.trim(),
|
||||
lora: lora.length ? lora : undefined,
|
||||
};
|
||||
}
|
||||
|
||||
async generateText(
|
||||
messages: PromptMessage[],
|
||||
model: string = 'llava-next',
|
||||
options: CopilotChatOptions = {}
|
||||
): Promise<string> {
|
||||
if (!this.availableModels.includes(model)) {
|
||||
throw new Error(`Invalid model: ${model}`);
|
||||
}
|
||||
|
||||
// by default, image prompt assumes there is only one message
|
||||
const prompt = this.extractPrompt(messages.pop());
|
||||
const data = (await fetch(`https://fal.run/fal-ai/${model}`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Authorization: `key ${this.config.apiKey}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
...prompt,
|
||||
sync_mode: true,
|
||||
enable_safety_checks: false,
|
||||
}),
|
||||
signal: options.signal,
|
||||
}).then(res => res.json())) as FalResponse;
|
||||
|
||||
if (!data.output) {
|
||||
const error = this.extractError(data);
|
||||
throw new Error(
|
||||
error ? `Failed to generate image: ${error}` : 'No images generated'
|
||||
);
|
||||
}
|
||||
return data.output;
|
||||
}
|
||||
|
||||
async *generateTextStream(
|
||||
messages: PromptMessage[],
|
||||
model: string = 'llava-next',
|
||||
options: CopilotChatOptions = {}
|
||||
): AsyncIterable<string> {
|
||||
const result = await this.generateText(messages, model, options);
|
||||
|
||||
for await (const content of result) {
|
||||
if (content) {
|
||||
yield content;
|
||||
if (options.signal?.aborted) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ====== image to image ======
|
||||
async generateImages(
|
||||
messages: PromptMessage[],
|
||||
@@ -113,7 +174,6 @@ export class FalProvider
|
||||
|
||||
// by default, image prompt assumes there is only one message
|
||||
const prompt = this.extractPrompt(messages.pop());
|
||||
|
||||
const data = (await fetch(`https://fal.run/fal-ai/${model}`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
|
||||
@@ -470,7 +470,7 @@ test('should be able to get provider', async t => {
|
||||
);
|
||||
t.is(
|
||||
p?.type.toString(),
|
||||
'openai',
|
||||
'fal',
|
||||
'should get provider support image-to-text'
|
||||
);
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ export class MockCopilotTestProvider
|
||||
{
|
||||
override readonly availableModels = [
|
||||
'test',
|
||||
'fast-turbo-diffusion',
|
||||
'fast-sdxl/image-to-image',
|
||||
'lcm-sd15-i2i',
|
||||
'clarity-upscaler',
|
||||
'imageutils/rembg',
|
||||
|
||||
Reference in New Issue
Block a user