mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-14 21:27:20 +00:00
fix(server): abort behavior in sse stream (#12211)
fix AI-121 fix AI-118 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Bug Fixes** - Improved handling of connection closures and request abortion for streaming and non-streaming chat endpoints, ensuring session data is saved appropriately even if the connection is interrupted. - **Refactor** - Streamlined internal logic for managing request signals and connection events, resulting in more robust and explicit session management during streaming interactions. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
@@ -13,22 +13,22 @@ import type { Request, Response } from 'express';
|
||||
import {
|
||||
BehaviorSubject,
|
||||
catchError,
|
||||
concatMap,
|
||||
connect,
|
||||
EMPTY,
|
||||
filter,
|
||||
finalize,
|
||||
from,
|
||||
ignoreElements,
|
||||
interval,
|
||||
lastValueFrom,
|
||||
map,
|
||||
merge,
|
||||
mergeMap,
|
||||
Observable,
|
||||
reduce,
|
||||
Subject,
|
||||
take,
|
||||
takeUntil,
|
||||
toArray,
|
||||
tap,
|
||||
} from 'rxjs';
|
||||
|
||||
import {
|
||||
@@ -50,11 +50,13 @@ import {
|
||||
CopilotProviderFactory,
|
||||
ModelInputType,
|
||||
ModelOutputType,
|
||||
StreamObject,
|
||||
} from './providers';
|
||||
import { StreamObjectParser } from './providers/utils';
|
||||
import { ChatSession, ChatSessionService } from './session';
|
||||
import { CopilotStorage } from './storage';
|
||||
import { ChatMessage, ChatQuerySchema } from './types';
|
||||
import { getSignal } from './utils';
|
||||
import { CopilotWorkflowService, GraphExecutorState } from './workflow';
|
||||
|
||||
export interface ChatEvent {
|
||||
@@ -156,16 +158,6 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
return [latestMessage, session];
|
||||
}
|
||||
|
||||
private getSignal(req: Request) {
|
||||
const controller = new AbortController();
|
||||
req.socket.on('close', hasError => {
|
||||
if (hasError) {
|
||||
controller.abort();
|
||||
}
|
||||
});
|
||||
return controller.signal;
|
||||
}
|
||||
|
||||
private parseNumber(value: string | string[] | undefined) {
|
||||
if (!value) {
|
||||
return undefined;
|
||||
@@ -255,7 +247,7 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
const { reasoning, webSearch } = ChatQuerySchema.parse(query);
|
||||
const content = await provider.text({ modelId: model }, finalMessage, {
|
||||
...session.config.promptConfig,
|
||||
signal: this.getSignal(req),
|
||||
signal: getSignal(req).signal,
|
||||
user: user.id,
|
||||
session: session.config.sessionId,
|
||||
workspace: session.config.workspaceId,
|
||||
@@ -306,11 +298,13 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
metrics.ai.counter('chat_stream_calls').add(1, { model });
|
||||
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
|
||||
|
||||
const { signal, onConnectionClosed } = getSignal(req);
|
||||
const { messageId, reasoning, webSearch } = ChatQuerySchema.parse(query);
|
||||
|
||||
const source$ = from(
|
||||
provider.streamText({ modelId: model }, finalMessage, {
|
||||
...session.config.promptConfig,
|
||||
signal: this.getSignal(req),
|
||||
signal,
|
||||
user: user.id,
|
||||
session: session.config.sessionId,
|
||||
workspace: session.config.workspaceId,
|
||||
@@ -326,16 +320,25 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
),
|
||||
// save the generated text to the session
|
||||
shared$.pipe(
|
||||
toArray(),
|
||||
concatMap(values => {
|
||||
session.push({
|
||||
role: 'assistant',
|
||||
content: values.join(''),
|
||||
createdAt: new Date(),
|
||||
reduce((acc, chunk) => acc + chunk, ''),
|
||||
tap(buffer => {
|
||||
onConnectionClosed(isAborted => {
|
||||
session.push({
|
||||
role: 'assistant',
|
||||
content: isAborted ? '> Request aborted' : buffer,
|
||||
createdAt: new Date(),
|
||||
});
|
||||
void session
|
||||
.save()
|
||||
.catch(err =>
|
||||
this.logger.error(
|
||||
'Failed to save session in sse stream',
|
||||
err
|
||||
)
|
||||
);
|
||||
});
|
||||
return from(session.save());
|
||||
}),
|
||||
mergeMap(() => EMPTY)
|
||||
ignoreElements()
|
||||
)
|
||||
)
|
||||
),
|
||||
@@ -380,11 +383,13 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
metrics.ai.counter('chat_object_stream_calls').add(1, { model });
|
||||
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
|
||||
|
||||
const { signal, onConnectionClosed } = getSignal(req);
|
||||
const { messageId, reasoning, webSearch } = ChatQuerySchema.parse(query);
|
||||
|
||||
const source$ = from(
|
||||
provider.streamObject({ modelId: model }, finalMessage, {
|
||||
...session.config.promptConfig,
|
||||
signal: this.getSignal(req),
|
||||
signal,
|
||||
user: user.id,
|
||||
session: session.config.sessionId,
|
||||
workspace: session.config.workspaceId,
|
||||
@@ -400,20 +405,29 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
),
|
||||
// save the generated text to the session
|
||||
shared$.pipe(
|
||||
toArray(),
|
||||
concatMap(values => {
|
||||
const parser = new StreamObjectParser();
|
||||
const streamObjects = parser.mergeTextDelta(values);
|
||||
const content = parser.mergeContent(streamObjects);
|
||||
session.push({
|
||||
role: 'assistant',
|
||||
content,
|
||||
streamObjects,
|
||||
createdAt: new Date(),
|
||||
reduce((acc, chunk) => acc.concat([chunk]), [] as StreamObject[]),
|
||||
tap(result => {
|
||||
onConnectionClosed(isAborted => {
|
||||
const parser = new StreamObjectParser();
|
||||
const streamObjects = parser.mergeTextDelta(result);
|
||||
const content = parser.mergeContent(streamObjects);
|
||||
session.push({
|
||||
role: 'assistant',
|
||||
content: isAborted ? '> Request aborted' : content,
|
||||
streamObjects: isAborted ? null : streamObjects,
|
||||
createdAt: new Date(),
|
||||
});
|
||||
void session
|
||||
.save()
|
||||
.catch(err =>
|
||||
this.logger.error(
|
||||
'Failed to save session in sse stream',
|
||||
err
|
||||
)
|
||||
);
|
||||
});
|
||||
return from(session.save());
|
||||
}),
|
||||
mergeMap(() => EMPTY)
|
||||
ignoreElements()
|
||||
)
|
||||
)
|
||||
),
|
||||
@@ -461,10 +475,12 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
});
|
||||
}
|
||||
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
|
||||
|
||||
const { signal, onConnectionClosed } = getSignal(req);
|
||||
const source$ = from(
|
||||
this.workflow.runGraph(params, session.model, {
|
||||
...session.config.promptConfig,
|
||||
signal: this.getSignal(req),
|
||||
signal,
|
||||
user: user.id,
|
||||
session: session.config.sessionId,
|
||||
workspace: session.config.workspaceId,
|
||||
@@ -503,19 +519,30 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
),
|
||||
// save the generated text to the session
|
||||
shared$.pipe(
|
||||
toArray(),
|
||||
concatMap(values => {
|
||||
session.push({
|
||||
role: 'assistant',
|
||||
content: values
|
||||
.filter(v => v.status === GraphExecutorState.EmitContent)
|
||||
.map(v => v.content)
|
||||
.join(''),
|
||||
createdAt: new Date(),
|
||||
reduce((acc, chunk) => {
|
||||
if (chunk.status === GraphExecutorState.EmitContent) {
|
||||
acc += chunk.content;
|
||||
}
|
||||
return acc;
|
||||
}, ''),
|
||||
tap(content => {
|
||||
onConnectionClosed(isAborted => {
|
||||
session.push({
|
||||
role: 'assistant',
|
||||
content: isAborted ? '> Request aborted' : content,
|
||||
createdAt: new Date(),
|
||||
});
|
||||
void session
|
||||
.save()
|
||||
.catch(err =>
|
||||
this.logger.error(
|
||||
'Failed to save session in sse stream',
|
||||
err
|
||||
)
|
||||
);
|
||||
});
|
||||
return from(session.save());
|
||||
}),
|
||||
mergeMap(() => EMPTY)
|
||||
ignoreElements()
|
||||
)
|
||||
)
|
||||
),
|
||||
@@ -575,6 +602,8 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
sessionId
|
||||
);
|
||||
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
|
||||
|
||||
const { signal, onConnectionClosed } = getSignal(req);
|
||||
const source$ = from(
|
||||
provider.streamImages(
|
||||
{
|
||||
@@ -588,7 +617,7 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
...session.config.promptConfig,
|
||||
quality: params.quality || undefined,
|
||||
seed: this.parseNumber(params.seed),
|
||||
signal: this.getSignal(req),
|
||||
signal,
|
||||
user: user.id,
|
||||
session: session.config.sessionId,
|
||||
workspace: session.config.workspaceId,
|
||||
@@ -608,17 +637,26 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
),
|
||||
// save the generated text to the session
|
||||
shared$.pipe(
|
||||
toArray(),
|
||||
concatMap(attachments => {
|
||||
session.push({
|
||||
role: 'assistant',
|
||||
content: '',
|
||||
attachments: attachments,
|
||||
createdAt: new Date(),
|
||||
reduce((acc, chunk) => acc.concat([chunk]), [] as string[]),
|
||||
tap(attachments => {
|
||||
onConnectionClosed(isAborted => {
|
||||
session.push({
|
||||
role: 'assistant',
|
||||
content: isAborted ? '> Request aborted' : '',
|
||||
attachments: isAborted ? [] : attachments,
|
||||
createdAt: new Date(),
|
||||
});
|
||||
void session
|
||||
.save()
|
||||
.catch(err =>
|
||||
this.logger.error(
|
||||
'Failed to save session in sse stream',
|
||||
err
|
||||
)
|
||||
);
|
||||
});
|
||||
return from(session.save());
|
||||
}),
|
||||
mergeMap(() => EMPTY)
|
||||
ignoreElements()
|
||||
)
|
||||
)
|
||||
),
|
||||
@@ -656,7 +694,7 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
`https://api.unsplash.com/search/photos?${query}`,
|
||||
{
|
||||
headers: { Authorization: `Client-ID ${key}` },
|
||||
signal: this.getSignal(req),
|
||||
signal: getSignal(req).signal,
|
||||
}
|
||||
);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user