diff --git a/packages/backend/server/src/__tests__/mocks/copilot.mock.ts b/packages/backend/server/src/__tests__/mocks/copilot.mock.ts index 39ecca31c6..b5d5a2c1e9 100644 --- a/packages/backend/server/src/__tests__/mocks/copilot.mock.ts +++ b/packages/backend/server/src/__tests__/mocks/copilot.mock.ts @@ -79,7 +79,7 @@ export class MockCopilotProvider extends OpenAIProvider { capabilities: [ { input: [ModelInputType.Text, ModelInputType.Image], - output: [ModelOutputType.Text], + output: [ModelOutputType.Text, ModelOutputType.Structured], }, ], }, diff --git a/packages/backend/server/src/plugins/copilot/context/embedding.ts b/packages/backend/server/src/plugins/copilot/context/embedding.ts index 6d628ba23a..06c5d6fd17 100644 --- a/packages/backend/server/src/plugins/copilot/context/embedding.ts +++ b/packages/backend/server/src/plugins/copilot/context/embedding.ts @@ -1,41 +1,75 @@ -import { - createOpenAI, - type OpenAIProvider as VercelOpenAIProvider, -} from '@ai-sdk/openai'; -import { embedMany, generateObject } from 'ai'; +import { Logger } from '@nestjs/common'; import { chunk } from 'lodash-es'; -import { ChunkSimilarity, Embedding } from '../../../models'; -import { OpenAIConfig } from '../providers/openai'; +import { + CopilotPromptNotFound, + CopilotProviderNotSupported, +} from '../../../base'; +import type { ChunkSimilarity, Embedding } from '../../../models'; +import type { PromptService } from '../prompt'; +import { + type CopilotProvider, + type CopilotProviderFactory, + type ModelFullConditions, + ModelInputType, + ModelOutputType, +} from '../providers'; import { EMBEDDING_DIMENSIONS, EmbeddingClient, getReRankSchema, - ReRankResult, + type ReRankResult, } from './types'; -const RERANK_MODEL = 'gpt-4.1-mini'; +const RERANK_PROMPT = 'Rerank results'; -export class OpenAIEmbeddingClient extends EmbeddingClient { - readonly #instance: VercelOpenAIProvider; +export class ProductionEmbeddingClient extends EmbeddingClient { + private readonly logger = new Logger(ProductionEmbeddingClient.name); - constructor(config: OpenAIConfig) { + constructor( + private readonly providerFactory: CopilotProviderFactory, + private readonly prompt: PromptService + ) { super(); - this.#instance = createOpenAI({ - apiKey: config.apiKey, - baseURL: config.baseUrl, + } + + override async configured(): Promise { + const embedding = await this.providerFactory.getProvider({ + outputType: ModelOutputType.Embedding, }); + const result = Boolean(embedding); + if (!result) { + this.logger.warn( + 'Copilot embedding client is not configured properly, please check your configuration.' + ); + } + return result; + } + + private async getProvider( + cond: ModelFullConditions + ): Promise { + const provider = await this.providerFactory.getProvider(cond); + if (!provider) { + throw new CopilotProviderNotSupported({ + provider: 'embedding', + kind: cond.outputType || 'embedding', + }); + } + return provider; } async getEmbeddings(input: string[]): Promise { - const modelInstance = this.#instance.embedding('text-embedding-3-large', { - dimensions: EMBEDDING_DIMENSIONS, + const provider = await this.getProvider({ + outputType: ModelOutputType.Embedding, }); + this.logger.verbose(`Using provider ${provider.type} for embedding`, input); - const { embeddings } = await embedMany({ - model: modelInstance, - values: input, - }); + const embeddings = await provider.embedding( + { inputTypes: [ModelInputType.Text] }, + input, + { dimensions: EMBEDDING_DIMENSIONS } + ); return Array.from(embeddings.entries()).map(([index, embedding]) => ({ index, @@ -44,27 +78,6 @@ export class OpenAIEmbeddingClient extends EmbeddingClient { })); } - private getRelevancePrompt( - query: string, - embeddings: Chunk[] - ) { - const results = embeddings - .map(e => { - const targetId = 'docId' in e ? e.docId : 'fileId' in e ? e.fileId : ''; - // NOTE: not xml, just for the sake of the prompt format - return [ - '', - `${targetId}`, - `${e.chunk}`, - `${e.content}`, - '', - ]; - }) - .flat() - .join('\n'); - return `Generate a score array based on the search results list to measure the likelihood that the information contained in the search results is useful for the report on the following topic: ${query}\n\nHere are the search results:\n\n${results}\n`; - } - private async getEmbeddingRelevance< Chunk extends ChunkSimilarity = ChunkSimilarity, >( @@ -72,19 +85,36 @@ export class OpenAIEmbeddingClient extends EmbeddingClient { embeddings: Chunk[], signal?: AbortSignal ): Promise { - const prompt = this.getRelevancePrompt(query, embeddings); - const modelInstance = this.#instance(RERANK_MODEL); + if (!embeddings.length) return []; - const { - object: { ranks }, - } = await generateObject({ - model: modelInstance, - prompt, - schema: getReRankSchema(embeddings.length), - maxRetries: 3, - abortSignal: signal, - }); - return ranks; + const prompt = await this.prompt.get(RERANK_PROMPT); + if (!prompt) { + throw new CopilotPromptNotFound({ name: RERANK_PROMPT }); + } + const provider = await this.getProvider({ modelId: prompt.model }); + const schema = getReRankSchema(embeddings.length); + + const ranks = await provider.structure( + { modelId: prompt.model }, + prompt.finish({ + query, + results: embeddings.map(e => { + const targetId = + 'docId' in e ? e.docId : 'fileId' in e ? e.fileId : ''; + return { targetId, chunk: e.chunk, content: e.content }; + }), + schema, + }), + { maxRetries: 3, signal } + ); + + try { + return schema.parse(JSON.parse(ranks)).ranks; + } catch (error) { + this.logger.error('Failed to parse rerank results', error); + // silent error, will fallback to default sorting in parent method + return []; + } } override async reRank( @@ -110,6 +140,10 @@ export class OpenAIEmbeddingClient extends EmbeddingClient { const ranks = []; for (const c of chunk(sortedEmbeddings, Math.min(topK, 10))) { const rank = await this.getEmbeddingRelevance(query, c, signal); + if (c.length !== rank.length) { + // llm return wrong result, fallback to default sorting + return super.reRank(query, embeddings, topK, signal); + } ranks.push(rank); } @@ -124,6 +158,21 @@ export class OpenAIEmbeddingClient extends EmbeddingClient { } } +let EMBEDDING_CLIENT: EmbeddingClient | undefined; +export async function getEmbeddingClient( + providerFactory: CopilotProviderFactory, + prompt: PromptService +): Promise { + if (EMBEDDING_CLIENT) { + return EMBEDDING_CLIENT; + } + const client = new ProductionEmbeddingClient(providerFactory, prompt); + if (await client.configured()) { + EMBEDDING_CLIENT = client; + } + return EMBEDDING_CLIENT; +} + export class MockEmbeddingClient extends EmbeddingClient { async getEmbeddings(input: string[]): Promise { return input.map((_, i) => ({ diff --git a/packages/backend/server/src/plugins/copilot/context/job.ts b/packages/backend/server/src/plugins/copilot/context/job.ts index cef903ba77..5244ec1fbf 100644 --- a/packages/backend/server/src/plugins/copilot/context/job.ts +++ b/packages/backend/server/src/plugins/copilot/context/job.ts @@ -4,7 +4,6 @@ import { AFFiNELogger, BlobNotFound, CallMetric, - Config, CopilotContextFileNotSupported, DocNotFound, EventBus, @@ -15,9 +14,11 @@ import { } from '../../../base'; import { DocReader } from '../../../core/doc'; import { Models } from '../../../models'; +import { PromptService } from '../prompt'; +import { CopilotProviderFactory } from '../providers'; import { CopilotStorage } from '../storage'; import { readStream } from '../utils'; -import { OpenAIEmbeddingClient } from './embedding'; +import { getEmbeddingClient } from './embedding'; import type { Chunk, DocFragment } from './types'; import { EMBEDDING_DIMENSIONS, EmbeddingClient } from './types'; @@ -30,11 +31,12 @@ export class CopilotContextDocJob { private client: EmbeddingClient | undefined; constructor( - private readonly config: Config, private readonly doc: DocReader, private readonly event: EventBus, private readonly logger: AFFiNELogger, private readonly models: Models, + private readonly providerFactory: CopilotProviderFactory, + private readonly prompt: PromptService, private readonly queue: JobQueue, private readonly storage: CopilotStorage ) { @@ -54,10 +56,8 @@ export class CopilotContextDocJob { private async setup() { this.supportEmbedding = await this.models.copilotContext.checkEmbeddingAvailable(); - if (this.supportEmbedding && this.config.copilot.providers.openai.apiKey) { - this.client = new OpenAIEmbeddingClient( - this.config.copilot.providers.openai - ); + if (this.supportEmbedding) { + this.client = await getEmbeddingClient(this.providerFactory, this.prompt); } } @@ -89,6 +89,14 @@ export class CopilotContextDocJob { if (!this.supportEmbedding) return; for (const { workspaceId, docId } of docs) { + const jobId = `workspace:embedding:${workspaceId}:${docId}`; + const job = await this.queue.get(jobId, 'copilot.embedding.docs'); + // if the job exists and is older than 5 minute, remove it + if (job && job.timestamp + 5 * 60 * 1000 < Date.now()) { + this.logger.verbose(`Removing old embedding job ${jobId}`); + await this.queue.remove(jobId, 'copilot.embedding.docs'); + } + await this.queue.add( 'copilot.embedding.docs', { @@ -99,6 +107,7 @@ export class CopilotContextDocJob { { jobId: `workspace:embedding:${workspaceId}:${docId}`, priority: options?.priority ?? 1, + timestamp: Date.now(), } ); } @@ -336,6 +345,9 @@ export class CopilotContextDocJob { workspaceId, docId ); + this.logger.verbose( + `Check if doc ${docId} in workspace ${workspaceId} needs embedding: ${needEmbedding}` + ); if (needEmbedding) { if (signal.aborted) return; const fragment = await this.getDocFragment(workspaceId, docId); diff --git a/packages/backend/server/src/plugins/copilot/context/service.ts b/packages/backend/server/src/plugins/copilot/context/service.ts index 2b450ef9a4..60e1870311 100644 --- a/packages/backend/server/src/plugins/copilot/context/service.ts +++ b/packages/backend/server/src/plugins/copilot/context/service.ts @@ -2,7 +2,6 @@ import { Injectable, OnApplicationBootstrap } from '@nestjs/common'; import { Cache, - Config, CopilotInvalidContext, NoCopilotProviderAvailable, OnEvent, @@ -15,9 +14,11 @@ import { ContextFile, Models, } from '../../../models'; -import { OpenAIEmbeddingClient } from './embedding'; +import { PromptService } from '../prompt'; +import { CopilotProviderFactory } from '../providers'; +import { getEmbeddingClient } from './embedding'; import { ContextSession } from './session'; -import { EmbeddingClient } from './types'; +import type { EmbeddingClient } from './types'; const CONTEXT_SESSION_KEY = 'context-session'; @@ -27,26 +28,24 @@ export class CopilotContextService implements OnApplicationBootstrap { private client: EmbeddingClient | undefined; constructor( - private readonly config: Config, private readonly cache: Cache, - private readonly models: Models + private readonly models: Models, + private readonly providerFactory: CopilotProviderFactory, + private readonly prompt: PromptService ) {} @OnEvent('config.init') - onConfigInit() { - this.setup(); + async onConfigInit() { + await this.setup(); } @OnEvent('config.changed') - onConfigChanged() { - this.setup(); + async onConfigChanged() { + await this.setup(); } - private setup() { - const configure = this.config.copilot.providers.openai; - if (configure.apiKey) { - this.client = new OpenAIEmbeddingClient(configure); - } + private async setup() { + this.client = await getEmbeddingClient(this.providerFactory, this.prompt); } async onApplicationBootstrap() { diff --git a/packages/backend/server/src/plugins/copilot/context/types.ts b/packages/backend/server/src/plugins/copilot/context/types.ts index eb53423826..5de7b2252f 100644 --- a/packages/backend/server/src/plugins/copilot/context/types.ts +++ b/packages/backend/server/src/plugins/copilot/context/types.ts @@ -69,6 +69,10 @@ export type Chunk = { export const EMBEDDING_DIMENSIONS = 1024; export abstract class EmbeddingClient { + async configured() { + return true; + } + async getFileEmbeddings( file: File, chunkMapper: (chunk: Chunk[]) => Chunk[], diff --git a/packages/backend/server/src/plugins/copilot/prompt/prompts.ts b/packages/backend/server/src/plugins/copilot/prompt/prompts.ts index 3074210606..0de8a09fb4 100644 --- a/packages/backend/server/src/plugins/copilot/prompt/prompts.ts +++ b/packages/backend/server/src/plugins/copilot/prompt/prompts.ts @@ -335,7 +335,66 @@ Convert a multi-speaker audio recording into a structured JSON format by transcr requireAttachment: true, }, }, + { + name: 'Rerank results', + action: 'Rerank results', + model: 'gpt-4.1-mini', + messages: [ + { + role: 'system', + content: `Evaluate and rank search results based on their relevance and quality to the given query by assigning a score from 1 to 10, where 10 denotes the highest relevance. +Consider various factors such as content alignment with the query, source credibility, timeliness, and user intent. + +# Steps + +1. **Read the Query**: Understand the main intent and specific details of the search query. +2. **Review Each Result**: + - Analyze the content's relevance to the query. + - Assess the credibility of the source or website. + - Consider the timeliness of the information, ensuring it's current and relevant. + - Evaluate the alignment with potential user intent based on the query. +3. **Scoring**: + - Assign a score from 1 to 10 based on the overall relevance and quality, with 10 being the most relevant. + +# Output Format + +Return a JSON object for each result in the following format in raw: +{ + "scores": [ + { + "reason": "[Reasoning behind the score in 20 words]", + "chunk": "[chunk]", + "targetId": "[targetId]", + "score": [1-10] + } + ] +} + +# Notes + +- Be aware of the potential biases or inaccuracies in the sources. +- Consider if the content is comprehensive and directly answers the query. +- Pay attention to the nuances of user intent that might influence relevance.`, + }, + { + role: 'user', + content: ` +{{query}} + +{{#results}} + +{{targetId}} +{{chunk}} + +{{content}} + + +{{/results}} +`, + }, + ], + }, { name: 'Generate a caption', action: 'Generate a caption', diff --git a/packages/backend/server/src/plugins/copilot/providers/openai.ts b/packages/backend/server/src/plugins/copilot/providers/openai.ts index f5497573e3..3e3f3a3f74 100644 --- a/packages/backend/server/src/plugins/copilot/providers/openai.ts +++ b/packages/backend/server/src/plugins/copilot/providers/openai.ts @@ -103,7 +103,7 @@ export class OpenAIProvider extends CopilotProvider { capabilities: [ { input: [ModelInputType.Text, ModelInputType.Image], - output: [ModelOutputType.Text], + output: [ModelOutputType.Text, ModelOutputType.Structured], defaultForOutputType: true, }, ], @@ -113,7 +113,7 @@ export class OpenAIProvider extends CopilotProvider { capabilities: [ { input: [ModelInputType.Text, ModelInputType.Image], - output: [ModelOutputType.Text], + output: [ModelOutputType.Text, ModelOutputType.Structured], }, ], }, @@ -122,7 +122,16 @@ export class OpenAIProvider extends CopilotProvider { capabilities: [ { input: [ModelInputType.Text, ModelInputType.Image], - output: [ModelOutputType.Text], + output: [ModelOutputType.Text, ModelOutputType.Structured], + }, + ], + }, + { + id: 'gpt-4.1-nano', + capabilities: [ + { + input: [ModelInputType.Text, ModelInputType.Image], + output: [ModelOutputType.Text, ModelOutputType.Structured], }, ], }, @@ -283,8 +292,8 @@ export class OpenAIProvider extends CopilotProvider { model: modelInstance, system, messages: msgs, - temperature: options.temperature || 0, - maxTokens: options.maxTokens || 4096, + temperature: options.temperature ?? 0, + maxTokens: options.maxTokens ?? 4096, providerOptions: { openai: this.getOpenAIOptions(options, model.id), }, @@ -322,10 +331,10 @@ export class OpenAIProvider extends CopilotProvider { model: modelInstance, system, messages: msgs, - frequencyPenalty: options.frequencyPenalty || 0, - presencePenalty: options.presencePenalty || 0, - temperature: options.temperature || 0, - maxTokens: options.maxTokens || 4096, + frequencyPenalty: options.frequencyPenalty ?? 0, + presencePenalty: options.presencePenalty ?? 0, + temperature: options.temperature ?? 0, + maxTokens: options.maxTokens ?? 4096, providerOptions: { openai: this.getOpenAIOptions(options, model.id), }, @@ -388,8 +397,9 @@ export class OpenAIProvider extends CopilotProvider { model: modelInstance, system, messages: msgs, - temperature: ('temperature' in options && options.temperature) || 0, - maxTokens: ('maxTokens' in options && options.maxTokens) || 4096, + temperature: options.temperature ?? 0, + maxTokens: options.maxTokens ?? 4096, + maxRetries: options.maxRetries ?? 3, schema, providerOptions: { openai: options.user ? { user: options.user } : {}, diff --git a/packages/backend/server/src/plugins/copilot/providers/perplexity.ts b/packages/backend/server/src/plugins/copilot/providers/perplexity.ts index e5a9c7e797..706f948ae5 100644 --- a/packages/backend/server/src/plugins/copilot/providers/perplexity.ts +++ b/packages/backend/server/src/plugins/copilot/providers/perplexity.ts @@ -124,8 +124,8 @@ export class PerplexityProvider extends CopilotProvider { model: modelInstance, system, messages: msgs, - temperature: options.temperature || 0, - maxTokens: options.maxTokens || 4096, + temperature: options.temperature ?? 0, + maxTokens: options.maxTokens ?? 4096, abortSignal: options.signal, }); @@ -164,8 +164,8 @@ export class PerplexityProvider extends CopilotProvider { model: modelInstance, system, messages: msgs, - temperature: options.temperature || 0, - maxTokens: options.maxTokens || 4096, + temperature: options.temperature ?? 0, + maxTokens: options.maxTokens ?? 4096, abortSignal: options.signal, }); diff --git a/packages/backend/server/src/plugins/copilot/providers/provider.ts b/packages/backend/server/src/plugins/copilot/providers/provider.ts index 7f25432398..a8a3150503 100644 --- a/packages/backend/server/src/plugins/copilot/providers/provider.ts +++ b/packages/backend/server/src/plugins/copilot/providers/provider.ts @@ -172,7 +172,7 @@ export abstract class CopilotProvider { structure( _cond: ModelConditions, _messages: PromptMessage[], - _options: CopilotStructuredOptions + _options?: CopilotStructuredOptions ): Promise { throw new CopilotProviderNotSupported({ provider: this.type, @@ -193,7 +193,7 @@ export abstract class CopilotProvider { embedding( _model: ModelConditions, - _text: string, + _text: string | string[], _options?: CopilotEmbeddingOptions ): Promise { throw new CopilotProviderNotSupported({ diff --git a/packages/backend/server/src/plugins/copilot/providers/types.ts b/packages/backend/server/src/plugins/copilot/providers/types.ts index bb29e02001..c541351a2a 100644 --- a/packages/backend/server/src/plugins/copilot/providers/types.ts +++ b/packages/backend/server/src/plugins/copilot/providers/types.ts @@ -61,6 +61,8 @@ export const PromptConfigStrictSchema = z.object({ // params requirements requireContent: z.boolean().nullable().optional(), requireAttachment: z.boolean().nullable().optional(), + // structure output + maxRetries: z.number().nullable().optional(), // openai frequencyPenalty: z.number().nullable().optional(), presencePenalty: z.number().nullable().optional(), diff --git a/tests/affine-cloud-copilot/e2e/settings/embedding.spec.ts b/tests/affine-cloud-copilot/e2e/settings/embedding.spec.ts index 52db8a8b22..728d8e19a4 100644 --- a/tests/affine-cloud-copilot/e2e/settings/embedding.spec.ts +++ b/tests/affine-cloud-copilot/e2e/settings/embedding.spec.ts @@ -168,13 +168,6 @@ test.describe('AISettings/Embedding', () => { 'workspace-embedding-setting-attachment-list' ); - // Uploading - await expect( - attachmentList.getByTestId( - 'workspace-embedding-setting-attachment-uploading-item' - ) - ).toHaveCount(2); - // Persisted await expect( attachmentList.getByTestId('workspace-embedding-setting-attachment-item') diff --git a/tests/affine-cloud-copilot/e2e/utils/settings-panel-utils.ts b/tests/affine-cloud-copilot/e2e/utils/settings-panel-utils.ts index b59934028f..e3d2ac7d4d 100644 --- a/tests/affine-cloud-copilot/e2e/utils/settings-panel-utils.ts +++ b/tests/affine-cloud-copilot/e2e/utils/settings-panel-utils.ts @@ -88,6 +88,10 @@ export class SettingsPanelUtils { const fileChooser = await fileChooserPromise; await fileChooser.setFiles(attachment); + + await page + .getByTestId('workspace-embedding-setting-attachment-uploading-item') + .waitFor({ state: 'hidden' }); } }