diff --git a/packages/backend/server/src/plugins/copilot/context/service.ts b/packages/backend/server/src/plugins/copilot/context/service.ts index 47c618e81f..d705cda8ca 100644 --- a/packages/backend/server/src/plugins/copilot/context/service.ts +++ b/packages/backend/server/src/plugins/copilot/context/service.ts @@ -1,4 +1,5 @@ import { Injectable, OnApplicationBootstrap } from '@nestjs/common'; +import { ModuleRef } from '@nestjs/core'; import { Cache, @@ -15,8 +16,6 @@ import { Models, } from '../../../models'; import { type EmbeddingClient, getEmbeddingClient } from '../embedding'; -import { PromptService } from '../prompt'; -import { CopilotProviderFactory } from '../providers'; import { ContextSession } from './session'; const CONTEXT_SESSION_KEY = 'context-session'; @@ -27,10 +26,9 @@ export class CopilotContextService implements OnApplicationBootstrap { private client: EmbeddingClient | undefined; constructor( + private readonly moduleRef: ModuleRef, private readonly cache: Cache, - private readonly models: Models, - private readonly providerFactory: CopilotProviderFactory, - private readonly prompt: PromptService + private readonly models: Models ) {} @OnEvent('config.init') @@ -44,7 +42,7 @@ export class CopilotContextService implements OnApplicationBootstrap { } private async setup() { - this.client = await getEmbeddingClient(this.providerFactory, this.prompt); + this.client = await getEmbeddingClient(this.moduleRef); } async onApplicationBootstrap() { @@ -165,7 +163,7 @@ export class CopilotContextService implements OnApplicationBootstrap { ); if (!fileChunks.length) return []; - return this.embeddingClient.reRank(content, fileChunks, topK, signal); + return await this.embeddingClient.reRank(content, fileChunks, topK, signal); } async matchWorkspaceDocs( @@ -188,7 +186,48 @@ export class CopilotContextService implements OnApplicationBootstrap { ); if (!workspaceChunks.length) return []; - return this.embeddingClient.reRank(content, workspaceChunks, topK, signal); + return await this.embeddingClient.reRank( + content, + workspaceChunks, + topK, + signal + ); + } + + async matchWorkspaceAll( + workspaceId: string, + content: string, + topK: number = 5, + signal?: AbortSignal, + threshold: number = 0.5 + ) { + if (!this.embeddingClient) return []; + const embedding = await this.embeddingClient.getEmbedding(content, signal); + if (!embedding) return []; + + const [fileChunks, workspaceChunks] = await Promise.all([ + this.models.copilotWorkspace.matchFileEmbedding( + workspaceId, + embedding, + topK * 2, + threshold + ), + this.models.copilotContext.matchWorkspaceEmbedding( + embedding, + workspaceId, + topK * 2, + threshold + ), + ]); + + if (!fileChunks.length && !workspaceChunks.length) return []; + + return await this.embeddingClient.reRank( + content, + [...fileChunks, ...workspaceChunks], + topK, + signal + ); } @OnEvent('workspace.doc.embed.failed') diff --git a/packages/backend/server/src/plugins/copilot/controller.ts b/packages/backend/server/src/plugins/copilot/controller.ts index 6ebfc44d30..a4e2d0e4fa 100644 --- a/packages/backend/server/src/plugins/copilot/controller.ts +++ b/packages/backend/server/src/plugins/copilot/controller.ts @@ -234,6 +234,7 @@ export class CopilotController implements BeforeApplicationShutdown { ...session.config.promptConfig, signal: this.getSignal(req), user: user.id, + workspace: session.config.workspaceId, reasoning, webSearch, }); @@ -304,6 +305,7 @@ export class CopilotController implements BeforeApplicationShutdown { ...session.config.promptConfig, signal: this.getSignal(req), user: user.id, + workspace: session.config.workspaceId, reasoning, webSearch, }) @@ -378,6 +380,7 @@ export class CopilotController implements BeforeApplicationShutdown { ...session.config.promptConfig, signal: this.getSignal(req), user: user.id, + workspace: session.config.workspaceId, }) ).pipe( connect(shared$ => @@ -500,6 +503,7 @@ export class CopilotController implements BeforeApplicationShutdown { seed: this.parseNumber(params.seed), signal: this.getSignal(req), user: user.id, + workspace: session.config.workspaceId, } ) ).pipe( diff --git a/packages/backend/server/src/plugins/copilot/embedding/client.ts b/packages/backend/server/src/plugins/copilot/embedding/client.ts index f2d4c7b6e9..d063b84555 100644 --- a/packages/backend/server/src/plugins/copilot/embedding/client.ts +++ b/packages/backend/server/src/plugins/copilot/embedding/client.ts @@ -1,4 +1,5 @@ import { Logger } from '@nestjs/common'; +import type { ModuleRef } from '@nestjs/core'; import { CopilotPromptNotFound, @@ -193,12 +194,16 @@ class ProductionEmbeddingClient extends EmbeddingClient { let EMBEDDING_CLIENT: EmbeddingClient | undefined; export async function getEmbeddingClient( - providerFactory: CopilotProviderFactory, - prompt: PromptService + moduleRef: ModuleRef ): Promise { if (EMBEDDING_CLIENT) { return EMBEDDING_CLIENT; } + const providerFactory = moduleRef.get(CopilotProviderFactory, { + strict: false, + }); + const prompt = moduleRef.get(PromptService, { strict: false }); + const client = new ProductionEmbeddingClient(providerFactory, prompt); if (await client.configured()) { EMBEDDING_CLIENT = client; diff --git a/packages/backend/server/src/plugins/copilot/embedding/job.ts b/packages/backend/server/src/plugins/copilot/embedding/job.ts index c7784ea07a..cc9381ddd2 100644 --- a/packages/backend/server/src/plugins/copilot/embedding/job.ts +++ b/packages/backend/server/src/plugins/copilot/embedding/job.ts @@ -1,4 +1,5 @@ import { Injectable } from '@nestjs/common'; +import { ModuleRef } from '@nestjs/core'; import { AFFiNELogger, @@ -14,8 +15,6 @@ import { } from '../../../base'; import { DocReader } from '../../../core/doc'; import { Models } from '../../../models'; -import { PromptService } from '../prompt'; -import { CopilotProviderFactory } from '../providers'; import { CopilotStorage } from '../storage'; import { readStream } from '../utils'; import { getEmbeddingClient } from './client'; @@ -31,12 +30,11 @@ export class CopilotEmbeddingJob { private client: EmbeddingClient | undefined; constructor( + private readonly moduleRef: ModuleRef, private readonly doc: DocReader, private readonly event: EventBus, private readonly logger: AFFiNELogger, private readonly models: Models, - private readonly providerFactory: CopilotProviderFactory, - private readonly prompt: PromptService, private readonly queue: JobQueue, private readonly storage: CopilotStorage ) { @@ -57,7 +55,7 @@ export class CopilotEmbeddingJob { this.supportEmbedding = await this.models.copilotContext.checkEmbeddingAvailable(); if (this.supportEmbedding) { - this.client = await getEmbeddingClient(this.providerFactory, this.prompt); + this.client = await getEmbeddingClient(this.moduleRef); } } diff --git a/packages/backend/server/src/plugins/copilot/prompt/prompts.ts b/packages/backend/server/src/plugins/copilot/prompt/prompts.ts index 086feac1a1..b9d9bd7246 100644 --- a/packages/backend/server/src/plugins/copilot/prompt/prompts.ts +++ b/packages/backend/server/src/plugins/copilot/prompt/prompts.ts @@ -1791,7 +1791,13 @@ Below is the user's query. Please respond in the user's preferred language witho }, ], config: { - tools: ['webSearch'], + tools: [ + 'readDoc', + 'editDoc', + 'webSearch', + 'keywordSearch', + 'semanticSearch', + ], }, }; diff --git a/packages/backend/server/src/plugins/copilot/providers/anthropic/anthropic.ts b/packages/backend/server/src/plugins/copilot/providers/anthropic/anthropic.ts index 1fd22d6aac..25fc1f863b 100644 --- a/packages/backend/server/src/plugins/copilot/providers/anthropic/anthropic.ts +++ b/packages/backend/server/src/plugins/copilot/providers/anthropic/anthropic.ts @@ -10,7 +10,6 @@ import { metrics, UserFriendlyError, } from '../../../../base'; -import { createExaCrawlTool, createExaSearchTool } from '../../tools'; import { CopilotProvider } from '../provider'; import type { CopilotChatOptions, @@ -68,7 +67,7 @@ export abstract class AnthropicProvider extends CopilotProvider { providerOptions: { anthropic: this.getAnthropicOptions(options, model.id), }, - tools: this.getTools(), + tools: await this.getTools(options, model.id), maxSteps: this.MAX_STEPS, experimental_continueSteps: true, }); @@ -103,7 +102,7 @@ export abstract class AnthropicProvider extends CopilotProvider { providerOptions: { anthropic: this.getAnthropicOptions(options, model.id), }, - tools: this.getTools(), + tools: await this.getTools(options, model.id), maxSteps: this.MAX_STEPS, experimental_continueSteps: true, }); @@ -123,13 +122,6 @@ export abstract class AnthropicProvider extends CopilotProvider { } } - private getTools() { - return { - web_search_exa: createExaSearchTool(this.AFFiNEConfig), - web_crawl_exa: createExaCrawlTool(this.AFFiNEConfig), - }; - } - private getAnthropicOptions(options: CopilotChatOptions, model: string) { const result: AnthropicProviderOptions = {}; if (options?.reasoning && this.isReasoningModel(model)) { diff --git a/packages/backend/server/src/plugins/copilot/providers/openai.ts b/packages/backend/server/src/plugins/copilot/providers/openai.ts index 3e3f3a3f74..a1495663da 100644 --- a/packages/backend/server/src/plugins/copilot/providers/openai.ts +++ b/packages/backend/server/src/plugins/copilot/providers/openai.ts @@ -11,7 +11,7 @@ import { generateObject, generateText, streamText, - ToolSet, + Tool, } from 'ai'; import { z } from 'zod'; @@ -21,10 +21,10 @@ import { metrics, UserFriendlyError, } from '../../../base'; -import { createExaCrawlTool, createExaSearchTool } from '../tools'; import { CopilotProvider } from './provider'; import type { CopilotChatOptions, + CopilotChatTools, CopilotEmbeddingOptions, CopilotImageOptions, CopilotStructuredOptions, @@ -248,25 +248,14 @@ export class OpenAIProvider extends CopilotProvider { } } - private getTools(options: CopilotChatOptions, model: string): ToolSet { - const tools: ToolSet = {}; - if (options?.tools?.length) { - for (const tool of options.tools) { - switch (tool) { - case 'webSearch': { - if (this.isReasoningModel(model)) { - tools.web_search_exa = createExaSearchTool(this.AFFiNEConfig); - tools.web_crawl_exa = createExaCrawlTool(this.AFFiNEConfig); - } else { - tools.web_search_preview = openai.tools.webSearchPreview(); - } - break; - } - } - } - return tools; + override getProviderSpecificTools( + toolName: CopilotChatTools, + model: string + ): [string, Tool] | undefined { + if (toolName === 'webSearch' && !this.isReasoningModel(model)) { + return ['web_search_preview', openai.tools.webSearchPreview()]; } - return tools; + return; } async text( @@ -297,7 +286,7 @@ export class OpenAIProvider extends CopilotProvider { providerOptions: { openai: this.getOpenAIOptions(options, model.id), }, - tools: this.getTools(options, model.id), + tools: await this.getTools(options, model.id), maxSteps: this.MAX_STEPS, abortSignal: options.signal, }); @@ -338,7 +327,7 @@ export class OpenAIProvider extends CopilotProvider { providerOptions: { openai: this.getOpenAIOptions(options, model.id), }, - tools: this.getTools(options, model.id), + tools: await this.getTools(options, model.id), maxSteps: this.MAX_STEPS, abortSignal: options.signal, }); diff --git a/packages/backend/server/src/plugins/copilot/providers/provider.ts b/packages/backend/server/src/plugins/copilot/providers/provider.ts index a8a3150503..e5d700d40f 100644 --- a/packages/backend/server/src/plugins/copilot/providers/provider.ts +++ b/packages/backend/server/src/plugins/copilot/providers/provider.ts @@ -1,4 +1,6 @@ import { Inject, Injectable, Logger } from '@nestjs/common'; +import { ModuleRef } from '@nestjs/core'; +import { Tool, ToolSet } from 'ai'; import { z } from 'zod'; import { @@ -7,9 +9,18 @@ import { CopilotProviderNotSupported, OnEvent, } from '../../../base'; +import { AccessController } from '../../../core/permission'; +import { CopilotContextService } from '../context'; +import { + buildDocSearchGetter, + createExaCrawlTool, + createExaSearchTool, + createSemanticSearchTool, +} from '../tools'; import { CopilotProviderFactory } from './factory'; import { type CopilotChatOptions, + CopilotChatTools, type CopilotEmbeddingOptions, type CopilotImageOptions, CopilotProviderModel, @@ -33,6 +44,7 @@ export abstract class CopilotProvider { @Inject() protected readonly AFFiNEConfig!: Config; @Inject() protected readonly factory!: CopilotProviderFactory; + @Inject() protected readonly moduleRef!: ModuleRef; get config(): C { return this.AFFiNEConfig.copilot.providers[this.type] as C; @@ -98,6 +110,49 @@ export abstract class CopilotProvider { ); } + protected getProviderSpecificTools( + _toolName: CopilotChatTools, + _model: string + ): [string, Tool] | undefined { + return; + } + + // use for tool use, shared between providers + protected async getTools( + options: CopilotChatOptions, + model: string + ): Promise { + const tools: ToolSet = {}; + if (options?.tools?.length) { + for (const tool of options.tools) { + const toolDef = this.getProviderSpecificTools(tool, model); + if (toolDef) { + tools[toolDef[0]] = toolDef[1]; + continue; + } + switch (tool) { + case 'webSearch': { + tools.web_search_exa = createExaSearchTool(this.AFFiNEConfig); + tools.web_crawl_exa = createExaCrawlTool(this.AFFiNEConfig); + break; + } + case 'semanticSearch': { + const ac = this.moduleRef.get(AccessController, { strict: false }); + const context = this.moduleRef.get(CopilotContextService, { + strict: false, + }); + const searchDocs = buildDocSearchGetter(ac, context); + tools.semantic_search = createSemanticSearchTool( + searchDocs.bind(null, options) + ); + } + } + } + return tools; + } + return tools; + } + private handleZodError(ret: z.SafeParseReturnType) { if (ret.success) return; const issues = ret.error.issues.map(i => { diff --git a/packages/backend/server/src/plugins/copilot/providers/types.ts b/packages/backend/server/src/plugins/copilot/providers/types.ts index c541351a2a..eb5bb20c6b 100644 --- a/packages/backend/server/src/plugins/copilot/providers/types.ts +++ b/packages/backend/server/src/plugins/copilot/providers/types.ts @@ -57,7 +57,21 @@ export const VertexSchema: JSONSchema = { // ========== prompt ========== export const PromptConfigStrictSchema = z.object({ - tools: z.enum(['webSearch']).array().nullable().optional(), + tools: z + .enum([ + // work with morph + 'editDoc', + // work with exa/model internal tools + 'webSearch', + // work with indexer + 'readDoc', + 'keywordSearch', + // work with embeddings + 'semanticSearch', + ]) + .array() + .nullable() + .optional(), // params requirements requireContent: z.boolean().nullable().optional(), requireAttachment: z.boolean().nullable().optional(), @@ -121,6 +135,7 @@ export type PromptParams = NonNullable; const CopilotProviderOptionsSchema = z.object({ signal: z.instanceof(AbortSignal).optional(), user: z.string().optional(), + workspace: z.string().optional(), }); export const CopilotChatOptionsSchema = CopilotProviderOptionsSchema.merge( @@ -133,6 +148,9 @@ export const CopilotChatOptionsSchema = CopilotProviderOptionsSchema.merge( .optional(); export type CopilotChatOptions = z.infer; +export type CopilotChatTools = NonNullable< + NonNullable['tools'] +>[number]; export const CopilotStructuredOptionsSchema = CopilotProviderOptionsSchema.merge(PromptConfigStrictSchema).optional(); diff --git a/packages/backend/server/src/plugins/copilot/providers/utils.ts b/packages/backend/server/src/plugins/copilot/providers/utils.ts index f98654c839..25b6312611 100644 --- a/packages/backend/server/src/plugins/copilot/providers/utils.ts +++ b/packages/backend/server/src/plugins/copilot/providers/utils.ts @@ -9,7 +9,11 @@ import { } from 'ai'; import { ZodType } from 'zod'; -import { createExaCrawlTool, createExaSearchTool } from '../tools'; +import { + createExaCrawlTool, + createExaSearchTool, + createSemanticSearchTool, +} from '../tools'; import { PromptMessage } from './types'; type ChatMessage = CoreUserMessage | CoreAssistantMessage; @@ -378,6 +382,7 @@ export class CitationParser { export interface CustomAITools extends ToolSet { web_search_exa: ReturnType; web_crawl_exa: ReturnType; + semantic_search: ReturnType; } type ChunkType = TextStreamPart['type']; @@ -424,6 +429,12 @@ export class TextStreamParser { case 'tool-result': { result = this.addPrefix(result); switch (chunk.toolName) { + case 'semantic_search': { + if (Array.isArray(chunk.result)) { + result += `\nFound ${chunk.result.length} document${chunk.result.length !== 1 ? 's' : ''} related to “${chunk.args.query}”.\n`; + } + break; + } case 'web_search_exa': { if (Array.isArray(chunk.result)) { result += `\n${this.getWebSearchLinks(chunk.result)}\n`; diff --git a/packages/backend/server/src/plugins/copilot/tools/index.ts b/packages/backend/server/src/plugins/copilot/tools/index.ts index af14d9ae45..cde30ce1ed 100644 --- a/packages/backend/server/src/plugins/copilot/tools/index.ts +++ b/packages/backend/server/src/plugins/copilot/tools/index.ts @@ -1 +1,2 @@ +export * from './semantic-search'; export * from './web-search'; diff --git a/packages/backend/server/src/plugins/copilot/tools/semantic-search.ts b/packages/backend/server/src/plugins/copilot/tools/semantic-search.ts new file mode 100644 index 0000000000..9e654c4a92 --- /dev/null +++ b/packages/backend/server/src/plugins/copilot/tools/semantic-search.ts @@ -0,0 +1,45 @@ +import { tool } from 'ai'; +import { z } from 'zod'; + +import type { AccessController } from '../../../core/permission'; +import type { ChunkSimilarity } from '../../../models'; +import type { CopilotContextService } from '../context'; +import type { CopilotChatOptions } from '../providers'; + +export const buildDocSearchGetter = ( + ac: AccessController, + context: CopilotContextService +) => { + const searchDocs = async (options: CopilotChatOptions, query?: string) => { + if (!options || !query?.trim() || !options.user || !options.workspace) { + return undefined; + } + const canAccess = await ac + .user(options.user) + .workspace(options.workspace) + .can('Workspace.Read'); + if (!canAccess) return undefined; + const chunks = await context.matchWorkspaceAll(options.workspace, query); + return chunks || undefined; + }; + return searchDocs; +}; + +export const createSemanticSearchTool = ( + searchDocs: (query: string) => Promise +) => { + return tool({ + description: + 'Semantic search for relevant documents in the current workspace', + parameters: z.object({ + query: z.string().describe('The query to search for.'), + }), + execute: async ({ query }) => { + try { + return await searchDocs(query); + } catch { + return 'Failed to search documents.'; + } + }, + }); +};