mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-14 21:27:20 +00:00
fix(server): query workspace embed files (#11982)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Expanded file chunk matching to include both context and workspace file embeddings, providing broader and more relevant search results. - **Improvements** - Enhanced result ranking by introducing a re-ranking step for combined embedding matches, improving the relevance of returned file chunks. - Adjusted file count reporting to reflect the total number of workspace files instead of ignored documents for more accurate workspace file statistics. - Renamed and streamlined workspace file management methods for clearer and more consistent API usage. - **Bug Fixes** - Prevented embedding similarity queries when embedding is disabled for a workspace, improving system behavior consistency. - **Tests** - Added comprehensive tests to verify workspace embedding management, including enabling, matching, and disabling embedding functionality. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
@@ -284,7 +284,15 @@ const actions = [
|
||||
verifier: (t: ExecutionContext<Tester>, result: string) => {
|
||||
assertNotWrappedInCodeBlock(t, result);
|
||||
assertCitation(t, result, (t, c) => {
|
||||
t.assert(c.length === 0, 'should not have citation');
|
||||
t.assert(
|
||||
c.length === 0 ||
|
||||
// ignore web search result
|
||||
c
|
||||
.map(c => JSON.parse(c.citationJson).type)
|
||||
.filter(type => ['attachment', 'doc'].includes(type)).length ===
|
||||
0,
|
||||
'should not have citation'
|
||||
);
|
||||
});
|
||||
},
|
||||
type: 'text' as const,
|
||||
@@ -404,8 +412,9 @@ const actions = [
|
||||
messages: [{ role: 'user' as const, content: TestAssets.SSOT }],
|
||||
verifier: (t: ExecutionContext<Tester>, result: string) => {
|
||||
assertNotWrappedInCodeBlock(t, result);
|
||||
const cleared = result.toLowerCase();
|
||||
t.assert(
|
||||
result.toLowerCase().includes('single source of truth'),
|
||||
cleared.includes('single source of truth') || cleared.includes('ssot'),
|
||||
'should include original keyword'
|
||||
);
|
||||
},
|
||||
|
||||
@@ -120,7 +120,7 @@ test.before(async t => {
|
||||
t.context.jobs = jobs;
|
||||
});
|
||||
|
||||
const promptName = 'prompt';
|
||||
const textPromptName = 'prompt';
|
||||
test.beforeEach(async t => {
|
||||
Sinon.restore();
|
||||
const { app, prompt } = t.context;
|
||||
@@ -128,7 +128,7 @@ test.beforeEach(async t => {
|
||||
await prompt.onApplicationBootstrap();
|
||||
t.context.u1 = await app.signupV1('u1@affine.pro');
|
||||
|
||||
await prompt.set(promptName, 'test', [
|
||||
await prompt.set(textPromptName, 'test', [
|
||||
{ role: 'system', content: 'hello {{word}}' },
|
||||
]);
|
||||
});
|
||||
@@ -150,7 +150,7 @@ test('should create session correctly', async t => {
|
||||
}
|
||||
) => {
|
||||
await asserter(
|
||||
createCopilotSession(app, workspaceId, randomUUID(), promptName)
|
||||
createCopilotSession(app, workspaceId, randomUUID(), textPromptName)
|
||||
);
|
||||
};
|
||||
|
||||
@@ -202,7 +202,7 @@ test('should update session correctly', async t => {
|
||||
t.truthy(await x, error);
|
||||
}
|
||||
) => {
|
||||
await asserter(updateCopilotSession(app, sessionId, promptName));
|
||||
await asserter(updateCopilotSession(app, sessionId, textPromptName));
|
||||
};
|
||||
|
||||
{
|
||||
@@ -212,7 +212,7 @@ test('should update session correctly', async t => {
|
||||
app,
|
||||
workspaceId,
|
||||
docId,
|
||||
promptName
|
||||
textPromptName
|
||||
);
|
||||
await assertUpdateSession(
|
||||
sessionId,
|
||||
@@ -225,7 +225,7 @@ test('should update session correctly', async t => {
|
||||
app,
|
||||
randomUUID(),
|
||||
randomUUID(),
|
||||
promptName
|
||||
textPromptName
|
||||
);
|
||||
await assertUpdateSession(
|
||||
sessionId,
|
||||
@@ -244,7 +244,7 @@ test('should update session correctly', async t => {
|
||||
app,
|
||||
workspaceId,
|
||||
randomUUID(),
|
||||
promptName
|
||||
textPromptName
|
||||
);
|
||||
await assertUpdateSession(
|
||||
sessionId,
|
||||
@@ -294,7 +294,7 @@ test('should fork session correctly', async t => {
|
||||
app,
|
||||
id,
|
||||
randomUUID(),
|
||||
promptName
|
||||
textPromptName
|
||||
);
|
||||
|
||||
let forkedSessionId: string;
|
||||
@@ -363,7 +363,7 @@ test('should be able to use test provider', async t => {
|
||||
|
||||
const { id } = await createWorkspace(app);
|
||||
t.truthy(
|
||||
await createCopilotSession(app, id, randomUUID(), promptName),
|
||||
await createCopilotSession(app, id, randomUUID(), textPromptName),
|
||||
'failed to create session'
|
||||
);
|
||||
});
|
||||
@@ -379,7 +379,7 @@ test('should create message correctly', async t => {
|
||||
app,
|
||||
id,
|
||||
randomUUID(),
|
||||
promptName
|
||||
textPromptName
|
||||
);
|
||||
const messageId = await createCopilotMessage(app, sessionId);
|
||||
t.truthy(messageId, 'should be able to create message with valid session');
|
||||
@@ -393,7 +393,7 @@ test('should create message correctly', async t => {
|
||||
app,
|
||||
id,
|
||||
randomUUID(),
|
||||
promptName
|
||||
textPromptName
|
||||
);
|
||||
const messageId = await createCopilotMessage(app, sessionId, undefined, [
|
||||
'http://example.com/cat.jpg',
|
||||
@@ -408,7 +408,7 @@ test('should create message correctly', async t => {
|
||||
app,
|
||||
id,
|
||||
randomUUID(),
|
||||
promptName
|
||||
textPromptName
|
||||
);
|
||||
const smallestPng =
|
||||
'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAgAAAAIAQMAAAD+wSzIAAAABlBMVEX///+/v7+jQ3Y5AAAADklEQVQI12P4AIX8EAgALgAD/aNpbtEAAAAASUVORK5CYII';
|
||||
@@ -445,7 +445,7 @@ test('should be able to chat with api', async t => {
|
||||
app,
|
||||
id,
|
||||
randomUUID(),
|
||||
promptName
|
||||
textPromptName
|
||||
);
|
||||
const messageId = await createCopilotMessage(app, sessionId);
|
||||
const ret = await chatWithText(app, sessionId, messageId);
|
||||
@@ -543,7 +543,7 @@ test('should be able to retry with api', async t => {
|
||||
app,
|
||||
id,
|
||||
randomUUID(),
|
||||
promptName
|
||||
textPromptName
|
||||
);
|
||||
const messageId = await createCopilotMessage(app, sessionId);
|
||||
// chat 2 times
|
||||
@@ -565,7 +565,7 @@ test('should be able to retry with api', async t => {
|
||||
app,
|
||||
id,
|
||||
randomUUID(),
|
||||
promptName
|
||||
textPromptName
|
||||
);
|
||||
const messageId = await createCopilotMessage(app, sessionId);
|
||||
await chatWithText(app, sessionId, messageId);
|
||||
@@ -587,7 +587,7 @@ test('should be able to retry with api', async t => {
|
||||
app,
|
||||
id,
|
||||
randomUUID(),
|
||||
promptName
|
||||
textPromptName
|
||||
);
|
||||
const messageId = await createCopilotMessage(app, sessionId);
|
||||
await chatWithText(app, sessionId, messageId);
|
||||
@@ -614,13 +614,13 @@ test('should reject message from different session', async t => {
|
||||
app,
|
||||
id,
|
||||
randomUUID(),
|
||||
promptName
|
||||
textPromptName
|
||||
);
|
||||
const anotherSessionId = await createCopilotSession(
|
||||
app,
|
||||
id,
|
||||
randomUUID(),
|
||||
promptName
|
||||
textPromptName
|
||||
);
|
||||
const anotherMessageId = await createCopilotMessage(app, anotherSessionId);
|
||||
await t.throwsAsync(
|
||||
@@ -639,7 +639,7 @@ test('should reject request from different user', async t => {
|
||||
app,
|
||||
id,
|
||||
randomUUID(),
|
||||
promptName
|
||||
textPromptName
|
||||
);
|
||||
|
||||
// should reject message from different user
|
||||
@@ -677,7 +677,7 @@ test('should be able to list history', async t => {
|
||||
app,
|
||||
workspaceId,
|
||||
randomUUID(),
|
||||
promptName
|
||||
textPromptName
|
||||
);
|
||||
|
||||
const messageId = await createCopilotMessage(app, sessionId, 'hello');
|
||||
@@ -740,7 +740,7 @@ test('should reject request that user have not permission', async t => {
|
||||
app,
|
||||
workspaceId,
|
||||
randomUUID(),
|
||||
promptName
|
||||
textPromptName
|
||||
);
|
||||
|
||||
const messageId = await createCopilotMessage(app, sessionId);
|
||||
@@ -777,7 +777,7 @@ test('should be able to manage context', async t => {
|
||||
app,
|
||||
workspaceId,
|
||||
randomUUID(),
|
||||
promptName
|
||||
textPromptName
|
||||
);
|
||||
|
||||
// use mocked embedding client
|
||||
@@ -859,7 +859,7 @@ test('should be able to manage context', async t => {
|
||||
app,
|
||||
workspaceId,
|
||||
randomUUID(),
|
||||
promptName
|
||||
textPromptName
|
||||
);
|
||||
const contextId = await createCopilotContext(app, workspaceId, sessionId);
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { randomUUID } from 'node:crypto';
|
||||
import { Readable } from 'node:stream';
|
||||
|
||||
import { ProjectRoot } from '@affine-tools/utils/path';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
@@ -6,11 +7,11 @@ import type { TestFn } from 'ava';
|
||||
import ava from 'ava';
|
||||
import Sinon from 'sinon';
|
||||
|
||||
import { EventBus } from '../base';
|
||||
import { EventBus, JobQueue } from '../base';
|
||||
import { ConfigModule } from '../base/config';
|
||||
import { AuthService } from '../core/auth';
|
||||
import { QuotaModule } from '../core/quota';
|
||||
import { ContextCategories } from '../models';
|
||||
import { ContextCategories, WorkspaceModel } from '../models';
|
||||
import { CopilotModule } from '../plugins/copilot';
|
||||
import {
|
||||
CopilotContextDocJob,
|
||||
@@ -47,6 +48,7 @@ import {
|
||||
} from '../plugins/copilot/workflow/executor';
|
||||
import { AutoRegisteredWorkflowExecutor } from '../plugins/copilot/workflow/executor/utils';
|
||||
import { WorkflowGraphList } from '../plugins/copilot/workflow/graph';
|
||||
import { CopilotWorkspaceService } from '../plugins/copilot/workspace';
|
||||
import { MockCopilotProvider } from './mocks';
|
||||
import { createTestingModule, TestingModule } from './utils';
|
||||
import { WorkflowTestCases } from './utils/copilot';
|
||||
@@ -56,9 +58,11 @@ const test = ava as TestFn<{
|
||||
module: TestingModule;
|
||||
db: PrismaClient;
|
||||
event: EventBus;
|
||||
workspace: WorkspaceModel;
|
||||
context: CopilotContextService;
|
||||
prompt: PromptService;
|
||||
transcript: CopilotTranscriptionService;
|
||||
workspaceEmbedding: CopilotWorkspaceService;
|
||||
factory: CopilotProviderFactory;
|
||||
session: ChatSessionService;
|
||||
jobs: CopilotContextDocJob;
|
||||
@@ -95,6 +99,8 @@ test.before(async t => {
|
||||
CopilotModule,
|
||||
],
|
||||
tapModule: builder => {
|
||||
// use real JobQueue for testing
|
||||
builder.overrideProvider(JobQueue).useClass(JobQueue);
|
||||
builder.overrideProvider(OpenAIProvider).useClass(MockCopilotProvider);
|
||||
},
|
||||
});
|
||||
@@ -102,6 +108,7 @@ test.before(async t => {
|
||||
const auth = module.get(AuthService);
|
||||
const db = module.get(PrismaClient);
|
||||
const event = module.get(EventBus);
|
||||
const workspace = module.get(WorkspaceModel);
|
||||
const prompt = module.get(PromptService);
|
||||
const factory = module.get(CopilotProviderFactory);
|
||||
|
||||
@@ -112,11 +119,13 @@ test.before(async t => {
|
||||
const context = module.get(CopilotContextService);
|
||||
const jobs = module.get(CopilotContextDocJob);
|
||||
const transcript = module.get(CopilotTranscriptionService);
|
||||
const workspaceEmbedding = module.get(CopilotWorkspaceService);
|
||||
|
||||
t.context.module = module;
|
||||
t.context.auth = auth;
|
||||
t.context.db = db;
|
||||
t.context.event = event;
|
||||
t.context.workspace = workspace;
|
||||
t.context.prompt = prompt;
|
||||
t.context.factory = factory;
|
||||
t.context.session = session;
|
||||
@@ -125,6 +134,7 @@ test.before(async t => {
|
||||
t.context.context = context;
|
||||
t.context.jobs = jobs;
|
||||
t.context.transcript = transcript;
|
||||
t.context.workspaceEmbedding = workspaceEmbedding;
|
||||
|
||||
t.context.executors = {
|
||||
image: module.get(CopilotChatImageExecutor),
|
||||
@@ -1426,3 +1436,70 @@ test('should be able to manage context', async t => {
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// ==================== workspace embedding ====================
|
||||
test('should be able to manage workspace embedding', async t => {
|
||||
const { db, jobs, workspace, workspaceEmbedding, context, prompt, session } =
|
||||
t.context;
|
||||
|
||||
// use mocked embedding client
|
||||
Sinon.stub(context, 'embeddingClient').get(() => new MockEmbeddingClient());
|
||||
Sinon.stub(jobs, 'embeddingClient').get(() => new MockEmbeddingClient());
|
||||
|
||||
const ws = await workspace.create(userId);
|
||||
|
||||
// should create workspace embedding
|
||||
{
|
||||
const { blobId, file } = await workspaceEmbedding.addFile(userId, ws.id, {
|
||||
filename: 'test.txt',
|
||||
mimetype: 'text/plain',
|
||||
encoding: 'utf-8',
|
||||
createReadStream: () => {
|
||||
return new Readable({
|
||||
read() {
|
||||
this.push(Buffer.from('content'));
|
||||
this.push(null);
|
||||
},
|
||||
});
|
||||
},
|
||||
});
|
||||
await workspaceEmbedding.queueFileEmbedding({
|
||||
userId,
|
||||
workspaceId: ws.id,
|
||||
blobId,
|
||||
fileId: file.fileId,
|
||||
fileName: file.fileName,
|
||||
});
|
||||
|
||||
let ret = 0;
|
||||
while (!ret) {
|
||||
await new Promise(resolve => setTimeout(resolve, 1000));
|
||||
ret = await db.aiWorkspaceFileEmbedding.count({
|
||||
where: { workspaceId: ws.id, fileId: file.fileId },
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// should create workspace embedding with file
|
||||
{
|
||||
await prompt.set('prompt', 'model', [
|
||||
{ role: 'system', content: 'hello {{word}}' },
|
||||
]);
|
||||
const sessionId = await session.create({
|
||||
docId: 'test',
|
||||
workspaceId: ws.id,
|
||||
userId,
|
||||
promptName: 'prompt',
|
||||
});
|
||||
const contextSession = await context.create(sessionId);
|
||||
|
||||
const ret = await contextSession.matchFileChunks('test', 1, undefined, 1);
|
||||
t.is(ret.length, 1, 'should match workspace context');
|
||||
t.is(ret[0].content, 'content', 'should match content');
|
||||
|
||||
await workspace.update(ws.id, { enableDocEmbedding: false });
|
||||
|
||||
const ret2 = await contextSession.matchFileChunks('test', 1, undefined, 1);
|
||||
t.is(ret2.length, 0, 'should not match workspace context');
|
||||
}
|
||||
});
|
||||
|
||||
@@ -96,20 +96,16 @@ test('should insert embedding by doc id', async t => {
|
||||
const { id: contextId } = await t.context.copilotContext.create(session.id);
|
||||
|
||||
{
|
||||
await t.context.copilotContext.insertContentEmbedding(
|
||||
contextId,
|
||||
'file-id',
|
||||
[
|
||||
{
|
||||
index: 0,
|
||||
content: 'content',
|
||||
embedding: Array.from({ length: 1024 }, () => 1),
|
||||
},
|
||||
]
|
||||
);
|
||||
await t.context.copilotContext.insertFileEmbedding(contextId, 'file-id', [
|
||||
{
|
||||
index: 0,
|
||||
content: 'content',
|
||||
embedding: Array.from({ length: 1024 }, () => 1),
|
||||
},
|
||||
]);
|
||||
|
||||
{
|
||||
const ret = await t.context.copilotContext.matchContentEmbedding(
|
||||
const ret = await t.context.copilotContext.matchFileEmbedding(
|
||||
Array.from({ length: 1024 }, () => 0.9),
|
||||
contextId,
|
||||
1,
|
||||
@@ -121,7 +117,7 @@ test('should insert embedding by doc id', async t => {
|
||||
|
||||
{
|
||||
await t.context.copilotContext.deleteEmbedding(contextId, 'file-id');
|
||||
const ret = await t.context.copilotContext.matchContentEmbedding(
|
||||
const ret = await t.context.copilotContext.matchFileEmbedding(
|
||||
Array.from({ length: 1024 }, () => 0.9),
|
||||
contextId,
|
||||
1,
|
||||
|
||||
@@ -110,16 +110,20 @@ test('should insert and search embedding', async t => {
|
||||
mimeType: 'text/plain',
|
||||
size: 1,
|
||||
});
|
||||
await t.context.copilotWorkspace.addFileEmbeddings(workspace.id, fileId, [
|
||||
{
|
||||
index: 0,
|
||||
content: 'content',
|
||||
embedding: Array.from({ length: 1024 }, () => 1),
|
||||
},
|
||||
]);
|
||||
await t.context.copilotWorkspace.insertFileEmbeddings(
|
||||
workspace.id,
|
||||
fileId,
|
||||
[
|
||||
{
|
||||
index: 0,
|
||||
content: 'content',
|
||||
embedding: Array.from({ length: 1024 }, () => 1),
|
||||
},
|
||||
]
|
||||
);
|
||||
|
||||
{
|
||||
const ret = await t.context.copilotWorkspace.matchWorkspaceFileEmbedding(
|
||||
const ret = await t.context.copilotWorkspace.matchFileEmbedding(
|
||||
workspace.id,
|
||||
Array.from({ length: 1024 }, () => 0.9),
|
||||
1,
|
||||
|
||||
@@ -153,7 +153,7 @@ export class CopilotContextModel extends BaseModel {
|
||||
return Prisma.join(groups.map(row => Prisma.sql`(${Prisma.join(row)})`));
|
||||
}
|
||||
|
||||
async insertContentEmbedding(
|
||||
async insertFileEmbedding(
|
||||
contextId: string,
|
||||
fileId: string,
|
||||
embeddings: Embedding[]
|
||||
@@ -168,7 +168,7 @@ export class CopilotContextModel extends BaseModel {
|
||||
`;
|
||||
}
|
||||
|
||||
async matchContentEmbedding(
|
||||
async matchFileEmbedding(
|
||||
embedding: number[],
|
||||
contextId: string,
|
||||
topK: number,
|
||||
|
||||
@@ -138,7 +138,7 @@ export class CopilotWorkspaceConfigModel extends BaseModel {
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
async addFileEmbeddings(
|
||||
async insertFileEmbeddings(
|
||||
workspaceId: string,
|
||||
fileId: string,
|
||||
embeddings: Embedding[]
|
||||
@@ -151,7 +151,7 @@ export class CopilotWorkspaceConfigModel extends BaseModel {
|
||||
`;
|
||||
}
|
||||
|
||||
async listWorkspaceFiles(
|
||||
async listFiles(
|
||||
workspaceId: string,
|
||||
options?: {
|
||||
includeRead?: boolean;
|
||||
@@ -168,7 +168,7 @@ export class CopilotWorkspaceConfigModel extends BaseModel {
|
||||
return files;
|
||||
}
|
||||
|
||||
async countWorkspaceFiles(workspaceId: string): Promise<number> {
|
||||
async countFiles(workspaceId: string): Promise<number> {
|
||||
const count = await this.db.aiWorkspaceFiles.count({
|
||||
where: {
|
||||
workspaceId,
|
||||
@@ -177,12 +177,16 @@ export class CopilotWorkspaceConfigModel extends BaseModel {
|
||||
return count;
|
||||
}
|
||||
|
||||
async matchWorkspaceFileEmbedding(
|
||||
async matchFileEmbedding(
|
||||
workspaceId: string,
|
||||
embedding: number[],
|
||||
topK: number,
|
||||
threshold: number
|
||||
): Promise<FileChunkSimilarity[]> {
|
||||
if (!(await this.allowEmbedding(workspaceId))) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const similarityChunks = await this.db.$queryRaw<
|
||||
Array<FileChunkSimilarity>
|
||||
>`
|
||||
@@ -195,7 +199,7 @@ export class CopilotWorkspaceConfigModel extends BaseModel {
|
||||
return similarityChunks.filter(c => Number(c.distance) <= threshold);
|
||||
}
|
||||
|
||||
async removeWorkspaceFile(workspaceId: string, fileId: string) {
|
||||
async removeFile(workspaceId: string, fileId: string) {
|
||||
// embeddings will be removed by foreign key constraint
|
||||
await this.db.aiWorkspaceFiles.deleteMany({
|
||||
where: {
|
||||
@@ -205,4 +209,8 @@ export class CopilotWorkspaceConfigModel extends BaseModel {
|
||||
});
|
||||
return true;
|
||||
}
|
||||
|
||||
private allowEmbedding(workspaceId: string) {
|
||||
return this.models.workspace.allowEmbedding(workspaceId);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -124,24 +124,38 @@ export class CopilotContextDocJob {
|
||||
|
||||
for (const chunk of chunks) {
|
||||
const embeddings = await this.embeddingClient.generateEmbeddings(chunk);
|
||||
await this.models.copilotContext.insertContentEmbedding(
|
||||
contextId,
|
||||
fileId,
|
||||
embeddings
|
||||
);
|
||||
if (contextId) {
|
||||
// for context files
|
||||
await this.models.copilotContext.insertFileEmbedding(
|
||||
contextId,
|
||||
fileId,
|
||||
embeddings
|
||||
);
|
||||
} else {
|
||||
// for workspace files
|
||||
await this.models.copilotWorkspace.insertFileEmbeddings(
|
||||
workspaceId,
|
||||
fileId,
|
||||
embeddings
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
this.event.emit('workspace.file.embed.finished', {
|
||||
contextId,
|
||||
fileId,
|
||||
chunkSize: total,
|
||||
});
|
||||
if (contextId) {
|
||||
this.event.emit('workspace.file.embed.finished', {
|
||||
contextId,
|
||||
fileId,
|
||||
chunkSize: total,
|
||||
});
|
||||
}
|
||||
} catch (error: any) {
|
||||
this.event.emit('workspace.file.embed.failed', {
|
||||
contextId,
|
||||
fileId,
|
||||
error: mapAnyError(error).message,
|
||||
});
|
||||
if (contextId) {
|
||||
this.event.emit('workspace.file.embed.failed', {
|
||||
contextId,
|
||||
fileId,
|
||||
error: mapAnyError(error).message,
|
||||
});
|
||||
}
|
||||
|
||||
// passthrough error to job queue
|
||||
throw error;
|
||||
|
||||
@@ -8,6 +8,7 @@ import {
|
||||
ContextEmbedStatus,
|
||||
ContextFile,
|
||||
ContextList,
|
||||
FileChunkSimilarity,
|
||||
Models,
|
||||
} from '../../../models';
|
||||
import { EmbeddingClient } from './types';
|
||||
@@ -176,18 +177,28 @@ export class ContextSession implements AsyncDisposable {
|
||||
topK: number = 5,
|
||||
signal?: AbortSignal,
|
||||
threshold: number = 0.7
|
||||
) {
|
||||
): Promise<FileChunkSimilarity[]> {
|
||||
const embedding = await this.client
|
||||
.getEmbeddings([content], signal)
|
||||
.then(r => r?.[0]?.embedding);
|
||||
if (!embedding) return [];
|
||||
|
||||
return this.models.copilotContext.matchContentEmbedding(
|
||||
embedding,
|
||||
this.id,
|
||||
topK,
|
||||
threshold
|
||||
);
|
||||
const [context, workspace] = await Promise.all([
|
||||
this.models.copilotContext.matchFileEmbedding(
|
||||
embedding,
|
||||
this.id,
|
||||
topK,
|
||||
threshold
|
||||
),
|
||||
this.models.copilotWorkspace.matchFileEmbedding(
|
||||
this.workspaceId,
|
||||
embedding,
|
||||
topK,
|
||||
threshold
|
||||
),
|
||||
]);
|
||||
|
||||
return this.client.reRank([...context, ...workspace]);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { File } from 'node:buffer';
|
||||
|
||||
import { CopilotContextFileNotSupported } from '../../../base';
|
||||
import { Embedding } from '../../../models';
|
||||
import { ChunkSimilarity, Embedding } from '../../../models';
|
||||
import { parseDoc } from '../../../native';
|
||||
|
||||
declare global {
|
||||
@@ -36,7 +36,7 @@ declare global {
|
||||
};
|
||||
|
||||
'copilot.embedding.files': {
|
||||
contextId: string;
|
||||
contextId?: string;
|
||||
userId: string;
|
||||
workspaceId: string;
|
||||
blobId: string;
|
||||
@@ -114,6 +114,15 @@ export abstract class EmbeddingClient {
|
||||
return embeddings.map(e => ({ ...e, index: chunks[e.index].index }));
|
||||
}
|
||||
|
||||
async reRank<Chunk extends ChunkSimilarity = ChunkSimilarity>(
|
||||
embeddings: Chunk[]
|
||||
): Promise<Chunk[]> {
|
||||
// sort by distance with ascending order
|
||||
return embeddings.sort(
|
||||
(a, b) => (a.distance ?? Infinity) - (b.distance ?? Infinity)
|
||||
);
|
||||
}
|
||||
|
||||
abstract getEmbeddings(
|
||||
input: string[],
|
||||
signal?: AbortSignal
|
||||
|
||||
@@ -135,7 +135,7 @@ export class CopilotWorkspaceEmbeddingConfigResolver {
|
||||
@Parent() config: CopilotWorkspaceConfigType,
|
||||
@Args('pagination', PaginationInput.decode) pagination: PaginationInput
|
||||
): Promise<PaginatedCopilotWorkspaceFileType> {
|
||||
const [files, totalCount] = await this.copilotWorkspace.listWorkspaceFiles(
|
||||
const [files, totalCount] = await this.copilotWorkspace.listFiles(
|
||||
config.workspaceId,
|
||||
pagination
|
||||
);
|
||||
@@ -177,12 +177,12 @@ export class CopilotWorkspaceEmbeddingConfigResolver {
|
||||
}
|
||||
|
||||
try {
|
||||
const { blobId, file } = await this.copilotWorkspace.addWorkspaceFile(
|
||||
const { blobId, file } = await this.copilotWorkspace.addFile(
|
||||
user.id,
|
||||
workspaceId,
|
||||
content
|
||||
);
|
||||
await this.copilotWorkspace.addWorkspaceFileEmbeddingQueue({
|
||||
await this.copilotWorkspace.queueFileEmbedding({
|
||||
userId: user.id,
|
||||
workspaceId,
|
||||
blobId,
|
||||
@@ -219,6 +219,6 @@ export class CopilotWorkspaceEmbeddingConfigResolver {
|
||||
.workspace(workspaceId)
|
||||
.assert('Workspace.Settings.Update');
|
||||
|
||||
return await this.copilotWorkspace.removeWorkspaceFile(workspaceId, fileId);
|
||||
return await this.copilotWorkspace.removeFile(workspaceId, fileId);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,11 +53,7 @@ export class CopilotWorkspaceService implements OnApplicationBootstrap {
|
||||
]);
|
||||
}
|
||||
|
||||
async addWorkspaceFile(
|
||||
userId: string,
|
||||
workspaceId: string,
|
||||
content: FileUpload
|
||||
) {
|
||||
async addFile(userId: string, workspaceId: string, content: FileUpload) {
|
||||
const fileName = content.filename;
|
||||
const buffer = await readStream(content.createReadStream());
|
||||
const blobId = createHash('sha256').update(buffer).digest('base64url');
|
||||
@@ -70,29 +66,25 @@ export class CopilotWorkspaceService implements OnApplicationBootstrap {
|
||||
return { blobId, file };
|
||||
}
|
||||
|
||||
async getWorkspaceFile(workspaceId: string, fileId: string) {
|
||||
async getFile(workspaceId: string, fileId: string) {
|
||||
return await this.models.copilotWorkspace.getFile(workspaceId, fileId);
|
||||
}
|
||||
|
||||
async listWorkspaceFiles(
|
||||
async listFiles(
|
||||
workspaceId: string,
|
||||
pagination?: {
|
||||
includeRead?: boolean;
|
||||
} & PaginationInput
|
||||
) {
|
||||
return await Promise.all([
|
||||
this.models.copilotWorkspace.listWorkspaceFiles(workspaceId, pagination),
|
||||
this.models.copilotWorkspace.countIgnoredDocs(workspaceId),
|
||||
this.models.copilotWorkspace.listFiles(workspaceId, pagination),
|
||||
this.models.copilotWorkspace.countFiles(workspaceId),
|
||||
]);
|
||||
}
|
||||
|
||||
async addWorkspaceFileEmbeddingQueue(
|
||||
file: Jobs['copilot.workspace.embedding.files']
|
||||
) {
|
||||
if (!this.supportEmbedding) return;
|
||||
|
||||
async queueFileEmbedding(file: Jobs['copilot.embedding.files']) {
|
||||
const { userId, workspaceId, blobId, fileId, fileName } = file;
|
||||
await this.queue.add('copilot.workspace.embedding.files', {
|
||||
await this.queue.add('copilot.embedding.files', {
|
||||
userId,
|
||||
workspaceId,
|
||||
blobId,
|
||||
@@ -101,10 +93,7 @@ export class CopilotWorkspaceService implements OnApplicationBootstrap {
|
||||
});
|
||||
}
|
||||
|
||||
async removeWorkspaceFile(workspaceId: string, fileId: string) {
|
||||
return await this.models.copilotWorkspace.removeWorkspaceFile(
|
||||
workspaceId,
|
||||
fileId
|
||||
);
|
||||
async removeFile(workspaceId: string, fileId: string) {
|
||||
return await this.models.copilotWorkspace.removeFile(workspaceId, fileId);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,15 +13,6 @@ declare global {
|
||||
jobId: string;
|
||||
};
|
||||
}
|
||||
interface Jobs {
|
||||
'copilot.workspace.embedding.files': {
|
||||
userId: string;
|
||||
workspaceId: string;
|
||||
blobId: string;
|
||||
fileId: string;
|
||||
fileName: string;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@ObjectType('CopilotWorkspaceIgnoredDoc')
|
||||
|
||||
Reference in New Issue
Block a user