refactor(core): add text stream parser (#12459)

Support [AI-82](https://linear.app/affine-design/issue/AI-82).

Added a `TextStreamParser` class to standardize formatting of different types of AI stream chunks across providers.

### What changed?

- Created a new `TextStreamParser` class in `utils.ts` that handles formatting of various chunk types (text-delta, reasoning, tool-call, tool-result, error)
- Refactored the Anthropic, Gemini, and OpenAI providers to use this shared parser instead of duplicating formatting logic
- Added comprehensive tests for the new `TextStreamParser` class, including tests for individual chunk types and sequences of chunks
- Defined a common `AITools` type to standardize tool interfaces across providers

<!-- This is an auto-generated comment: release notes by coderabbit.ai -->
## Summary by CodeRabbit

- **New Features**
	- Enhanced formatting and structure for streamed AI responses, including improved handling of callouts, web search, and web crawl results.
- **Refactor**
	- Streamlined and unified the processing of streamed AI response chunks across providers for more consistent output.
- **Bug Fixes**
	- Improved error handling and display for streamed responses.
- **Tests**
	- Added comprehensive tests to ensure correct formatting and handling of various streamed message types.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
akumatus
2025-05-27 07:14:17 +00:00
parent 83caf98618
commit f4cba7d6ee
5 changed files with 355 additions and 224 deletions

View File

@@ -26,7 +26,10 @@ import {
ModelOutputType,
OpenAIProvider,
} from '../plugins/copilot/providers';
import { CitationParser } from '../plugins/copilot/providers/utils';
import {
CitationParser,
TextStreamParser,
} from '../plugins/copilot/providers/utils';
import { ChatSessionService } from '../plugins/copilot/session';
import { CopilotStorage } from '../plugins/copilot/storage';
import { CopilotTranscriptionService } from '../plugins/copilot/transcript';
@@ -1257,6 +1260,213 @@ test('CitationParser should replace openai style reference chunks', t => {
t.is(result, expected);
});
test('TextStreamParser should format different types of chunks correctly', t => {
// Define interfaces for fixtures
interface BaseFixture {
chunk: any;
description: string;
}
interface ContentFixture extends BaseFixture {
expected: string;
}
interface ErrorFixture extends BaseFixture {
errorMessage: string;
}
type ChunkFixture = ContentFixture | ErrorFixture;
// Define test fixtures for different chunk types
const fixtures: Record<string, ChunkFixture> = {
textDelta: {
chunk: {
type: 'text-delta' as const,
textDelta: 'Hello world',
} as any,
expected: 'Hello world',
description: 'should format text-delta correctly',
},
reasoning: {
chunk: {
type: 'reasoning' as const,
textDelta: 'I need to think about this',
} as any,
expected: '\n> [!]\n> I need to think about this',
description: 'should format reasoning as callout',
},
webSearch: {
chunk: {
type: 'tool-call' as const,
toolName: 'web_search_exa' as const,
toolCallId: 'test-id-1',
args: { query: 'test query', mode: 'AUTO' as const },
} as any,
expected: '\n> [!]\n> \n> Searching the web "test query"\n> ',
description: 'should format web search tool call correctly',
},
webCrawl: {
chunk: {
type: 'tool-call' as const,
toolName: 'web_crawl_exa' as const,
toolCallId: 'test-id-2',
args: { url: 'https://example.com' },
} as any,
expected: '\n> [!]\n> \n> Crawling the web "https://example.com"\n> ',
description: 'should format web crawl tool call correctly',
},
toolResult: {
chunk: {
type: 'tool-result' as const,
toolName: 'web_search_exa' as const,
toolCallId: 'test-id-1',
args: { query: 'test query', mode: 'AUTO' as const },
result: [
{
title: 'Test Title',
url: 'https://test.com',
content: 'Test content',
favicon: undefined,
publishedDate: undefined,
author: undefined,
},
{
title: null,
url: 'https://example.com',
content: 'Example content',
favicon: undefined,
publishedDate: undefined,
author: undefined,
},
],
} as any,
expected:
'\n> [!]\n> \n> \n> \n> [Test Title](https://test.com)\n> \n> \n> \n> [https://example.com](https://example.com)\n> \n> \n> ',
description: 'should format tool result correctly',
},
error: {
chunk: {
type: 'error' as const,
error: { type: 'testError', message: 'Test error message' },
} as any,
errorMessage: 'Test error message',
description: 'should throw error for error chunks',
},
};
// Test each chunk type individually
Object.entries(fixtures).forEach(([_name, fixture]) => {
const parser = new TextStreamParser();
if ('errorMessage' in fixture) {
t.throws(
() => parser.parse(fixture.chunk),
{ message: fixture.errorMessage },
fixture.description
);
} else {
const result = parser.parse(fixture.chunk);
t.is(result, fixture.expected, fixture.description);
}
});
});
test('TextStreamParser should process a sequence of message chunks', t => {
const parser = new TextStreamParser();
// Define test fixtures for mixed chunks sequence
const mixedChunksFixture = {
chunks: [
// Reasoning chunks
{
type: 'reasoning' as const,
textDelta: 'The user is asking about',
} as any,
{
type: 'reasoning' as const,
textDelta: ' recent advances in quantum computing',
} as any,
{
type: 'reasoning' as const,
textDelta: ' and how it might impact',
} as any,
{
type: 'reasoning' as const,
textDelta: ' cryptography and data security.',
} as any,
{
type: 'reasoning' as const,
textDelta:
' I should provide information on quantum supremacy achievements',
} as any,
// Text delta
{
type: 'text-delta' as const,
textDelta:
'Let me search for the latest breakthroughs in quantum computing and their ',
} as any,
// Tool call
{
type: 'tool-call' as const,
toolCallId: 'toolu_01ABCxyz123456789',
toolName: 'web_search_exa' as const,
args: {
query: 'latest quantum computing breakthroughs cryptography impact',
},
} as any,
// Tool result
{
type: 'tool-result' as const,
toolCallId: 'toolu_01ABCxyz123456789',
toolName: 'web_search_exa' as const,
args: {
query: 'latest quantum computing breakthroughs cryptography impact',
},
result: [
{
title: 'IBM Unveils 1000-Qubit Quantum Processor',
url: 'https://example.com/tech/quantum-computing-milestone',
},
],
} as any,
// More text deltas
{
type: 'text-delta' as const,
textDelta: 'implications for security.',
} as any,
{
type: 'text-delta' as const,
textDelta: '\n\nQuantum computing has made ',
} as any,
{
type: 'text-delta' as const,
textDelta: 'remarkable progress in the past year. ',
} as any,
{
type: 'text-delta' as const,
textDelta:
'The development of more stable qubits has accelerated research significantly.',
} as any,
],
expected:
'\n> [!]\n> The user is asking about recent advances in quantum computing and how it might impact cryptography and data security. I should provide information on quantum supremacy achievements\n\nLet me search for the latest breakthroughs in quantum computing and their \n> [!]\n> \n> Searching the web "latest quantum computing breakthroughs cryptography impact"\n> \n> \n> \n> [IBM Unveils 1000-Qubit Quantum Processor](https://example.com/tech/quantum-computing-milestone)\n> \n> \n> \n\nimplications for security.\n\nQuantum computing has made remarkable progress in the past year. The development of more stable qubits has accelerated research significantly.',
description:
'should format the entire stream correctly with proper sequence',
};
// Process all chunks sequentially
let result = '';
for (const chunk of mixedChunksFixture.chunks) {
result += parser.parse(chunk);
}
// Check final processed output
t.is(result, mixedChunksFixture.expected, mixedChunksFixture.description);
});
// ==================== context ====================
test('should be able to manage context', async t => {
const { context, prompt, session, event, jobs, storage } = t.context;

View File

@@ -18,13 +18,11 @@ import type {
PromptMessage,
} from '../types';
import { ModelOutputType } from '../types';
import { chatToGPTMessage } from '../utils';
import { chatToGPTMessage, TextStreamParser } from '../utils';
export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
private readonly MAX_STEPS = 20;
private readonly CALLOUT_PREFIX = '\n> [!]\n> ';
protected abstract instance:
| AnthropicSDKProvider
| GoogleVertexAnthropicProvider;
@@ -110,74 +108,14 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
experimental_continueSteps: true,
});
let lastType;
// reasoning, tool-call, tool-result need to mark as callout
let prefix: string | null = this.CALLOUT_PREFIX;
const parser = new TextStreamParser();
for await (const chunk of fullStream) {
switch (chunk.type) {
case 'text-delta': {
if (!prefix) {
prefix = this.CALLOUT_PREFIX;
}
let result = chunk.textDelta;
if (lastType !== chunk.type) {
result = '\n\n' + result;
}
yield result;
break;
}
case 'reasoning': {
if (prefix) {
yield prefix;
prefix = null;
}
let result = chunk.textDelta;
if (lastType !== chunk.type) {
result = '\n\n' + result;
}
yield this.markAsCallout(result);
break;
}
case 'tool-call': {
if (prefix) {
yield prefix;
prefix = null;
}
if (chunk.toolName === 'web_search_exa') {
yield this.markAsCallout(
`\nSearching the web "${chunk.args.query}"\n`
);
}
if (chunk.toolName === 'web_crawl_exa') {
yield this.markAsCallout(
`\nCrawling the web "${chunk.args.url}"\n`
);
}
break;
}
case 'tool-result': {
if (
chunk.toolName === 'web_search_exa' &&
Array.isArray(chunk.result)
) {
if (prefix) {
yield prefix;
prefix = null;
}
yield this.markAsCallout(this.getWebSearchLinks(chunk.result));
}
break;
}
case 'error': {
const error = chunk.error as { type: string; message: string };
throw new Error(error.message);
}
}
const result = parser.parse(chunk);
yield result;
if (options.signal?.aborted) {
await fullStream.cancel();
break;
}
lastType = chunk.type;
}
} catch (e: any) {
metrics.ai.counter('chat_text_stream_errors').add(1, { model: model.id });
@@ -203,22 +141,6 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
return result;
}
private getWebSearchLinks(
list: {
title: string | null;
url: string;
}[]
): string {
const links = list.reduce((acc, result) => {
return acc + `\n[${result.title ?? result.url}](${result.url})\n\n`;
}, '');
return links;
}
private markAsCallout(text: string) {
return text.replaceAll('\n', '\n> ');
}
private isReasoningModel(model: string) {
// only claude 3.7 sonnet supports reasoning config
return model.startsWith('claude-3-7-sonnet');

View File

@@ -25,7 +25,7 @@ import type {
PromptMessage,
} from '../types';
import { ModelOutputType } from '../types';
import { chatToGPTMessage } from '../utils';
import { chatToGPTMessage, TextStreamParser } from '../utils';
export const DEFAULT_DIMENSIONS = 256;
@@ -37,8 +37,6 @@ export type GeminiConfig = {
export abstract class GeminiProvider<T> extends CopilotProvider<T> {
private readonly MAX_STEPS = 20;
private readonly CALLOUT_PREFIX = '\n> [!]\n> ';
protected abstract instance:
| GoogleGenerativeAIProvider
| GoogleVertexProvider;
@@ -167,42 +165,13 @@ export abstract class GeminiProvider<T> extends CopilotProvider<T> {
},
});
let lastType;
// reasoning, tool-call, tool-result need to mark as callout
let prefix: string | null = this.CALLOUT_PREFIX;
const parser = new TextStreamParser();
for await (const chunk of fullStream) {
if (chunk) {
switch (chunk.type) {
case 'text-delta': {
let result = chunk.textDelta;
if (lastType !== chunk.type) {
result = '\n\n' + result;
}
yield result;
break;
}
case 'reasoning': {
if (prefix) {
yield prefix;
prefix = null;
}
let result = chunk.textDelta;
if (lastType !== chunk.type) {
result = '\n\n' + result;
}
yield this.markAsCallout(result);
break;
}
case 'error': {
const error = chunk.error as { type: string; message: string };
throw new Error(error.message);
}
}
if (options.signal?.aborted) {
await fullStream.cancel();
break;
}
lastType = chunk.type;
const result = parser.parse(chunk);
yield result;
if (options.signal?.aborted) {
await fullStream.cancel();
break;
}
}
} catch (e: any) {
@@ -222,10 +191,6 @@ export abstract class GeminiProvider<T> extends CopilotProvider<T> {
return result;
}
private markAsCallout(text: string) {
return text.replaceAll('\n', '\n> ');
}
private isReasoningModel(model: string) {
return model.startsWith('gemini-2.5');
}

View File

@@ -11,6 +11,7 @@ import {
generateObject,
generateText,
streamText,
ToolSet,
} from 'ai';
import {
@@ -30,7 +31,7 @@ import type {
PromptMessage,
} from './types';
import { CopilotProviderType, ModelInputType, ModelOutputType } from './types';
import { chatToGPTMessage, CitationParser } from './utils';
import { chatToGPTMessage, CitationParser, TextStreamParser } from './utils';
export const DEFAULT_DIMENSIONS = 256;
@@ -39,12 +40,6 @@ export type OpenAIConfig = {
baseUrl?: string;
};
type OpenAITools = {
web_search_preview: ReturnType<typeof openai.tools.webSearchPreview>;
web_search_exa: ReturnType<typeof createExaSearchTool>;
web_crawl_exa: ReturnType<typeof createExaCrawlTool>;
};
export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
readonly type = CopilotProviderType.OpenAI;
@@ -187,8 +182,6 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
private readonly MAX_STEPS = 20;
private readonly CALLOUT_PREFIX = '\n> [!]\n> ';
#instance!: VercelOpenAIProvider;
override configured(): boolean {
@@ -231,11 +224,8 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
}
}
private getTools(
options: CopilotChatOptions,
model: string
): Partial<OpenAITools> {
const tools: Partial<OpenAITools> = {};
private getTools(options: CopilotChatOptions, model: string): ToolSet {
const tools: ToolSet = {};
if (options?.tools?.length) {
for (const tool of options.tools) {
switch (tool) {
@@ -324,82 +314,34 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
providerOptions: {
openai: this.getOpenAIOptions(options, model.id),
},
tools: this.getTools(options, model.id) as OpenAITools,
tools: this.getTools(options, model.id),
maxSteps: this.MAX_STEPS,
abortSignal: options.signal,
});
const parser = new CitationParser();
let lastType;
// reasoning, tool-call, tool-result need to mark as callout
let prefix: string | null = this.CALLOUT_PREFIX;
const citationParser = new CitationParser();
const textParser = new TextStreamParser();
for await (const chunk of fullStream) {
if (chunk) {
switch (chunk.type) {
case 'text-delta': {
let result = parser.parse(chunk.textDelta);
if (lastType !== chunk.type) {
result = '\n\n' + result;
}
yield result;
break;
}
case 'reasoning': {
if (prefix) {
yield prefix;
prefix = null;
}
let result = chunk.textDelta;
if (lastType !== chunk.type) {
result = '\n\n' + result;
}
yield this.markAsCallout(result);
break;
}
case 'tool-call': {
if (prefix) {
yield prefix;
prefix = null;
}
if (chunk.toolName === 'web_search_exa') {
yield this.markAsCallout(
`\nSearching the web "${chunk.args.query}"\n`
);
}
if (chunk.toolName === 'web_crawl_exa') {
yield this.markAsCallout(
`\nCrawling the web "${chunk.args.url}"\n`
);
}
break;
}
case 'tool-result': {
if (
chunk.toolName === 'web_search_exa' &&
Array.isArray(chunk.result)
) {
yield this.markAsCallout(
`\n${this.getWebSearchLinks(chunk.result)}\n`
);
}
break;
}
case 'finish': {
const result = parser.end();
yield result;
break;
}
case 'error': {
const error = chunk.error as { type: string; message: string };
throw new Error(error.message);
}
}
if (options.signal?.aborted) {
await fullStream.cancel();
switch (chunk.type) {
case 'text-delta': {
let result = textParser.parse(chunk);
result = citationParser.parse(result);
yield result;
break;
}
lastType = chunk.type;
case 'finish': {
const result = citationParser.end();
yield result;
break;
}
default: {
yield textParser.parse(chunk);
break;
}
}
if (options.signal?.aborted) {
await fullStream.cancel();
break;
}
}
} catch (e: any) {
@@ -539,22 +481,6 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
return result;
}
private getWebSearchLinks(
list: {
title: string | null;
url: string;
}[]
): string {
const links = list.reduce((acc, result) => {
return acc + `\n[${result.title ?? result.url}](${result.url})\n\n`;
}, '');
return links;
}
private markAsCallout(text: string) {
return text.replaceAll('\n', '\n> ');
}
private isReasoningModel(model: string) {
// o series reasoning models
return model.startsWith('o');

View File

@@ -4,9 +4,12 @@ import {
FilePart,
ImagePart,
TextPart,
TextStreamPart,
ToolSet,
} from 'ai';
import { ZodType } from 'zod';
import { createExaCrawlTool, createExaSearchTool } from '../tools';
import { PromptMessage } from './types';
type ChatMessage = CoreUserMessage | CoreAssistantMessage;
@@ -367,3 +370,108 @@ export class CitationParser {
return footnotes.join('\n');
}
}
export interface CustomAITools extends ToolSet {
web_search_exa: ReturnType<typeof createExaSearchTool>;
web_crawl_exa: ReturnType<typeof createExaCrawlTool>;
}
type ChunkType = TextStreamPart<CustomAITools>['type'];
export class TextStreamParser {
private readonly CALLOUT_PREFIX = '\n[!]\n';
private lastType: ChunkType | undefined;
private prefix: string | null = this.CALLOUT_PREFIX;
public parse(chunk: TextStreamPart<CustomAITools>) {
let result = '';
switch (chunk.type) {
case 'text-delta': {
if (!this.prefix) {
this.resetPrefix();
}
result = chunk.textDelta;
result = this.addNewline(chunk.type, result);
break;
}
case 'reasoning': {
result = chunk.textDelta;
result = this.addPrefix(result);
result = this.markAsCallout(result);
break;
}
case 'tool-call': {
result = this.addPrefix(result);
switch (chunk.toolName) {
case 'web_search_exa': {
result += `\nSearching the web "${chunk.args.query}"\n`;
break;
}
case 'web_crawl_exa': {
result += `\nCrawling the web "${chunk.args.url}"\n`;
break;
}
}
result = this.markAsCallout(result);
break;
}
case 'tool-result': {
result = this.addPrefix(result);
switch (chunk.toolName) {
case 'web_search_exa': {
if (Array.isArray(chunk.result)) {
result += `\n${this.getWebSearchLinks(chunk.result)}\n`;
}
break;
}
}
result = this.markAsCallout(result);
break;
}
case 'error': {
const error = chunk.error as { type: string; message: string };
throw new Error(error.message);
}
}
this.lastType = chunk.type;
return result;
}
private addPrefix(text: string) {
if (this.prefix) {
const result = this.prefix + text;
this.prefix = null;
return result;
}
return text;
}
private resetPrefix() {
this.prefix = this.CALLOUT_PREFIX;
}
private addNewline(chunkType: ChunkType, result: string) {
if (this.lastType && this.lastType !== chunkType) {
return '\n\n' + result;
}
return result;
}
private markAsCallout(text: string) {
return text.replaceAll('\n', '\n> ');
}
private getWebSearchLinks(
list: {
title: string | null;
url: string;
}[]
): string {
const links = list.reduce((acc, result) => {
return acc + `\n\n[${result.title ?? result.url}](${result.url})\n\n`;
}, '');
return links;
}
}