feat(server): improve session modify (#12928)

fix AI-248
This commit is contained in:
DarkSky
2025-06-25 20:02:21 +08:00
committed by GitHub
parent 697e0bf9ba
commit e32c9a814a
11 changed files with 282 additions and 270 deletions

View File

@@ -3,8 +3,8 @@ import { Logger } from '@nestjs/common';
import { AiPrompt } from '@prisma/client';
import Mustache from 'mustache';
import { getTokenEncoder } from '../../../native';
import { PromptConfig, PromptMessage, PromptParams } from '../providers';
import { getTokenEncoder } from '../types';
// disable escaping
Mustache.escape = (text: string) => text;
@@ -56,8 +56,7 @@ export class ChatPrompt {
private readonly messages: PromptMessage[]
) {
this.encoder = getTokenEncoder(model);
this.promptTokenSize =
this.encoder?.count(messages.map(m => m.content).join('') || '') || 0;
this.promptTokenSize = this.encode(messages.map(m => m.content).join(''));
this.templateParamKeys = extractMustacheParams(
messages.map(m => m.content).join('')
);

View File

@@ -39,15 +39,12 @@ import { PromptMessage, StreamObject } from './providers';
import { ChatSessionService } from './session';
import { CopilotStorage } from './storage';
import {
AvailableModels,
type ChatHistory,
type ChatMessage,
type ChatSessionState,
SubmittedMessage,
} from './types';
registerEnumType(AvailableModels, { name: 'CopilotModel' });
export const COPILOT_LOCKER = 'copilot';
// ================== Input Types ==================
@@ -301,8 +298,6 @@ class CopilotPromptMessageType {
params!: Record<string, string> | null;
}
registerEnumType(AvailableModels, { name: 'CopilotModels' });
@ObjectType()
class CopilotPromptType {
@Field(() => String)
@@ -533,7 +528,7 @@ export class CopilotResolver {
}
await this.chatSession.checkQuota(user.id);
return await this.chatSession.updateSession({
return await this.chatSession.update({
...options,
userId: user.id,
});
@@ -682,8 +677,8 @@ class CreateCopilotPromptInput {
@Field(() => String)
name!: string;
@Field(() => AvailableModels)
model!: AvailableModels;
@Field(() => String)
model!: string;
@Field(() => String, { nullable: true })
action!: string | null;

View File

@@ -2,7 +2,7 @@ import { randomUUID } from 'node:crypto';
import { Injectable, Logger } from '@nestjs/common';
import { Transactional } from '@nestjs-cls/transactional';
import { AiPromptRole, PrismaClient } from '@prisma/client';
import { AiPromptRole } from '@prisma/client';
import {
CopilotActionTaken,
@@ -14,10 +14,11 @@ import {
} from '../../base';
import { QuotaService } from '../../core/quota';
import {
CleanupSessionOptions,
ListSessionOptions,
Models,
type UpdateChatSession,
UpdateChatSessionData,
UpdateChatSessionOptions,
} from '../../models';
import { ChatMessageCache } from './message';
import { PromptService } from './prompt';
@@ -29,7 +30,6 @@ import {
type ChatSessionForkOptions,
type ChatSessionOptions,
type ChatSessionState,
getTokenEncoder,
type SubmittedMessage,
} from './types';
@@ -224,46 +224,12 @@ export class ChatSessionService {
private readonly logger = new Logger(ChatSessionService.name);
constructor(
private readonly db: PrismaClient,
private readonly quota: QuotaService,
private readonly messageCache: ChatMessageCache,
private readonly prompt: PromptService,
private readonly models: Models
) {}
@Transactional()
private async setSession(state: ChatSessionState): Promise<string> {
const session = this.models.copilotSession;
let sessionId = state.sessionId;
// find existing session if session is chat session
if (!state.prompt.action) {
const id = await session.getChatSessionId(state);
if (id) sessionId = id;
}
const haveSession = await session.has(sessionId, state.userId);
if (haveSession) {
// message will only exists when setSession call by session.save
if (state.messages.length) {
await session.setMessages(
sessionId,
state.messages,
this.calculateTokenSize(state.messages, state.prompt.model)
);
}
} else {
await session.create({
...state,
sessionId,
promptName: state.prompt.name,
promptAction: state.prompt.action ?? null,
});
}
return sessionId;
}
async getSession(sessionId: string): Promise<ChatSessionState | undefined> {
const session = await this.models.copilotSession.get(sessionId);
if (!session) return;
@@ -296,23 +262,6 @@ export class ChatSessionService {
);
}
private calculateTokenSize(messages: PromptMessage[], model: string): number {
const encoder = getTokenEncoder(model);
return messages
.map(m => encoder?.count(m.content) ?? 0)
.reduce((total, length) => total + length, 0);
}
private async countUserMessages(userId: string): Promise<number> {
const sessions = await this.db.aiSession.findMany({
where: { userId },
select: { messageCost: true, prompt: { select: { action: true } } },
});
return sessions
.map(({ messageCost, prompt: { action } }) => (action ? 1 : messageCost))
.reduce((prev, cost) => prev + cost, 0);
}
async listSessions(
options: ListSessionOptions
): Promise<Omit<ChatSessionState, 'messages'>[]> {
@@ -431,7 +380,7 @@ export class ChatSessionService {
limit = quota.copilotActionLimit;
}
const used = await this.countUserMessages(userId);
const used = await this.models.copilotSession.countUserMessages(userId);
return { limit, used };
}
@@ -456,20 +405,19 @@ export class ChatSessionService {
}
// validate prompt compatibility with session type
this.models.copilotSession.checkSessionPrompt(
options,
prompt.name,
prompt.action
);
this.models.copilotSession.checkSessionPrompt(options, prompt);
return await this.setSession({
...options,
sessionId,
prompt,
messages: [],
// when client create chat session, we always find root session
parentSessionId: null,
});
return await this.models.copilotSession.createWithPrompt(
{
...options,
sessionId,
prompt,
messages: [],
// when client create chat session, we always find root session
parentSessionId: null,
},
true
);
}
@Transactional()
@@ -478,13 +426,16 @@ export class ChatSessionService {
}
@Transactional()
async updateSession(options: UpdateChatSession): Promise<string> {
async update(options: UpdateChatSession): Promise<string> {
const session = await this.getSession(options.sessionId);
if (!session) {
throw new CopilotSessionNotFound();
}
const finalData: UpdateChatSessionData = {};
const finalData: UpdateChatSessionOptions = {
userId: options.userId,
sessionId: options.sessionId,
};
if (options.promptName) {
const prompt = await this.prompt.get(options.promptName);
if (!prompt) {
@@ -492,11 +443,7 @@ export class ChatSessionService {
throw new CopilotPromptNotFound({ name: options.promptName });
}
this.models.copilotSession.checkSessionPrompt(
session,
prompt.name,
prompt.action
);
this.models.copilotSession.checkSessionPrompt(session, prompt);
finalData.promptName = prompt.name;
}
finalData.pinned = options.pinned;
@@ -508,21 +455,15 @@ export class ChatSessionService {
);
}
return await this.models.copilotSession.update(
options.userId,
options.sessionId,
finalData
);
return await this.models.copilotSession.update(finalData);
}
@Transactional()
async fork(options: ChatSessionForkOptions): Promise<string> {
const state = await this.getSession(options.sessionId);
if (!state) {
throw new CopilotSessionNotFound();
}
if (state.pinned) {
await this.unpin(options.workspaceId, options.userId);
}
let messages = state.messages.map(m => ({ ...m, id: undefined }));
if (options.latestMessageId) {
@@ -538,62 +479,17 @@ export class ChatSessionService {
messages = messages.slice(0, lastMessageIdx + 1);
}
const forkedState = {
return await this.models.copilotSession.fork({
...state,
userId: options.userId,
sessionId: randomUUID(),
messages: [],
parentSessionId: options.sessionId,
};
// create session
await this.setSession(forkedState);
// save message
return await this.setSession({ ...forkedState, messages });
messages,
});
}
async cleanup(
options: Omit<ChatSessionOptions, 'pinned' | 'promptName'> & {
sessionIds: string[];
}
) {
return await this.db.$transaction(async tx => {
const sessions = await tx.aiSession.findMany({
where: {
id: { in: options.sessionIds },
userId: options.userId,
workspaceId: options.workspaceId,
docId: options.docId,
deletedAt: null,
},
select: { id: true, promptName: true },
});
const sessionIds = sessions.map(({ id }) => id);
// cleanup all messages
await tx.aiSessionMessage.deleteMany({
where: { sessionId: { in: sessionIds } },
});
// only mark action session as deleted
// chat session always can be reuse
const actionIds = (
await Promise.all(
sessions.map(({ id, promptName }) =>
this.prompt
.get(promptName)
.then(prompt => ({ id, action: !!prompt?.action }))
)
)
)
.filter(({ action }) => action)
.map(({ id }) => id);
await tx.aiSession.updateMany({
where: { id: { in: actionIds } },
data: { pinned: false, deletedAt: new Date() },
});
return [...sessionIds, ...actionIds];
});
async cleanup(options: CleanupSessionOptions) {
return await this.models.copilotSession.cleanup(options);
}
async createMessage(message: SubmittedMessage): Promise<string> {
@@ -617,7 +513,7 @@ export class ChatSessionService {
const state = await this.getSession(sessionId);
if (state) {
return new ChatSession(this.messageCache, state, async state => {
await this.setSession(state);
await this.models.copilotSession.updateMessages(state);
});
}
return null;

View File

@@ -1,8 +1,6 @@
import { type Tokenizer } from '@affine/server-native';
import { z } from 'zod';
import { OneMB } from '../../base';
import { fromModelName } from '../../native';
import type { ChatPrompt } from './prompt';
import { PromptMessageSchema, PureMessageSchema } from './providers';
@@ -38,41 +36,6 @@ export const ChatQuerySchema = z
})
);
export enum AvailableModels {
// text to text
Gpt4Omni = 'gpt-4o',
Gpt4Omni0806 = 'gpt-4o-2024-08-06',
Gpt4OmniMini = 'gpt-4o-mini',
Gpt4OmniMini0718 = 'gpt-4o-mini-2024-07-18',
Gpt41 = 'gpt-4.1',
Gpt410414 = 'gpt-4.1-2025-04-14',
Gpt41Mini = 'gpt-4.1-mini',
Gpt41Nano = 'gpt-4.1-nano',
// embeddings
TextEmbedding3Large = 'text-embedding-3-large',
TextEmbedding3Small = 'text-embedding-3-small',
TextEmbeddingAda002 = 'text-embedding-ada-002',
// text to image
DallE3 = 'dall-e-3',
GptImage = 'gpt-image-1',
}
const availableModels = Object.values(AvailableModels);
export function getTokenEncoder(model?: string | null): Tokenizer | null {
if (!model) return null;
if (!availableModels.includes(model as AvailableModels)) return null;
if (model.startsWith('gpt')) {
return fromModelName(model);
} else if (model.startsWith('dall')) {
// dalle don't need to calc the token
return null;
} else {
// c100k based model
return fromModelName('gpt-4');
}
}
// ======== ChatMessage ========
export const ChatMessageSchema = PromptMessageSchema.extend({