feat(server): rerank for matching (#12039)

fix AI-20
fix AI-77

<!-- This is an auto-generated comment: release notes by coderabbit.ai -->
## Summary by CodeRabbit

- **New Features**
  - Enhanced relevance-based re-ranking for embedding results, improving the accuracy of content suggestions.
  - Added prioritization for workspace content that matches specific document IDs in search results.
  - Introduced a new scoped threshold parameter to refine workspace document matching.

- **Improvements**
  - Increased default similarity threshold for file chunk matching, resulting in more precise matches.
  - Doubled candidate retrieval for file and workspace chunk matching to improve result quality.
  - Updated sorting to prioritize context-relevant documents in workspace matches.
  - Explicitly included original input content in re-ranking calls for better relevance assessment.

- **Bug Fixes**
  - Adjusted re-ranking logic to return only highly relevant results based on confidence scores.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
darkskygit
2025-05-09 03:59:03 +00:00
parent c24fde7168
commit cb49ab0f69
7 changed files with 176 additions and 18 deletions

View File

@@ -209,12 +209,14 @@ export class CopilotContextModel extends BaseModel {
embedding: number[],
workspaceId: string,
topK: number,
threshold: number
threshold: number,
matchDocIds?: string[]
): Promise<DocChunkSimilarity[]> {
const similarityChunks = await this.db.$queryRaw<Array<DocChunkSimilarity>>`
SELECT "doc_id" as "docId", "chunk", "content", "embedding" <=> ${embedding}::vector as "distance"
FROM "ai_workspace_embeddings"
WHERE "workspace_id" = ${workspaceId}
${matchDocIds?.length ? Prisma.sql`AND "doc_id" IN (${Prisma.join(matchDocIds)})` : Prisma.empty}
ORDER BY "distance" ASC
LIMIT ${topK};
`;

View File

@@ -2,11 +2,14 @@ import {
createOpenAI,
type OpenAIProvider as VercelOpenAIProvider,
} from '@ai-sdk/openai';
import { embedMany } from 'ai';
import { embedMany, generateObject } from 'ai';
import { chunk } from 'lodash-es';
import { Embedding } from '../../../models';
import { ChunkSimilarity, Embedding } from '../../../models';
import { OpenAIConfig } from '../providers/openai';
import { EmbeddingClient } from './types';
import { EmbeddingClient, getReRankSchema, ReRankResult } from './types';
const RERANK_MODEL = 'gpt-4.1-mini';
export class OpenAIEmbeddingClient extends EmbeddingClient {
readonly #instance: VercelOpenAIProvider;
@@ -35,6 +38,85 @@ export class OpenAIEmbeddingClient extends EmbeddingClient {
content: input[index],
}));
}
private getRelevancePrompt<Chunk extends ChunkSimilarity = ChunkSimilarity>(
query: string,
embeddings: Chunk[]
) {
const results = embeddings
.map(e => {
const targetId = 'docId' in e ? e.docId : 'fileId' in e ? e.fileId : '';
// NOTE: not xml, just for the sake of the prompt format
return [
'<result>',
`<targetId>${targetId}</targetId>`,
`<chunk>${e.chunk}</chunk>`,
`<content>${e.content}</content>`,
'</result>',
];
})
.flat()
.join('\n');
return `Generate a score array based on the search results list to measure the likelihood that the information contained in the search results is useful for the report on the following topic: ${query}\n\nHere are the search results:\n<results>\n${results}\n</results>`;
}
private async getEmbeddingRelevance<
Chunk extends ChunkSimilarity = ChunkSimilarity,
>(
query: string,
embeddings: Chunk[],
signal?: AbortSignal
): Promise<ReRankResult> {
const prompt = this.getRelevancePrompt(query, embeddings);
const modelInstance = this.#instance(RERANK_MODEL);
const {
object: { ranks },
} = await generateObject({
model: modelInstance,
prompt,
schema: getReRankSchema(embeddings.length),
maxRetries: 3,
abortSignal: signal,
});
return ranks;
}
override async reRank<Chunk extends ChunkSimilarity = ChunkSimilarity>(
query: string,
embeddings: Chunk[],
topK: number,
signal?: AbortSignal
): Promise<Chunk[]> {
const sortedEmbeddings = embeddings.toSorted(
(a, b) => (a.distance ?? Infinity) - (b.distance ?? Infinity)
);
const chunks = sortedEmbeddings.reduce(
(acc, e) => {
const targetId = 'docId' in e ? e.docId : 'fileId' in e ? e.fileId : '';
const key = `${targetId}:${e.chunk}`;
acc[key] = e;
return acc;
},
{} as Record<string, Chunk>
);
const ranks = [];
for (const c of chunk(sortedEmbeddings, Math.min(topK, 10))) {
const rank = await this.getEmbeddingRelevance(query, c, signal);
ranks.push(rank);
}
const highConfidenceChunks = ranks
.flat()
.toSorted((a, b) => b.scores.score - a.scores.score)
.filter(r => r.scores.score > 5)
.map(r => chunks[`${r.scores.targetId}:${r.scores.chunk}`])
.filter(Boolean);
return highConfidenceChunks.slice(0, topK);
}
}
export class MockEmbeddingClient extends EmbeddingClient {

View File

@@ -701,6 +701,8 @@ export class CopilotContextResolver {
@Args('content') content: string,
@Args('limit', { type: () => SafeIntResolver, nullable: true })
limit?: number,
@Args('scopedThreshold', { type: () => Float, nullable: true })
scopedThreshold?: number,
@Args('threshold', { type: () => Float, nullable: true })
threshold?: number
): Promise<ContextMatchedDocChunk[]> {
@@ -726,6 +728,7 @@ export class CopilotContextResolver {
content,
limit,
this.getSignal(ctx.req),
scopedThreshold,
threshold
);
} catch (e: any) {

View File

@@ -55,6 +55,12 @@ export class ContextSession implements AsyncDisposable {
return this.config.files.map(f => ({ ...f }));
}
get docIds() {
return Array.from(
new Set([this.config.docs, this.config.categories].flat().map(d => d.id))
);
}
get sortedList(): ContextList {
const { docs, files } = this.config;
return [...docs, ...files].toSorted(
@@ -176,7 +182,7 @@ export class ContextSession implements AsyncDisposable {
content: string,
topK: number = 5,
signal?: AbortSignal,
threshold: number = 0.7
threshold: number = 0.85
): Promise<FileChunkSimilarity[]> {
const embedding = await this.client
.getEmbeddings([content], signal)
@@ -187,18 +193,23 @@ export class ContextSession implements AsyncDisposable {
this.models.copilotContext.matchFileEmbedding(
embedding,
this.id,
topK,
topK * 2,
threshold
),
this.models.copilotWorkspace.matchFileEmbedding(
this.workspaceId,
embedding,
topK,
topK * 2,
threshold
),
]);
return this.client.reRank([...context, ...workspace]);
return this.client.reRank(
content,
[...context, ...workspace],
topK,
signal
);
}
/**
@@ -213,18 +224,44 @@ export class ContextSession implements AsyncDisposable {
content: string,
topK: number = 5,
signal?: AbortSignal,
threshold: number = 0.7
scopedThreshold: number = 0.5,
threshold: number = 0.85
) {
const embedding = await this.client
.getEmbeddings([content], signal)
.then(r => r?.[0]?.embedding);
if (!embedding) return [];
return this.models.copilotContext.matchWorkspaceEmbedding(
embedding,
this.workspaceId,
const docIds = this.docIds;
const [inContext, workspace] = await Promise.all([
this.models.copilotContext.matchWorkspaceEmbedding(
embedding,
this.workspaceId,
topK * 2,
scopedThreshold,
docIds
),
this.models.copilotContext.matchWorkspaceEmbedding(
embedding,
this.workspaceId,
topK * 2,
threshold
),
]);
const result = await this.client.reRank(
content,
[...inContext, ...workspace],
topK,
threshold
signal
);
// sort result, doc recorded in context first
const docIdSet = new Set(docIds);
return result.toSorted(
(a, b) =>
(docIdSet.has(a.docId) ? -1 : 1) - (docIdSet.has(b.docId) ? -1 : 1) ||
(a.distance || Infinity) - (b.distance || Infinity)
);
}

View File

@@ -1,5 +1,7 @@
import { File } from 'node:buffer';
import { z } from 'zod';
import { CopilotContextFileNotSupported } from '../../../base';
import { ChunkSimilarity, Embedding } from '../../../models';
import { parseDoc } from '../../../native';
@@ -115,12 +117,15 @@ export abstract class EmbeddingClient {
}
async reRank<Chunk extends ChunkSimilarity = ChunkSimilarity>(
embeddings: Chunk[]
_query: string,
embeddings: Chunk[],
topK: number,
_signal?: AbortSignal
): Promise<Chunk[]> {
// sort by distance with ascending order
return embeddings.sort(
(a, b) => (a.distance ?? Infinity) - (b.distance ?? Infinity)
);
return embeddings
.toSorted((a, b) => (a.distance ?? Infinity) - (b.distance ?? Infinity))
.slice(0, topK);
}
abstract getEmbeddings(
@@ -128,3 +133,31 @@ export abstract class EmbeddingClient {
signal?: AbortSignal
): Promise<Embedding[]>;
}
const ReRankItemSchema = z.object({
scores: z.object({
reason: z
.string()
.describe(
'Think step by step, describe in 20 words the reason for giving this score.'
),
chunk: z.string().describe('The chunk index of the search result.'),
targetId: z.string().describe('The id of the target.'),
score: z
.number()
.min(0)
.max(10)
.describe(
'The relevance score of the results should be 0-10, with 0 being the least relevant and 10 being the most relevant.'
),
}),
});
export const getReRankSchema = (size: number) =>
z.object({
ranks: ReRankItemSchema.array().describe(
`A array of scores. Make sure to score all ${size} results.`
),
});
export type ReRankResult = z.infer<ReturnType<typeof getReRankSchema>>['ranks'];

View File

@@ -113,7 +113,7 @@ type CopilotContext {
matchFiles(content: String!, limit: SafeInt, threshold: Float): [ContextMatchedFileChunk!]!
"""match workspace docs"""
matchWorkspaceDocs(content: String!, limit: SafeInt, threshold: Float): [ContextMatchedDocChunk!]!
matchWorkspaceDocs(content: String!, limit: SafeInt, scopedThreshold: Float, threshold: Float): [ContextMatchedDocChunk!]!
"""list tags in context"""
tags: [CopilotContextCategory!]!