feat(server): graceful shutdown for AI streams (#10025)

https://github.com/user-attachments/assets/8dd3c4f5-4059-4f03-9f51-68078d7ab5c4
This commit is contained in:
Brooooooklyn
2025-02-07 11:25:02 +00:00
parent 0df94b8e35
commit 4e00ddd5f1
3 changed files with 42 additions and 12 deletions

View File

@@ -44,6 +44,11 @@ export async function createApp() {
app.useGlobalInterceptors(app.get(CacheInterceptor)); app.useGlobalInterceptors(app.get(CacheInterceptor));
app.useGlobalFilters(new GlobalExceptionFilter(app.getHttpAdapter())); app.useGlobalFilters(new GlobalExceptionFilter(app.getHttpAdapter()));
app.use(cookieParser()); app.use(cookieParser());
// only enable shutdown hooks in production
// https://docs.nestjs.com/fundamentals/lifecycle-events#application-shutdown
if (AFFiNE.NODE_ENV === 'production') {
app.enableShutdownHooks();
}
const adapter = new SocketIoAdapter(app); const adapter = new SocketIoAdapter(app);
app.useWebSocketAdapter(adapter); app.useWebSocketAdapter(adapter);

View File

@@ -1,11 +1,11 @@
import type { OnModuleDestroy, OnModuleInit } from '@nestjs/common'; import type { OnApplicationShutdown, OnModuleInit } from '@nestjs/common';
import { Injectable } from '@nestjs/common'; import { Injectable } from '@nestjs/common';
import { Prisma, PrismaClient } from '@prisma/client'; import { Prisma, PrismaClient } from '@prisma/client';
@Injectable() @Injectable()
export class PrismaService export class PrismaService
extends PrismaClient extends PrismaClient
implements OnModuleInit, OnModuleDestroy implements OnModuleInit, OnApplicationShutdown
{ {
static INSTANCE: PrismaService | null = null; static INSTANCE: PrismaService | null = null;
@@ -18,7 +18,7 @@ export class PrismaService
await this.$connect(); await this.$connect();
} }
async onModuleDestroy(): Promise<void> { async onApplicationShutdown(): Promise<void> {
if (!AFFiNE.node.test) { if (!AFFiNE.node.test) {
await this.$disconnect(); await this.$disconnect();
PrismaService.INSTANCE = null; PrismaService.INSTANCE = null;

View File

@@ -1,4 +1,5 @@
import { import {
BeforeApplicationShutdown,
Controller, Controller,
Get, Get,
Logger, Logger,
@@ -10,19 +11,22 @@ import {
} from '@nestjs/common'; } from '@nestjs/common';
import type { Request, Response } from 'express'; import type { Request, Response } from 'express';
import { import {
BehaviorSubject,
catchError, catchError,
concatMap, concatMap,
connect, connect,
EMPTY, EMPTY,
filter,
finalize, finalize,
from, from,
interval, interval,
lastValueFrom,
map, map,
merge, merge,
mergeMap, mergeMap,
Observable, Observable,
Subject, Subject,
switchMap, take,
takeUntil, takeUntil,
toArray, toArray,
} from 'rxjs'; } from 'rxjs';
@@ -59,8 +63,9 @@ type CheckResult = {
const PING_INTERVAL = 5000; const PING_INTERVAL = 5000;
@Controller('/api/copilot') @Controller('/api/copilot')
export class CopilotController { export class CopilotController implements BeforeApplicationShutdown {
private readonly logger = new Logger(CopilotController.name); private readonly logger = new Logger(CopilotController.name);
private readonly ongoingStreamCount$ = new BehaviorSubject(0);
constructor( constructor(
private readonly config: Config, private readonly config: Config,
@@ -70,6 +75,16 @@ export class CopilotController {
private readonly storage: CopilotStorage private readonly storage: CopilotStorage
) {} ) {}
async beforeApplicationShutdown() {
await lastValueFrom(
this.ongoingStreamCount$.asObservable().pipe(
filter(count => count === 0),
take(1)
)
);
this.ongoingStreamCount$.complete();
}
private async checkRequest( private async checkRequest(
userId: string, userId: string,
sessionId: string, sessionId: string,
@@ -241,6 +256,7 @@ export class CopilotController {
const session = await this.appendSessionMessage(sessionId, messageId); const session = await this.appendSessionMessage(sessionId, messageId);
try { try {
metrics.ai.counter('chat_stream_calls').add(1, { model: session.model }); metrics.ai.counter('chat_stream_calls').add(1, { model: session.model });
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
const source$ = from( const source$ = from(
provider.generateTextStream(session.finish(params), session.model, { provider.generateTextStream(session.finish(params), session.model, {
...session.config.promptConfig, ...session.config.promptConfig,
@@ -265,7 +281,7 @@ export class CopilotController {
}); });
return from(session.save()); return from(session.save());
}), }),
switchMap(() => EMPTY) mergeMap(() => EMPTY)
) )
) )
), ),
@@ -274,6 +290,9 @@ export class CopilotController {
.counter('chat_stream_errors') .counter('chat_stream_errors')
.add(1, { model: session.model }); .add(1, { model: session.model });
return mapSseError(e); return mapSseError(e);
}),
finalize(() => {
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value - 1);
}) })
); );
@@ -306,7 +325,7 @@ export class CopilotController {
attachments: latestMessage.attachments, attachments: latestMessage.attachments,
}); });
} }
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
const source$ = from( const source$ = from(
this.workflow.runGraph(params, session.model, { this.workflow.runGraph(params, session.model, {
...session.config.promptConfig, ...session.config.promptConfig,
@@ -359,7 +378,7 @@ export class CopilotController {
}); });
return from(session.save()); return from(session.save());
}), }),
switchMap(() => EMPTY) mergeMap(() => EMPTY)
) )
) )
), ),
@@ -368,7 +387,10 @@ export class CopilotController {
.counter('workflow_errors') .counter('workflow_errors')
.add(1, { model: session.model }); .add(1, { model: session.model });
return mapSseError(e); return mapSseError(e);
}) }),
finalize(() =>
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value - 1)
)
); );
return this.mergePingStream(messageId, source$); return this.mergePingStream(messageId, source$);
@@ -413,7 +435,7 @@ export class CopilotController {
user.id, user.id,
sessionId sessionId
); );
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
const source$ = from( const source$ = from(
provider.generateImagesStream(session.finish(params), session.model, { provider.generateImagesStream(session.finish(params), session.model, {
...session.config.promptConfig, ...session.config.promptConfig,
@@ -445,7 +467,7 @@ export class CopilotController {
}); });
return from(session.save()); return from(session.save());
}), }),
switchMap(() => EMPTY) mergeMap(() => EMPTY)
) )
) )
), ),
@@ -454,7 +476,10 @@ export class CopilotController {
.counter('images_stream_errors') .counter('images_stream_errors')
.add(1, { model: session.model }); .add(1, { model: session.model });
return mapSseError(e); return mapSseError(e);
}) }),
finalize(() =>
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value - 1)
)
); );
return this.mergePingStream(messageId, source$); return this.mergePingStream(messageId, source$);