feat(core): add stream object api (#12841)

Close [AI-193](https://linear.app/affine-design/issue/AI-193)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Added support for streaming structured AI chat responses as objects,
enabling richer and more interactive chat experiences.
- Chat messages now include a new field displaying structured stream
objects, such as reasoning steps, text deltas, tool calls, and tool
results.
- GraphQL APIs and queries updated to expose these structured streaming
objects in chat histories.
  - Introduced a new streaming chat endpoint for object-based responses.

- **Bug Fixes**
- Improved error handling for streaming responses to ensure more robust
and informative error reporting.

- **Refactor**
- Centralized and streamlined session preparation and streaming logic
for AI chat providers.
  - Unified streaming setup across multiple AI model providers.

- **Tests**
- Extended test coverage for streaming object responses to ensure
reliability and correctness.

- **Documentation**
- Updated type definitions and schemas to reflect new streaming object
capabilities in both backend and frontend code.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Co-authored-by: DarkSky <25152247+darkskygit@users.noreply.github.com>
This commit is contained in:
Wu Yue
2025-06-19 09:13:18 +08:00
committed by GitHub
parent ce951ec316
commit 6169cdab3a
24 changed files with 722 additions and 149 deletions

View File

@@ -1,5 +1,6 @@
import type { ExecutionContext, TestFn } from 'ava';
import ava from 'ava';
import { z } from 'zod';
import { ServerFeature, ServerService } from '../core';
import { AuthService } from '../core/auth';
@@ -9,6 +10,8 @@ import { prompts, PromptService } from '../plugins/copilot/prompt';
import {
CopilotProviderFactory,
CopilotProviderType,
StreamObject,
StreamObjectSchema,
} from '../plugins/copilot/providers';
import { TranscriptionResponseSchema } from '../plugins/copilot/transcript/types';
import {
@@ -183,6 +186,16 @@ const checkUrl = (url: string) => {
}
};
const checkStreamObjects = (result: string) => {
try {
const streamObjects = JSON.parse(result);
z.array(StreamObjectSchema).parse(streamObjects);
return true;
} catch {
return false;
}
};
const retry = async (
action: string,
t: ExecutionContext<Tester>,
@@ -387,6 +400,20 @@ The term **“CRDT”** was first introduced by Marc Shapiro, Nuno Preguiça, Ca
},
type: 'text' as const,
},
{
name: 'stream objects',
promptName: ['Chat With AFFiNE AI'],
messages: [
{
role: 'user' as const,
content: 'what is AFFiNE AI',
},
],
verifier: (t: ExecutionContext<Tester>, result: string) => {
t.truthy(checkStreamObjects(result), 'should be valid stream objects');
},
type: 'object' as const,
},
{
name: 'Should transcribe short audio',
promptName: ['Transcript audio'],
@@ -680,6 +707,27 @@ for (const {
verifier?.(t, result);
break;
}
case 'object': {
const streamObjects: StreamObject[] = [];
for await (const chunk of provider.streamObject(
{ modelId: prompt.model },
[
...prompt.finish(
messages.reduce(
(acc, m) => Object.assign(acc, (m as any).params || {}),
{}
)
),
...messages,
],
finalConfig
)) {
streamObjects.push(chunk);
}
t.truthy(streamObjects, 'should return result');
verifier?.(t, JSON.stringify(streamObjects));
break;
}
case 'image': {
const finalMessage = [...messages];
const params = {};

View File

@@ -39,6 +39,7 @@ import {
array2sse,
audioTranscription,
chatWithImages,
chatWithStreamObject,
chatWithText,
chatWithTextStream,
chatWithWorkflow,
@@ -512,6 +513,28 @@ test('should be able to chat with api', async t => {
);
}
{
const sessionId = await createCopilotSession(
app,
id,
randomUUID(),
textPromptName
);
const messageId = await createCopilotMessage(app, sessionId);
const ret4 = await chatWithStreamObject(app, sessionId, messageId);
const objects = Array.from('generate text to object stream').map(data =>
JSON.stringify({ type: 'text-delta', textDelta: data })
);
t.is(
ret4,
textToEventStream(objects, messageId),
'should be able to chat with stream object'
);
}
Sinon.restore();
});

View File

@@ -9,6 +9,7 @@ import {
ModelInputType,
ModelOutputType,
PromptMessage,
StreamObject,
} from '../../plugins/copilot/providers';
import {
DEFAULT_DIMENSIONS,
@@ -23,7 +24,7 @@ export class MockCopilotProvider extends OpenAIProvider {
capabilities: [
{
input: [ModelInputType.Text],
output: [ModelOutputType.Text],
output: [ModelOutputType.Text, ModelOutputType.Object],
defaultForOutputType: true,
},
],
@@ -43,7 +44,7 @@ export class MockCopilotProvider extends OpenAIProvider {
capabilities: [
{
input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Text],
output: [ModelOutputType.Text, ModelOutputType.Object],
},
],
},
@@ -52,7 +53,7 @@ export class MockCopilotProvider extends OpenAIProvider {
capabilities: [
{
input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Text],
output: [ModelOutputType.Text, ModelOutputType.Object],
},
],
},
@@ -61,7 +62,7 @@ export class MockCopilotProvider extends OpenAIProvider {
capabilities: [
{
input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Text],
output: [ModelOutputType.Text, ModelOutputType.Object],
},
],
},
@@ -70,7 +71,7 @@ export class MockCopilotProvider extends OpenAIProvider {
capabilities: [
{
input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Text],
output: [ModelOutputType.Text, ModelOutputType.Object],
},
],
},
@@ -79,7 +80,11 @@ export class MockCopilotProvider extends OpenAIProvider {
capabilities: [
{
input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Text, ModelOutputType.Structured],
output: [
ModelOutputType.Text,
ModelOutputType.Object,
ModelOutputType.Structured,
],
},
],
},
@@ -98,7 +103,11 @@ export class MockCopilotProvider extends OpenAIProvider {
capabilities: [
{
input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Text, ModelOutputType.Structured],
output: [
ModelOutputType.Text,
ModelOutputType.Object,
ModelOutputType.Structured,
],
},
],
},
@@ -195,4 +204,24 @@ export class MockCopilotProvider extends OpenAIProvider {
await sleep(100);
return [Array.from(randomBytes(options.dimensions)).map(v => v % 128)];
}
override async *streamObject(
cond: ModelConditions,
messages: PromptMessage[],
options: CopilotChatOptions = {}
): AsyncIterable<StreamObject> {
const fullCond = { ...cond, outputType: ModelOutputType.Object };
await this.checkParams({ messages, cond: fullCond, options });
// make some time gap for history test case
await sleep(100);
const result = 'generate text to object stream';
for (const data of result) {
yield { type: 'text-delta', textDelta: data } as const;
if (options.signal?.aborted) {
break;
}
}
}
}

