feat(server): embedding search tool (#12810)

This commit is contained in:
DarkSky
2025-06-17 09:22:56 +08:00
committed by GitHub
parent 0785438cfe
commit cdaaa52845
12 changed files with 213 additions and 50 deletions

View File

@@ -1,4 +1,5 @@
import { Injectable, OnApplicationBootstrap } from '@nestjs/common';
import { ModuleRef } from '@nestjs/core';
import {
Cache,
@@ -15,8 +16,6 @@ import {
Models,
} from '../../../models';
import { type EmbeddingClient, getEmbeddingClient } from '../embedding';
import { PromptService } from '../prompt';
import { CopilotProviderFactory } from '../providers';
import { ContextSession } from './session';
const CONTEXT_SESSION_KEY = 'context-session';
@@ -27,10 +26,9 @@ export class CopilotContextService implements OnApplicationBootstrap {
private client: EmbeddingClient | undefined;
constructor(
private readonly moduleRef: ModuleRef,
private readonly cache: Cache,
private readonly models: Models,
private readonly providerFactory: CopilotProviderFactory,
private readonly prompt: PromptService
private readonly models: Models
) {}
@OnEvent('config.init')
@@ -44,7 +42,7 @@ export class CopilotContextService implements OnApplicationBootstrap {
}
private async setup() {
this.client = await getEmbeddingClient(this.providerFactory, this.prompt);
this.client = await getEmbeddingClient(this.moduleRef);
}
async onApplicationBootstrap() {
@@ -165,7 +163,7 @@ export class CopilotContextService implements OnApplicationBootstrap {
);
if (!fileChunks.length) return [];
return this.embeddingClient.reRank(content, fileChunks, topK, signal);
return await this.embeddingClient.reRank(content, fileChunks, topK, signal);
}
async matchWorkspaceDocs(
@@ -188,7 +186,48 @@ export class CopilotContextService implements OnApplicationBootstrap {
);
if (!workspaceChunks.length) return [];
return this.embeddingClient.reRank(content, workspaceChunks, topK, signal);
return await this.embeddingClient.reRank(
content,
workspaceChunks,
topK,
signal
);
}
async matchWorkspaceAll(
workspaceId: string,
content: string,
topK: number = 5,
signal?: AbortSignal,
threshold: number = 0.5
) {
if (!this.embeddingClient) return [];
const embedding = await this.embeddingClient.getEmbedding(content, signal);
if (!embedding) return [];
const [fileChunks, workspaceChunks] = await Promise.all([
this.models.copilotWorkspace.matchFileEmbedding(
workspaceId,
embedding,
topK * 2,
threshold
),
this.models.copilotContext.matchWorkspaceEmbedding(
embedding,
workspaceId,
topK * 2,
threshold
),
]);
if (!fileChunks.length && !workspaceChunks.length) return [];
return await this.embeddingClient.reRank(
content,
[...fileChunks, ...workspaceChunks],
topK,
signal
);
}
@OnEvent('workspace.doc.embed.failed')

View File

@@ -234,6 +234,7 @@ export class CopilotController implements BeforeApplicationShutdown {
...session.config.promptConfig,
signal: this.getSignal(req),
user: user.id,
workspace: session.config.workspaceId,
reasoning,
webSearch,
});
@@ -304,6 +305,7 @@ export class CopilotController implements BeforeApplicationShutdown {
...session.config.promptConfig,
signal: this.getSignal(req),
user: user.id,
workspace: session.config.workspaceId,
reasoning,
webSearch,
})
@@ -378,6 +380,7 @@ export class CopilotController implements BeforeApplicationShutdown {
...session.config.promptConfig,
signal: this.getSignal(req),
user: user.id,
workspace: session.config.workspaceId,
})
).pipe(
connect(shared$ =>
@@ -500,6 +503,7 @@ export class CopilotController implements BeforeApplicationShutdown {
seed: this.parseNumber(params.seed),
signal: this.getSignal(req),
user: user.id,
workspace: session.config.workspaceId,
}
)
).pipe(

View File

@@ -1,4 +1,5 @@
import { Logger } from '@nestjs/common';
import type { ModuleRef } from '@nestjs/core';
import {
CopilotPromptNotFound,
@@ -193,12 +194,16 @@ class ProductionEmbeddingClient extends EmbeddingClient {
let EMBEDDING_CLIENT: EmbeddingClient | undefined;
export async function getEmbeddingClient(
providerFactory: CopilotProviderFactory,
prompt: PromptService
moduleRef: ModuleRef
): Promise<EmbeddingClient | undefined> {
if (EMBEDDING_CLIENT) {
return EMBEDDING_CLIENT;
}
const providerFactory = moduleRef.get(CopilotProviderFactory, {
strict: false,
});
const prompt = moduleRef.get(PromptService, { strict: false });
const client = new ProductionEmbeddingClient(providerFactory, prompt);
if (await client.configured()) {
EMBEDDING_CLIENT = client;

View File

@@ -1,4 +1,5 @@
import { Injectable } from '@nestjs/common';
import { ModuleRef } from '@nestjs/core';
import {
AFFiNELogger,
@@ -14,8 +15,6 @@ import {
} from '../../../base';
import { DocReader } from '../../../core/doc';
import { Models } from '../../../models';
import { PromptService } from '../prompt';
import { CopilotProviderFactory } from '../providers';
import { CopilotStorage } from '../storage';
import { readStream } from '../utils';
import { getEmbeddingClient } from './client';
@@ -31,12 +30,11 @@ export class CopilotEmbeddingJob {
private client: EmbeddingClient | undefined;
constructor(
private readonly moduleRef: ModuleRef,
private readonly doc: DocReader,
private readonly event: EventBus,
private readonly logger: AFFiNELogger,
private readonly models: Models,
private readonly providerFactory: CopilotProviderFactory,
private readonly prompt: PromptService,
private readonly queue: JobQueue,
private readonly storage: CopilotStorage
) {
@@ -57,7 +55,7 @@ export class CopilotEmbeddingJob {
this.supportEmbedding =
await this.models.copilotContext.checkEmbeddingAvailable();
if (this.supportEmbedding) {
this.client = await getEmbeddingClient(this.providerFactory, this.prompt);
this.client = await getEmbeddingClient(this.moduleRef);
}
}

View File

@@ -1791,7 +1791,13 @@ Below is the user's query. Please respond in the user's preferred language witho
},
],
config: {
tools: ['webSearch'],
tools: [
'readDoc',
'editDoc',
'webSearch',
'keywordSearch',
'semanticSearch',
],
},
};

View File

@@ -10,7 +10,6 @@ import {
metrics,
UserFriendlyError,
} from '../../../../base';
import { createExaCrawlTool, createExaSearchTool } from '../../tools';
import { CopilotProvider } from '../provider';
import type {
CopilotChatOptions,
@@ -68,7 +67,7 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
providerOptions: {
anthropic: this.getAnthropicOptions(options, model.id),
},
tools: this.getTools(),
tools: await this.getTools(options, model.id),
maxSteps: this.MAX_STEPS,
experimental_continueSteps: true,
});
@@ -103,7 +102,7 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
providerOptions: {
anthropic: this.getAnthropicOptions(options, model.id),
},
tools: this.getTools(),
tools: await this.getTools(options, model.id),
maxSteps: this.MAX_STEPS,
experimental_continueSteps: true,
});
@@ -123,13 +122,6 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
}
}
private getTools() {
return {
web_search_exa: createExaSearchTool(this.AFFiNEConfig),
web_crawl_exa: createExaCrawlTool(this.AFFiNEConfig),
};
}
private getAnthropicOptions(options: CopilotChatOptions, model: string) {
const result: AnthropicProviderOptions = {};
if (options?.reasoning && this.isReasoningModel(model)) {

View File

@@ -11,7 +11,7 @@ import {
generateObject,
generateText,
streamText,
ToolSet,
Tool,
} from 'ai';
import { z } from 'zod';
@@ -21,10 +21,10 @@ import {
metrics,
UserFriendlyError,
} from '../../../base';
import { createExaCrawlTool, createExaSearchTool } from '../tools';
import { CopilotProvider } from './provider';
import type {
CopilotChatOptions,
CopilotChatTools,
CopilotEmbeddingOptions,
CopilotImageOptions,
CopilotStructuredOptions,
@@ -248,25 +248,14 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
}
}
private getTools(options: CopilotChatOptions, model: string): ToolSet {
const tools: ToolSet = {};
if (options?.tools?.length) {
for (const tool of options.tools) {
switch (tool) {
case 'webSearch': {
if (this.isReasoningModel(model)) {
tools.web_search_exa = createExaSearchTool(this.AFFiNEConfig);
tools.web_crawl_exa = createExaCrawlTool(this.AFFiNEConfig);
} else {
tools.web_search_preview = openai.tools.webSearchPreview();
}
break;
}
}
}
return tools;
override getProviderSpecificTools(
toolName: CopilotChatTools,
model: string
): [string, Tool] | undefined {
if (toolName === 'webSearch' && !this.isReasoningModel(model)) {
return ['web_search_preview', openai.tools.webSearchPreview()];
}
return tools;
return;
}
async text(
@@ -297,7 +286,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
providerOptions: {
openai: this.getOpenAIOptions(options, model.id),
},
tools: this.getTools(options, model.id),
tools: await this.getTools(options, model.id),
maxSteps: this.MAX_STEPS,
abortSignal: options.signal,
});
@@ -338,7 +327,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
providerOptions: {
openai: this.getOpenAIOptions(options, model.id),
},
tools: this.getTools(options, model.id),
tools: await this.getTools(options, model.id),
maxSteps: this.MAX_STEPS,
abortSignal: options.signal,
});

