feat(server): update gql endpoint & workspace doc match test (#11104)

This commit is contained in:
darkskygit
2025-03-25 10:09:22 +00:00
parent bf4107feac
commit 1bb324eeed
20 changed files with 355 additions and 139 deletions

View File

@@ -444,7 +444,7 @@ model AiContextEmbedding {
// a file can be divided into multiple chunks and embedded separately.
chunk Int @db.Integer
content String @db.VarChar
embedding Unsupported("vector(512)")
embedding Unsupported("vector(1024)")
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3)
updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(3)
@@ -462,7 +462,7 @@ model AiWorkspaceEmbedding {
// a doc can be divided into multiple chunks and embedded separately.
chunk Int @db.Integer
content String @db.VarChar
embedding Unsupported("vector(512)")
embedding Unsupported("vector(1024)")
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3)
updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(3)

View File

@@ -40,7 +40,10 @@ Generated by [AVA](https://avajs.dev).
[
{
id: 'docId1',
blobId: 'fileId1',
chunkSize: 0,
name: 'sample.pdf',
status: 'processing',
},
]
@@ -48,9 +51,6 @@ Generated by [AVA](https://avajs.dev).
[
{
blobId: 'fileId1',
chunkSize: 0,
name: 'sample.pdf',
status: 'processing',
id: 'docId1',
},
]

View File

@@ -1,6 +1,7 @@
import { randomUUID } from 'node:crypto';
import { ProjectRoot } from '@affine-tools/utils/path';
import { PrismaClient } from '@prisma/client';
import type { TestFn } from 'ava';
import ava from 'ava';
import Sinon from 'sinon';
@@ -8,6 +9,7 @@ import Sinon from 'sinon';
import { JobQueue } from '../base';
import { ConfigModule } from '../base/config';
import { AuthService } from '../core/auth';
import { DocReader } from '../core/doc';
import { WorkspaceModule } from '../core/workspaces';
import { CopilotModule } from '../plugins/copilot';
import {
@@ -41,14 +43,16 @@ import {
chatWithText,
chatWithTextStream,
chatWithWorkflow,
cleanObject,
createCopilotContext,
createCopilotMessage,
createCopilotSession,
forkCopilotSession,
getHistories,
listContext,
listContextFiles,
matchContext,
listContextDocAndFiles,
matchFiles,
matchWorkspaceDocs,
MockCopilotTestProvider,
sse2array,
textToEventStream,
@@ -59,6 +63,7 @@ import {
const test = ava as TestFn<{
auth: AuthService;
app: TestingApp;
db: PrismaClient;
context: CopilotContextService;
jobs: CopilotContextDocJob;
prompt: PromptService;
@@ -92,16 +97,26 @@ test.before(async t => {
tapModule: m => {
// use real JobQueue for testing
m.overrideProvider(JobQueue).useClass(JobQueue);
m.overrideProvider(DocReader).useValue({
getFullDocContent() {
return {
title: '1',
summary: '1',
};
},
});
},
});
const auth = app.get(AuthService);
const db = app.get(PrismaClient);
const context = app.get(CopilotContextService);
const prompt = app.get(PromptService);
const storage = app.get(CopilotStorage);
const jobs = app.get(CopilotContextDocJob);
t.context.app = app;
t.context.db = db;
t.context.auth = auth;
t.context.context = context;
t.context.prompt = prompt;
@@ -513,15 +528,6 @@ test('should be able to retry with api', async t => {
);
}
const cleanObject = (obj: any[]) =>
JSON.parse(
JSON.stringify(obj, (k, v) =>
['id', 'sessionId', 'createdAt'].includes(k) || v === null
? undefined
: v
)
);
// retry chat
{
const { id } = await createWorkspace(app);
@@ -771,6 +777,7 @@ test('should be able to manage context', async t => {
ProjectRoot.join('packages/common/native/fixtures/sample.pdf').toFileUrl()
);
// match files
{
const contextId = await createCopilotContext(app, workspaceId, sessionId);
@@ -781,34 +788,98 @@ test('should be able to manage context', async t => {
'sample.pdf',
buffer
);
await addContextDoc(app, contextId, 'docId1');
const { docs, files } =
(await listContextFiles(app, workspaceId, sessionId, contextId)) || {};
const { files } =
(await listContextDocAndFiles(app, workspaceId, sessionId, contextId)) ||
{};
t.snapshot(
docs?.map(({ createdAt: _, ...d }) => d),
cleanObject(files, ['id', 'error', 'createdAt']),
'should list context files'
);
t.snapshot(
files?.map(({ createdAt: _, id: __, ...f }) => f),
'should list context docs'
);
// wait for processing
{
let { files } =
(await listContextFiles(app, workspaceId, sessionId, contextId)) || {};
(await listContextDocAndFiles(
app,
workspaceId,
sessionId,
contextId
)) || {};
while (files?.[0].status !== 'finished') {
await new Promise(resolve => setTimeout(resolve, 1000));
({ files } =
(await listContextFiles(app, workspaceId, sessionId, contextId)) ||
{});
(await listContextDocAndFiles(
app,
workspaceId,
sessionId,
contextId
)) || {});
}
}
const result = (await matchContext(app, contextId, 'test', 1))!;
const result = (await matchFiles(app, contextId, 'test', 1))!;
t.is(result.length, 1, 'should match context');
t.is(result[0].fileId, fileId, 'should match file id');
}
// match docs
{
const sessionId = await createCopilotSession(
app,
workspaceId,
randomUUID(),
promptName
);
const contextId = await createCopilotContext(app, workspaceId, sessionId);
const docId = 'docId1';
await t.context.db.snapshot.create({
data: {
workspaceId: workspaceId,
id: docId,
blob: Buffer.from([1, 1]),
state: Buffer.from([1, 1]),
updatedAt: new Date(),
createdAt: new Date(),
},
});
await addContextDoc(app, contextId, docId);
const { docs } =
(await listContextDocAndFiles(app, workspaceId, sessionId, contextId)) ||
{};
t.snapshot(
cleanObject(docs, ['error', 'createdAt']),
'should list context docs'
);
// wait for processing
{
let { docs } =
(await listContextDocAndFiles(
app,
workspaceId,
sessionId,
contextId
)) || {};
while (docs?.[0].status !== 'finished') {
await new Promise(resolve => setTimeout(resolve, 1000));
({ docs } =
(await listContextDocAndFiles(
app,
workspaceId,
sessionId,
contextId
)) || {});
}
}
const result = (await matchWorkspaceDocs(app, contextId, 'test', 1))!;
t.is(result.length, 1, 'should match context');
t.is(result[0].docId, docId, 'should match doc id');
}
});

View File

@@ -104,14 +104,14 @@ test('should insert embedding by doc id', async t => {
{
index: 0,
content: 'content',
embedding: Array.from({ length: 512 }, () => 1),
embedding: Array.from({ length: 1024 }, () => 1),
},
]
);
{
const ret = await t.context.copilotContext.matchContentEmbedding(
Array.from({ length: 512 }, () => 0.9),
Array.from({ length: 1024 }, () => 0.9),
contextId,
1,
1
@@ -123,7 +123,7 @@ test('should insert embedding by doc id', async t => {
{
await t.context.copilotContext.deleteEmbedding(contextId, 'file-id');
const ret = await t.context.copilotContext.matchContentEmbedding(
Array.from({ length: 512 }, () => 0.9),
Array.from({ length: 1024 }, () => 0.9),
contextId,
1,
1
@@ -151,7 +151,7 @@ test('should insert embedding by doc id', async t => {
{
index: 0,
content: 'content',
embedding: Array.from({ length: 512 }, () => 1),
embedding: Array.from({ length: 1024 }, () => 1),
},
]
);
@@ -166,7 +166,7 @@ test('should insert embedding by doc id', async t => {
{
const ret = await t.context.copilotContext.matchWorkspaceEmbedding(
Array.from({ length: 512 }, () => 0.9),
Array.from({ length: 1024 }, () => 0.9),
workspace.id,
1,
1

View File

@@ -156,6 +156,16 @@ export class MockCopilotTestProvider
}
}
export const cleanObject = (
obj: any[] | undefined,
condition = ['id', 'status', 'error', 'sessionId', 'createdAt']
) =>
JSON.parse(
JSON.stringify(obj || [], (k, v) =>
condition.includes(k) || v === null ? undefined : v
)
);
export async function createCopilotSession(
app: TestingApp,
workspaceId: string,
@@ -224,7 +234,7 @@ export async function createCopilotContext(
return res.createCopilotContext;
}
export async function matchContext(
export async function matchFiles(
app: TestingApp,
contextId: string,
content: string,
@@ -240,11 +250,11 @@ export async function matchContext(
> {
const res = await app.gql(
`
query matchContext($contextId: String!, $content: String!, $limit: SafeInt, $threshold: Float) {
query matchFiles($contextId: String!, $content: String!, $limit: SafeInt, $threshold: Float) {
currentUser {
copilot {
contexts(contextId: $contextId) {
matchContext(content: $content, limit: $limit, threshold: $threshold) {
matchFiles(content: $content, limit: $limit, threshold: $threshold) {
fileId
chunk
content
@@ -258,7 +268,44 @@ export async function matchContext(
{ contextId, content, limit, threshold: 1 }
);
return res.currentUser?.copilot?.contexts?.[0]?.matchContext;
return res.currentUser?.copilot?.contexts?.[0]?.matchFiles;
}
export async function matchWorkspaceDocs(
app: TestingApp,
contextId: string,
content: string,
limit: number
): Promise<
| {
docId: string;
chunk: number;
content: string;
distance: number | null;
}[]
| undefined
> {
const res = await app.gql(
`
query matchWorkspaceDocs($contextId: String!, $content: String!, $limit: SafeInt, $threshold: Float) {
currentUser {
copilot {
contexts(contextId: $contextId) {
matchWorkspaceDocs(content: $content, limit: $limit, threshold: $threshold) {
docId
chunk
content
distance
}
}
}
}
}
`,
{ contextId, content, limit, threshold: 1 }
);
return res.currentUser?.copilot?.contexts?.[0]?.matchWorkspaceDocs;
}
export async function listContext(
@@ -376,7 +423,7 @@ export async function removeContextDoc(
return res.removeContextDoc;
}
export async function listContextFiles(
export async function listContextDocAndFiles(
app: TestingApp,
workspaceId: string,
sessionId: string,
@@ -385,6 +432,8 @@ export async function listContextFiles(
| {
docs: {
id: string;
status: string;
error: string | null;
createdAt: number;
}[];
files: {
@@ -393,6 +442,7 @@ export async function listContextFiles(
blobId: string;
chunkSize: number;
status: string;
error: string | null;
createdAt: number;
}[];
}
@@ -405,6 +455,8 @@ export async function listContextFiles(
contexts(sessionId: "${sessionId}", contextId: "${contextId}") {
docs {
id
status
error
createdAt
}
files {
@@ -413,6 +465,7 @@ export async function listContextFiles(
blobId
chunkSize
status
error
createdAt
}
}

View File

@@ -30,7 +30,7 @@ export class MockEmbeddingClient extends EmbeddingClient {
return input.map((_, i) => ({
index: i,
content: input[i],
embedding: Array.from({ length: 512 }, () => Math.random()),
embedding: Array.from({ length: 1024 }, () => Math.random()),
}));
}
}

View File

@@ -656,10 +656,10 @@ export class CopilotContextResolver {
}
@ResolveField(() => [ContextMatchedFileChunk], {
description: 'match file context',
description: 'match file in context',
})
@CallMetric('ai', 'context_file_remove')
async matchContext(
async matchFiles(
@Context() ctx: { req: Request },
@Parent() context: CopilotContextType,
@Args('content') content: string,
@@ -667,16 +667,11 @@ export class CopilotContextResolver {
limit?: number,
@Args('threshold', { type: () => Float, nullable: true })
threshold?: number
) {
): Promise<ContextMatchedFileChunk[]> {
if (!this.context.canEmbedding) {
return [];
}
const lockFlag = `${COPILOT_LOCKER}:context:${context.id}`;
await using lock = await this.mutex.acquire(lockFlag);
if (!lock) {
return new TooManyRequest('Server is busy');
}
const session = await this.context.get(context.id);
try {
@@ -696,18 +691,20 @@ export class CopilotContextResolver {
}
}
@ResolveField(() => ContextMatchedDocChunk, {
description: 'match workspace doc content',
@ResolveField(() => [ContextMatchedDocChunk], {
description: 'match workspace docs',
})
@CallMetric('ai', 'context_match_workspace_doc')
async matchWorkspaceContext(
async matchWorkspaceDocs(
@CurrentUser() user: CurrentUser,
@Context() ctx: { req: Request },
@Parent() context: CopilotContextType,
@Args('content') content: string,
@Args('limit', { type: () => SafeIntResolver, nullable: true })
limit?: number
) {
limit?: number,
@Args('threshold', { type: () => Float, nullable: true })
threshold?: number
): Promise<ContextMatchedDocChunk[]> {
if (!this.context.canEmbedding) {
return [];
}
@@ -723,7 +720,8 @@ export class CopilotContextResolver {
return await session.matchWorkspaceChunks(
content,
limit,
this.getSignal(ctx.req)
this.getSignal(ctx.req),
threshold
);
} catch (e: any) {
throw new CopilotFailedToMatchContext({

View File

@@ -199,7 +199,7 @@ export class ContextSession implements AsyncDisposable {
return this.models.copilotContext.matchWorkspaceEmbedding(
embedding,
this.id,
this.workspaceId,
topK,
threshold
);

View File

@@ -109,11 +109,11 @@ type CopilotContext {
files: [CopilotContextFile!]!
id: ID!
"""match file context"""
matchContext(content: String!, limit: SafeInt, threshold: Float): [ContextMatchedFileChunk!]!
"""match file in context"""
matchFiles(content: String!, limit: SafeInt, threshold: Float): [ContextMatchedFileChunk!]!
"""match workspace doc content"""
matchWorkspaceContext(content: String!, limit: SafeInt): ContextMatchedDocChunk!
"""match workspace docs"""
matchWorkspaceDocs(content: String!, limit: SafeInt, threshold: Float): [ContextMatchedDocChunk!]!
"""list tags in context"""
tags: [CopilotContextCategory!]!