View File

@@ -582,6 +582,14 @@ export async function chatWithImages(
return chatWithText(app, sessionId, messageId, '/images');
}
export async function chatWithStreamObject(
app: TestingApp,
sessionId: string,
messageId?: string
) {
return chatWithText(app, sessionId, messageId, '/stream-object');
}
export async function unsplashSearch(
app: TestingApp,
params: Record<string, string> = {}

View File

@@ -51,6 +51,7 @@ import {
ModelInputType,
ModelOutputType,
} from './providers';
import { StreamObjectParser } from './providers/utils';
import { ChatSession, ChatSessionService } from './session';
import { CopilotStorage } from './storage';
import { ChatMessage, ChatQuerySchema } from './types';
@@ -189,6 +190,45 @@ export class CopilotController implements BeforeApplicationShutdown {
return merge(source$.pipe(finalize(() => subject$.next(null))), ping$);
}
private async prepareChatSession(
user: CurrentUser,
sessionId: string,
query: Record<string, string | string[]>,
outputType: ModelOutputType
) {
let { messageId, retry, modelId, params } = ChatQuerySchema.parse(query);
const { provider, model } = await this.chooseProvider(
outputType,
user.id,
sessionId,
messageId,
modelId
);
const [latestMessage, session] = await this.appendSessionMessage(
sessionId,
messageId,
retry
);
if (latestMessage) {
params = Object.assign({}, params, latestMessage.params, {
content: latestMessage.content,
attachments: latestMessage.attachments,
});
}
const finalMessage = session.finish(params);
return {
provider,
model,
session,
finalMessage,
};
}
@Get('/chat/:sessionId')
@CallMetric('ai', 'chat', { timer: true })
async chat(
@@ -200,36 +240,19 @@ export class CopilotController implements BeforeApplicationShutdown {
const info: any = { sessionId, params: query };
try {
let { messageId, retry, reasoning, webSearch, modelId, params } =
ChatQuerySchema.parse(query);
const { provider, model } = await this.chooseProvider(
ModelOutputType.Text,
user.id,
sessionId,
messageId,
modelId
);
const [latestMessage, session] = await this.appendSessionMessage(
sessionId,
messageId,
retry
);
const { provider, model, session, finalMessage } =
await this.prepareChatSession(
user,
sessionId,
query,
ModelOutputType.Text
);
info.model = model;
info.finalMessage = finalMessage.filter(m => m.role !== 'system');
metrics.ai.counter('chat_calls').add(1, { model });
if (latestMessage) {
params = Object.assign({}, params, latestMessage.params, {
content: latestMessage.content,
attachments: latestMessage.attachments,
});
}
const finalMessage = session.finish(params);
info.finalMessage = finalMessage.filter(m => m.role !== 'system');
const { reasoning, webSearch } = ChatQuerySchema.parse(query);
const content = await provider.text({ modelId: model }, finalMessage, {
...session.config.promptConfig,
signal: this.getSignal(req),
@@ -269,37 +292,20 @@ export class CopilotController implements BeforeApplicationShutdown {
const info: any = { sessionId, params: query, throwInStream: false };
try {
let { messageId, retry, reasoning, webSearch, modelId, params } =
ChatQuerySchema.parse(query);
const { provider, model } = await this.chooseProvider(
ModelOutputType.Text,
user.id,
sessionId,
messageId,
modelId
);
const [latestMessage, session] = await this.appendSessionMessage(
sessionId,
messageId,
retry
);
const { provider, model, session, finalMessage } =
await this.prepareChatSession(
user,
sessionId,
query,
ModelOutputType.Text
);
info.model = model;
metrics.ai.counter('chat_stream_calls').add(1, { model });
if (latestMessage) {
params = Object.assign({}, params, latestMessage.params, {
content: latestMessage.content,
attachments: latestMessage.attachments,
});
}
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
const finalMessage = session.finish(params);
info.finalMessage = finalMessage.filter(m => m.role !== 'system');
metrics.ai.counter('chat_stream_calls').add(1, { model });
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
const { messageId, reasoning, webSearch } = ChatQuerySchema.parse(query);
const source$ = from(
provider.streamText({ modelId: model }, finalMessage, {
...session.config.promptConfig,
@@ -348,6 +354,83 @@ export class CopilotController implements BeforeApplicationShutdown {
}
}
@Sse('/chat/:sessionId/stream-object')
@CallMetric('ai', 'chat_object_stream', { timer: true })
async chatStreamObject(
@CurrentUser() user: CurrentUser,
@Req() req: Request,
@Param('sessionId') sessionId: string,
@Query() query: Record<string, string>
): Promise<Observable<ChatEvent>> {
const info: any = { sessionId, params: query, throwInStream: false };
try {
const { provider, model, session, finalMessage } =
await this.prepareChatSession(
user,
sessionId,
query,
ModelOutputType.Object
);
info.model = model;
info.finalMessage = finalMessage.filter(m => m.role !== 'system');
metrics.ai.counter('chat_object_stream_calls').add(1, { model });
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
const { messageId, reasoning, webSearch } = ChatQuerySchema.parse(query);
const source$ = from(
provider.streamObject({ modelId: model }, finalMessage, {
...session.config.promptConfig,
signal: this.getSignal(req),
user: user.id,
workspace: session.config.workspaceId,
reasoning,
webSearch,
})
).pipe(
connect(shared$ =>
merge(
// actual chat event stream
shared$.pipe(
map(data => ({ type: 'message' as const, id: messageId, data }))
),
// save the generated text to the session
shared$.pipe(
toArray(),
concatMap(values => {
const parser = new StreamObjectParser();
const streamObjects = parser.mergeTextDelta(values);
const content = parser.mergeContent(streamObjects);
session.push({
role: 'assistant',
content,
streamObjects,
createdAt: new Date(),
});
return from(session.save());
}),
mergeMap(() => EMPTY)
)
)
),
catchError(e => {
metrics.ai.counter('chat_object_stream_errors').add(1);
info.throwInStream = true;
return mapSseError(e, info);
}),
finalize(() => {
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value - 1);
})
);
return this.mergePingStream(messageId || '', source$);
} catch (err) {
metrics.ai.counter('chat_object_stream_errors').add(1, info);
return mapSseError(err, info);
}
}
@Sse('/chat/:sessionId/workflow')
@CallMetric('ai', 'chat_workflow', { timer: true })
async chatWorkflow(

View File

@@ -13,11 +13,17 @@ import {
import { CopilotProvider } from '../provider';
import type {
CopilotChatOptions,
CopilotProviderModel,
ModelConditions,
PromptMessage,
StreamObject,
} from '../types';
import { ModelOutputType } from '../types';
import { chatToGPTMessage, TextStreamParser } from '../utils';
import {
chatToGPTMessage,
StreamObjectParser,
TextStreamParser,
} from '../utils';
export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
private readonly MAX_STEPS = 20;
@@ -92,21 +98,7 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
try {
metrics.ai.counter('chat_text_stream_calls').add(1, { model: model.id });
const [system, msgs] = await chatToGPTMessage(messages, true, true);
const { fullStream } = streamText({
model: this.instance(model.id),
system,
messages: msgs,
abortSignal: options.signal,
providerOptions: {
anthropic: this.getAnthropicOptions(options, model.id),
},
tools: await this.getTools(options, model.id),
maxSteps: this.MAX_STEPS,
experimental_continueSteps: true,
});
const fullStream = await this.getFullStream(model, messages, options);
const parser = new TextStreamParser();
for await (const chunk of fullStream) {
const result = parser.parse(chunk);
@@ -122,6 +114,60 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
}
}
override async *streamObject(
cond: ModelConditions,
messages: PromptMessage[],
options: CopilotChatOptions = {}
): AsyncIterable<StreamObject> {
const fullCond = { ...cond, outputType: ModelOutputType.Object };
await this.checkParams({ cond: fullCond, messages, options });
const model = this.selectModel(fullCond);
try {
metrics.ai
.counter('chat_object_stream_calls')
.add(1, { model: model.id });
const fullStream = await this.getFullStream(model, messages, options);
const parser = new StreamObjectParser();
for await (const chunk of fullStream) {
const result = parser.parse(chunk);
if (result) {
yield result;
}
if (options.signal?.aborted) {
await fullStream.cancel();
break;
}
}
} catch (e: any) {
metrics.ai
.counter('chat_object_stream_errors')
.add(1, { model: model.id });
throw this.handleError(e);
}
}
private async getFullStream(
model: CopilotProviderModel,
messages: PromptMessage[],
options: CopilotChatOptions = {}
) {
const [system, msgs] = await chatToGPTMessage(messages, true, true);
const { fullStream } = streamText({
model: this.instance(model.id),
system,
messages: msgs,
abortSignal: options.signal,
providerOptions: {
anthropic: this.getAnthropicOptions(options, model.id),
},
tools: await this.getTools(options, model.id),
maxSteps: this.MAX_STEPS,
experimental_continueSteps: true,
});
return fullStream;
}
private getAnthropicOptions(options: CopilotChatOptions, model: string) {
const result: AnthropicProviderOptions = {};
if (options?.reasoning && this.isReasoningModel(model)) {

View File

@@ -20,7 +20,7 @@ export class AnthropicOfficialProvider extends AnthropicProvider<AnthropicOffici
capabilities: [
{
input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Text],
output: [ModelOutputType.Text, ModelOutputType.Object],
},
],
},
@@ -29,7 +29,7 @@ export class AnthropicOfficialProvider extends AnthropicProvider<AnthropicOffici
capabilities: [
{
input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Text],
output: [ModelOutputType.Text, ModelOutputType.Object],
},
],
},
@@ -38,7 +38,7 @@ export class AnthropicOfficialProvider extends AnthropicProvider<AnthropicOffici
capabilities: [
{
input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Text],
output: [ModelOutputType.Text, ModelOutputType.Object],
},
],
},
@@ -47,7 +47,7 @@ export class AnthropicOfficialProvider extends AnthropicProvider<AnthropicOffici
capabilities: [
{
input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Text],
output: [ModelOutputType.Text, ModelOutputType.Object],
defaultForOutputType: true,
},
],

View File

@@ -18,7 +18,7 @@ export class AnthropicVertexProvider extends AnthropicProvider<AnthropicVertexCo
capabilities: [
{
input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Text],
output: [ModelOutputType.Text, ModelOutputType.Object],
},
],
},
@@ -27,7 +27,7 @@ export class AnthropicVertexProvider extends AnthropicProvider<AnthropicVertexCo
capabilities: [
{
input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Text],
output: [ModelOutputType.Text, ModelOutputType.Object],
},
],
},
@@ -36,7 +36,7 @@ export class AnthropicVertexProvider extends AnthropicProvider<AnthropicVertexCo
capabilities: [
{
input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Text],
output: [ModelOutputType.Text, ModelOutputType.Object],
},
],
},
@@ -45,7 +45,7 @@ export class AnthropicVertexProvider extends AnthropicProvider<AnthropicVertexCo
capabilities: [
{
input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Text],
output: [ModelOutputType.Text, ModelOutputType.Object],
defaultForOutputType: true,
},
],

