feat(server): improve embedding & rerank speed (#12666)

fix AI-109
This commit is contained in:
darkskygit
2025-06-03 11:12:34 +00:00
parent 2288cbe54d
commit 44e1eb503f
12 changed files with 232 additions and 100 deletions

View File

@@ -1,41 +1,75 @@
import {
createOpenAI,
type OpenAIProvider as VercelOpenAIProvider,
} from '@ai-sdk/openai';
import { embedMany, generateObject } from 'ai';
import { Logger } from '@nestjs/common';
import { chunk } from 'lodash-es';
import { ChunkSimilarity, Embedding } from '../../../models';
import { OpenAIConfig } from '../providers/openai';
import {
CopilotPromptNotFound,
CopilotProviderNotSupported,
} from '../../../base';
import type { ChunkSimilarity, Embedding } from '../../../models';
import type { PromptService } from '../prompt';
import {
type CopilotProvider,
type CopilotProviderFactory,
type ModelFullConditions,
ModelInputType,
ModelOutputType,
} from '../providers';
import {
EMBEDDING_DIMENSIONS,
EmbeddingClient,
getReRankSchema,
ReRankResult,
type ReRankResult,
} from './types';
const RERANK_MODEL = 'gpt-4.1-mini';
const RERANK_PROMPT = 'Rerank results';
export class OpenAIEmbeddingClient extends EmbeddingClient {
readonly #instance: VercelOpenAIProvider;
export class ProductionEmbeddingClient extends EmbeddingClient {
private readonly logger = new Logger(ProductionEmbeddingClient.name);
constructor(config: OpenAIConfig) {
constructor(
private readonly providerFactory: CopilotProviderFactory,
private readonly prompt: PromptService
) {
super();
this.#instance = createOpenAI({
apiKey: config.apiKey,
baseURL: config.baseUrl,
}
override async configured(): Promise<boolean> {
const embedding = await this.providerFactory.getProvider({
outputType: ModelOutputType.Embedding,
});
const result = Boolean(embedding);
if (!result) {
this.logger.warn(
'Copilot embedding client is not configured properly, please check your configuration.'
);
}
return result;
}
private async getProvider(
cond: ModelFullConditions
): Promise<CopilotProvider> {
const provider = await this.providerFactory.getProvider(cond);
if (!provider) {
throw new CopilotProviderNotSupported({
provider: 'embedding',
kind: cond.outputType || 'embedding',
});
}
return provider;
}
async getEmbeddings(input: string[]): Promise<Embedding[]> {
const modelInstance = this.#instance.embedding('text-embedding-3-large', {
dimensions: EMBEDDING_DIMENSIONS,
const provider = await this.getProvider({
outputType: ModelOutputType.Embedding,
});
this.logger.verbose(`Using provider ${provider.type} for embedding`, input);
const { embeddings } = await embedMany({
model: modelInstance,
values: input,
});
const embeddings = await provider.embedding(
{ inputTypes: [ModelInputType.Text] },
input,
{ dimensions: EMBEDDING_DIMENSIONS }
);
return Array.from(embeddings.entries()).map(([index, embedding]) => ({
index,
@@ -44,27 +78,6 @@ export class OpenAIEmbeddingClient extends EmbeddingClient {
}));
}
private getRelevancePrompt<Chunk extends ChunkSimilarity = ChunkSimilarity>(
query: string,
embeddings: Chunk[]
) {
const results = embeddings
.map(e => {
const targetId = 'docId' in e ? e.docId : 'fileId' in e ? e.fileId : '';
// NOTE: not xml, just for the sake of the prompt format
return [
'<result>',
`<targetId>${targetId}</targetId>`,
`<chunk>${e.chunk}</chunk>`,
`<content>${e.content}</content>`,
'</result>',
];
})
.flat()
.join('\n');
return `Generate a score array based on the search results list to measure the likelihood that the information contained in the search results is useful for the report on the following topic: ${query}\n\nHere are the search results:\n<results>\n${results}\n</results>`;
}
private async getEmbeddingRelevance<
Chunk extends ChunkSimilarity = ChunkSimilarity,
>(
@@ -72,19 +85,36 @@ export class OpenAIEmbeddingClient extends EmbeddingClient {
embeddings: Chunk[],
signal?: AbortSignal
): Promise<ReRankResult> {
const prompt = this.getRelevancePrompt(query, embeddings);
const modelInstance = this.#instance(RERANK_MODEL);
if (!embeddings.length) return [];
const {
object: { ranks },
} = await generateObject({
model: modelInstance,
prompt,
schema: getReRankSchema(embeddings.length),
maxRetries: 3,
abortSignal: signal,
});
return ranks;
const prompt = await this.prompt.get(RERANK_PROMPT);
if (!prompt) {
throw new CopilotPromptNotFound({ name: RERANK_PROMPT });
}
const provider = await this.getProvider({ modelId: prompt.model });
const schema = getReRankSchema(embeddings.length);
const ranks = await provider.structure(
{ modelId: prompt.model },
prompt.finish({
query,
results: embeddings.map(e => {
const targetId =
'docId' in e ? e.docId : 'fileId' in e ? e.fileId : '';
return { targetId, chunk: e.chunk, content: e.content };
}),
schema,
}),
{ maxRetries: 3, signal }
);
try {
return schema.parse(JSON.parse(ranks)).ranks;
} catch (error) {
this.logger.error('Failed to parse rerank results', error);
// silent error, will fallback to default sorting in parent method
return [];
}
}
override async reRank<Chunk extends ChunkSimilarity = ChunkSimilarity>(
@@ -110,6 +140,10 @@ export class OpenAIEmbeddingClient extends EmbeddingClient {
const ranks = [];
for (const c of chunk(sortedEmbeddings, Math.min(topK, 10))) {
const rank = await this.getEmbeddingRelevance(query, c, signal);
if (c.length !== rank.length) {
// llm return wrong result, fallback to default sorting
return super.reRank(query, embeddings, topK, signal);
}
ranks.push(rank);
}
@@ -124,6 +158,21 @@ export class OpenAIEmbeddingClient extends EmbeddingClient {
}
}
let EMBEDDING_CLIENT: EmbeddingClient | undefined;
export async function getEmbeddingClient(
providerFactory: CopilotProviderFactory,
prompt: PromptService
): Promise<EmbeddingClient | undefined> {
if (EMBEDDING_CLIENT) {
return EMBEDDING_CLIENT;
}
const client = new ProductionEmbeddingClient(providerFactory, prompt);
if (await client.configured()) {
EMBEDDING_CLIENT = client;
}
return EMBEDDING_CLIENT;
}
export class MockEmbeddingClient extends EmbeddingClient {
async getEmbeddings(input: string[]): Promise<Embedding[]> {
return input.map((_, i) => ({