feat(server): allow chat session dangling & pin session support (#12849)

fix AI-181
fix AI-179
fix AI-178
fix PD-2682
fix PD-2683
This commit is contained in:
DarkSky
2025-06-19 13:17:01 +08:00
committed by GitHub
parent d80bfac1d2
commit bd04930560
28 changed files with 1422 additions and 394 deletions

View File

@@ -1,34 +1,36 @@
import { randomUUID } from 'node:crypto';
import { Injectable, Logger } from '@nestjs/common';
import { AiPromptRole, Prisma, PrismaClient } from '@prisma/client';
import { omit } from 'lodash-es';
import { Transactional } from '@nestjs-cls/transactional';
import { AiPromptRole, PrismaClient } from '@prisma/client';
import {
CopilotActionTaken,
CopilotMessageNotFound,
CopilotPromptNotFound,
CopilotQuotaExceeded,
CopilotSessionDeleted,
CopilotSessionInvalidInput,
CopilotSessionNotFound,
PrismaTransaction,
} from '../../base';
import { QuotaService } from '../../core/quota';
import { Models } from '../../models';
import {
Models,
type UpdateChatSession,
UpdateChatSessionData,
} from '../../models';
import { ChatMessageCache } from './message';
import { PromptService } from './prompt';
import { PromptMessage, PromptParams } from './providers';
import {
ChatHistory,
ChatMessage,
type ChatHistory,
type ChatMessage,
ChatMessageSchema,
ChatSessionForkOptions,
ChatSessionOptions,
ChatSessionPromptUpdateOptions,
ChatSessionState,
type ChatSessionForkOptions,
type ChatSessionOptions,
type ChatSessionState,
getTokenEncoder,
ListHistoriesOptions,
SubmittedMessage,
type ListHistoriesOptions,
type SubmittedMessage,
} from './types';
export class ChatSession implements AsyncDisposable {
@@ -229,141 +231,56 @@ export class ChatSessionService {
private readonly models: Models
) {}
private async haveSession(
sessionId: string,
userId: string,
tx?: PrismaTransaction,
params?: Prisma.AiSessionCountArgs['where']
) {
const executor = tx ?? this.db;
return await executor.aiSession
.count({
where: {
id: sessionId,
userId,
...params,
},
})
.then(c => c > 0);
}
@Transactional()
private async setSession(state: ChatSessionState): Promise<string> {
return await this.db.$transaction(async tx => {
let sessionId = state.sessionId;
const session = this.models.copilotSession;
let sessionId = state.sessionId;
// find existing session if session is chat session
if (!state.prompt.action) {
const extraCondition: Record<string, any> = {};
if (state.parentSessionId) {
// also check session id if provided session is forked session
extraCondition.id = state.sessionId;
extraCondition.parentSessionId = state.parentSessionId;
}
const { id, deletedAt } =
(await tx.aiSession.findFirst({
where: {
userId: state.userId,
workspaceId: state.workspaceId,
docId: state.docId,
prompt: { action: { equals: null } },
parentSessionId: null,
...extraCondition,
},
select: { id: true, deletedAt: true },
})) || {};
if (deletedAt) throw new CopilotSessionDeleted();
if (id) sessionId = id;
// find existing session if session is chat session
if (!state.prompt.action) {
const id = await session.getChatSessionId(state);
if (id) sessionId = id;
}
const haveSession = await session.has(sessionId, state.userId);
if (haveSession) {
// message will only exists when setSession call by session.save
if (state.messages.length) {
await session.setMessages(
sessionId,
state.messages,
this.calculateTokenSize(state.messages, state.prompt.model)
);
}
} else {
await session.create({
...state,
sessionId,
promptName: state.prompt.name,
});
}
const haveSession = await this.haveSession(sessionId, state.userId, tx);
if (haveSession) {
// message will only exists when setSession call by session.save
if (state.messages.length) {
await tx.aiSessionMessage.createMany({
data: state.messages.map(m => ({
...m,
streamObjects: m.streamObjects || undefined,
attachments: m.attachments || undefined,
params: omit(m.params, ['docs']) || undefined,
sessionId,
})),
});
// only count message generated by user
const userMessages = state.messages.filter(m => m.role === 'user');
await tx.aiSession.update({
where: { id: sessionId },
data: {
messageCost: { increment: userMessages.length },
tokenCost: {
increment: this.calculateTokenSize(
state.messages,
state.prompt.model
),
},
},
});
}
} else {
await tx.aiSession.create({
data: {
id: sessionId,
workspaceId: state.workspaceId,
docId: state.docId,
// connect
userId: state.userId,
promptName: state.prompt.name,
parentSessionId: state.parentSessionId,
},
});
}
return sessionId;
});
return sessionId;
}
async getSession(sessionId: string): Promise<ChatSessionState | undefined> {
return await this.db.aiSession
.findUnique({
where: { id: sessionId, deletedAt: null },
select: {
id: true,
userId: true,
workspaceId: true,
docId: true,
parentSessionId: true,
messages: {
select: {
id: true,
role: true,
content: true,
attachments: true,
params: true,
createdAt: true,
},
orderBy: { createdAt: 'asc' },
},
promptName: true,
},
})
.then(async session => {
if (!session) return;
const prompt = await this.prompt.get(session.promptName);
if (!prompt)
throw new CopilotPromptNotFound({ name: session.promptName });
const session = await this.models.copilotSession.get(sessionId);
if (!session) return;
const prompt = await this.prompt.get(session.promptName);
if (!prompt) throw new CopilotPromptNotFound({ name: session.promptName });
const messages = ChatMessageSchema.array().safeParse(session.messages);
const messages = ChatMessageSchema.array().safeParse(session.messages);
return {
sessionId: session.id,
userId: session.userId,
workspaceId: session.workspaceId,
docId: session.docId,
parentSessionId: session.parentSessionId,
prompt,
messages: messages.success ? messages.data : [],
};
});
return {
sessionId: session.id,
userId: session.userId,
workspaceId: session.workspaceId,
docId: session.docId,
pinned: session.pinned,
parentSessionId: session.parentSessionId,
prompt,
messages: messages.success ? messages.data : [],
};
}
// revert the latest messages not generate by user
@@ -372,34 +289,10 @@ export class ChatSessionService {
sessionId: string,
removeLatestUserMessage: boolean
) {
await this.db.$transaction(async tx => {
const id = await tx.aiSession
.findUnique({
where: { id: sessionId, deletedAt: null },
select: { id: true },
})
.then(session => session?.id);
if (!id) {
throw new CopilotSessionNotFound();
}
const ids = await tx.aiSessionMessage
.findMany({
where: { sessionId: id },
select: { id: true, role: true },
orderBy: { createdAt: 'asc' },
})
.then(roles =>
roles
.slice(
roles.findLastIndex(({ role }) => role === AiPromptRole.user) +
(removeLatestUserMessage ? 0 : 1)
)
.map(({ id }) => id)
);
if (ids.length) {
await tx.aiSessionMessage.deleteMany({ where: { id: { in: ids } } });
}
});
await this.models.copilotSession.revertLatestMessage(
sessionId,
removeLatestUserMessage
);
}
private calculateTokenSize(messages: PromptMessage[], model: string): number {
@@ -441,6 +334,7 @@ export class ChatSessionService {
userId: true,
workspaceId: true,
docId: true,
pinned: true,
parentSessionId: true,
promptName: true,
},
@@ -457,6 +351,7 @@ export class ChatSessionService {
userId: session.userId,
workspaceId: session.workspaceId,
docId: session.docId,
pinned: session.pinned,
parentSessionId: session.parentSessionId,
prompt,
};
@@ -471,138 +366,83 @@ export class ChatSessionService {
docId?: string,
options?: ListHistoriesOptions
): Promise<ChatHistory[]> {
const extraCondition = [];
if (!options?.action && options?.fork) {
// only query forked session if fork == true and action == false
extraCondition.push({
userId: { not: userId },
workspaceId: workspaceId,
docId: workspaceId === docId ? undefined : docId,
id: options?.sessionId ? { equals: options.sessionId } : undefined,
// should only find forked session
parentSessionId: { not: null },
deletedAt: null,
});
}
return await this.db.aiSession
.findMany({
where: {
OR: [
{
userId,
workspaceId: workspaceId,
docId: workspaceId === docId ? undefined : docId,
id: options?.sessionId
? { equals: options.sessionId }
: undefined,
deletedAt: null,
},
...extraCondition,
],
},
select: {
id: true,
userId: true,
promptName: true,
tokenCost: true,
createdAt: true,
messages: {
select: {
id: true,
role: true,
content: true,
streamObjects: true,
attachments: true,
params: true,
createdAt: true,
},
orderBy: {
// message order is asc by default
createdAt: options?.messageOrder === 'desc' ? 'desc' : 'asc',
},
},
},
take: options?.limit,
skip: options?.skip,
orderBy: {
// session order is desc by default
createdAt: options?.sessionOrder === 'asc' ? 'asc' : 'desc',
},
})
.then(sessions =>
Promise.all(
sessions.map(
async ({
id,
userId: uid,
promptName,
tokenCost,
messages,
createdAt,
}) => {
try {
const prompt = await this.prompt.get(promptName);
if (!prompt) {
throw new CopilotPromptNotFound({ name: promptName });
}
if (
// filter out the user's session that not match the action option
(uid === userId && !!options?.action !== !!prompt.action) ||
// filter out the non chat session from other user
(uid !== userId && !!prompt.action)
) {
return undefined;
}
const ret = ChatMessageSchema.array().safeParse(messages);
if (ret.success) {
// render system prompt
const preload = (
options?.withPrompt
? prompt
.finish(ret.data[0]?.params || {}, id)
.filter(({ role }) => role !== 'system')
: []
) as ChatMessage[];
// `createdAt` is required for history sorting in frontend
// let's fake the creating time of prompt messages
preload.forEach((msg, i) => {
msg.createdAt = new Date(
createdAt.getTime() - preload.length - i - 1
);
});
return {
sessionId: id,
action: prompt.action || null,
tokens: tokenCost,
createdAt,
messages: preload.concat(ret.data).map(m => ({
...m,
attachments: m.attachments
?.map(a => (typeof a === 'string' ? a : a.attachment))
.filter(a => !!a),
})),
};
} else {
this.logger.error(
`Unexpected message schema: ${JSON.stringify(ret.error)}`
);
}
} catch (e) {
this.logger.error('Unexpected error in listHistories', e);
}
const sessions = await this.models.copilotSession.list(
userId,
workspaceId,
docId,
options
);
const histories = await Promise.all(
sessions.map(
async ({
id,
userId: uid,
pinned,
promptName,
tokenCost,
messages,
createdAt,
}) => {
try {
const prompt = await this.prompt.get(promptName);
if (!prompt) {
throw new CopilotPromptNotFound({ name: promptName });
}
if (
// filter out the user's session that not match the action option
(uid === userId && !!options?.action !== !!prompt.action) ||
// filter out the non chat session from other user
(uid !== userId && !!prompt.action)
) {
return undefined;
}
)
)
const ret = ChatMessageSchema.array().safeParse(messages);
if (ret.success) {
// render system prompt
const preload = (
options?.withPrompt
? prompt
.finish(ret.data[0]?.params || {}, id)
.filter(({ role }) => role !== 'system')
: []
) as ChatMessage[];
// `createdAt` is required for history sorting in frontend
// let's fake the creating time of prompt messages
preload.forEach((msg, i) => {
msg.createdAt = new Date(
createdAt.getTime() - preload.length - i - 1
);
});
return {
sessionId: id,
pinned,
action: prompt.action || null,
tokens: tokenCost,
createdAt,
messages: preload.concat(ret.data).map(m => ({
...m,
attachments: m.attachments
?.map(a => (typeof a === 'string' ? a : a.attachment))
.filter(a => !!a),
})),
};
} else {
this.logger.error(
`Unexpected message schema: ${JSON.stringify(ret.error)}`
);
}
} catch (e) {
this.logger.error('Unexpected error in listHistories', e);
}
return undefined;
}
)
.then(histories =>
histories.filter((v): v is NonNullable<typeof v> => !!v)
);
);
return histories.filter((v): v is NonNullable<typeof v> => !!v);
}
async getQuota(userId: string) {
@@ -637,6 +477,17 @@ export class ChatSessionService {
throw new CopilotPromptNotFound({ name: options.promptName });
}
if (options.pinned) {
await this.unpin(options.workspaceId, options.userId);
}
// validate prompt compatibility with session type
this.models.copilotSession.checkSessionPrompt(
options,
prompt.name,
prompt.action
);
return await this.setSession({
...options,
sessionId,
@@ -647,30 +498,47 @@ export class ChatSessionService {
});
}
async updateSessionPrompt(
options: ChatSessionPromptUpdateOptions
): Promise<string> {
const prompt = await this.prompt.get(options.promptName);
if (!prompt) {
this.logger.error(`Prompt not found: ${options.promptName}`);
throw new CopilotPromptNotFound({ name: options.promptName });
@Transactional()
async unpin(workspaceId: string, userId: string) {
await this.models.copilotSession.unpin(workspaceId, userId);
}
@Transactional()
async updateSession(options: UpdateChatSession): Promise<string> {
const session = await this.getSession(options.sessionId);
if (!session) {
throw new CopilotSessionNotFound();
}
return await this.db.$transaction(async tx => {
let sessionId = options.sessionId;
const haveSession = await this.haveSession(
sessionId,
options.userId,
tx,
{ prompt: { action: null } }
);
if (haveSession) {
await tx.aiSession.update({
where: { id: sessionId },
data: { promptName: prompt.name },
});
const finalData: UpdateChatSessionData = {};
if (options.promptName) {
const prompt = await this.prompt.get(options.promptName);
if (!prompt) {
this.logger.error(`Prompt not found: ${options.promptName}`);
throw new CopilotPromptNotFound({ name: options.promptName });
}
return sessionId;
});
this.models.copilotSession.checkSessionPrompt(
session,
prompt.name,
prompt.action
);
finalData.promptName = prompt.name;
}
finalData.pinned = options.pinned;
finalData.docId = options.docId;
if (Object.keys(finalData).length === 0) {
throw new CopilotSessionInvalidInput(
'No valid fields to update in the session'
);
}
return await this.models.copilotSession.update(
options.userId,
options.sessionId,
finalData
);
}
async fork(options: ChatSessionForkOptions): Promise<string> {
@@ -678,6 +546,10 @@ export class ChatSessionService {
if (!state) {
throw new CopilotSessionNotFound();
}
if (state.pinned) {
await this.unpin(options.workspaceId, options.userId);
}
let messages = state.messages.map(m => ({ ...m, id: undefined }));
if (options.latestMessageId) {
const lastMessageIdx = state.messages.findLastIndex(
@@ -706,7 +578,9 @@ export class ChatSessionService {
}
async cleanup(
options: Omit<ChatSessionOptions, 'promptName'> & { sessionIds: string[] }
options: Omit<ChatSessionOptions, 'pinned' | 'promptName'> & {
sessionIds: string[];
}
) {
return await this.db.$transaction(async tx => {
const sessions = await tx.aiSession.findMany({