mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-12 20:38:52 +00:00
feat(server): update gql endpoint & workspace doc match test (#11104)
This commit is contained in:
@@ -40,7 +40,10 @@ Generated by [AVA](https://avajs.dev).
|
||||
|
||||
[
|
||||
{
|
||||
id: 'docId1',
|
||||
blobId: 'fileId1',
|
||||
chunkSize: 0,
|
||||
name: 'sample.pdf',
|
||||
status: 'processing',
|
||||
},
|
||||
]
|
||||
|
||||
@@ -48,9 +51,6 @@ Generated by [AVA](https://avajs.dev).
|
||||
|
||||
[
|
||||
{
|
||||
blobId: 'fileId1',
|
||||
chunkSize: 0,
|
||||
name: 'sample.pdf',
|
||||
status: 'processing',
|
||||
id: 'docId1',
|
||||
},
|
||||
]
|
||||
|
||||
Binary file not shown.
@@ -1,6 +1,7 @@
|
||||
import { randomUUID } from 'node:crypto';
|
||||
|
||||
import { ProjectRoot } from '@affine-tools/utils/path';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
import type { TestFn } from 'ava';
|
||||
import ava from 'ava';
|
||||
import Sinon from 'sinon';
|
||||
@@ -8,6 +9,7 @@ import Sinon from 'sinon';
|
||||
import { JobQueue } from '../base';
|
||||
import { ConfigModule } from '../base/config';
|
||||
import { AuthService } from '../core/auth';
|
||||
import { DocReader } from '../core/doc';
|
||||
import { WorkspaceModule } from '../core/workspaces';
|
||||
import { CopilotModule } from '../plugins/copilot';
|
||||
import {
|
||||
@@ -41,14 +43,16 @@ import {
|
||||
chatWithText,
|
||||
chatWithTextStream,
|
||||
chatWithWorkflow,
|
||||
cleanObject,
|
||||
createCopilotContext,
|
||||
createCopilotMessage,
|
||||
createCopilotSession,
|
||||
forkCopilotSession,
|
||||
getHistories,
|
||||
listContext,
|
||||
listContextFiles,
|
||||
matchContext,
|
||||
listContextDocAndFiles,
|
||||
matchFiles,
|
||||
matchWorkspaceDocs,
|
||||
MockCopilotTestProvider,
|
||||
sse2array,
|
||||
textToEventStream,
|
||||
@@ -59,6 +63,7 @@ import {
|
||||
const test = ava as TestFn<{
|
||||
auth: AuthService;
|
||||
app: TestingApp;
|
||||
db: PrismaClient;
|
||||
context: CopilotContextService;
|
||||
jobs: CopilotContextDocJob;
|
||||
prompt: PromptService;
|
||||
@@ -92,16 +97,26 @@ test.before(async t => {
|
||||
tapModule: m => {
|
||||
// use real JobQueue for testing
|
||||
m.overrideProvider(JobQueue).useClass(JobQueue);
|
||||
m.overrideProvider(DocReader).useValue({
|
||||
getFullDocContent() {
|
||||
return {
|
||||
title: '1',
|
||||
summary: '1',
|
||||
};
|
||||
},
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
const auth = app.get(AuthService);
|
||||
const db = app.get(PrismaClient);
|
||||
const context = app.get(CopilotContextService);
|
||||
const prompt = app.get(PromptService);
|
||||
const storage = app.get(CopilotStorage);
|
||||
const jobs = app.get(CopilotContextDocJob);
|
||||
|
||||
t.context.app = app;
|
||||
t.context.db = db;
|
||||
t.context.auth = auth;
|
||||
t.context.context = context;
|
||||
t.context.prompt = prompt;
|
||||
@@ -513,15 +528,6 @@ test('should be able to retry with api', async t => {
|
||||
);
|
||||
}
|
||||
|
||||
const cleanObject = (obj: any[]) =>
|
||||
JSON.parse(
|
||||
JSON.stringify(obj, (k, v) =>
|
||||
['id', 'sessionId', 'createdAt'].includes(k) || v === null
|
||||
? undefined
|
||||
: v
|
||||
)
|
||||
);
|
||||
|
||||
// retry chat
|
||||
{
|
||||
const { id } = await createWorkspace(app);
|
||||
@@ -771,6 +777,7 @@ test('should be able to manage context', async t => {
|
||||
ProjectRoot.join('packages/common/native/fixtures/sample.pdf').toFileUrl()
|
||||
);
|
||||
|
||||
// match files
|
||||
{
|
||||
const contextId = await createCopilotContext(app, workspaceId, sessionId);
|
||||
|
||||
@@ -781,34 +788,98 @@ test('should be able to manage context', async t => {
|
||||
'sample.pdf',
|
||||
buffer
|
||||
);
|
||||
await addContextDoc(app, contextId, 'docId1');
|
||||
|
||||
const { docs, files } =
|
||||
(await listContextFiles(app, workspaceId, sessionId, contextId)) || {};
|
||||
const { files } =
|
||||
(await listContextDocAndFiles(app, workspaceId, sessionId, contextId)) ||
|
||||
{};
|
||||
t.snapshot(
|
||||
docs?.map(({ createdAt: _, ...d }) => d),
|
||||
cleanObject(files, ['id', 'error', 'createdAt']),
|
||||
'should list context files'
|
||||
);
|
||||
t.snapshot(
|
||||
files?.map(({ createdAt: _, id: __, ...f }) => f),
|
||||
'should list context docs'
|
||||
);
|
||||
|
||||
// wait for processing
|
||||
{
|
||||
let { files } =
|
||||
(await listContextFiles(app, workspaceId, sessionId, contextId)) || {};
|
||||
(await listContextDocAndFiles(
|
||||
app,
|
||||
workspaceId,
|
||||
sessionId,
|
||||
contextId
|
||||
)) || {};
|
||||
|
||||
while (files?.[0].status !== 'finished') {
|
||||
await new Promise(resolve => setTimeout(resolve, 1000));
|
||||
({ files } =
|
||||
(await listContextFiles(app, workspaceId, sessionId, contextId)) ||
|
||||
{});
|
||||
(await listContextDocAndFiles(
|
||||
app,
|
||||
workspaceId,
|
||||
sessionId,
|
||||
contextId
|
||||
)) || {});
|
||||
}
|
||||
}
|
||||
|
||||
const result = (await matchContext(app, contextId, 'test', 1))!;
|
||||
const result = (await matchFiles(app, contextId, 'test', 1))!;
|
||||
t.is(result.length, 1, 'should match context');
|
||||
t.is(result[0].fileId, fileId, 'should match file id');
|
||||
}
|
||||
|
||||
// match docs
|
||||
{
|
||||
const sessionId = await createCopilotSession(
|
||||
app,
|
||||
workspaceId,
|
||||
randomUUID(),
|
||||
promptName
|
||||
);
|
||||
const contextId = await createCopilotContext(app, workspaceId, sessionId);
|
||||
|
||||
const docId = 'docId1';
|
||||
await t.context.db.snapshot.create({
|
||||
data: {
|
||||
workspaceId: workspaceId,
|
||||
id: docId,
|
||||
blob: Buffer.from([1, 1]),
|
||||
state: Buffer.from([1, 1]),
|
||||
updatedAt: new Date(),
|
||||
createdAt: new Date(),
|
||||
},
|
||||
});
|
||||
|
||||
await addContextDoc(app, contextId, docId);
|
||||
|
||||
const { docs } =
|
||||
(await listContextDocAndFiles(app, workspaceId, sessionId, contextId)) ||
|
||||
{};
|
||||
t.snapshot(
|
||||
cleanObject(docs, ['error', 'createdAt']),
|
||||
'should list context docs'
|
||||
);
|
||||
|
||||
// wait for processing
|
||||
{
|
||||
let { docs } =
|
||||
(await listContextDocAndFiles(
|
||||
app,
|
||||
workspaceId,
|
||||
sessionId,
|
||||
contextId
|
||||
)) || {};
|
||||
|
||||
while (docs?.[0].status !== 'finished') {
|
||||
await new Promise(resolve => setTimeout(resolve, 1000));
|
||||
({ docs } =
|
||||
(await listContextDocAndFiles(
|
||||
app,
|
||||
workspaceId,
|
||||
sessionId,
|
||||
contextId
|
||||
)) || {});
|
||||
}
|
||||
}
|
||||
|
||||
const result = (await matchWorkspaceDocs(app, contextId, 'test', 1))!;
|
||||
t.is(result.length, 1, 'should match context');
|
||||
t.is(result[0].docId, docId, 'should match doc id');
|
||||
}
|
||||
});
|
||||
|
||||
@@ -104,14 +104,14 @@ test('should insert embedding by doc id', async t => {
|
||||
{
|
||||
index: 0,
|
||||
content: 'content',
|
||||
embedding: Array.from({ length: 512 }, () => 1),
|
||||
embedding: Array.from({ length: 1024 }, () => 1),
|
||||
},
|
||||
]
|
||||
);
|
||||
|
||||
{
|
||||
const ret = await t.context.copilotContext.matchContentEmbedding(
|
||||
Array.from({ length: 512 }, () => 0.9),
|
||||
Array.from({ length: 1024 }, () => 0.9),
|
||||
contextId,
|
||||
1,
|
||||
1
|
||||
@@ -123,7 +123,7 @@ test('should insert embedding by doc id', async t => {
|
||||
{
|
||||
await t.context.copilotContext.deleteEmbedding(contextId, 'file-id');
|
||||
const ret = await t.context.copilotContext.matchContentEmbedding(
|
||||
Array.from({ length: 512 }, () => 0.9),
|
||||
Array.from({ length: 1024 }, () => 0.9),
|
||||
contextId,
|
||||
1,
|
||||
1
|
||||
@@ -151,7 +151,7 @@ test('should insert embedding by doc id', async t => {
|
||||
{
|
||||
index: 0,
|
||||
content: 'content',
|
||||
embedding: Array.from({ length: 512 }, () => 1),
|
||||
embedding: Array.from({ length: 1024 }, () => 1),
|
||||
},
|
||||
]
|
||||
);
|
||||
@@ -166,7 +166,7 @@ test('should insert embedding by doc id', async t => {
|
||||
|
||||
{
|
||||
const ret = await t.context.copilotContext.matchWorkspaceEmbedding(
|
||||
Array.from({ length: 512 }, () => 0.9),
|
||||
Array.from({ length: 1024 }, () => 0.9),
|
||||
workspace.id,
|
||||
1,
|
||||
1
|
||||
|
||||
@@ -156,6 +156,16 @@ export class MockCopilotTestProvider
|
||||
}
|
||||
}
|
||||
|
||||
export const cleanObject = (
|
||||
obj: any[] | undefined,
|
||||
condition = ['id', 'status', 'error', 'sessionId', 'createdAt']
|
||||
) =>
|
||||
JSON.parse(
|
||||
JSON.stringify(obj || [], (k, v) =>
|
||||
condition.includes(k) || v === null ? undefined : v
|
||||
)
|
||||
);
|
||||
|
||||
export async function createCopilotSession(
|
||||
app: TestingApp,
|
||||
workspaceId: string,
|
||||
@@ -224,7 +234,7 @@ export async function createCopilotContext(
|
||||
return res.createCopilotContext;
|
||||
}
|
||||
|
||||
export async function matchContext(
|
||||
export async function matchFiles(
|
||||
app: TestingApp,
|
||||
contextId: string,
|
||||
content: string,
|
||||
@@ -240,11 +250,11 @@ export async function matchContext(
|
||||
> {
|
||||
const res = await app.gql(
|
||||
`
|
||||
query matchContext($contextId: String!, $content: String!, $limit: SafeInt, $threshold: Float) {
|
||||
query matchFiles($contextId: String!, $content: String!, $limit: SafeInt, $threshold: Float) {
|
||||
currentUser {
|
||||
copilot {
|
||||
contexts(contextId: $contextId) {
|
||||
matchContext(content: $content, limit: $limit, threshold: $threshold) {
|
||||
matchFiles(content: $content, limit: $limit, threshold: $threshold) {
|
||||
fileId
|
||||
chunk
|
||||
content
|
||||
@@ -258,7 +268,44 @@ export async function matchContext(
|
||||
{ contextId, content, limit, threshold: 1 }
|
||||
);
|
||||
|
||||
return res.currentUser?.copilot?.contexts?.[0]?.matchContext;
|
||||
return res.currentUser?.copilot?.contexts?.[0]?.matchFiles;
|
||||
}
|
||||
|
||||
export async function matchWorkspaceDocs(
|
||||
app: TestingApp,
|
||||
contextId: string,
|
||||
content: string,
|
||||
limit: number
|
||||
): Promise<
|
||||
| {
|
||||
docId: string;
|
||||
chunk: number;
|
||||
content: string;
|
||||
distance: number | null;
|
||||
}[]
|
||||
| undefined
|
||||
> {
|
||||
const res = await app.gql(
|
||||
`
|
||||
query matchWorkspaceDocs($contextId: String!, $content: String!, $limit: SafeInt, $threshold: Float) {
|
||||
currentUser {
|
||||
copilot {
|
||||
contexts(contextId: $contextId) {
|
||||
matchWorkspaceDocs(content: $content, limit: $limit, threshold: $threshold) {
|
||||
docId
|
||||
chunk
|
||||
content
|
||||
distance
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
`,
|
||||
{ contextId, content, limit, threshold: 1 }
|
||||
);
|
||||
|
||||
return res.currentUser?.copilot?.contexts?.[0]?.matchWorkspaceDocs;
|
||||
}
|
||||
|
||||
export async function listContext(
|
||||
@@ -376,7 +423,7 @@ export async function removeContextDoc(
|
||||
return res.removeContextDoc;
|
||||
}
|
||||
|
||||
export async function listContextFiles(
|
||||
export async function listContextDocAndFiles(
|
||||
app: TestingApp,
|
||||
workspaceId: string,
|
||||
sessionId: string,
|
||||
@@ -385,6 +432,8 @@ export async function listContextFiles(
|
||||
| {
|
||||
docs: {
|
||||
id: string;
|
||||
status: string;
|
||||
error: string | null;
|
||||
createdAt: number;
|
||||
}[];
|
||||
files: {
|
||||
@@ -393,6 +442,7 @@ export async function listContextFiles(
|
||||
blobId: string;
|
||||
chunkSize: number;
|
||||
status: string;
|
||||
error: string | null;
|
||||
createdAt: number;
|
||||
}[];
|
||||
}
|
||||
@@ -405,6 +455,8 @@ export async function listContextFiles(
|
||||
contexts(sessionId: "${sessionId}", contextId: "${contextId}") {
|
||||
docs {
|
||||
id
|
||||
status
|
||||
error
|
||||
createdAt
|
||||
}
|
||||
files {
|
||||
@@ -413,6 +465,7 @@ export async function listContextFiles(
|
||||
blobId
|
||||
chunkSize
|
||||
status
|
||||
error
|
||||
createdAt
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,7 +30,7 @@ export class MockEmbeddingClient extends EmbeddingClient {
|
||||
return input.map((_, i) => ({
|
||||
index: i,
|
||||
content: input[i],
|
||||
embedding: Array.from({ length: 512 }, () => Math.random()),
|
||||
embedding: Array.from({ length: 1024 }, () => Math.random()),
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -656,10 +656,10 @@ export class CopilotContextResolver {
|
||||
}
|
||||
|
||||
@ResolveField(() => [ContextMatchedFileChunk], {
|
||||
description: 'match file context',
|
||||
description: 'match file in context',
|
||||
})
|
||||
@CallMetric('ai', 'context_file_remove')
|
||||
async matchContext(
|
||||
async matchFiles(
|
||||
@Context() ctx: { req: Request },
|
||||
@Parent() context: CopilotContextType,
|
||||
@Args('content') content: string,
|
||||
@@ -667,16 +667,11 @@ export class CopilotContextResolver {
|
||||
limit?: number,
|
||||
@Args('threshold', { type: () => Float, nullable: true })
|
||||
threshold?: number
|
||||
) {
|
||||
): Promise<ContextMatchedFileChunk[]> {
|
||||
if (!this.context.canEmbedding) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const lockFlag = `${COPILOT_LOCKER}:context:${context.id}`;
|
||||
await using lock = await this.mutex.acquire(lockFlag);
|
||||
if (!lock) {
|
||||
return new TooManyRequest('Server is busy');
|
||||
}
|
||||
const session = await this.context.get(context.id);
|
||||
|
||||
try {
|
||||
@@ -696,18 +691,20 @@ export class CopilotContextResolver {
|
||||
}
|
||||
}
|
||||
|
||||
@ResolveField(() => ContextMatchedDocChunk, {
|
||||
description: 'match workspace doc content',
|
||||
@ResolveField(() => [ContextMatchedDocChunk], {
|
||||
description: 'match workspace docs',
|
||||
})
|
||||
@CallMetric('ai', 'context_match_workspace_doc')
|
||||
async matchWorkspaceContext(
|
||||
async matchWorkspaceDocs(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Context() ctx: { req: Request },
|
||||
@Parent() context: CopilotContextType,
|
||||
@Args('content') content: string,
|
||||
@Args('limit', { type: () => SafeIntResolver, nullable: true })
|
||||
limit?: number
|
||||
) {
|
||||
limit?: number,
|
||||
@Args('threshold', { type: () => Float, nullable: true })
|
||||
threshold?: number
|
||||
): Promise<ContextMatchedDocChunk[]> {
|
||||
if (!this.context.canEmbedding) {
|
||||
return [];
|
||||
}
|
||||
@@ -723,7 +720,8 @@ export class CopilotContextResolver {
|
||||
return await session.matchWorkspaceChunks(
|
||||
content,
|
||||
limit,
|
||||
this.getSignal(ctx.req)
|
||||
this.getSignal(ctx.req),
|
||||
threshold
|
||||
);
|
||||
} catch (e: any) {
|
||||
throw new CopilotFailedToMatchContext({
|
||||
|
||||
@@ -199,7 +199,7 @@ export class ContextSession implements AsyncDisposable {
|
||||
|
||||
return this.models.copilotContext.matchWorkspaceEmbedding(
|
||||
embedding,
|
||||
this.id,
|
||||
this.workspaceId,
|
||||
topK,
|
||||
threshold
|
||||
);
|
||||
|
||||
@@ -109,11 +109,11 @@ type CopilotContext {
|
||||
files: [CopilotContextFile!]!
|
||||
id: ID!
|
||||
|
||||
"""match file context"""
|
||||
matchContext(content: String!, limit: SafeInt, threshold: Float): [ContextMatchedFileChunk!]!
|
||||
"""match file in context"""
|
||||
matchFiles(content: String!, limit: SafeInt, threshold: Float): [ContextMatchedFileChunk!]!
|
||||
|
||||
"""match workspace doc content"""
|
||||
matchWorkspaceContext(content: String!, limit: SafeInt): ContextMatchedDocChunk!
|
||||
"""match workspace docs"""
|
||||
matchWorkspaceDocs(content: String!, limit: SafeInt, threshold: Float): [ContextMatchedDocChunk!]!
|
||||
|
||||
"""list tags in context"""
|
||||
tags: [CopilotContextCategory!]!
|
||||
|
||||
Reference in New Issue
Block a user