From a3f3d097641e8af30d9206d5688fed1f87285a38 Mon Sep 17 00:00:00 2001 From: darkskygit Date: Thu, 16 May 2024 11:09:33 +0000 Subject: [PATCH] feat: add upscaler & bg remover (#6967) --- .../src/data/migrations/utils/prompts.ts | 17 ++++++ .../server/src/plugins/copilot/controller.ts | 1 + .../src/plugins/copilot/providers/fal.ts | 10 +++- .../src/plugins/copilot/providers/index.ts | 11 ++++ packages/backend/server/tests/copilot.e2e.ts | 57 ++++++++++++++++++- .../backend/server/tests/utils/copilot.ts | 13 ++++- .../block-suite-editor/ai/prompt.ts | 2 + 7 files changed, 105 insertions(+), 6 deletions(-) diff --git a/packages/backend/server/src/data/migrations/utils/prompts.ts b/packages/backend/server/src/data/migrations/utils/prompts.ts index 287c7cdf90..c1aa495466 100644 --- a/packages/backend/server/src/data/migrations/utils/prompts.ts +++ b/packages/backend/server/src/data/migrations/utils/prompts.ts @@ -66,6 +66,23 @@ export const prompts: Prompt[] = [ model: 'fast-turbo-diffusion', messages: [], }, + { + name: 'debug:action:fal-upscaler', + action: 'image', + model: 'clarity-upscaler', + messages: [ + { + role: 'user', + content: 'best quality, 8K resolution, highres, clarity, {{content}}', + }, + ], + }, + { + name: 'debug:action:fal-remove-bg', + action: 'image', + model: 'imageutils/rembg', + messages: [], + }, { name: 'Summary', action: 'Summary', diff --git a/packages/backend/server/src/plugins/copilot/controller.ts b/packages/backend/server/src/plugins/copilot/controller.ts index 62e3d14056..097fc28d79 100644 --- a/packages/backend/server/src/plugins/copilot/controller.ts +++ b/packages/backend/server/src/plugins/copilot/controller.ts @@ -127,6 +127,7 @@ export class CopilotController { if (err instanceof HttpException) { ret.status = err.getStatus(); } + return ret; } return err; } diff --git a/packages/backend/server/src/plugins/copilot/providers/fal.ts b/packages/backend/server/src/plugins/copilot/providers/fal.ts index cfb13d879a..42a52eecc3 100644 --- a/packages/backend/server/src/plugins/copilot/providers/fal.ts +++ b/packages/backend/server/src/plugins/copilot/providers/fal.ts @@ -14,7 +14,7 @@ export type FalConfig = { }; export type FalResponse = { - detail: Array<{ msg: string }>; + detail: Array<{ msg: string }> | string; images: Array<{ url: string }>; }; @@ -32,6 +32,8 @@ export class FalProvider 'fast-turbo-diffusion', // image to image 'lcm-sd15-i2i', + 'clarity-upscaler', + 'imageutils/rembg', ]; constructor(private readonly config: FalConfig) { @@ -87,7 +89,11 @@ export class FalProvider }).then(res => res.json())) as FalResponse; if (!data.images?.length) { - const error = data.detail?.[0]?.msg; + const error = Array.isArray(data.detail) + ? data.detail[0]?.msg + : typeof data.detail === 'string' + ? data.detail + : ''; throw new Error( error ? `Invalid message: ${error}` : 'No images generated' ); diff --git a/packages/backend/server/src/plugins/copilot/providers/index.ts b/packages/backend/server/src/plugins/copilot/providers/index.ts index 675e09e307..d48a911ad2 100644 --- a/packages/backend/server/src/plugins/copilot/providers/index.ts +++ b/packages/backend/server/src/plugins/copilot/providers/index.ts @@ -77,6 +77,17 @@ export function registerCopilotProvider< }); } +export function unregisterCopilotProvider(type: CopilotProviderType) { + COPILOT_PROVIDER.delete(type); + ASSERT_CONFIG.delete(type); + for (const providers of PROVIDER_CAPABILITY_MAP.values()) { + const index = providers.indexOf(type); + if (index !== -1) { + providers.splice(index, 1); + } + } +} + /// Asserts that the config is valid for any registered providers export function assertProvidersConfigs(config: Config) { return ( diff --git a/packages/backend/server/tests/copilot.e2e.ts b/packages/backend/server/tests/copilot.e2e.ts index 4653330fa5..22ce793c5c 100644 --- a/packages/backend/server/tests/copilot.e2e.ts +++ b/packages/backend/server/tests/copilot.e2e.ts @@ -9,12 +9,16 @@ import Sinon from 'sinon'; import { AuthService } from '../src/core/auth'; import { WorkspaceModule } from '../src/core/workspaces'; +import { prompts } from '../src/data/migrations/utils/prompts'; import { ConfigModule } from '../src/fundamentals/config'; import { CopilotModule } from '../src/plugins/copilot'; import { PromptService } from '../src/plugins/copilot/prompt'; import { CopilotProviderService, + FalProvider, + OpenAIProvider, registerCopilotProvider, + unregisterCopilotProvider, } from '../src/plugins/copilot/providers'; import { CopilotStorage } from '../src/plugins/copilot/storage'; import { @@ -80,11 +84,17 @@ test.beforeEach(async t => { const user = await signUp(app, 'test', 'darksky@affine.pro', '123456'); token = user.token.token; + unregisterCopilotProvider(OpenAIProvider.type); + unregisterCopilotProvider(FalProvider.type); registerCopilotProvider(MockCopilotTestProvider); await prompt.set(promptName, 'test', [ { role: 'system', content: 'hello {{word}}' }, ]); + + for (const p of prompts) { + await prompt.set(p.name, p.model, p.messages); + } }); test.afterEach.always(async t => { @@ -218,7 +228,7 @@ test('should be able to chat with api', async t => { t.is( ret3, textToEventStream( - ['https://example.com/image.jpg'], + ['https://example.com/test.jpg', 'generate text to text stream'], messageId, 'attachment' ), @@ -228,6 +238,51 @@ test('should be able to chat with api', async t => { Sinon.restore(); }); +test('should be able to chat with special image model', async t => { + const { app, storage } = t.context; + + Sinon.stub(storage, 'handleRemoteLink').resolvesArg(2); + + const { id } = await createWorkspace(app, token); + + const testWithModel = async (promptName: string, finalPrompt: string) => { + const model = prompts.find(p => p.name === promptName)?.model; + const sessionId = await createCopilotSession( + app, + token, + id, + randomUUID(), + promptName + ); + const messageId = await createCopilotMessage( + app, + token, + sessionId, + 'some-tag', + [`https://example.com/${promptName}.jpg`] + ); + const ret3 = await chatWithImages(app, token, sessionId, messageId); + t.is( + ret3, + textToEventStream( + [`https://example.com/${model}.jpg`, finalPrompt], + messageId, + 'attachment' + ), + 'should be able to chat with images' + ); + }; + + await testWithModel('debug:action:fal-sd15', 'some-tag'); + await testWithModel( + 'debug:action:fal-upscaler', + 'best quality, 8K resolution, highres, clarity, some-tag' + ); + await testWithModel('debug:action:fal-remove-bg', 'some-tag'); + + Sinon.restore(); +}); + test('should be able to retry with api', async t => { const { app, storage } = t.context; diff --git a/packages/backend/server/tests/utils/copilot.ts b/packages/backend/server/tests/utils/copilot.ts index b28bb2bcbd..d349a7671e 100644 --- a/packages/backend/server/tests/utils/copilot.ts +++ b/packages/backend/server/tests/utils/copilot.ts @@ -29,7 +29,13 @@ export class MockCopilotTestProvider CopilotImageToImageProvider, CopilotImageToTextProvider { - override readonly availableModels = ['test']; + override readonly availableModels = [ + 'test', + 'fast-turbo-diffusion', + 'lcm-sd15-i2i', + 'clarity-upscaler', + 'imageutils/rembg', + ]; static override readonly capabilities = [ CopilotCapability.TextToText, CopilotCapability.TextToEmbedding, @@ -107,7 +113,7 @@ export class MockCopilotTestProvider // ====== text to image ====== override async generateImages( messages: PromptMessage[], - _model: string = 'test', + model: string = 'test', _options: { signal?: AbortSignal; user?: string; @@ -118,7 +124,8 @@ export class MockCopilotTestProvider throw new Error('Prompt is required'); } - return ['https://example.com/image.jpg']; + // just let test case can easily verify the final prompt + return [`https://example.com/${model}.jpg`, prompt]; } override async *generateImagesStream( diff --git a/packages/frontend/core/src/components/blocksuite/block-suite-editor/ai/prompt.ts b/packages/frontend/core/src/components/blocksuite/block-suite-editor/ai/prompt.ts index 2d56fd2bd7..e85df73e37 100644 --- a/packages/frontend/core/src/components/blocksuite/block-suite-editor/ai/prompt.ts +++ b/packages/frontend/core/src/components/blocksuite/block-suite-editor/ai/prompt.ts @@ -6,6 +6,8 @@ export const promptKeys = [ 'debug:action:vision4', 'debug:action:dalle3', 'debug:action:fal-sd15', + 'debug:action:fal-upscaler', + 'debug:action:fal-rembg', 'chat:gpt4', 'Summary', 'Summary the webpage',