diff --git a/packages/backend/server/src/app.ts b/packages/backend/server/src/app.ts index 824eb29afd..33a6510509 100644 --- a/packages/backend/server/src/app.ts +++ b/packages/backend/server/src/app.ts @@ -44,6 +44,11 @@ export async function createApp() { app.useGlobalInterceptors(app.get(CacheInterceptor)); app.useGlobalFilters(new GlobalExceptionFilter(app.getHttpAdapter())); app.use(cookieParser()); + // only enable shutdown hooks in production + // https://docs.nestjs.com/fundamentals/lifecycle-events#application-shutdown + if (AFFiNE.NODE_ENV === 'production') { + app.enableShutdownHooks(); + } const adapter = new SocketIoAdapter(app); app.useWebSocketAdapter(adapter); diff --git a/packages/backend/server/src/base/prisma/service.ts b/packages/backend/server/src/base/prisma/service.ts index a4e53e20fc..281b67b8a7 100644 --- a/packages/backend/server/src/base/prisma/service.ts +++ b/packages/backend/server/src/base/prisma/service.ts @@ -1,11 +1,11 @@ -import type { OnModuleDestroy, OnModuleInit } from '@nestjs/common'; +import type { OnApplicationShutdown, OnModuleInit } from '@nestjs/common'; import { Injectable } from '@nestjs/common'; import { Prisma, PrismaClient } from '@prisma/client'; @Injectable() export class PrismaService extends PrismaClient - implements OnModuleInit, OnModuleDestroy + implements OnModuleInit, OnApplicationShutdown { static INSTANCE: PrismaService | null = null; @@ -18,7 +18,7 @@ export class PrismaService await this.$connect(); } - async onModuleDestroy(): Promise { + async onApplicationShutdown(): Promise { if (!AFFiNE.node.test) { await this.$disconnect(); PrismaService.INSTANCE = null; diff --git a/packages/backend/server/src/plugins/copilot/controller.ts b/packages/backend/server/src/plugins/copilot/controller.ts index fba8f0dee1..5ac2735ee8 100644 --- a/packages/backend/server/src/plugins/copilot/controller.ts +++ b/packages/backend/server/src/plugins/copilot/controller.ts @@ -1,4 +1,5 @@ import { + BeforeApplicationShutdown, Controller, Get, Logger, @@ -10,19 +11,22 @@ import { } from '@nestjs/common'; import type { Request, Response } from 'express'; import { + BehaviorSubject, catchError, concatMap, connect, EMPTY, + filter, finalize, from, interval, + lastValueFrom, map, merge, mergeMap, Observable, Subject, - switchMap, + take, takeUntil, toArray, } from 'rxjs'; @@ -59,8 +63,9 @@ type CheckResult = { const PING_INTERVAL = 5000; @Controller('/api/copilot') -export class CopilotController { +export class CopilotController implements BeforeApplicationShutdown { private readonly logger = new Logger(CopilotController.name); + private readonly ongoingStreamCount$ = new BehaviorSubject(0); constructor( private readonly config: Config, @@ -70,6 +75,16 @@ export class CopilotController { private readonly storage: CopilotStorage ) {} + async beforeApplicationShutdown() { + await lastValueFrom( + this.ongoingStreamCount$.asObservable().pipe( + filter(count => count === 0), + take(1) + ) + ); + this.ongoingStreamCount$.complete(); + } + private async checkRequest( userId: string, sessionId: string, @@ -241,6 +256,7 @@ export class CopilotController { const session = await this.appendSessionMessage(sessionId, messageId); try { metrics.ai.counter('chat_stream_calls').add(1, { model: session.model }); + this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1); const source$ = from( provider.generateTextStream(session.finish(params), session.model, { ...session.config.promptConfig, @@ -265,7 +281,7 @@ export class CopilotController { }); return from(session.save()); }), - switchMap(() => EMPTY) + mergeMap(() => EMPTY) ) ) ), @@ -274,6 +290,9 @@ export class CopilotController { .counter('chat_stream_errors') .add(1, { model: session.model }); return mapSseError(e); + }), + finalize(() => { + this.ongoingStreamCount$.next(this.ongoingStreamCount$.value - 1); }) ); @@ -306,7 +325,7 @@ export class CopilotController { attachments: latestMessage.attachments, }); } - + this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1); const source$ = from( this.workflow.runGraph(params, session.model, { ...session.config.promptConfig, @@ -359,7 +378,7 @@ export class CopilotController { }); return from(session.save()); }), - switchMap(() => EMPTY) + mergeMap(() => EMPTY) ) ) ), @@ -368,7 +387,10 @@ export class CopilotController { .counter('workflow_errors') .add(1, { model: session.model }); return mapSseError(e); - }) + }), + finalize(() => + this.ongoingStreamCount$.next(this.ongoingStreamCount$.value - 1) + ) ); return this.mergePingStream(messageId, source$); @@ -413,7 +435,7 @@ export class CopilotController { user.id, sessionId ); - + this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1); const source$ = from( provider.generateImagesStream(session.finish(params), session.model, { ...session.config.promptConfig, @@ -445,7 +467,7 @@ export class CopilotController { }); return from(session.save()); }), - switchMap(() => EMPTY) + mergeMap(() => EMPTY) ) ) ), @@ -454,7 +476,10 @@ export class CopilotController { .counter('images_stream_errors') .add(1, { model: session.model }); return mapSseError(e); - }) + }), + finalize(() => + this.ongoingStreamCount$.next(this.ongoingStreamCount$.value - 1) + ) ); return this.mergePingStream(messageId, source$);