feat(server): improve rerank performance (#12775)

fix AI-183
This commit is contained in:
DarkSky
2025-06-12 13:31:01 +08:00
committed by GitHub
parent 2d17c265ca
commit ed56f076ed
5 changed files with 75 additions and 36 deletions

View File

@@ -1,5 +1,4 @@
import { Logger } from '@nestjs/common';
import { chunk } from 'lodash-es';
import {
CopilotPromptNotFound,
@@ -63,7 +62,9 @@ export class ProductionEmbeddingClient extends EmbeddingClient {
const provider = await this.getProvider({
outputType: ModelOutputType.Embedding,
});
this.logger.verbose(`Using provider ${provider.type} for embedding`, input);
this.logger.verbose(
`Using provider ${provider.type} for embedding: ${input.join(', ')}`
);
const embeddings = await provider.embedding(
{ inputTypes: [ModelInputType.Text] },
@@ -78,6 +79,14 @@ export class ProductionEmbeddingClient extends EmbeddingClient {
}));
}
private getTargetId<T extends ChunkSimilarity>(embedding: T) {
return 'docId' in embedding
? embedding.docId
: 'fileId' in embedding
? embedding.fileId
: '';
}
private async getEmbeddingRelevance<
Chunk extends ChunkSimilarity = ChunkSimilarity,
>(
@@ -98,11 +107,11 @@ export class ProductionEmbeddingClient extends EmbeddingClient {
{ modelId: prompt.model },
prompt.finish({
query,
results: embeddings.map(e => {
const targetId =
'docId' in e ? e.docId : 'fileId' in e ? e.fileId : '';
return { targetId, chunk: e.chunk, content: e.content };
}),
results: embeddings.map(e => ({
targetId: this.getTargetId(e),
chunk: e.chunk,
content: e.content,
})),
schema,
}),
{ maxRetries: 3, signal }
@@ -123,7 +132,19 @@ export class ProductionEmbeddingClient extends EmbeddingClient {
topK: number,
signal?: AbortSignal
): Promise<Chunk[]> {
const sortedEmbeddings = embeddings.toSorted(
// search in context and workspace may find same chunks, de-duplicate them
const { deduped: dedupedEmbeddings } = embeddings.reduce(
(acc, e) => {
const key = `${this.getTargetId(e)}:${e.chunk}`;
if (!acc.seen.has(key)) {
acc.seen.add(key);
acc.deduped.push(e);
}
return acc;
},
{ deduped: [] as Chunk[], seen: new Set<string>() }
);
const sortedEmbeddings = dedupedEmbeddings.toSorted(
(a, b) => (a.distance ?? Infinity) - (b.distance ?? Infinity)
);
@@ -137,24 +158,36 @@ export class ProductionEmbeddingClient extends EmbeddingClient {
{} as Record<string, Chunk>
);
const ranks = [];
for (const c of chunk(sortedEmbeddings, Math.min(topK, 10))) {
const rank = await this.getEmbeddingRelevance(query, c, signal);
if (c.length !== rank.length) {
try {
// 4.1 mini's context windows large enough to handle all embeddings
const ranks = await this.getEmbeddingRelevance(
query,
sortedEmbeddings,
signal
);
if (sortedEmbeddings.length !== ranks.length) {
// llm return wrong result, fallback to default sorting
return super.reRank(query, embeddings, topK, signal);
this.logger.warn(
`Batch size mismatch: expected ${sortedEmbeddings.length}, got ${ranks.length}`
);
return await super.reRank(query, dedupedEmbeddings, topK, signal);
}
ranks.push(rank);
const highConfidenceChunks = ranks
.flat()
.toSorted((a, b) => b.scores.score - a.scores.score)
.filter(r => r.scores.score > 5)
.map(r => chunks[`${r.scores.targetId}:${r.scores.chunk}`])
.filter(Boolean);
this.logger.verbose(
`ReRank completed: ${highConfidenceChunks.length} high-confidence results found`
);
return highConfidenceChunks.slice(0, topK);
} catch (error) {
this.logger.warn('ReRank failed, falling back to default sorting', error);
return await super.reRank(query, dedupedEmbeddings, topK, signal);
}
const highConfidenceChunks = ranks
.flat()
.toSorted((a, b) => b.scores.score - a.scores.score)
.filter(r => r.scores.score > 5)
.map(r => chunks[`${r.scores.targetId}:${r.scores.chunk}`])
.filter(Boolean);
return highConfidenceChunks.slice(0, topK);
}
}

View File

@@ -140,10 +140,9 @@ export class CopilotContextDocJob {
if (enableDocEmbedding) {
const toBeEmbedDocIds =
await this.models.copilotWorkspace.findDocsToEmbed(workspaceId);
this.logger.debug('Trigger embedding for docs', {
workspaceId,
toBeEmbedDocs: toBeEmbedDocIds.length,
});
this.logger.debug(
`Trigger embedding for ${toBeEmbedDocIds.length} docs in workspace ${workspaceId}`
);
for (const docId of toBeEmbedDocIds) {
await this.queue.add(
'copilot.embedding.docs',

View File

@@ -158,14 +158,15 @@ export class CopilotContextService implements OnApplicationBootstrap {
const embedding = await this.embeddingClient.getEmbedding(content, signal);
if (!embedding) return [];
const chunks = await this.models.copilotWorkspace.matchFileEmbedding(
const fileChunks = await this.models.copilotWorkspace.matchFileEmbedding(
workspaceId,
embedding,
topK * 2,
threshold
);
if (!fileChunks.length) return [];
return this.embeddingClient.reRank(content, chunks, topK, signal);
return this.embeddingClient.reRank(content, fileChunks, topK, signal);
}
async matchWorkspaceDocs(
@@ -179,14 +180,16 @@ export class CopilotContextService implements OnApplicationBootstrap {
const embedding = await this.embeddingClient.getEmbedding(content, signal);
if (!embedding) return [];
const workspace = await this.models.copilotContext.matchWorkspaceEmbedding(
embedding,
workspaceId,
topK * 2,
threshold
);
const workspaceChunks =
await this.models.copilotContext.matchWorkspaceEmbedding(
embedding,
workspaceId,
topK * 2,
threshold
);
if (!workspaceChunks.length) return [];
return this.embeddingClient.reRank(content, workspace, topK);
return this.embeddingClient.reRank(content, workspaceChunks, topK, signal);
}
@OnEvent('workspace.doc.embed.failed')

View File

@@ -356,6 +356,7 @@ Consider various factors such as content alignment with the query, source credib
- Evaluate the alignment with potential user intent based on the query.
3. **Scoring**:
- Assign a score from 1 to 10 based on the overall relevance and quality, with 10 being the most relevant.
- Each chunk returns a score and should not be mixed together.
# Output Format