mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-14 13:25:12 +00:00
@@ -1,41 +1,75 @@
|
||||
import {
|
||||
createOpenAI,
|
||||
type OpenAIProvider as VercelOpenAIProvider,
|
||||
} from '@ai-sdk/openai';
|
||||
import { embedMany, generateObject } from 'ai';
|
||||
import { Logger } from '@nestjs/common';
|
||||
import { chunk } from 'lodash-es';
|
||||
|
||||
import { ChunkSimilarity, Embedding } from '../../../models';
|
||||
import { OpenAIConfig } from '../providers/openai';
|
||||
import {
|
||||
CopilotPromptNotFound,
|
||||
CopilotProviderNotSupported,
|
||||
} from '../../../base';
|
||||
import type { ChunkSimilarity, Embedding } from '../../../models';
|
||||
import type { PromptService } from '../prompt';
|
||||
import {
|
||||
type CopilotProvider,
|
||||
type CopilotProviderFactory,
|
||||
type ModelFullConditions,
|
||||
ModelInputType,
|
||||
ModelOutputType,
|
||||
} from '../providers';
|
||||
import {
|
||||
EMBEDDING_DIMENSIONS,
|
||||
EmbeddingClient,
|
||||
getReRankSchema,
|
||||
ReRankResult,
|
||||
type ReRankResult,
|
||||
} from './types';
|
||||
|
||||
const RERANK_MODEL = 'gpt-4.1-mini';
|
||||
const RERANK_PROMPT = 'Rerank results';
|
||||
|
||||
export class OpenAIEmbeddingClient extends EmbeddingClient {
|
||||
readonly #instance: VercelOpenAIProvider;
|
||||
export class ProductionEmbeddingClient extends EmbeddingClient {
|
||||
private readonly logger = new Logger(ProductionEmbeddingClient.name);
|
||||
|
||||
constructor(config: OpenAIConfig) {
|
||||
constructor(
|
||||
private readonly providerFactory: CopilotProviderFactory,
|
||||
private readonly prompt: PromptService
|
||||
) {
|
||||
super();
|
||||
this.#instance = createOpenAI({
|
||||
apiKey: config.apiKey,
|
||||
baseURL: config.baseUrl,
|
||||
}
|
||||
|
||||
override async configured(): Promise<boolean> {
|
||||
const embedding = await this.providerFactory.getProvider({
|
||||
outputType: ModelOutputType.Embedding,
|
||||
});
|
||||
const result = Boolean(embedding);
|
||||
if (!result) {
|
||||
this.logger.warn(
|
||||
'Copilot embedding client is not configured properly, please check your configuration.'
|
||||
);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private async getProvider(
|
||||
cond: ModelFullConditions
|
||||
): Promise<CopilotProvider> {
|
||||
const provider = await this.providerFactory.getProvider(cond);
|
||||
if (!provider) {
|
||||
throw new CopilotProviderNotSupported({
|
||||
provider: 'embedding',
|
||||
kind: cond.outputType || 'embedding',
|
||||
});
|
||||
}
|
||||
return provider;
|
||||
}
|
||||
|
||||
async getEmbeddings(input: string[]): Promise<Embedding[]> {
|
||||
const modelInstance = this.#instance.embedding('text-embedding-3-large', {
|
||||
dimensions: EMBEDDING_DIMENSIONS,
|
||||
const provider = await this.getProvider({
|
||||
outputType: ModelOutputType.Embedding,
|
||||
});
|
||||
this.logger.verbose(`Using provider ${provider.type} for embedding`, input);
|
||||
|
||||
const { embeddings } = await embedMany({
|
||||
model: modelInstance,
|
||||
values: input,
|
||||
});
|
||||
const embeddings = await provider.embedding(
|
||||
{ inputTypes: [ModelInputType.Text] },
|
||||
input,
|
||||
{ dimensions: EMBEDDING_DIMENSIONS }
|
||||
);
|
||||
|
||||
return Array.from(embeddings.entries()).map(([index, embedding]) => ({
|
||||
index,
|
||||
@@ -44,27 +78,6 @@ export class OpenAIEmbeddingClient extends EmbeddingClient {
|
||||
}));
|
||||
}
|
||||
|
||||
private getRelevancePrompt<Chunk extends ChunkSimilarity = ChunkSimilarity>(
|
||||
query: string,
|
||||
embeddings: Chunk[]
|
||||
) {
|
||||
const results = embeddings
|
||||
.map(e => {
|
||||
const targetId = 'docId' in e ? e.docId : 'fileId' in e ? e.fileId : '';
|
||||
// NOTE: not xml, just for the sake of the prompt format
|
||||
return [
|
||||
'<result>',
|
||||
`<targetId>${targetId}</targetId>`,
|
||||
`<chunk>${e.chunk}</chunk>`,
|
||||
`<content>${e.content}</content>`,
|
||||
'</result>',
|
||||
];
|
||||
})
|
||||
.flat()
|
||||
.join('\n');
|
||||
return `Generate a score array based on the search results list to measure the likelihood that the information contained in the search results is useful for the report on the following topic: ${query}\n\nHere are the search results:\n<results>\n${results}\n</results>`;
|
||||
}
|
||||
|
||||
private async getEmbeddingRelevance<
|
||||
Chunk extends ChunkSimilarity = ChunkSimilarity,
|
||||
>(
|
||||
@@ -72,19 +85,36 @@ export class OpenAIEmbeddingClient extends EmbeddingClient {
|
||||
embeddings: Chunk[],
|
||||
signal?: AbortSignal
|
||||
): Promise<ReRankResult> {
|
||||
const prompt = this.getRelevancePrompt(query, embeddings);
|
||||
const modelInstance = this.#instance(RERANK_MODEL);
|
||||
if (!embeddings.length) return [];
|
||||
|
||||
const {
|
||||
object: { ranks },
|
||||
} = await generateObject({
|
||||
model: modelInstance,
|
||||
prompt,
|
||||
schema: getReRankSchema(embeddings.length),
|
||||
maxRetries: 3,
|
||||
abortSignal: signal,
|
||||
});
|
||||
return ranks;
|
||||
const prompt = await this.prompt.get(RERANK_PROMPT);
|
||||
if (!prompt) {
|
||||
throw new CopilotPromptNotFound({ name: RERANK_PROMPT });
|
||||
}
|
||||
const provider = await this.getProvider({ modelId: prompt.model });
|
||||
const schema = getReRankSchema(embeddings.length);
|
||||
|
||||
const ranks = await provider.structure(
|
||||
{ modelId: prompt.model },
|
||||
prompt.finish({
|
||||
query,
|
||||
results: embeddings.map(e => {
|
||||
const targetId =
|
||||
'docId' in e ? e.docId : 'fileId' in e ? e.fileId : '';
|
||||
return { targetId, chunk: e.chunk, content: e.content };
|
||||
}),
|
||||
schema,
|
||||
}),
|
||||
{ maxRetries: 3, signal }
|
||||
);
|
||||
|
||||
try {
|
||||
return schema.parse(JSON.parse(ranks)).ranks;
|
||||
} catch (error) {
|
||||
this.logger.error('Failed to parse rerank results', error);
|
||||
// silent error, will fallback to default sorting in parent method
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
override async reRank<Chunk extends ChunkSimilarity = ChunkSimilarity>(
|
||||
@@ -110,6 +140,10 @@ export class OpenAIEmbeddingClient extends EmbeddingClient {
|
||||
const ranks = [];
|
||||
for (const c of chunk(sortedEmbeddings, Math.min(topK, 10))) {
|
||||
const rank = await this.getEmbeddingRelevance(query, c, signal);
|
||||
if (c.length !== rank.length) {
|
||||
// llm return wrong result, fallback to default sorting
|
||||
return super.reRank(query, embeddings, topK, signal);
|
||||
}
|
||||
ranks.push(rank);
|
||||
}
|
||||
|
||||
@@ -124,6 +158,21 @@ export class OpenAIEmbeddingClient extends EmbeddingClient {
|
||||
}
|
||||
}
|
||||
|
||||
let EMBEDDING_CLIENT: EmbeddingClient | undefined;
|
||||
export async function getEmbeddingClient(
|
||||
providerFactory: CopilotProviderFactory,
|
||||
prompt: PromptService
|
||||
): Promise<EmbeddingClient | undefined> {
|
||||
if (EMBEDDING_CLIENT) {
|
||||
return EMBEDDING_CLIENT;
|
||||
}
|
||||
const client = new ProductionEmbeddingClient(providerFactory, prompt);
|
||||
if (await client.configured()) {
|
||||
EMBEDDING_CLIENT = client;
|
||||
}
|
||||
return EMBEDDING_CLIENT;
|
||||
}
|
||||
|
||||
export class MockEmbeddingClient extends EmbeddingClient {
|
||||
async getEmbeddings(input: string[]): Promise<Embedding[]> {
|
||||
return input.map((_, i) => ({
|
||||
|
||||
Reference in New Issue
Block a user