diff --git a/packages/backend/server/src/models/copilot-context.ts b/packages/backend/server/src/models/copilot-context.ts index 3ea446e05e..2a392815a4 100644 --- a/packages/backend/server/src/models/copilot-context.ts +++ b/packages/backend/server/src/models/copilot-context.ts @@ -209,12 +209,14 @@ export class CopilotContextModel extends BaseModel { embedding: number[], workspaceId: string, topK: number, - threshold: number + threshold: number, + matchDocIds?: string[] ): Promise { const similarityChunks = await this.db.$queryRaw>` SELECT "doc_id" as "docId", "chunk", "content", "embedding" <=> ${embedding}::vector as "distance" FROM "ai_workspace_embeddings" WHERE "workspace_id" = ${workspaceId} + ${matchDocIds?.length ? Prisma.sql`AND "doc_id" IN (${Prisma.join(matchDocIds)})` : Prisma.empty} ORDER BY "distance" ASC LIMIT ${topK}; `; diff --git a/packages/backend/server/src/plugins/copilot/context/embedding.ts b/packages/backend/server/src/plugins/copilot/context/embedding.ts index 845a686991..40054a3b46 100644 --- a/packages/backend/server/src/plugins/copilot/context/embedding.ts +++ b/packages/backend/server/src/plugins/copilot/context/embedding.ts @@ -2,11 +2,14 @@ import { createOpenAI, type OpenAIProvider as VercelOpenAIProvider, } from '@ai-sdk/openai'; -import { embedMany } from 'ai'; +import { embedMany, generateObject } from 'ai'; +import { chunk } from 'lodash-es'; -import { Embedding } from '../../../models'; +import { ChunkSimilarity, Embedding } from '../../../models'; import { OpenAIConfig } from '../providers/openai'; -import { EmbeddingClient } from './types'; +import { EmbeddingClient, getReRankSchema, ReRankResult } from './types'; + +const RERANK_MODEL = 'gpt-4.1-mini'; export class OpenAIEmbeddingClient extends EmbeddingClient { readonly #instance: VercelOpenAIProvider; @@ -35,6 +38,85 @@ export class OpenAIEmbeddingClient extends EmbeddingClient { content: input[index], })); } + + private getRelevancePrompt( + query: string, + embeddings: Chunk[] + ) { + const results = embeddings + .map(e => { + const targetId = 'docId' in e ? e.docId : 'fileId' in e ? e.fileId : ''; + // NOTE: not xml, just for the sake of the prompt format + return [ + '', + `${targetId}`, + `${e.chunk}`, + `${e.content}`, + '', + ]; + }) + .flat() + .join('\n'); + return `Generate a score array based on the search results list to measure the likelihood that the information contained in the search results is useful for the report on the following topic: ${query}\n\nHere are the search results:\n\n${results}\n`; + } + + private async getEmbeddingRelevance< + Chunk extends ChunkSimilarity = ChunkSimilarity, + >( + query: string, + embeddings: Chunk[], + signal?: AbortSignal + ): Promise { + const prompt = this.getRelevancePrompt(query, embeddings); + const modelInstance = this.#instance(RERANK_MODEL); + + const { + object: { ranks }, + } = await generateObject({ + model: modelInstance, + prompt, + schema: getReRankSchema(embeddings.length), + maxRetries: 3, + abortSignal: signal, + }); + return ranks; + } + + override async reRank( + query: string, + embeddings: Chunk[], + topK: number, + signal?: AbortSignal + ): Promise { + const sortedEmbeddings = embeddings.toSorted( + (a, b) => (a.distance ?? Infinity) - (b.distance ?? Infinity) + ); + + const chunks = sortedEmbeddings.reduce( + (acc, e) => { + const targetId = 'docId' in e ? e.docId : 'fileId' in e ? e.fileId : ''; + const key = `${targetId}:${e.chunk}`; + acc[key] = e; + return acc; + }, + {} as Record + ); + + const ranks = []; + for (const c of chunk(sortedEmbeddings, Math.min(topK, 10))) { + const rank = await this.getEmbeddingRelevance(query, c, signal); + ranks.push(rank); + } + + const highConfidenceChunks = ranks + .flat() + .toSorted((a, b) => b.scores.score - a.scores.score) + .filter(r => r.scores.score > 5) + .map(r => chunks[`${r.scores.targetId}:${r.scores.chunk}`]) + .filter(Boolean); + + return highConfidenceChunks.slice(0, topK); + } } export class MockEmbeddingClient extends EmbeddingClient { diff --git a/packages/backend/server/src/plugins/copilot/context/resolver.ts b/packages/backend/server/src/plugins/copilot/context/resolver.ts index c14dc0b891..faffed4cf0 100644 --- a/packages/backend/server/src/plugins/copilot/context/resolver.ts +++ b/packages/backend/server/src/plugins/copilot/context/resolver.ts @@ -701,6 +701,8 @@ export class CopilotContextResolver { @Args('content') content: string, @Args('limit', { type: () => SafeIntResolver, nullable: true }) limit?: number, + @Args('scopedThreshold', { type: () => Float, nullable: true }) + scopedThreshold?: number, @Args('threshold', { type: () => Float, nullable: true }) threshold?: number ): Promise { @@ -726,6 +728,7 @@ export class CopilotContextResolver { content, limit, this.getSignal(ctx.req), + scopedThreshold, threshold ); } catch (e: any) { diff --git a/packages/backend/server/src/plugins/copilot/context/session.ts b/packages/backend/server/src/plugins/copilot/context/session.ts index b31c480e61..c4ea10bc4d 100644 --- a/packages/backend/server/src/plugins/copilot/context/session.ts +++ b/packages/backend/server/src/plugins/copilot/context/session.ts @@ -55,6 +55,12 @@ export class ContextSession implements AsyncDisposable { return this.config.files.map(f => ({ ...f })); } + get docIds() { + return Array.from( + new Set([this.config.docs, this.config.categories].flat().map(d => d.id)) + ); + } + get sortedList(): ContextList { const { docs, files } = this.config; return [...docs, ...files].toSorted( @@ -176,7 +182,7 @@ export class ContextSession implements AsyncDisposable { content: string, topK: number = 5, signal?: AbortSignal, - threshold: number = 0.7 + threshold: number = 0.85 ): Promise { const embedding = await this.client .getEmbeddings([content], signal) @@ -187,18 +193,23 @@ export class ContextSession implements AsyncDisposable { this.models.copilotContext.matchFileEmbedding( embedding, this.id, - topK, + topK * 2, threshold ), this.models.copilotWorkspace.matchFileEmbedding( this.workspaceId, embedding, - topK, + topK * 2, threshold ), ]); - return this.client.reRank([...context, ...workspace]); + return this.client.reRank( + content, + [...context, ...workspace], + topK, + signal + ); } /** @@ -213,18 +224,44 @@ export class ContextSession implements AsyncDisposable { content: string, topK: number = 5, signal?: AbortSignal, - threshold: number = 0.7 + scopedThreshold: number = 0.5, + threshold: number = 0.85 ) { const embedding = await this.client .getEmbeddings([content], signal) .then(r => r?.[0]?.embedding); if (!embedding) return []; - return this.models.copilotContext.matchWorkspaceEmbedding( - embedding, - this.workspaceId, + const docIds = this.docIds; + const [inContext, workspace] = await Promise.all([ + this.models.copilotContext.matchWorkspaceEmbedding( + embedding, + this.workspaceId, + topK * 2, + scopedThreshold, + docIds + ), + this.models.copilotContext.matchWorkspaceEmbedding( + embedding, + this.workspaceId, + topK * 2, + threshold + ), + ]); + + const result = await this.client.reRank( + content, + [...inContext, ...workspace], topK, - threshold + signal + ); + + // sort result, doc recorded in context first + const docIdSet = new Set(docIds); + return result.toSorted( + (a, b) => + (docIdSet.has(a.docId) ? -1 : 1) - (docIdSet.has(b.docId) ? -1 : 1) || + (a.distance || Infinity) - (b.distance || Infinity) ); } diff --git a/packages/backend/server/src/plugins/copilot/context/types.ts b/packages/backend/server/src/plugins/copilot/context/types.ts index d29efac8bd..e43ff1ab58 100644 --- a/packages/backend/server/src/plugins/copilot/context/types.ts +++ b/packages/backend/server/src/plugins/copilot/context/types.ts @@ -1,5 +1,7 @@ import { File } from 'node:buffer'; +import { z } from 'zod'; + import { CopilotContextFileNotSupported } from '../../../base'; import { ChunkSimilarity, Embedding } from '../../../models'; import { parseDoc } from '../../../native'; @@ -115,12 +117,15 @@ export abstract class EmbeddingClient { } async reRank( - embeddings: Chunk[] + _query: string, + embeddings: Chunk[], + topK: number, + _signal?: AbortSignal ): Promise { // sort by distance with ascending order - return embeddings.sort( - (a, b) => (a.distance ?? Infinity) - (b.distance ?? Infinity) - ); + return embeddings + .toSorted((a, b) => (a.distance ?? Infinity) - (b.distance ?? Infinity)) + .slice(0, topK); } abstract getEmbeddings( @@ -128,3 +133,31 @@ export abstract class EmbeddingClient { signal?: AbortSignal ): Promise; } + +const ReRankItemSchema = z.object({ + scores: z.object({ + reason: z + .string() + .describe( + 'Think step by step, describe in 20 words the reason for giving this score.' + ), + chunk: z.string().describe('The chunk index of the search result.'), + targetId: z.string().describe('The id of the target.'), + score: z + .number() + .min(0) + .max(10) + .describe( + 'The relevance score of the results should be 0-10, with 0 being the least relevant and 10 being the most relevant.' + ), + }), +}); + +export const getReRankSchema = (size: number) => + z.object({ + ranks: ReRankItemSchema.array().describe( + `A array of scores. Make sure to score all ${size} results.` + ), + }); + +export type ReRankResult = z.infer>['ranks']; diff --git a/packages/backend/server/src/schema.gql b/packages/backend/server/src/schema.gql index 29c13fcc4c..e705c710c8 100644 --- a/packages/backend/server/src/schema.gql +++ b/packages/backend/server/src/schema.gql @@ -113,7 +113,7 @@ type CopilotContext { matchFiles(content: String!, limit: SafeInt, threshold: Float): [ContextMatchedFileChunk!]! """match workspace docs""" - matchWorkspaceDocs(content: String!, limit: SafeInt, threshold: Float): [ContextMatchedDocChunk!]! + matchWorkspaceDocs(content: String!, limit: SafeInt, scopedThreshold: Float, threshold: Float): [ContextMatchedDocChunk!]! """list tags in context""" tags: [CopilotContextCategory!]! diff --git a/packages/common/graphql/src/schema.ts b/packages/common/graphql/src/schema.ts index 232c4f69b0..98bf7ba685 100644 --- a/packages/common/graphql/src/schema.ts +++ b/packages/common/graphql/src/schema.ts @@ -191,6 +191,7 @@ export interface CopilotContextMatchFilesArgs { export interface CopilotContextMatchWorkspaceDocsArgs { content: Scalars['String']['input']; limit?: InputMaybe; + scopedThreshold?: InputMaybe; threshold?: InputMaybe; }