feat(server): improve embedding & rerank speed (#12666)

fix AI-109
This commit is contained in:
darkskygit
2025-06-03 11:12:34 +00:00
parent 2288cbe54d
commit 44e1eb503f
12 changed files with 232 additions and 100 deletions

View File

@@ -79,7 +79,7 @@ export class MockCopilotProvider extends OpenAIProvider {
capabilities: [
{
input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Text],
output: [ModelOutputType.Text, ModelOutputType.Structured],
},
],
},

View File

@@ -1,41 +1,75 @@
import {
createOpenAI,
type OpenAIProvider as VercelOpenAIProvider,
} from '@ai-sdk/openai';
import { embedMany, generateObject } from 'ai';
import { Logger } from '@nestjs/common';
import { chunk } from 'lodash-es';
import { ChunkSimilarity, Embedding } from '../../../models';
import { OpenAIConfig } from '../providers/openai';
import {
CopilotPromptNotFound,
CopilotProviderNotSupported,
} from '../../../base';
import type { ChunkSimilarity, Embedding } from '../../../models';
import type { PromptService } from '../prompt';
import {
type CopilotProvider,
type CopilotProviderFactory,
type ModelFullConditions,
ModelInputType,
ModelOutputType,
} from '../providers';
import {
EMBEDDING_DIMENSIONS,
EmbeddingClient,
getReRankSchema,
ReRankResult,
type ReRankResult,
} from './types';
const RERANK_MODEL = 'gpt-4.1-mini';
const RERANK_PROMPT = 'Rerank results';
export class OpenAIEmbeddingClient extends EmbeddingClient {
readonly #instance: VercelOpenAIProvider;
export class ProductionEmbeddingClient extends EmbeddingClient {
private readonly logger = new Logger(ProductionEmbeddingClient.name);
constructor(config: OpenAIConfig) {
constructor(
private readonly providerFactory: CopilotProviderFactory,
private readonly prompt: PromptService
) {
super();
this.#instance = createOpenAI({
apiKey: config.apiKey,
baseURL: config.baseUrl,
}
override async configured(): Promise<boolean> {
const embedding = await this.providerFactory.getProvider({
outputType: ModelOutputType.Embedding,
});
const result = Boolean(embedding);
if (!result) {
this.logger.warn(
'Copilot embedding client is not configured properly, please check your configuration.'
);
}
return result;
}
private async getProvider(
cond: ModelFullConditions
): Promise<CopilotProvider> {
const provider = await this.providerFactory.getProvider(cond);
if (!provider) {
throw new CopilotProviderNotSupported({
provider: 'embedding',
kind: cond.outputType || 'embedding',
});
}
return provider;
}
async getEmbeddings(input: string[]): Promise<Embedding[]> {
const modelInstance = this.#instance.embedding('text-embedding-3-large', {
dimensions: EMBEDDING_DIMENSIONS,
const provider = await this.getProvider({
outputType: ModelOutputType.Embedding,
});
this.logger.verbose(`Using provider ${provider.type} for embedding`, input);
const { embeddings } = await embedMany({
model: modelInstance,
values: input,
});
const embeddings = await provider.embedding(
{ inputTypes: [ModelInputType.Text] },
input,
{ dimensions: EMBEDDING_DIMENSIONS }
);
return Array.from(embeddings.entries()).map(([index, embedding]) => ({
index,
@@ -44,27 +78,6 @@ export class OpenAIEmbeddingClient extends EmbeddingClient {
}));
}
private getRelevancePrompt<Chunk extends ChunkSimilarity = ChunkSimilarity>(
query: string,
embeddings: Chunk[]
) {
const results = embeddings
.map(e => {
const targetId = 'docId' in e ? e.docId : 'fileId' in e ? e.fileId : '';
// NOTE: not xml, just for the sake of the prompt format
return [
'<result>',
`<targetId>${targetId}</targetId>`,
`<chunk>${e.chunk}</chunk>`,
`<content>${e.content}</content>`,
'</result>',
];
})
.flat()
.join('\n');
return `Generate a score array based on the search results list to measure the likelihood that the information contained in the search results is useful for the report on the following topic: ${query}\n\nHere are the search results:\n<results>\n${results}\n</results>`;
}
private async getEmbeddingRelevance<
Chunk extends ChunkSimilarity = ChunkSimilarity,
>(
@@ -72,19 +85,36 @@ export class OpenAIEmbeddingClient extends EmbeddingClient {
embeddings: Chunk[],
signal?: AbortSignal
): Promise<ReRankResult> {
const prompt = this.getRelevancePrompt(query, embeddings);
const modelInstance = this.#instance(RERANK_MODEL);
if (!embeddings.length) return [];
const {
object: { ranks },
} = await generateObject({
model: modelInstance,
prompt,
schema: getReRankSchema(embeddings.length),
maxRetries: 3,
abortSignal: signal,
});
return ranks;
const prompt = await this.prompt.get(RERANK_PROMPT);
if (!prompt) {
throw new CopilotPromptNotFound({ name: RERANK_PROMPT });
}
const provider = await this.getProvider({ modelId: prompt.model });
const schema = getReRankSchema(embeddings.length);
const ranks = await provider.structure(
{ modelId: prompt.model },
prompt.finish({
query,
results: embeddings.map(e => {
const targetId =
'docId' in e ? e.docId : 'fileId' in e ? e.fileId : '';
return { targetId, chunk: e.chunk, content: e.content };
}),
schema,
}),
{ maxRetries: 3, signal }
);
try {
return schema.parse(JSON.parse(ranks)).ranks;
} catch (error) {
this.logger.error('Failed to parse rerank results', error);
// silent error, will fallback to default sorting in parent method
return [];
}
}
override async reRank<Chunk extends ChunkSimilarity = ChunkSimilarity>(
@@ -110,6 +140,10 @@ export class OpenAIEmbeddingClient extends EmbeddingClient {
const ranks = [];
for (const c of chunk(sortedEmbeddings, Math.min(topK, 10))) {
const rank = await this.getEmbeddingRelevance(query, c, signal);
if (c.length !== rank.length) {
// llm return wrong result, fallback to default sorting
return super.reRank(query, embeddings, topK, signal);
}
ranks.push(rank);
}
@@ -124,6 +158,21 @@ export class OpenAIEmbeddingClient extends EmbeddingClient {
}
}
let EMBEDDING_CLIENT: EmbeddingClient | undefined;
export async function getEmbeddingClient(
providerFactory: CopilotProviderFactory,
prompt: PromptService
): Promise<EmbeddingClient | undefined> {
if (EMBEDDING_CLIENT) {
return EMBEDDING_CLIENT;
}
const client = new ProductionEmbeddingClient(providerFactory, prompt);
if (await client.configured()) {
EMBEDDING_CLIENT = client;
}
return EMBEDDING_CLIENT;
}
export class MockEmbeddingClient extends EmbeddingClient {
async getEmbeddings(input: string[]): Promise<Embedding[]> {
return input.map((_, i) => ({

View File

@@ -4,7 +4,6 @@ import {
AFFiNELogger,
BlobNotFound,
CallMetric,
Config,
CopilotContextFileNotSupported,
DocNotFound,
EventBus,
@@ -15,9 +14,11 @@ import {
} from '../../../base';
import { DocReader } from '../../../core/doc';
import { Models } from '../../../models';
import { PromptService } from '../prompt';
import { CopilotProviderFactory } from '../providers';
import { CopilotStorage } from '../storage';
import { readStream } from '../utils';
import { OpenAIEmbeddingClient } from './embedding';
import { getEmbeddingClient } from './embedding';
import type { Chunk, DocFragment } from './types';
import { EMBEDDING_DIMENSIONS, EmbeddingClient } from './types';
@@ -30,11 +31,12 @@ export class CopilotContextDocJob {
private client: EmbeddingClient | undefined;
constructor(
private readonly config: Config,
private readonly doc: DocReader,
private readonly event: EventBus,
private readonly logger: AFFiNELogger,
private readonly models: Models,
private readonly providerFactory: CopilotProviderFactory,
private readonly prompt: PromptService,
private readonly queue: JobQueue,
private readonly storage: CopilotStorage
) {
@@ -54,10 +56,8 @@ export class CopilotContextDocJob {
private async setup() {
this.supportEmbedding =
await this.models.copilotContext.checkEmbeddingAvailable();
if (this.supportEmbedding && this.config.copilot.providers.openai.apiKey) {
this.client = new OpenAIEmbeddingClient(
this.config.copilot.providers.openai
);
if (this.supportEmbedding) {
this.client = await getEmbeddingClient(this.providerFactory, this.prompt);
}
}
@@ -89,6 +89,14 @@ export class CopilotContextDocJob {
if (!this.supportEmbedding) return;
for (const { workspaceId, docId } of docs) {
const jobId = `workspace:embedding:${workspaceId}:${docId}`;
const job = await this.queue.get(jobId, 'copilot.embedding.docs');
// if the job exists and is older than 5 minute, remove it
if (job && job.timestamp + 5 * 60 * 1000 < Date.now()) {
this.logger.verbose(`Removing old embedding job ${jobId}`);
await this.queue.remove(jobId, 'copilot.embedding.docs');
}
await this.queue.add(
'copilot.embedding.docs',
{
@@ -99,6 +107,7 @@ export class CopilotContextDocJob {
{
jobId: `workspace:embedding:${workspaceId}:${docId}`,
priority: options?.priority ?? 1,
timestamp: Date.now(),
}
);
}
@@ -336,6 +345,9 @@ export class CopilotContextDocJob {
workspaceId,
docId
);
this.logger.verbose(
`Check if doc ${docId} in workspace ${workspaceId} needs embedding: ${needEmbedding}`
);
if (needEmbedding) {
if (signal.aborted) return;
const fragment = await this.getDocFragment(workspaceId, docId);

View File

@@ -2,7 +2,6 @@ import { Injectable, OnApplicationBootstrap } from '@nestjs/common';
import {
Cache,
Config,
CopilotInvalidContext,
NoCopilotProviderAvailable,
OnEvent,
@@ -15,9 +14,11 @@ import {
ContextFile,
Models,
} from '../../../models';
import { OpenAIEmbeddingClient } from './embedding';
import { PromptService } from '../prompt';
import { CopilotProviderFactory } from '../providers';
import { getEmbeddingClient } from './embedding';
import { ContextSession } from './session';
import { EmbeddingClient } from './types';
import type { EmbeddingClient } from './types';
const CONTEXT_SESSION_KEY = 'context-session';
@@ -27,26 +28,24 @@ export class CopilotContextService implements OnApplicationBootstrap {
private client: EmbeddingClient | undefined;
constructor(
private readonly config: Config,
private readonly cache: Cache,
private readonly models: Models
private readonly models: Models,
private readonly providerFactory: CopilotProviderFactory,
private readonly prompt: PromptService
) {}
@OnEvent('config.init')
onConfigInit() {
this.setup();
async onConfigInit() {
await this.setup();
}
@OnEvent('config.changed')
onConfigChanged() {
this.setup();
async onConfigChanged() {
await this.setup();
}
private setup() {
const configure = this.config.copilot.providers.openai;
if (configure.apiKey) {
this.client = new OpenAIEmbeddingClient(configure);
}
private async setup() {
this.client = await getEmbeddingClient(this.providerFactory, this.prompt);
}
async onApplicationBootstrap() {

View File

@@ -69,6 +69,10 @@ export type Chunk = {
export const EMBEDDING_DIMENSIONS = 1024;
export abstract class EmbeddingClient {
async configured() {
return true;
}
async getFileEmbeddings(
file: File,
chunkMapper: (chunk: Chunk[]) => Chunk[],

View File

@@ -335,7 +335,66 @@ Convert a multi-speaker audio recording into a structured JSON format by transcr
requireAttachment: true,
},
},
{
name: 'Rerank results',
action: 'Rerank results',
model: 'gpt-4.1-mini',
messages: [
{
role: 'system',
content: `Evaluate and rank search results based on their relevance and quality to the given query by assigning a score from 1 to 10, where 10 denotes the highest relevance.
Consider various factors such as content alignment with the query, source credibility, timeliness, and user intent.
# Steps
1. **Read the Query**: Understand the main intent and specific details of the search query.
2. **Review Each Result**:
- Analyze the content's relevance to the query.
- Assess the credibility of the source or website.
- Consider the timeliness of the information, ensuring it's current and relevant.
- Evaluate the alignment with potential user intent based on the query.
3. **Scoring**:
- Assign a score from 1 to 10 based on the overall relevance and quality, with 10 being the most relevant.
# Output Format
Return a JSON object for each result in the following format in raw:
{
"scores": [
{
"reason": "[Reasoning behind the score in 20 words]",
"chunk": "[chunk]",
"targetId": "[targetId]",
"score": [1-10]
}
]
}
# Notes
- Be aware of the potential biases or inaccuracies in the sources.
- Consider if the content is comprehensive and directly answers the query.
- Pay attention to the nuances of user intent that might influence relevance.`,
},
{
role: 'user',
content: `
<query>{{query}}</query>
<results>
{{#results}}
<result>
<targetId>{{targetId}}</targetId>
<chunk>{{chunk}}</chunk>
<content>
{{content}}
</content>
</result>
{{/results}}
</results>`,
},
],
},
{
name: 'Generate a caption',
action: 'Generate a caption',

View File

@@ -103,7 +103,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
capabilities: [
{
input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Text],
output: [ModelOutputType.Text, ModelOutputType.Structured],
defaultForOutputType: true,
},
],
@@ -113,7 +113,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
capabilities: [
{
input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Text],
output: [ModelOutputType.Text, ModelOutputType.Structured],
},
],
},
@@ -122,7 +122,16 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
capabilities: [
{
input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Text],
output: [ModelOutputType.Text, ModelOutputType.Structured],
},
],
},
{
id: 'gpt-4.1-nano',
capabilities: [
{
input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Text, ModelOutputType.Structured],
},
],
},
@@ -283,8 +292,8 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
model: modelInstance,
system,
messages: msgs,
temperature: options.temperature || 0,
maxTokens: options.maxTokens || 4096,
temperature: options.temperature ?? 0,
maxTokens: options.maxTokens ?? 4096,
providerOptions: {
openai: this.getOpenAIOptions(options, model.id),
},
@@ -322,10 +331,10 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
model: modelInstance,
system,
messages: msgs,
frequencyPenalty: options.frequencyPenalty || 0,
presencePenalty: options.presencePenalty || 0,
temperature: options.temperature || 0,
maxTokens: options.maxTokens || 4096,
frequencyPenalty: options.frequencyPenalty ?? 0,
presencePenalty: options.presencePenalty ?? 0,
temperature: options.temperature ?? 0,
maxTokens: options.maxTokens ?? 4096,
providerOptions: {
openai: this.getOpenAIOptions(options, model.id),
},
@@ -388,8 +397,9 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
model: modelInstance,
system,
messages: msgs,
temperature: ('temperature' in options && options.temperature) || 0,
maxTokens: ('maxTokens' in options && options.maxTokens) || 4096,
temperature: options.temperature ?? 0,
maxTokens: options.maxTokens ?? 4096,
maxRetries: options.maxRetries ?? 3,
schema,
providerOptions: {
openai: options.user ? { user: options.user } : {},

View File

@@ -124,8 +124,8 @@ export class PerplexityProvider extends CopilotProvider<PerplexityConfig> {
model: modelInstance,
system,
messages: msgs,
temperature: options.temperature || 0,
maxTokens: options.maxTokens || 4096,
temperature: options.temperature ?? 0,
maxTokens: options.maxTokens ?? 4096,
abortSignal: options.signal,
});
@@ -164,8 +164,8 @@ export class PerplexityProvider extends CopilotProvider<PerplexityConfig> {
model: modelInstance,
system,
messages: msgs,
temperature: options.temperature || 0,
maxTokens: options.maxTokens || 4096,
temperature: options.temperature ?? 0,
maxTokens: options.maxTokens ?? 4096,
abortSignal: options.signal,
});

View File

@@ -172,7 +172,7 @@ export abstract class CopilotProvider<C = any> {
structure(
_cond: ModelConditions,
_messages: PromptMessage[],
_options: CopilotStructuredOptions
_options?: CopilotStructuredOptions
): Promise<string> {
throw new CopilotProviderNotSupported({
provider: this.type,
@@ -193,7 +193,7 @@ export abstract class CopilotProvider<C = any> {
embedding(
_model: ModelConditions,
_text: string,
_text: string | string[],
_options?: CopilotEmbeddingOptions
): Promise<number[][]> {
throw new CopilotProviderNotSupported({

View File

@@ -61,6 +61,8 @@ export const PromptConfigStrictSchema = z.object({
// params requirements
requireContent: z.boolean().nullable().optional(),
requireAttachment: z.boolean().nullable().optional(),
// structure output
maxRetries: z.number().nullable().optional(),
// openai
frequencyPenalty: z.number().nullable().optional(),
presencePenalty: z.number().nullable().optional(),