mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-13 21:05:19 +00:00
@@ -11,6 +11,7 @@ import {
|
||||
ChatMessageSchema,
|
||||
getTokenEncoder,
|
||||
PromptMessage,
|
||||
PromptMessageSchema,
|
||||
PromptParams,
|
||||
} from './types';
|
||||
|
||||
@@ -105,37 +106,62 @@ export class ChatSession implements AsyncDisposable {
|
||||
@Injectable()
|
||||
export class ChatSessionService {
|
||||
private readonly logger = new Logger(ChatSessionService.name);
|
||||
|
||||
constructor(
|
||||
private readonly db: PrismaClient,
|
||||
private readonly prompt: PromptService
|
||||
) {}
|
||||
|
||||
private async setSession(state: ChatSessionState): Promise<void> {
|
||||
await this.db.aiSession.upsert({
|
||||
where: {
|
||||
id: state.sessionId,
|
||||
},
|
||||
update: {
|
||||
messages: {
|
||||
create: state.messages.map((m, idx) => ({ idx, ...m })),
|
||||
},
|
||||
},
|
||||
create: {
|
||||
id: state.sessionId,
|
||||
messages: { create: state.messages },
|
||||
// connect
|
||||
user: { connect: { id: state.userId } },
|
||||
workspace: { connect: { id: state.workspaceId } },
|
||||
doc: {
|
||||
connect: {
|
||||
id_workspaceId: {
|
||||
id: state.docId,
|
||||
private async setSession(state: ChatSessionState): Promise<string> {
|
||||
return await this.db.$transaction(async tx => {
|
||||
let sessionId = state.sessionId;
|
||||
|
||||
// find existing session if session is chat session
|
||||
if (!state.prompt.action) {
|
||||
const { id } =
|
||||
(await tx.aiSession.findFirst({
|
||||
where: {
|
||||
userId: state.userId,
|
||||
workspaceId: state.workspaceId,
|
||||
docId: state.docId,
|
||||
prompt: { action: { equals: null } },
|
||||
},
|
||||
select: { id: true },
|
||||
})) || {};
|
||||
if (id) sessionId = id;
|
||||
}
|
||||
|
||||
await tx.aiSession.upsert({
|
||||
where: {
|
||||
id: sessionId,
|
||||
userId: state.userId,
|
||||
},
|
||||
update: {
|
||||
messages: {
|
||||
// delete old messages
|
||||
deleteMany: {},
|
||||
create: state.messages.map(m => ({
|
||||
...m,
|
||||
params: m.params || undefined,
|
||||
})),
|
||||
},
|
||||
},
|
||||
prompt: { connect: { name: state.prompt.name } },
|
||||
},
|
||||
create: {
|
||||
id: sessionId,
|
||||
workspaceId: state.workspaceId,
|
||||
docId: state.docId,
|
||||
messages: {
|
||||
create: state.messages.map(m => ({
|
||||
...m,
|
||||
params: m.params || undefined,
|
||||
})),
|
||||
},
|
||||
// connect
|
||||
user: { connect: { id: state.userId } },
|
||||
prompt: { connect: { name: state.prompt.name } },
|
||||
},
|
||||
});
|
||||
return sessionId;
|
||||
});
|
||||
}
|
||||
|
||||
@@ -171,6 +197,7 @@ export class ChatSessionService {
|
||||
})
|
||||
.then(async session => {
|
||||
if (!session) return;
|
||||
|
||||
const messages = ChatMessageSchema.array().safeParse(session.messages);
|
||||
|
||||
return {
|
||||
@@ -184,18 +211,58 @@ export class ChatSessionService {
|
||||
});
|
||||
}
|
||||
|
||||
async listHistories(
|
||||
private calculateTokenSize(
|
||||
messages: PromptMessage[],
|
||||
model: AvailableModel
|
||||
): number {
|
||||
const encoder = getTokenEncoder(model);
|
||||
return messages
|
||||
.map(m => encoder?.encode_ordinary(m.content).length || 0)
|
||||
.reduce((total, length) => total + length, 0);
|
||||
}
|
||||
|
||||
async countUserActions(userId: string): Promise<number> {
|
||||
return await this.db.aiSession.count({
|
||||
where: { userId, prompt: { action: { not: null } } },
|
||||
});
|
||||
}
|
||||
|
||||
async listSessions(
|
||||
userId: string,
|
||||
workspaceId: string,
|
||||
docId: string,
|
||||
options: ListHistoriesOptions
|
||||
options?: { docId?: string; action?: boolean }
|
||||
): Promise<string[]> {
|
||||
return await this.db.aiSession
|
||||
.findMany({
|
||||
where: {
|
||||
userId,
|
||||
workspaceId,
|
||||
docId: workspaceId === options?.docId ? undefined : options?.docId,
|
||||
prompt: {
|
||||
action: options?.action ? { not: null } : null,
|
||||
},
|
||||
},
|
||||
select: { id: true },
|
||||
})
|
||||
.then(sessions => sessions.map(({ id }) => id));
|
||||
}
|
||||
|
||||
async listHistories(
|
||||
userId: string,
|
||||
workspaceId?: string,
|
||||
docId?: string,
|
||||
options?: ListHistoriesOptions
|
||||
): Promise<ChatHistory[]> {
|
||||
return await this.db.aiSession
|
||||
.findMany({
|
||||
where: {
|
||||
userId,
|
||||
workspaceId: workspaceId,
|
||||
docId: workspaceId === docId ? undefined : docId,
|
||||
prompt: { action: { not: null } },
|
||||
id: options.sessionId ? { equals: options.sessionId } : undefined,
|
||||
prompt: {
|
||||
action: options?.action ? { not: null } : null,
|
||||
},
|
||||
id: options?.sessionId ? { equals: options.sessionId } : undefined,
|
||||
},
|
||||
select: {
|
||||
id: true,
|
||||
@@ -210,20 +277,33 @@ export class ChatSessionService {
|
||||
},
|
||||
},
|
||||
},
|
||||
take: options.limit,
|
||||
skip: options.skip,
|
||||
take: options?.limit,
|
||||
skip: options?.skip,
|
||||
orderBy: { createdAt: 'desc' },
|
||||
})
|
||||
.then(sessions =>
|
||||
sessions
|
||||
.map(({ id, prompt, messages }) => {
|
||||
const ret = ChatMessageSchema.array().safeParse(messages);
|
||||
if (ret.success) {
|
||||
const encoder = getTokenEncoder(prompt.model as AvailableModel);
|
||||
const tokens = ret.data
|
||||
.map(m => encoder?.encode_ordinary(m.content).length || 0)
|
||||
.reduce((total, length) => total + length, 0);
|
||||
return { sessionId: id, tokens, messages: ret.data };
|
||||
try {
|
||||
const ret = PromptMessageSchema.array().safeParse(messages);
|
||||
if (ret.success) {
|
||||
const tokens = this.calculateTokenSize(
|
||||
ret.data,
|
||||
prompt.model as AvailableModel
|
||||
);
|
||||
return {
|
||||
sessionId: id,
|
||||
action: prompt.action || undefined,
|
||||
tokens,
|
||||
messages: ret.data,
|
||||
};
|
||||
} else {
|
||||
this.logger.error(
|
||||
`Unexpected message schema: ${JSON.stringify(ret.error)}`
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
this.logger.error('Unexpected error in listHistories', e);
|
||||
}
|
||||
return undefined;
|
||||
})
|
||||
@@ -238,8 +318,12 @@ export class ChatSessionService {
|
||||
this.logger.error(`Prompt not found: ${options.promptName}`);
|
||||
throw new Error('Prompt not found');
|
||||
}
|
||||
await this.setSession({ ...options, sessionId, prompt, messages: [] });
|
||||
return sessionId;
|
||||
return await this.setSession({
|
||||
...options,
|
||||
sessionId,
|
||||
prompt,
|
||||
messages: [],
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user