mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-13 21:05:19 +00:00
@@ -829,3 +829,55 @@ for (const { name, content, verifier } of workflows) {
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
// ==================== rerank ====================
|
||||
|
||||
test(
|
||||
'should be able to rerank message chunks',
|
||||
runIfCopilotConfigured,
|
||||
async t => {
|
||||
const { factory, prompt } = t.context;
|
||||
|
||||
await retry('rerank', t, async t => {
|
||||
const query = 'Is this content relevant to programming?';
|
||||
const embeddings = [
|
||||
'How to write JavaScript code for web development.',
|
||||
'Today is a beautiful sunny day for walking in the park.',
|
||||
'Python is a popular programming language for data science.',
|
||||
'The weather forecast predicts rain for the weekend.',
|
||||
'JavaScript frameworks like React and Angular are widely used.',
|
||||
'Cooking recipes can be found in many online blogs.',
|
||||
'Machine learning algorithms are essential for AI development.',
|
||||
'The latest smartphone models have impressive camera features.',
|
||||
'Learning to code can open up many career opportunities.',
|
||||
'The stock market is experiencing significant fluctuations.',
|
||||
];
|
||||
|
||||
const p = (await prompt.get('Rerank results'))!;
|
||||
t.assert(p, 'should have prompt for rerank');
|
||||
const provider = (await factory.getProviderByModel(p.model))!;
|
||||
t.assert(provider, 'should have provider for rerank');
|
||||
|
||||
const scores = await provider.rerank(
|
||||
{ modelId: p.model },
|
||||
embeddings.map(e => p.finish({ query, doc: e }))
|
||||
);
|
||||
|
||||
t.is(scores.length, 10, 'should return scores for all chunks');
|
||||
|
||||
for (const score of scores) {
|
||||
t.assert(
|
||||
typeof score === 'number' && score >= 0 && score <= 1,
|
||||
`score should be a number between 0 and 1, got ${score}`
|
||||
);
|
||||
}
|
||||
|
||||
t.log('Rerank scores:', scores);
|
||||
t.is(
|
||||
scores.filter(s => s > 0.5).length,
|
||||
4,
|
||||
'should have 4 related chunks'
|
||||
);
|
||||
});
|
||||
}
|
||||
);
|
||||
|
||||
@@ -215,7 +215,6 @@ export class CopilotContextService implements OnApplicationBootstrap {
|
||||
topK * 2,
|
||||
threshold
|
||||
),
|
||||
|
||||
this.models.copilotContext.matchWorkspaceEmbedding(
|
||||
embedding,
|
||||
workspaceId,
|
||||
@@ -237,8 +236,9 @@ export class CopilotContextService implements OnApplicationBootstrap {
|
||||
!fileChunks.length &&
|
||||
!workspaceChunks.length &&
|
||||
!scopedWorkspaceChunks?.length
|
||||
)
|
||||
) {
|
||||
return [];
|
||||
}
|
||||
|
||||
return await this.embeddingClient.reRank(
|
||||
content,
|
||||
|
||||
@@ -20,6 +20,7 @@ import {
|
||||
type ReRankResult,
|
||||
} from './types';
|
||||
|
||||
const EMBEDDING_MODEL = 'gemini-embedding-001';
|
||||
const RERANK_PROMPT = 'Rerank results';
|
||||
|
||||
class ProductionEmbeddingClient extends EmbeddingClient {
|
||||
@@ -34,6 +35,7 @@ class ProductionEmbeddingClient extends EmbeddingClient {
|
||||
|
||||
override async configured(): Promise<boolean> {
|
||||
const embedding = await this.providerFactory.getProvider({
|
||||
modelId: EMBEDDING_MODEL,
|
||||
outputType: ModelOutputType.Embedding,
|
||||
});
|
||||
const result = Boolean(embedding);
|
||||
@@ -60,6 +62,7 @@ class ProductionEmbeddingClient extends EmbeddingClient {
|
||||
|
||||
async getEmbeddings(input: string[]): Promise<Embedding[]> {
|
||||
const provider = await this.getProvider({
|
||||
modelId: EMBEDDING_MODEL,
|
||||
outputType: ModelOutputType.Embedding,
|
||||
});
|
||||
this.logger.verbose(
|
||||
@@ -109,9 +112,9 @@ class ProductionEmbeddingClient extends EmbeddingClient {
|
||||
);
|
||||
|
||||
try {
|
||||
return ranks.map((score, i) => ({
|
||||
chunk: embeddings[i].content,
|
||||
targetId: this.getTargetId(embeddings[i]),
|
||||
return ranks.map((score, chunk) => ({
|
||||
chunk,
|
||||
targetId: this.getTargetId(embeddings[chunk]),
|
||||
score,
|
||||
}));
|
||||
} catch (error) {
|
||||
@@ -171,7 +174,7 @@ class ProductionEmbeddingClient extends EmbeddingClient {
|
||||
const highConfidenceChunks = ranks
|
||||
.flat()
|
||||
.toSorted((a, b) => b.score - a.score)
|
||||
.filter(r => r.score > 5)
|
||||
.filter(r => r.score > 0.5)
|
||||
.map(r => chunks[`${r.targetId}:${r.chunk}`])
|
||||
.filter(Boolean);
|
||||
|
||||
|
||||
@@ -176,17 +176,15 @@ export abstract class EmbeddingClient {
|
||||
}
|
||||
|
||||
const ReRankItemSchema = z.object({
|
||||
scores: z.object({
|
||||
chunk: z.string().describe('The chunk index of the search result.'),
|
||||
targetId: z.string().describe('The id of the target.'),
|
||||
score: z
|
||||
.number()
|
||||
.min(0)
|
||||
.max(10)
|
||||
.describe(
|
||||
'The relevance score of the results should be 0-10, with 0 being the least relevant and 10 being the most relevant.'
|
||||
),
|
||||
}),
|
||||
chunk: z.number().describe('The chunk index of the search result.'),
|
||||
targetId: z.string().describe('The id of the target.'),
|
||||
score: z
|
||||
.number()
|
||||
.min(0)
|
||||
.max(10)
|
||||
.describe(
|
||||
'The relevance score of the results should be 0-10, with 0 being the least relevant and 10 being the most relevant.'
|
||||
),
|
||||
});
|
||||
|
||||
export type ReRankResult = z.infer<typeof ReRankItemSchema>['scores'][];
|
||||
export type ReRankResult = z.infer<typeof ReRankItemSchema>[];
|
||||
|
||||
@@ -346,7 +346,7 @@ Convert a multi-speaker audio recording into a structured JSON format by transcr
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: `<Instruct>: Given a web search query, retrieve relevant passages that answer the query\n<Query>: {query}\n<Document>: {doc}`,
|
||||
content: `<Instruct>: Given a document search result, determine whether the result is relevant to the query.\n<Query>: {{query}}\n<Document>: {{doc}}`,
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -1676,7 +1676,7 @@ This sentence contains information from the first source[^1]. This sentence refe
|
||||
<tool-calling-guidelines>
|
||||
Before starting Tool calling, you need to follow:
|
||||
- DO NOT embed a tool call mid-sentence.
|
||||
- When searching for information, searching web & searching the user's Workspace information.
|
||||
- When searching for unknown information or keyword, prioritize searching the user's workspace.
|
||||
- Depending on the complexity of the question and the information returned by the search tools, you can call different tools multiple times to search.
|
||||
</tool-calling-guidelines>
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import type {
|
||||
import type { GoogleVertexProvider } from '@ai-sdk/google-vertex';
|
||||
import {
|
||||
AISDKError,
|
||||
embedMany,
|
||||
generateObject,
|
||||
generateText,
|
||||
JSONParseError,
|
||||
@@ -20,6 +21,7 @@ import {
|
||||
import { CopilotProvider } from '../provider';
|
||||
import type {
|
||||
CopilotChatOptions,
|
||||
CopilotEmbeddingOptions,
|
||||
CopilotImageOptions,
|
||||
CopilotProviderModel,
|
||||
ModelConditions,
|
||||
@@ -211,6 +213,40 @@ export abstract class GeminiProvider<T> extends CopilotProvider<T> {
|
||||
}
|
||||
}
|
||||
|
||||
override async embedding(
|
||||
cond: ModelConditions,
|
||||
messages: string | string[],
|
||||
options: CopilotEmbeddingOptions = { dimensions: DEFAULT_DIMENSIONS }
|
||||
): Promise<number[][]> {
|
||||
messages = Array.isArray(messages) ? messages : [messages];
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Embedding };
|
||||
await this.checkParams({ embeddings: messages, cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
try {
|
||||
metrics.ai
|
||||
.counter('generate_embedding_calls')
|
||||
.add(1, { model: model.id });
|
||||
|
||||
const modelInstance = this.instance.textEmbeddingModel(model.id, {
|
||||
outputDimensionality: options.dimensions || DEFAULT_DIMENSIONS,
|
||||
taskType: 'RETRIEVAL_DOCUMENT',
|
||||
});
|
||||
|
||||
const { embeddings } = await embedMany({
|
||||
model: modelInstance,
|
||||
values: messages,
|
||||
});
|
||||
|
||||
return embeddings.filter(v => v && Array.isArray(v));
|
||||
} catch (e: any) {
|
||||
metrics.ai
|
||||
.counter('generate_embedding_errors')
|
||||
.add(1, { model: model.id });
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
|
||||
private async getFullStream(
|
||||
model: CopilotProviderModel,
|
||||
messages: PromptMessage[],
|
||||
|
||||
@@ -71,8 +71,8 @@ export class GeminiGenerativeProvider extends GeminiProvider<GeminiGenerativeCon
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'Text Embedding 004',
|
||||
id: 'text-embedding-004',
|
||||
name: 'Text Embedding 005',
|
||||
id: 'text-embedding-005',
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text],
|
||||
@@ -80,6 +80,18 @@ export class GeminiGenerativeProvider extends GeminiProvider<GeminiGenerativeCon
|
||||
},
|
||||
],
|
||||
},
|
||||
// not exists yet
|
||||
// {
|
||||
// name: 'Gemini Embedding',
|
||||
// id: 'gemini-embedding-001',
|
||||
// capabilities: [
|
||||
// {
|
||||
// input: [ModelInputType.Text],
|
||||
// output: [ModelOutputType.Embedding],
|
||||
// defaultForOutputType: true,
|
||||
// },
|
||||
// ],
|
||||
// },
|
||||
];
|
||||
|
||||
protected instance!: GoogleGenerativeAIProvider;
|
||||
|
||||
@@ -49,6 +49,17 @@ export class GeminiVertexProvider extends GeminiProvider<GeminiVertexConfig> {
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'Gemini Embedding',
|
||||
id: 'gemini-embedding-001',
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text],
|
||||
output: [ModelOutputType.Embedding],
|
||||
defaultForOutputType: true,
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
protected instance!: GoogleVertexProvider;
|
||||
|
||||
@@ -450,7 +450,8 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Text };
|
||||
await this.checkParams({ messages: [], cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
const instance = this.#instance.responses(model.id);
|
||||
// get the log probability of "yes"/"no"
|
||||
const instance = this.#instance(model.id, { logprobs: 16 });
|
||||
|
||||
const scores = await Promise.all(
|
||||
chunkMessages.map(async messages => {
|
||||
@@ -461,29 +462,37 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
system,
|
||||
messages: msgs,
|
||||
temperature: 0,
|
||||
maxTokens: 1,
|
||||
maxTokens: 16,
|
||||
providerOptions: {
|
||||
openai: {
|
||||
...this.getOpenAIOptions(options, model.id),
|
||||
// get the log probability of "yes"/"no"
|
||||
logprobs: 2,
|
||||
},
|
||||
},
|
||||
maxSteps: 1,
|
||||
abortSignal: options.signal,
|
||||
});
|
||||
|
||||
const top = (logprobs?.[0]?.topLogprobs ?? []).reduce(
|
||||
(acc, item) => {
|
||||
acc[item.token] = item.logprob;
|
||||
return acc;
|
||||
},
|
||||
{} as Record<string, number>
|
||||
const topMap: Record<string, number> = (
|
||||
logprobs?.[0]?.topLogprobs ?? []
|
||||
).reduce<Record<string, number>>(
|
||||
(acc, { token, logprob }) => ({ ...acc, [token]: logprob }),
|
||||
{}
|
||||
);
|
||||
|
||||
// OpenAI often includes a leading space, so try matching both ' yes' and 'yes'
|
||||
const logYes = top[' yes'] ?? top['yes'] ?? Number.NEGATIVE_INFINITY;
|
||||
const logNo = top[' no'] ?? top['no'] ?? Number.NEGATIVE_INFINITY;
|
||||
const findLogProb = (token: string): number => {
|
||||
// OpenAI often includes a leading space, so try matching '.yes', '_yes', ' yes' and 'yes'
|
||||
return [`.${token}`, `_${token}`, ` ${token}`, token]
|
||||
.flatMap(v => [v, v.toLowerCase(), v.toUpperCase()])
|
||||
.reduce<number>(
|
||||
(best, key) =>
|
||||
(topMap[key] ?? Number.NEGATIVE_INFINITY) > best
|
||||
? topMap[key]
|
||||
: best,
|
||||
Number.NEGATIVE_INFINITY
|
||||
);
|
||||
};
|
||||
|
||||
const logYes = findLogProb('Yes');
|
||||
const logNo = findLogProb('No');
|
||||
|
||||
const pYes = Math.exp(logYes);
|
||||
const pNo = Math.exp(logNo);
|
||||
|
||||
@@ -501,6 +501,12 @@ export class TextStreamParser {
|
||||
case 'doc_semantic_search': {
|
||||
if (Array.isArray(chunk.result)) {
|
||||
result += `\nFound ${chunk.result.length} document${chunk.result.length !== 1 ? 's' : ''} related to “${chunk.args.query}”.\n`;
|
||||
} else if (typeof chunk.result === 'string') {
|
||||
result += `\n${chunk.result}\n`;
|
||||
} else {
|
||||
this.logger.warn(
|
||||
`Unexpected result type for doc_semantic_search: ${chunk.result?.message || 'Unknown error'}`
|
||||
);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -39,9 +39,13 @@ export const createDocKeywordSearchTool = (
|
||||
) => {
|
||||
return tool({
|
||||
description:
|
||||
'Search all workspace documents for the exact keyword or phrase supplied and return passages ranked by textual match. Use this tool by default whenever a straightforward term-based lookup is sufficient.',
|
||||
'Fuzzy search all workspace documents for the exact keyword or phrase supplied and return passages ranked by textual match. Use this tool by default whenever a straightforward term-based or keyword-base lookup is sufficient.',
|
||||
parameters: z.object({
|
||||
query: z.string().describe('The query to search for'),
|
||||
query: z
|
||||
.string()
|
||||
.describe(
|
||||
'The query to search for, e.g. "meeting notes" or "project plan".'
|
||||
),
|
||||
}),
|
||||
execute: async ({ query }) => {
|
||||
try {
|
||||
|
||||
@@ -19,13 +19,14 @@ export const buildDocSearchGetter = (
|
||||
abortSignal?: AbortSignal
|
||||
) => {
|
||||
if (!options || !query?.trim() || !options.user || !options.workspace) {
|
||||
return undefined;
|
||||
return `Invalid search parameters.`;
|
||||
}
|
||||
const canAccess = await ac
|
||||
.user(options.user)
|
||||
.workspace(options.workspace)
|
||||
.can('Workspace.Read');
|
||||
if (!canAccess) return undefined;
|
||||
if (!canAccess)
|
||||
return 'You do not have permission to access this workspace.';
|
||||
const [chunks, contextChunks] = await Promise.all([
|
||||
context.matchWorkspaceAll(options.workspace, query, 10, abortSignal),
|
||||
docContext?.matchFiles(query, 10, abortSignal) ?? [],
|
||||
@@ -42,7 +43,8 @@ export const buildDocSearchGetter = (
|
||||
if (contextChunks.length) {
|
||||
fileChunks.push(...contextChunks);
|
||||
}
|
||||
if (!docChunks.length && !fileChunks.length) return undefined;
|
||||
if (!docChunks.length && !fileChunks.length)
|
||||
return `No results found for "${query}".`;
|
||||
return [...fileChunks, ...docChunks];
|
||||
};
|
||||
return searchDocs;
|
||||
@@ -52,16 +54,16 @@ export const createDocSemanticSearchTool = (
|
||||
searchDocs: (
|
||||
query: string,
|
||||
abortSignal?: AbortSignal
|
||||
) => Promise<ChunkSimilarity[] | undefined>
|
||||
) => Promise<ChunkSimilarity[] | string | undefined>
|
||||
) => {
|
||||
return tool({
|
||||
description:
|
||||
'Retrieve conceptually related passages by performing vector-based semantic similarity search across embedded documents; call this tool only when exact keyword search fails or the user explicitly needs meaning-level matches (e.g., paraphrases, synonyms, broader concepts).',
|
||||
'Retrieve conceptually related passages by performing vector-based semantic similarity search across embedded documents; use this tool only when exact keyword search fails or the user explicitly needs meaning-level matches (e.g., paraphrases, synonyms, broader concepts).',
|
||||
parameters: z.object({
|
||||
query: z
|
||||
.string()
|
||||
.describe(
|
||||
'The query statement to search for, e.g. "What is the capital of France?"'
|
||||
'The query statement to search for, e.g. "What is the capital of France?"\nWhen querying specific terms or IDs, you should provide the complete string instead of separating it with delimiters.\nFor example, if a user wants to look up the ID "sicDoe1is", use "What is sicDoe1is" instead of "si code 1is".'
|
||||
),
|
||||
}),
|
||||
execute: async ({ query }, options) => {
|
||||
|
||||
Reference in New Issue
Block a user