feat(server): context awareness for copilot (#9611)

fix PD-2167
fix PD-2169
fix PD-2190
This commit is contained in:
darkskygit
2025-03-13 11:44:55 +00:00
parent 05f3069efd
commit d8373f66e7
51 changed files with 2101 additions and 294 deletions

View File

@@ -43,3 +43,14 @@ Generated by [AVA](https://avajs.dev).
id: 'docId1',
},
]
> should list context docs
[
{
blobId: 'fileId1',
chunkSize: 0,
name: 'sample.pdf',
status: 'processing',
},
]

View File

@@ -1,5 +1,6 @@
import { randomUUID } from 'node:crypto';
import { ProjectRoot } from '@affine-tools/utils/path';
import type { TestFn } from 'ava';
import ava from 'ava';
import Sinon from 'sinon';
@@ -8,7 +9,11 @@ import { ConfigModule } from '../base/config';
import { AuthService } from '../core/auth';
import { WorkspaceModule } from '../core/workspaces';
import { CopilotModule } from '../plugins/copilot';
import { CopilotContextService } from '../plugins/copilot/context';
import {
CopilotContextDocJob,
CopilotContextService,
} from '../plugins/copilot/context';
import { MockEmbeddingClient } from '../plugins/copilot/context/embedding';
import { prompts, PromptService } from '../plugins/copilot/prompt';
import {
CopilotProviderService,
@@ -29,6 +34,7 @@ import {
} from './utils';
import {
addContextDoc,
addContextFile,
array2sse,
chatWithImages,
chatWithText,
@@ -41,6 +47,7 @@ import {
getHistories,
listContext,
listContextFiles,
matchContext,
MockCopilotTestProvider,
sse2array,
textToEventStream,
@@ -52,6 +59,7 @@ const test = ava as TestFn<{
auth: AuthService;
app: TestingApp;
context: CopilotContextService;
jobs: CopilotContextDocJob;
prompt: PromptService;
provider: CopilotProviderService;
storage: CopilotStorage;
@@ -86,12 +94,14 @@ test.before(async t => {
const context = app.get(CopilotContextService);
const prompt = app.get(PromptService);
const storage = app.get(CopilotStorage);
const jobs = app.get(CopilotContextDocJob);
t.context.app = app;
t.context.auth = auth;
t.context.context = context;
t.context.prompt = prompt;
t.context.storage = storage;
t.context.jobs = jobs;
});
const promptName = 'prompt';
@@ -719,7 +729,7 @@ test('should be able to search image from unsplash', async t => {
});
test('should be able to manage context', async t => {
const { app } = t.context;
const { app, context, jobs } = t.context;
const { id: workspaceId } = await createWorkspace(app);
const sessionId = await createCopilotSession(
@@ -729,6 +739,10 @@ test('should be able to manage context', async t => {
promptName
);
// use mocked embedding client
Sinon.stub(context, 'embeddingClient').get(() => new MockEmbeddingClient());
Sinon.stub(jobs, 'embeddingClient').get(() => new MockEmbeddingClient());
{
await t.throwsAsync(
createCopilotContext(app, workspaceId, randomUUID()),
@@ -747,16 +761,49 @@ test('should be able to manage context', async t => {
);
}
const fs = await import('node:fs');
const buffer = fs.readFileSync(
ProjectRoot.join('packages/common/native/fixtures/sample.pdf').toFileUrl()
);
{
const contextId = await createCopilotContext(app, workspaceId, sessionId);
const { id: fileId } = await addContextFile(
app,
contextId,
'fileId1',
'sample.pdf',
buffer
);
await addContextDoc(app, contextId, 'docId1');
const { docs } =
const { docs, files } =
(await listContextFiles(app, workspaceId, sessionId, contextId)) || {};
t.snapshot(
docs?.map(({ createdAt: _, ...d }) => d),
'should list context files'
);
t.snapshot(
files?.map(({ createdAt: _, id: __, ...f }) => f),
'should list context docs'
);
// wait for processing
{
let { files } =
(await listContextFiles(app, workspaceId, sessionId, contextId)) || {};
while (files?.[0].status !== 'finished') {
await new Promise(resolve => setTimeout(resolve, 1000));
({ files } =
(await listContextFiles(app, workspaceId, sessionId, contextId)) ||
{});
}
}
const result = (await matchContext(app, contextId, 'test', 1))!;
t.is(result.length, 1, 'should match context');
t.is(result[0].fileId, fileId, 'should match file id');
}
});

View File

@@ -1,14 +1,20 @@
import { randomUUID } from 'node:crypto';
import { ProjectRoot } from '@affine-tools/utils/path';
import type { TestFn } from 'ava';
import ava from 'ava';
import Sinon from 'sinon';
import { EventBus } from '../base';
import { ConfigModule } from '../base/config';
import { AuthService } from '../core/auth';
import { QuotaModule } from '../core/quota';
import { CopilotModule } from '../plugins/copilot';
import { CopilotContextService } from '../plugins/copilot/context';
import {
CopilotContextDocJob,
CopilotContextService,
} from '../plugins/copilot/context';
import { MockEmbeddingClient } from '../plugins/copilot/context/embedding';
import { prompts, PromptService } from '../plugins/copilot/prompt';
import {
CopilotProviderService,
@@ -18,6 +24,7 @@ import {
} from '../plugins/copilot/providers';
import { CitationParser } from '../plugins/copilot/providers/perplexity';
import { ChatSessionService } from '../plugins/copilot/session';
import { CopilotStorage } from '../plugins/copilot/storage';
import {
CopilotCapability,
CopilotProviderType,
@@ -47,10 +54,13 @@ import { MockCopilotTestProvider, WorkflowTestCases } from './utils/copilot';
const test = ava as TestFn<{
auth: AuthService;
module: TestingModule;
event: EventBus;
context: CopilotContextService;
prompt: PromptService;
provider: CopilotProviderService;
session: ChatSessionService;
jobs: CopilotContextDocJob;
storage: CopilotStorage;
workflow: CopilotWorkflowService;
executors: {
image: CopilotChatImageExecutor;
@@ -85,19 +95,25 @@ test.before(async t => {
});
const auth = module.get(AuthService);
const event = module.get(EventBus);
const context = module.get(CopilotContextService);
const prompt = module.get(PromptService);
const provider = module.get(CopilotProviderService);
const session = module.get(ChatSessionService);
const workflow = module.get(CopilotWorkflowService);
const jobs = module.get(CopilotContextDocJob);
const storage = module.get(CopilotStorage);
t.context.module = module;
t.context.auth = auth;
t.context.event = event;
t.context.context = context;
t.context.prompt = prompt;
t.context.provider = provider;
t.context.session = session;
t.context.workflow = workflow;
t.context.jobs = jobs;
t.context.storage = storage;
t.context.executors = {
image: module.get(CopilotChatImageExecutor),
text: module.get(CopilotChatTextExecutor),
@@ -1276,7 +1292,7 @@ test('CitationParser should not replace chunks of citation already with URLs', t
// ==================== context ====================
test('should be able to manage context', async t => {
const { context, prompt, session } = t.context;
const { context, prompt, session, event, jobs, storage } = t.context;
await prompt.set('prompt', 'model', [
{ role: 'system', content: 'hello {{word}}' },
@@ -1288,6 +1304,10 @@ test('should be able to manage context', async t => {
promptName: 'prompt',
});
// use mocked embedding client
Sinon.stub(context, 'embeddingClient').get(() => new MockEmbeddingClient());
Sinon.stub(jobs, 'embeddingClient').get(() => new MockEmbeddingClient());
{
await t.throwsAsync(
context.create(randomUUID()),
@@ -1310,9 +1330,45 @@ test('should be able to manage context', async t => {
);
}
const fs = await import('node:fs');
const buffer = fs.readFileSync(
ProjectRoot.join('packages/common/native/fixtures/sample.pdf').toFileUrl()
);
{
const session = await context.create(chatSession);
await storage.put(userId, session.workspaceId, 'blob', buffer);
const file = await session.addFile('blob', 'sample.pdf');
const handler = Sinon.spy(event, 'emit');
await jobs.embedPendingFile({
userId,
workspaceId: session.workspaceId,
contextId: session.id,
blobId: file.blobId,
fileId: file.id,
fileName: file.name,
});
t.deepEqual(handler.lastCall.args, [
'workspace.file.embed.finished',
{
contextId: session.id,
fileId: file.id,
chunkSize: 1,
},
]);
const list = session.listFiles();
t.deepEqual(
list.map(f => f.id),
[file.id],
'should list file id'
);
const docId = randomUUID();
await session.addDocRecord(docId);
const docs = session.listDocs().map(d => d.id);
@@ -1320,5 +1376,9 @@ test('should be able to manage context', async t => {
await session.removeDocRecord(docId);
t.deepEqual(session.listDocs(), [], 'should remove doc id');
const result = await session.matchFileChunks('test', 1, undefined, 1);
t.is(result.length, 1, 'should match context');
t.is(result[0].fileId, file.id, 'should match file id');
}
});

View File

@@ -240,19 +240,25 @@ export async function matchContext(
> {
const res = await app.gql(
`
mutation matchContext($content: String!, $contextId: String!, $limit: SafeInt) {
matchContext(content: $content, contextId: $contextId, limit: $limit) {
fileId
chunk
content
distance
query matchContext($contextId: String!, $content: String!, $limit: SafeInt, $threshold: Float) {
currentUser {
copilot {
contexts(contextId: $contextId) {
matchContext(content: $content, limit: $limit, threshold: $threshold) {
fileId
chunk
content
distance
}
}
}
}
}
`,
{ contextId, content, limit }
{ contextId, content, limit, threshold: 1 }
);
return res.matchContext;
return res.currentUser?.copilot?.contexts?.[0]?.matchContext;
}
export async function listContext(
@@ -287,7 +293,7 @@ export async function addContextFile(
blobId: string,
fileName: string,
content: Buffer
): Promise<{ id: string }[]> {
): Promise<{ id: string }> {
const res = await app
.POST(gql)
.set({ 'x-request-id': 'test', 'x-operation-name': 'test' })
@@ -303,7 +309,7 @@ export async function addContextFile(
`,
variables: {
content: null,
options: { contextId, blobId, fileName },
options: { contextId, blobId },
},
})
)