mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-12 12:28:42 +00:00
feat(server): add real-world prompt test for copilot apis (#8629)
fix AF-1432, PD-1176
This commit is contained in:
492
packages/backend/server/tests/copilot-provider.e2e.ts
Normal file
492
packages/backend/server/tests/copilot-provider.e2e.ts
Normal file
@@ -0,0 +1,492 @@
|
||||
/// <reference types="../src/global.d.ts" />
|
||||
|
||||
import { TestingModule } from '@nestjs/testing';
|
||||
import type { ExecutionContext, TestFn } from 'ava';
|
||||
import ava from 'ava';
|
||||
|
||||
import { AuthService } from '../src/core/auth';
|
||||
import { QuotaModule } from '../src/core/quota';
|
||||
import { ConfigModule } from '../src/fundamentals/config';
|
||||
import { CopilotModule } from '../src/plugins/copilot';
|
||||
import { prompts, PromptService } from '../src/plugins/copilot/prompt';
|
||||
import {
|
||||
CopilotProviderService,
|
||||
FalProvider,
|
||||
OpenAIProvider,
|
||||
registerCopilotProvider,
|
||||
unregisterCopilotProvider,
|
||||
} from '../src/plugins/copilot/providers';
|
||||
import {
|
||||
CopilotChatTextExecutor,
|
||||
CopilotWorkflowService,
|
||||
GraphExecutorState,
|
||||
} from '../src/plugins/copilot/workflow';
|
||||
import {
|
||||
CopilotChatImageExecutor,
|
||||
CopilotCheckHtmlExecutor,
|
||||
CopilotCheckJsonExecutor,
|
||||
} from '../src/plugins/copilot/workflow/executor';
|
||||
import { createTestingModule } from './utils';
|
||||
import { TestAssets } from './utils/copilot';
|
||||
|
||||
type Tester = {
|
||||
auth: AuthService;
|
||||
module: TestingModule;
|
||||
prompt: PromptService;
|
||||
provider: CopilotProviderService;
|
||||
workflow: CopilotWorkflowService;
|
||||
executors: {
|
||||
image: CopilotChatImageExecutor;
|
||||
text: CopilotChatTextExecutor;
|
||||
html: CopilotCheckHtmlExecutor;
|
||||
json: CopilotCheckJsonExecutor;
|
||||
};
|
||||
};
|
||||
const test = ava as TestFn<Tester>;
|
||||
|
||||
const isCopilotConfigured =
|
||||
!!process.env.COPILOT_OPENAI_API_KEY &&
|
||||
!!process.env.COPILOT_FAL_API_KEY &&
|
||||
process.env.COPILOT_OPENAI_API_KEY !== '1' &&
|
||||
process.env.COPILOT_FAL_API_KEY !== '1';
|
||||
const runIfCopilotConfigured = test.macro(
|
||||
async (
|
||||
t,
|
||||
callback: (t: ExecutionContext<Tester>) => Promise<void> | void
|
||||
) => {
|
||||
if (isCopilotConfigured) {
|
||||
await callback(t);
|
||||
} else {
|
||||
t.log('Skip test because copilot is not configured');
|
||||
t.pass();
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
test.beforeEach(async t => {
|
||||
const module = await createTestingModule({
|
||||
imports: [
|
||||
ConfigModule.forRoot({
|
||||
plugins: {
|
||||
copilot: {
|
||||
openai: {
|
||||
apiKey: process.env.COPILOT_OPENAI_API_KEY,
|
||||
},
|
||||
fal: {
|
||||
apiKey: process.env.COPILOT_FAL_API_KEY,
|
||||
},
|
||||
},
|
||||
},
|
||||
}),
|
||||
QuotaModule,
|
||||
CopilotModule,
|
||||
],
|
||||
});
|
||||
|
||||
const auth = module.get(AuthService);
|
||||
const prompt = module.get(PromptService);
|
||||
const provider = module.get(CopilotProviderService);
|
||||
const workflow = module.get(CopilotWorkflowService);
|
||||
|
||||
t.context.module = module;
|
||||
t.context.auth = auth;
|
||||
t.context.prompt = prompt;
|
||||
t.context.provider = provider;
|
||||
t.context.workflow = workflow;
|
||||
t.context.executors = {
|
||||
image: module.get(CopilotChatImageExecutor),
|
||||
text: module.get(CopilotChatTextExecutor),
|
||||
html: module.get(CopilotCheckHtmlExecutor),
|
||||
json: module.get(CopilotCheckJsonExecutor),
|
||||
};
|
||||
});
|
||||
|
||||
test.beforeEach(async t => {
|
||||
const { prompt, executors } = t.context;
|
||||
|
||||
executors.image.register();
|
||||
executors.text.register();
|
||||
executors.html.register();
|
||||
executors.json.register();
|
||||
|
||||
registerCopilotProvider(OpenAIProvider);
|
||||
registerCopilotProvider(FalProvider);
|
||||
|
||||
for (const name of await prompt.listNames()) {
|
||||
await prompt.delete(name);
|
||||
}
|
||||
|
||||
for (const p of prompts) {
|
||||
await prompt.set(p.name, p.model, p.messages, p.config);
|
||||
}
|
||||
});
|
||||
|
||||
test.afterEach.always(async _ => {
|
||||
unregisterCopilotProvider(OpenAIProvider.type);
|
||||
unregisterCopilotProvider(FalProvider.type);
|
||||
});
|
||||
|
||||
test.afterEach.always(async t => {
|
||||
await t.context.module.close();
|
||||
});
|
||||
|
||||
const assertNotWrappedInCodeBlock = (
|
||||
t: ExecutionContext<Tester>,
|
||||
result: string
|
||||
) => {
|
||||
t.assert(
|
||||
!result.replaceAll('\n', '').trim().startsWith('```') &&
|
||||
!result.replaceAll('\n', '').trim().endsWith('```'),
|
||||
'should not wrap in code block'
|
||||
);
|
||||
};
|
||||
|
||||
const checkMDList = (text: string) => {
|
||||
const lines = text.split('\n');
|
||||
const listItemRegex = /^( {2})*(-|\*|\+) .+$/;
|
||||
let prevIndent = null;
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.trim() === '') continue;
|
||||
if (!listItemRegex.test(line)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-non-null-asserted-optional-chain
|
||||
const currentIndent = line.match(/^( *)/)?.[0].length!;
|
||||
if (Number.isNaN(currentIndent) || currentIndent % 2 !== 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (prevIndent !== null && currentIndent > 0) {
|
||||
const indentDiff = currentIndent - prevIndent;
|
||||
// allow 1 level of indentation difference
|
||||
if (indentDiff > 2) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
prevIndent = currentIndent;
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
const checkUrl = (url: string) => {
|
||||
try {
|
||||
new URL(url);
|
||||
return true;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
const retry = async (
|
||||
action: string,
|
||||
t: ExecutionContext<Tester>,
|
||||
callback: (t: ExecutionContext<Tester>) => void
|
||||
) => {
|
||||
let i = 3;
|
||||
while (i--) {
|
||||
const ret = await t.try(callback);
|
||||
if (ret.passed) {
|
||||
ret.commit();
|
||||
break;
|
||||
} else {
|
||||
ret.discard();
|
||||
t.log(ret.errors.map(e => e.message).join('\n'));
|
||||
t.log(`retrying ${action} ${3 - i}/3 ...`);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// ==================== utils ====================
|
||||
|
||||
test('should validate markdown list', t => {
|
||||
t.true(
|
||||
checkMDList(`
|
||||
- item 1
|
||||
- item 2
|
||||
`)
|
||||
);
|
||||
t.true(
|
||||
checkMDList(`
|
||||
- item 1
|
||||
- item 1.1
|
||||
- item 2
|
||||
`)
|
||||
);
|
||||
t.true(
|
||||
checkMDList(`
|
||||
- item 1
|
||||
- item 1.1
|
||||
- item 1.1.1
|
||||
- item 2
|
||||
`)
|
||||
);
|
||||
t.true(
|
||||
checkMDList(`
|
||||
- item 1
|
||||
- item 1.1
|
||||
- item 1.1.1
|
||||
- item 1.1.2
|
||||
- item 2
|
||||
`)
|
||||
);
|
||||
t.true(
|
||||
checkMDList(`
|
||||
- item 1
|
||||
- item 1.1
|
||||
- item 1.1.1
|
||||
- item 1.2
|
||||
`)
|
||||
);
|
||||
t.false(
|
||||
checkMDList(`
|
||||
- item 1
|
||||
- item 1.1
|
||||
- item 1.1.1.1
|
||||
`)
|
||||
);
|
||||
});
|
||||
|
||||
// ==================== action ====================
|
||||
|
||||
const actions = [
|
||||
{
|
||||
promptName: [
|
||||
'Summary',
|
||||
'Explain this',
|
||||
'Write an article about this',
|
||||
'Write a twitter about this',
|
||||
'Write a poem about this',
|
||||
'Write a blog post about this',
|
||||
'Write outline',
|
||||
'Change tone to',
|
||||
'Improve writing for it',
|
||||
'Improve grammar for it',
|
||||
'Fix spelling for it',
|
||||
'Create headings',
|
||||
'Make it longer',
|
||||
'Make it shorter',
|
||||
'Continue writing',
|
||||
],
|
||||
messages: [{ role: 'user' as const, content: TestAssets.SSOT }],
|
||||
verifier: (t: ExecutionContext<Tester>, result: string) => {
|
||||
assertNotWrappedInCodeBlock(t, result);
|
||||
t.assert(
|
||||
result.toLowerCase().includes('single source of truth'),
|
||||
'should include original keyword'
|
||||
);
|
||||
},
|
||||
type: 'text' as const,
|
||||
},
|
||||
{
|
||||
promptName: ['Brainstorm ideas about this', 'Brainstorm mindmap'],
|
||||
messages: [{ role: 'user' as const, content: TestAssets.SSOT }],
|
||||
verifier: (t: ExecutionContext<Tester>, result: string) => {
|
||||
assertNotWrappedInCodeBlock(t, result);
|
||||
t.assert(checkMDList(result), 'should be a markdown list');
|
||||
},
|
||||
type: 'text' as const,
|
||||
},
|
||||
{
|
||||
promptName: 'Expand mind map',
|
||||
messages: [{ role: 'user' as const, content: '- Single source of truth' }],
|
||||
verifier: (t: ExecutionContext<Tester>, result: string) => {
|
||||
assertNotWrappedInCodeBlock(t, result);
|
||||
t.assert(checkMDList(result), 'should be a markdown list');
|
||||
},
|
||||
type: 'text' as const,
|
||||
},
|
||||
{
|
||||
promptName: 'Find action items from it',
|
||||
messages: [{ role: 'user' as const, content: TestAssets.TODO }],
|
||||
verifier: (t: ExecutionContext<Tester>, result: string) => {
|
||||
assertNotWrappedInCodeBlock(t, result);
|
||||
t.assert(checkMDList(result), 'should be a markdown list');
|
||||
},
|
||||
type: 'text' as const,
|
||||
},
|
||||
{
|
||||
promptName: ['Explain this code', 'Check code error'],
|
||||
messages: [{ role: 'user' as const, content: TestAssets.Code }],
|
||||
verifier: (t: ExecutionContext<Tester>, result: string) => {
|
||||
assertNotWrappedInCodeBlock(t, result);
|
||||
t.assert(
|
||||
result.toLowerCase().includes('distance'),
|
||||
'explain code result should include keyword'
|
||||
);
|
||||
},
|
||||
type: 'text' as const,
|
||||
},
|
||||
{
|
||||
promptName: 'Translate to',
|
||||
messages: [
|
||||
{
|
||||
role: 'user' as const,
|
||||
content: TestAssets.SSOT,
|
||||
params: { language: 'Simplified Chinese' },
|
||||
},
|
||||
],
|
||||
verifier: (t: ExecutionContext<Tester>, result: string) => {
|
||||
assertNotWrappedInCodeBlock(t, result);
|
||||
t.assert(
|
||||
result.toLowerCase().includes('单一事实来源'),
|
||||
'explain code result should include keyword'
|
||||
);
|
||||
},
|
||||
type: 'text' as const,
|
||||
},
|
||||
{
|
||||
promptName: ['Generate a caption', 'Explain this image'],
|
||||
messages: [
|
||||
{
|
||||
role: 'user' as const,
|
||||
content: '',
|
||||
attachments: [
|
||||
'https://cdn.affine.pro/copilot-test/Qgqy9qZT3VGIEuMIotJYoCCH.jpg',
|
||||
],
|
||||
},
|
||||
],
|
||||
verifier: (t: ExecutionContext<Tester>, result: string) => {
|
||||
assertNotWrappedInCodeBlock(t, result);
|
||||
const content = result.toLowerCase();
|
||||
t.assert(
|
||||
content.includes('classroom') ||
|
||||
content.includes('school') ||
|
||||
content.includes('sky'),
|
||||
'explain code result should include keyword'
|
||||
);
|
||||
},
|
||||
type: 'text' as const,
|
||||
},
|
||||
{
|
||||
promptName: [
|
||||
'debug:action:fal-face-to-sticker',
|
||||
'debug:action:fal-remove-bg',
|
||||
'debug:action:fal-sd15',
|
||||
'debug:action:fal-upscaler',
|
||||
],
|
||||
messages: [
|
||||
{
|
||||
role: 'user' as const,
|
||||
content: '',
|
||||
attachments: [
|
||||
'https://cdn.affine.pro/copilot-test/Zkas098lkjdf-908231.jpg',
|
||||
],
|
||||
},
|
||||
],
|
||||
verifier: (t: ExecutionContext<Tester>, link: string) => {
|
||||
t.truthy(checkUrl(link), 'should be a valid url');
|
||||
},
|
||||
type: 'image' as const,
|
||||
},
|
||||
];
|
||||
for (const { promptName, messages, verifier, type } of actions) {
|
||||
const prompts = Array.isArray(promptName) ? promptName : [promptName];
|
||||
for (const promptName of prompts) {
|
||||
test(
|
||||
`should be able to run action: ${promptName}`,
|
||||
runIfCopilotConfigured,
|
||||
async t => {
|
||||
const { provider: providerService, prompt: promptService } = t.context;
|
||||
const prompt = (await promptService.get(promptName))!;
|
||||
t.truthy(prompt, 'should have prompt');
|
||||
const provider = (await providerService.getProviderByModel(
|
||||
prompt.model
|
||||
))!;
|
||||
t.truthy(provider, 'should have provider');
|
||||
await retry(`action: ${promptName}`, t, async t => {
|
||||
if (type === 'text' && 'generateText' in provider) {
|
||||
const result = await provider.generateText(
|
||||
[
|
||||
...prompt.finish(
|
||||
messages.reduce(
|
||||
// @ts-expect-error
|
||||
(acc, m) => Object.assign(acc, m.params),
|
||||
{}
|
||||
)
|
||||
),
|
||||
...messages,
|
||||
],
|
||||
prompt.model
|
||||
);
|
||||
t.truthy(result, 'should return result');
|
||||
verifier?.(t, result);
|
||||
} else if (type === 'image' && 'generateImages' in provider) {
|
||||
const result = await provider.generateImages(
|
||||
[
|
||||
...prompt.finish(
|
||||
messages.reduce(
|
||||
// @ts-expect-error
|
||||
(acc, m) => Object.assign(acc, m.params),
|
||||
{}
|
||||
)
|
||||
),
|
||||
...messages,
|
||||
],
|
||||
prompt.model
|
||||
);
|
||||
t.truthy(result.length, 'should return result');
|
||||
for (const r of result) {
|
||||
verifier?.(t, r);
|
||||
}
|
||||
} else {
|
||||
t.fail('unsupported provider type');
|
||||
}
|
||||
});
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== workflow ====================
|
||||
|
||||
const workflows = [
|
||||
{
|
||||
name: 'brainstorm',
|
||||
content: 'apple company',
|
||||
verifier: (t: ExecutionContext<Tester>, result: string) => {
|
||||
t.assert(checkMDList(result), 'should be a markdown list');
|
||||
},
|
||||
},
|
||||
{
|
||||
name: 'presentation',
|
||||
content: 'apple company',
|
||||
verifier: (t: ExecutionContext<Tester>, result: string) => {
|
||||
for (const l of result.split('\n')) {
|
||||
t.notThrows(() => {
|
||||
JSON.parse(l.trim());
|
||||
}, 'should be valid json');
|
||||
}
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
for (const { name, content, verifier } of workflows) {
|
||||
test(
|
||||
`should be able to run workflow: ${name}`,
|
||||
runIfCopilotConfigured,
|
||||
async t => {
|
||||
const { workflow } = t.context;
|
||||
|
||||
await retry(`workflow: ${name}`, t, async t => {
|
||||
let result = '';
|
||||
for await (const ret of workflow.runGraph({ content }, name)) {
|
||||
if (ret.status === GraphExecutorState.EnterNode) {
|
||||
console.log('enter node:', ret.node.name);
|
||||
} else if (ret.status === GraphExecutorState.ExitNode) {
|
||||
console.log('exit node:', ret.node.name);
|
||||
} else if (ret.status === GraphExecutorState.EmitAttachment) {
|
||||
console.log('stream attachment:', ret);
|
||||
} else {
|
||||
result += ret.content;
|
||||
}
|
||||
}
|
||||
t.truthy(result, 'should return result');
|
||||
verifier?.(t, result);
|
||||
});
|
||||
}
|
||||
);
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user