mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-12 20:38:52 +00:00
feat(server): rerank for matching (#12039)
fix AI-20 fix AI-77 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced relevance-based re-ranking for embedding results, improving the accuracy of content suggestions. - Added prioritization for workspace content that matches specific document IDs in search results. - Introduced a new scoped threshold parameter to refine workspace document matching. - **Improvements** - Increased default similarity threshold for file chunk matching, resulting in more precise matches. - Doubled candidate retrieval for file and workspace chunk matching to improve result quality. - Updated sorting to prioritize context-relevant documents in workspace matches. - Explicitly included original input content in re-ranking calls for better relevance assessment. - **Bug Fixes** - Adjusted re-ranking logic to return only highly relevant results based on confidence scores. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
@@ -209,12 +209,14 @@ export class CopilotContextModel extends BaseModel {
|
||||
embedding: number[],
|
||||
workspaceId: string,
|
||||
topK: number,
|
||||
threshold: number
|
||||
threshold: number,
|
||||
matchDocIds?: string[]
|
||||
): Promise<DocChunkSimilarity[]> {
|
||||
const similarityChunks = await this.db.$queryRaw<Array<DocChunkSimilarity>>`
|
||||
SELECT "doc_id" as "docId", "chunk", "content", "embedding" <=> ${embedding}::vector as "distance"
|
||||
FROM "ai_workspace_embeddings"
|
||||
WHERE "workspace_id" = ${workspaceId}
|
||||
${matchDocIds?.length ? Prisma.sql`AND "doc_id" IN (${Prisma.join(matchDocIds)})` : Prisma.empty}
|
||||
ORDER BY "distance" ASC
|
||||
LIMIT ${topK};
|
||||
`;
|
||||
|
||||
@@ -2,11 +2,14 @@ import {
|
||||
createOpenAI,
|
||||
type OpenAIProvider as VercelOpenAIProvider,
|
||||
} from '@ai-sdk/openai';
|
||||
import { embedMany } from 'ai';
|
||||
import { embedMany, generateObject } from 'ai';
|
||||
import { chunk } from 'lodash-es';
|
||||
|
||||
import { Embedding } from '../../../models';
|
||||
import { ChunkSimilarity, Embedding } from '../../../models';
|
||||
import { OpenAIConfig } from '../providers/openai';
|
||||
import { EmbeddingClient } from './types';
|
||||
import { EmbeddingClient, getReRankSchema, ReRankResult } from './types';
|
||||
|
||||
const RERANK_MODEL = 'gpt-4.1-mini';
|
||||
|
||||
export class OpenAIEmbeddingClient extends EmbeddingClient {
|
||||
readonly #instance: VercelOpenAIProvider;
|
||||
@@ -35,6 +38,85 @@ export class OpenAIEmbeddingClient extends EmbeddingClient {
|
||||
content: input[index],
|
||||
}));
|
||||
}
|
||||
|
||||
private getRelevancePrompt<Chunk extends ChunkSimilarity = ChunkSimilarity>(
|
||||
query: string,
|
||||
embeddings: Chunk[]
|
||||
) {
|
||||
const results = embeddings
|
||||
.map(e => {
|
||||
const targetId = 'docId' in e ? e.docId : 'fileId' in e ? e.fileId : '';
|
||||
// NOTE: not xml, just for the sake of the prompt format
|
||||
return [
|
||||
'<result>',
|
||||
`<targetId>${targetId}</targetId>`,
|
||||
`<chunk>${e.chunk}</chunk>`,
|
||||
`<content>${e.content}</content>`,
|
||||
'</result>',
|
||||
];
|
||||
})
|
||||
.flat()
|
||||
.join('\n');
|
||||
return `Generate a score array based on the search results list to measure the likelihood that the information contained in the search results is useful for the report on the following topic: ${query}\n\nHere are the search results:\n<results>\n${results}\n</results>`;
|
||||
}
|
||||
|
||||
private async getEmbeddingRelevance<
|
||||
Chunk extends ChunkSimilarity = ChunkSimilarity,
|
||||
>(
|
||||
query: string,
|
||||
embeddings: Chunk[],
|
||||
signal?: AbortSignal
|
||||
): Promise<ReRankResult> {
|
||||
const prompt = this.getRelevancePrompt(query, embeddings);
|
||||
const modelInstance = this.#instance(RERANK_MODEL);
|
||||
|
||||
const {
|
||||
object: { ranks },
|
||||
} = await generateObject({
|
||||
model: modelInstance,
|
||||
prompt,
|
||||
schema: getReRankSchema(embeddings.length),
|
||||
maxRetries: 3,
|
||||
abortSignal: signal,
|
||||
});
|
||||
return ranks;
|
||||
}
|
||||
|
||||
override async reRank<Chunk extends ChunkSimilarity = ChunkSimilarity>(
|
||||
query: string,
|
||||
embeddings: Chunk[],
|
||||
topK: number,
|
||||
signal?: AbortSignal
|
||||
): Promise<Chunk[]> {
|
||||
const sortedEmbeddings = embeddings.toSorted(
|
||||
(a, b) => (a.distance ?? Infinity) - (b.distance ?? Infinity)
|
||||
);
|
||||
|
||||
const chunks = sortedEmbeddings.reduce(
|
||||
(acc, e) => {
|
||||
const targetId = 'docId' in e ? e.docId : 'fileId' in e ? e.fileId : '';
|
||||
const key = `${targetId}:${e.chunk}`;
|
||||
acc[key] = e;
|
||||
return acc;
|
||||
},
|
||||
{} as Record<string, Chunk>
|
||||
);
|
||||
|
||||
const ranks = [];
|
||||
for (const c of chunk(sortedEmbeddings, Math.min(topK, 10))) {
|
||||
const rank = await this.getEmbeddingRelevance(query, c, signal);
|
||||
ranks.push(rank);
|
||||
}
|
||||
|
||||
const highConfidenceChunks = ranks
|
||||
.flat()
|
||||
.toSorted((a, b) => b.scores.score - a.scores.score)
|
||||
.filter(r => r.scores.score > 5)
|
||||
.map(r => chunks[`${r.scores.targetId}:${r.scores.chunk}`])
|
||||
.filter(Boolean);
|
||||
|
||||
return highConfidenceChunks.slice(0, topK);
|
||||
}
|
||||
}
|
||||
|
||||
export class MockEmbeddingClient extends EmbeddingClient {
|
||||
|
||||
@@ -701,6 +701,8 @@ export class CopilotContextResolver {
|
||||
@Args('content') content: string,
|
||||
@Args('limit', { type: () => SafeIntResolver, nullable: true })
|
||||
limit?: number,
|
||||
@Args('scopedThreshold', { type: () => Float, nullable: true })
|
||||
scopedThreshold?: number,
|
||||
@Args('threshold', { type: () => Float, nullable: true })
|
||||
threshold?: number
|
||||
): Promise<ContextMatchedDocChunk[]> {
|
||||
@@ -726,6 +728,7 @@ export class CopilotContextResolver {
|
||||
content,
|
||||
limit,
|
||||
this.getSignal(ctx.req),
|
||||
scopedThreshold,
|
||||
threshold
|
||||
);
|
||||
} catch (e: any) {
|
||||
|
||||
@@ -55,6 +55,12 @@ export class ContextSession implements AsyncDisposable {
|
||||
return this.config.files.map(f => ({ ...f }));
|
||||
}
|
||||
|
||||
get docIds() {
|
||||
return Array.from(
|
||||
new Set([this.config.docs, this.config.categories].flat().map(d => d.id))
|
||||
);
|
||||
}
|
||||
|
||||
get sortedList(): ContextList {
|
||||
const { docs, files } = this.config;
|
||||
return [...docs, ...files].toSorted(
|
||||
@@ -176,7 +182,7 @@ export class ContextSession implements AsyncDisposable {
|
||||
content: string,
|
||||
topK: number = 5,
|
||||
signal?: AbortSignal,
|
||||
threshold: number = 0.7
|
||||
threshold: number = 0.85
|
||||
): Promise<FileChunkSimilarity[]> {
|
||||
const embedding = await this.client
|
||||
.getEmbeddings([content], signal)
|
||||
@@ -187,18 +193,23 @@ export class ContextSession implements AsyncDisposable {
|
||||
this.models.copilotContext.matchFileEmbedding(
|
||||
embedding,
|
||||
this.id,
|
||||
topK,
|
||||
topK * 2,
|
||||
threshold
|
||||
),
|
||||
this.models.copilotWorkspace.matchFileEmbedding(
|
||||
this.workspaceId,
|
||||
embedding,
|
||||
topK,
|
||||
topK * 2,
|
||||
threshold
|
||||
),
|
||||
]);
|
||||
|
||||
return this.client.reRank([...context, ...workspace]);
|
||||
return this.client.reRank(
|
||||
content,
|
||||
[...context, ...workspace],
|
||||
topK,
|
||||
signal
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -213,18 +224,44 @@ export class ContextSession implements AsyncDisposable {
|
||||
content: string,
|
||||
topK: number = 5,
|
||||
signal?: AbortSignal,
|
||||
threshold: number = 0.7
|
||||
scopedThreshold: number = 0.5,
|
||||
threshold: number = 0.85
|
||||
) {
|
||||
const embedding = await this.client
|
||||
.getEmbeddings([content], signal)
|
||||
.then(r => r?.[0]?.embedding);
|
||||
if (!embedding) return [];
|
||||
|
||||
return this.models.copilotContext.matchWorkspaceEmbedding(
|
||||
embedding,
|
||||
this.workspaceId,
|
||||
const docIds = this.docIds;
|
||||
const [inContext, workspace] = await Promise.all([
|
||||
this.models.copilotContext.matchWorkspaceEmbedding(
|
||||
embedding,
|
||||
this.workspaceId,
|
||||
topK * 2,
|
||||
scopedThreshold,
|
||||
docIds
|
||||
),
|
||||
this.models.copilotContext.matchWorkspaceEmbedding(
|
||||
embedding,
|
||||
this.workspaceId,
|
||||
topK * 2,
|
||||
threshold
|
||||
),
|
||||
]);
|
||||
|
||||
const result = await this.client.reRank(
|
||||
content,
|
||||
[...inContext, ...workspace],
|
||||
topK,
|
||||
threshold
|
||||
signal
|
||||
);
|
||||
|
||||
// sort result, doc recorded in context first
|
||||
const docIdSet = new Set(docIds);
|
||||
return result.toSorted(
|
||||
(a, b) =>
|
||||
(docIdSet.has(a.docId) ? -1 : 1) - (docIdSet.has(b.docId) ? -1 : 1) ||
|
||||
(a.distance || Infinity) - (b.distance || Infinity)
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import { File } from 'node:buffer';
|
||||
|
||||
import { z } from 'zod';
|
||||
|
||||
import { CopilotContextFileNotSupported } from '../../../base';
|
||||
import { ChunkSimilarity, Embedding } from '../../../models';
|
||||
import { parseDoc } from '../../../native';
|
||||
@@ -115,12 +117,15 @@ export abstract class EmbeddingClient {
|
||||
}
|
||||
|
||||
async reRank<Chunk extends ChunkSimilarity = ChunkSimilarity>(
|
||||
embeddings: Chunk[]
|
||||
_query: string,
|
||||
embeddings: Chunk[],
|
||||
topK: number,
|
||||
_signal?: AbortSignal
|
||||
): Promise<Chunk[]> {
|
||||
// sort by distance with ascending order
|
||||
return embeddings.sort(
|
||||
(a, b) => (a.distance ?? Infinity) - (b.distance ?? Infinity)
|
||||
);
|
||||
return embeddings
|
||||
.toSorted((a, b) => (a.distance ?? Infinity) - (b.distance ?? Infinity))
|
||||
.slice(0, topK);
|
||||
}
|
||||
|
||||
abstract getEmbeddings(
|
||||
@@ -128,3 +133,31 @@ export abstract class EmbeddingClient {
|
||||
signal?: AbortSignal
|
||||
): Promise<Embedding[]>;
|
||||
}
|
||||
|
||||
const ReRankItemSchema = z.object({
|
||||
scores: z.object({
|
||||
reason: z
|
||||
.string()
|
||||
.describe(
|
||||
'Think step by step, describe in 20 words the reason for giving this score.'
|
||||
),
|
||||
chunk: z.string().describe('The chunk index of the search result.'),
|
||||
targetId: z.string().describe('The id of the target.'),
|
||||
score: z
|
||||
.number()
|
||||
.min(0)
|
||||
.max(10)
|
||||
.describe(
|
||||
'The relevance score of the results should be 0-10, with 0 being the least relevant and 10 being the most relevant.'
|
||||
),
|
||||
}),
|
||||
});
|
||||
|
||||
export const getReRankSchema = (size: number) =>
|
||||
z.object({
|
||||
ranks: ReRankItemSchema.array().describe(
|
||||
`A array of scores. Make sure to score all ${size} results.`
|
||||
),
|
||||
});
|
||||
|
||||
export type ReRankResult = z.infer<ReturnType<typeof getReRankSchema>>['ranks'];
|
||||
|
||||
@@ -113,7 +113,7 @@ type CopilotContext {
|
||||
matchFiles(content: String!, limit: SafeInt, threshold: Float): [ContextMatchedFileChunk!]!
|
||||
|
||||
"""match workspace docs"""
|
||||
matchWorkspaceDocs(content: String!, limit: SafeInt, threshold: Float): [ContextMatchedDocChunk!]!
|
||||
matchWorkspaceDocs(content: String!, limit: SafeInt, scopedThreshold: Float, threshold: Float): [ContextMatchedDocChunk!]!
|
||||
|
||||
"""list tags in context"""
|
||||
tags: [CopilotContextCategory!]!
|
||||
|
||||
Reference in New Issue
Block a user