View File

@@ -1,4 +1,6 @@
import { Inject, Injectable, Logger } from '@nestjs/common';
import { ModuleRef } from '@nestjs/core';
import { Tool, ToolSet } from 'ai';
import { z } from 'zod';
import {
@@ -7,9 +9,18 @@ import {
CopilotProviderNotSupported,
OnEvent,
} from '../../../base';
import { AccessController } from '../../../core/permission';
import { CopilotContextService } from '../context';
import {
buildDocSearchGetter,
createExaCrawlTool,
createExaSearchTool,
createSemanticSearchTool,
} from '../tools';
import { CopilotProviderFactory } from './factory';
import {
type CopilotChatOptions,
CopilotChatTools,
type CopilotEmbeddingOptions,
type CopilotImageOptions,
CopilotProviderModel,
@@ -33,6 +44,7 @@ export abstract class CopilotProvider<C = any> {
@Inject() protected readonly AFFiNEConfig!: Config;
@Inject() protected readonly factory!: CopilotProviderFactory;
@Inject() protected readonly moduleRef!: ModuleRef;
get config(): C {
return this.AFFiNEConfig.copilot.providers[this.type] as C;
@@ -98,6 +110,49 @@ export abstract class CopilotProvider<C = any> {
);
}
protected getProviderSpecificTools(
_toolName: CopilotChatTools,
_model: string
): [string, Tool] | undefined {
return;
}
// use for tool use, shared between providers
protected async getTools(
options: CopilotChatOptions,
model: string
): Promise<ToolSet> {
const tools: ToolSet = {};
if (options?.tools?.length) {
for (const tool of options.tools) {
const toolDef = this.getProviderSpecificTools(tool, model);
if (toolDef) {
tools[toolDef[0]] = toolDef[1];
continue;
}
switch (tool) {
case 'webSearch': {
tools.web_search_exa = createExaSearchTool(this.AFFiNEConfig);
tools.web_crawl_exa = createExaCrawlTool(this.AFFiNEConfig);
break;
}
case 'semanticSearch': {
const ac = this.moduleRef.get(AccessController, { strict: false });
const context = this.moduleRef.get(CopilotContextService, {
strict: false,
});
const searchDocs = buildDocSearchGetter(ac, context);
tools.semantic_search = createSemanticSearchTool(
searchDocs.bind(null, options)
);
}
}
}
return tools;
}
return tools;
}
private handleZodError(ret: z.SafeParseReturnType<any, any>) {
if (ret.success) return;
const issues = ret.error.issues.map(i => {

View File

@@ -57,7 +57,21 @@ export const VertexSchema: JSONSchema = {
// ========== prompt ==========
export const PromptConfigStrictSchema = z.object({
tools: z.enum(['webSearch']).array().nullable().optional(),
tools: z
.enum([
// work with morph
'editDoc',
// work with exa/model internal tools
'webSearch',
// work with indexer
'readDoc',
'keywordSearch',
// work with embeddings
'semanticSearch',
])
.array()
.nullable()
.optional(),
// params requirements
requireContent: z.boolean().nullable().optional(),
requireAttachment: z.boolean().nullable().optional(),
@@ -121,6 +135,7 @@ export type PromptParams = NonNullable<PromptMessage['params']>;
const CopilotProviderOptionsSchema = z.object({
signal: z.instanceof(AbortSignal).optional(),
user: z.string().optional(),
workspace: z.string().optional(),
});
export const CopilotChatOptionsSchema = CopilotProviderOptionsSchema.merge(
@@ -133,6 +148,9 @@ export const CopilotChatOptionsSchema = CopilotProviderOptionsSchema.merge(
.optional();
export type CopilotChatOptions = z.infer<typeof CopilotChatOptionsSchema>;
export type CopilotChatTools = NonNullable<
NonNullable<CopilotChatOptions>['tools']
>[number];
export const CopilotStructuredOptionsSchema =
CopilotProviderOptionsSchema.merge(PromptConfigStrictSchema).optional();

View File

@@ -9,7 +9,11 @@ import {
} from 'ai';
import { ZodType } from 'zod';
import { createExaCrawlTool, createExaSearchTool } from '../tools';
import {
createExaCrawlTool,
createExaSearchTool,
createSemanticSearchTool,
} from '../tools';
import { PromptMessage } from './types';
type ChatMessage = CoreUserMessage | CoreAssistantMessage;
@@ -378,6 +382,7 @@ export class CitationParser {
export interface CustomAITools extends ToolSet {
web_search_exa: ReturnType<typeof createExaSearchTool>;
web_crawl_exa: ReturnType<typeof createExaCrawlTool>;
semantic_search: ReturnType<typeof createSemanticSearchTool>;
}
type ChunkType = TextStreamPart<CustomAITools>['type'];
@@ -424,6 +429,12 @@ export class TextStreamParser {
case 'tool-result': {
result = this.addPrefix(result);
switch (chunk.toolName) {
case 'semantic_search': {
if (Array.isArray(chunk.result)) {
result += `\nFound ${chunk.result.length} document${chunk.result.length !== 1 ? 's' : ''} related to “${chunk.args.query}”.\n`;
}
break;
}
case 'web_search_exa': {
if (Array.isArray(chunk.result)) {
result += `\n${this.getWebSearchLinks(chunk.result)}\n`;

View File

@@ -1 +1,2 @@
export * from './semantic-search';
export * from './web-search';

View File

@@ -0,0 +1,45 @@
import { tool } from 'ai';
import { z } from 'zod';
import type { AccessController } from '../../../core/permission';
import type { ChunkSimilarity } from '../../../models';
import type { CopilotContextService } from '../context';
import type { CopilotChatOptions } from '../providers';
export const buildDocSearchGetter = (
ac: AccessController,
context: CopilotContextService
) => {
const searchDocs = async (options: CopilotChatOptions, query?: string) => {
if (!options || !query?.trim() || !options.user || !options.workspace) {
return undefined;
}
const canAccess = await ac
.user(options.user)
.workspace(options.workspace)
.can('Workspace.Read');
if (!canAccess) return undefined;
const chunks = await context.matchWorkspaceAll(options.workspace, query);
return chunks || undefined;
};
return searchDocs;
};
export const createSemanticSearchTool = (
searchDocs: (query: string) => Promise<ChunkSimilarity[] | undefined>
) => {
return tool({
description:
'Semantic search for relevant documents in the current workspace',
parameters: z.object({
query: z.string().describe('The query to search for.'),
}),
execute: async ({ query }) => {
try {
return await searchDocs(query);
} catch {
return 'Failed to search documents.';
}
},
});
};