mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-17 06:16:59 +08:00
feat(server): embedding search tool (#12810)
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import { Injectable, OnApplicationBootstrap } from '@nestjs/common';
|
||||
import { ModuleRef } from '@nestjs/core';
|
||||
|
||||
import {
|
||||
Cache,
|
||||
@@ -15,8 +16,6 @@ import {
|
||||
Models,
|
||||
} from '../../../models';
|
||||
import { type EmbeddingClient, getEmbeddingClient } from '../embedding';
|
||||
import { PromptService } from '../prompt';
|
||||
import { CopilotProviderFactory } from '../providers';
|
||||
import { ContextSession } from './session';
|
||||
|
||||
const CONTEXT_SESSION_KEY = 'context-session';
|
||||
@@ -27,10 +26,9 @@ export class CopilotContextService implements OnApplicationBootstrap {
|
||||
private client: EmbeddingClient | undefined;
|
||||
|
||||
constructor(
|
||||
private readonly moduleRef: ModuleRef,
|
||||
private readonly cache: Cache,
|
||||
private readonly models: Models,
|
||||
private readonly providerFactory: CopilotProviderFactory,
|
||||
private readonly prompt: PromptService
|
||||
private readonly models: Models
|
||||
) {}
|
||||
|
||||
@OnEvent('config.init')
|
||||
@@ -44,7 +42,7 @@ export class CopilotContextService implements OnApplicationBootstrap {
|
||||
}
|
||||
|
||||
private async setup() {
|
||||
this.client = await getEmbeddingClient(this.providerFactory, this.prompt);
|
||||
this.client = await getEmbeddingClient(this.moduleRef);
|
||||
}
|
||||
|
||||
async onApplicationBootstrap() {
|
||||
@@ -165,7 +163,7 @@ export class CopilotContextService implements OnApplicationBootstrap {
|
||||
);
|
||||
if (!fileChunks.length) return [];
|
||||
|
||||
return this.embeddingClient.reRank(content, fileChunks, topK, signal);
|
||||
return await this.embeddingClient.reRank(content, fileChunks, topK, signal);
|
||||
}
|
||||
|
||||
async matchWorkspaceDocs(
|
||||
@@ -188,7 +186,48 @@ export class CopilotContextService implements OnApplicationBootstrap {
|
||||
);
|
||||
if (!workspaceChunks.length) return [];
|
||||
|
||||
return this.embeddingClient.reRank(content, workspaceChunks, topK, signal);
|
||||
return await this.embeddingClient.reRank(
|
||||
content,
|
||||
workspaceChunks,
|
||||
topK,
|
||||
signal
|
||||
);
|
||||
}
|
||||
|
||||
async matchWorkspaceAll(
|
||||
workspaceId: string,
|
||||
content: string,
|
||||
topK: number = 5,
|
||||
signal?: AbortSignal,
|
||||
threshold: number = 0.5
|
||||
) {
|
||||
if (!this.embeddingClient) return [];
|
||||
const embedding = await this.embeddingClient.getEmbedding(content, signal);
|
||||
if (!embedding) return [];
|
||||
|
||||
const [fileChunks, workspaceChunks] = await Promise.all([
|
||||
this.models.copilotWorkspace.matchFileEmbedding(
|
||||
workspaceId,
|
||||
embedding,
|
||||
topK * 2,
|
||||
threshold
|
||||
),
|
||||
this.models.copilotContext.matchWorkspaceEmbedding(
|
||||
embedding,
|
||||
workspaceId,
|
||||
topK * 2,
|
||||
threshold
|
||||
),
|
||||
]);
|
||||
|
||||
if (!fileChunks.length && !workspaceChunks.length) return [];
|
||||
|
||||
return await this.embeddingClient.reRank(
|
||||
content,
|
||||
[...fileChunks, ...workspaceChunks],
|
||||
topK,
|
||||
signal
|
||||
);
|
||||
}
|
||||
|
||||
@OnEvent('workspace.doc.embed.failed')
|
||||
|
||||
@@ -234,6 +234,7 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
...session.config.promptConfig,
|
||||
signal: this.getSignal(req),
|
||||
user: user.id,
|
||||
workspace: session.config.workspaceId,
|
||||
reasoning,
|
||||
webSearch,
|
||||
});
|
||||
@@ -304,6 +305,7 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
...session.config.promptConfig,
|
||||
signal: this.getSignal(req),
|
||||
user: user.id,
|
||||
workspace: session.config.workspaceId,
|
||||
reasoning,
|
||||
webSearch,
|
||||
})
|
||||
@@ -378,6 +380,7 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
...session.config.promptConfig,
|
||||
signal: this.getSignal(req),
|
||||
user: user.id,
|
||||
workspace: session.config.workspaceId,
|
||||
})
|
||||
).pipe(
|
||||
connect(shared$ =>
|
||||
@@ -500,6 +503,7 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
seed: this.parseNumber(params.seed),
|
||||
signal: this.getSignal(req),
|
||||
user: user.id,
|
||||
workspace: session.config.workspaceId,
|
||||
}
|
||||
)
|
||||
).pipe(
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { Logger } from '@nestjs/common';
|
||||
import type { ModuleRef } from '@nestjs/core';
|
||||
|
||||
import {
|
||||
CopilotPromptNotFound,
|
||||
@@ -193,12 +194,16 @@ class ProductionEmbeddingClient extends EmbeddingClient {
|
||||
|
||||
let EMBEDDING_CLIENT: EmbeddingClient | undefined;
|
||||
export async function getEmbeddingClient(
|
||||
providerFactory: CopilotProviderFactory,
|
||||
prompt: PromptService
|
||||
moduleRef: ModuleRef
|
||||
): Promise<EmbeddingClient | undefined> {
|
||||
if (EMBEDDING_CLIENT) {
|
||||
return EMBEDDING_CLIENT;
|
||||
}
|
||||
const providerFactory = moduleRef.get(CopilotProviderFactory, {
|
||||
strict: false,
|
||||
});
|
||||
const prompt = moduleRef.get(PromptService, { strict: false });
|
||||
|
||||
const client = new ProductionEmbeddingClient(providerFactory, prompt);
|
||||
if (await client.configured()) {
|
||||
EMBEDDING_CLIENT = client;
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
import { ModuleRef } from '@nestjs/core';
|
||||
|
||||
import {
|
||||
AFFiNELogger,
|
||||
@@ -14,8 +15,6 @@ import {
|
||||
} from '../../../base';
|
||||
import { DocReader } from '../../../core/doc';
|
||||
import { Models } from '../../../models';
|
||||
import { PromptService } from '../prompt';
|
||||
import { CopilotProviderFactory } from '../providers';
|
||||
import { CopilotStorage } from '../storage';
|
||||
import { readStream } from '../utils';
|
||||
import { getEmbeddingClient } from './client';
|
||||
@@ -31,12 +30,11 @@ export class CopilotEmbeddingJob {
|
||||
private client: EmbeddingClient | undefined;
|
||||
|
||||
constructor(
|
||||
private readonly moduleRef: ModuleRef,
|
||||
private readonly doc: DocReader,
|
||||
private readonly event: EventBus,
|
||||
private readonly logger: AFFiNELogger,
|
||||
private readonly models: Models,
|
||||
private readonly providerFactory: CopilotProviderFactory,
|
||||
private readonly prompt: PromptService,
|
||||
private readonly queue: JobQueue,
|
||||
private readonly storage: CopilotStorage
|
||||
) {
|
||||
@@ -57,7 +55,7 @@ export class CopilotEmbeddingJob {
|
||||
this.supportEmbedding =
|
||||
await this.models.copilotContext.checkEmbeddingAvailable();
|
||||
if (this.supportEmbedding) {
|
||||
this.client = await getEmbeddingClient(this.providerFactory, this.prompt);
|
||||
this.client = await getEmbeddingClient(this.moduleRef);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1791,7 +1791,13 @@ Below is the user's query. Please respond in the user's preferred language witho
|
||||
},
|
||||
],
|
||||
config: {
|
||||
tools: ['webSearch'],
|
||||
tools: [
|
||||
'readDoc',
|
||||
'editDoc',
|
||||
'webSearch',
|
||||
'keywordSearch',
|
||||
'semanticSearch',
|
||||
],
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ import {
|
||||
metrics,
|
||||
UserFriendlyError,
|
||||
} from '../../../../base';
|
||||
import { createExaCrawlTool, createExaSearchTool } from '../../tools';
|
||||
import { CopilotProvider } from '../provider';
|
||||
import type {
|
||||
CopilotChatOptions,
|
||||
@@ -68,7 +67,7 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
|
||||
providerOptions: {
|
||||
anthropic: this.getAnthropicOptions(options, model.id),
|
||||
},
|
||||
tools: this.getTools(),
|
||||
tools: await this.getTools(options, model.id),
|
||||
maxSteps: this.MAX_STEPS,
|
||||
experimental_continueSteps: true,
|
||||
});
|
||||
@@ -103,7 +102,7 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
|
||||
providerOptions: {
|
||||
anthropic: this.getAnthropicOptions(options, model.id),
|
||||
},
|
||||
tools: this.getTools(),
|
||||
tools: await this.getTools(options, model.id),
|
||||
maxSteps: this.MAX_STEPS,
|
||||
experimental_continueSteps: true,
|
||||
});
|
||||
@@ -123,13 +122,6 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
|
||||
}
|
||||
}
|
||||
|
||||
private getTools() {
|
||||
return {
|
||||
web_search_exa: createExaSearchTool(this.AFFiNEConfig),
|
||||
web_crawl_exa: createExaCrawlTool(this.AFFiNEConfig),
|
||||
};
|
||||
}
|
||||
|
||||
private getAnthropicOptions(options: CopilotChatOptions, model: string) {
|
||||
const result: AnthropicProviderOptions = {};
|
||||
if (options?.reasoning && this.isReasoningModel(model)) {
|
||||
|
||||
@@ -11,7 +11,7 @@ import {
|
||||
generateObject,
|
||||
generateText,
|
||||
streamText,
|
||||
ToolSet,
|
||||
Tool,
|
||||
} from 'ai';
|
||||
import { z } from 'zod';
|
||||
|
||||
@@ -21,10 +21,10 @@ import {
|
||||
metrics,
|
||||
UserFriendlyError,
|
||||
} from '../../../base';
|
||||
import { createExaCrawlTool, createExaSearchTool } from '../tools';
|
||||
import { CopilotProvider } from './provider';
|
||||
import type {
|
||||
CopilotChatOptions,
|
||||
CopilotChatTools,
|
||||
CopilotEmbeddingOptions,
|
||||
CopilotImageOptions,
|
||||
CopilotStructuredOptions,
|
||||
@@ -248,25 +248,14 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
}
|
||||
}
|
||||
|
||||
private getTools(options: CopilotChatOptions, model: string): ToolSet {
|
||||
const tools: ToolSet = {};
|
||||
if (options?.tools?.length) {
|
||||
for (const tool of options.tools) {
|
||||
switch (tool) {
|
||||
case 'webSearch': {
|
||||
if (this.isReasoningModel(model)) {
|
||||
tools.web_search_exa = createExaSearchTool(this.AFFiNEConfig);
|
||||
tools.web_crawl_exa = createExaCrawlTool(this.AFFiNEConfig);
|
||||
} else {
|
||||
tools.web_search_preview = openai.tools.webSearchPreview();
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return tools;
|
||||
override getProviderSpecificTools(
|
||||
toolName: CopilotChatTools,
|
||||
model: string
|
||||
): [string, Tool] | undefined {
|
||||
if (toolName === 'webSearch' && !this.isReasoningModel(model)) {
|
||||
return ['web_search_preview', openai.tools.webSearchPreview()];
|
||||
}
|
||||
return tools;
|
||||
return;
|
||||
}
|
||||
|
||||
async text(
|
||||
@@ -297,7 +286,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
providerOptions: {
|
||||
openai: this.getOpenAIOptions(options, model.id),
|
||||
},
|
||||
tools: this.getTools(options, model.id),
|
||||
tools: await this.getTools(options, model.id),
|
||||
maxSteps: this.MAX_STEPS,
|
||||
abortSignal: options.signal,
|
||||
});
|
||||
@@ -338,7 +327,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
providerOptions: {
|
||||
openai: this.getOpenAIOptions(options, model.id),
|
||||
},
|
||||
tools: this.getTools(options, model.id),
|
||||
tools: await this.getTools(options, model.id),
|
||||
maxSteps: this.MAX_STEPS,
|
||||
abortSignal: options.signal,
|
||||
});
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import { Inject, Injectable, Logger } from '@nestjs/common';
|
||||
import { ModuleRef } from '@nestjs/core';
|
||||
import { Tool, ToolSet } from 'ai';
|
||||
import { z } from 'zod';
|
||||
|
||||
import {
|
||||
@@ -7,9 +9,18 @@ import {
|
||||
CopilotProviderNotSupported,
|
||||
OnEvent,
|
||||
} from '../../../base';
|
||||
import { AccessController } from '../../../core/permission';
|
||||
import { CopilotContextService } from '../context';
|
||||
import {
|
||||
buildDocSearchGetter,
|
||||
createExaCrawlTool,
|
||||
createExaSearchTool,
|
||||
createSemanticSearchTool,
|
||||
} from '../tools';
|
||||
import { CopilotProviderFactory } from './factory';
|
||||
import {
|
||||
type CopilotChatOptions,
|
||||
CopilotChatTools,
|
||||
type CopilotEmbeddingOptions,
|
||||
type CopilotImageOptions,
|
||||
CopilotProviderModel,
|
||||
@@ -33,6 +44,7 @@ export abstract class CopilotProvider<C = any> {
|
||||
|
||||
@Inject() protected readonly AFFiNEConfig!: Config;
|
||||
@Inject() protected readonly factory!: CopilotProviderFactory;
|
||||
@Inject() protected readonly moduleRef!: ModuleRef;
|
||||
|
||||
get config(): C {
|
||||
return this.AFFiNEConfig.copilot.providers[this.type] as C;
|
||||
@@ -98,6 +110,49 @@ export abstract class CopilotProvider<C = any> {
|
||||
);
|
||||
}
|
||||
|
||||
protected getProviderSpecificTools(
|
||||
_toolName: CopilotChatTools,
|
||||
_model: string
|
||||
): [string, Tool] | undefined {
|
||||
return;
|
||||
}
|
||||
|
||||
// use for tool use, shared between providers
|
||||
protected async getTools(
|
||||
options: CopilotChatOptions,
|
||||
model: string
|
||||
): Promise<ToolSet> {
|
||||
const tools: ToolSet = {};
|
||||
if (options?.tools?.length) {
|
||||
for (const tool of options.tools) {
|
||||
const toolDef = this.getProviderSpecificTools(tool, model);
|
||||
if (toolDef) {
|
||||
tools[toolDef[0]] = toolDef[1];
|
||||
continue;
|
||||
}
|
||||
switch (tool) {
|
||||
case 'webSearch': {
|
||||
tools.web_search_exa = createExaSearchTool(this.AFFiNEConfig);
|
||||
tools.web_crawl_exa = createExaCrawlTool(this.AFFiNEConfig);
|
||||
break;
|
||||
}
|
||||
case 'semanticSearch': {
|
||||
const ac = this.moduleRef.get(AccessController, { strict: false });
|
||||
const context = this.moduleRef.get(CopilotContextService, {
|
||||
strict: false,
|
||||
});
|
||||
const searchDocs = buildDocSearchGetter(ac, context);
|
||||
tools.semantic_search = createSemanticSearchTool(
|
||||
searchDocs.bind(null, options)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
return tools;
|
||||
}
|
||||
return tools;
|
||||
}
|
||||
|
||||
private handleZodError(ret: z.SafeParseReturnType<any, any>) {
|
||||
if (ret.success) return;
|
||||
const issues = ret.error.issues.map(i => {
|
||||
|
||||
@@ -57,7 +57,21 @@ export const VertexSchema: JSONSchema = {
|
||||
// ========== prompt ==========
|
||||
|
||||
export const PromptConfigStrictSchema = z.object({
|
||||
tools: z.enum(['webSearch']).array().nullable().optional(),
|
||||
tools: z
|
||||
.enum([
|
||||
// work with morph
|
||||
'editDoc',
|
||||
// work with exa/model internal tools
|
||||
'webSearch',
|
||||
// work with indexer
|
||||
'readDoc',
|
||||
'keywordSearch',
|
||||
// work with embeddings
|
||||
'semanticSearch',
|
||||
])
|
||||
.array()
|
||||
.nullable()
|
||||
.optional(),
|
||||
// params requirements
|
||||
requireContent: z.boolean().nullable().optional(),
|
||||
requireAttachment: z.boolean().nullable().optional(),
|
||||
@@ -121,6 +135,7 @@ export type PromptParams = NonNullable<PromptMessage['params']>;
|
||||
const CopilotProviderOptionsSchema = z.object({
|
||||
signal: z.instanceof(AbortSignal).optional(),
|
||||
user: z.string().optional(),
|
||||
workspace: z.string().optional(),
|
||||
});
|
||||
|
||||
export const CopilotChatOptionsSchema = CopilotProviderOptionsSchema.merge(
|
||||
@@ -133,6 +148,9 @@ export const CopilotChatOptionsSchema = CopilotProviderOptionsSchema.merge(
|
||||
.optional();
|
||||
|
||||
export type CopilotChatOptions = z.infer<typeof CopilotChatOptionsSchema>;
|
||||
export type CopilotChatTools = NonNullable<
|
||||
NonNullable<CopilotChatOptions>['tools']
|
||||
>[number];
|
||||
|
||||
export const CopilotStructuredOptionsSchema =
|
||||
CopilotProviderOptionsSchema.merge(PromptConfigStrictSchema).optional();
|
||||
|
||||
@@ -9,7 +9,11 @@ import {
|
||||
} from 'ai';
|
||||
import { ZodType } from 'zod';
|
||||
|
||||
import { createExaCrawlTool, createExaSearchTool } from '../tools';
|
||||
import {
|
||||
createExaCrawlTool,
|
||||
createExaSearchTool,
|
||||
createSemanticSearchTool,
|
||||
} from '../tools';
|
||||
import { PromptMessage } from './types';
|
||||
|
||||
type ChatMessage = CoreUserMessage | CoreAssistantMessage;
|
||||
@@ -378,6 +382,7 @@ export class CitationParser {
|
||||
export interface CustomAITools extends ToolSet {
|
||||
web_search_exa: ReturnType<typeof createExaSearchTool>;
|
||||
web_crawl_exa: ReturnType<typeof createExaCrawlTool>;
|
||||
semantic_search: ReturnType<typeof createSemanticSearchTool>;
|
||||
}
|
||||
|
||||
type ChunkType = TextStreamPart<CustomAITools>['type'];
|
||||
@@ -424,6 +429,12 @@ export class TextStreamParser {
|
||||
case 'tool-result': {
|
||||
result = this.addPrefix(result);
|
||||
switch (chunk.toolName) {
|
||||
case 'semantic_search': {
|
||||
if (Array.isArray(chunk.result)) {
|
||||
result += `\nFound ${chunk.result.length} document${chunk.result.length !== 1 ? 's' : ''} related to “${chunk.args.query}”.\n`;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'web_search_exa': {
|
||||
if (Array.isArray(chunk.result)) {
|
||||
result += `\n${this.getWebSearchLinks(chunk.result)}\n`;
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
export * from './semantic-search';
|
||||
export * from './web-search';
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
import { tool } from 'ai';
|
||||
import { z } from 'zod';
|
||||
|
||||
import type { AccessController } from '../../../core/permission';
|
||||
import type { ChunkSimilarity } from '../../../models';
|
||||
import type { CopilotContextService } from '../context';
|
||||
import type { CopilotChatOptions } from '../providers';
|
||||
|
||||
export const buildDocSearchGetter = (
|
||||
ac: AccessController,
|
||||
context: CopilotContextService
|
||||
) => {
|
||||
const searchDocs = async (options: CopilotChatOptions, query?: string) => {
|
||||
if (!options || !query?.trim() || !options.user || !options.workspace) {
|
||||
return undefined;
|
||||
}
|
||||
const canAccess = await ac
|
||||
.user(options.user)
|
||||
.workspace(options.workspace)
|
||||
.can('Workspace.Read');
|
||||
if (!canAccess) return undefined;
|
||||
const chunks = await context.matchWorkspaceAll(options.workspace, query);
|
||||
return chunks || undefined;
|
||||
};
|
||||
return searchDocs;
|
||||
};
|
||||
|
||||
export const createSemanticSearchTool = (
|
||||
searchDocs: (query: string) => Promise<ChunkSimilarity[] | undefined>
|
||||
) => {
|
||||
return tool({
|
||||
description:
|
||||
'Semantic search for relevant documents in the current workspace',
|
||||
parameters: z.object({
|
||||
query: z.string().describe('The query to search for.'),
|
||||
}),
|
||||
execute: async ({ query }) => {
|
||||
try {
|
||||
return await searchDocs(query);
|
||||
} catch {
|
||||
return 'Failed to search documents.';
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
Reference in New Issue
Block a user