mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-15 05:37:32 +00:00
feat(core): add stream object api (#12841)
Close [AI-193](https://linear.app/affine-design/issue/AI-193) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Added support for streaming structured AI chat responses as objects, enabling richer and more interactive chat experiences. - Chat messages now include a new field displaying structured stream objects, such as reasoning steps, text deltas, tool calls, and tool results. - GraphQL APIs and queries updated to expose these structured streaming objects in chat histories. - Introduced a new streaming chat endpoint for object-based responses. - **Bug Fixes** - Improved error handling for streaming responses to ensure more robust and informative error reporting. - **Refactor** - Centralized and streamlined session preparation and streaming logic for AI chat providers. - Unified streaming setup across multiple AI model providers. - **Tests** - Extended test coverage for streaming object responses to ensure reliability and correctness. - **Documentation** - Updated type definitions and schemas to reflect new streaming object capabilities in both backend and frontend code. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Co-authored-by: DarkSky <25152247+darkskygit@users.noreply.github.com>
This commit is contained in:
@@ -51,6 +51,7 @@ import {
|
||||
ModelInputType,
|
||||
ModelOutputType,
|
||||
} from './providers';
|
||||
import { StreamObjectParser } from './providers/utils';
|
||||
import { ChatSession, ChatSessionService } from './session';
|
||||
import { CopilotStorage } from './storage';
|
||||
import { ChatMessage, ChatQuerySchema } from './types';
|
||||
@@ -189,6 +190,45 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
return merge(source$.pipe(finalize(() => subject$.next(null))), ping$);
|
||||
}
|
||||
|
||||
private async prepareChatSession(
|
||||
user: CurrentUser,
|
||||
sessionId: string,
|
||||
query: Record<string, string | string[]>,
|
||||
outputType: ModelOutputType
|
||||
) {
|
||||
let { messageId, retry, modelId, params } = ChatQuerySchema.parse(query);
|
||||
|
||||
const { provider, model } = await this.chooseProvider(
|
||||
outputType,
|
||||
user.id,
|
||||
sessionId,
|
||||
messageId,
|
||||
modelId
|
||||
);
|
||||
|
||||
const [latestMessage, session] = await this.appendSessionMessage(
|
||||
sessionId,
|
||||
messageId,
|
||||
retry
|
||||
);
|
||||
|
||||
if (latestMessage) {
|
||||
params = Object.assign({}, params, latestMessage.params, {
|
||||
content: latestMessage.content,
|
||||
attachments: latestMessage.attachments,
|
||||
});
|
||||
}
|
||||
|
||||
const finalMessage = session.finish(params);
|
||||
|
||||
return {
|
||||
provider,
|
||||
model,
|
||||
session,
|
||||
finalMessage,
|
||||
};
|
||||
}
|
||||
|
||||
@Get('/chat/:sessionId')
|
||||
@CallMetric('ai', 'chat', { timer: true })
|
||||
async chat(
|
||||
@@ -200,36 +240,19 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
const info: any = { sessionId, params: query };
|
||||
|
||||
try {
|
||||
let { messageId, retry, reasoning, webSearch, modelId, params } =
|
||||
ChatQuerySchema.parse(query);
|
||||
|
||||
const { provider, model } = await this.chooseProvider(
|
||||
ModelOutputType.Text,
|
||||
user.id,
|
||||
sessionId,
|
||||
messageId,
|
||||
modelId
|
||||
);
|
||||
|
||||
const [latestMessage, session] = await this.appendSessionMessage(
|
||||
sessionId,
|
||||
messageId,
|
||||
retry
|
||||
);
|
||||
const { provider, model, session, finalMessage } =
|
||||
await this.prepareChatSession(
|
||||
user,
|
||||
sessionId,
|
||||
query,
|
||||
ModelOutputType.Text
|
||||
);
|
||||
|
||||
info.model = model;
|
||||
info.finalMessage = finalMessage.filter(m => m.role !== 'system');
|
||||
metrics.ai.counter('chat_calls').add(1, { model });
|
||||
|
||||
if (latestMessage) {
|
||||
params = Object.assign({}, params, latestMessage.params, {
|
||||
content: latestMessage.content,
|
||||
attachments: latestMessage.attachments,
|
||||
});
|
||||
}
|
||||
|
||||
const finalMessage = session.finish(params);
|
||||
info.finalMessage = finalMessage.filter(m => m.role !== 'system');
|
||||
|
||||
const { reasoning, webSearch } = ChatQuerySchema.parse(query);
|
||||
const content = await provider.text({ modelId: model }, finalMessage, {
|
||||
...session.config.promptConfig,
|
||||
signal: this.getSignal(req),
|
||||
@@ -269,37 +292,20 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
const info: any = { sessionId, params: query, throwInStream: false };
|
||||
|
||||
try {
|
||||
let { messageId, retry, reasoning, webSearch, modelId, params } =
|
||||
ChatQuerySchema.parse(query);
|
||||
|
||||
const { provider, model } = await this.chooseProvider(
|
||||
ModelOutputType.Text,
|
||||
user.id,
|
||||
sessionId,
|
||||
messageId,
|
||||
modelId
|
||||
);
|
||||
|
||||
const [latestMessage, session] = await this.appendSessionMessage(
|
||||
sessionId,
|
||||
messageId,
|
||||
retry
|
||||
);
|
||||
const { provider, model, session, finalMessage } =
|
||||
await this.prepareChatSession(
|
||||
user,
|
||||
sessionId,
|
||||
query,
|
||||
ModelOutputType.Text
|
||||
);
|
||||
|
||||
info.model = model;
|
||||
metrics.ai.counter('chat_stream_calls').add(1, { model });
|
||||
|
||||
if (latestMessage) {
|
||||
params = Object.assign({}, params, latestMessage.params, {
|
||||
content: latestMessage.content,
|
||||
attachments: latestMessage.attachments,
|
||||
});
|
||||
}
|
||||
|
||||
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
|
||||
const finalMessage = session.finish(params);
|
||||
info.finalMessage = finalMessage.filter(m => m.role !== 'system');
|
||||
metrics.ai.counter('chat_stream_calls').add(1, { model });
|
||||
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
|
||||
|
||||
const { messageId, reasoning, webSearch } = ChatQuerySchema.parse(query);
|
||||
const source$ = from(
|
||||
provider.streamText({ modelId: model }, finalMessage, {
|
||||
...session.config.promptConfig,
|
||||
@@ -348,6 +354,83 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
}
|
||||
}
|
||||
|
||||
@Sse('/chat/:sessionId/stream-object')
|
||||
@CallMetric('ai', 'chat_object_stream', { timer: true })
|
||||
async chatStreamObject(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Req() req: Request,
|
||||
@Param('sessionId') sessionId: string,
|
||||
@Query() query: Record<string, string>
|
||||
): Promise<Observable<ChatEvent>> {
|
||||
const info: any = { sessionId, params: query, throwInStream: false };
|
||||
|
||||
try {
|
||||
const { provider, model, session, finalMessage } =
|
||||
await this.prepareChatSession(
|
||||
user,
|
||||
sessionId,
|
||||
query,
|
||||
ModelOutputType.Object
|
||||
);
|
||||
|
||||
info.model = model;
|
||||
info.finalMessage = finalMessage.filter(m => m.role !== 'system');
|
||||
metrics.ai.counter('chat_object_stream_calls').add(1, { model });
|
||||
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
|
||||
|
||||
const { messageId, reasoning, webSearch } = ChatQuerySchema.parse(query);
|
||||
const source$ = from(
|
||||
provider.streamObject({ modelId: model }, finalMessage, {
|
||||
...session.config.promptConfig,
|
||||
signal: this.getSignal(req),
|
||||
user: user.id,
|
||||
workspace: session.config.workspaceId,
|
||||
reasoning,
|
||||
webSearch,
|
||||
})
|
||||
).pipe(
|
||||
connect(shared$ =>
|
||||
merge(
|
||||
// actual chat event stream
|
||||
shared$.pipe(
|
||||
map(data => ({ type: 'message' as const, id: messageId, data }))
|
||||
),
|
||||
// save the generated text to the session
|
||||
shared$.pipe(
|
||||
toArray(),
|
||||
concatMap(values => {
|
||||
const parser = new StreamObjectParser();
|
||||
const streamObjects = parser.mergeTextDelta(values);
|
||||
const content = parser.mergeContent(streamObjects);
|
||||
session.push({
|
||||
role: 'assistant',
|
||||
content,
|
||||
streamObjects,
|
||||
createdAt: new Date(),
|
||||
});
|
||||
return from(session.save());
|
||||
}),
|
||||
mergeMap(() => EMPTY)
|
||||
)
|
||||
)
|
||||
),
|
||||
catchError(e => {
|
||||
metrics.ai.counter('chat_object_stream_errors').add(1);
|
||||
info.throwInStream = true;
|
||||
return mapSseError(e, info);
|
||||
}),
|
||||
finalize(() => {
|
||||
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value - 1);
|
||||
})
|
||||
);
|
||||
|
||||
return this.mergePingStream(messageId || '', source$);
|
||||
} catch (err) {
|
||||
metrics.ai.counter('chat_object_stream_errors').add(1, info);
|
||||
return mapSseError(err, info);
|
||||
}
|
||||
}
|
||||
|
||||
@Sse('/chat/:sessionId/workflow')
|
||||
@CallMetric('ai', 'chat_workflow', { timer: true })
|
||||
async chatWorkflow(
|
||||
|
||||
@@ -13,11 +13,17 @@ import {
|
||||
import { CopilotProvider } from '../provider';
|
||||
import type {
|
||||
CopilotChatOptions,
|
||||
CopilotProviderModel,
|
||||
ModelConditions,
|
||||
PromptMessage,
|
||||
StreamObject,
|
||||
} from '../types';
|
||||
import { ModelOutputType } from '../types';
|
||||
import { chatToGPTMessage, TextStreamParser } from '../utils';
|
||||
import {
|
||||
chatToGPTMessage,
|
||||
StreamObjectParser,
|
||||
TextStreamParser,
|
||||
} from '../utils';
|
||||
|
||||
export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
|
||||
private readonly MAX_STEPS = 20;
|
||||
@@ -92,21 +98,7 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
|
||||
|
||||
try {
|
||||
metrics.ai.counter('chat_text_stream_calls').add(1, { model: model.id });
|
||||
const [system, msgs] = await chatToGPTMessage(messages, true, true);
|
||||
|
||||
const { fullStream } = streamText({
|
||||
model: this.instance(model.id),
|
||||
system,
|
||||
messages: msgs,
|
||||
abortSignal: options.signal,
|
||||
providerOptions: {
|
||||
anthropic: this.getAnthropicOptions(options, model.id),
|
||||
},
|
||||
tools: await this.getTools(options, model.id),
|
||||
maxSteps: this.MAX_STEPS,
|
||||
experimental_continueSteps: true,
|
||||
});
|
||||
|
||||
const fullStream = await this.getFullStream(model, messages, options);
|
||||
const parser = new TextStreamParser();
|
||||
for await (const chunk of fullStream) {
|
||||
const result = parser.parse(chunk);
|
||||
@@ -122,6 +114,60 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
|
||||
}
|
||||
}
|
||||
|
||||
override async *streamObject(
|
||||
cond: ModelConditions,
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): AsyncIterable<StreamObject> {
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Object };
|
||||
await this.checkParams({ cond: fullCond, messages, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
try {
|
||||
metrics.ai
|
||||
.counter('chat_object_stream_calls')
|
||||
.add(1, { model: model.id });
|
||||
const fullStream = await this.getFullStream(model, messages, options);
|
||||
const parser = new StreamObjectParser();
|
||||
for await (const chunk of fullStream) {
|
||||
const result = parser.parse(chunk);
|
||||
if (result) {
|
||||
yield result;
|
||||
}
|
||||
if (options.signal?.aborted) {
|
||||
await fullStream.cancel();
|
||||
break;
|
||||
}
|
||||
}
|
||||
} catch (e: any) {
|
||||
metrics.ai
|
||||
.counter('chat_object_stream_errors')
|
||||
.add(1, { model: model.id });
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
|
||||
private async getFullStream(
|
||||
model: CopilotProviderModel,
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
) {
|
||||
const [system, msgs] = await chatToGPTMessage(messages, true, true);
|
||||
const { fullStream } = streamText({
|
||||
model: this.instance(model.id),
|
||||
system,
|
||||
messages: msgs,
|
||||
abortSignal: options.signal,
|
||||
providerOptions: {
|
||||
anthropic: this.getAnthropicOptions(options, model.id),
|
||||
},
|
||||
tools: await this.getTools(options, model.id),
|
||||
maxSteps: this.MAX_STEPS,
|
||||
experimental_continueSteps: true,
|
||||
});
|
||||
return fullStream;
|
||||
}
|
||||
|
||||
private getAnthropicOptions(options: CopilotChatOptions, model: string) {
|
||||
const result: AnthropicProviderOptions = {};
|
||||
if (options?.reasoning && this.isReasoningModel(model)) {
|
||||
|
||||
@@ -20,7 +20,7 @@ export class AnthropicOfficialProvider extends AnthropicProvider<AnthropicOffici
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [ModelOutputType.Text],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Object],
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -29,7 +29,7 @@ export class AnthropicOfficialProvider extends AnthropicProvider<AnthropicOffici
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [ModelOutputType.Text],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Object],
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -38,7 +38,7 @@ export class AnthropicOfficialProvider extends AnthropicProvider<AnthropicOffici
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [ModelOutputType.Text],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Object],
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -47,7 +47,7 @@ export class AnthropicOfficialProvider extends AnthropicProvider<AnthropicOffici
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [ModelOutputType.Text],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Object],
|
||||
defaultForOutputType: true,
|
||||
},
|
||||
],
|
||||
|
||||
@@ -18,7 +18,7 @@ export class AnthropicVertexProvider extends AnthropicProvider<AnthropicVertexCo
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [ModelOutputType.Text],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Object],
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -27,7 +27,7 @@ export class AnthropicVertexProvider extends AnthropicProvider<AnthropicVertexCo
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [ModelOutputType.Text],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Object],
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -36,7 +36,7 @@ export class AnthropicVertexProvider extends AnthropicProvider<AnthropicVertexCo
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [ModelOutputType.Text],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Object],
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -45,7 +45,7 @@ export class AnthropicVertexProvider extends AnthropicProvider<AnthropicVertexCo
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [ModelOutputType.Text],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Object],
|
||||
defaultForOutputType: true,
|
||||
},
|
||||
],
|
||||
|
||||
@@ -21,11 +21,17 @@ import { CopilotProvider } from '../provider';
|
||||
import type {
|
||||
CopilotChatOptions,
|
||||
CopilotImageOptions,
|
||||
CopilotProviderModel,
|
||||
ModelConditions,
|
||||
PromptMessage,
|
||||
StreamObject,
|
||||
} from '../types';
|
||||
import { ModelOutputType } from '../types';
|
||||
import { chatToGPTMessage, TextStreamParser } from '../utils';
|
||||
import {
|
||||
chatToGPTMessage,
|
||||
StreamObjectParser,
|
||||
TextStreamParser,
|
||||
} from '../utils';
|
||||
|
||||
export const DEFAULT_DIMENSIONS = 256;
|
||||
|
||||
@@ -150,21 +156,7 @@ export abstract class GeminiProvider<T> extends CopilotProvider<T> {
|
||||
|
||||
try {
|
||||
metrics.ai.counter('chat_text_stream_calls').add(1, { model: model.id });
|
||||
const [system, msgs] = await chatToGPTMessage(messages);
|
||||
|
||||
const { fullStream } = streamText({
|
||||
model: this.instance(model.id, {
|
||||
useSearchGrounding: this.useSearchGrounding(options),
|
||||
}),
|
||||
system,
|
||||
messages: msgs,
|
||||
abortSignal: options.signal,
|
||||
maxSteps: this.MAX_STEPS,
|
||||
providerOptions: {
|
||||
google: this.getGeminiOptions(options, model.id),
|
||||
},
|
||||
});
|
||||
|
||||
const fullStream = await this.getFullStream(model, messages, options);
|
||||
const parser = new TextStreamParser();
|
||||
for await (const chunk of fullStream) {
|
||||
const result = parser.parse(chunk);
|
||||
@@ -180,6 +172,60 @@ export abstract class GeminiProvider<T> extends CopilotProvider<T> {
|
||||
}
|
||||
}
|
||||
|
||||
override async *streamObject(
|
||||
cond: ModelConditions,
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): AsyncIterable<StreamObject> {
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Object };
|
||||
await this.checkParams({ cond: fullCond, messages, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
try {
|
||||
metrics.ai
|
||||
.counter('chat_object_stream_calls')
|
||||
.add(1, { model: model.id });
|
||||
const fullStream = await this.getFullStream(model, messages, options);
|
||||
const parser = new StreamObjectParser();
|
||||
for await (const chunk of fullStream) {
|
||||
const result = parser.parse(chunk);
|
||||
if (result) {
|
||||
yield result;
|
||||
}
|
||||
if (options.signal?.aborted) {
|
||||
await fullStream.cancel();
|
||||
break;
|
||||
}
|
||||
}
|
||||
} catch (e: any) {
|
||||
metrics.ai
|
||||
.counter('chat_object_stream_errors')
|
||||
.add(1, { model: model.id });
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
|
||||
private async getFullStream(
|
||||
model: CopilotProviderModel,
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
) {
|
||||
const [system, msgs] = await chatToGPTMessage(messages);
|
||||
const { fullStream } = streamText({
|
||||
model: this.instance(model.id, {
|
||||
useSearchGrounding: this.useSearchGrounding(options),
|
||||
}),
|
||||
system,
|
||||
messages: msgs,
|
||||
abortSignal: options.signal,
|
||||
maxSteps: this.MAX_STEPS,
|
||||
providerOptions: {
|
||||
google: this.getGeminiOptions(options, model.id),
|
||||
},
|
||||
});
|
||||
return fullStream;
|
||||
}
|
||||
|
||||
private getGeminiOptions(options: CopilotChatOptions, model: string) {
|
||||
const result: GoogleGenerativeAIProviderOptions = {};
|
||||
if (options?.reasoning && this.isReasoningModel(model)) {
|
||||
|
||||
@@ -25,7 +25,11 @@ export class GeminiGenerativeProvider extends GeminiProvider<GeminiGenerativeCon
|
||||
ModelInputType.Image,
|
||||
ModelInputType.Audio,
|
||||
],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Structured],
|
||||
output: [
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
ModelOutputType.Structured,
|
||||
],
|
||||
defaultForOutputType: true,
|
||||
},
|
||||
],
|
||||
@@ -40,7 +44,11 @@ export class GeminiGenerativeProvider extends GeminiProvider<GeminiGenerativeCon
|
||||
ModelInputType.Image,
|
||||
ModelInputType.Audio,
|
||||
],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Structured],
|
||||
output: [
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
ModelOutputType.Structured,
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -54,7 +62,11 @@ export class GeminiGenerativeProvider extends GeminiProvider<GeminiGenerativeCon
|
||||
ModelInputType.Image,
|
||||
ModelInputType.Audio,
|
||||
],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Structured],
|
||||
output: [
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
ModelOutputType.Structured,
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
|
||||
@@ -23,7 +23,11 @@ export class GeminiVertexProvider extends GeminiProvider<GeminiVertexConfig> {
|
||||
ModelInputType.Image,
|
||||
ModelInputType.Audio,
|
||||
],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Structured],
|
||||
output: [
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
ModelOutputType.Structured,
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -37,7 +41,11 @@ export class GeminiVertexProvider extends GeminiProvider<GeminiVertexConfig> {
|
||||
ModelInputType.Image,
|
||||
ModelInputType.Audio,
|
||||
],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Structured],
|
||||
output: [
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
ModelOutputType.Structured,
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
|
||||
@@ -27,12 +27,19 @@ import type {
|
||||
CopilotChatTools,
|
||||
CopilotEmbeddingOptions,
|
||||
CopilotImageOptions,
|
||||
CopilotProviderModel,
|
||||
CopilotStructuredOptions,
|
||||
ModelConditions,
|
||||
PromptMessage,
|
||||
StreamObject,
|
||||
} from './types';
|
||||
import { CopilotProviderType, ModelInputType, ModelOutputType } from './types';
|
||||
import { chatToGPTMessage, CitationParser, TextStreamParser } from './utils';
|
||||
import {
|
||||
chatToGPTMessage,
|
||||
CitationParser,
|
||||
StreamObjectParser,
|
||||
TextStreamParser,
|
||||
} from './utils';
|
||||
|
||||
export const DEFAULT_DIMENSIONS = 256;
|
||||
|
||||
@@ -65,7 +72,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [ModelOutputType.Text],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Object],
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -75,7 +82,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [ModelOutputType.Text],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Object],
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -84,7 +91,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [ModelOutputType.Text],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Object],
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -94,7 +101,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [ModelOutputType.Text],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Object],
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -103,7 +110,11 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Structured],
|
||||
output: [
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
ModelOutputType.Structured,
|
||||
],
|
||||
defaultForOutputType: true,
|
||||
},
|
||||
],
|
||||
@@ -113,7 +124,11 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Structured],
|
||||
output: [
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
ModelOutputType.Structured,
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -122,7 +137,11 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Structured],
|
||||
output: [
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
ModelOutputType.Structured,
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -131,7 +150,11 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Structured],
|
||||
output: [
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
ModelOutputType.Structured,
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -140,7 +163,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [ModelOutputType.Text],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Object],
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -149,7 +172,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [ModelOutputType.Text],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Object],
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -158,7 +181,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [ModelOutputType.Text],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Object],
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -312,26 +335,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
|
||||
try {
|
||||
metrics.ai.counter('chat_text_stream_calls').add(1, { model: model.id });
|
||||
const [system, msgs] = await chatToGPTMessage(messages);
|
||||
|
||||
const modelInstance = this.#instance.responses(model.id);
|
||||
|
||||
const { fullStream } = streamText({
|
||||
model: modelInstance,
|
||||
system,
|
||||
messages: msgs,
|
||||
frequencyPenalty: options.frequencyPenalty ?? 0,
|
||||
presencePenalty: options.presencePenalty ?? 0,
|
||||
temperature: options.temperature ?? 0,
|
||||
maxTokens: options.maxTokens ?? 4096,
|
||||
providerOptions: {
|
||||
openai: this.getOpenAIOptions(options, model.id),
|
||||
},
|
||||
tools: await this.getTools(options, model.id),
|
||||
maxSteps: this.MAX_STEPS,
|
||||
abortSignal: options.signal,
|
||||
});
|
||||
|
||||
const fullStream = await this.getFullStream(model, messages, options);
|
||||
const citationParser = new CitationParser();
|
||||
const textParser = new TextStreamParser();
|
||||
for await (const chunk of fullStream) {
|
||||
@@ -363,6 +367,39 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
}
|
||||
}
|
||||
|
||||
override async *streamObject(
|
||||
cond: ModelConditions,
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): AsyncIterable<StreamObject> {
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Object };
|
||||
await this.checkParams({ cond: fullCond, messages, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
try {
|
||||
metrics.ai
|
||||
.counter('chat_object_stream_calls')
|
||||
.add(1, { model: model.id });
|
||||
const fullStream = await this.getFullStream(model, messages, options);
|
||||
const parser = new StreamObjectParser();
|
||||
for await (const chunk of fullStream) {
|
||||
const result = parser.parse(chunk);
|
||||
if (result) {
|
||||
yield result;
|
||||
}
|
||||
if (options.signal?.aborted) {
|
||||
await fullStream.cancel();
|
||||
break;
|
||||
}
|
||||
}
|
||||
} catch (e: any) {
|
||||
metrics.ai
|
||||
.counter('chat_object_stream_errors')
|
||||
.add(1, { model: model.id });
|
||||
throw this.handleError(e, model.id, options);
|
||||
}
|
||||
}
|
||||
|
||||
override async structure(
|
||||
cond: ModelConditions,
|
||||
messages: PromptMessage[],
|
||||
@@ -403,6 +440,31 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
}
|
||||
}
|
||||
|
||||
private async getFullStream(
|
||||
model: CopilotProviderModel,
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
) {
|
||||
const [system, msgs] = await chatToGPTMessage(messages);
|
||||
const modelInstance = this.#instance.responses(model.id);
|
||||
const { fullStream } = streamText({
|
||||
model: modelInstance,
|
||||
system,
|
||||
messages: msgs,
|
||||
frequencyPenalty: options.frequencyPenalty ?? 0,
|
||||
presencePenalty: options.presencePenalty ?? 0,
|
||||
temperature: options.temperature ?? 0,
|
||||
maxTokens: options.maxTokens ?? 4096,
|
||||
providerOptions: {
|
||||
openai: this.getOpenAIOptions(options, model.id),
|
||||
},
|
||||
tools: await this.getTools(options, model.id),
|
||||
maxSteps: this.MAX_STEPS,
|
||||
abortSignal: options.signal,
|
||||
});
|
||||
return fullStream;
|
||||
}
|
||||
|
||||
// ====== text to image ======
|
||||
private async *generateImageWithAttachments(
|
||||
model: string,
|
||||
|
||||
@@ -33,6 +33,7 @@ import {
|
||||
ModelInputType,
|
||||
type PromptMessage,
|
||||
PromptMessageSchema,
|
||||
StreamObject,
|
||||
} from './types';
|
||||
|
||||
@Injectable()
|
||||
@@ -225,6 +226,17 @@ export abstract class CopilotProvider<C = any> {
|
||||
options?: CopilotChatOptions
|
||||
): AsyncIterable<string>;
|
||||
|
||||
streamObject(
|
||||
_model: ModelConditions,
|
||||
_messages: PromptMessage[],
|
||||
_options?: CopilotChatOptions
|
||||
): AsyncIterable<StreamObject> {
|
||||
throw new CopilotProviderNotSupported({
|
||||
provider: this.type,
|
||||
kind: 'object',
|
||||
});
|
||||
}
|
||||
|
||||
structure(
|
||||
_cond: ModelConditions,
|
||||
_messages: PromptMessage[],
|
||||
|
||||
@@ -118,8 +118,33 @@ export const ChatMessageAttachment = z.union([
|
||||
}),
|
||||
]);
|
||||
|
||||
export const StreamObjectSchema = z.discriminatedUnion('type', [
|
||||
z.object({
|
||||
type: z.literal('text-delta'),
|
||||
textDelta: z.string(),
|
||||
}),
|
||||
z.object({
|
||||
type: z.literal('reasoning'),
|
||||
textDelta: z.string(),
|
||||
}),
|
||||
z.object({
|
||||
type: z.literal('tool-call'),
|
||||
toolCallId: z.string(),
|
||||
toolName: z.string(),
|
||||
args: z.record(z.any()),
|
||||
}),
|
||||
z.object({
|
||||
type: z.literal('tool-result'),
|
||||
toolCallId: z.string(),
|
||||
toolName: z.string(),
|
||||
args: z.record(z.any()),
|
||||
result: z.any(),
|
||||
}),
|
||||
]);
|
||||
|
||||
export const PureMessageSchema = z.object({
|
||||
content: z.string(),
|
||||
streamObjects: z.array(StreamObjectSchema).optional().nullable(),
|
||||
attachments: z.array(ChatMessageAttachment).optional().nullable(),
|
||||
params: z.record(z.any()).optional().nullable(),
|
||||
});
|
||||
@@ -129,6 +154,7 @@ export const PromptMessageSchema = PureMessageSchema.extend({
|
||||
}).strict();
|
||||
export type PromptMessage = z.infer<typeof PromptMessageSchema>;
|
||||
export type PromptParams = NonNullable<PromptMessage['params']>;
|
||||
export type StreamObject = z.infer<typeof StreamObjectSchema>;
|
||||
|
||||
// ========== options ==========
|
||||
|
||||
@@ -187,6 +213,7 @@ export enum ModelInputType {
|
||||
|
||||
export enum ModelOutputType {
|
||||
Text = 'text',
|
||||
Object = 'object',
|
||||
Embedding = 'embedding',
|
||||
Image = 'image',
|
||||
Structured = 'structured',
|
||||
|
||||
@@ -14,7 +14,7 @@ import {
|
||||
createExaCrawlTool,
|
||||
createExaSearchTool,
|
||||
} from '../tools';
|
||||
import { PromptMessage } from './types';
|
||||
import { PromptMessage, StreamObject } from './types';
|
||||
|
||||
type ChatMessage = CoreUserMessage | CoreAssistantMessage;
|
||||
|
||||
@@ -387,6 +387,22 @@ export interface CustomAITools extends ToolSet {
|
||||
|
||||
type ChunkType = TextStreamPart<CustomAITools>['type'];
|
||||
|
||||
export function parseUnknownError(error: unknown) {
|
||||
if (typeof error === 'string') {
|
||||
throw new Error(error);
|
||||
} else if (error instanceof Error) {
|
||||
throw error;
|
||||
} else if (
|
||||
typeof error === 'object' &&
|
||||
error !== null &&
|
||||
'message' in error
|
||||
) {
|
||||
throw new Error(String(error.message));
|
||||
} else {
|
||||
throw new Error(JSON.stringify(error));
|
||||
}
|
||||
}
|
||||
|
||||
export class TextStreamParser {
|
||||
private readonly CALLOUT_PREFIX = '\n[!]\n';
|
||||
|
||||
@@ -446,8 +462,8 @@ export class TextStreamParser {
|
||||
break;
|
||||
}
|
||||
case 'error': {
|
||||
const error = chunk.error as { type: string; message: string };
|
||||
throw new Error(error.message);
|
||||
parseUnknownError(chunk.error);
|
||||
break;
|
||||
}
|
||||
}
|
||||
this.lastType = chunk.type;
|
||||
@@ -490,3 +506,54 @@ export class TextStreamParser {
|
||||
return links;
|
||||
}
|
||||
}
|
||||
|
||||
export class StreamObjectParser {
|
||||
public parse(chunk: TextStreamPart<CustomAITools>) {
|
||||
switch (chunk.type) {
|
||||
case 'reasoning':
|
||||
case 'text-delta':
|
||||
case 'tool-call':
|
||||
case 'tool-result': {
|
||||
return chunk;
|
||||
}
|
||||
case 'error': {
|
||||
parseUnknownError(chunk.error);
|
||||
return null;
|
||||
}
|
||||
default: {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public mergeTextDelta(chunks: StreamObject[]): StreamObject[] {
|
||||
return chunks.reduce((acc, curr) => {
|
||||
const prev = acc.at(-1);
|
||||
switch (curr.type) {
|
||||
case 'reasoning':
|
||||
case 'text-delta': {
|
||||
if (prev && prev.type === curr.type) {
|
||||
prev.textDelta += curr.textDelta;
|
||||
} else {
|
||||
acc.push(curr);
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
acc.push(curr);
|
||||
break;
|
||||
}
|
||||
}
|
||||
return acc;
|
||||
}, [] as StreamObject[]);
|
||||
}
|
||||
|
||||
public mergeContent(chunks: StreamObject[]): string {
|
||||
return chunks.reduce((acc, curr) => {
|
||||
if (curr.type === 'text-delta') {
|
||||
acc += curr.textDelta;
|
||||
}
|
||||
return acc;
|
||||
}, '');
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,7 +34,7 @@ import { Admin } from '../../core/common';
|
||||
import { AccessController } from '../../core/permission';
|
||||
import { UserType } from '../../core/user';
|
||||
import { PromptService } from './prompt';
|
||||
import { PromptMessage } from './providers';
|
||||
import { PromptMessage, StreamObject } from './providers';
|
||||
import { ChatSessionService } from './session';
|
||||
import { CopilotStorage } from './storage';
|
||||
import {
|
||||
@@ -168,6 +168,27 @@ class QueryChatHistoriesInput implements Partial<ListHistoriesOptions> {
|
||||
|
||||
// ================== Return Types ==================
|
||||
|
||||
@ObjectType('StreamObject')
|
||||
class StreamObjectType {
|
||||
@Field(() => String)
|
||||
type!: string;
|
||||
|
||||
@Field(() => String, { nullable: true })
|
||||
textDelta?: string;
|
||||
|
||||
@Field(() => String, { nullable: true })
|
||||
toolCallId?: string;
|
||||
|
||||
@Field(() => String, { nullable: true })
|
||||
toolName?: string;
|
||||
|
||||
@Field(() => GraphQLJSON, { nullable: true })
|
||||
args?: any;
|
||||
|
||||
@Field(() => GraphQLJSON, { nullable: true })
|
||||
result?: any;
|
||||
}
|
||||
|
||||
@ObjectType('ChatMessage')
|
||||
class ChatMessageType implements Partial<ChatMessage> {
|
||||
// id will be null if message is a prompt message
|
||||
@@ -180,6 +201,9 @@ class ChatMessageType implements Partial<ChatMessage> {
|
||||
@Field(() => String)
|
||||
content!: string;
|
||||
|
||||
@Field(() => [StreamObjectType], { nullable: true })
|
||||
streamObjects!: StreamObject[];
|
||||
|
||||
@Field(() => [String], { nullable: true })
|
||||
attachments!: string[];
|
||||
|
||||
|
||||
@@ -282,6 +282,7 @@ export class ChatSessionService {
|
||||
await tx.aiSessionMessage.createMany({
|
||||
data: state.messages.map(m => ({
|
||||
...m,
|
||||
streamObjects: m.streamObjects || undefined,
|
||||
attachments: m.attachments || undefined,
|
||||
params: omit(m.params, ['docs']) || undefined,
|
||||
sessionId,
|
||||
@@ -512,6 +513,7 @@ export class ChatSessionService {
|
||||
id: true,
|
||||
role: true,
|
||||
content: true,
|
||||
streamObjects: true,
|
||||
attachments: true,
|
||||
params: true,
|
||||
createdAt: true,
|
||||
|
||||
Reference in New Issue
Block a user