mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-14 13:25:12 +00:00
639 lines
17 KiB
TypeScript
639 lines
17 KiB
TypeScript
import { Injectable } from '@nestjs/common';
|
|
import { Transactional } from '@nestjs-cls/transactional';
|
|
import { AiPromptRole, Prisma } from '@prisma/client';
|
|
import { omit } from 'lodash-es';
|
|
|
|
import {
|
|
CopilotPromptInvalid,
|
|
CopilotSessionDeleted,
|
|
CopilotSessionInvalidInput,
|
|
CopilotSessionNotFound,
|
|
} from '../base';
|
|
import { getTokenEncoder } from '../native';
|
|
import { BaseModel } from './base';
|
|
|
|
export enum SessionType {
|
|
Workspace = 'workspace', // docId is null and pinned is false
|
|
Pinned = 'pinned', // pinned is true
|
|
Doc = 'doc', // docId points to specific document
|
|
}
|
|
|
|
type ChatPrompt = {
|
|
name: string;
|
|
action?: string | null;
|
|
model: string;
|
|
};
|
|
|
|
type ChatAttachment = { attachment: string; mimeType: string } | string;
|
|
|
|
type ChatStreamObject = {
|
|
type: 'text-delta' | 'reasoning' | 'tool-call' | 'tool-result';
|
|
textDelta?: string;
|
|
toolCallId?: string;
|
|
toolName?: string;
|
|
args?: Record<string, any>;
|
|
result?: any;
|
|
};
|
|
|
|
type ChatMessage = {
|
|
id?: string | undefined;
|
|
role: 'system' | 'assistant' | 'user';
|
|
content: string;
|
|
attachments?: ChatAttachment[] | null;
|
|
params?: Record<string, any> | null;
|
|
streamObjects?: ChatStreamObject[] | null;
|
|
createdAt: Date;
|
|
};
|
|
|
|
type PureChatSession = {
|
|
sessionId: string;
|
|
workspaceId: string;
|
|
docId?: string | null;
|
|
pinned?: boolean;
|
|
title: string | null;
|
|
messages?: ChatMessage[];
|
|
// connect ids
|
|
userId: string;
|
|
parentSessionId?: string | null;
|
|
};
|
|
|
|
type ChatSession = PureChatSession & {
|
|
// connect ids
|
|
promptName: string;
|
|
promptAction: string | null;
|
|
};
|
|
|
|
type ChatSessionWithPrompt = PureChatSession & {
|
|
prompt: ChatPrompt;
|
|
};
|
|
|
|
type ChatSessionBaseState = Pick<ChatSession, 'userId' | 'sessionId'>;
|
|
|
|
export type ForkSessionOptions = Omit<
|
|
ChatSession,
|
|
'messages' | 'promptName' | 'promptAction'
|
|
> & {
|
|
prompt: { name: string; action: string | null | undefined; model: string };
|
|
messages: ChatMessage[];
|
|
};
|
|
|
|
type UpdateChatSessionMessage = ChatSessionBaseState & {
|
|
prompt: { model: string };
|
|
messages: ChatMessage[];
|
|
};
|
|
|
|
export type UpdateChatSessionOptions = ChatSessionBaseState &
|
|
Pick<Partial<ChatSession>, 'docId' | 'pinned' | 'promptName' | 'title'>;
|
|
|
|
export type UpdateChatSession = ChatSessionBaseState & UpdateChatSessionOptions;
|
|
|
|
export type ListSessionOptions = Pick<
|
|
Partial<ChatSession>,
|
|
'sessionId' | 'workspaceId' | 'docId' | 'pinned'
|
|
> & {
|
|
userId: string | undefined;
|
|
action?: boolean;
|
|
fork?: boolean;
|
|
limit?: number;
|
|
skip?: number;
|
|
sessionOrder?: 'asc' | 'desc';
|
|
messageOrder?: 'asc' | 'desc';
|
|
|
|
// extra condition
|
|
withPrompt?: boolean;
|
|
withMessages?: boolean;
|
|
};
|
|
|
|
export type CleanupSessionOptions = Pick<
|
|
ChatSession,
|
|
'userId' | 'workspaceId' | 'docId'
|
|
> & {
|
|
sessionIds: string[];
|
|
};
|
|
|
|
@Injectable()
|
|
export class CopilotSessionModel extends BaseModel {
|
|
getSessionType(session: Pick<ChatSession, 'docId' | 'pinned'>): SessionType {
|
|
if (session.pinned) return SessionType.Pinned;
|
|
if (!session.docId) return SessionType.Workspace;
|
|
return SessionType.Doc;
|
|
}
|
|
|
|
checkSessionPrompt(
|
|
session: Pick<ChatSession, 'docId' | 'pinned'>,
|
|
prompt: Partial<ChatPrompt>
|
|
): boolean {
|
|
const sessionType = this.getSessionType(session);
|
|
const { name: promptName, action: promptAction } = prompt;
|
|
|
|
// workspace and pinned sessions cannot use action prompts
|
|
if (
|
|
[SessionType.Workspace, SessionType.Pinned].includes(sessionType) &&
|
|
!!promptAction?.trim()
|
|
) {
|
|
throw new CopilotPromptInvalid(
|
|
`${promptName} are not allowed for ${sessionType} sessions`
|
|
);
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
// NOTE: just for test, remove it after copilot prompt model is ready
|
|
async createPrompt(name: string, model: string, action?: string) {
|
|
await this.db.aiPrompt.create({
|
|
data: { name, model, action: action ?? null },
|
|
});
|
|
}
|
|
|
|
@Transactional()
|
|
async create(state: ChatSession, reuseChat = false): Promise<string> {
|
|
// find and return existing session if session is chat session
|
|
if (reuseChat && !state.promptAction) {
|
|
const sessionId = await this.find(state);
|
|
if (sessionId) return sessionId;
|
|
}
|
|
|
|
if (state.pinned) {
|
|
await this.unpin(state.workspaceId, state.userId);
|
|
}
|
|
|
|
const session = await this.db.aiSession.create({
|
|
data: {
|
|
id: state.sessionId,
|
|
workspaceId: state.workspaceId,
|
|
docId: state.docId,
|
|
pinned: state.pinned ?? false,
|
|
// connect
|
|
userId: state.userId,
|
|
promptName: state.promptName,
|
|
promptAction: state.promptAction,
|
|
parentSessionId: state.parentSessionId,
|
|
},
|
|
select: { id: true },
|
|
});
|
|
return session.id;
|
|
}
|
|
|
|
@Transactional()
|
|
async createWithPrompt(
|
|
state: ChatSessionWithPrompt,
|
|
reuseChat = false
|
|
): Promise<string> {
|
|
const { prompt, ...rest } = state;
|
|
return await this.models.copilotSession.create(
|
|
{ ...rest, promptName: prompt.name, promptAction: prompt.action ?? null },
|
|
reuseChat
|
|
);
|
|
}
|
|
|
|
@Transactional()
|
|
async fork(options: ForkSessionOptions): Promise<string> {
|
|
if (options.pinned) {
|
|
await this.unpin(options.workspaceId, options.userId);
|
|
}
|
|
const { messages, ...forkedState } = options;
|
|
|
|
// create session
|
|
const sessionId = await this.createWithPrompt({
|
|
...forkedState,
|
|
messages: [],
|
|
});
|
|
if (options.messages.length) {
|
|
// save message
|
|
await this.models.copilotSession.updateMessages({
|
|
...forkedState,
|
|
sessionId,
|
|
messages,
|
|
});
|
|
}
|
|
|
|
return sessionId;
|
|
}
|
|
|
|
@Transactional()
|
|
async has(
|
|
sessionId: string,
|
|
userId: string,
|
|
params?: Prisma.AiSessionCountArgs['where']
|
|
) {
|
|
return await this.db.aiSession
|
|
.count({ where: { id: sessionId, userId, ...params } })
|
|
.then(c => c > 0);
|
|
}
|
|
|
|
@Transactional()
|
|
async find(state: PureChatSession) {
|
|
const extraCondition: Record<string, any> = {};
|
|
if (state.parentSessionId) {
|
|
// also check session id if provided session is forked session
|
|
extraCondition.id = state.sessionId;
|
|
extraCondition.parentSessionId = state.parentSessionId;
|
|
}
|
|
|
|
const session = await this.db.aiSession.findFirst({
|
|
where: {
|
|
userId: state.userId,
|
|
workspaceId: state.workspaceId,
|
|
docId: state.docId,
|
|
parentSessionId: null,
|
|
prompt: { action: { equals: null } },
|
|
...extraCondition,
|
|
},
|
|
select: { id: true, deletedAt: true },
|
|
});
|
|
if (session?.deletedAt) throw new CopilotSessionDeleted();
|
|
return session?.id;
|
|
}
|
|
|
|
@Transactional()
|
|
async getExists<Select extends Prisma.AiSessionSelect>(
|
|
sessionId: string,
|
|
select?: Select,
|
|
where?: Omit<Prisma.AiSessionWhereInput, 'id' | 'deletedAt'>
|
|
) {
|
|
return (await this.db.aiSession.findUnique({
|
|
where: { ...where, id: sessionId, deletedAt: null },
|
|
select,
|
|
})) as Prisma.AiSessionGetPayload<{ select: Select }> | null;
|
|
}
|
|
|
|
@Transactional()
|
|
async get(sessionId: string) {
|
|
return await this.getExists(sessionId, {
|
|
id: true,
|
|
userId: true,
|
|
workspaceId: true,
|
|
docId: true,
|
|
parentSessionId: true,
|
|
pinned: true,
|
|
title: true,
|
|
promptName: true,
|
|
tokenCost: true,
|
|
createdAt: true,
|
|
updatedAt: true,
|
|
messages: {
|
|
select: {
|
|
id: true,
|
|
role: true,
|
|
content: true,
|
|
attachments: true,
|
|
streamObjects: true,
|
|
params: true,
|
|
createdAt: true,
|
|
},
|
|
orderBy: { createdAt: 'asc' },
|
|
},
|
|
});
|
|
}
|
|
|
|
private getListConditions(
|
|
options: ListSessionOptions
|
|
): Prisma.AiSessionWhereInput {
|
|
const { userId, sessionId, workspaceId, docId, action, fork } = options;
|
|
|
|
function getNullCond<T>(
|
|
maybeBool: boolean | undefined,
|
|
wrap: (ret: { not: null } | null) => T = ret => ret as T
|
|
): T | undefined {
|
|
return maybeBool === true
|
|
? wrap({ not: null })
|
|
: maybeBool === false
|
|
? wrap(null)
|
|
: undefined;
|
|
}
|
|
|
|
function getEqCond<T>(maybeValue: T | undefined): T | undefined {
|
|
return maybeValue !== undefined ? maybeValue : undefined;
|
|
}
|
|
|
|
const conditions: Prisma.AiSessionWhereInput['OR'] = [
|
|
{
|
|
userId,
|
|
workspaceId,
|
|
docId: getEqCond(docId),
|
|
id: getEqCond(sessionId),
|
|
deletedAt: null,
|
|
pinned: getEqCond(options.pinned),
|
|
prompt: getNullCond(action, ret => ({ action: ret })),
|
|
parentSessionId: getNullCond(fork),
|
|
},
|
|
];
|
|
|
|
if (!action && fork) {
|
|
// query forked sessions from other users
|
|
// only query forked session if fork == true and action == false
|
|
conditions.push({
|
|
userId: { not: userId },
|
|
workspaceId: workspaceId,
|
|
docId: docId ?? null,
|
|
id: getEqCond(sessionId),
|
|
prompt: { action: null },
|
|
// should only find forked session
|
|
parentSessionId: { not: null },
|
|
deletedAt: null,
|
|
});
|
|
}
|
|
|
|
return { OR: conditions };
|
|
}
|
|
|
|
async count(options: ListSessionOptions) {
|
|
return await this.db.aiSession.count({
|
|
where: this.getListConditions(options),
|
|
});
|
|
}
|
|
|
|
async list(options: ListSessionOptions) {
|
|
return await this.db.aiSession.findMany({
|
|
where: this.getListConditions(options),
|
|
select: {
|
|
id: true,
|
|
userId: true,
|
|
workspaceId: true,
|
|
docId: true,
|
|
parentSessionId: true,
|
|
pinned: true,
|
|
title: true,
|
|
promptName: true,
|
|
tokenCost: true,
|
|
createdAt: true,
|
|
updatedAt: true,
|
|
messages: options.withMessages
|
|
? {
|
|
select: {
|
|
id: true,
|
|
role: true,
|
|
content: true,
|
|
attachments: true,
|
|
streamObjects: true,
|
|
params: true,
|
|
createdAt: true,
|
|
},
|
|
orderBy: {
|
|
// message order is asc by default
|
|
createdAt: options?.messageOrder === 'desc' ? 'desc' : 'asc',
|
|
},
|
|
}
|
|
: false,
|
|
},
|
|
take: options?.limit,
|
|
skip: options?.skip,
|
|
orderBy: {
|
|
updatedAt: options?.sessionOrder === 'asc' ? 'asc' : 'desc',
|
|
},
|
|
});
|
|
}
|
|
|
|
@Transactional()
|
|
async unpin(workspaceId: string, userId: string): Promise<boolean> {
|
|
const { count } = await this.db.aiSession.updateMany({
|
|
where: { userId, workspaceId, pinned: true, deletedAt: null },
|
|
data: { pinned: false },
|
|
});
|
|
|
|
return count > 0;
|
|
}
|
|
|
|
@Transactional()
|
|
async update(options: UpdateChatSessionOptions): Promise<string> {
|
|
const { userId, sessionId, docId, promptName, pinned, title } = options;
|
|
const session = await this.getExists(
|
|
sessionId,
|
|
{
|
|
id: true,
|
|
workspaceId: true,
|
|
docId: true,
|
|
parentSessionId: true,
|
|
pinned: true,
|
|
prompt: true,
|
|
},
|
|
{ userId }
|
|
);
|
|
if (!session) {
|
|
throw new CopilotSessionNotFound();
|
|
}
|
|
|
|
// not allow to update action session
|
|
if (session.prompt.action) {
|
|
throw new CopilotSessionInvalidInput(
|
|
`Cannot update action: ${session.id}`
|
|
);
|
|
} else if (docId && session.parentSessionId) {
|
|
throw new CopilotSessionInvalidInput(
|
|
`Cannot update docId for forked session: ${session.id}`
|
|
);
|
|
}
|
|
|
|
if (promptName) {
|
|
const prompt = await this.db.aiPrompt.findFirst({
|
|
where: { name: promptName },
|
|
});
|
|
// always not allow to update to action prompt
|
|
if (!prompt || prompt.action) {
|
|
throw new CopilotSessionInvalidInput(
|
|
`Prompt ${promptName} not found or not available for session ${sessionId}`
|
|
);
|
|
}
|
|
}
|
|
if (pinned && pinned !== session.pinned) {
|
|
// if pin the session, unpin exists session in the workspace
|
|
await this.unpin(session.workspaceId, userId);
|
|
}
|
|
|
|
await this.db.aiSession.update({
|
|
where: { id: sessionId },
|
|
data: { docId, promptName, pinned, title },
|
|
});
|
|
|
|
return sessionId;
|
|
}
|
|
|
|
@Transactional()
|
|
async cleanup(options: CleanupSessionOptions): Promise<string[]> {
|
|
const sessions = await this.db.aiSession.findMany({
|
|
where: {
|
|
id: { in: options.sessionIds },
|
|
userId: options.userId,
|
|
workspaceId: options.workspaceId,
|
|
docId: options.docId,
|
|
deletedAt: null,
|
|
},
|
|
select: { id: true, prompt: true },
|
|
});
|
|
const sessionIds = sessions.map(({ id }) => id);
|
|
// cleanup all messages
|
|
await this.db.aiSessionMessage.deleteMany({
|
|
where: { sessionId: { in: sessionIds } },
|
|
});
|
|
|
|
// only mark action session as deleted
|
|
// chat session always can be reuse
|
|
const actionIds = sessions
|
|
.filter(({ prompt }) => !!prompt.action)
|
|
.map(({ id }) => id);
|
|
|
|
// 标记 action session 为已删除
|
|
if (actionIds.length > 0) {
|
|
await this.db.aiSession.updateMany({
|
|
where: { id: { in: actionIds } },
|
|
data: { pinned: false, deletedAt: new Date() },
|
|
});
|
|
}
|
|
|
|
return sessionIds;
|
|
}
|
|
|
|
@Transactional()
|
|
async getMessages(
|
|
sessionId: string,
|
|
select?: Prisma.AiSessionMessageSelect,
|
|
orderBy?: Prisma.AiSessionMessageOrderByWithRelationInput
|
|
) {
|
|
return this.db.aiSessionMessage.findMany({
|
|
where: { sessionId },
|
|
select,
|
|
orderBy: orderBy ?? { createdAt: 'asc' },
|
|
});
|
|
}
|
|
|
|
private calculateTokenSize(messages: any[], model: string): number {
|
|
const encoder = getTokenEncoder(model);
|
|
const content = messages.map(m => m.content).join('');
|
|
return encoder?.count(content) || 0;
|
|
}
|
|
|
|
@Transactional()
|
|
async updateMessages(state: UpdateChatSessionMessage) {
|
|
const { sessionId, userId, messages } = state;
|
|
const haveSession = await this.has(sessionId, userId);
|
|
if (!haveSession) {
|
|
throw new CopilotSessionNotFound();
|
|
}
|
|
|
|
if (messages.length) {
|
|
const tokenCost = this.calculateTokenSize(messages, state.prompt.model);
|
|
await this.db.aiSessionMessage.createMany({
|
|
data: messages.map(m => ({
|
|
...m,
|
|
attachments: m.attachments || undefined,
|
|
params: omit(m.params, ['docs']) || undefined,
|
|
streamObjects: m.streamObjects || undefined,
|
|
sessionId,
|
|
})),
|
|
});
|
|
|
|
// only count message generated by user
|
|
const userMessages = messages.filter(m => m.role === 'user');
|
|
await this.db.aiSession.update({
|
|
where: { id: sessionId },
|
|
data: {
|
|
messageCost: { increment: userMessages.length },
|
|
tokenCost: { increment: tokenCost },
|
|
},
|
|
});
|
|
}
|
|
}
|
|
|
|
@Transactional()
|
|
async revertLatestMessage(
|
|
sessionId: string,
|
|
removeLatestUserMessage: boolean
|
|
) {
|
|
const id = await this.getExists(sessionId, { id: true }).then(
|
|
session => session?.id
|
|
);
|
|
if (!id) {
|
|
throw new CopilotSessionNotFound();
|
|
}
|
|
const messages = await this.getMessages(id, { id: true, role: true });
|
|
const ids = messages
|
|
.slice(
|
|
messages.findLastIndex(({ role }) => role === AiPromptRole.user) +
|
|
(removeLatestUserMessage ? 0 : 1)
|
|
)
|
|
.map(({ id }) => id);
|
|
|
|
if (ids.length) {
|
|
await this.db.aiSessionMessage.deleteMany({ where: { id: { in: ids } } });
|
|
|
|
// clear the title if there only one round of conversation left
|
|
const remainingMessages = await this.getMessages(id, { role: true });
|
|
const userMessageCount = remainingMessages.filter(
|
|
m => m.role === AiPromptRole.user
|
|
).length;
|
|
|
|
if (userMessageCount <= 1) {
|
|
await this.db.aiSession.update({
|
|
where: { id },
|
|
data: { title: null },
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
@Transactional()
|
|
async countUserMessages(userId: string): Promise<number> {
|
|
const sessions = await this.db.aiSession.findMany({
|
|
where: { userId },
|
|
select: { messageCost: true, prompt: { select: { action: true } } },
|
|
});
|
|
return sessions
|
|
.map(({ messageCost, prompt: { action } }) => (action ? 1 : messageCost))
|
|
.reduce((prev, cost) => prev + cost, 0);
|
|
}
|
|
|
|
@Transactional()
|
|
async cleanupEmptySessions(earlyThen: Date) {
|
|
// delete never used sessions
|
|
const { count: removed } = await this.db.aiSession.deleteMany({
|
|
where: {
|
|
messageCost: 0,
|
|
deletedAt: null,
|
|
// filter session updated more than 24 hours ago
|
|
updatedAt: { lt: earlyThen },
|
|
},
|
|
});
|
|
|
|
// mark empty sessions as deleted
|
|
const { count: cleaned } = await this.db.aiSession.updateMany({
|
|
where: {
|
|
deletedAt: null,
|
|
messages: { none: {} },
|
|
// filter session updated more than 24 hours ago
|
|
updatedAt: { lt: earlyThen },
|
|
},
|
|
data: {
|
|
deletedAt: new Date(),
|
|
pinned: false,
|
|
},
|
|
});
|
|
|
|
return { removed, cleaned };
|
|
}
|
|
|
|
@Transactional()
|
|
async toBeGenerateTitle(take: number) {
|
|
const sessions = await this.db.aiSession
|
|
.findMany({
|
|
where: {
|
|
title: null,
|
|
deletedAt: null,
|
|
messages: { some: {} },
|
|
// only generate titles for non-actions sessions
|
|
prompt: { action: null },
|
|
},
|
|
select: {
|
|
id: true,
|
|
// count assistant messages
|
|
_count: { select: { messages: { where: { role: 'assistant' } } } },
|
|
},
|
|
take,
|
|
orderBy: { updatedAt: 'desc' },
|
|
})
|
|
.then(s => s.filter(s => s._count.messages > 0));
|
|
|
|
return sessions;
|
|
}
|
|
}
|