View File

@@ -21,11 +21,17 @@ import { CopilotProvider } from '../provider';
import type {
CopilotChatOptions,
CopilotImageOptions,
CopilotProviderModel,
ModelConditions,
PromptMessage,
StreamObject,
} from '../types';
import { ModelOutputType } from '../types';
import { chatToGPTMessage, TextStreamParser } from '../utils';
import {
chatToGPTMessage,
StreamObjectParser,
TextStreamParser,
} from '../utils';
export const DEFAULT_DIMENSIONS = 256;
@@ -150,21 +156,7 @@ export abstract class GeminiProvider<T> extends CopilotProvider<T> {
try {
metrics.ai.counter('chat_text_stream_calls').add(1, { model: model.id });
const [system, msgs] = await chatToGPTMessage(messages);
const { fullStream } = streamText({
model: this.instance(model.id, {
useSearchGrounding: this.useSearchGrounding(options),
}),
system,
messages: msgs,
abortSignal: options.signal,
maxSteps: this.MAX_STEPS,
providerOptions: {
google: this.getGeminiOptions(options, model.id),
},
});
const fullStream = await this.getFullStream(model, messages, options);
const parser = new TextStreamParser();
for await (const chunk of fullStream) {
const result = parser.parse(chunk);
@@ -180,6 +172,60 @@ export abstract class GeminiProvider<T> extends CopilotProvider<T> {
}
}
override async *streamObject(
cond: ModelConditions,
messages: PromptMessage[],
options: CopilotChatOptions = {}
): AsyncIterable<StreamObject> {
const fullCond = { ...cond, outputType: ModelOutputType.Object };
await this.checkParams({ cond: fullCond, messages, options });
const model = this.selectModel(fullCond);
try {
metrics.ai
.counter('chat_object_stream_calls')
.add(1, { model: model.id });
const fullStream = await this.getFullStream(model, messages, options);
const parser = new StreamObjectParser();
for await (const chunk of fullStream) {
const result = parser.parse(chunk);
if (result) {
yield result;
}
if (options.signal?.aborted) {
await fullStream.cancel();
break;
}
}
} catch (e: any) {
metrics.ai
.counter('chat_object_stream_errors')
.add(1, { model: model.id });
throw this.handleError(e);
}
}
private async getFullStream(
model: CopilotProviderModel,
messages: PromptMessage[],
options: CopilotChatOptions = {}
) {
const [system, msgs] = await chatToGPTMessage(messages);
const { fullStream } = streamText({
model: this.instance(model.id, {
useSearchGrounding: this.useSearchGrounding(options),
}),
system,
messages: msgs,
abortSignal: options.signal,
maxSteps: this.MAX_STEPS,
providerOptions: {
google: this.getGeminiOptions(options, model.id),
},
});
return fullStream;
}
private getGeminiOptions(options: CopilotChatOptions, model: string) {
const result: GoogleGenerativeAIProviderOptions = {};
if (options?.reasoning && this.isReasoningModel(model)) {

View File

@@ -25,7 +25,11 @@ export class GeminiGenerativeProvider extends GeminiProvider<GeminiGenerativeCon
ModelInputType.Image,
ModelInputType.Audio,
],
output: [ModelOutputType.Text, ModelOutputType.Structured],
output: [
ModelOutputType.Text,
ModelOutputType.Object,
ModelOutputType.Structured,
],
defaultForOutputType: true,
},
],
@@ -40,7 +44,11 @@ export class GeminiGenerativeProvider extends GeminiProvider<GeminiGenerativeCon
ModelInputType.Image,
ModelInputType.Audio,
],
output: [ModelOutputType.Text, ModelOutputType.Structured],
output: [
ModelOutputType.Text,
ModelOutputType.Object,
ModelOutputType.Structured,
],
},
],
},
@@ -54,7 +62,11 @@ export class GeminiGenerativeProvider extends GeminiProvider<GeminiGenerativeCon
ModelInputType.Image,
ModelInputType.Audio,
],
output: [ModelOutputType.Text, ModelOutputType.Structured],
output: [
ModelOutputType.Text,
ModelOutputType.Object,
ModelOutputType.Structured,
],
},
],
},

