fix(server): query workspace embed files (#11982)

<!-- This is an auto-generated comment: release notes by coderabbit.ai -->
## Summary by CodeRabbit

- **New Features**
	- Expanded file chunk matching to include both context and workspace file embeddings, providing broader and more relevant search results.
- **Improvements**
	- Enhanced result ranking by introducing a re-ranking step for combined embedding matches, improving the relevance of returned file chunks.
	- Adjusted file count reporting to reflect the total number of workspace files instead of ignored documents for more accurate workspace file statistics.
	- Renamed and streamlined workspace file management methods for clearer and more consistent API usage.
- **Bug Fixes**
	- Prevented embedding similarity queries when embedding is disabled for a workspace, improving system behavior consistency.
- **Tests**
	- Added comprehensive tests to verify workspace embedding management, including enabling, matching, and disabling embedding functionality.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
darkskygit
2025-04-25 08:32:32 +00:00
parent 0abe65653b
commit 49c57ca649
13 changed files with 220 additions and 112 deletions

View File

@@ -284,7 +284,15 @@ const actions = [
verifier: (t: ExecutionContext<Tester>, result: string) => {
assertNotWrappedInCodeBlock(t, result);
assertCitation(t, result, (t, c) => {
t.assert(c.length === 0, 'should not have citation');
t.assert(
c.length === 0 ||
// ignore web search result
c
.map(c => JSON.parse(c.citationJson).type)
.filter(type => ['attachment', 'doc'].includes(type)).length ===
0,
'should not have citation'
);
});
},
type: 'text' as const,
@@ -404,8 +412,9 @@ const actions = [
messages: [{ role: 'user' as const, content: TestAssets.SSOT }],
verifier: (t: ExecutionContext<Tester>, result: string) => {
assertNotWrappedInCodeBlock(t, result);
const cleared = result.toLowerCase();
t.assert(
result.toLowerCase().includes('single source of truth'),
cleared.includes('single source of truth') || cleared.includes('ssot'),
'should include original keyword'
);
},

View File

@@ -120,7 +120,7 @@ test.before(async t => {
t.context.jobs = jobs;
});
const promptName = 'prompt';
const textPromptName = 'prompt';
test.beforeEach(async t => {
Sinon.restore();
const { app, prompt } = t.context;
@@ -128,7 +128,7 @@ test.beforeEach(async t => {
await prompt.onApplicationBootstrap();
t.context.u1 = await app.signupV1('u1@affine.pro');
await prompt.set(promptName, 'test', [
await prompt.set(textPromptName, 'test', [
{ role: 'system', content: 'hello {{word}}' },
]);
});
@@ -150,7 +150,7 @@ test('should create session correctly', async t => {
}
) => {
await asserter(
createCopilotSession(app, workspaceId, randomUUID(), promptName)
createCopilotSession(app, workspaceId, randomUUID(), textPromptName)
);
};
@@ -202,7 +202,7 @@ test('should update session correctly', async t => {
t.truthy(await x, error);
}
) => {
await asserter(updateCopilotSession(app, sessionId, promptName));
await asserter(updateCopilotSession(app, sessionId, textPromptName));
};
{
@@ -212,7 +212,7 @@ test('should update session correctly', async t => {
app,
workspaceId,
docId,
promptName
textPromptName
);
await assertUpdateSession(
sessionId,
@@ -225,7 +225,7 @@ test('should update session correctly', async t => {
app,
randomUUID(),
randomUUID(),
promptName
textPromptName
);
await assertUpdateSession(
sessionId,
@@ -244,7 +244,7 @@ test('should update session correctly', async t => {
app,
workspaceId,
randomUUID(),
promptName
textPromptName
);
await assertUpdateSession(
sessionId,
@@ -294,7 +294,7 @@ test('should fork session correctly', async t => {
app,
id,
randomUUID(),
promptName
textPromptName
);
let forkedSessionId: string;
@@ -363,7 +363,7 @@ test('should be able to use test provider', async t => {
const { id } = await createWorkspace(app);
t.truthy(
await createCopilotSession(app, id, randomUUID(), promptName),
await createCopilotSession(app, id, randomUUID(), textPromptName),
'failed to create session'
);
});
@@ -379,7 +379,7 @@ test('should create message correctly', async t => {
app,
id,
randomUUID(),
promptName
textPromptName
);
const messageId = await createCopilotMessage(app, sessionId);
t.truthy(messageId, 'should be able to create message with valid session');
@@ -393,7 +393,7 @@ test('should create message correctly', async t => {
app,
id,
randomUUID(),
promptName
textPromptName
);
const messageId = await createCopilotMessage(app, sessionId, undefined, [
'http://example.com/cat.jpg',
@@ -408,7 +408,7 @@ test('should create message correctly', async t => {
app,
id,
randomUUID(),
promptName
textPromptName
);
const smallestPng =
'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAgAAAAIAQMAAAD+wSzIAAAABlBMVEX///+/v7+jQ3Y5AAAADklEQVQI12P4AIX8EAgALgAD/aNpbtEAAAAASUVORK5CYII';
@@ -445,7 +445,7 @@ test('should be able to chat with api', async t => {
app,
id,
randomUUID(),
promptName
textPromptName
);
const messageId = await createCopilotMessage(app, sessionId);
const ret = await chatWithText(app, sessionId, messageId);
@@ -543,7 +543,7 @@ test('should be able to retry with api', async t => {
app,
id,
randomUUID(),
promptName
textPromptName
);
const messageId = await createCopilotMessage(app, sessionId);
// chat 2 times
@@ -565,7 +565,7 @@ test('should be able to retry with api', async t => {
app,
id,
randomUUID(),
promptName
textPromptName
);
const messageId = await createCopilotMessage(app, sessionId);
await chatWithText(app, sessionId, messageId);
@@ -587,7 +587,7 @@ test('should be able to retry with api', async t => {
app,
id,
randomUUID(),
promptName
textPromptName
);
const messageId = await createCopilotMessage(app, sessionId);
await chatWithText(app, sessionId, messageId);
@@ -614,13 +614,13 @@ test('should reject message from different session', async t => {
app,
id,
randomUUID(),
promptName
textPromptName
);
const anotherSessionId = await createCopilotSession(
app,
id,
randomUUID(),
promptName
textPromptName
);
const anotherMessageId = await createCopilotMessage(app, anotherSessionId);
await t.throwsAsync(
@@ -639,7 +639,7 @@ test('should reject request from different user', async t => {
app,
id,
randomUUID(),
promptName
textPromptName
);
// should reject message from different user
@@ -677,7 +677,7 @@ test('should be able to list history', async t => {
app,
workspaceId,
randomUUID(),
promptName
textPromptName
);
const messageId = await createCopilotMessage(app, sessionId, 'hello');
@@ -740,7 +740,7 @@ test('should reject request that user have not permission', async t => {
app,
workspaceId,
randomUUID(),
promptName
textPromptName
);
const messageId = await createCopilotMessage(app, sessionId);
@@ -777,7 +777,7 @@ test('should be able to manage context', async t => {
app,
workspaceId,
randomUUID(),
promptName
textPromptName
);
// use mocked embedding client
@@ -859,7 +859,7 @@ test('should be able to manage context', async t => {
app,
workspaceId,
randomUUID(),
promptName
textPromptName
);
const contextId = await createCopilotContext(app, workspaceId, sessionId);

View File

@@ -1,4 +1,5 @@
import { randomUUID } from 'node:crypto';
import { Readable } from 'node:stream';
import { ProjectRoot } from '@affine-tools/utils/path';
import { PrismaClient } from '@prisma/client';
@@ -6,11 +7,11 @@ import type { TestFn } from 'ava';
import ava from 'ava';
import Sinon from 'sinon';
import { EventBus } from '../base';
import { EventBus, JobQueue } from '../base';
import { ConfigModule } from '../base/config';
import { AuthService } from '../core/auth';
import { QuotaModule } from '../core/quota';
import { ContextCategories } from '../models';
import { ContextCategories, WorkspaceModel } from '../models';
import { CopilotModule } from '../plugins/copilot';
import {
CopilotContextDocJob,
@@ -47,6 +48,7 @@ import {
} from '../plugins/copilot/workflow/executor';
import { AutoRegisteredWorkflowExecutor } from '../plugins/copilot/workflow/executor/utils';
import { WorkflowGraphList } from '../plugins/copilot/workflow/graph';
import { CopilotWorkspaceService } from '../plugins/copilot/workspace';
import { MockCopilotProvider } from './mocks';
import { createTestingModule, TestingModule } from './utils';
import { WorkflowTestCases } from './utils/copilot';
@@ -56,9 +58,11 @@ const test = ava as TestFn<{
module: TestingModule;
db: PrismaClient;
event: EventBus;
workspace: WorkspaceModel;
context: CopilotContextService;
prompt: PromptService;
transcript: CopilotTranscriptionService;
workspaceEmbedding: CopilotWorkspaceService;
factory: CopilotProviderFactory;
session: ChatSessionService;
jobs: CopilotContextDocJob;
@@ -95,6 +99,8 @@ test.before(async t => {
CopilotModule,
],
tapModule: builder => {
// use real JobQueue for testing
builder.overrideProvider(JobQueue).useClass(JobQueue);
builder.overrideProvider(OpenAIProvider).useClass(MockCopilotProvider);
},
});
@@ -102,6 +108,7 @@ test.before(async t => {
const auth = module.get(AuthService);
const db = module.get(PrismaClient);
const event = module.get(EventBus);
const workspace = module.get(WorkspaceModel);
const prompt = module.get(PromptService);
const factory = module.get(CopilotProviderFactory);
@@ -112,11 +119,13 @@ test.before(async t => {
const context = module.get(CopilotContextService);
const jobs = module.get(CopilotContextDocJob);
const transcript = module.get(CopilotTranscriptionService);
const workspaceEmbedding = module.get(CopilotWorkspaceService);
t.context.module = module;
t.context.auth = auth;
t.context.db = db;
t.context.event = event;
t.context.workspace = workspace;
t.context.prompt = prompt;
t.context.factory = factory;
t.context.session = session;
@@ -125,6 +134,7 @@ test.before(async t => {
t.context.context = context;
t.context.jobs = jobs;
t.context.transcript = transcript;
t.context.workspaceEmbedding = workspaceEmbedding;
t.context.executors = {
image: module.get(CopilotChatImageExecutor),
@@ -1426,3 +1436,70 @@ test('should be able to manage context', async t => {
}
}
});
// ==================== workspace embedding ====================
test('should be able to manage workspace embedding', async t => {
const { db, jobs, workspace, workspaceEmbedding, context, prompt, session } =
t.context;
// use mocked embedding client
Sinon.stub(context, 'embeddingClient').get(() => new MockEmbeddingClient());
Sinon.stub(jobs, 'embeddingClient').get(() => new MockEmbeddingClient());
const ws = await workspace.create(userId);
// should create workspace embedding
{
const { blobId, file } = await workspaceEmbedding.addFile(userId, ws.id, {
filename: 'test.txt',
mimetype: 'text/plain',
encoding: 'utf-8',
createReadStream: () => {
return new Readable({
read() {
this.push(Buffer.from('content'));
this.push(null);
},
});
},
});
await workspaceEmbedding.queueFileEmbedding({
userId,
workspaceId: ws.id,
blobId,
fileId: file.fileId,
fileName: file.fileName,
});
let ret = 0;
while (!ret) {
await new Promise(resolve => setTimeout(resolve, 1000));
ret = await db.aiWorkspaceFileEmbedding.count({
where: { workspaceId: ws.id, fileId: file.fileId },
});
}
}
// should create workspace embedding with file
{
await prompt.set('prompt', 'model', [
{ role: 'system', content: 'hello {{word}}' },
]);
const sessionId = await session.create({
docId: 'test',
workspaceId: ws.id,
userId,
promptName: 'prompt',
});
const contextSession = await context.create(sessionId);
const ret = await contextSession.matchFileChunks('test', 1, undefined, 1);
t.is(ret.length, 1, 'should match workspace context');
t.is(ret[0].content, 'content', 'should match content');
await workspace.update(ws.id, { enableDocEmbedding: false });
const ret2 = await contextSession.matchFileChunks('test', 1, undefined, 1);
t.is(ret2.length, 0, 'should not match workspace context');
}
});

View File

@@ -96,20 +96,16 @@ test('should insert embedding by doc id', async t => {
const { id: contextId } = await t.context.copilotContext.create(session.id);
{
await t.context.copilotContext.insertContentEmbedding(
contextId,
'file-id',
[
{
index: 0,
content: 'content',
embedding: Array.from({ length: 1024 }, () => 1),
},
]
);
await t.context.copilotContext.insertFileEmbedding(contextId, 'file-id', [
{
index: 0,
content: 'content',
embedding: Array.from({ length: 1024 }, () => 1),
},
]);
{
const ret = await t.context.copilotContext.matchContentEmbedding(
const ret = await t.context.copilotContext.matchFileEmbedding(
Array.from({ length: 1024 }, () => 0.9),
contextId,
1,
@@ -121,7 +117,7 @@ test('should insert embedding by doc id', async t => {
{
await t.context.copilotContext.deleteEmbedding(contextId, 'file-id');
const ret = await t.context.copilotContext.matchContentEmbedding(
const ret = await t.context.copilotContext.matchFileEmbedding(
Array.from({ length: 1024 }, () => 0.9),
contextId,
1,

View File

@@ -110,16 +110,20 @@ test('should insert and search embedding', async t => {
mimeType: 'text/plain',
size: 1,
});
await t.context.copilotWorkspace.addFileEmbeddings(workspace.id, fileId, [
{
index: 0,
content: 'content',
embedding: Array.from({ length: 1024 }, () => 1),
},
]);
await t.context.copilotWorkspace.insertFileEmbeddings(
workspace.id,
fileId,
[
{
index: 0,
content: 'content',
embedding: Array.from({ length: 1024 }, () => 1),
},
]
);
{
const ret = await t.context.copilotWorkspace.matchWorkspaceFileEmbedding(
const ret = await t.context.copilotWorkspace.matchFileEmbedding(
workspace.id,
Array.from({ length: 1024 }, () => 0.9),
1,

View File

@@ -153,7 +153,7 @@ export class CopilotContextModel extends BaseModel {
return Prisma.join(groups.map(row => Prisma.sql`(${Prisma.join(row)})`));
}
async insertContentEmbedding(
async insertFileEmbedding(
contextId: string,
fileId: string,
embeddings: Embedding[]
@@ -168,7 +168,7 @@ export class CopilotContextModel extends BaseModel {
`;
}
async matchContentEmbedding(
async matchFileEmbedding(
embedding: number[],
contextId: string,
topK: number,

View File

@@ -138,7 +138,7 @@ export class CopilotWorkspaceConfigModel extends BaseModel {
}
@Transactional()
async addFileEmbeddings(
async insertFileEmbeddings(
workspaceId: string,
fileId: string,
embeddings: Embedding[]
@@ -151,7 +151,7 @@ export class CopilotWorkspaceConfigModel extends BaseModel {
`;
}
async listWorkspaceFiles(
async listFiles(
workspaceId: string,
options?: {
includeRead?: boolean;
@@ -168,7 +168,7 @@ export class CopilotWorkspaceConfigModel extends BaseModel {
return files;
}
async countWorkspaceFiles(workspaceId: string): Promise<number> {
async countFiles(workspaceId: string): Promise<number> {
const count = await this.db.aiWorkspaceFiles.count({
where: {
workspaceId,
@@ -177,12 +177,16 @@ export class CopilotWorkspaceConfigModel extends BaseModel {
return count;
}
async matchWorkspaceFileEmbedding(
async matchFileEmbedding(
workspaceId: string,
embedding: number[],
topK: number,
threshold: number
): Promise<FileChunkSimilarity[]> {
if (!(await this.allowEmbedding(workspaceId))) {
return [];
}
const similarityChunks = await this.db.$queryRaw<
Array<FileChunkSimilarity>
>`
@@ -195,7 +199,7 @@ export class CopilotWorkspaceConfigModel extends BaseModel {
return similarityChunks.filter(c => Number(c.distance) <= threshold);
}
async removeWorkspaceFile(workspaceId: string, fileId: string) {
async removeFile(workspaceId: string, fileId: string) {
// embeddings will be removed by foreign key constraint
await this.db.aiWorkspaceFiles.deleteMany({
where: {
@@ -205,4 +209,8 @@ export class CopilotWorkspaceConfigModel extends BaseModel {
});
return true;
}
private allowEmbedding(workspaceId: string) {
return this.models.workspace.allowEmbedding(workspaceId);
}
}

View File

@@ -124,24 +124,38 @@ export class CopilotContextDocJob {
for (const chunk of chunks) {
const embeddings = await this.embeddingClient.generateEmbeddings(chunk);
await this.models.copilotContext.insertContentEmbedding(
contextId,
fileId,
embeddings
);
if (contextId) {
// for context files
await this.models.copilotContext.insertFileEmbedding(
contextId,
fileId,
embeddings
);
} else {
// for workspace files
await this.models.copilotWorkspace.insertFileEmbeddings(
workspaceId,
fileId,
embeddings
);
}
}
this.event.emit('workspace.file.embed.finished', {
contextId,
fileId,
chunkSize: total,
});
if (contextId) {
this.event.emit('workspace.file.embed.finished', {
contextId,
fileId,
chunkSize: total,
});
}
} catch (error: any) {
this.event.emit('workspace.file.embed.failed', {
contextId,
fileId,
error: mapAnyError(error).message,
});
if (contextId) {
this.event.emit('workspace.file.embed.failed', {
contextId,
fileId,
error: mapAnyError(error).message,
});
}
// passthrough error to job queue
throw error;

View File

@@ -8,6 +8,7 @@ import {
ContextEmbedStatus,
ContextFile,
ContextList,
FileChunkSimilarity,
Models,
} from '../../../models';
import { EmbeddingClient } from './types';
@@ -176,18 +177,28 @@ export class ContextSession implements AsyncDisposable {
topK: number = 5,
signal?: AbortSignal,
threshold: number = 0.7
) {
): Promise<FileChunkSimilarity[]> {
const embedding = await this.client
.getEmbeddings([content], signal)
.then(r => r?.[0]?.embedding);
if (!embedding) return [];
return this.models.copilotContext.matchContentEmbedding(
embedding,
this.id,
topK,
threshold
);
const [context, workspace] = await Promise.all([
this.models.copilotContext.matchFileEmbedding(
embedding,
this.id,
topK,
threshold
),
this.models.copilotWorkspace.matchFileEmbedding(
this.workspaceId,
embedding,
topK,
threshold
),
]);
return this.client.reRank([...context, ...workspace]);
}
/**

View File

@@ -1,7 +1,7 @@
import { File } from 'node:buffer';
import { CopilotContextFileNotSupported } from '../../../base';
import { Embedding } from '../../../models';
import { ChunkSimilarity, Embedding } from '../../../models';
import { parseDoc } from '../../../native';
declare global {
@@ -36,7 +36,7 @@ declare global {
};
'copilot.embedding.files': {
contextId: string;
contextId?: string;
userId: string;
workspaceId: string;
blobId: string;
@@ -114,6 +114,15 @@ export abstract class EmbeddingClient {
return embeddings.map(e => ({ ...e, index: chunks[e.index].index }));
}
async reRank<Chunk extends ChunkSimilarity = ChunkSimilarity>(
embeddings: Chunk[]
): Promise<Chunk[]> {
// sort by distance with ascending order
return embeddings.sort(
(a, b) => (a.distance ?? Infinity) - (b.distance ?? Infinity)
);
}
abstract getEmbeddings(
input: string[],
signal?: AbortSignal

View File

@@ -135,7 +135,7 @@ export class CopilotWorkspaceEmbeddingConfigResolver {
@Parent() config: CopilotWorkspaceConfigType,
@Args('pagination', PaginationInput.decode) pagination: PaginationInput
): Promise<PaginatedCopilotWorkspaceFileType> {
const [files, totalCount] = await this.copilotWorkspace.listWorkspaceFiles(
const [files, totalCount] = await this.copilotWorkspace.listFiles(
config.workspaceId,
pagination
);
@@ -177,12 +177,12 @@ export class CopilotWorkspaceEmbeddingConfigResolver {
}
try {
const { blobId, file } = await this.copilotWorkspace.addWorkspaceFile(
const { blobId, file } = await this.copilotWorkspace.addFile(
user.id,
workspaceId,
content
);
await this.copilotWorkspace.addWorkspaceFileEmbeddingQueue({
await this.copilotWorkspace.queueFileEmbedding({
userId: user.id,
workspaceId,
blobId,
@@ -219,6 +219,6 @@ export class CopilotWorkspaceEmbeddingConfigResolver {
.workspace(workspaceId)
.assert('Workspace.Settings.Update');
return await this.copilotWorkspace.removeWorkspaceFile(workspaceId, fileId);
return await this.copilotWorkspace.removeFile(workspaceId, fileId);
}
}

View File

@@ -53,11 +53,7 @@ export class CopilotWorkspaceService implements OnApplicationBootstrap {
]);
}
async addWorkspaceFile(
userId: string,
workspaceId: string,
content: FileUpload
) {
async addFile(userId: string, workspaceId: string, content: FileUpload) {
const fileName = content.filename;
const buffer = await readStream(content.createReadStream());
const blobId = createHash('sha256').update(buffer).digest('base64url');
@@ -70,29 +66,25 @@ export class CopilotWorkspaceService implements OnApplicationBootstrap {
return { blobId, file };
}
async getWorkspaceFile(workspaceId: string, fileId: string) {
async getFile(workspaceId: string, fileId: string) {
return await this.models.copilotWorkspace.getFile(workspaceId, fileId);
}
async listWorkspaceFiles(
async listFiles(
workspaceId: string,
pagination?: {
includeRead?: boolean;
} & PaginationInput
) {
return await Promise.all([
this.models.copilotWorkspace.listWorkspaceFiles(workspaceId, pagination),
this.models.copilotWorkspace.countIgnoredDocs(workspaceId),
this.models.copilotWorkspace.listFiles(workspaceId, pagination),
this.models.copilotWorkspace.countFiles(workspaceId),
]);
}
async addWorkspaceFileEmbeddingQueue(
file: Jobs['copilot.workspace.embedding.files']
) {
if (!this.supportEmbedding) return;
async queueFileEmbedding(file: Jobs['copilot.embedding.files']) {
const { userId, workspaceId, blobId, fileId, fileName } = file;
await this.queue.add('copilot.workspace.embedding.files', {
await this.queue.add('copilot.embedding.files', {
userId,
workspaceId,
blobId,
@@ -101,10 +93,7 @@ export class CopilotWorkspaceService implements OnApplicationBootstrap {
});
}
async removeWorkspaceFile(workspaceId: string, fileId: string) {
return await this.models.copilotWorkspace.removeWorkspaceFile(
workspaceId,
fileId
);
async removeFile(workspaceId: string, fileId: string) {
return await this.models.copilotWorkspace.removeFile(workspaceId, fileId);
}
}

View File

@@ -13,15 +13,6 @@ declare global {
jobId: string;
};
}
interface Jobs {
'copilot.workspace.embedding.files': {
userId: string;
workspaceId: string;
blobId: string;
fileId: string;
fileName: string;
};
}
}
@ObjectType('CopilotWorkspaceIgnoredDoc')