mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-12 12:28:42 +00:00
feat(server): graceful shutdown for AI streams (#10025)
https://github.com/user-attachments/assets/8dd3c4f5-4059-4f03-9f51-68078d7ab5c4
This commit is contained in:
@@ -44,6 +44,11 @@ export async function createApp() {
|
|||||||
app.useGlobalInterceptors(app.get(CacheInterceptor));
|
app.useGlobalInterceptors(app.get(CacheInterceptor));
|
||||||
app.useGlobalFilters(new GlobalExceptionFilter(app.getHttpAdapter()));
|
app.useGlobalFilters(new GlobalExceptionFilter(app.getHttpAdapter()));
|
||||||
app.use(cookieParser());
|
app.use(cookieParser());
|
||||||
|
// only enable shutdown hooks in production
|
||||||
|
// https://docs.nestjs.com/fundamentals/lifecycle-events#application-shutdown
|
||||||
|
if (AFFiNE.NODE_ENV === 'production') {
|
||||||
|
app.enableShutdownHooks();
|
||||||
|
}
|
||||||
|
|
||||||
const adapter = new SocketIoAdapter(app);
|
const adapter = new SocketIoAdapter(app);
|
||||||
app.useWebSocketAdapter(adapter);
|
app.useWebSocketAdapter(adapter);
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
import type { OnModuleDestroy, OnModuleInit } from '@nestjs/common';
|
import type { OnApplicationShutdown, OnModuleInit } from '@nestjs/common';
|
||||||
import { Injectable } from '@nestjs/common';
|
import { Injectable } from '@nestjs/common';
|
||||||
import { Prisma, PrismaClient } from '@prisma/client';
|
import { Prisma, PrismaClient } from '@prisma/client';
|
||||||
|
|
||||||
@Injectable()
|
@Injectable()
|
||||||
export class PrismaService
|
export class PrismaService
|
||||||
extends PrismaClient
|
extends PrismaClient
|
||||||
implements OnModuleInit, OnModuleDestroy
|
implements OnModuleInit, OnApplicationShutdown
|
||||||
{
|
{
|
||||||
static INSTANCE: PrismaService | null = null;
|
static INSTANCE: PrismaService | null = null;
|
||||||
|
|
||||||
@@ -18,7 +18,7 @@ export class PrismaService
|
|||||||
await this.$connect();
|
await this.$connect();
|
||||||
}
|
}
|
||||||
|
|
||||||
async onModuleDestroy(): Promise<void> {
|
async onApplicationShutdown(): Promise<void> {
|
||||||
if (!AFFiNE.node.test) {
|
if (!AFFiNE.node.test) {
|
||||||
await this.$disconnect();
|
await this.$disconnect();
|
||||||
PrismaService.INSTANCE = null;
|
PrismaService.INSTANCE = null;
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import {
|
import {
|
||||||
|
BeforeApplicationShutdown,
|
||||||
Controller,
|
Controller,
|
||||||
Get,
|
Get,
|
||||||
Logger,
|
Logger,
|
||||||
@@ -10,19 +11,22 @@ import {
|
|||||||
} from '@nestjs/common';
|
} from '@nestjs/common';
|
||||||
import type { Request, Response } from 'express';
|
import type { Request, Response } from 'express';
|
||||||
import {
|
import {
|
||||||
|
BehaviorSubject,
|
||||||
catchError,
|
catchError,
|
||||||
concatMap,
|
concatMap,
|
||||||
connect,
|
connect,
|
||||||
EMPTY,
|
EMPTY,
|
||||||
|
filter,
|
||||||
finalize,
|
finalize,
|
||||||
from,
|
from,
|
||||||
interval,
|
interval,
|
||||||
|
lastValueFrom,
|
||||||
map,
|
map,
|
||||||
merge,
|
merge,
|
||||||
mergeMap,
|
mergeMap,
|
||||||
Observable,
|
Observable,
|
||||||
Subject,
|
Subject,
|
||||||
switchMap,
|
take,
|
||||||
takeUntil,
|
takeUntil,
|
||||||
toArray,
|
toArray,
|
||||||
} from 'rxjs';
|
} from 'rxjs';
|
||||||
@@ -59,8 +63,9 @@ type CheckResult = {
|
|||||||
const PING_INTERVAL = 5000;
|
const PING_INTERVAL = 5000;
|
||||||
|
|
||||||
@Controller('/api/copilot')
|
@Controller('/api/copilot')
|
||||||
export class CopilotController {
|
export class CopilotController implements BeforeApplicationShutdown {
|
||||||
private readonly logger = new Logger(CopilotController.name);
|
private readonly logger = new Logger(CopilotController.name);
|
||||||
|
private readonly ongoingStreamCount$ = new BehaviorSubject(0);
|
||||||
|
|
||||||
constructor(
|
constructor(
|
||||||
private readonly config: Config,
|
private readonly config: Config,
|
||||||
@@ -70,6 +75,16 @@ export class CopilotController {
|
|||||||
private readonly storage: CopilotStorage
|
private readonly storage: CopilotStorage
|
||||||
) {}
|
) {}
|
||||||
|
|
||||||
|
async beforeApplicationShutdown() {
|
||||||
|
await lastValueFrom(
|
||||||
|
this.ongoingStreamCount$.asObservable().pipe(
|
||||||
|
filter(count => count === 0),
|
||||||
|
take(1)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
this.ongoingStreamCount$.complete();
|
||||||
|
}
|
||||||
|
|
||||||
private async checkRequest(
|
private async checkRequest(
|
||||||
userId: string,
|
userId: string,
|
||||||
sessionId: string,
|
sessionId: string,
|
||||||
@@ -241,6 +256,7 @@ export class CopilotController {
|
|||||||
const session = await this.appendSessionMessage(sessionId, messageId);
|
const session = await this.appendSessionMessage(sessionId, messageId);
|
||||||
try {
|
try {
|
||||||
metrics.ai.counter('chat_stream_calls').add(1, { model: session.model });
|
metrics.ai.counter('chat_stream_calls').add(1, { model: session.model });
|
||||||
|
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
|
||||||
const source$ = from(
|
const source$ = from(
|
||||||
provider.generateTextStream(session.finish(params), session.model, {
|
provider.generateTextStream(session.finish(params), session.model, {
|
||||||
...session.config.promptConfig,
|
...session.config.promptConfig,
|
||||||
@@ -265,7 +281,7 @@ export class CopilotController {
|
|||||||
});
|
});
|
||||||
return from(session.save());
|
return from(session.save());
|
||||||
}),
|
}),
|
||||||
switchMap(() => EMPTY)
|
mergeMap(() => EMPTY)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
@@ -274,6 +290,9 @@ export class CopilotController {
|
|||||||
.counter('chat_stream_errors')
|
.counter('chat_stream_errors')
|
||||||
.add(1, { model: session.model });
|
.add(1, { model: session.model });
|
||||||
return mapSseError(e);
|
return mapSseError(e);
|
||||||
|
}),
|
||||||
|
finalize(() => {
|
||||||
|
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value - 1);
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -306,7 +325,7 @@ export class CopilotController {
|
|||||||
attachments: latestMessage.attachments,
|
attachments: latestMessage.attachments,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
|
||||||
const source$ = from(
|
const source$ = from(
|
||||||
this.workflow.runGraph(params, session.model, {
|
this.workflow.runGraph(params, session.model, {
|
||||||
...session.config.promptConfig,
|
...session.config.promptConfig,
|
||||||
@@ -359,7 +378,7 @@ export class CopilotController {
|
|||||||
});
|
});
|
||||||
return from(session.save());
|
return from(session.save());
|
||||||
}),
|
}),
|
||||||
switchMap(() => EMPTY)
|
mergeMap(() => EMPTY)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
@@ -368,7 +387,10 @@ export class CopilotController {
|
|||||||
.counter('workflow_errors')
|
.counter('workflow_errors')
|
||||||
.add(1, { model: session.model });
|
.add(1, { model: session.model });
|
||||||
return mapSseError(e);
|
return mapSseError(e);
|
||||||
})
|
}),
|
||||||
|
finalize(() =>
|
||||||
|
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value - 1)
|
||||||
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
return this.mergePingStream(messageId, source$);
|
return this.mergePingStream(messageId, source$);
|
||||||
@@ -413,7 +435,7 @@ export class CopilotController {
|
|||||||
user.id,
|
user.id,
|
||||||
sessionId
|
sessionId
|
||||||
);
|
);
|
||||||
|
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
|
||||||
const source$ = from(
|
const source$ = from(
|
||||||
provider.generateImagesStream(session.finish(params), session.model, {
|
provider.generateImagesStream(session.finish(params), session.model, {
|
||||||
...session.config.promptConfig,
|
...session.config.promptConfig,
|
||||||
@@ -445,7 +467,7 @@ export class CopilotController {
|
|||||||
});
|
});
|
||||||
return from(session.save());
|
return from(session.save());
|
||||||
}),
|
}),
|
||||||
switchMap(() => EMPTY)
|
mergeMap(() => EMPTY)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
@@ -454,7 +476,10 @@ export class CopilotController {
|
|||||||
.counter('images_stream_errors')
|
.counter('images_stream_errors')
|
||||||
.add(1, { model: session.model });
|
.add(1, { model: session.model });
|
||||||
return mapSseError(e);
|
return mapSseError(e);
|
||||||
})
|
}),
|
||||||
|
finalize(() =>
|
||||||
|
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value - 1)
|
||||||
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
return this.mergePingStream(messageId, source$);
|
return this.mergePingStream(messageId, source$);
|
||||||
|
|||||||
Reference in New Issue
Block a user