import { BeforeApplicationShutdown, Controller, Get, Logger, Param, Query, Req, Res, Sse, } from '@nestjs/common'; import type { Request, Response } from 'express'; import { BehaviorSubject, catchError, connect, filter, finalize, from, ignoreElements, interval, lastValueFrom, map, merge, mergeMap, Observable, reduce, Subject, take, takeUntil, tap, } from 'rxjs'; import { BlobNotFound, CallMetric, Config, CopilotFailedToGenerateText, CopilotSessionNotFound, InternalServerError, mapAnyError, mapSseError, metrics, NoCopilotProviderAvailable, UnsplashIsNotConfigured, } from '../../base'; import { CurrentUser, Public } from '../../core/auth'; import { CopilotProvider, CopilotProviderFactory, ModelInputType, ModelOutputType, StreamObject, } from './providers'; import { StreamObjectParser } from './providers/utils'; import { ChatSession, ChatSessionService } from './session'; import { CopilotStorage } from './storage'; import { ChatMessage, ChatQuerySchema } from './types'; import { getSignal } from './utils'; import { CopilotWorkflowService, GraphExecutorState } from './workflow'; export interface ChatEvent { type: 'event' | 'attachment' | 'message' | 'error' | 'ping'; id?: string; data: string | object; } const PING_INTERVAL = 5000; @Controller('/api/copilot') export class CopilotController implements BeforeApplicationShutdown { private readonly logger = new Logger(CopilotController.name); private readonly ongoingStreamCount$ = new BehaviorSubject(0); constructor( private readonly config: Config, private readonly chatSession: ChatSessionService, private readonly provider: CopilotProviderFactory, private readonly workflow: CopilotWorkflowService, private readonly storage: CopilotStorage ) {} async beforeApplicationShutdown() { await lastValueFrom( this.ongoingStreamCount$.asObservable().pipe( filter(count => count === 0), take(1) ) ); this.ongoingStreamCount$.complete(); } private async chooseProvider( outputType: ModelOutputType, userId: string, sessionId: string, messageId?: string, modelId?: string ): Promise<{ provider: CopilotProvider; model: string; hasAttachment: boolean; }> { const [, session] = await Promise.all([ this.chatSession.checkQuota(userId), this.chatSession.get(sessionId), ]); if (!session || session.config.userId !== userId) { throw new CopilotSessionNotFound(); } const model = modelId && session.optionalModels.includes(modelId) ? modelId : session.model; const hasAttachment = messageId ? !!(await session.getMessageById(messageId)).attachments?.length : false; const provider = await this.provider.getProvider({ outputType, modelId: model, }); if (!provider) { throw new NoCopilotProviderAvailable(); } return { provider, model, hasAttachment }; } private async appendSessionMessage( sessionId: string, messageId?: string, retry = false ): Promise<[ChatMessage | undefined, ChatSession]> { const session = await this.chatSession.get(sessionId); if (!session) { throw new CopilotSessionNotFound(); } let latestMessage = undefined; if (!messageId || retry) { // revert the latest message generated by the assistant // if messageId is provided, we will also revert latest user message await this.chatSession.revertLatestMessage(sessionId, !!messageId); session.revertLatestMessage(!!messageId); if (!messageId) { latestMessage = session.latestUserMessage; } } if (messageId) { await session.pushByMessageId(messageId); } return [latestMessage, session]; } private parseNumber(value: string | string[] | undefined) { if (!value) { return undefined; } const num = Number.parseInt(Array.isArray(value) ? value[0] : value, 10); if (Number.isNaN(num)) { return undefined; } return num; } private mergePingStream( messageId: string, source$: Observable ): Observable { const subject$ = new Subject(); const ping$ = interval(PING_INTERVAL).pipe( map(() => ({ type: 'ping' as const, id: messageId, data: '' })), takeUntil(subject$) ); return merge(source$.pipe(finalize(() => subject$.next(null))), ping$); } private async prepareChatSession( user: CurrentUser, sessionId: string, query: Record, outputType: ModelOutputType ) { let { messageId, retry, modelId, params } = ChatQuerySchema.parse(query); const { provider, model } = await this.chooseProvider( outputType, user.id, sessionId, messageId, modelId ); const [latestMessage, session] = await this.appendSessionMessage( sessionId, messageId, retry ); if (latestMessage) { params = Object.assign({}, params, latestMessage.params, { content: latestMessage.content, attachments: latestMessage.attachments, }); } const finalMessage = session.finish(params); return { provider, model, session, finalMessage, }; } @Get('/chat/:sessionId') @CallMetric('ai', 'chat', { timer: true }) async chat( @CurrentUser() user: CurrentUser, @Req() req: Request, @Param('sessionId') sessionId: string, @Query() query: Record ): Promise { const info: any = { sessionId, params: query }; try { const { provider, model, session, finalMessage } = await this.prepareChatSession( user, sessionId, query, ModelOutputType.Text ); info.model = model; info.finalMessage = finalMessage.filter(m => m.role !== 'system'); metrics.ai.counter('chat_calls').add(1, { model }); const { reasoning, webSearch } = ChatQuerySchema.parse(query); const content = await provider.text({ modelId: model }, finalMessage, { ...session.config.promptConfig, signal: getSignal(req).signal, user: user.id, session: session.config.sessionId, workspace: session.config.workspaceId, reasoning, webSearch, }); session.push({ role: 'assistant', content, createdAt: new Date(), }); await session.save(); return content; } catch (e: any) { metrics.ai.counter('chat_errors').add(1); let error = mapAnyError(e); if (error instanceof InternalServerError) { error = new CopilotFailedToGenerateText(e.message); } error.log('CopilotChat', info); throw error; } } @Sse('/chat/:sessionId/stream') @CallMetric('ai', 'chat_stream', { timer: true }) async chatStream( @CurrentUser() user: CurrentUser, @Req() req: Request, @Param('sessionId') sessionId: string, @Query() query: Record ): Promise> { const info: any = { sessionId, params: query, throwInStream: false }; try { const { provider, model, session, finalMessage } = await this.prepareChatSession( user, sessionId, query, ModelOutputType.Text ); info.model = model; info.finalMessage = finalMessage.filter(m => m.role !== 'system'); metrics.ai.counter('chat_stream_calls').add(1, { model }); this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1); const { signal, onConnectionClosed } = getSignal(req); let endBeforePromiseResolve = false; onConnectionClosed(isAborted => { if (isAborted) { endBeforePromiseResolve = true; } }); const { messageId, reasoning, webSearch } = ChatQuerySchema.parse(query); const source$ = from( provider.streamText({ modelId: model }, finalMessage, { ...session.config.promptConfig, signal, user: user.id, session: session.config.sessionId, workspace: session.config.workspaceId, reasoning, webSearch, }) ).pipe( connect(shared$ => merge( // actual chat event stream shared$.pipe( map(data => ({ type: 'message' as const, id: messageId, data })) ), // save the generated text to the session shared$.pipe( reduce((acc, chunk) => acc + chunk, ''), tap(buffer => { session.push({ role: 'assistant', content: endBeforePromiseResolve ? '> Request aborted' : buffer, createdAt: new Date(), }); void session .save() .catch(err => this.logger.error( 'Failed to save session in sse stream', err ) ); }), ignoreElements() ) ) ), catchError(e => { metrics.ai.counter('chat_stream_errors').add(1); info.throwInStream = true; return mapSseError(e, info); }), finalize(() => { this.ongoingStreamCount$.next(this.ongoingStreamCount$.value - 1); }) ); return this.mergePingStream(messageId || '', source$); } catch (err) { metrics.ai.counter('chat_stream_errors').add(1, info); return mapSseError(err, info); } } @Sse('/chat/:sessionId/stream-object') @CallMetric('ai', 'chat_object_stream', { timer: true }) async chatStreamObject( @CurrentUser() user: CurrentUser, @Req() req: Request, @Param('sessionId') sessionId: string, @Query() query: Record ): Promise> { const info: any = { sessionId, params: query, throwInStream: false }; try { const { provider, model, session, finalMessage } = await this.prepareChatSession( user, sessionId, query, ModelOutputType.Object ); info.model = model; info.finalMessage = finalMessage.filter(m => m.role !== 'system'); metrics.ai.counter('chat_object_stream_calls').add(1, { model }); this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1); const { signal, onConnectionClosed } = getSignal(req); let endBeforePromiseResolve = false; onConnectionClosed(isAborted => { if (isAborted) { endBeforePromiseResolve = true; } }); const { messageId, reasoning, webSearch } = ChatQuerySchema.parse(query); const source$ = from( provider.streamObject({ modelId: model }, finalMessage, { ...session.config.promptConfig, signal, user: user.id, session: session.config.sessionId, workspace: session.config.workspaceId, reasoning, webSearch, }) ).pipe( connect(shared$ => merge( // actual chat event stream shared$.pipe( map(data => ({ type: 'message' as const, id: messageId, data })) ), // save the generated text to the session shared$.pipe( reduce((acc, chunk) => acc.concat([chunk]), [] as StreamObject[]), tap(result => { const parser = new StreamObjectParser(); const streamObjects = parser.mergeTextDelta(result); const content = parser.mergeContent(streamObjects); session.push({ role: 'assistant', content: endBeforePromiseResolve ? '> Request aborted' : content, streamObjects: endBeforePromiseResolve ? null : streamObjects, createdAt: new Date(), }); void session .save() .catch(err => this.logger.error( 'Failed to save session in sse stream', err ) ); }), ignoreElements() ) ) ), catchError(e => { metrics.ai.counter('chat_object_stream_errors').add(1); info.throwInStream = true; return mapSseError(e, info); }), finalize(() => { this.ongoingStreamCount$.next(this.ongoingStreamCount$.value - 1); }) ); return this.mergePingStream(messageId || '', source$); } catch (err) { metrics.ai.counter('chat_object_stream_errors').add(1, info); return mapSseError(err, info); } } @Sse('/chat/:sessionId/workflow') @CallMetric('ai', 'chat_workflow', { timer: true }) async chatWorkflow( @CurrentUser() user: CurrentUser, @Req() req: Request, @Param('sessionId') sessionId: string, @Query() query: Record ): Promise> { const info: any = { sessionId, params: query, throwInStream: false }; try { let { messageId, params } = ChatQuerySchema.parse(query); const [, session] = await this.appendSessionMessage(sessionId, messageId); info.model = session.model; metrics.ai.counter('workflow_calls').add(1, { model: session.model }); const latestMessage = session.stashMessages.findLast( m => m.role === 'user' ); if (latestMessage) { params = Object.assign({}, params, latestMessage.params, { content: latestMessage.content, attachments: latestMessage.attachments, }); } this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1); const { signal, onConnectionClosed } = getSignal(req); let endBeforePromiseResolve = false; onConnectionClosed(isAborted => { if (isAborted) { endBeforePromiseResolve = true; } }); const source$ = from( this.workflow.runGraph(params, session.model, { ...session.config.promptConfig, signal, user: user.id, session: session.config.sessionId, workspace: session.config.workspaceId, }) ).pipe( connect(shared$ => merge( // actual chat event stream shared$.pipe( map(data => { switch (data.status) { case GraphExecutorState.EmitContent: return { type: 'message' as const, id: messageId, data: data.content, }; case GraphExecutorState.EmitAttachment: return { type: 'attachment' as const, id: messageId, data: data.attachment, }; default: return { type: 'event' as const, id: messageId, data: { status: data.status, id: data.node.id, type: data.node.config.nodeType, } as any, }; } }) ), // save the generated text to the session shared$.pipe( reduce((acc, chunk) => { if (chunk.status === GraphExecutorState.EmitContent) { acc += chunk.content; } return acc; }, ''), tap(content => { session.push({ role: 'assistant', content: endBeforePromiseResolve ? '> Request aborted' : content, createdAt: new Date(), }); void session .save() .catch(err => this.logger.error( 'Failed to save session in sse stream', err ) ); }), ignoreElements() ) ) ), catchError(e => { metrics.ai.counter('workflow_errors').add(1, info); info.throwInStream = true; return mapSseError(e, info); }), finalize(() => this.ongoingStreamCount$.next(this.ongoingStreamCount$.value - 1) ) ); return this.mergePingStream(messageId || '', source$); } catch (err) { metrics.ai.counter('workflow_errors').add(1, info); return mapSseError(err, info); } } @Sse('/chat/:sessionId/images') @CallMetric('ai', 'chat_images', { timer: true }) async chatImagesStream( @CurrentUser() user: CurrentUser, @Req() req: Request, @Param('sessionId') sessionId: string, @Query() query: Record ): Promise> { const info: any = { sessionId, params: query, throwInStream: false }; try { let { messageId, params } = ChatQuerySchema.parse(query); const { provider, model, hasAttachment } = await this.chooseProvider( ModelOutputType.Image, user.id, sessionId, messageId ); const [latestMessage, session] = await this.appendSessionMessage( sessionId, messageId ); info.model = model; metrics.ai.counter('images_stream_calls').add(1, { model }); if (latestMessage) { params = Object.assign({}, params, latestMessage.params, { content: latestMessage.content, attachments: latestMessage.attachments, }); } const handleRemoteLink = this.storage.handleRemoteLink.bind( this.storage, user.id, sessionId ); this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1); const { signal, onConnectionClosed } = getSignal(req); let endBeforePromiseResolve = false; onConnectionClosed(isAborted => { if (isAborted) { endBeforePromiseResolve = true; } }); const source$ = from( provider.streamImages( { modelId: model, inputTypes: hasAttachment ? [ModelInputType.Image] : [ModelInputType.Text], }, session.finish(params), { ...session.config.promptConfig, quality: params.quality || undefined, seed: this.parseNumber(params.seed), signal, user: user.id, session: session.config.sessionId, workspace: session.config.workspaceId, } ) ).pipe( mergeMap(handleRemoteLink), connect(shared$ => merge( // actual chat event stream shared$.pipe( map(attachment => ({ type: 'attachment' as const, id: messageId, data: attachment, })) ), // save the generated text to the session shared$.pipe( reduce((acc, chunk) => acc.concat([chunk]), [] as string[]), tap(attachments => { session.push({ role: 'assistant', content: endBeforePromiseResolve ? '> Request aborted' : '', attachments: endBeforePromiseResolve ? [] : attachments, createdAt: new Date(), }); void session .save() .catch(err => this.logger.error( 'Failed to save session in sse stream', err ) ); }), ignoreElements() ) ) ), catchError(e => { metrics.ai.counter('images_stream_errors').add(1, info); info.throwInStream = true; return mapSseError(e, info); }), finalize(() => this.ongoingStreamCount$.next(this.ongoingStreamCount$.value - 1) ) ); return this.mergePingStream(messageId || '', source$); } catch (err) { metrics.ai.counter('images_stream_errors').add(1, info); return mapSseError(err, info); } } @Get('/unsplash/photos') @CallMetric('ai', 'unsplash') async unsplashPhotos( @Req() req: Request, @Res() res: Response, @Query() params: Record ) { const { key } = this.config.copilot.unsplash; if (!key) { throw new UnsplashIsNotConfigured(); } const query = new URLSearchParams(params); const response = await fetch( `https://api.unsplash.com/search/photos?${query}`, { headers: { Authorization: `Client-ID ${key}` }, signal: getSignal(req).signal, } ); res.set({ 'Content-Type': response.headers.get('Content-Type'), 'Content-Length': response.headers.get('Content-Length'), 'X-Ratelimit-Limit': response.headers.get('X-Ratelimit-Limit'), 'X-Ratelimit-Remaining': response.headers.get('X-Ratelimit-Remaining'), }); res.status(response.status).send(await response.json()); } @Public() @Get('/blob/:userId/:workspaceId/:key') async getBlob( @Res() res: Response, @Param('userId') userId: string, @Param('workspaceId') workspaceId: string, @Param('key') key: string ) { const { body, metadata, redirectUrl } = await this.storage.get( userId, workspaceId, key, true ); if (redirectUrl) { // redirect to signed url return res.redirect(redirectUrl); } if (!body) { throw new BlobNotFound({ spaceId: workspaceId, blobId: key, }); } // metadata should always exists if body is not null if (metadata) { res.setHeader('content-type', metadata.contentType); res.setHeader('last-modified', metadata.lastModified.toUTCString()); res.setHeader('content-length', metadata.contentLength); } else { this.logger.warn(`Blob ${workspaceId}/${key} has no metadata`); } res.setHeader('cache-control', 'public, max-age=2592000, immutable'); body.pipe(res); } }