mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-11 20:08:37 +00:00
feat: add upscaler & bg remover (#6967)
This commit is contained in:
@@ -9,12 +9,16 @@ import Sinon from 'sinon';
|
||||
|
||||
import { AuthService } from '../src/core/auth';
|
||||
import { WorkspaceModule } from '../src/core/workspaces';
|
||||
import { prompts } from '../src/data/migrations/utils/prompts';
|
||||
import { ConfigModule } from '../src/fundamentals/config';
|
||||
import { CopilotModule } from '../src/plugins/copilot';
|
||||
import { PromptService } from '../src/plugins/copilot/prompt';
|
||||
import {
|
||||
CopilotProviderService,
|
||||
FalProvider,
|
||||
OpenAIProvider,
|
||||
registerCopilotProvider,
|
||||
unregisterCopilotProvider,
|
||||
} from '../src/plugins/copilot/providers';
|
||||
import { CopilotStorage } from '../src/plugins/copilot/storage';
|
||||
import {
|
||||
@@ -80,11 +84,17 @@ test.beforeEach(async t => {
|
||||
const user = await signUp(app, 'test', 'darksky@affine.pro', '123456');
|
||||
token = user.token.token;
|
||||
|
||||
unregisterCopilotProvider(OpenAIProvider.type);
|
||||
unregisterCopilotProvider(FalProvider.type);
|
||||
registerCopilotProvider(MockCopilotTestProvider);
|
||||
|
||||
await prompt.set(promptName, 'test', [
|
||||
{ role: 'system', content: 'hello {{word}}' },
|
||||
]);
|
||||
|
||||
for (const p of prompts) {
|
||||
await prompt.set(p.name, p.model, p.messages);
|
||||
}
|
||||
});
|
||||
|
||||
test.afterEach.always(async t => {
|
||||
@@ -218,7 +228,7 @@ test('should be able to chat with api', async t => {
|
||||
t.is(
|
||||
ret3,
|
||||
textToEventStream(
|
||||
['https://example.com/image.jpg'],
|
||||
['https://example.com/test.jpg', 'generate text to text stream'],
|
||||
messageId,
|
||||
'attachment'
|
||||
),
|
||||
@@ -228,6 +238,51 @@ test('should be able to chat with api', async t => {
|
||||
Sinon.restore();
|
||||
});
|
||||
|
||||
test('should be able to chat with special image model', async t => {
|
||||
const { app, storage } = t.context;
|
||||
|
||||
Sinon.stub(storage, 'handleRemoteLink').resolvesArg(2);
|
||||
|
||||
const { id } = await createWorkspace(app, token);
|
||||
|
||||
const testWithModel = async (promptName: string, finalPrompt: string) => {
|
||||
const model = prompts.find(p => p.name === promptName)?.model;
|
||||
const sessionId = await createCopilotSession(
|
||||
app,
|
||||
token,
|
||||
id,
|
||||
randomUUID(),
|
||||
promptName
|
||||
);
|
||||
const messageId = await createCopilotMessage(
|
||||
app,
|
||||
token,
|
||||
sessionId,
|
||||
'some-tag',
|
||||
[`https://example.com/${promptName}.jpg`]
|
||||
);
|
||||
const ret3 = await chatWithImages(app, token, sessionId, messageId);
|
||||
t.is(
|
||||
ret3,
|
||||
textToEventStream(
|
||||
[`https://example.com/${model}.jpg`, finalPrompt],
|
||||
messageId,
|
||||
'attachment'
|
||||
),
|
||||
'should be able to chat with images'
|
||||
);
|
||||
};
|
||||
|
||||
await testWithModel('debug:action:fal-sd15', 'some-tag');
|
||||
await testWithModel(
|
||||
'debug:action:fal-upscaler',
|
||||
'best quality, 8K resolution, highres, clarity, some-tag'
|
||||
);
|
||||
await testWithModel('debug:action:fal-remove-bg', 'some-tag');
|
||||
|
||||
Sinon.restore();
|
||||
});
|
||||
|
||||
test('should be able to retry with api', async t => {
|
||||
const { app, storage } = t.context;
|
||||
|
||||
|
||||
@@ -29,7 +29,13 @@ export class MockCopilotTestProvider
|
||||
CopilotImageToImageProvider,
|
||||
CopilotImageToTextProvider
|
||||
{
|
||||
override readonly availableModels = ['test'];
|
||||
override readonly availableModels = [
|
||||
'test',
|
||||
'fast-turbo-diffusion',
|
||||
'lcm-sd15-i2i',
|
||||
'clarity-upscaler',
|
||||
'imageutils/rembg',
|
||||
];
|
||||
static override readonly capabilities = [
|
||||
CopilotCapability.TextToText,
|
||||
CopilotCapability.TextToEmbedding,
|
||||
@@ -107,7 +113,7 @@ export class MockCopilotTestProvider
|
||||
// ====== text to image ======
|
||||
override async generateImages(
|
||||
messages: PromptMessage[],
|
||||
_model: string = 'test',
|
||||
model: string = 'test',
|
||||
_options: {
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
@@ -118,7 +124,8 @@ export class MockCopilotTestProvider
|
||||
throw new Error('Prompt is required');
|
||||
}
|
||||
|
||||
return ['https://example.com/image.jpg'];
|
||||
// just let test case can easily verify the final prompt
|
||||
return [`https://example.com/${model}.jpg`, prompt];
|
||||
}
|
||||
|
||||
override async *generateImagesStream(
|
||||
|
||||
Reference in New Issue
Block a user