fix(server): abort behavior in sse stream (#12211)

fix AI-121
fix AI-118

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **Bug Fixes**
- Improved handling of connection closures and request abortion for
streaming and non-streaming chat endpoints, ensuring session data is
saved appropriately even if the connection is interrupted.
- **Refactor**
- Streamlined internal logic for managing request signals and connection
events, resulting in more robust and explicit session management during
streaming interactions.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
DarkSky
2025-07-04 14:07:45 +08:00
committed by GitHub
parent 1b9ed2fb6d
commit 5a49d5cd24
5 changed files with 146 additions and 74 deletions

View File

@@ -51,7 +51,7 @@ import { COPILOT_LOCKER, CopilotType } from '../resolver';
import { ChatSessionService } from '../session'; import { ChatSessionService } from '../session';
import { CopilotStorage } from '../storage'; import { CopilotStorage } from '../storage';
import { MAX_EMBEDDABLE_SIZE } from '../types'; import { MAX_EMBEDDABLE_SIZE } from '../types';
import { readStream } from '../utils'; import { getSignal, readStream } from '../utils';
import { CopilotContextService } from './service'; import { CopilotContextService } from './service';
@InputType() @InputType()
@@ -394,16 +394,6 @@ export class CopilotContextResolver {
private readonly storage: CopilotStorage private readonly storage: CopilotStorage
) {} ) {}
private getSignal(req: Request) {
const controller = new AbortController();
req.socket.on('close', hasError => {
if (hasError) {
controller.abort();
}
});
return controller.signal;
}
@ResolveField(() => [CopilotContextCategory], { @ResolveField(() => [CopilotContextCategory], {
description: 'list collections in context', description: 'list collections in context',
}) })
@@ -710,7 +700,7 @@ export class CopilotContextResolver {
context.workspaceId, context.workspaceId,
content, content,
limit, limit,
this.getSignal(ctx.req), getSignal(ctx.req).signal,
threshold threshold
); );
} }
@@ -719,7 +709,7 @@ export class CopilotContextResolver {
return await session.matchFiles( return await session.matchFiles(
content, content,
limit, limit,
this.getSignal(ctx.req), getSignal(ctx.req).signal,
scopedThreshold, scopedThreshold,
threshold threshold
); );
@@ -785,7 +775,7 @@ export class CopilotContextResolver {
context.workspaceId, context.workspaceId,
content, content,
limit, limit,
this.getSignal(ctx.req), getSignal(ctx.req).signal,
threshold threshold
); );
} }
@@ -802,7 +792,7 @@ export class CopilotContextResolver {
const chunks = await session.matchWorkspaceDocs( const chunks = await session.matchWorkspaceDocs(
content, content,
limit, limit,
this.getSignal(ctx.req), getSignal(ctx.req).signal,
scopedThreshold, scopedThreshold,
threshold threshold
); );

View File

@@ -13,22 +13,22 @@ import type { Request, Response } from 'express';
import { import {
BehaviorSubject, BehaviorSubject,
catchError, catchError,
concatMap,
connect, connect,
EMPTY,
filter, filter,
finalize, finalize,
from, from,
ignoreElements,
interval, interval,
lastValueFrom, lastValueFrom,
map, map,
merge, merge,
mergeMap, mergeMap,
Observable, Observable,
reduce,
Subject, Subject,
take, take,
takeUntil, takeUntil,
toArray, tap,
} from 'rxjs'; } from 'rxjs';
import { import {
@@ -50,11 +50,13 @@ import {
CopilotProviderFactory, CopilotProviderFactory,
ModelInputType, ModelInputType,
ModelOutputType, ModelOutputType,
StreamObject,
} from './providers'; } from './providers';
import { StreamObjectParser } from './providers/utils'; import { StreamObjectParser } from './providers/utils';
import { ChatSession, ChatSessionService } from './session'; import { ChatSession, ChatSessionService } from './session';
import { CopilotStorage } from './storage'; import { CopilotStorage } from './storage';
import { ChatMessage, ChatQuerySchema } from './types'; import { ChatMessage, ChatQuerySchema } from './types';
import { getSignal } from './utils';
import { CopilotWorkflowService, GraphExecutorState } from './workflow'; import { CopilotWorkflowService, GraphExecutorState } from './workflow';
export interface ChatEvent { export interface ChatEvent {
@@ -156,16 +158,6 @@ export class CopilotController implements BeforeApplicationShutdown {
return [latestMessage, session]; return [latestMessage, session];
} }
private getSignal(req: Request) {
const controller = new AbortController();
req.socket.on('close', hasError => {
if (hasError) {
controller.abort();
}
});
return controller.signal;
}
private parseNumber(value: string | string[] | undefined) { private parseNumber(value: string | string[] | undefined) {
if (!value) { if (!value) {
return undefined; return undefined;
@@ -255,7 +247,7 @@ export class CopilotController implements BeforeApplicationShutdown {
const { reasoning, webSearch } = ChatQuerySchema.parse(query); const { reasoning, webSearch } = ChatQuerySchema.parse(query);
const content = await provider.text({ modelId: model }, finalMessage, { const content = await provider.text({ modelId: model }, finalMessage, {
...session.config.promptConfig, ...session.config.promptConfig,
signal: this.getSignal(req), signal: getSignal(req).signal,
user: user.id, user: user.id,
session: session.config.sessionId, session: session.config.sessionId,
workspace: session.config.workspaceId, workspace: session.config.workspaceId,
@@ -306,11 +298,13 @@ export class CopilotController implements BeforeApplicationShutdown {
metrics.ai.counter('chat_stream_calls').add(1, { model }); metrics.ai.counter('chat_stream_calls').add(1, { model });
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1); this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
const { signal, onConnectionClosed } = getSignal(req);
const { messageId, reasoning, webSearch } = ChatQuerySchema.parse(query); const { messageId, reasoning, webSearch } = ChatQuerySchema.parse(query);
const source$ = from( const source$ = from(
provider.streamText({ modelId: model }, finalMessage, { provider.streamText({ modelId: model }, finalMessage, {
...session.config.promptConfig, ...session.config.promptConfig,
signal: this.getSignal(req), signal,
user: user.id, user: user.id,
session: session.config.sessionId, session: session.config.sessionId,
workspace: session.config.workspaceId, workspace: session.config.workspaceId,
@@ -326,16 +320,25 @@ export class CopilotController implements BeforeApplicationShutdown {
), ),
// save the generated text to the session // save the generated text to the session
shared$.pipe( shared$.pipe(
toArray(), reduce((acc, chunk) => acc + chunk, ''),
concatMap(values => { tap(buffer => {
session.push({ onConnectionClosed(isAborted => {
role: 'assistant', session.push({
content: values.join(''), role: 'assistant',
createdAt: new Date(), content: isAborted ? '> Request aborted' : buffer,
createdAt: new Date(),
});
void session
.save()
.catch(err =>
this.logger.error(
'Failed to save session in sse stream',
err
)
);
}); });
return from(session.save());
}), }),
mergeMap(() => EMPTY) ignoreElements()
) )
) )
), ),
@@ -380,11 +383,13 @@ export class CopilotController implements BeforeApplicationShutdown {
metrics.ai.counter('chat_object_stream_calls').add(1, { model }); metrics.ai.counter('chat_object_stream_calls').add(1, { model });
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1); this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
const { signal, onConnectionClosed } = getSignal(req);
const { messageId, reasoning, webSearch } = ChatQuerySchema.parse(query); const { messageId, reasoning, webSearch } = ChatQuerySchema.parse(query);
const source$ = from( const source$ = from(
provider.streamObject({ modelId: model }, finalMessage, { provider.streamObject({ modelId: model }, finalMessage, {
...session.config.promptConfig, ...session.config.promptConfig,
signal: this.getSignal(req), signal,
user: user.id, user: user.id,
session: session.config.sessionId, session: session.config.sessionId,
workspace: session.config.workspaceId, workspace: session.config.workspaceId,
@@ -400,20 +405,29 @@ export class CopilotController implements BeforeApplicationShutdown {
), ),
// save the generated text to the session // save the generated text to the session
shared$.pipe( shared$.pipe(
toArray(), reduce((acc, chunk) => acc.concat([chunk]), [] as StreamObject[]),
concatMap(values => { tap(result => {
const parser = new StreamObjectParser(); onConnectionClosed(isAborted => {
const streamObjects = parser.mergeTextDelta(values); const parser = new StreamObjectParser();
const content = parser.mergeContent(streamObjects); const streamObjects = parser.mergeTextDelta(result);
session.push({ const content = parser.mergeContent(streamObjects);
role: 'assistant', session.push({
content, role: 'assistant',
streamObjects, content: isAborted ? '> Request aborted' : content,
createdAt: new Date(), streamObjects: isAborted ? null : streamObjects,
createdAt: new Date(),
});
void session
.save()
.catch(err =>
this.logger.error(
'Failed to save session in sse stream',
err
)
);
}); });
return from(session.save());
}), }),
mergeMap(() => EMPTY) ignoreElements()
) )
) )
), ),
@@ -461,10 +475,12 @@ export class CopilotController implements BeforeApplicationShutdown {
}); });
} }
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1); this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
const { signal, onConnectionClosed } = getSignal(req);
const source$ = from( const source$ = from(
this.workflow.runGraph(params, session.model, { this.workflow.runGraph(params, session.model, {
...session.config.promptConfig, ...session.config.promptConfig,
signal: this.getSignal(req), signal,
user: user.id, user: user.id,
session: session.config.sessionId, session: session.config.sessionId,
workspace: session.config.workspaceId, workspace: session.config.workspaceId,
@@ -503,19 +519,30 @@ export class CopilotController implements BeforeApplicationShutdown {
), ),
// save the generated text to the session // save the generated text to the session
shared$.pipe( shared$.pipe(
toArray(), reduce((acc, chunk) => {
concatMap(values => { if (chunk.status === GraphExecutorState.EmitContent) {
session.push({ acc += chunk.content;
role: 'assistant', }
content: values return acc;
.filter(v => v.status === GraphExecutorState.EmitContent) }, ''),
.map(v => v.content) tap(content => {
.join(''), onConnectionClosed(isAborted => {
createdAt: new Date(), session.push({
role: 'assistant',
content: isAborted ? '> Request aborted' : content,
createdAt: new Date(),
});
void session
.save()
.catch(err =>
this.logger.error(
'Failed to save session in sse stream',
err
)
);
}); });
return from(session.save());
}), }),
mergeMap(() => EMPTY) ignoreElements()
) )
) )
), ),
@@ -575,6 +602,8 @@ export class CopilotController implements BeforeApplicationShutdown {
sessionId sessionId
); );
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1); this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
const { signal, onConnectionClosed } = getSignal(req);
const source$ = from( const source$ = from(
provider.streamImages( provider.streamImages(
{ {
@@ -588,7 +617,7 @@ export class CopilotController implements BeforeApplicationShutdown {
...session.config.promptConfig, ...session.config.promptConfig,
quality: params.quality || undefined, quality: params.quality || undefined,
seed: this.parseNumber(params.seed), seed: this.parseNumber(params.seed),
signal: this.getSignal(req), signal,
user: user.id, user: user.id,
session: session.config.sessionId, session: session.config.sessionId,
workspace: session.config.workspaceId, workspace: session.config.workspaceId,
@@ -608,17 +637,26 @@ export class CopilotController implements BeforeApplicationShutdown {
), ),
// save the generated text to the session // save the generated text to the session
shared$.pipe( shared$.pipe(
toArray(), reduce((acc, chunk) => acc.concat([chunk]), [] as string[]),
concatMap(attachments => { tap(attachments => {
session.push({ onConnectionClosed(isAborted => {
role: 'assistant', session.push({
content: '', role: 'assistant',
attachments: attachments, content: isAborted ? '> Request aborted' : '',
createdAt: new Date(), attachments: isAborted ? [] : attachments,
createdAt: new Date(),
});
void session
.save()
.catch(err =>
this.logger.error(
'Failed to save session in sse stream',
err
)
);
}); });
return from(session.save());
}), }),
mergeMap(() => EMPTY) ignoreElements()
) )
) )
), ),
@@ -656,7 +694,7 @@ export class CopilotController implements BeforeApplicationShutdown {
`https://api.unsplash.com/search/photos?${query}`, `https://api.unsplash.com/search/photos?${query}`,
{ {
headers: { Authorization: `Client-ID ${key}` }, headers: { Authorization: `Client-ID ${key}` },
signal: this.getSignal(req), signal: getSignal(req).signal,
} }
); );

View File

@@ -1,5 +1,7 @@
import { Readable } from 'node:stream'; import { Readable } from 'node:stream';
import type { Request } from 'express';
import { readBufferWithLimit } from '../../base'; import { readBufferWithLimit } from '../../base';
import { MAX_EMBEDDABLE_SIZE } from './types'; import { MAX_EMBEDDABLE_SIZE } from './types';
@@ -9,3 +11,38 @@ export function readStream(
): Promise<Buffer> { ): Promise<Buffer> {
return readBufferWithLimit(readable, maxSize); return readBufferWithLimit(readable, maxSize);
} }
type RequestClosedCallback = (isAborted: boolean) => void;
type SignalReturnType = {
signal: AbortSignal;
onConnectionClosed: (cb: RequestClosedCallback) => void;
};
export function getSignal(req: Request): SignalReturnType {
const controller = new AbortController();
let isAborted = true;
let callback: ((isAborted: boolean) => void) | undefined = undefined;
const onSocketEnd = () => {
isAborted = false;
};
const onSocketClose = (hadError: boolean) => {
req.socket.off('end', onSocketEnd);
req.socket.off('close', onSocketClose);
const aborted = hadError || isAborted;
if (aborted) {
controller.abort();
}
callback?.(aborted);
};
req.socket.on('end', onSocketEnd);
req.socket.on('close', onSocketClose);
return {
signal: controller.signal,
onConnectionClosed: cb => (callback = cb),
};
}

View File

@@ -36,6 +36,9 @@ test.describe('AIAction/GenerateAnImageWithImage', () => {
await expect(answer.getByTestId('ai-answer-image')).toBeVisible(); await expect(answer.getByTestId('ai-answer-image')).toBeVisible();
const insert = answer.getByTestId('answer-insert-below'); const insert = answer.getByTestId('answer-insert-below');
await insert.click(); await insert.click();
await page.waitForSelector('.affine-image-container');
await page.reload();
await utils.chatPanel.waitForHistory(page, [ await utils.chatPanel.waitForHistory(page, [
{ {
role: 'action', role: 'action',

View File

@@ -71,11 +71,15 @@ test.describe('AIAction/GenerateAnImageWithText', () => {
const { answer } = await generateImage(); const { answer } = await generateImage();
const insert = answer.getByTestId('answer-insert-below'); const insert = answer.getByTestId('answer-insert-below');
await insert.click(); await insert.click();
await page.waitForSelector('.affine-image-container');
await page.reload();
await utils.chatPanel.waitForHistory(page, [ await utils.chatPanel.waitForHistory(page, [
{ {
role: 'action', role: 'action',
}, },
]); ]);
const { answer: panelAnswer, actionName } = const { answer: panelAnswer, actionName } =
await utils.chatPanel.getLatestAIActionMessage(page); await utils.chatPanel.getLatestAIActionMessage(page);
await expect( await expect(