feat(server): enable web search for 4.1 (#11825)

This commit is contained in:
darkskygit
2025-04-22 14:49:46 +00:00
parent bbdea71686
commit 597b27c22f
10 changed files with 317 additions and 99 deletions

View File

@@ -600,6 +600,8 @@ const workflows = [
content: 'apple company',
verifier: (t: ExecutionContext, result: string) => {
for (const l of result.split('\n')) {
const line = l.trim();
if (!line) continue;
t.notThrows(() => {
JSON.parse(l.trim());
}, 'should be valid json');

View File

@@ -1112,7 +1112,11 @@ test('CitationParser should replace citation placeholders with URLs', t => {
const citations = ['https://example1.com', 'https://example2.com'];
const parser = new CitationParser();
const result = parser.parse(content, citations) + parser.end();
for (const citation of citations) {
parser.push(citation);
}
const result = parser.parse(content) + parser.end();
const expected = [
'This is [a] test sentence with [citations [^1]] and [^2] and [3].',
@@ -1147,8 +1151,12 @@ test('CitationParser should replace chunks of citation placeholders with URLs',
];
const parser = new CitationParser();
for (const citation of citations) {
parser.push(citation);
}
let result = contents.reduce((acc, current) => {
return acc + parser.parse(current, citations);
return acc + parser.parse(current);
}, '');
result += parser.end();
@@ -1175,7 +1183,11 @@ test('CitationParser should not replace citation already with URLs', t => {
];
const parser = new CitationParser();
const result = parser.parse(content, citations) + parser.end();
for (const citation of citations) {
parser.push(citation);
}
const result = parser.parse(content) + parser.end();
const expected = [
content,
@@ -1199,8 +1211,12 @@ test('CitationParser should not replace chunks of citation already with URLs', t
];
const parser = new CitationParser();
for (const citation of citations) {
parser.push(citation);
}
let result = contents.reduce((acc, current) => {
return acc + parser.parse(current, citations);
return acc + parser.parse(current);
}, '');
result += parser.end();
@@ -1213,6 +1229,26 @@ test('CitationParser should not replace chunks of citation already with URLs', t
t.is(result, expected);
});
test('CitationParser should replace openai style reference chunks', t => {
const contents = [
'This is [a] test sentence with citations ',
'([example1.com](https://example1.com))',
];
const parser = new CitationParser();
let result = contents.reduce((acc, current) => {
return acc + parser.parse(current);
}, '');
result += parser.end();
const expected = [
contents[0] + '[^1]',
`[^1]: {"type":"url","url":"${encodeURIComponent('https://example1.com')}"}`,
].join('\n');
t.is(result, expected);
});
// ==================== context ====================
test('should be able to manage context', async t => {
const { context, prompt, session, event, jobs, storage } = t.context;

View File

@@ -1102,6 +1102,9 @@ Below is the user's query. Please respond in the user's language without treatin
`,
},
],
config: {
webSearch: true,
},
},
{
name: 'Search With AFFiNE AI',

View File

@@ -1,5 +1,6 @@
import {
createOpenAI,
openai,
type OpenAIProvider as VercelOpenAIProvider,
} from '@ai-sdk/openai';
import {
@@ -31,7 +32,7 @@ import {
CopilotTextToTextProvider,
PromptMessage,
} from './types';
import { chatToGPTMessage } from './utils';
import { chatToGPTMessage, CitationParser } from './utils';
export const DEFAULT_DIMENSIONS = 256;
@@ -176,6 +177,15 @@ export class OpenAIProvider
}
}
private getToolUse(options: CopilotChatOptions = {}) {
if (options.webSearch) {
return {
web_search_preview: openai.tools.webSearchPreview(),
};
}
return undefined;
}
// ====== text to text ======
async generateText(
messages: PromptMessage[],
@@ -234,15 +244,18 @@ export class OpenAIProvider
const [system, msgs] = await chatToGPTMessage(messages);
const modelInstance = this.#instance(model, {
structuredOutputs: Boolean(options.jsonMode),
user: options.user,
});
const modelInstance = options.webSearch
? this.#instance.responses(model)
: this.#instance(model, {
structuredOutputs: Boolean(options.jsonMode),
user: options.user,
});
const { textStream } = streamText({
const { fullStream } = streamText({
model: modelInstance,
system,
messages: msgs,
tools: this.getToolUse(options),
frequencyPenalty: options.frequencyPenalty || 0,
presencePenalty: options.presencePenalty || 0,
temperature: options.temperature || 0,
@@ -250,11 +263,24 @@ export class OpenAIProvider
abortSignal: options.signal,
});
for await (const message of textStream) {
if (message) {
yield message;
const parser = new CitationParser();
for await (const chunk of fullStream) {
if (chunk) {
switch (chunk.type) {
case 'text-delta': {
const result = parser.parse(chunk.textDelta);
yield result;
break;
}
case 'step-finish': {
const result = parser.end();
yield result;
break;
}
}
if (options.signal?.aborted) {
await textStream.cancel();
await fullStream.cancel();
break;
}
}

View File

@@ -95,11 +95,14 @@ export class PerplexityProvider
abortSignal: options.signal,
});
const citationParser = new CitationParser();
const citations = sources.map(s => s.url);
const parser = new CitationParser();
for (const source of sources) {
parser.push(source.url);
}
let result = text.replaceAll(/<\/?think>\n/g, '\n---\n');
result = citationParser.parse(result, citations);
result += citationParser.end();
result = parser.parse(result);
result += parser.end();
return result;
} catch (e: any) {
metrics.ai.counter('chat_text_errors').add(1, { model });
@@ -129,24 +132,24 @@ export class PerplexityProvider
abortSignal: options.signal,
});
const citationParser = new CitationParser();
const citations = [];
const parser = new CitationParser();
for await (const chunk of stream.fullStream) {
switch (chunk.type) {
case 'source': {
citations.push(chunk.source.url);
parser.push(chunk.source.url);
break;
}
case 'text-delta': {
const result = citationParser.parse(
chunk.textDelta.replaceAll(/<\/?think>\n?/g, '\n---\n'),
citations
const text = chunk.textDelta.replaceAll(
/<\/?think>\n?/g,
'\n---\n'
);
const result = parser.parse(text);
yield result;
break;
}
case 'step-finish': {
const result = citationParser.end();
const result = parser.end();
yield result;
break;
}

View File

@@ -21,6 +21,7 @@ export enum CopilotCapability {
export const PromptConfigStrictSchema = z.object({
// openai
jsonMode: z.boolean().nullable().optional(),
webSearch: z.boolean().nullable().optional(),
frequencyPenalty: z.number().nullable().optional(),
presencePenalty: z.number().nullable().optional(),
temperature: z.number().nullable().optional(),

View File

@@ -92,83 +92,243 @@ export async function chatToGPTMessage(
return [system?.content, msgs, schema];
}
export class CitationParser {
private readonly SQUARE_BRACKET_OPEN = '[';
// pattern types the callback will receive
type Pattern =
| { kind: 'index'; value: number } // [123]
| { kind: 'link'; text: string; url: string } // [text](url)
| { kind: 'wrappedLink'; text: string; url: string }; // ([text](url))
private readonly SQUARE_BRACKET_CLOSE = ']';
type NeedMore = { kind: 'needMore' };
type Failed = { kind: 'fail'; nextPos: number };
type Finished =
| { kind: 'ok'; endPos: number; text: string; url: string }
| { kind: 'index'; endPos: number; value: number };
type ParseStatus = Finished | NeedMore | Failed;
private readonly PARENTHESES_OPEN = '(';
type PatternCallback = (m: Pattern) => string;
private startToken: string[] = [];
export class StreamPatternParser {
#buffer = '';
private endToken: string[] = [];
constructor(private readonly callback: PatternCallback) {}
private numberToken: string[] = [];
write(chunk: string): string {
this.#buffer += chunk;
const output: string[] = [];
let i = 0;
private citations: string[] = [];
while (i < this.#buffer.length) {
const ch = this.#buffer[i];
public parse(content: string, citations: string[]) {
this.citations = citations;
let result = '';
const contentArray = content.split('');
for (const [index, char] of contentArray.entries()) {
if (char === this.SQUARE_BRACKET_OPEN) {
if (this.numberToken.length === 0) {
this.startToken.push(char);
} else {
result += this.flush() + char;
}
// [[[number]]] or [text](url) or ([text](url))
if (ch === '[' || (ch === '(' && this.peek(i + 1) === '[')) {
const isWrapped = ch === '(';
const startPos = isWrapped ? i + 1 : i;
const res = this.tryParse(startPos);
if (res.kind === 'needMore') break;
const { output: out, nextPos } = this.handlePattern(
res,
isWrapped,
startPos,
i
);
output.push(out);
i = nextPos;
continue;
}
output.push(ch);
i += 1;
}
if (char === this.SQUARE_BRACKET_CLOSE) {
this.endToken.push(char);
if (this.startToken.length === this.endToken.length) {
const cIndex = Number(this.numberToken.join('').trim());
if (
cIndex > 0 &&
cIndex <= citations.length &&
contentArray[index + 1] !== this.PARENTHESES_OPEN
) {
const content = `[^${cIndex}]`;
result += content;
this.resetToken();
} else {
result += this.flush();
}
} else if (this.startToken.length < this.endToken.length) {
result += this.flush();
}
continue;
this.#buffer = this.#buffer.slice(i);
return output.join('');
}
end(): string {
const rest = this.#buffer;
this.#buffer = '';
return rest;
}
// =========== helpers ===========
private peek(pos: number): string | undefined {
return pos < this.#buffer.length ? this.#buffer[pos] : undefined;
}
private tryParse(pos: number): ParseStatus {
const nestedRes = this.tryParseNestedIndex(pos);
if (nestedRes) return nestedRes;
return this.tryParseBracketPattern(pos);
}
private tryParseNestedIndex(pos: number): ParseStatus | null {
if (this.peek(pos + 1) !== '[') return null;
let i = pos;
let bracketCount = 0;
while (i < this.#buffer.length && this.#buffer[i] === '[') {
bracketCount++;
i++;
}
if (bracketCount >= 2) {
if (i >= this.#buffer.length) {
return { kind: 'needMore' };
}
if (this.isNumeric(char)) {
if (this.startToken.length > 0) {
this.numberToken.push(char);
} else {
result += this.flush() + char;
}
continue;
let content = '';
while (i < this.#buffer.length && this.#buffer[i] !== ']') {
content += this.#buffer[i++];
}
if (this.startToken.length > 0) {
result += this.flush() + char;
} else {
result += char;
let rightBracketCount = 0;
while (i < this.#buffer.length && this.#buffer[i] === ']') {
rightBracketCount++;
i++;
}
if (i >= this.#buffer.length && rightBracketCount < bracketCount) {
return { kind: 'needMore' };
}
if (
rightBracketCount === bracketCount &&
content.length > 0 &&
this.isNumeric(content)
) {
if (this.peek(i) === '(') {
return { kind: 'fail', nextPos: i };
}
return { kind: 'index', endPos: i, value: Number(content) };
}
}
return result;
return null;
}
private tryParseBracketPattern(pos: number): ParseStatus {
let i = pos + 1; // skip '['
if (i >= this.#buffer.length) {
return { kind: 'needMore' };
}
let content = '';
while (i < this.#buffer.length && this.#buffer[i] !== ']') {
const nextChar = this.#buffer[i];
if (nextChar === '[') {
return { kind: 'fail', nextPos: i };
}
content += nextChar;
i += 1;
}
if (i >= this.#buffer.length) {
return { kind: 'needMore' };
}
const after = i + 1;
const afterChar = this.peek(after);
if (content.length > 0 && this.isNumeric(content) && afterChar !== '(') {
// [number] pattern
return { kind: 'index', endPos: after, value: Number(content) };
} else if (afterChar !== '(') {
// [text](url) pattern
return { kind: 'fail', nextPos: after };
}
i = after + 1; // skip '('
if (i >= this.#buffer.length) {
return { kind: 'needMore' };
}
let url = '';
while (i < this.#buffer.length && this.#buffer[i] !== ')') {
url += this.#buffer[i++];
}
if (i >= this.#buffer.length) {
return { kind: 'needMore' };
}
return { kind: 'ok', endPos: i + 1, text: content, url };
}
private isNumeric(str: string): boolean {
return !Number.isNaN(Number(str)) && str.trim() !== '';
}
private handlePattern(
pattern: Finished | Failed,
isWrapped: boolean,
start: number,
current: number
): { output: string; nextPos: number } {
if (pattern.kind === 'fail') {
return {
output: this.#buffer.slice(current, pattern.nextPos),
nextPos: pattern.nextPos,
};
}
if (isWrapped) {
const afterLinkPos = pattern.endPos;
if (this.peek(afterLinkPos) !== ')') {
if (afterLinkPos >= this.#buffer.length) {
return { output: '', nextPos: current };
}
return { output: '(', nextPos: start };
}
const out =
pattern.kind === 'index'
? this.callback({ ...pattern, kind: 'index' })
: this.callback({ ...pattern, kind: 'wrappedLink' });
return { output: out, nextPos: afterLinkPos + 1 };
} else {
const out =
pattern.kind === 'ok'
? this.callback({ ...pattern, kind: 'link' })
: this.callback({ ...pattern, kind: 'index' });
return { output: out, nextPos: pattern.endPos };
}
}
}
export class CitationParser {
private readonly citations: string[] = [];
private readonly parser = new StreamPatternParser(p => {
switch (p.kind) {
case 'index': {
if (p.value <= this.citations.length) {
return `[^${p.value}]`;
}
return `[${p.value}]`;
}
case 'wrappedLink': {
const index = this.citations.indexOf(p.url);
if (index === -1) {
this.citations.push(p.url);
return `[^${this.citations.length}]`;
}
return `[^${index + 1}]`;
}
case 'link': {
return `[${p.text}](${p.url})`;
}
}
});
public push(citation: string) {
this.citations.push(citation);
}
public parse(content: string) {
return this.parser.write(content);
}
public end() {
return this.flush() + '\n' + this.getFootnotes();
}
private flush() {
const content = this.getTokenContent();
this.resetToken();
return content;
return this.parser.end() + '\n' + this.getFootnotes();
}
private getFootnotes() {
@@ -179,18 +339,4 @@ export class CitationParser {
});
return footnotes.join('\n');
}
private getTokenContent() {
return this.startToken.concat(this.numberToken, this.endToken).join('');
}
private resetToken() {
this.startToken = [];
this.endToken = [];
this.numberToken = [];
}
private isNumeric(str: string) {
return !isNaN(Number(str)) && str.trim() !== '';
}
}