mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-25 18:26:05 +08:00
fix(server): abort behavior in sse stream (#12211)
fix AI-121 fix AI-118 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Bug Fixes** - Improved handling of connection closures and request abortion for streaming and non-streaming chat endpoints, ensuring session data is saved appropriately even if the connection is interrupted. - **Refactor** - Streamlined internal logic for managing request signals and connection events, resulting in more robust and explicit session management during streaming interactions. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
@@ -51,7 +51,7 @@ import { COPILOT_LOCKER, CopilotType } from '../resolver';
|
|||||||
import { ChatSessionService } from '../session';
|
import { ChatSessionService } from '../session';
|
||||||
import { CopilotStorage } from '../storage';
|
import { CopilotStorage } from '../storage';
|
||||||
import { MAX_EMBEDDABLE_SIZE } from '../types';
|
import { MAX_EMBEDDABLE_SIZE } from '../types';
|
||||||
import { readStream } from '../utils';
|
import { getSignal, readStream } from '../utils';
|
||||||
import { CopilotContextService } from './service';
|
import { CopilotContextService } from './service';
|
||||||
|
|
||||||
@InputType()
|
@InputType()
|
||||||
@@ -394,16 +394,6 @@ export class CopilotContextResolver {
|
|||||||
private readonly storage: CopilotStorage
|
private readonly storage: CopilotStorage
|
||||||
) {}
|
) {}
|
||||||
|
|
||||||
private getSignal(req: Request) {
|
|
||||||
const controller = new AbortController();
|
|
||||||
req.socket.on('close', hasError => {
|
|
||||||
if (hasError) {
|
|
||||||
controller.abort();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
return controller.signal;
|
|
||||||
}
|
|
||||||
|
|
||||||
@ResolveField(() => [CopilotContextCategory], {
|
@ResolveField(() => [CopilotContextCategory], {
|
||||||
description: 'list collections in context',
|
description: 'list collections in context',
|
||||||
})
|
})
|
||||||
@@ -710,7 +700,7 @@ export class CopilotContextResolver {
|
|||||||
context.workspaceId,
|
context.workspaceId,
|
||||||
content,
|
content,
|
||||||
limit,
|
limit,
|
||||||
this.getSignal(ctx.req),
|
getSignal(ctx.req).signal,
|
||||||
threshold
|
threshold
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -719,7 +709,7 @@ export class CopilotContextResolver {
|
|||||||
return await session.matchFiles(
|
return await session.matchFiles(
|
||||||
content,
|
content,
|
||||||
limit,
|
limit,
|
||||||
this.getSignal(ctx.req),
|
getSignal(ctx.req).signal,
|
||||||
scopedThreshold,
|
scopedThreshold,
|
||||||
threshold
|
threshold
|
||||||
);
|
);
|
||||||
@@ -785,7 +775,7 @@ export class CopilotContextResolver {
|
|||||||
context.workspaceId,
|
context.workspaceId,
|
||||||
content,
|
content,
|
||||||
limit,
|
limit,
|
||||||
this.getSignal(ctx.req),
|
getSignal(ctx.req).signal,
|
||||||
threshold
|
threshold
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -802,7 +792,7 @@ export class CopilotContextResolver {
|
|||||||
const chunks = await session.matchWorkspaceDocs(
|
const chunks = await session.matchWorkspaceDocs(
|
||||||
content,
|
content,
|
||||||
limit,
|
limit,
|
||||||
this.getSignal(ctx.req),
|
getSignal(ctx.req).signal,
|
||||||
scopedThreshold,
|
scopedThreshold,
|
||||||
threshold
|
threshold
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -13,22 +13,22 @@ import type { Request, Response } from 'express';
|
|||||||
import {
|
import {
|
||||||
BehaviorSubject,
|
BehaviorSubject,
|
||||||
catchError,
|
catchError,
|
||||||
concatMap,
|
|
||||||
connect,
|
connect,
|
||||||
EMPTY,
|
|
||||||
filter,
|
filter,
|
||||||
finalize,
|
finalize,
|
||||||
from,
|
from,
|
||||||
|
ignoreElements,
|
||||||
interval,
|
interval,
|
||||||
lastValueFrom,
|
lastValueFrom,
|
||||||
map,
|
map,
|
||||||
merge,
|
merge,
|
||||||
mergeMap,
|
mergeMap,
|
||||||
Observable,
|
Observable,
|
||||||
|
reduce,
|
||||||
Subject,
|
Subject,
|
||||||
take,
|
take,
|
||||||
takeUntil,
|
takeUntil,
|
||||||
toArray,
|
tap,
|
||||||
} from 'rxjs';
|
} from 'rxjs';
|
||||||
|
|
||||||
import {
|
import {
|
||||||
@@ -50,11 +50,13 @@ import {
|
|||||||
CopilotProviderFactory,
|
CopilotProviderFactory,
|
||||||
ModelInputType,
|
ModelInputType,
|
||||||
ModelOutputType,
|
ModelOutputType,
|
||||||
|
StreamObject,
|
||||||
} from './providers';
|
} from './providers';
|
||||||
import { StreamObjectParser } from './providers/utils';
|
import { StreamObjectParser } from './providers/utils';
|
||||||
import { ChatSession, ChatSessionService } from './session';
|
import { ChatSession, ChatSessionService } from './session';
|
||||||
import { CopilotStorage } from './storage';
|
import { CopilotStorage } from './storage';
|
||||||
import { ChatMessage, ChatQuerySchema } from './types';
|
import { ChatMessage, ChatQuerySchema } from './types';
|
||||||
|
import { getSignal } from './utils';
|
||||||
import { CopilotWorkflowService, GraphExecutorState } from './workflow';
|
import { CopilotWorkflowService, GraphExecutorState } from './workflow';
|
||||||
|
|
||||||
export interface ChatEvent {
|
export interface ChatEvent {
|
||||||
@@ -156,16 +158,6 @@ export class CopilotController implements BeforeApplicationShutdown {
|
|||||||
return [latestMessage, session];
|
return [latestMessage, session];
|
||||||
}
|
}
|
||||||
|
|
||||||
private getSignal(req: Request) {
|
|
||||||
const controller = new AbortController();
|
|
||||||
req.socket.on('close', hasError => {
|
|
||||||
if (hasError) {
|
|
||||||
controller.abort();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
return controller.signal;
|
|
||||||
}
|
|
||||||
|
|
||||||
private parseNumber(value: string | string[] | undefined) {
|
private parseNumber(value: string | string[] | undefined) {
|
||||||
if (!value) {
|
if (!value) {
|
||||||
return undefined;
|
return undefined;
|
||||||
@@ -255,7 +247,7 @@ export class CopilotController implements BeforeApplicationShutdown {
|
|||||||
const { reasoning, webSearch } = ChatQuerySchema.parse(query);
|
const { reasoning, webSearch } = ChatQuerySchema.parse(query);
|
||||||
const content = await provider.text({ modelId: model }, finalMessage, {
|
const content = await provider.text({ modelId: model }, finalMessage, {
|
||||||
...session.config.promptConfig,
|
...session.config.promptConfig,
|
||||||
signal: this.getSignal(req),
|
signal: getSignal(req).signal,
|
||||||
user: user.id,
|
user: user.id,
|
||||||
session: session.config.sessionId,
|
session: session.config.sessionId,
|
||||||
workspace: session.config.workspaceId,
|
workspace: session.config.workspaceId,
|
||||||
@@ -306,11 +298,13 @@ export class CopilotController implements BeforeApplicationShutdown {
|
|||||||
metrics.ai.counter('chat_stream_calls').add(1, { model });
|
metrics.ai.counter('chat_stream_calls').add(1, { model });
|
||||||
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
|
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
|
||||||
|
|
||||||
|
const { signal, onConnectionClosed } = getSignal(req);
|
||||||
const { messageId, reasoning, webSearch } = ChatQuerySchema.parse(query);
|
const { messageId, reasoning, webSearch } = ChatQuerySchema.parse(query);
|
||||||
|
|
||||||
const source$ = from(
|
const source$ = from(
|
||||||
provider.streamText({ modelId: model }, finalMessage, {
|
provider.streamText({ modelId: model }, finalMessage, {
|
||||||
...session.config.promptConfig,
|
...session.config.promptConfig,
|
||||||
signal: this.getSignal(req),
|
signal,
|
||||||
user: user.id,
|
user: user.id,
|
||||||
session: session.config.sessionId,
|
session: session.config.sessionId,
|
||||||
workspace: session.config.workspaceId,
|
workspace: session.config.workspaceId,
|
||||||
@@ -326,16 +320,25 @@ export class CopilotController implements BeforeApplicationShutdown {
|
|||||||
),
|
),
|
||||||
// save the generated text to the session
|
// save the generated text to the session
|
||||||
shared$.pipe(
|
shared$.pipe(
|
||||||
toArray(),
|
reduce((acc, chunk) => acc + chunk, ''),
|
||||||
concatMap(values => {
|
tap(buffer => {
|
||||||
session.push({
|
onConnectionClosed(isAborted => {
|
||||||
role: 'assistant',
|
session.push({
|
||||||
content: values.join(''),
|
role: 'assistant',
|
||||||
createdAt: new Date(),
|
content: isAborted ? '> Request aborted' : buffer,
|
||||||
|
createdAt: new Date(),
|
||||||
|
});
|
||||||
|
void session
|
||||||
|
.save()
|
||||||
|
.catch(err =>
|
||||||
|
this.logger.error(
|
||||||
|
'Failed to save session in sse stream',
|
||||||
|
err
|
||||||
|
)
|
||||||
|
);
|
||||||
});
|
});
|
||||||
return from(session.save());
|
|
||||||
}),
|
}),
|
||||||
mergeMap(() => EMPTY)
|
ignoreElements()
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
@@ -380,11 +383,13 @@ export class CopilotController implements BeforeApplicationShutdown {
|
|||||||
metrics.ai.counter('chat_object_stream_calls').add(1, { model });
|
metrics.ai.counter('chat_object_stream_calls').add(1, { model });
|
||||||
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
|
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
|
||||||
|
|
||||||
|
const { signal, onConnectionClosed } = getSignal(req);
|
||||||
const { messageId, reasoning, webSearch } = ChatQuerySchema.parse(query);
|
const { messageId, reasoning, webSearch } = ChatQuerySchema.parse(query);
|
||||||
|
|
||||||
const source$ = from(
|
const source$ = from(
|
||||||
provider.streamObject({ modelId: model }, finalMessage, {
|
provider.streamObject({ modelId: model }, finalMessage, {
|
||||||
...session.config.promptConfig,
|
...session.config.promptConfig,
|
||||||
signal: this.getSignal(req),
|
signal,
|
||||||
user: user.id,
|
user: user.id,
|
||||||
session: session.config.sessionId,
|
session: session.config.sessionId,
|
||||||
workspace: session.config.workspaceId,
|
workspace: session.config.workspaceId,
|
||||||
@@ -400,20 +405,29 @@ export class CopilotController implements BeforeApplicationShutdown {
|
|||||||
),
|
),
|
||||||
// save the generated text to the session
|
// save the generated text to the session
|
||||||
shared$.pipe(
|
shared$.pipe(
|
||||||
toArray(),
|
reduce((acc, chunk) => acc.concat([chunk]), [] as StreamObject[]),
|
||||||
concatMap(values => {
|
tap(result => {
|
||||||
const parser = new StreamObjectParser();
|
onConnectionClosed(isAborted => {
|
||||||
const streamObjects = parser.mergeTextDelta(values);
|
const parser = new StreamObjectParser();
|
||||||
const content = parser.mergeContent(streamObjects);
|
const streamObjects = parser.mergeTextDelta(result);
|
||||||
session.push({
|
const content = parser.mergeContent(streamObjects);
|
||||||
role: 'assistant',
|
session.push({
|
||||||
content,
|
role: 'assistant',
|
||||||
streamObjects,
|
content: isAborted ? '> Request aborted' : content,
|
||||||
createdAt: new Date(),
|
streamObjects: isAborted ? null : streamObjects,
|
||||||
|
createdAt: new Date(),
|
||||||
|
});
|
||||||
|
void session
|
||||||
|
.save()
|
||||||
|
.catch(err =>
|
||||||
|
this.logger.error(
|
||||||
|
'Failed to save session in sse stream',
|
||||||
|
err
|
||||||
|
)
|
||||||
|
);
|
||||||
});
|
});
|
||||||
return from(session.save());
|
|
||||||
}),
|
}),
|
||||||
mergeMap(() => EMPTY)
|
ignoreElements()
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
@@ -461,10 +475,12 @@ export class CopilotController implements BeforeApplicationShutdown {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
|
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
|
||||||
|
|
||||||
|
const { signal, onConnectionClosed } = getSignal(req);
|
||||||
const source$ = from(
|
const source$ = from(
|
||||||
this.workflow.runGraph(params, session.model, {
|
this.workflow.runGraph(params, session.model, {
|
||||||
...session.config.promptConfig,
|
...session.config.promptConfig,
|
||||||
signal: this.getSignal(req),
|
signal,
|
||||||
user: user.id,
|
user: user.id,
|
||||||
session: session.config.sessionId,
|
session: session.config.sessionId,
|
||||||
workspace: session.config.workspaceId,
|
workspace: session.config.workspaceId,
|
||||||
@@ -503,19 +519,30 @@ export class CopilotController implements BeforeApplicationShutdown {
|
|||||||
),
|
),
|
||||||
// save the generated text to the session
|
// save the generated text to the session
|
||||||
shared$.pipe(
|
shared$.pipe(
|
||||||
toArray(),
|
reduce((acc, chunk) => {
|
||||||
concatMap(values => {
|
if (chunk.status === GraphExecutorState.EmitContent) {
|
||||||
session.push({
|
acc += chunk.content;
|
||||||
role: 'assistant',
|
}
|
||||||
content: values
|
return acc;
|
||||||
.filter(v => v.status === GraphExecutorState.EmitContent)
|
}, ''),
|
||||||
.map(v => v.content)
|
tap(content => {
|
||||||
.join(''),
|
onConnectionClosed(isAborted => {
|
||||||
createdAt: new Date(),
|
session.push({
|
||||||
|
role: 'assistant',
|
||||||
|
content: isAborted ? '> Request aborted' : content,
|
||||||
|
createdAt: new Date(),
|
||||||
|
});
|
||||||
|
void session
|
||||||
|
.save()
|
||||||
|
.catch(err =>
|
||||||
|
this.logger.error(
|
||||||
|
'Failed to save session in sse stream',
|
||||||
|
err
|
||||||
|
)
|
||||||
|
);
|
||||||
});
|
});
|
||||||
return from(session.save());
|
|
||||||
}),
|
}),
|
||||||
mergeMap(() => EMPTY)
|
ignoreElements()
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
@@ -575,6 +602,8 @@ export class CopilotController implements BeforeApplicationShutdown {
|
|||||||
sessionId
|
sessionId
|
||||||
);
|
);
|
||||||
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
|
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
|
||||||
|
|
||||||
|
const { signal, onConnectionClosed } = getSignal(req);
|
||||||
const source$ = from(
|
const source$ = from(
|
||||||
provider.streamImages(
|
provider.streamImages(
|
||||||
{
|
{
|
||||||
@@ -588,7 +617,7 @@ export class CopilotController implements BeforeApplicationShutdown {
|
|||||||
...session.config.promptConfig,
|
...session.config.promptConfig,
|
||||||
quality: params.quality || undefined,
|
quality: params.quality || undefined,
|
||||||
seed: this.parseNumber(params.seed),
|
seed: this.parseNumber(params.seed),
|
||||||
signal: this.getSignal(req),
|
signal,
|
||||||
user: user.id,
|
user: user.id,
|
||||||
session: session.config.sessionId,
|
session: session.config.sessionId,
|
||||||
workspace: session.config.workspaceId,
|
workspace: session.config.workspaceId,
|
||||||
@@ -608,17 +637,26 @@ export class CopilotController implements BeforeApplicationShutdown {
|
|||||||
),
|
),
|
||||||
// save the generated text to the session
|
// save the generated text to the session
|
||||||
shared$.pipe(
|
shared$.pipe(
|
||||||
toArray(),
|
reduce((acc, chunk) => acc.concat([chunk]), [] as string[]),
|
||||||
concatMap(attachments => {
|
tap(attachments => {
|
||||||
session.push({
|
onConnectionClosed(isAborted => {
|
||||||
role: 'assistant',
|
session.push({
|
||||||
content: '',
|
role: 'assistant',
|
||||||
attachments: attachments,
|
content: isAborted ? '> Request aborted' : '',
|
||||||
createdAt: new Date(),
|
attachments: isAborted ? [] : attachments,
|
||||||
|
createdAt: new Date(),
|
||||||
|
});
|
||||||
|
void session
|
||||||
|
.save()
|
||||||
|
.catch(err =>
|
||||||
|
this.logger.error(
|
||||||
|
'Failed to save session in sse stream',
|
||||||
|
err
|
||||||
|
)
|
||||||
|
);
|
||||||
});
|
});
|
||||||
return from(session.save());
|
|
||||||
}),
|
}),
|
||||||
mergeMap(() => EMPTY)
|
ignoreElements()
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
@@ -656,7 +694,7 @@ export class CopilotController implements BeforeApplicationShutdown {
|
|||||||
`https://api.unsplash.com/search/photos?${query}`,
|
`https://api.unsplash.com/search/photos?${query}`,
|
||||||
{
|
{
|
||||||
headers: { Authorization: `Client-ID ${key}` },
|
headers: { Authorization: `Client-ID ${key}` },
|
||||||
signal: this.getSignal(req),
|
signal: getSignal(req).signal,
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
import { Readable } from 'node:stream';
|
import { Readable } from 'node:stream';
|
||||||
|
|
||||||
|
import type { Request } from 'express';
|
||||||
|
|
||||||
import { readBufferWithLimit } from '../../base';
|
import { readBufferWithLimit } from '../../base';
|
||||||
import { MAX_EMBEDDABLE_SIZE } from './types';
|
import { MAX_EMBEDDABLE_SIZE } from './types';
|
||||||
|
|
||||||
@@ -9,3 +11,38 @@ export function readStream(
|
|||||||
): Promise<Buffer> {
|
): Promise<Buffer> {
|
||||||
return readBufferWithLimit(readable, maxSize);
|
return readBufferWithLimit(readable, maxSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type RequestClosedCallback = (isAborted: boolean) => void;
|
||||||
|
type SignalReturnType = {
|
||||||
|
signal: AbortSignal;
|
||||||
|
onConnectionClosed: (cb: RequestClosedCallback) => void;
|
||||||
|
};
|
||||||
|
|
||||||
|
export function getSignal(req: Request): SignalReturnType {
|
||||||
|
const controller = new AbortController();
|
||||||
|
|
||||||
|
let isAborted = true;
|
||||||
|
let callback: ((isAborted: boolean) => void) | undefined = undefined;
|
||||||
|
|
||||||
|
const onSocketEnd = () => {
|
||||||
|
isAborted = false;
|
||||||
|
};
|
||||||
|
const onSocketClose = (hadError: boolean) => {
|
||||||
|
req.socket.off('end', onSocketEnd);
|
||||||
|
req.socket.off('close', onSocketClose);
|
||||||
|
const aborted = hadError || isAborted;
|
||||||
|
if (aborted) {
|
||||||
|
controller.abort();
|
||||||
|
}
|
||||||
|
|
||||||
|
callback?.(aborted);
|
||||||
|
};
|
||||||
|
|
||||||
|
req.socket.on('end', onSocketEnd);
|
||||||
|
req.socket.on('close', onSocketClose);
|
||||||
|
|
||||||
|
return {
|
||||||
|
signal: controller.signal,
|
||||||
|
onConnectionClosed: cb => (callback = cb),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|||||||
@@ -36,6 +36,9 @@ test.describe('AIAction/GenerateAnImageWithImage', () => {
|
|||||||
await expect(answer.getByTestId('ai-answer-image')).toBeVisible();
|
await expect(answer.getByTestId('ai-answer-image')).toBeVisible();
|
||||||
const insert = answer.getByTestId('answer-insert-below');
|
const insert = answer.getByTestId('answer-insert-below');
|
||||||
await insert.click();
|
await insert.click();
|
||||||
|
await page.waitForSelector('.affine-image-container');
|
||||||
|
await page.reload();
|
||||||
|
|
||||||
await utils.chatPanel.waitForHistory(page, [
|
await utils.chatPanel.waitForHistory(page, [
|
||||||
{
|
{
|
||||||
role: 'action',
|
role: 'action',
|
||||||
|
|||||||
@@ -71,11 +71,15 @@ test.describe('AIAction/GenerateAnImageWithText', () => {
|
|||||||
const { answer } = await generateImage();
|
const { answer } = await generateImage();
|
||||||
const insert = answer.getByTestId('answer-insert-below');
|
const insert = answer.getByTestId('answer-insert-below');
|
||||||
await insert.click();
|
await insert.click();
|
||||||
|
await page.waitForSelector('.affine-image-container');
|
||||||
|
await page.reload();
|
||||||
|
|
||||||
await utils.chatPanel.waitForHistory(page, [
|
await utils.chatPanel.waitForHistory(page, [
|
||||||
{
|
{
|
||||||
role: 'action',
|
role: 'action',
|
||||||
},
|
},
|
||||||
]);
|
]);
|
||||||
|
|
||||||
const { answer: panelAnswer, actionName } =
|
const { answer: panelAnswer, actionName } =
|
||||||
await utils.chatPanel.getLatestAIActionMessage(page);
|
await utils.chatPanel.getLatestAIActionMessage(page);
|
||||||
await expect(
|
await expect(
|
||||||
|
|||||||
Reference in New Issue
Block a user