mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-12 04:18:54 +00:00
feat(server): graceful shutdown for AI streams (#10025)
https://github.com/user-attachments/assets/8dd3c4f5-4059-4f03-9f51-68078d7ab5c4
This commit is contained in:
@@ -44,6 +44,11 @@ export async function createApp() {
|
||||
app.useGlobalInterceptors(app.get(CacheInterceptor));
|
||||
app.useGlobalFilters(new GlobalExceptionFilter(app.getHttpAdapter()));
|
||||
app.use(cookieParser());
|
||||
// only enable shutdown hooks in production
|
||||
// https://docs.nestjs.com/fundamentals/lifecycle-events#application-shutdown
|
||||
if (AFFiNE.NODE_ENV === 'production') {
|
||||
app.enableShutdownHooks();
|
||||
}
|
||||
|
||||
const adapter = new SocketIoAdapter(app);
|
||||
app.useWebSocketAdapter(adapter);
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import type { OnModuleDestroy, OnModuleInit } from '@nestjs/common';
|
||||
import type { OnApplicationShutdown, OnModuleInit } from '@nestjs/common';
|
||||
import { Injectable } from '@nestjs/common';
|
||||
import { Prisma, PrismaClient } from '@prisma/client';
|
||||
|
||||
@Injectable()
|
||||
export class PrismaService
|
||||
extends PrismaClient
|
||||
implements OnModuleInit, OnModuleDestroy
|
||||
implements OnModuleInit, OnApplicationShutdown
|
||||
{
|
||||
static INSTANCE: PrismaService | null = null;
|
||||
|
||||
@@ -18,7 +18,7 @@ export class PrismaService
|
||||
await this.$connect();
|
||||
}
|
||||
|
||||
async onModuleDestroy(): Promise<void> {
|
||||
async onApplicationShutdown(): Promise<void> {
|
||||
if (!AFFiNE.node.test) {
|
||||
await this.$disconnect();
|
||||
PrismaService.INSTANCE = null;
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import {
|
||||
BeforeApplicationShutdown,
|
||||
Controller,
|
||||
Get,
|
||||
Logger,
|
||||
@@ -10,19 +11,22 @@ import {
|
||||
} from '@nestjs/common';
|
||||
import type { Request, Response } from 'express';
|
||||
import {
|
||||
BehaviorSubject,
|
||||
catchError,
|
||||
concatMap,
|
||||
connect,
|
||||
EMPTY,
|
||||
filter,
|
||||
finalize,
|
||||
from,
|
||||
interval,
|
||||
lastValueFrom,
|
||||
map,
|
||||
merge,
|
||||
mergeMap,
|
||||
Observable,
|
||||
Subject,
|
||||
switchMap,
|
||||
take,
|
||||
takeUntil,
|
||||
toArray,
|
||||
} from 'rxjs';
|
||||
@@ -59,8 +63,9 @@ type CheckResult = {
|
||||
const PING_INTERVAL = 5000;
|
||||
|
||||
@Controller('/api/copilot')
|
||||
export class CopilotController {
|
||||
export class CopilotController implements BeforeApplicationShutdown {
|
||||
private readonly logger = new Logger(CopilotController.name);
|
||||
private readonly ongoingStreamCount$ = new BehaviorSubject(0);
|
||||
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
@@ -70,6 +75,16 @@ export class CopilotController {
|
||||
private readonly storage: CopilotStorage
|
||||
) {}
|
||||
|
||||
async beforeApplicationShutdown() {
|
||||
await lastValueFrom(
|
||||
this.ongoingStreamCount$.asObservable().pipe(
|
||||
filter(count => count === 0),
|
||||
take(1)
|
||||
)
|
||||
);
|
||||
this.ongoingStreamCount$.complete();
|
||||
}
|
||||
|
||||
private async checkRequest(
|
||||
userId: string,
|
||||
sessionId: string,
|
||||
@@ -241,6 +256,7 @@ export class CopilotController {
|
||||
const session = await this.appendSessionMessage(sessionId, messageId);
|
||||
try {
|
||||
metrics.ai.counter('chat_stream_calls').add(1, { model: session.model });
|
||||
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
|
||||
const source$ = from(
|
||||
provider.generateTextStream(session.finish(params), session.model, {
|
||||
...session.config.promptConfig,
|
||||
@@ -265,7 +281,7 @@ export class CopilotController {
|
||||
});
|
||||
return from(session.save());
|
||||
}),
|
||||
switchMap(() => EMPTY)
|
||||
mergeMap(() => EMPTY)
|
||||
)
|
||||
)
|
||||
),
|
||||
@@ -274,6 +290,9 @@ export class CopilotController {
|
||||
.counter('chat_stream_errors')
|
||||
.add(1, { model: session.model });
|
||||
return mapSseError(e);
|
||||
}),
|
||||
finalize(() => {
|
||||
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value - 1);
|
||||
})
|
||||
);
|
||||
|
||||
@@ -306,7 +325,7 @@ export class CopilotController {
|
||||
attachments: latestMessage.attachments,
|
||||
});
|
||||
}
|
||||
|
||||
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
|
||||
const source$ = from(
|
||||
this.workflow.runGraph(params, session.model, {
|
||||
...session.config.promptConfig,
|
||||
@@ -359,7 +378,7 @@ export class CopilotController {
|
||||
});
|
||||
return from(session.save());
|
||||
}),
|
||||
switchMap(() => EMPTY)
|
||||
mergeMap(() => EMPTY)
|
||||
)
|
||||
)
|
||||
),
|
||||
@@ -368,7 +387,10 @@ export class CopilotController {
|
||||
.counter('workflow_errors')
|
||||
.add(1, { model: session.model });
|
||||
return mapSseError(e);
|
||||
})
|
||||
}),
|
||||
finalize(() =>
|
||||
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value - 1)
|
||||
)
|
||||
);
|
||||
|
||||
return this.mergePingStream(messageId, source$);
|
||||
@@ -413,7 +435,7 @@ export class CopilotController {
|
||||
user.id,
|
||||
sessionId
|
||||
);
|
||||
|
||||
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
|
||||
const source$ = from(
|
||||
provider.generateImagesStream(session.finish(params), session.model, {
|
||||
...session.config.promptConfig,
|
||||
@@ -445,7 +467,7 @@ export class CopilotController {
|
||||
});
|
||||
return from(session.save());
|
||||
}),
|
||||
switchMap(() => EMPTY)
|
||||
mergeMap(() => EMPTY)
|
||||
)
|
||||
)
|
||||
),
|
||||
@@ -454,7 +476,10 @@ export class CopilotController {
|
||||
.counter('images_stream_errors')
|
||||
.add(1, { model: session.model });
|
||||
return mapSseError(e);
|
||||
})
|
||||
}),
|
||||
finalize(() =>
|
||||
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value - 1)
|
||||
)
|
||||
);
|
||||
|
||||
return this.mergePingStream(messageId, source$);
|
||||
|
||||
Reference in New Issue
Block a user