From ed56f076ed9329c8b01f8fd71cfd59144c5d50cf Mon Sep 17 00:00:00 2001 From: DarkSky <25152247+darkskygit@users.noreply.github.com> Date: Thu, 12 Jun 2025 13:31:01 +0800 Subject: [PATCH] feat(server): improve rerank performance (#12775) fix AI-183 --- .../src/plugins/copilot/context/embedding.ts | 79 +++++++++++++------ .../server/src/plugins/copilot/context/job.ts | 7 +- .../src/plugins/copilot/context/service.ts | 21 ++--- .../src/plugins/copilot/prompt/prompts.ts | 1 + .../e2e/utils/editor-utils.ts | 3 + 5 files changed, 75 insertions(+), 36 deletions(-) diff --git a/packages/backend/server/src/plugins/copilot/context/embedding.ts b/packages/backend/server/src/plugins/copilot/context/embedding.ts index 06c5d6fd17..240d0c26c9 100644 --- a/packages/backend/server/src/plugins/copilot/context/embedding.ts +++ b/packages/backend/server/src/plugins/copilot/context/embedding.ts @@ -1,5 +1,4 @@ import { Logger } from '@nestjs/common'; -import { chunk } from 'lodash-es'; import { CopilotPromptNotFound, @@ -63,7 +62,9 @@ export class ProductionEmbeddingClient extends EmbeddingClient { const provider = await this.getProvider({ outputType: ModelOutputType.Embedding, }); - this.logger.verbose(`Using provider ${provider.type} for embedding`, input); + this.logger.verbose( + `Using provider ${provider.type} for embedding: ${input.join(', ')}` + ); const embeddings = await provider.embedding( { inputTypes: [ModelInputType.Text] }, @@ -78,6 +79,14 @@ export class ProductionEmbeddingClient extends EmbeddingClient { })); } + private getTargetId(embedding: T) { + return 'docId' in embedding + ? embedding.docId + : 'fileId' in embedding + ? embedding.fileId + : ''; + } + private async getEmbeddingRelevance< Chunk extends ChunkSimilarity = ChunkSimilarity, >( @@ -98,11 +107,11 @@ export class ProductionEmbeddingClient extends EmbeddingClient { { modelId: prompt.model }, prompt.finish({ query, - results: embeddings.map(e => { - const targetId = - 'docId' in e ? e.docId : 'fileId' in e ? e.fileId : ''; - return { targetId, chunk: e.chunk, content: e.content }; - }), + results: embeddings.map(e => ({ + targetId: this.getTargetId(e), + chunk: e.chunk, + content: e.content, + })), schema, }), { maxRetries: 3, signal } @@ -123,7 +132,19 @@ export class ProductionEmbeddingClient extends EmbeddingClient { topK: number, signal?: AbortSignal ): Promise { - const sortedEmbeddings = embeddings.toSorted( + // search in context and workspace may find same chunks, de-duplicate them + const { deduped: dedupedEmbeddings } = embeddings.reduce( + (acc, e) => { + const key = `${this.getTargetId(e)}:${e.chunk}`; + if (!acc.seen.has(key)) { + acc.seen.add(key); + acc.deduped.push(e); + } + return acc; + }, + { deduped: [] as Chunk[], seen: new Set() } + ); + const sortedEmbeddings = dedupedEmbeddings.toSorted( (a, b) => (a.distance ?? Infinity) - (b.distance ?? Infinity) ); @@ -137,24 +158,36 @@ export class ProductionEmbeddingClient extends EmbeddingClient { {} as Record ); - const ranks = []; - for (const c of chunk(sortedEmbeddings, Math.min(topK, 10))) { - const rank = await this.getEmbeddingRelevance(query, c, signal); - if (c.length !== rank.length) { + try { + // 4.1 mini's context windows large enough to handle all embeddings + const ranks = await this.getEmbeddingRelevance( + query, + sortedEmbeddings, + signal + ); + if (sortedEmbeddings.length !== ranks.length) { // llm return wrong result, fallback to default sorting - return super.reRank(query, embeddings, topK, signal); + this.logger.warn( + `Batch size mismatch: expected ${sortedEmbeddings.length}, got ${ranks.length}` + ); + return await super.reRank(query, dedupedEmbeddings, topK, signal); } - ranks.push(rank); + + const highConfidenceChunks = ranks + .flat() + .toSorted((a, b) => b.scores.score - a.scores.score) + .filter(r => r.scores.score > 5) + .map(r => chunks[`${r.scores.targetId}:${r.scores.chunk}`]) + .filter(Boolean); + + this.logger.verbose( + `ReRank completed: ${highConfidenceChunks.length} high-confidence results found` + ); + return highConfidenceChunks.slice(0, topK); + } catch (error) { + this.logger.warn('ReRank failed, falling back to default sorting', error); + return await super.reRank(query, dedupedEmbeddings, topK, signal); } - - const highConfidenceChunks = ranks - .flat() - .toSorted((a, b) => b.scores.score - a.scores.score) - .filter(r => r.scores.score > 5) - .map(r => chunks[`${r.scores.targetId}:${r.scores.chunk}`]) - .filter(Boolean); - - return highConfidenceChunks.slice(0, topK); } } diff --git a/packages/backend/server/src/plugins/copilot/context/job.ts b/packages/backend/server/src/plugins/copilot/context/job.ts index 5244ec1fbf..9562b786ee 100644 --- a/packages/backend/server/src/plugins/copilot/context/job.ts +++ b/packages/backend/server/src/plugins/copilot/context/job.ts @@ -140,10 +140,9 @@ export class CopilotContextDocJob { if (enableDocEmbedding) { const toBeEmbedDocIds = await this.models.copilotWorkspace.findDocsToEmbed(workspaceId); - this.logger.debug('Trigger embedding for docs', { - workspaceId, - toBeEmbedDocs: toBeEmbedDocIds.length, - }); + this.logger.debug( + `Trigger embedding for ${toBeEmbedDocIds.length} docs in workspace ${workspaceId}` + ); for (const docId of toBeEmbedDocIds) { await this.queue.add( 'copilot.embedding.docs', diff --git a/packages/backend/server/src/plugins/copilot/context/service.ts b/packages/backend/server/src/plugins/copilot/context/service.ts index 60e1870311..ab727712e0 100644 --- a/packages/backend/server/src/plugins/copilot/context/service.ts +++ b/packages/backend/server/src/plugins/copilot/context/service.ts @@ -158,14 +158,15 @@ export class CopilotContextService implements OnApplicationBootstrap { const embedding = await this.embeddingClient.getEmbedding(content, signal); if (!embedding) return []; - const chunks = await this.models.copilotWorkspace.matchFileEmbedding( + const fileChunks = await this.models.copilotWorkspace.matchFileEmbedding( workspaceId, embedding, topK * 2, threshold ); + if (!fileChunks.length) return []; - return this.embeddingClient.reRank(content, chunks, topK, signal); + return this.embeddingClient.reRank(content, fileChunks, topK, signal); } async matchWorkspaceDocs( @@ -179,14 +180,16 @@ export class CopilotContextService implements OnApplicationBootstrap { const embedding = await this.embeddingClient.getEmbedding(content, signal); if (!embedding) return []; - const workspace = await this.models.copilotContext.matchWorkspaceEmbedding( - embedding, - workspaceId, - topK * 2, - threshold - ); + const workspaceChunks = + await this.models.copilotContext.matchWorkspaceEmbedding( + embedding, + workspaceId, + topK * 2, + threshold + ); + if (!workspaceChunks.length) return []; - return this.embeddingClient.reRank(content, workspace, topK); + return this.embeddingClient.reRank(content, workspaceChunks, topK, signal); } @OnEvent('workspace.doc.embed.failed') diff --git a/packages/backend/server/src/plugins/copilot/prompt/prompts.ts b/packages/backend/server/src/plugins/copilot/prompt/prompts.ts index 0de8a09fb4..a1e15da517 100644 --- a/packages/backend/server/src/plugins/copilot/prompt/prompts.ts +++ b/packages/backend/server/src/plugins/copilot/prompt/prompts.ts @@ -356,6 +356,7 @@ Consider various factors such as content alignment with the query, source credib - Evaluate the alignment with potential user intent based on the query. 3. **Scoring**: - Assign a score from 1 to 10 based on the overall relevance and quality, with 10 being the most relevant. + - Each chunk returns a score and should not be mixed together. # Output Format diff --git a/tests/affine-cloud-copilot/e2e/utils/editor-utils.ts b/tests/affine-cloud-copilot/e2e/utils/editor-utils.ts index af079ef75c..b7c80cdb0d 100644 --- a/tests/affine-cloud-copilot/e2e/utils/editor-utils.ts +++ b/tests/affine-cloud-copilot/e2e/utils/editor-utils.ts @@ -98,6 +98,9 @@ export class EditorUtils { const responsesMenu = answer.getByTestId('answer-responses'); await responsesMenu.isVisible(); await responsesMenu.scrollIntoViewIfNeeded({ timeout: 60000 }); + await responsesMenu + .getByTestId('answer-insert-below-loading') + .waitFor({ state: 'hidden' }); if (await responsesMenu.getByTestId('answer-insert-below').isVisible()) { responses.add('insert-below');