feat: add upscaler & bg remover (#6967)

This commit is contained in:
darkskygit
2024-05-16 11:09:33 +00:00
parent f37bbb0784
commit a3f3d09764
7 changed files with 105 additions and 6 deletions

View File

@@ -9,12 +9,16 @@ import Sinon from 'sinon';
import { AuthService } from '../src/core/auth';
import { WorkspaceModule } from '../src/core/workspaces';
import { prompts } from '../src/data/migrations/utils/prompts';
import { ConfigModule } from '../src/fundamentals/config';
import { CopilotModule } from '../src/plugins/copilot';
import { PromptService } from '../src/plugins/copilot/prompt';
import {
CopilotProviderService,
FalProvider,
OpenAIProvider,
registerCopilotProvider,
unregisterCopilotProvider,
} from '../src/plugins/copilot/providers';
import { CopilotStorage } from '../src/plugins/copilot/storage';
import {
@@ -80,11 +84,17 @@ test.beforeEach(async t => {
const user = await signUp(app, 'test', 'darksky@affine.pro', '123456');
token = user.token.token;
unregisterCopilotProvider(OpenAIProvider.type);
unregisterCopilotProvider(FalProvider.type);
registerCopilotProvider(MockCopilotTestProvider);
await prompt.set(promptName, 'test', [
{ role: 'system', content: 'hello {{word}}' },
]);
for (const p of prompts) {
await prompt.set(p.name, p.model, p.messages);
}
});
test.afterEach.always(async t => {
@@ -218,7 +228,7 @@ test('should be able to chat with api', async t => {
t.is(
ret3,
textToEventStream(
['https://example.com/image.jpg'],
['https://example.com/test.jpg', 'generate text to text stream'],
messageId,
'attachment'
),
@@ -228,6 +238,51 @@ test('should be able to chat with api', async t => {
Sinon.restore();
});
test('should be able to chat with special image model', async t => {
const { app, storage } = t.context;
Sinon.stub(storage, 'handleRemoteLink').resolvesArg(2);
const { id } = await createWorkspace(app, token);
const testWithModel = async (promptName: string, finalPrompt: string) => {
const model = prompts.find(p => p.name === promptName)?.model;
const sessionId = await createCopilotSession(
app,
token,
id,
randomUUID(),
promptName
);
const messageId = await createCopilotMessage(
app,
token,
sessionId,
'some-tag',
[`https://example.com/${promptName}.jpg`]
);
const ret3 = await chatWithImages(app, token, sessionId, messageId);
t.is(
ret3,
textToEventStream(
[`https://example.com/${model}.jpg`, finalPrompt],
messageId,
'attachment'
),
'should be able to chat with images'
);
};
await testWithModel('debug:action:fal-sd15', 'some-tag');
await testWithModel(
'debug:action:fal-upscaler',
'best quality, 8K resolution, highres, clarity, some-tag'
);
await testWithModel('debug:action:fal-remove-bg', 'some-tag');
Sinon.restore();
});
test('should be able to retry with api', async t => {
const { app, storage } = t.context;

View File

@@ -29,7 +29,13 @@ export class MockCopilotTestProvider
CopilotImageToImageProvider,
CopilotImageToTextProvider
{
override readonly availableModels = ['test'];
override readonly availableModels = [
'test',
'fast-turbo-diffusion',
'lcm-sd15-i2i',
'clarity-upscaler',
'imageutils/rembg',
];
static override readonly capabilities = [
CopilotCapability.TextToText,
CopilotCapability.TextToEmbedding,
@@ -107,7 +113,7 @@ export class MockCopilotTestProvider
// ====== text to image ======
override async generateImages(
messages: PromptMessage[],
_model: string = 'test',
model: string = 'test',
_options: {
signal?: AbortSignal;
user?: string;
@@ -118,7 +124,8 @@ export class MockCopilotTestProvider
throw new Error('Prompt is required');
}
return ['https://example.com/image.jpg'];
// just let test case can easily verify the final prompt
return [`https://example.com/${model}.jpg`, prompt];
}
override async *generateImagesStream(