Files
AFFiNE-Mirror/packages/backend/server/src/plugins/copilot/controller.ts
DarkSky b79439b01d fix(server): sse abort behavior (#13153)
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Bug Fixes**
* Improved handling of aborted client connections during streaming,
ensuring that session messages accurately reflect if a request was
aborted.
* Enhanced consistency and reliability across all streaming endpoints
when saving session messages after streaming.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-07-11 04:46:55 +00:00

777 lines
22 KiB
TypeScript

import {
BeforeApplicationShutdown,
Controller,
Get,
Logger,
Param,
Query,
Req,
Res,
Sse,
} from '@nestjs/common';
import type { Request, Response } from 'express';
import {
BehaviorSubject,
catchError,
connect,
filter,
finalize,
from,
ignoreElements,
interval,
lastValueFrom,
map,
merge,
mergeMap,
Observable,
reduce,
Subject,
take,
takeUntil,
tap,
} from 'rxjs';
import {
BlobNotFound,
CallMetric,
Config,
CopilotFailedToGenerateText,
CopilotSessionNotFound,
InternalServerError,
mapAnyError,
mapSseError,
metrics,
NoCopilotProviderAvailable,
UnsplashIsNotConfigured,
} from '../../base';
import { CurrentUser, Public } from '../../core/auth';
import {
CopilotProvider,
CopilotProviderFactory,
ModelInputType,
ModelOutputType,
StreamObject,
} from './providers';
import { StreamObjectParser } from './providers/utils';
import { ChatSession, ChatSessionService } from './session';
import { CopilotStorage } from './storage';
import { ChatMessage, ChatQuerySchema } from './types';
import { getSignal } from './utils';
import { CopilotWorkflowService, GraphExecutorState } from './workflow';
export interface ChatEvent {
type: 'event' | 'attachment' | 'message' | 'error' | 'ping';
id?: string;
data: string | object;
}
const PING_INTERVAL = 5000;
@Controller('/api/copilot')
export class CopilotController implements BeforeApplicationShutdown {
private readonly logger = new Logger(CopilotController.name);
private readonly ongoingStreamCount$ = new BehaviorSubject(0);
constructor(
private readonly config: Config,
private readonly chatSession: ChatSessionService,
private readonly provider: CopilotProviderFactory,
private readonly workflow: CopilotWorkflowService,
private readonly storage: CopilotStorage
) {}
async beforeApplicationShutdown() {
await lastValueFrom(
this.ongoingStreamCount$.asObservable().pipe(
filter(count => count === 0),
take(1)
)
);
this.ongoingStreamCount$.complete();
}
private async chooseProvider(
outputType: ModelOutputType,
userId: string,
sessionId: string,
messageId?: string,
modelId?: string
): Promise<{
provider: CopilotProvider;
model: string;
hasAttachment: boolean;
}> {
const [, session] = await Promise.all([
this.chatSession.checkQuota(userId),
this.chatSession.get(sessionId),
]);
if (!session || session.config.userId !== userId) {
throw new CopilotSessionNotFound();
}
const model =
modelId && session.optionalModels.includes(modelId)
? modelId
: session.model;
const hasAttachment = messageId
? !!(await session.getMessageById(messageId)).attachments?.length
: false;
const provider = await this.provider.getProvider({
outputType,
modelId: model,
});
if (!provider) {
throw new NoCopilotProviderAvailable();
}
return { provider, model, hasAttachment };
}
private async appendSessionMessage(
sessionId: string,
messageId?: string,
retry = false
): Promise<[ChatMessage | undefined, ChatSession]> {
const session = await this.chatSession.get(sessionId);
if (!session) {
throw new CopilotSessionNotFound();
}
let latestMessage = undefined;
if (!messageId || retry) {
// revert the latest message generated by the assistant
// if messageId is provided, we will also revert latest user message
await this.chatSession.revertLatestMessage(sessionId, !!messageId);
session.revertLatestMessage(!!messageId);
if (!messageId) {
latestMessage = session.latestUserMessage;
}
}
if (messageId) {
await session.pushByMessageId(messageId);
}
return [latestMessage, session];
}
private parseNumber(value: string | string[] | undefined) {
if (!value) {
return undefined;
}
const num = Number.parseInt(Array.isArray(value) ? value[0] : value, 10);
if (Number.isNaN(num)) {
return undefined;
}
return num;
}
private mergePingStream(
messageId: string,
source$: Observable<ChatEvent>
): Observable<ChatEvent> {
const subject$ = new Subject();
const ping$ = interval(PING_INTERVAL).pipe(
map(() => ({ type: 'ping' as const, id: messageId, data: '' })),
takeUntil(subject$)
);
return merge(source$.pipe(finalize(() => subject$.next(null))), ping$);
}
private async prepareChatSession(
user: CurrentUser,
sessionId: string,
query: Record<string, string | string[]>,
outputType: ModelOutputType
) {
let { messageId, retry, modelId, params } = ChatQuerySchema.parse(query);
const { provider, model } = await this.chooseProvider(
outputType,
user.id,
sessionId,
messageId,
modelId
);
const [latestMessage, session] = await this.appendSessionMessage(
sessionId,
messageId,
retry
);
if (latestMessage) {
params = Object.assign({}, params, latestMessage.params, {
content: latestMessage.content,
attachments: latestMessage.attachments,
});
}
const finalMessage = session.finish(params);
return {
provider,
model,
session,
finalMessage,
};
}
@Get('/chat/:sessionId')
@CallMetric('ai', 'chat', { timer: true })
async chat(
@CurrentUser() user: CurrentUser,
@Req() req: Request,
@Param('sessionId') sessionId: string,
@Query() query: Record<string, string | string[]>
): Promise<string> {
const info: any = { sessionId, params: query };
try {
const { provider, model, session, finalMessage } =
await this.prepareChatSession(
user,
sessionId,
query,
ModelOutputType.Text
);
info.model = model;
info.finalMessage = finalMessage.filter(m => m.role !== 'system');
metrics.ai.counter('chat_calls').add(1, { model });
const { reasoning, webSearch } = ChatQuerySchema.parse(query);
const content = await provider.text({ modelId: model }, finalMessage, {
...session.config.promptConfig,
signal: getSignal(req).signal,
user: user.id,
session: session.config.sessionId,
workspace: session.config.workspaceId,
reasoning,
webSearch,
});
session.push({
role: 'assistant',
content,
createdAt: new Date(),
});
await session.save();
return content;
} catch (e: any) {
metrics.ai.counter('chat_errors').add(1);
let error = mapAnyError(e);
if (error instanceof InternalServerError) {
error = new CopilotFailedToGenerateText(e.message);
}
error.log('CopilotChat', info);
throw error;
}
}
@Sse('/chat/:sessionId/stream')
@CallMetric('ai', 'chat_stream', { timer: true })
async chatStream(
@CurrentUser() user: CurrentUser,
@Req() req: Request,
@Param('sessionId') sessionId: string,
@Query() query: Record<string, string>
): Promise<Observable<ChatEvent>> {
const info: any = { sessionId, params: query, throwInStream: false };
try {
const { provider, model, session, finalMessage } =
await this.prepareChatSession(
user,
sessionId,
query,
ModelOutputType.Text
);
info.model = model;
info.finalMessage = finalMessage.filter(m => m.role !== 'system');
metrics.ai.counter('chat_stream_calls').add(1, { model });
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
const { signal, onConnectionClosed } = getSignal(req);
let endBeforePromiseResolve = false;
onConnectionClosed(isAborted => {
if (isAborted) {
endBeforePromiseResolve = true;
}
});
const { messageId, reasoning, webSearch } = ChatQuerySchema.parse(query);
const source$ = from(
provider.streamText({ modelId: model }, finalMessage, {
...session.config.promptConfig,
signal,
user: user.id,
session: session.config.sessionId,
workspace: session.config.workspaceId,
reasoning,
webSearch,
})
).pipe(
connect(shared$ =>
merge(
// actual chat event stream
shared$.pipe(
map(data => ({ type: 'message' as const, id: messageId, data }))
),
// save the generated text to the session
shared$.pipe(
reduce((acc, chunk) => acc + chunk, ''),
tap(buffer => {
session.push({
role: 'assistant',
content: endBeforePromiseResolve
? '> Request aborted'
: buffer,
createdAt: new Date(),
});
void session
.save()
.catch(err =>
this.logger.error(
'Failed to save session in sse stream',
err
)
);
}),
ignoreElements()
)
)
),
catchError(e => {
metrics.ai.counter('chat_stream_errors').add(1);
info.throwInStream = true;
return mapSseError(e, info);
}),
finalize(() => {
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value - 1);
})
);
return this.mergePingStream(messageId || '', source$);
} catch (err) {
metrics.ai.counter('chat_stream_errors').add(1, info);
return mapSseError(err, info);
}
}
@Sse('/chat/:sessionId/stream-object')
@CallMetric('ai', 'chat_object_stream', { timer: true })
async chatStreamObject(
@CurrentUser() user: CurrentUser,
@Req() req: Request,
@Param('sessionId') sessionId: string,
@Query() query: Record<string, string>
): Promise<Observable<ChatEvent>> {
const info: any = { sessionId, params: query, throwInStream: false };
try {
const { provider, model, session, finalMessage } =
await this.prepareChatSession(
user,
sessionId,
query,
ModelOutputType.Object
);
info.model = model;
info.finalMessage = finalMessage.filter(m => m.role !== 'system');
metrics.ai.counter('chat_object_stream_calls').add(1, { model });
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
const { signal, onConnectionClosed } = getSignal(req);
let endBeforePromiseResolve = false;
onConnectionClosed(isAborted => {
if (isAborted) {
endBeforePromiseResolve = true;
}
});
const { messageId, reasoning, webSearch } = ChatQuerySchema.parse(query);
const source$ = from(
provider.streamObject({ modelId: model }, finalMessage, {
...session.config.promptConfig,
signal,
user: user.id,
session: session.config.sessionId,
workspace: session.config.workspaceId,
reasoning,
webSearch,
})
).pipe(
connect(shared$ =>
merge(
// actual chat event stream
shared$.pipe(
map(data => ({ type: 'message' as const, id: messageId, data }))
),
// save the generated text to the session
shared$.pipe(
reduce((acc, chunk) => acc.concat([chunk]), [] as StreamObject[]),
tap(result => {
const parser = new StreamObjectParser();
const streamObjects = parser.mergeTextDelta(result);
const content = parser.mergeContent(streamObjects);
session.push({
role: 'assistant',
content: endBeforePromiseResolve
? '> Request aborted'
: content,
streamObjects: endBeforePromiseResolve ? null : streamObjects,
createdAt: new Date(),
});
void session
.save()
.catch(err =>
this.logger.error(
'Failed to save session in sse stream',
err
)
);
}),
ignoreElements()
)
)
),
catchError(e => {
metrics.ai.counter('chat_object_stream_errors').add(1);
info.throwInStream = true;
return mapSseError(e, info);
}),
finalize(() => {
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value - 1);
})
);
return this.mergePingStream(messageId || '', source$);
} catch (err) {
metrics.ai.counter('chat_object_stream_errors').add(1, info);
return mapSseError(err, info);
}
}
@Sse('/chat/:sessionId/workflow')
@CallMetric('ai', 'chat_workflow', { timer: true })
async chatWorkflow(
@CurrentUser() user: CurrentUser,
@Req() req: Request,
@Param('sessionId') sessionId: string,
@Query() query: Record<string, string>
): Promise<Observable<ChatEvent>> {
const info: any = { sessionId, params: query, throwInStream: false };
try {
let { messageId, params } = ChatQuerySchema.parse(query);
const [, session] = await this.appendSessionMessage(sessionId, messageId);
info.model = session.model;
metrics.ai.counter('workflow_calls').add(1, { model: session.model });
const latestMessage = session.stashMessages.findLast(
m => m.role === 'user'
);
if (latestMessage) {
params = Object.assign({}, params, latestMessage.params, {
content: latestMessage.content,
attachments: latestMessage.attachments,
});
}
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
const { signal, onConnectionClosed } = getSignal(req);
let endBeforePromiseResolve = false;
onConnectionClosed(isAborted => {
if (isAborted) {
endBeforePromiseResolve = true;
}
});
const source$ = from(
this.workflow.runGraph(params, session.model, {
...session.config.promptConfig,
signal,
user: user.id,
session: session.config.sessionId,
workspace: session.config.workspaceId,
})
).pipe(
connect(shared$ =>
merge(
// actual chat event stream
shared$.pipe(
map(data => {
switch (data.status) {
case GraphExecutorState.EmitContent:
return {
type: 'message' as const,
id: messageId,
data: data.content,
};
case GraphExecutorState.EmitAttachment:
return {
type: 'attachment' as const,
id: messageId,
data: data.attachment,
};
default:
return {
type: 'event' as const,
id: messageId,
data: {
status: data.status,
id: data.node.id,
type: data.node.config.nodeType,
} as any,
};
}
})
),
// save the generated text to the session
shared$.pipe(
reduce((acc, chunk) => {
if (chunk.status === GraphExecutorState.EmitContent) {
acc += chunk.content;
}
return acc;
}, ''),
tap(content => {
session.push({
role: 'assistant',
content: endBeforePromiseResolve
? '> Request aborted'
: content,
createdAt: new Date(),
});
void session
.save()
.catch(err =>
this.logger.error(
'Failed to save session in sse stream',
err
)
);
}),
ignoreElements()
)
)
),
catchError(e => {
metrics.ai.counter('workflow_errors').add(1, info);
info.throwInStream = true;
return mapSseError(e, info);
}),
finalize(() =>
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value - 1)
)
);
return this.mergePingStream(messageId || '', source$);
} catch (err) {
metrics.ai.counter('workflow_errors').add(1, info);
return mapSseError(err, info);
}
}
@Sse('/chat/:sessionId/images')
@CallMetric('ai', 'chat_images', { timer: true })
async chatImagesStream(
@CurrentUser() user: CurrentUser,
@Req() req: Request,
@Param('sessionId') sessionId: string,
@Query() query: Record<string, string>
): Promise<Observable<ChatEvent>> {
const info: any = { sessionId, params: query, throwInStream: false };
try {
let { messageId, params } = ChatQuerySchema.parse(query);
const { provider, model, hasAttachment } = await this.chooseProvider(
ModelOutputType.Image,
user.id,
sessionId,
messageId
);
const [latestMessage, session] = await this.appendSessionMessage(
sessionId,
messageId
);
info.model = model;
metrics.ai.counter('images_stream_calls').add(1, { model });
if (latestMessage) {
params = Object.assign({}, params, latestMessage.params, {
content: latestMessage.content,
attachments: latestMessage.attachments,
});
}
const handleRemoteLink = this.storage.handleRemoteLink.bind(
this.storage,
user.id,
sessionId
);
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
const { signal, onConnectionClosed } = getSignal(req);
let endBeforePromiseResolve = false;
onConnectionClosed(isAborted => {
if (isAborted) {
endBeforePromiseResolve = true;
}
});
const source$ = from(
provider.streamImages(
{
modelId: model,
inputTypes: hasAttachment
? [ModelInputType.Image]
: [ModelInputType.Text],
},
session.finish(params),
{
...session.config.promptConfig,
quality: params.quality || undefined,
seed: this.parseNumber(params.seed),
signal,
user: user.id,
session: session.config.sessionId,
workspace: session.config.workspaceId,
}
)
).pipe(
mergeMap(handleRemoteLink),
connect(shared$ =>
merge(
// actual chat event stream
shared$.pipe(
map(attachment => ({
type: 'attachment' as const,
id: messageId,
data: attachment,
}))
),
// save the generated text to the session
shared$.pipe(
reduce((acc, chunk) => acc.concat([chunk]), [] as string[]),
tap(attachments => {
session.push({
role: 'assistant',
content: endBeforePromiseResolve ? '> Request aborted' : '',
attachments: endBeforePromiseResolve ? [] : attachments,
createdAt: new Date(),
});
void session
.save()
.catch(err =>
this.logger.error(
'Failed to save session in sse stream',
err
)
);
}),
ignoreElements()
)
)
),
catchError(e => {
metrics.ai.counter('images_stream_errors').add(1, info);
info.throwInStream = true;
return mapSseError(e, info);
}),
finalize(() =>
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value - 1)
)
);
return this.mergePingStream(messageId || '', source$);
} catch (err) {
metrics.ai.counter('images_stream_errors').add(1, info);
return mapSseError(err, info);
}
}
@Get('/unsplash/photos')
@CallMetric('ai', 'unsplash')
async unsplashPhotos(
@Req() req: Request,
@Res() res: Response,
@Query() params: Record<string, string>
) {
const { key } = this.config.copilot.unsplash;
if (!key) {
throw new UnsplashIsNotConfigured();
}
const query = new URLSearchParams(params);
const response = await fetch(
`https://api.unsplash.com/search/photos?${query}`,
{
headers: { Authorization: `Client-ID ${key}` },
signal: getSignal(req).signal,
}
);
res.set({
'Content-Type': response.headers.get('Content-Type'),
'Content-Length': response.headers.get('Content-Length'),
'X-Ratelimit-Limit': response.headers.get('X-Ratelimit-Limit'),
'X-Ratelimit-Remaining': response.headers.get('X-Ratelimit-Remaining'),
});
res.status(response.status).send(await response.json());
}
@Public()
@Get('/blob/:userId/:workspaceId/:key')
async getBlob(
@Res() res: Response,
@Param('userId') userId: string,
@Param('workspaceId') workspaceId: string,
@Param('key') key: string
) {
const { body, metadata, redirectUrl } = await this.storage.get(
userId,
workspaceId,
key,
true
);
if (redirectUrl) {
// redirect to signed url
return res.redirect(redirectUrl);
}
if (!body) {
throw new BlobNotFound({
spaceId: workspaceId,
blobId: key,
});
}
// metadata should always exists if body is not null
if (metadata) {
res.setHeader('content-type', metadata.contentType);
res.setHeader('last-modified', metadata.lastModified.toUTCString());
res.setHeader('content-length', metadata.contentLength);
} else {
this.logger.warn(`Blob ${workspaceId}/${key} has no metadata`);
}
res.setHeader('cache-control', 'public, max-age=2592000, immutable');
body.pipe(res);
}
}