View File

@@ -23,7 +23,11 @@ export class GeminiVertexProvider extends GeminiProvider<GeminiVertexConfig> {
ModelInputType.Image,
ModelInputType.Audio,
],
output: [ModelOutputType.Text, ModelOutputType.Structured],
output: [
ModelOutputType.Text,
ModelOutputType.Object,
ModelOutputType.Structured,
],
},
],
},
@@ -37,7 +41,11 @@ export class GeminiVertexProvider extends GeminiProvider<GeminiVertexConfig> {
ModelInputType.Image,
ModelInputType.Audio,
],
output: [ModelOutputType.Text, ModelOutputType.Structured],
output: [
ModelOutputType.Text,
ModelOutputType.Object,
ModelOutputType.Structured,
],
},
],
},

View File

@@ -27,12 +27,19 @@ import type {
CopilotChatTools,
CopilotEmbeddingOptions,
CopilotImageOptions,
CopilotProviderModel,
CopilotStructuredOptions,
ModelConditions,
PromptMessage,
StreamObject,
} from './types';
import { CopilotProviderType, ModelInputType, ModelOutputType } from './types';
import { chatToGPTMessage, CitationParser, TextStreamParser } from './utils';
import {
chatToGPTMessage,
CitationParser,
StreamObjectParser,
TextStreamParser,
} from './utils';
export const DEFAULT_DIMENSIONS = 256;
@@ -65,7 +72,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
capabilities: [
{
input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Text],
output: [ModelOutputType.Text, ModelOutputType.Object],
},
],
},
@@ -75,7 +82,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
capabilities: [
{
input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Text],
output: [ModelOutputType.Text, ModelOutputType.Object],
},
],
},
@@ -84,7 +91,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
capabilities: [
{
input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Text],
output: [ModelOutputType.Text, ModelOutputType.Object],
},
],
},
@@ -94,7 +101,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
capabilities: [
{
input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Text],
output: [ModelOutputType.Text, ModelOutputType.Object],
},
],
},
@@ -103,7 +110,11 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
capabilities: [
{
input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Text, ModelOutputType.Structured],
output: [
ModelOutputType.Text,
ModelOutputType.Object,
ModelOutputType.Structured,
],
defaultForOutputType: true,
},
],
@@ -113,7 +124,11 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
capabilities: [
{
input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Text, ModelOutputType.Structured],
output: [
ModelOutputType.Text,
ModelOutputType.Object,
ModelOutputType.Structured,
],
},
],
},
@@ -122,7 +137,11 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
capabilities: [
{
input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Text, ModelOutputType.Structured],
output: [
ModelOutputType.Text,
ModelOutputType.Object,
ModelOutputType.Structured,
],
},
],
},
@@ -131,7 +150,11 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
capabilities: [
{
input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Text, ModelOutputType.Structured],
output: [
ModelOutputType.Text,
ModelOutputType.Object,
ModelOutputType.Structured,
],
},
],
},
@@ -140,7 +163,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
capabilities: [
{
input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Text],
output: [ModelOutputType.Text, ModelOutputType.Object],
},
],
},
@@ -149,7 +172,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
capabilities: [
{
input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Text],
output: [ModelOutputType.Text, ModelOutputType.Object],
},
],
},
@@ -158,7 +181,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
capabilities: [
{
input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Text],
output: [ModelOutputType.Text, ModelOutputType.Object],
},
],
},
@@ -312,26 +335,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
try {
metrics.ai.counter('chat_text_stream_calls').add(1, { model: model.id });
const [system, msgs] = await chatToGPTMessage(messages);
const modelInstance = this.#instance.responses(model.id);
const { fullStream } = streamText({
model: modelInstance,
system,
messages: msgs,
frequencyPenalty: options.frequencyPenalty ?? 0,
presencePenalty: options.presencePenalty ?? 0,
temperature: options.temperature ?? 0,
maxTokens: options.maxTokens ?? 4096,
providerOptions: {
openai: this.getOpenAIOptions(options, model.id),
},
tools: await this.getTools(options, model.id),
maxSteps: this.MAX_STEPS,
abortSignal: options.signal,
});
const fullStream = await this.getFullStream(model, messages, options);
const citationParser = new CitationParser();
const textParser = new TextStreamParser();
for await (const chunk of fullStream) {
@@ -363,6 +367,39 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
}
}
override async *streamObject(
cond: ModelConditions,
messages: PromptMessage[],
options: CopilotChatOptions = {}
): AsyncIterable<StreamObject> {
const fullCond = { ...cond, outputType: ModelOutputType.Object };
await this.checkParams({ cond: fullCond, messages, options });
const model = this.selectModel(fullCond);
try {
metrics.ai
.counter('chat_object_stream_calls')
.add(1, { model: model.id });
const fullStream = await this.getFullStream(model, messages, options);
const parser = new StreamObjectParser();
for await (const chunk of fullStream) {
const result = parser.parse(chunk);
if (result) {
yield result;
}
if (options.signal?.aborted) {
await fullStream.cancel();
break;
}
}
} catch (e: any) {
metrics.ai
.counter('chat_object_stream_errors')
.add(1, { model: model.id });
throw this.handleError(e, model.id, options);
}
}
override async structure(
cond: ModelConditions,
messages: PromptMessage[],
@@ -403,6 +440,31 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
}
}
private async getFullStream(
model: CopilotProviderModel,
messages: PromptMessage[],
options: CopilotChatOptions = {}
) {
const [system, msgs] = await chatToGPTMessage(messages);
const modelInstance = this.#instance.responses(model.id);
const { fullStream } = streamText({
model: modelInstance,
system,
messages: msgs,
frequencyPenalty: options.frequencyPenalty ?? 0,
presencePenalty: options.presencePenalty ?? 0,
temperature: options.temperature ?? 0,
maxTokens: options.maxTokens ?? 4096,
providerOptions: {
openai: this.getOpenAIOptions(options, model.id),
},
tools: await this.getTools(options, model.id),
maxSteps: this.MAX_STEPS,
abortSignal: options.signal,
});
return fullStream;
}
// ====== text to image ======
private async *generateImageWithAttachments(
model: string,

View File

@@ -33,6 +33,7 @@ import {
ModelInputType,
type PromptMessage,
PromptMessageSchema,
StreamObject,
} from './types';
@Injectable()
@@ -225,6 +226,17 @@ export abstract class CopilotProvider<C = any> {
options?: CopilotChatOptions
): AsyncIterable<string>;
streamObject(
_model: ModelConditions,
_messages: PromptMessage[],
_options?: CopilotChatOptions
): AsyncIterable<StreamObject> {
throw new CopilotProviderNotSupported({
provider: this.type,
kind: 'object',
});
}
structure(
_cond: ModelConditions,
_messages: PromptMessage[],

View File

@@ -118,8 +118,33 @@ export const ChatMessageAttachment = z.union([
}),
]);
export const StreamObjectSchema = z.discriminatedUnion('type', [
z.object({
type: z.literal('text-delta'),
textDelta: z.string(),
}),
z.object({
type: z.literal('reasoning'),
textDelta: z.string(),
}),
z.object({
type: z.literal('tool-call'),
toolCallId: z.string(),
toolName: z.string(),
args: z.record(z.any()),
}),
z.object({
type: z.literal('tool-result'),
toolCallId: z.string(),
toolName: z.string(),
args: z.record(z.any()),
result: z.any(),
}),
]);
export const PureMessageSchema = z.object({
content: z.string(),
streamObjects: z.array(StreamObjectSchema).optional().nullable(),
attachments: z.array(ChatMessageAttachment).optional().nullable(),
params: z.record(z.any()).optional().nullable(),
});
@@ -129,6 +154,7 @@ export const PromptMessageSchema = PureMessageSchema.extend({
}).strict();
export type PromptMessage = z.infer<typeof PromptMessageSchema>;
export type PromptParams = NonNullable<PromptMessage['params']>;
export type StreamObject = z.infer<typeof StreamObjectSchema>;
// ========== options ==========
@@ -187,6 +213,7 @@ export enum ModelInputType {
export enum ModelOutputType {
Text = 'text',
Object = 'object',
Embedding = 'embedding',
Image = 'image',
Structured = 'structured',

View File

@@ -14,7 +14,7 @@ import {
createExaCrawlTool,
createExaSearchTool,
} from '../tools';
import { PromptMessage } from './types';
import { PromptMessage, StreamObject } from './types';
type ChatMessage = CoreUserMessage | CoreAssistantMessage;
@@ -387,6 +387,22 @@ export interface CustomAITools extends ToolSet {
type ChunkType = TextStreamPart<CustomAITools>['type'];
export function parseUnknownError(error: unknown) {
if (typeof error === 'string') {
throw new Error(error);
} else if (error instanceof Error) {
throw error;
} else if (
typeof error === 'object' &&
error !== null &&
'message' in error
) {
throw new Error(String(error.message));
} else {
throw new Error(JSON.stringify(error));
}
}
export class TextStreamParser {
private readonly CALLOUT_PREFIX = '\n[!]\n';
@@ -446,8 +462,8 @@ export class TextStreamParser {
break;
}
case 'error': {
const error = chunk.error as { type: string; message: string };
throw new Error(error.message);
parseUnknownError(chunk.error);
break;
}
}
this.lastType = chunk.type;
@@ -490,3 +506,54 @@ export class TextStreamParser {
return links;
}
}
export class StreamObjectParser {
public parse(chunk: TextStreamPart<CustomAITools>) {
switch (chunk.type) {
case 'reasoning':
case 'text-delta':
case 'tool-call':
case 'tool-result': {
return chunk;
}
case 'error': {
parseUnknownError(chunk.error);
return null;
}
default: {
return null;
}
}
}
public mergeTextDelta(chunks: StreamObject[]): StreamObject[] {
return chunks.reduce((acc, curr) => {
const prev = acc.at(-1);
switch (curr.type) {
case 'reasoning':
case 'text-delta': {
if (prev && prev.type === curr.type) {
prev.textDelta += curr.textDelta;
} else {
acc.push(curr);
}
break;
}
default: {
acc.push(curr);
break;
}
}
return acc;
}, [] as StreamObject[]);
}
public mergeContent(chunks: StreamObject[]): string {
return chunks.reduce((acc, curr) => {
if (curr.type === 'text-delta') {
acc += curr.textDelta;
}
return acc;
}, '');
}
}

View File

@@ -34,7 +34,7 @@ import { Admin } from '../../core/common';
import { AccessController } from '../../core/permission';
import { UserType } from '../../core/user';
import { PromptService } from './prompt';
import { PromptMessage } from './providers';
import { PromptMessage, StreamObject } from './providers';
import { ChatSessionService } from './session';
import { CopilotStorage } from './storage';
import {
@@ -168,6 +168,27 @@ class QueryChatHistoriesInput implements Partial<ListHistoriesOptions> {
// ================== Return Types ==================
@ObjectType('StreamObject')
class StreamObjectType {
@Field(() => String)
type!: string;
@Field(() => String, { nullable: true })
textDelta?: string;
@Field(() => String, { nullable: true })
toolCallId?: string;
@Field(() => String, { nullable: true })
toolName?: string;
@Field(() => GraphQLJSON, { nullable: true })
args?: any;
@Field(() => GraphQLJSON, { nullable: true })
result?: any;
}
@ObjectType('ChatMessage')
class ChatMessageType implements Partial<ChatMessage> {
// id will be null if message is a prompt message
@@ -180,6 +201,9 @@ class ChatMessageType implements Partial<ChatMessage> {
@Field(() => String)
content!: string;
@Field(() => [StreamObjectType], { nullable: true })
streamObjects!: StreamObject[];
@Field(() => [String], { nullable: true })
attachments!: string[];

View File

@@ -282,6 +282,7 @@ export class ChatSessionService {
await tx.aiSessionMessage.createMany({
data: state.messages.map(m => ({
...m,
streamObjects: m.streamObjects || undefined,
attachments: m.attachments || undefined,
params: omit(m.params, ['docs']) || undefined,
sessionId,
@@ -512,6 +513,7 @@ export class ChatSessionService {
id: true,
role: true,
content: true,
streamObjects: true,
attachments: true,
params: true,
createdAt: true,

View File

@@ -96,6 +96,7 @@ type ChatMessage {
id: ID
params: JSON
role: String!
streamObjects: [StreamObject!]
}
enum ContextCategories {
@@ -1628,6 +1629,15 @@ type SpaceShouldHaveOnlyOneOwnerDataType {
spaceId: String!
}
type StreamObject {
args: JSON
result: JSON
textDelta: String
toolCallId: String
toolName: String
type: String!
}
type SubscriptionAlreadyExistsDataType {
plan: String!
}