mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-26 02:35:58 +08:00
feat(server): allow chat session dangling & pin session support (#12849)
fix AI-181 fix AI-179 fix AI-178 fix PD-2682 fix PD-2683
This commit is contained in:
@@ -1,34 +1,36 @@
|
||||
import { randomUUID } from 'node:crypto';
|
||||
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
import { AiPromptRole, Prisma, PrismaClient } from '@prisma/client';
|
||||
import { omit } from 'lodash-es';
|
||||
import { Transactional } from '@nestjs-cls/transactional';
|
||||
import { AiPromptRole, PrismaClient } from '@prisma/client';
|
||||
|
||||
import {
|
||||
CopilotActionTaken,
|
||||
CopilotMessageNotFound,
|
||||
CopilotPromptNotFound,
|
||||
CopilotQuotaExceeded,
|
||||
CopilotSessionDeleted,
|
||||
CopilotSessionInvalidInput,
|
||||
CopilotSessionNotFound,
|
||||
PrismaTransaction,
|
||||
} from '../../base';
|
||||
import { QuotaService } from '../../core/quota';
|
||||
import { Models } from '../../models';
|
||||
import {
|
||||
Models,
|
||||
type UpdateChatSession,
|
||||
UpdateChatSessionData,
|
||||
} from '../../models';
|
||||
import { ChatMessageCache } from './message';
|
||||
import { PromptService } from './prompt';
|
||||
import { PromptMessage, PromptParams } from './providers';
|
||||
import {
|
||||
ChatHistory,
|
||||
ChatMessage,
|
||||
type ChatHistory,
|
||||
type ChatMessage,
|
||||
ChatMessageSchema,
|
||||
ChatSessionForkOptions,
|
||||
ChatSessionOptions,
|
||||
ChatSessionPromptUpdateOptions,
|
||||
ChatSessionState,
|
||||
type ChatSessionForkOptions,
|
||||
type ChatSessionOptions,
|
||||
type ChatSessionState,
|
||||
getTokenEncoder,
|
||||
ListHistoriesOptions,
|
||||
SubmittedMessage,
|
||||
type ListHistoriesOptions,
|
||||
type SubmittedMessage,
|
||||
} from './types';
|
||||
|
||||
export class ChatSession implements AsyncDisposable {
|
||||
@@ -229,141 +231,56 @@ export class ChatSessionService {
|
||||
private readonly models: Models
|
||||
) {}
|
||||
|
||||
private async haveSession(
|
||||
sessionId: string,
|
||||
userId: string,
|
||||
tx?: PrismaTransaction,
|
||||
params?: Prisma.AiSessionCountArgs['where']
|
||||
) {
|
||||
const executor = tx ?? this.db;
|
||||
return await executor.aiSession
|
||||
.count({
|
||||
where: {
|
||||
id: sessionId,
|
||||
userId,
|
||||
...params,
|
||||
},
|
||||
})
|
||||
.then(c => c > 0);
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
private async setSession(state: ChatSessionState): Promise<string> {
|
||||
return await this.db.$transaction(async tx => {
|
||||
let sessionId = state.sessionId;
|
||||
const session = this.models.copilotSession;
|
||||
let sessionId = state.sessionId;
|
||||
|
||||
// find existing session if session is chat session
|
||||
if (!state.prompt.action) {
|
||||
const extraCondition: Record<string, any> = {};
|
||||
if (state.parentSessionId) {
|
||||
// also check session id if provided session is forked session
|
||||
extraCondition.id = state.sessionId;
|
||||
extraCondition.parentSessionId = state.parentSessionId;
|
||||
}
|
||||
const { id, deletedAt } =
|
||||
(await tx.aiSession.findFirst({
|
||||
where: {
|
||||
userId: state.userId,
|
||||
workspaceId: state.workspaceId,
|
||||
docId: state.docId,
|
||||
prompt: { action: { equals: null } },
|
||||
parentSessionId: null,
|
||||
...extraCondition,
|
||||
},
|
||||
select: { id: true, deletedAt: true },
|
||||
})) || {};
|
||||
if (deletedAt) throw new CopilotSessionDeleted();
|
||||
if (id) sessionId = id;
|
||||
// find existing session if session is chat session
|
||||
if (!state.prompt.action) {
|
||||
const id = await session.getChatSessionId(state);
|
||||
if (id) sessionId = id;
|
||||
}
|
||||
|
||||
const haveSession = await session.has(sessionId, state.userId);
|
||||
if (haveSession) {
|
||||
// message will only exists when setSession call by session.save
|
||||
if (state.messages.length) {
|
||||
await session.setMessages(
|
||||
sessionId,
|
||||
state.messages,
|
||||
this.calculateTokenSize(state.messages, state.prompt.model)
|
||||
);
|
||||
}
|
||||
} else {
|
||||
await session.create({
|
||||
...state,
|
||||
sessionId,
|
||||
promptName: state.prompt.name,
|
||||
});
|
||||
}
|
||||
|
||||
const haveSession = await this.haveSession(sessionId, state.userId, tx);
|
||||
if (haveSession) {
|
||||
// message will only exists when setSession call by session.save
|
||||
if (state.messages.length) {
|
||||
await tx.aiSessionMessage.createMany({
|
||||
data: state.messages.map(m => ({
|
||||
...m,
|
||||
streamObjects: m.streamObjects || undefined,
|
||||
attachments: m.attachments || undefined,
|
||||
params: omit(m.params, ['docs']) || undefined,
|
||||
sessionId,
|
||||
})),
|
||||
});
|
||||
|
||||
// only count message generated by user
|
||||
const userMessages = state.messages.filter(m => m.role === 'user');
|
||||
await tx.aiSession.update({
|
||||
where: { id: sessionId },
|
||||
data: {
|
||||
messageCost: { increment: userMessages.length },
|
||||
tokenCost: {
|
||||
increment: this.calculateTokenSize(
|
||||
state.messages,
|
||||
state.prompt.model
|
||||
),
|
||||
},
|
||||
},
|
||||
});
|
||||
}
|
||||
} else {
|
||||
await tx.aiSession.create({
|
||||
data: {
|
||||
id: sessionId,
|
||||
workspaceId: state.workspaceId,
|
||||
docId: state.docId,
|
||||
// connect
|
||||
userId: state.userId,
|
||||
promptName: state.prompt.name,
|
||||
parentSessionId: state.parentSessionId,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
return sessionId;
|
||||
});
|
||||
return sessionId;
|
||||
}
|
||||
|
||||
async getSession(sessionId: string): Promise<ChatSessionState | undefined> {
|
||||
return await this.db.aiSession
|
||||
.findUnique({
|
||||
where: { id: sessionId, deletedAt: null },
|
||||
select: {
|
||||
id: true,
|
||||
userId: true,
|
||||
workspaceId: true,
|
||||
docId: true,
|
||||
parentSessionId: true,
|
||||
messages: {
|
||||
select: {
|
||||
id: true,
|
||||
role: true,
|
||||
content: true,
|
||||
attachments: true,
|
||||
params: true,
|
||||
createdAt: true,
|
||||
},
|
||||
orderBy: { createdAt: 'asc' },
|
||||
},
|
||||
promptName: true,
|
||||
},
|
||||
})
|
||||
.then(async session => {
|
||||
if (!session) return;
|
||||
const prompt = await this.prompt.get(session.promptName);
|
||||
if (!prompt)
|
||||
throw new CopilotPromptNotFound({ name: session.promptName });
|
||||
const session = await this.models.copilotSession.get(sessionId);
|
||||
if (!session) return;
|
||||
const prompt = await this.prompt.get(session.promptName);
|
||||
if (!prompt) throw new CopilotPromptNotFound({ name: session.promptName });
|
||||
|
||||
const messages = ChatMessageSchema.array().safeParse(session.messages);
|
||||
const messages = ChatMessageSchema.array().safeParse(session.messages);
|
||||
|
||||
return {
|
||||
sessionId: session.id,
|
||||
userId: session.userId,
|
||||
workspaceId: session.workspaceId,
|
||||
docId: session.docId,
|
||||
parentSessionId: session.parentSessionId,
|
||||
prompt,
|
||||
messages: messages.success ? messages.data : [],
|
||||
};
|
||||
});
|
||||
return {
|
||||
sessionId: session.id,
|
||||
userId: session.userId,
|
||||
workspaceId: session.workspaceId,
|
||||
docId: session.docId,
|
||||
pinned: session.pinned,
|
||||
parentSessionId: session.parentSessionId,
|
||||
prompt,
|
||||
messages: messages.success ? messages.data : [],
|
||||
};
|
||||
}
|
||||
|
||||
// revert the latest messages not generate by user
|
||||
@@ -372,34 +289,10 @@ export class ChatSessionService {
|
||||
sessionId: string,
|
||||
removeLatestUserMessage: boolean
|
||||
) {
|
||||
await this.db.$transaction(async tx => {
|
||||
const id = await tx.aiSession
|
||||
.findUnique({
|
||||
where: { id: sessionId, deletedAt: null },
|
||||
select: { id: true },
|
||||
})
|
||||
.then(session => session?.id);
|
||||
if (!id) {
|
||||
throw new CopilotSessionNotFound();
|
||||
}
|
||||
const ids = await tx.aiSessionMessage
|
||||
.findMany({
|
||||
where: { sessionId: id },
|
||||
select: { id: true, role: true },
|
||||
orderBy: { createdAt: 'asc' },
|
||||
})
|
||||
.then(roles =>
|
||||
roles
|
||||
.slice(
|
||||
roles.findLastIndex(({ role }) => role === AiPromptRole.user) +
|
||||
(removeLatestUserMessage ? 0 : 1)
|
||||
)
|
||||
.map(({ id }) => id)
|
||||
);
|
||||
if (ids.length) {
|
||||
await tx.aiSessionMessage.deleteMany({ where: { id: { in: ids } } });
|
||||
}
|
||||
});
|
||||
await this.models.copilotSession.revertLatestMessage(
|
||||
sessionId,
|
||||
removeLatestUserMessage
|
||||
);
|
||||
}
|
||||
|
||||
private calculateTokenSize(messages: PromptMessage[], model: string): number {
|
||||
@@ -441,6 +334,7 @@ export class ChatSessionService {
|
||||
userId: true,
|
||||
workspaceId: true,
|
||||
docId: true,
|
||||
pinned: true,
|
||||
parentSessionId: true,
|
||||
promptName: true,
|
||||
},
|
||||
@@ -457,6 +351,7 @@ export class ChatSessionService {
|
||||
userId: session.userId,
|
||||
workspaceId: session.workspaceId,
|
||||
docId: session.docId,
|
||||
pinned: session.pinned,
|
||||
parentSessionId: session.parentSessionId,
|
||||
prompt,
|
||||
};
|
||||
@@ -471,138 +366,83 @@ export class ChatSessionService {
|
||||
docId?: string,
|
||||
options?: ListHistoriesOptions
|
||||
): Promise<ChatHistory[]> {
|
||||
const extraCondition = [];
|
||||
|
||||
if (!options?.action && options?.fork) {
|
||||
// only query forked session if fork == true and action == false
|
||||
extraCondition.push({
|
||||
userId: { not: userId },
|
||||
workspaceId: workspaceId,
|
||||
docId: workspaceId === docId ? undefined : docId,
|
||||
id: options?.sessionId ? { equals: options.sessionId } : undefined,
|
||||
// should only find forked session
|
||||
parentSessionId: { not: null },
|
||||
deletedAt: null,
|
||||
});
|
||||
}
|
||||
|
||||
return await this.db.aiSession
|
||||
.findMany({
|
||||
where: {
|
||||
OR: [
|
||||
{
|
||||
userId,
|
||||
workspaceId: workspaceId,
|
||||
docId: workspaceId === docId ? undefined : docId,
|
||||
id: options?.sessionId
|
||||
? { equals: options.sessionId }
|
||||
: undefined,
|
||||
deletedAt: null,
|
||||
},
|
||||
...extraCondition,
|
||||
],
|
||||
},
|
||||
select: {
|
||||
id: true,
|
||||
userId: true,
|
||||
promptName: true,
|
||||
tokenCost: true,
|
||||
createdAt: true,
|
||||
messages: {
|
||||
select: {
|
||||
id: true,
|
||||
role: true,
|
||||
content: true,
|
||||
streamObjects: true,
|
||||
attachments: true,
|
||||
params: true,
|
||||
createdAt: true,
|
||||
},
|
||||
orderBy: {
|
||||
// message order is asc by default
|
||||
createdAt: options?.messageOrder === 'desc' ? 'desc' : 'asc',
|
||||
},
|
||||
},
|
||||
},
|
||||
take: options?.limit,
|
||||
skip: options?.skip,
|
||||
orderBy: {
|
||||
// session order is desc by default
|
||||
createdAt: options?.sessionOrder === 'asc' ? 'asc' : 'desc',
|
||||
},
|
||||
})
|
||||
.then(sessions =>
|
||||
Promise.all(
|
||||
sessions.map(
|
||||
async ({
|
||||
id,
|
||||
userId: uid,
|
||||
promptName,
|
||||
tokenCost,
|
||||
messages,
|
||||
createdAt,
|
||||
}) => {
|
||||
try {
|
||||
const prompt = await this.prompt.get(promptName);
|
||||
if (!prompt) {
|
||||
throw new CopilotPromptNotFound({ name: promptName });
|
||||
}
|
||||
if (
|
||||
// filter out the user's session that not match the action option
|
||||
(uid === userId && !!options?.action !== !!prompt.action) ||
|
||||
// filter out the non chat session from other user
|
||||
(uid !== userId && !!prompt.action)
|
||||
) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const ret = ChatMessageSchema.array().safeParse(messages);
|
||||
if (ret.success) {
|
||||
// render system prompt
|
||||
const preload = (
|
||||
options?.withPrompt
|
||||
? prompt
|
||||
.finish(ret.data[0]?.params || {}, id)
|
||||
.filter(({ role }) => role !== 'system')
|
||||
: []
|
||||
) as ChatMessage[];
|
||||
|
||||
// `createdAt` is required for history sorting in frontend
|
||||
// let's fake the creating time of prompt messages
|
||||
preload.forEach((msg, i) => {
|
||||
msg.createdAt = new Date(
|
||||
createdAt.getTime() - preload.length - i - 1
|
||||
);
|
||||
});
|
||||
|
||||
return {
|
||||
sessionId: id,
|
||||
action: prompt.action || null,
|
||||
tokens: tokenCost,
|
||||
createdAt,
|
||||
messages: preload.concat(ret.data).map(m => ({
|
||||
...m,
|
||||
attachments: m.attachments
|
||||
?.map(a => (typeof a === 'string' ? a : a.attachment))
|
||||
.filter(a => !!a),
|
||||
})),
|
||||
};
|
||||
} else {
|
||||
this.logger.error(
|
||||
`Unexpected message schema: ${JSON.stringify(ret.error)}`
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
this.logger.error('Unexpected error in listHistories', e);
|
||||
}
|
||||
const sessions = await this.models.copilotSession.list(
|
||||
userId,
|
||||
workspaceId,
|
||||
docId,
|
||||
options
|
||||
);
|
||||
const histories = await Promise.all(
|
||||
sessions.map(
|
||||
async ({
|
||||
id,
|
||||
userId: uid,
|
||||
pinned,
|
||||
promptName,
|
||||
tokenCost,
|
||||
messages,
|
||||
createdAt,
|
||||
}) => {
|
||||
try {
|
||||
const prompt = await this.prompt.get(promptName);
|
||||
if (!prompt) {
|
||||
throw new CopilotPromptNotFound({ name: promptName });
|
||||
}
|
||||
if (
|
||||
// filter out the user's session that not match the action option
|
||||
(uid === userId && !!options?.action !== !!prompt.action) ||
|
||||
// filter out the non chat session from other user
|
||||
(uid !== userId && !!prompt.action)
|
||||
) {
|
||||
return undefined;
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
const ret = ChatMessageSchema.array().safeParse(messages);
|
||||
if (ret.success) {
|
||||
// render system prompt
|
||||
const preload = (
|
||||
options?.withPrompt
|
||||
? prompt
|
||||
.finish(ret.data[0]?.params || {}, id)
|
||||
.filter(({ role }) => role !== 'system')
|
||||
: []
|
||||
) as ChatMessage[];
|
||||
|
||||
// `createdAt` is required for history sorting in frontend
|
||||
// let's fake the creating time of prompt messages
|
||||
preload.forEach((msg, i) => {
|
||||
msg.createdAt = new Date(
|
||||
createdAt.getTime() - preload.length - i - 1
|
||||
);
|
||||
});
|
||||
|
||||
return {
|
||||
sessionId: id,
|
||||
pinned,
|
||||
action: prompt.action || null,
|
||||
tokens: tokenCost,
|
||||
createdAt,
|
||||
messages: preload.concat(ret.data).map(m => ({
|
||||
...m,
|
||||
attachments: m.attachments
|
||||
?.map(a => (typeof a === 'string' ? a : a.attachment))
|
||||
.filter(a => !!a),
|
||||
})),
|
||||
};
|
||||
} else {
|
||||
this.logger.error(
|
||||
`Unexpected message schema: ${JSON.stringify(ret.error)}`
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
this.logger.error('Unexpected error in listHistories', e);
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
)
|
||||
.then(histories =>
|
||||
histories.filter((v): v is NonNullable<typeof v> => !!v)
|
||||
);
|
||||
);
|
||||
|
||||
return histories.filter((v): v is NonNullable<typeof v> => !!v);
|
||||
}
|
||||
|
||||
async getQuota(userId: string) {
|
||||
@@ -637,6 +477,17 @@ export class ChatSessionService {
|
||||
throw new CopilotPromptNotFound({ name: options.promptName });
|
||||
}
|
||||
|
||||
if (options.pinned) {
|
||||
await this.unpin(options.workspaceId, options.userId);
|
||||
}
|
||||
|
||||
// validate prompt compatibility with session type
|
||||
this.models.copilotSession.checkSessionPrompt(
|
||||
options,
|
||||
prompt.name,
|
||||
prompt.action
|
||||
);
|
||||
|
||||
return await this.setSession({
|
||||
...options,
|
||||
sessionId,
|
||||
@@ -647,30 +498,47 @@ export class ChatSessionService {
|
||||
});
|
||||
}
|
||||
|
||||
async updateSessionPrompt(
|
||||
options: ChatSessionPromptUpdateOptions
|
||||
): Promise<string> {
|
||||
const prompt = await this.prompt.get(options.promptName);
|
||||
if (!prompt) {
|
||||
this.logger.error(`Prompt not found: ${options.promptName}`);
|
||||
throw new CopilotPromptNotFound({ name: options.promptName });
|
||||
@Transactional()
|
||||
async unpin(workspaceId: string, userId: string) {
|
||||
await this.models.copilotSession.unpin(workspaceId, userId);
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
async updateSession(options: UpdateChatSession): Promise<string> {
|
||||
const session = await this.getSession(options.sessionId);
|
||||
if (!session) {
|
||||
throw new CopilotSessionNotFound();
|
||||
}
|
||||
return await this.db.$transaction(async tx => {
|
||||
let sessionId = options.sessionId;
|
||||
const haveSession = await this.haveSession(
|
||||
sessionId,
|
||||
options.userId,
|
||||
tx,
|
||||
{ prompt: { action: null } }
|
||||
);
|
||||
if (haveSession) {
|
||||
await tx.aiSession.update({
|
||||
where: { id: sessionId },
|
||||
data: { promptName: prompt.name },
|
||||
});
|
||||
|
||||
const finalData: UpdateChatSessionData = {};
|
||||
if (options.promptName) {
|
||||
const prompt = await this.prompt.get(options.promptName);
|
||||
if (!prompt) {
|
||||
this.logger.error(`Prompt not found: ${options.promptName}`);
|
||||
throw new CopilotPromptNotFound({ name: options.promptName });
|
||||
}
|
||||
return sessionId;
|
||||
});
|
||||
|
||||
this.models.copilotSession.checkSessionPrompt(
|
||||
session,
|
||||
prompt.name,
|
||||
prompt.action
|
||||
);
|
||||
finalData.promptName = prompt.name;
|
||||
}
|
||||
finalData.pinned = options.pinned;
|
||||
finalData.docId = options.docId;
|
||||
|
||||
if (Object.keys(finalData).length === 0) {
|
||||
throw new CopilotSessionInvalidInput(
|
||||
'No valid fields to update in the session'
|
||||
);
|
||||
}
|
||||
|
||||
return await this.models.copilotSession.update(
|
||||
options.userId,
|
||||
options.sessionId,
|
||||
finalData
|
||||
);
|
||||
}
|
||||
|
||||
async fork(options: ChatSessionForkOptions): Promise<string> {
|
||||
@@ -678,6 +546,10 @@ export class ChatSessionService {
|
||||
if (!state) {
|
||||
throw new CopilotSessionNotFound();
|
||||
}
|
||||
if (state.pinned) {
|
||||
await this.unpin(options.workspaceId, options.userId);
|
||||
}
|
||||
|
||||
let messages = state.messages.map(m => ({ ...m, id: undefined }));
|
||||
if (options.latestMessageId) {
|
||||
const lastMessageIdx = state.messages.findLastIndex(
|
||||
@@ -706,7 +578,9 @@ export class ChatSessionService {
|
||||
}
|
||||
|
||||
async cleanup(
|
||||
options: Omit<ChatSessionOptions, 'promptName'> & { sessionIds: string[] }
|
||||
options: Omit<ChatSessionOptions, 'pinned' | 'promptName'> & {
|
||||
sessionIds: string[];
|
||||
}
|
||||
) {
|
||||
return await this.db.$transaction(async tx => {
|
||||
const sessions = await tx.aiSession.findMany({
|
||||
|
||||
Reference in New Issue
Block a user