mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-14 21:27:20 +00:00
feat(server): faster reranking based on confidence (#12957)
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* Improved document reranking with a more streamlined and accurate
scoring system.
* Enhanced support for binary ("yes"/"no") document relevance judgments.
* **Improvements**
* Simplified user prompts and output formats for reranking tasks, making
results easier to interpret.
* Increased reliability and consistency in document ranking results.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
@@ -17,7 +17,6 @@ import {
|
||||
import {
|
||||
EMBEDDING_DIMENSIONS,
|
||||
EmbeddingClient,
|
||||
getReRankSchema,
|
||||
type ReRankResult,
|
||||
} from './types';
|
||||
|
||||
@@ -81,9 +80,9 @@ class ProductionEmbeddingClient extends EmbeddingClient {
|
||||
}
|
||||
|
||||
private getTargetId<T extends ChunkSimilarity>(embedding: T) {
|
||||
return 'docId' in embedding
|
||||
return 'docId' in embedding && typeof embedding.docId === 'string'
|
||||
? embedding.docId
|
||||
: 'fileId' in embedding
|
||||
: 'fileId' in embedding && typeof embedding.fileId === 'string'
|
||||
? embedding.fileId
|
||||
: '';
|
||||
}
|
||||
@@ -102,24 +101,19 @@ class ProductionEmbeddingClient extends EmbeddingClient {
|
||||
throw new CopilotPromptNotFound({ name: RERANK_PROMPT });
|
||||
}
|
||||
const provider = await this.getProvider({ modelId: prompt.model });
|
||||
const schema = getReRankSchema(embeddings.length);
|
||||
|
||||
const ranks = await provider.structure(
|
||||
const ranks = await provider.rerank(
|
||||
{ modelId: prompt.model },
|
||||
prompt.finish({
|
||||
query,
|
||||
results: embeddings.map(e => ({
|
||||
targetId: this.getTargetId(e),
|
||||
chunk: e.chunk,
|
||||
content: e.content,
|
||||
})),
|
||||
schema,
|
||||
}),
|
||||
{ maxRetries: 3, signal }
|
||||
embeddings.map(e => prompt.finish({ query, doc: e.content })),
|
||||
{ signal }
|
||||
);
|
||||
|
||||
try {
|
||||
return schema.parse(JSON.parse(ranks)).ranks;
|
||||
return ranks.map((score, i) => ({
|
||||
chunk: embeddings[i].content,
|
||||
targetId: this.getTargetId(embeddings[i]),
|
||||
score,
|
||||
}));
|
||||
} catch (error) {
|
||||
this.logger.error('Failed to parse rerank results', error);
|
||||
// silent error, will fallback to default sorting in parent method
|
||||
@@ -176,9 +170,9 @@ class ProductionEmbeddingClient extends EmbeddingClient {
|
||||
|
||||
const highConfidenceChunks = ranks
|
||||
.flat()
|
||||
.toSorted((a, b) => b.scores.score - a.scores.score)
|
||||
.filter(r => r.scores.score > 5)
|
||||
.map(r => chunks[`${r.scores.targetId}:${r.scores.chunk}`])
|
||||
.toSorted((a, b) => b.score - a.score)
|
||||
.filter(r => r.score > 5)
|
||||
.map(r => chunks[`${r.targetId}:${r.chunk}`])
|
||||
.filter(Boolean);
|
||||
|
||||
this.logger.verbose(
|
||||
|
||||
@@ -177,11 +177,6 @@ export abstract class EmbeddingClient {
|
||||
|
||||
const ReRankItemSchema = z.object({
|
||||
scores: z.object({
|
||||
reason: z
|
||||
.string()
|
||||
.describe(
|
||||
'Think step by step, describe in 20 words the reason for giving this score.'
|
||||
),
|
||||
chunk: z.string().describe('The chunk index of the search result.'),
|
||||
targetId: z.string().describe('The id of the target.'),
|
||||
score: z
|
||||
@@ -194,11 +189,4 @@ const ReRankItemSchema = z.object({
|
||||
}),
|
||||
});
|
||||
|
||||
export const getReRankSchema = (size: number) =>
|
||||
z.object({
|
||||
ranks: ReRankItemSchema.array().describe(
|
||||
`A array of scores. Make sure to score all ${size} results.`
|
||||
),
|
||||
});
|
||||
|
||||
export type ReRankResult = z.infer<ReturnType<typeof getReRankSchema>>['ranks'];
|
||||
export type ReRankResult = z.infer<typeof ReRankItemSchema>['scores'][];
|
||||
|
||||
@@ -342,57 +342,11 @@ Convert a multi-speaker audio recording into a structured JSON format by transcr
|
||||
messages: [
|
||||
{
|
||||
role: 'system',
|
||||
content: `Evaluate and rank search results based on their relevance and quality to the given query by assigning a score from 1 to 10, where 10 denotes the highest relevance.
|
||||
|
||||
Consider various factors such as content alignment with the query, source credibility, timeliness, and user intent.
|
||||
|
||||
# Steps
|
||||
|
||||
1. **Read the Query**: Understand the main intent and specific details of the search query.
|
||||
2. **Review Each Result**:
|
||||
- Analyze the content's relevance to the query.
|
||||
- Assess the credibility of the source or website.
|
||||
- Consider the timeliness of the information, ensuring it's current and relevant.
|
||||
- Evaluate the alignment with potential user intent based on the query.
|
||||
3. **Scoring**:
|
||||
- Assign a score from 1 to 10 based on the overall relevance and quality, with 10 being the most relevant.
|
||||
- Each chunk returns a score and should not be mixed together.
|
||||
|
||||
# Output Format
|
||||
|
||||
Return a JSON object for each result in the following format in raw:
|
||||
{
|
||||
"scores": [
|
||||
{
|
||||
"reason": "[Reasoning behind the score in 20 words]",
|
||||
"chunk": "[chunk]",
|
||||
"targetId": "[targetId]",
|
||||
"score": [1-10]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Notes
|
||||
|
||||
- Be aware of the potential biases or inaccuracies in the sources.
|
||||
- Consider if the content is comprehensive and directly answers the query.
|
||||
- Pay attention to the nuances of user intent that might influence relevance.`,
|
||||
content: `Judge whether the Document meets the requirements based on the Query and the Instruct provided. The answer must be "yes" or "no".`,
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: `
|
||||
<query>{{query}}</query>
|
||||
<results>
|
||||
{{#results}}
|
||||
<result>
|
||||
<targetId>{{targetId}}</targetId>
|
||||
<chunk>{{chunk}}</chunk>
|
||||
<content>
|
||||
{{content}}
|
||||
</content>
|
||||
</result>
|
||||
{{/results}}
|
||||
</results>`,
|
||||
content: `<Instruct>: Given a web search query, retrieve relevant passages that answer the query\n<Query>: {query}\n<Document>: {doc}`,
|
||||
},
|
||||
],
|
||||
},
|
||||
|
||||
@@ -440,6 +440,60 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
}
|
||||
}
|
||||
|
||||
override async rerank(
|
||||
cond: ModelConditions,
|
||||
chunkMessages: PromptMessage[][],
|
||||
options: CopilotChatOptions = {}
|
||||
): Promise<number[]> {
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Text };
|
||||
await this.checkParams({ messages: [], cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
const instance = this.#instance.responses(model.id);
|
||||
|
||||
const scores = await Promise.all(
|
||||
chunkMessages.map(async messages => {
|
||||
const [system, msgs] = await chatToGPTMessage(messages);
|
||||
|
||||
const { logprobs } = await generateText({
|
||||
model: instance,
|
||||
system,
|
||||
messages: msgs,
|
||||
temperature: 0,
|
||||
maxTokens: 1,
|
||||
providerOptions: {
|
||||
openai: {
|
||||
...this.getOpenAIOptions(options, model.id),
|
||||
// get the log probability of "yes"/"no"
|
||||
logprobs: 2,
|
||||
},
|
||||
},
|
||||
maxSteps: 1,
|
||||
abortSignal: options.signal,
|
||||
});
|
||||
|
||||
const top = (logprobs?.[0]?.topLogprobs ?? []).reduce(
|
||||
(acc, item) => {
|
||||
acc[item.token] = item.logprob;
|
||||
return acc;
|
||||
},
|
||||
{} as Record<string, number>
|
||||
);
|
||||
|
||||
// OpenAI often includes a leading space, so try matching both ' yes' and 'yes'
|
||||
const logYes = top[' yes'] ?? top['yes'] ?? Number.NEGATIVE_INFINITY;
|
||||
const logNo = top[' no'] ?? top['no'] ?? Number.NEGATIVE_INFINITY;
|
||||
|
||||
const pYes = Math.exp(logYes);
|
||||
const pNo = Math.exp(logNo);
|
||||
const prob = pYes + pNo === 0 ? 0 : pYes / (pYes + pNo);
|
||||
|
||||
return prob;
|
||||
})
|
||||
);
|
||||
|
||||
return scores;
|
||||
}
|
||||
|
||||
private async getFullStream(
|
||||
model: CopilotProviderModel,
|
||||
messages: PromptMessage[],
|
||||
|
||||
@@ -295,4 +295,15 @@ export abstract class CopilotProvider<C = any> {
|
||||
kind: 'embedding',
|
||||
});
|
||||
}
|
||||
|
||||
async rerank(
|
||||
_model: ModelConditions,
|
||||
_messages: PromptMessage[][],
|
||||
_options?: CopilotChatOptions
|
||||
): Promise<number[]> {
|
||||
throw new CopilotProviderNotSupported({
|
||||
provider: this.type,
|
||||
kind: 'rerank',
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user