mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-14 21:27:20 +00:00
@@ -67,7 +67,9 @@ class CreateChatSessionInput {
|
||||
}
|
||||
|
||||
@InputType()
|
||||
class UpdateChatSessionInput implements Omit<UpdateChatSession, 'userId'> {
|
||||
class UpdateChatSessionInput
|
||||
implements Omit<UpdateChatSession, 'userId' | 'title'>
|
||||
{
|
||||
@Field(() => String)
|
||||
sessionId!: string;
|
||||
|
||||
@@ -336,6 +338,9 @@ export class CopilotSessionType {
|
||||
@Field(() => Boolean)
|
||||
pinned!: boolean;
|
||||
|
||||
@Field(() => String, { nullable: true })
|
||||
title!: string | null;
|
||||
|
||||
@Field(() => ID, { nullable: true })
|
||||
parentSessionId!: string | null;
|
||||
|
||||
@@ -653,6 +658,7 @@ export class CopilotResolver {
|
||||
parentSessionId: session.parentSessionId,
|
||||
docId: session.docId,
|
||||
pinned: session.pinned,
|
||||
title: session.title,
|
||||
promptName: session.prompt.name,
|
||||
model: session.prompt.model,
|
||||
optionalModels: session.prompt.optionalModels,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { randomUUID } from 'node:crypto';
|
||||
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
import { ModuleRef } from '@nestjs/core';
|
||||
import { Transactional } from '@nestjs-cls/transactional';
|
||||
import { AiPromptRole } from '@prisma/client';
|
||||
|
||||
@@ -11,6 +12,9 @@ import {
|
||||
CopilotQuotaExceeded,
|
||||
CopilotSessionInvalidInput,
|
||||
CopilotSessionNotFound,
|
||||
JobQueue,
|
||||
NoCopilotProviderAvailable,
|
||||
OnJob,
|
||||
} from '../../base';
|
||||
import { QuotaService } from '../../core/quota';
|
||||
import {
|
||||
@@ -22,7 +26,12 @@ import {
|
||||
} from '../../models';
|
||||
import { ChatMessageCache } from './message';
|
||||
import { PromptService } from './prompt';
|
||||
import { PromptMessage, PromptParams } from './providers';
|
||||
import {
|
||||
CopilotProviderFactory,
|
||||
ModelOutputType,
|
||||
PromptMessage,
|
||||
PromptParams,
|
||||
} from './providers';
|
||||
import {
|
||||
type ChatHistory,
|
||||
type ChatMessage,
|
||||
@@ -33,6 +42,14 @@ import {
|
||||
type SubmittedMessage,
|
||||
} from './types';
|
||||
|
||||
declare global {
|
||||
interface Jobs {
|
||||
'copilot.session.generateTitle': {
|
||||
sessionId: string;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
export class ChatSession implements AsyncDisposable {
|
||||
private stashMessageCount = 0;
|
||||
constructor(
|
||||
@@ -224,10 +241,12 @@ export class ChatSessionService {
|
||||
private readonly logger = new Logger(ChatSessionService.name);
|
||||
|
||||
constructor(
|
||||
private readonly moduleRef: ModuleRef,
|
||||
private readonly models: Models,
|
||||
private readonly jobs: JobQueue,
|
||||
private readonly quota: QuotaService,
|
||||
private readonly messageCache: ChatMessageCache,
|
||||
private readonly prompt: PromptService,
|
||||
private readonly models: Models
|
||||
private readonly prompt: PromptService
|
||||
) {}
|
||||
|
||||
async getSession(sessionId: string): Promise<ChatSessionState | undefined> {
|
||||
@@ -244,6 +263,7 @@ export class ChatSessionService {
|
||||
workspaceId: session.workspaceId,
|
||||
docId: session.docId,
|
||||
pinned: session.pinned,
|
||||
title: session.title,
|
||||
parentSessionId: session.parentSessionId,
|
||||
prompt,
|
||||
messages: messages.success ? messages.data : [],
|
||||
@@ -282,6 +302,7 @@ export class ChatSessionService {
|
||||
workspaceId: session.workspaceId,
|
||||
docId: session.docId,
|
||||
pinned: session.pinned,
|
||||
title: session.title,
|
||||
parentSessionId: session.parentSessionId,
|
||||
prompt,
|
||||
};
|
||||
@@ -303,6 +324,7 @@ export class ChatSessionService {
|
||||
workspaceId,
|
||||
docId,
|
||||
pinned,
|
||||
title,
|
||||
promptName,
|
||||
tokenCost,
|
||||
messages,
|
||||
@@ -347,6 +369,7 @@ export class ChatSessionService {
|
||||
workspaceId,
|
||||
docId,
|
||||
pinned,
|
||||
title,
|
||||
action: prompt.action || null,
|
||||
tokens: tokenCost,
|
||||
createdAt,
|
||||
@@ -418,6 +441,7 @@ export class ChatSessionService {
|
||||
...options,
|
||||
sessionId,
|
||||
prompt,
|
||||
title: null,
|
||||
messages: [],
|
||||
// when client create chat session, we always find root session
|
||||
parentSessionId: null,
|
||||
@@ -520,8 +544,78 @@ export class ChatSessionService {
|
||||
if (state) {
|
||||
return new ChatSession(this.messageCache, state, async state => {
|
||||
await this.models.copilotSession.updateMessages(state);
|
||||
if (!state.prompt.action) {
|
||||
await this.jobs.add('copilot.session.generateTitle', { sessionId });
|
||||
}
|
||||
});
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
// public for test mock
|
||||
async chatWithPrompt(
|
||||
promptName: string,
|
||||
message: Partial<PromptMessage>
|
||||
): Promise<string> {
|
||||
const prompt = await this.prompt.get(promptName);
|
||||
if (!prompt) {
|
||||
throw new CopilotPromptNotFound({ name: promptName });
|
||||
}
|
||||
|
||||
const cond = { modelId: prompt.model };
|
||||
const msg = { role: 'user' as const, content: '', ...message };
|
||||
const config = Object.assign({}, prompt.config);
|
||||
|
||||
const provider = await this.moduleRef
|
||||
.get(CopilotProviderFactory)
|
||||
.getProvider({
|
||||
outputType: ModelOutputType.Text,
|
||||
modelId: prompt.model,
|
||||
});
|
||||
|
||||
if (!provider) {
|
||||
throw new NoCopilotProviderAvailable();
|
||||
}
|
||||
|
||||
return provider.text(cond, [...prompt.finish({}), msg], config);
|
||||
}
|
||||
|
||||
@OnJob('copilot.session.generateTitle')
|
||||
async generateSessionTitle(job: Jobs['copilot.session.generateTitle']) {
|
||||
const { sessionId } = job;
|
||||
|
||||
try {
|
||||
const session = await this.models.copilotSession.get(sessionId);
|
||||
if (!session) {
|
||||
this.logger.warn(
|
||||
`Session ${sessionId} not found when generating title`
|
||||
);
|
||||
return;
|
||||
}
|
||||
const { userId, title, messages } = session;
|
||||
if (
|
||||
title ||
|
||||
!messages.length ||
|
||||
messages.filter(m => m.role === 'user').length === 0 ||
|
||||
messages.filter(m => m.role === 'assistant').length === 0
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
{
|
||||
const title = await this.chatWithPrompt('Summary as title', {
|
||||
content: session.messages
|
||||
.map(m => `[${m.role}]: ${m.content}`)
|
||||
.join('\n'),
|
||||
});
|
||||
await this.models.copilotSession.update({ userId, sessionId, title });
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`Failed to generate title for session ${sessionId}:`,
|
||||
error
|
||||
);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,6 +50,7 @@ export const ChatHistorySchema = z
|
||||
workspaceId: z.string(),
|
||||
docId: z.string().nullable(),
|
||||
pinned: z.boolean(),
|
||||
title: z.string().nullable(),
|
||||
action: z.string().nullable(),
|
||||
tokens: z.number(),
|
||||
messages: z.array(ChatMessageSchema),
|
||||
@@ -85,6 +86,7 @@ export interface ChatSessionForkOptions
|
||||
|
||||
export interface ChatSessionState
|
||||
extends Omit<ChatSessionOptions, 'promptName'> {
|
||||
title: string | null;
|
||||
// connect ids
|
||||
sessionId: string;
|
||||
parentSessionId: string | null;
|
||||
|
||||
Reference in New Issue
Block a user