From 597b27c22f08c68d30e2a29bedeecb1a5b6b9dd8 Mon Sep 17 00:00:00 2001 From: darkskygit Date: Tue, 22 Apr 2025 14:49:46 +0000 Subject: [PATCH] feat(server): enable web search for 4.1 (#11825) --- .../src/__tests__/copilot-provider.spec.ts | 2 + .../server/src/__tests__/copilot.spec.ts | 44 ++- .../src/plugins/copilot/prompt/prompts.ts | 3 + .../src/plugins/copilot/providers/openai.ts | 46 ++- .../plugins/copilot/providers/perplexity.ts | 25 +- .../src/plugins/copilot/providers/types.ts | 1 + .../src/plugins/copilot/providers/utils.ts | 288 +++++++++++++----- .../Intelligents/Backend/Model/Prompt.swift | 1 - .../core/src/blocksuite/ai/provider/prompt.ts | 1 - .../e2e/chat-with/image-block.spec.ts | 5 +- 10 files changed, 317 insertions(+), 99 deletions(-) diff --git a/packages/backend/server/src/__tests__/copilot-provider.spec.ts b/packages/backend/server/src/__tests__/copilot-provider.spec.ts index a03d685915..0da47bd6ad 100644 --- a/packages/backend/server/src/__tests__/copilot-provider.spec.ts +++ b/packages/backend/server/src/__tests__/copilot-provider.spec.ts @@ -600,6 +600,8 @@ const workflows = [ content: 'apple company', verifier: (t: ExecutionContext, result: string) => { for (const l of result.split('\n')) { + const line = l.trim(); + if (!line) continue; t.notThrows(() => { JSON.parse(l.trim()); }, 'should be valid json'); diff --git a/packages/backend/server/src/__tests__/copilot.spec.ts b/packages/backend/server/src/__tests__/copilot.spec.ts index cc3c648b1d..f50cfdd45c 100644 --- a/packages/backend/server/src/__tests__/copilot.spec.ts +++ b/packages/backend/server/src/__tests__/copilot.spec.ts @@ -1112,7 +1112,11 @@ test('CitationParser should replace citation placeholders with URLs', t => { const citations = ['https://example1.com', 'https://example2.com']; const parser = new CitationParser(); - const result = parser.parse(content, citations) + parser.end(); + for (const citation of citations) { + parser.push(citation); + } + + const result = parser.parse(content) + parser.end(); const expected = [ 'This is [a] test sentence with [citations [^1]] and [^2] and [3].', @@ -1147,8 +1151,12 @@ test('CitationParser should replace chunks of citation placeholders with URLs', ]; const parser = new CitationParser(); + for (const citation of citations) { + parser.push(citation); + } + let result = contents.reduce((acc, current) => { - return acc + parser.parse(current, citations); + return acc + parser.parse(current); }, ''); result += parser.end(); @@ -1175,7 +1183,11 @@ test('CitationParser should not replace citation already with URLs', t => { ]; const parser = new CitationParser(); - const result = parser.parse(content, citations) + parser.end(); + for (const citation of citations) { + parser.push(citation); + } + + const result = parser.parse(content) + parser.end(); const expected = [ content, @@ -1199,8 +1211,12 @@ test('CitationParser should not replace chunks of citation already with URLs', t ]; const parser = new CitationParser(); + for (const citation of citations) { + parser.push(citation); + } + let result = contents.reduce((acc, current) => { - return acc + parser.parse(current, citations); + return acc + parser.parse(current); }, ''); result += parser.end(); @@ -1213,6 +1229,26 @@ test('CitationParser should not replace chunks of citation already with URLs', t t.is(result, expected); }); +test('CitationParser should replace openai style reference chunks', t => { + const contents = [ + 'This is [a] test sentence with citations ', + '([example1.com](https://example1.com))', + ]; + + const parser = new CitationParser(); + + let result = contents.reduce((acc, current) => { + return acc + parser.parse(current); + }, ''); + result += parser.end(); + + const expected = [ + contents[0] + '[^1]', + `[^1]: {"type":"url","url":"${encodeURIComponent('https://example1.com')}"}`, + ].join('\n'); + t.is(result, expected); +}); + // ==================== context ==================== test('should be able to manage context', async t => { const { context, prompt, session, event, jobs, storage } = t.context; diff --git a/packages/backend/server/src/plugins/copilot/prompt/prompts.ts b/packages/backend/server/src/plugins/copilot/prompt/prompts.ts index a866aeb52e..8ee918e25a 100644 --- a/packages/backend/server/src/plugins/copilot/prompt/prompts.ts +++ b/packages/backend/server/src/plugins/copilot/prompt/prompts.ts @@ -1102,6 +1102,9 @@ Below is the user's query. Please respond in the user's language without treatin `, }, ], + config: { + webSearch: true, + }, }, { name: 'Search With AFFiNE AI', diff --git a/packages/backend/server/src/plugins/copilot/providers/openai.ts b/packages/backend/server/src/plugins/copilot/providers/openai.ts index 8d6c5510dc..01f2c771c7 100644 --- a/packages/backend/server/src/plugins/copilot/providers/openai.ts +++ b/packages/backend/server/src/plugins/copilot/providers/openai.ts @@ -1,5 +1,6 @@ import { createOpenAI, + openai, type OpenAIProvider as VercelOpenAIProvider, } from '@ai-sdk/openai'; import { @@ -31,7 +32,7 @@ import { CopilotTextToTextProvider, PromptMessage, } from './types'; -import { chatToGPTMessage } from './utils'; +import { chatToGPTMessage, CitationParser } from './utils'; export const DEFAULT_DIMENSIONS = 256; @@ -176,6 +177,15 @@ export class OpenAIProvider } } + private getToolUse(options: CopilotChatOptions = {}) { + if (options.webSearch) { + return { + web_search_preview: openai.tools.webSearchPreview(), + }; + } + return undefined; + } + // ====== text to text ====== async generateText( messages: PromptMessage[], @@ -234,15 +244,18 @@ export class OpenAIProvider const [system, msgs] = await chatToGPTMessage(messages); - const modelInstance = this.#instance(model, { - structuredOutputs: Boolean(options.jsonMode), - user: options.user, - }); + const modelInstance = options.webSearch + ? this.#instance.responses(model) + : this.#instance(model, { + structuredOutputs: Boolean(options.jsonMode), + user: options.user, + }); - const { textStream } = streamText({ + const { fullStream } = streamText({ model: modelInstance, system, messages: msgs, + tools: this.getToolUse(options), frequencyPenalty: options.frequencyPenalty || 0, presencePenalty: options.presencePenalty || 0, temperature: options.temperature || 0, @@ -250,11 +263,24 @@ export class OpenAIProvider abortSignal: options.signal, }); - for await (const message of textStream) { - if (message) { - yield message; + const parser = new CitationParser(); + for await (const chunk of fullStream) { + if (chunk) { + switch (chunk.type) { + case 'text-delta': { + const result = parser.parse(chunk.textDelta); + yield result; + break; + } + case 'step-finish': { + const result = parser.end(); + yield result; + break; + } + } + if (options.signal?.aborted) { - await textStream.cancel(); + await fullStream.cancel(); break; } } diff --git a/packages/backend/server/src/plugins/copilot/providers/perplexity.ts b/packages/backend/server/src/plugins/copilot/providers/perplexity.ts index 617678f958..dfb6f2217f 100644 --- a/packages/backend/server/src/plugins/copilot/providers/perplexity.ts +++ b/packages/backend/server/src/plugins/copilot/providers/perplexity.ts @@ -95,11 +95,14 @@ export class PerplexityProvider abortSignal: options.signal, }); - const citationParser = new CitationParser(); - const citations = sources.map(s => s.url); + const parser = new CitationParser(); + for (const source of sources) { + parser.push(source.url); + } + let result = text.replaceAll(/<\/?think>\n/g, '\n---\n'); - result = citationParser.parse(result, citations); - result += citationParser.end(); + result = parser.parse(result); + result += parser.end(); return result; } catch (e: any) { metrics.ai.counter('chat_text_errors').add(1, { model }); @@ -129,24 +132,24 @@ export class PerplexityProvider abortSignal: options.signal, }); - const citationParser = new CitationParser(); - const citations = []; + const parser = new CitationParser(); for await (const chunk of stream.fullStream) { switch (chunk.type) { case 'source': { - citations.push(chunk.source.url); + parser.push(chunk.source.url); break; } case 'text-delta': { - const result = citationParser.parse( - chunk.textDelta.replaceAll(/<\/?think>\n?/g, '\n---\n'), - citations + const text = chunk.textDelta.replaceAll( + /<\/?think>\n?/g, + '\n---\n' ); + const result = parser.parse(text); yield result; break; } case 'step-finish': { - const result = citationParser.end(); + const result = parser.end(); yield result; break; } diff --git a/packages/backend/server/src/plugins/copilot/providers/types.ts b/packages/backend/server/src/plugins/copilot/providers/types.ts index 786e1413f4..23188701c9 100644 --- a/packages/backend/server/src/plugins/copilot/providers/types.ts +++ b/packages/backend/server/src/plugins/copilot/providers/types.ts @@ -21,6 +21,7 @@ export enum CopilotCapability { export const PromptConfigStrictSchema = z.object({ // openai jsonMode: z.boolean().nullable().optional(), + webSearch: z.boolean().nullable().optional(), frequencyPenalty: z.number().nullable().optional(), presencePenalty: z.number().nullable().optional(), temperature: z.number().nullable().optional(), diff --git a/packages/backend/server/src/plugins/copilot/providers/utils.ts b/packages/backend/server/src/plugins/copilot/providers/utils.ts index 7d12bac50f..9797ce0e6f 100644 --- a/packages/backend/server/src/plugins/copilot/providers/utils.ts +++ b/packages/backend/server/src/plugins/copilot/providers/utils.ts @@ -92,83 +92,243 @@ export async function chatToGPTMessage( return [system?.content, msgs, schema]; } -export class CitationParser { - private readonly SQUARE_BRACKET_OPEN = '['; +// pattern types the callback will receive +type Pattern = + | { kind: 'index'; value: number } // [123] + | { kind: 'link'; text: string; url: string } // [text](url) + | { kind: 'wrappedLink'; text: string; url: string }; // ([text](url)) - private readonly SQUARE_BRACKET_CLOSE = ']'; +type NeedMore = { kind: 'needMore' }; +type Failed = { kind: 'fail'; nextPos: number }; +type Finished = + | { kind: 'ok'; endPos: number; text: string; url: string } + | { kind: 'index'; endPos: number; value: number }; +type ParseStatus = Finished | NeedMore | Failed; - private readonly PARENTHESES_OPEN = '('; +type PatternCallback = (m: Pattern) => string; - private startToken: string[] = []; +export class StreamPatternParser { + #buffer = ''; - private endToken: string[] = []; + constructor(private readonly callback: PatternCallback) {} - private numberToken: string[] = []; + write(chunk: string): string { + this.#buffer += chunk; + const output: string[] = []; + let i = 0; - private citations: string[] = []; + while (i < this.#buffer.length) { + const ch = this.#buffer[i]; - public parse(content: string, citations: string[]) { - this.citations = citations; - let result = ''; - const contentArray = content.split(''); - for (const [index, char] of contentArray.entries()) { - if (char === this.SQUARE_BRACKET_OPEN) { - if (this.numberToken.length === 0) { - this.startToken.push(char); - } else { - result += this.flush() + char; - } + // [[[number]]] or [text](url) or ([text](url)) + if (ch === '[' || (ch === '(' && this.peek(i + 1) === '[')) { + const isWrapped = ch === '('; + const startPos = isWrapped ? i + 1 : i; + const res = this.tryParse(startPos); + if (res.kind === 'needMore') break; + const { output: out, nextPos } = this.handlePattern( + res, + isWrapped, + startPos, + i + ); + output.push(out); + i = nextPos; continue; } + output.push(ch); + i += 1; + } - if (char === this.SQUARE_BRACKET_CLOSE) { - this.endToken.push(char); - if (this.startToken.length === this.endToken.length) { - const cIndex = Number(this.numberToken.join('').trim()); - if ( - cIndex > 0 && - cIndex <= citations.length && - contentArray[index + 1] !== this.PARENTHESES_OPEN - ) { - const content = `[^${cIndex}]`; - result += content; - this.resetToken(); - } else { - result += this.flush(); - } - } else if (this.startToken.length < this.endToken.length) { - result += this.flush(); - } - continue; + this.#buffer = this.#buffer.slice(i); + return output.join(''); + } + + end(): string { + const rest = this.#buffer; + this.#buffer = ''; + return rest; + } + + // =========== helpers =========== + + private peek(pos: number): string | undefined { + return pos < this.#buffer.length ? this.#buffer[pos] : undefined; + } + + private tryParse(pos: number): ParseStatus { + const nestedRes = this.tryParseNestedIndex(pos); + if (nestedRes) return nestedRes; + return this.tryParseBracketPattern(pos); + } + + private tryParseNestedIndex(pos: number): ParseStatus | null { + if (this.peek(pos + 1) !== '[') return null; + + let i = pos; + let bracketCount = 0; + + while (i < this.#buffer.length && this.#buffer[i] === '[') { + bracketCount++; + i++; + } + + if (bracketCount >= 2) { + if (i >= this.#buffer.length) { + return { kind: 'needMore' }; } - if (this.isNumeric(char)) { - if (this.startToken.length > 0) { - this.numberToken.push(char); - } else { - result += this.flush() + char; - } - continue; + let content = ''; + while (i < this.#buffer.length && this.#buffer[i] !== ']') { + content += this.#buffer[i++]; } - if (this.startToken.length > 0) { - result += this.flush() + char; - } else { - result += char; + let rightBracketCount = 0; + while (i < this.#buffer.length && this.#buffer[i] === ']') { + rightBracketCount++; + i++; + } + + if (i >= this.#buffer.length && rightBracketCount < bracketCount) { + return { kind: 'needMore' }; + } + + if ( + rightBracketCount === bracketCount && + content.length > 0 && + this.isNumeric(content) + ) { + if (this.peek(i) === '(') { + return { kind: 'fail', nextPos: i }; + } + return { kind: 'index', endPos: i, value: Number(content) }; } } - return result; + return null; + } + + private tryParseBracketPattern(pos: number): ParseStatus { + let i = pos + 1; // skip '[' + if (i >= this.#buffer.length) { + return { kind: 'needMore' }; + } + + let content = ''; + while (i < this.#buffer.length && this.#buffer[i] !== ']') { + const nextChar = this.#buffer[i]; + if (nextChar === '[') { + return { kind: 'fail', nextPos: i }; + } + content += nextChar; + i += 1; + } + + if (i >= this.#buffer.length) { + return { kind: 'needMore' }; + } + const after = i + 1; + const afterChar = this.peek(after); + + if (content.length > 0 && this.isNumeric(content) && afterChar !== '(') { + // [number] pattern + return { kind: 'index', endPos: after, value: Number(content) }; + } else if (afterChar !== '(') { + // [text](url) pattern + return { kind: 'fail', nextPos: after }; + } + + i = after + 1; // skip '(' + if (i >= this.#buffer.length) { + return { kind: 'needMore' }; + } + + let url = ''; + while (i < this.#buffer.length && this.#buffer[i] !== ')') { + url += this.#buffer[i++]; + } + if (i >= this.#buffer.length) { + return { kind: 'needMore' }; + } + return { kind: 'ok', endPos: i + 1, text: content, url }; + } + + private isNumeric(str: string): boolean { + return !Number.isNaN(Number(str)) && str.trim() !== ''; + } + + private handlePattern( + pattern: Finished | Failed, + isWrapped: boolean, + start: number, + current: number + ): { output: string; nextPos: number } { + if (pattern.kind === 'fail') { + return { + output: this.#buffer.slice(current, pattern.nextPos), + nextPos: pattern.nextPos, + }; + } + + if (isWrapped) { + const afterLinkPos = pattern.endPos; + if (this.peek(afterLinkPos) !== ')') { + if (afterLinkPos >= this.#buffer.length) { + return { output: '', nextPos: current }; + } + return { output: '(', nextPos: start }; + } + + const out = + pattern.kind === 'index' + ? this.callback({ ...pattern, kind: 'index' }) + : this.callback({ ...pattern, kind: 'wrappedLink' }); + return { output: out, nextPos: afterLinkPos + 1 }; + } else { + const out = + pattern.kind === 'ok' + ? this.callback({ ...pattern, kind: 'link' }) + : this.callback({ ...pattern, kind: 'index' }); + return { output: out, nextPos: pattern.endPos }; + } + } +} + +export class CitationParser { + private readonly citations: string[] = []; + + private readonly parser = new StreamPatternParser(p => { + switch (p.kind) { + case 'index': { + if (p.value <= this.citations.length) { + return `[^${p.value}]`; + } + return `[${p.value}]`; + } + case 'wrappedLink': { + const index = this.citations.indexOf(p.url); + if (index === -1) { + this.citations.push(p.url); + return `[^${this.citations.length}]`; + } + return `[^${index + 1}]`; + } + case 'link': { + return `[${p.text}](${p.url})`; + } + } + }); + + public push(citation: string) { + this.citations.push(citation); + } + + public parse(content: string) { + return this.parser.write(content); } public end() { - return this.flush() + '\n' + this.getFootnotes(); - } - - private flush() { - const content = this.getTokenContent(); - this.resetToken(); - return content; + return this.parser.end() + '\n' + this.getFootnotes(); } private getFootnotes() { @@ -179,18 +339,4 @@ export class CitationParser { }); return footnotes.join('\n'); } - - private getTokenContent() { - return this.startToken.concat(this.numberToken, this.endToken).join(''); - } - - private resetToken() { - this.startToken = []; - this.endToken = []; - this.numberToken = []; - } - - private isNumeric(str: string) { - return !isNaN(Number(str)) && str.trim() !== ''; - } } diff --git a/packages/frontend/apps/ios/App/Packages/Intelligents/Sources/Intelligents/Backend/Model/Prompt.swift b/packages/frontend/apps/ios/App/Packages/Intelligents/Sources/Intelligents/Backend/Model/Prompt.swift index d5ba23e98c..6bbac1b907 100644 --- a/packages/frontend/apps/ios/App/Packages/Intelligents/Sources/Intelligents/Backend/Model/Prompt.swift +++ b/packages/frontend/apps/ios/App/Packages/Intelligents/Sources/Intelligents/Backend/Model/Prompt.swift @@ -9,7 +9,6 @@ import Foundation enum Prompt: String { #if DEBUG - case debug_chat_gpt4 = "debug:chat:gpt4" case debug_action_dalle3 = "debug:action:dalle3" case debug_action_fal_sd15 = "debug:action:fal-sd15" case debug_action_fal_upscaler = "debug:action:fal-upscaler" diff --git a/packages/frontend/core/src/blocksuite/ai/provider/prompt.ts b/packages/frontend/core/src/blocksuite/ai/provider/prompt.ts index 00cc036eac..85957f0172 100644 --- a/packages/frontend/core/src/blocksuite/ai/provider/prompt.ts +++ b/packages/frontend/core/src/blocksuite/ai/provider/prompt.ts @@ -1,7 +1,6 @@ // manually synced with packages/backend/server/src/data/migrations/utils/prompts.ts // TODO(@Peng): automate this export const promptKeys = [ - 'debug:chat:gpt4', 'debug:action:dalle3', 'debug:action:fal-sd15', 'debug:action:fal-upscaler', diff --git a/tests/affine-cloud-copilot/e2e/chat-with/image-block.spec.ts b/tests/affine-cloud-copilot/e2e/chat-with/image-block.spec.ts index a3dfe2577f..2e5596f213 100644 --- a/tests/affine-cloud-copilot/e2e/chat-with/image-block.spec.ts +++ b/tests/affine-cloud-copilot/e2e/chat-with/image-block.spec.ts @@ -30,6 +30,9 @@ test.describe('AIChatWith/Image', () => { const imageBlock = await page.locator('affine-image'); const captionBlock = await imageBlock.locator('block-caption-editor'); await expect(captionBlock).toBeVisible(); - await expect(captionBlock.locator('textarea')).toHaveValue(caption); + const captionText = await captionBlock.locator('textarea'); + expect(await captionText.inputValue().then(t => t.trim())).toBe( + caption.trim() + ); }); });