diff --git a/packages/backend/server/src/plugins/copilot/context/resolver.ts b/packages/backend/server/src/plugins/copilot/context/resolver.ts index 0585d10a33..484ecf8ba7 100644 --- a/packages/backend/server/src/plugins/copilot/context/resolver.ts +++ b/packages/backend/server/src/plugins/copilot/context/resolver.ts @@ -51,7 +51,7 @@ import { COPILOT_LOCKER, CopilotType } from '../resolver'; import { ChatSessionService } from '../session'; import { CopilotStorage } from '../storage'; import { MAX_EMBEDDABLE_SIZE } from '../types'; -import { readStream } from '../utils'; +import { getSignal, readStream } from '../utils'; import { CopilotContextService } from './service'; @InputType() @@ -394,16 +394,6 @@ export class CopilotContextResolver { private readonly storage: CopilotStorage ) {} - private getSignal(req: Request) { - const controller = new AbortController(); - req.socket.on('close', hasError => { - if (hasError) { - controller.abort(); - } - }); - return controller.signal; - } - @ResolveField(() => [CopilotContextCategory], { description: 'list collections in context', }) @@ -710,7 +700,7 @@ export class CopilotContextResolver { context.workspaceId, content, limit, - this.getSignal(ctx.req), + getSignal(ctx.req).signal, threshold ); } @@ -719,7 +709,7 @@ export class CopilotContextResolver { return await session.matchFiles( content, limit, - this.getSignal(ctx.req), + getSignal(ctx.req).signal, scopedThreshold, threshold ); @@ -785,7 +775,7 @@ export class CopilotContextResolver { context.workspaceId, content, limit, - this.getSignal(ctx.req), + getSignal(ctx.req).signal, threshold ); } @@ -802,7 +792,7 @@ export class CopilotContextResolver { const chunks = await session.matchWorkspaceDocs( content, limit, - this.getSignal(ctx.req), + getSignal(ctx.req).signal, scopedThreshold, threshold ); diff --git a/packages/backend/server/src/plugins/copilot/controller.ts b/packages/backend/server/src/plugins/copilot/controller.ts index 65797a5bc2..7697ca24f1 100644 --- a/packages/backend/server/src/plugins/copilot/controller.ts +++ b/packages/backend/server/src/plugins/copilot/controller.ts @@ -13,22 +13,22 @@ import type { Request, Response } from 'express'; import { BehaviorSubject, catchError, - concatMap, connect, - EMPTY, filter, finalize, from, + ignoreElements, interval, lastValueFrom, map, merge, mergeMap, Observable, + reduce, Subject, take, takeUntil, - toArray, + tap, } from 'rxjs'; import { @@ -50,11 +50,13 @@ import { CopilotProviderFactory, ModelInputType, ModelOutputType, + StreamObject, } from './providers'; import { StreamObjectParser } from './providers/utils'; import { ChatSession, ChatSessionService } from './session'; import { CopilotStorage } from './storage'; import { ChatMessage, ChatQuerySchema } from './types'; +import { getSignal } from './utils'; import { CopilotWorkflowService, GraphExecutorState } from './workflow'; export interface ChatEvent { @@ -156,16 +158,6 @@ export class CopilotController implements BeforeApplicationShutdown { return [latestMessage, session]; } - private getSignal(req: Request) { - const controller = new AbortController(); - req.socket.on('close', hasError => { - if (hasError) { - controller.abort(); - } - }); - return controller.signal; - } - private parseNumber(value: string | string[] | undefined) { if (!value) { return undefined; @@ -255,7 +247,7 @@ export class CopilotController implements BeforeApplicationShutdown { const { reasoning, webSearch } = ChatQuerySchema.parse(query); const content = await provider.text({ modelId: model }, finalMessage, { ...session.config.promptConfig, - signal: this.getSignal(req), + signal: getSignal(req).signal, user: user.id, session: session.config.sessionId, workspace: session.config.workspaceId, @@ -306,11 +298,13 @@ export class CopilotController implements BeforeApplicationShutdown { metrics.ai.counter('chat_stream_calls').add(1, { model }); this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1); + const { signal, onConnectionClosed } = getSignal(req); const { messageId, reasoning, webSearch } = ChatQuerySchema.parse(query); + const source$ = from( provider.streamText({ modelId: model }, finalMessage, { ...session.config.promptConfig, - signal: this.getSignal(req), + signal, user: user.id, session: session.config.sessionId, workspace: session.config.workspaceId, @@ -326,16 +320,25 @@ export class CopilotController implements BeforeApplicationShutdown { ), // save the generated text to the session shared$.pipe( - toArray(), - concatMap(values => { - session.push({ - role: 'assistant', - content: values.join(''), - createdAt: new Date(), + reduce((acc, chunk) => acc + chunk, ''), + tap(buffer => { + onConnectionClosed(isAborted => { + session.push({ + role: 'assistant', + content: isAborted ? '> Request aborted' : buffer, + createdAt: new Date(), + }); + void session + .save() + .catch(err => + this.logger.error( + 'Failed to save session in sse stream', + err + ) + ); }); - return from(session.save()); }), - mergeMap(() => EMPTY) + ignoreElements() ) ) ), @@ -380,11 +383,13 @@ export class CopilotController implements BeforeApplicationShutdown { metrics.ai.counter('chat_object_stream_calls').add(1, { model }); this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1); + const { signal, onConnectionClosed } = getSignal(req); const { messageId, reasoning, webSearch } = ChatQuerySchema.parse(query); + const source$ = from( provider.streamObject({ modelId: model }, finalMessage, { ...session.config.promptConfig, - signal: this.getSignal(req), + signal, user: user.id, session: session.config.sessionId, workspace: session.config.workspaceId, @@ -400,20 +405,29 @@ export class CopilotController implements BeforeApplicationShutdown { ), // save the generated text to the session shared$.pipe( - toArray(), - concatMap(values => { - const parser = new StreamObjectParser(); - const streamObjects = parser.mergeTextDelta(values); - const content = parser.mergeContent(streamObjects); - session.push({ - role: 'assistant', - content, - streamObjects, - createdAt: new Date(), + reduce((acc, chunk) => acc.concat([chunk]), [] as StreamObject[]), + tap(result => { + onConnectionClosed(isAborted => { + const parser = new StreamObjectParser(); + const streamObjects = parser.mergeTextDelta(result); + const content = parser.mergeContent(streamObjects); + session.push({ + role: 'assistant', + content: isAborted ? '> Request aborted' : content, + streamObjects: isAborted ? null : streamObjects, + createdAt: new Date(), + }); + void session + .save() + .catch(err => + this.logger.error( + 'Failed to save session in sse stream', + err + ) + ); }); - return from(session.save()); }), - mergeMap(() => EMPTY) + ignoreElements() ) ) ), @@ -461,10 +475,12 @@ export class CopilotController implements BeforeApplicationShutdown { }); } this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1); + + const { signal, onConnectionClosed } = getSignal(req); const source$ = from( this.workflow.runGraph(params, session.model, { ...session.config.promptConfig, - signal: this.getSignal(req), + signal, user: user.id, session: session.config.sessionId, workspace: session.config.workspaceId, @@ -503,19 +519,30 @@ export class CopilotController implements BeforeApplicationShutdown { ), // save the generated text to the session shared$.pipe( - toArray(), - concatMap(values => { - session.push({ - role: 'assistant', - content: values - .filter(v => v.status === GraphExecutorState.EmitContent) - .map(v => v.content) - .join(''), - createdAt: new Date(), + reduce((acc, chunk) => { + if (chunk.status === GraphExecutorState.EmitContent) { + acc += chunk.content; + } + return acc; + }, ''), + tap(content => { + onConnectionClosed(isAborted => { + session.push({ + role: 'assistant', + content: isAborted ? '> Request aborted' : content, + createdAt: new Date(), + }); + void session + .save() + .catch(err => + this.logger.error( + 'Failed to save session in sse stream', + err + ) + ); }); - return from(session.save()); }), - mergeMap(() => EMPTY) + ignoreElements() ) ) ), @@ -575,6 +602,8 @@ export class CopilotController implements BeforeApplicationShutdown { sessionId ); this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1); + + const { signal, onConnectionClosed } = getSignal(req); const source$ = from( provider.streamImages( { @@ -588,7 +617,7 @@ export class CopilotController implements BeforeApplicationShutdown { ...session.config.promptConfig, quality: params.quality || undefined, seed: this.parseNumber(params.seed), - signal: this.getSignal(req), + signal, user: user.id, session: session.config.sessionId, workspace: session.config.workspaceId, @@ -608,17 +637,26 @@ export class CopilotController implements BeforeApplicationShutdown { ), // save the generated text to the session shared$.pipe( - toArray(), - concatMap(attachments => { - session.push({ - role: 'assistant', - content: '', - attachments: attachments, - createdAt: new Date(), + reduce((acc, chunk) => acc.concat([chunk]), [] as string[]), + tap(attachments => { + onConnectionClosed(isAborted => { + session.push({ + role: 'assistant', + content: isAborted ? '> Request aborted' : '', + attachments: isAborted ? [] : attachments, + createdAt: new Date(), + }); + void session + .save() + .catch(err => + this.logger.error( + 'Failed to save session in sse stream', + err + ) + ); }); - return from(session.save()); }), - mergeMap(() => EMPTY) + ignoreElements() ) ) ), @@ -656,7 +694,7 @@ export class CopilotController implements BeforeApplicationShutdown { `https://api.unsplash.com/search/photos?${query}`, { headers: { Authorization: `Client-ID ${key}` }, - signal: this.getSignal(req), + signal: getSignal(req).signal, } ); diff --git a/packages/backend/server/src/plugins/copilot/utils.ts b/packages/backend/server/src/plugins/copilot/utils.ts index 8ba8f138b7..9e8ff43439 100644 --- a/packages/backend/server/src/plugins/copilot/utils.ts +++ b/packages/backend/server/src/plugins/copilot/utils.ts @@ -1,5 +1,7 @@ import { Readable } from 'node:stream'; +import type { Request } from 'express'; + import { readBufferWithLimit } from '../../base'; import { MAX_EMBEDDABLE_SIZE } from './types'; @@ -9,3 +11,38 @@ export function readStream( ): Promise { return readBufferWithLimit(readable, maxSize); } + +type RequestClosedCallback = (isAborted: boolean) => void; +type SignalReturnType = { + signal: AbortSignal; + onConnectionClosed: (cb: RequestClosedCallback) => void; +}; + +export function getSignal(req: Request): SignalReturnType { + const controller = new AbortController(); + + let isAborted = true; + let callback: ((isAborted: boolean) => void) | undefined = undefined; + + const onSocketEnd = () => { + isAborted = false; + }; + const onSocketClose = (hadError: boolean) => { + req.socket.off('end', onSocketEnd); + req.socket.off('close', onSocketClose); + const aborted = hadError || isAborted; + if (aborted) { + controller.abort(); + } + + callback?.(aborted); + }; + + req.socket.on('end', onSocketEnd); + req.socket.on('close', onSocketClose); + + return { + signal: controller.signal, + onConnectionClosed: cb => (callback = cb), + }; +} diff --git a/tests/affine-cloud-copilot/e2e/ai-action/generate-an-image-with-image.spec.ts b/tests/affine-cloud-copilot/e2e/ai-action/generate-an-image-with-image.spec.ts index 6f1aa8c887..9165d50c3e 100644 --- a/tests/affine-cloud-copilot/e2e/ai-action/generate-an-image-with-image.spec.ts +++ b/tests/affine-cloud-copilot/e2e/ai-action/generate-an-image-with-image.spec.ts @@ -36,6 +36,9 @@ test.describe('AIAction/GenerateAnImageWithImage', () => { await expect(answer.getByTestId('ai-answer-image')).toBeVisible(); const insert = answer.getByTestId('answer-insert-below'); await insert.click(); + await page.waitForSelector('.affine-image-container'); + await page.reload(); + await utils.chatPanel.waitForHistory(page, [ { role: 'action', diff --git a/tests/affine-cloud-copilot/e2e/ai-action/generate-an-image-with-text.spec.ts b/tests/affine-cloud-copilot/e2e/ai-action/generate-an-image-with-text.spec.ts index 8081b800b4..5677b4de2b 100644 --- a/tests/affine-cloud-copilot/e2e/ai-action/generate-an-image-with-text.spec.ts +++ b/tests/affine-cloud-copilot/e2e/ai-action/generate-an-image-with-text.spec.ts @@ -71,11 +71,15 @@ test.describe('AIAction/GenerateAnImageWithText', () => { const { answer } = await generateImage(); const insert = answer.getByTestId('answer-insert-below'); await insert.click(); + await page.waitForSelector('.affine-image-container'); + await page.reload(); + await utils.chatPanel.waitForHistory(page, [ { role: 'action', }, ]); + const { answer: panelAnswer, actionName } = await utils.chatPanel.getLatestAIActionMessage(page); await expect(