mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-14 05:14:54 +00:00
@@ -3,8 +3,8 @@ import { Logger } from '@nestjs/common';
|
||||
import { AiPrompt } from '@prisma/client';
|
||||
import Mustache from 'mustache';
|
||||
|
||||
import { getTokenEncoder } from '../../../native';
|
||||
import { PromptConfig, PromptMessage, PromptParams } from '../providers';
|
||||
import { getTokenEncoder } from '../types';
|
||||
|
||||
// disable escaping
|
||||
Mustache.escape = (text: string) => text;
|
||||
@@ -56,8 +56,7 @@ export class ChatPrompt {
|
||||
private readonly messages: PromptMessage[]
|
||||
) {
|
||||
this.encoder = getTokenEncoder(model);
|
||||
this.promptTokenSize =
|
||||
this.encoder?.count(messages.map(m => m.content).join('') || '') || 0;
|
||||
this.promptTokenSize = this.encode(messages.map(m => m.content).join(''));
|
||||
this.templateParamKeys = extractMustacheParams(
|
||||
messages.map(m => m.content).join('')
|
||||
);
|
||||
|
||||
@@ -39,15 +39,12 @@ import { PromptMessage, StreamObject } from './providers';
|
||||
import { ChatSessionService } from './session';
|
||||
import { CopilotStorage } from './storage';
|
||||
import {
|
||||
AvailableModels,
|
||||
type ChatHistory,
|
||||
type ChatMessage,
|
||||
type ChatSessionState,
|
||||
SubmittedMessage,
|
||||
} from './types';
|
||||
|
||||
registerEnumType(AvailableModels, { name: 'CopilotModel' });
|
||||
|
||||
export const COPILOT_LOCKER = 'copilot';
|
||||
|
||||
// ================== Input Types ==================
|
||||
@@ -301,8 +298,6 @@ class CopilotPromptMessageType {
|
||||
params!: Record<string, string> | null;
|
||||
}
|
||||
|
||||
registerEnumType(AvailableModels, { name: 'CopilotModels' });
|
||||
|
||||
@ObjectType()
|
||||
class CopilotPromptType {
|
||||
@Field(() => String)
|
||||
@@ -533,7 +528,7 @@ export class CopilotResolver {
|
||||
}
|
||||
|
||||
await this.chatSession.checkQuota(user.id);
|
||||
return await this.chatSession.updateSession({
|
||||
return await this.chatSession.update({
|
||||
...options,
|
||||
userId: user.id,
|
||||
});
|
||||
@@ -682,8 +677,8 @@ class CreateCopilotPromptInput {
|
||||
@Field(() => String)
|
||||
name!: string;
|
||||
|
||||
@Field(() => AvailableModels)
|
||||
model!: AvailableModels;
|
||||
@Field(() => String)
|
||||
model!: string;
|
||||
|
||||
@Field(() => String, { nullable: true })
|
||||
action!: string | null;
|
||||
|
||||
@@ -2,7 +2,7 @@ import { randomUUID } from 'node:crypto';
|
||||
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
import { Transactional } from '@nestjs-cls/transactional';
|
||||
import { AiPromptRole, PrismaClient } from '@prisma/client';
|
||||
import { AiPromptRole } from '@prisma/client';
|
||||
|
||||
import {
|
||||
CopilotActionTaken,
|
||||
@@ -14,10 +14,11 @@ import {
|
||||
} from '../../base';
|
||||
import { QuotaService } from '../../core/quota';
|
||||
import {
|
||||
CleanupSessionOptions,
|
||||
ListSessionOptions,
|
||||
Models,
|
||||
type UpdateChatSession,
|
||||
UpdateChatSessionData,
|
||||
UpdateChatSessionOptions,
|
||||
} from '../../models';
|
||||
import { ChatMessageCache } from './message';
|
||||
import { PromptService } from './prompt';
|
||||
@@ -29,7 +30,6 @@ import {
|
||||
type ChatSessionForkOptions,
|
||||
type ChatSessionOptions,
|
||||
type ChatSessionState,
|
||||
getTokenEncoder,
|
||||
type SubmittedMessage,
|
||||
} from './types';
|
||||
|
||||
@@ -224,46 +224,12 @@ export class ChatSessionService {
|
||||
private readonly logger = new Logger(ChatSessionService.name);
|
||||
|
||||
constructor(
|
||||
private readonly db: PrismaClient,
|
||||
private readonly quota: QuotaService,
|
||||
private readonly messageCache: ChatMessageCache,
|
||||
private readonly prompt: PromptService,
|
||||
private readonly models: Models
|
||||
) {}
|
||||
|
||||
@Transactional()
|
||||
private async setSession(state: ChatSessionState): Promise<string> {
|
||||
const session = this.models.copilotSession;
|
||||
let sessionId = state.sessionId;
|
||||
|
||||
// find existing session if session is chat session
|
||||
if (!state.prompt.action) {
|
||||
const id = await session.getChatSessionId(state);
|
||||
if (id) sessionId = id;
|
||||
}
|
||||
|
||||
const haveSession = await session.has(sessionId, state.userId);
|
||||
if (haveSession) {
|
||||
// message will only exists when setSession call by session.save
|
||||
if (state.messages.length) {
|
||||
await session.setMessages(
|
||||
sessionId,
|
||||
state.messages,
|
||||
this.calculateTokenSize(state.messages, state.prompt.model)
|
||||
);
|
||||
}
|
||||
} else {
|
||||
await session.create({
|
||||
...state,
|
||||
sessionId,
|
||||
promptName: state.prompt.name,
|
||||
promptAction: state.prompt.action ?? null,
|
||||
});
|
||||
}
|
||||
|
||||
return sessionId;
|
||||
}
|
||||
|
||||
async getSession(sessionId: string): Promise<ChatSessionState | undefined> {
|
||||
const session = await this.models.copilotSession.get(sessionId);
|
||||
if (!session) return;
|
||||
@@ -296,23 +262,6 @@ export class ChatSessionService {
|
||||
);
|
||||
}
|
||||
|
||||
private calculateTokenSize(messages: PromptMessage[], model: string): number {
|
||||
const encoder = getTokenEncoder(model);
|
||||
return messages
|
||||
.map(m => encoder?.count(m.content) ?? 0)
|
||||
.reduce((total, length) => total + length, 0);
|
||||
}
|
||||
|
||||
private async countUserMessages(userId: string): Promise<number> {
|
||||
const sessions = await this.db.aiSession.findMany({
|
||||
where: { userId },
|
||||
select: { messageCost: true, prompt: { select: { action: true } } },
|
||||
});
|
||||
return sessions
|
||||
.map(({ messageCost, prompt: { action } }) => (action ? 1 : messageCost))
|
||||
.reduce((prev, cost) => prev + cost, 0);
|
||||
}
|
||||
|
||||
async listSessions(
|
||||
options: ListSessionOptions
|
||||
): Promise<Omit<ChatSessionState, 'messages'>[]> {
|
||||
@@ -431,7 +380,7 @@ export class ChatSessionService {
|
||||
limit = quota.copilotActionLimit;
|
||||
}
|
||||
|
||||
const used = await this.countUserMessages(userId);
|
||||
const used = await this.models.copilotSession.countUserMessages(userId);
|
||||
|
||||
return { limit, used };
|
||||
}
|
||||
@@ -456,20 +405,19 @@ export class ChatSessionService {
|
||||
}
|
||||
|
||||
// validate prompt compatibility with session type
|
||||
this.models.copilotSession.checkSessionPrompt(
|
||||
options,
|
||||
prompt.name,
|
||||
prompt.action
|
||||
);
|
||||
this.models.copilotSession.checkSessionPrompt(options, prompt);
|
||||
|
||||
return await this.setSession({
|
||||
...options,
|
||||
sessionId,
|
||||
prompt,
|
||||
messages: [],
|
||||
// when client create chat session, we always find root session
|
||||
parentSessionId: null,
|
||||
});
|
||||
return await this.models.copilotSession.createWithPrompt(
|
||||
{
|
||||
...options,
|
||||
sessionId,
|
||||
prompt,
|
||||
messages: [],
|
||||
// when client create chat session, we always find root session
|
||||
parentSessionId: null,
|
||||
},
|
||||
true
|
||||
);
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
@@ -478,13 +426,16 @@ export class ChatSessionService {
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
async updateSession(options: UpdateChatSession): Promise<string> {
|
||||
async update(options: UpdateChatSession): Promise<string> {
|
||||
const session = await this.getSession(options.sessionId);
|
||||
if (!session) {
|
||||
throw new CopilotSessionNotFound();
|
||||
}
|
||||
|
||||
const finalData: UpdateChatSessionData = {};
|
||||
const finalData: UpdateChatSessionOptions = {
|
||||
userId: options.userId,
|
||||
sessionId: options.sessionId,
|
||||
};
|
||||
if (options.promptName) {
|
||||
const prompt = await this.prompt.get(options.promptName);
|
||||
if (!prompt) {
|
||||
@@ -492,11 +443,7 @@ export class ChatSessionService {
|
||||
throw new CopilotPromptNotFound({ name: options.promptName });
|
||||
}
|
||||
|
||||
this.models.copilotSession.checkSessionPrompt(
|
||||
session,
|
||||
prompt.name,
|
||||
prompt.action
|
||||
);
|
||||
this.models.copilotSession.checkSessionPrompt(session, prompt);
|
||||
finalData.promptName = prompt.name;
|
||||
}
|
||||
finalData.pinned = options.pinned;
|
||||
@@ -508,21 +455,15 @@ export class ChatSessionService {
|
||||
);
|
||||
}
|
||||
|
||||
return await this.models.copilotSession.update(
|
||||
options.userId,
|
||||
options.sessionId,
|
||||
finalData
|
||||
);
|
||||
return await this.models.copilotSession.update(finalData);
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
async fork(options: ChatSessionForkOptions): Promise<string> {
|
||||
const state = await this.getSession(options.sessionId);
|
||||
if (!state) {
|
||||
throw new CopilotSessionNotFound();
|
||||
}
|
||||
if (state.pinned) {
|
||||
await this.unpin(options.workspaceId, options.userId);
|
||||
}
|
||||
|
||||
let messages = state.messages.map(m => ({ ...m, id: undefined }));
|
||||
if (options.latestMessageId) {
|
||||
@@ -538,62 +479,17 @@ export class ChatSessionService {
|
||||
messages = messages.slice(0, lastMessageIdx + 1);
|
||||
}
|
||||
|
||||
const forkedState = {
|
||||
return await this.models.copilotSession.fork({
|
||||
...state,
|
||||
userId: options.userId,
|
||||
sessionId: randomUUID(),
|
||||
messages: [],
|
||||
parentSessionId: options.sessionId,
|
||||
};
|
||||
// create session
|
||||
await this.setSession(forkedState);
|
||||
// save message
|
||||
return await this.setSession({ ...forkedState, messages });
|
||||
messages,
|
||||
});
|
||||
}
|
||||
|
||||
async cleanup(
|
||||
options: Omit<ChatSessionOptions, 'pinned' | 'promptName'> & {
|
||||
sessionIds: string[];
|
||||
}
|
||||
) {
|
||||
return await this.db.$transaction(async tx => {
|
||||
const sessions = await tx.aiSession.findMany({
|
||||
where: {
|
||||
id: { in: options.sessionIds },
|
||||
userId: options.userId,
|
||||
workspaceId: options.workspaceId,
|
||||
docId: options.docId,
|
||||
deletedAt: null,
|
||||
},
|
||||
select: { id: true, promptName: true },
|
||||
});
|
||||
const sessionIds = sessions.map(({ id }) => id);
|
||||
// cleanup all messages
|
||||
await tx.aiSessionMessage.deleteMany({
|
||||
where: { sessionId: { in: sessionIds } },
|
||||
});
|
||||
|
||||
// only mark action session as deleted
|
||||
// chat session always can be reuse
|
||||
const actionIds = (
|
||||
await Promise.all(
|
||||
sessions.map(({ id, promptName }) =>
|
||||
this.prompt
|
||||
.get(promptName)
|
||||
.then(prompt => ({ id, action: !!prompt?.action }))
|
||||
)
|
||||
)
|
||||
)
|
||||
.filter(({ action }) => action)
|
||||
.map(({ id }) => id);
|
||||
|
||||
await tx.aiSession.updateMany({
|
||||
where: { id: { in: actionIds } },
|
||||
data: { pinned: false, deletedAt: new Date() },
|
||||
});
|
||||
|
||||
return [...sessionIds, ...actionIds];
|
||||
});
|
||||
async cleanup(options: CleanupSessionOptions) {
|
||||
return await this.models.copilotSession.cleanup(options);
|
||||
}
|
||||
|
||||
async createMessage(message: SubmittedMessage): Promise<string> {
|
||||
@@ -617,7 +513,7 @@ export class ChatSessionService {
|
||||
const state = await this.getSession(sessionId);
|
||||
if (state) {
|
||||
return new ChatSession(this.messageCache, state, async state => {
|
||||
await this.setSession(state);
|
||||
await this.models.copilotSession.updateMessages(state);
|
||||
});
|
||||
}
|
||||
return null;
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import { type Tokenizer } from '@affine/server-native';
|
||||
import { z } from 'zod';
|
||||
|
||||
import { OneMB } from '../../base';
|
||||
import { fromModelName } from '../../native';
|
||||
import type { ChatPrompt } from './prompt';
|
||||
import { PromptMessageSchema, PureMessageSchema } from './providers';
|
||||
|
||||
@@ -38,41 +36,6 @@ export const ChatQuerySchema = z
|
||||
})
|
||||
);
|
||||
|
||||
export enum AvailableModels {
|
||||
// text to text
|
||||
Gpt4Omni = 'gpt-4o',
|
||||
Gpt4Omni0806 = 'gpt-4o-2024-08-06',
|
||||
Gpt4OmniMini = 'gpt-4o-mini',
|
||||
Gpt4OmniMini0718 = 'gpt-4o-mini-2024-07-18',
|
||||
Gpt41 = 'gpt-4.1',
|
||||
Gpt410414 = 'gpt-4.1-2025-04-14',
|
||||
Gpt41Mini = 'gpt-4.1-mini',
|
||||
Gpt41Nano = 'gpt-4.1-nano',
|
||||
// embeddings
|
||||
TextEmbedding3Large = 'text-embedding-3-large',
|
||||
TextEmbedding3Small = 'text-embedding-3-small',
|
||||
TextEmbeddingAda002 = 'text-embedding-ada-002',
|
||||
// text to image
|
||||
DallE3 = 'dall-e-3',
|
||||
GptImage = 'gpt-image-1',
|
||||
}
|
||||
|
||||
const availableModels = Object.values(AvailableModels);
|
||||
|
||||
export function getTokenEncoder(model?: string | null): Tokenizer | null {
|
||||
if (!model) return null;
|
||||
if (!availableModels.includes(model as AvailableModels)) return null;
|
||||
if (model.startsWith('gpt')) {
|
||||
return fromModelName(model);
|
||||
} else if (model.startsWith('dall')) {
|
||||
// dalle don't need to calc the token
|
||||
return null;
|
||||
} else {
|
||||
// c100k based model
|
||||
return fromModelName('gpt-4');
|
||||
}
|
||||
}
|
||||
|
||||
// ======== ChatMessage ========
|
||||
|
||||
export const ChatMessageSchema = PromptMessageSchema.extend({
|
||||
|
||||
Reference in New Issue
Block a user