feat: title of session (#12971)

fix AI-253
This commit is contained in:
DarkSky
2025-07-01 13:24:42 +08:00
committed by GitHub
parent 2be3f84196
commit 6e034185cf
16 changed files with 390 additions and 42 deletions

View File

@@ -330,3 +330,45 @@ Generated by [AVA](https://avajs.dev).
],
{},
]
## should handle generateSessionTitle correctly under various conditions
> should generate title when conditions are met
{
chatWithPromptCalled: undefined,
exists: true,
title: 'What is Machine Learning?',
}
> should not generate title when session already has title
{
chatWithPromptCalled: false,
exists: true,
title: 'Existing Title',
}
> should not generate title when no user messages exist
{
chatWithPromptCalled: false,
exists: true,
title: null,
}
> should not generate title when no assistant messages exist
{
chatWithPromptCalled: false,
exists: true,
title: null,
}
> should use correct prompt for title generation
{
content: `[user]: Explain quantum computing briefly␊
[assistant]: Quantum computing uses quantum mechanics principles.`,
promptName: 'Summary as title',
}

View File

@@ -11,7 +11,11 @@ import { EventBus, JobQueue } from '../base';
import { ConfigModule } from '../base/config';
import { AuthService } from '../core/auth';
import { QuotaModule } from '../core/quota';
import { ContextCategories, WorkspaceModel } from '../models';
import {
ContextCategories,
CopilotSessionModel,
WorkspaceModel,
} from '../models';
import { CopilotModule } from '../plugins/copilot';
import { CopilotContextService } from '../plugins/copilot/context';
import {
@@ -57,12 +61,13 @@ import { MockCopilotProvider } from './mocks';
import { createTestingModule, TestingModule } from './utils';
import { WorkflowTestCases } from './utils/copilot';
const test = ava as TestFn<{
type Context = {
auth: AuthService;
module: TestingModule;
db: PrismaClient;
event: EventBus;
workspace: WorkspaceModel;
copilotSession: CopilotSessionModel;
context: CopilotContextService;
prompt: PromptService;
transcript: CopilotTranscriptionService;
@@ -78,7 +83,8 @@ const test = ava as TestFn<{
html: CopilotCheckHtmlExecutor;
json: CopilotCheckJsonExecutor;
};
}>;
};
const test = ava as TestFn<Context>;
let userId: string;
test.before(async t => {
@@ -119,6 +125,7 @@ test.before(async t => {
const db = module.get(PrismaClient);
const event = module.get(EventBus);
const workspace = module.get(WorkspaceModel);
const copilotSession = module.get(CopilotSessionModel);
const prompt = module.get(PromptService);
const factory = module.get(CopilotProviderFactory);
@@ -136,6 +143,7 @@ test.before(async t => {
t.context.db = db;
t.context.event = event;
t.context.workspace = workspace;
t.context.copilotSession = copilotSession;
t.context.prompt = prompt;
t.context.factory = factory;
t.context.session = session;
@@ -1752,3 +1760,168 @@ test('should be able to manage workspace embedding', async t => {
t.is(ret2.length, 0, 'should not match workspace context');
}
});
test('should handle generateSessionTitle correctly under various conditions', async t => {
const { prompt, session, workspace, copilotSession } = t.context;
await prompt.set('test', 'model', [{ role: 'user', content: '{{content}}' }]);
const createSession = async (
options: {
userMessage?: string;
assistantMessage?: string;
existingTitle?: string;
} = {}
) => {
const ws = await workspace.create(userId);
const sessionId = await session.create({
docId: 'test-doc',
workspaceId: ws.id,
userId,
promptName: 'test',
pinned: false,
});
if (options.existingTitle) {
await copilotSession.update({
userId,
sessionId,
title: options.existingTitle,
});
}
const chatSession = await session.get(sessionId);
if (chatSession) {
if (options.userMessage) {
chatSession.push({
role: 'user',
content: options.userMessage,
createdAt: new Date(),
});
}
if (options.assistantMessage) {
chatSession.push({
role: 'assistant',
content: options.assistantMessage,
createdAt: new Date(),
});
}
await chatSession.save();
}
return sessionId;
};
const testCases = [
{
name: 'should generate title when conditions are met',
setup: () =>
createSession({
userMessage: 'What is machine learning?',
assistantMessage:
'Machine learning is a subset of artificial intelligence.',
}),
mockFn: () => 'What is Machine Learning?',
expectSnapshot: true,
},
{
name: 'should not generate title when session already has title',
setup: () =>
createSession({
userMessage: 'Test message',
assistantMessage: 'Test response',
existingTitle: 'Existing Title',
}),
mockFn: () => 'New Title',
expectSnapshot: true,
expectNotCalled: true,
},
{
name: 'should not generate title when no user messages exist',
setup: () =>
createSession({ assistantMessage: 'Hello! How can I help you?' }),
mockFn: () => 'New Title',
expectSnapshot: true,
expectNotCalled: true,
},
{
name: 'should not generate title when no assistant messages exist',
setup: () => createSession({ userMessage: 'What is AI?' }),
mockFn: () => 'New Title',
expectSnapshot: true,
expectNotCalled: true,
},
{
name: 'should handle errors gracefully',
setup: () =>
createSession({
userMessage: 'Test question',
assistantMessage: 'Test answer',
}),
mockFn: () => {
throw new Error('Mock error for testing');
},
expectError: 'Mock error for testing',
},
];
for (const testCase of testCases) {
const sessionId = await testCase.setup();
let chatWithPromptCalled = false;
const mockStub = Sinon.stub(session, 'chatWithPrompt').callsFake(
async () => {
chatWithPromptCalled = true;
return testCase.mockFn();
}
);
if (testCase.expectError) {
await t.throwsAsync(
() => session.generateSessionTitle({ sessionId }),
{ message: testCase.expectError },
testCase.name
);
} else {
await session.generateSessionTitle({ sessionId });
if (testCase.expectSnapshot) {
const sessionState = await session.getSession(sessionId);
t.snapshot(
{
chatWithPromptCalled: testCase.expectNotCalled
? chatWithPromptCalled
: undefined,
title: sessionState?.title,
exists: !!sessionState,
},
testCase.name
);
}
}
mockStub.restore();
}
{
const sessionId = await createSession({
userMessage: 'Explain quantum computing briefly',
assistantMessage: 'Quantum computing uses quantum mechanics principles.',
});
let capturedArgs: any[] = [];
Sinon.stub(session, 'chatWithPrompt').callsFake(async (...args) => {
capturedArgs = args;
return 'Quantum Computing Explained';
});
await session.generateSessionTitle({ sessionId });
t.snapshot(
{
promptName: capturedArgs[0],
content: capturedArgs[1]?.content,
},
'should use correct prompt for title generation'
);
}
});

View File

@@ -58,6 +58,7 @@ test.beforeEach(async t => {
workspaceId: workspace.id,
docId,
userId: user.id,
title: null,
promptName: 'prompt-name',
promptAction: null,
});

View File

@@ -82,6 +82,7 @@ const createTestSession = async (
workspaceId: workspace.id,
docId: null,
pinned: false,
title: null,
promptName: TEST_PROMPTS.NORMAL,
promptAction: null,
...overrides,
@@ -297,6 +298,7 @@ test('should pin and unpin sessions', async t => {
promptName: 'test-prompt',
promptAction: null,
pinned: true,
title: null,
});
const firstSession = await copilotSession.get(firstSessionId);
@@ -312,6 +314,7 @@ test('should pin and unpin sessions', async t => {
promptName: 'test-prompt',
promptAction: null,
pinned: true,
title: null,
});
const sessionStatesAfterSecondPin = await getSessionStates(db, [
@@ -796,6 +799,7 @@ test('should handle fork and session attachment operations', async t => {
workspaceId: workspace.id,
docId: forkConfig.docId,
pinned: forkConfig.pinned,
title: null,
parentSessionId,
prompt: { name: TEST_PROMPTS.NORMAL, action: null, model: 'gpt-4.1' },
messages: [

View File

@@ -50,6 +50,7 @@ type PureChatSession = {
workspaceId: string;
docId?: string | null;
pinned?: boolean;
title: string | null;
messages?: ChatMessage[];
// connect ids
userId: string;
@@ -82,7 +83,7 @@ type UpdateChatSessionMessage = ChatSessionBaseState & {
};
export type UpdateChatSessionOptions = ChatSessionBaseState &
Pick<Partial<ChatSession>, 'docId' | 'pinned' | 'promptName'>;
Pick<Partial<ChatSession>, 'docId' | 'pinned' | 'promptName' | 'title'>;
export type UpdateChatSession = ChatSessionBaseState & UpdateChatSessionOptions;
@@ -254,7 +255,7 @@ export class CopilotSessionModel extends BaseModel {
return (await this.db.aiSession.findUnique({
where: { ...where, id: sessionId, deletedAt: null },
select,
})) as Prisma.AiSessionGetPayload<{ select: Select }>;
})) as Prisma.AiSessionGetPayload<{ select: Select }> | null;
}
@Transactional()
@@ -266,6 +267,7 @@ export class CopilotSessionModel extends BaseModel {
docId: true,
pinned: true,
parentSessionId: true,
title: true,
messages: {
select: {
id: true,
@@ -331,6 +333,7 @@ export class CopilotSessionModel extends BaseModel {
docId: true,
parentSessionId: true,
pinned: true,
title: true,
promptName: true,
tokenCost: true,
createdAt: true,
@@ -373,7 +376,7 @@ export class CopilotSessionModel extends BaseModel {
@Transactional()
async update(options: UpdateChatSessionOptions): Promise<string> {
const { userId, sessionId, docId, promptName, pinned } = options;
const { userId, sessionId, docId, promptName, pinned, title } = options;
const session = await this.getExists(
sessionId,
{
@@ -419,7 +422,7 @@ export class CopilotSessionModel extends BaseModel {
await this.db.aiSession.update({
where: { id: sessionId },
data: { docId, promptName, pinned },
data: { docId, promptName, pinned, title },
});
return sessionId;
@@ -522,17 +525,29 @@ export class CopilotSessionModel extends BaseModel {
if (!id) {
throw new CopilotSessionNotFound();
}
const ids = await this.getMessages(id, { id: true, role: true }).then(
roles =>
roles
.slice(
roles.findLastIndex(({ role }) => role === AiPromptRole.user) +
(removeLatestUserMessage ? 0 : 1)
)
.map(({ id }) => id)
);
const messages = await this.getMessages(id, { id: true, role: true });
const ids = messages
.slice(
messages.findLastIndex(({ role }) => role === AiPromptRole.user) +
(removeLatestUserMessage ? 0 : 1)
)
.map(({ id }) => id);
if (ids.length) {
await this.db.aiSessionMessage.deleteMany({ where: { id: { in: ids } } });
// clear the title if there only one round of conversation left
const remainingMessages = await this.getMessages(id, { role: true });
const userMessageCount = remainingMessages.filter(
m => m.role === AiPromptRole.user
).length;
if (userMessageCount <= 1) {
await this.db.aiSession.update({
where: { id },
data: { title: null },
});
}
}
}

View File

@@ -67,7 +67,9 @@ class CreateChatSessionInput {
}
@InputType()
class UpdateChatSessionInput implements Omit<UpdateChatSession, 'userId'> {
class UpdateChatSessionInput
implements Omit<UpdateChatSession, 'userId' | 'title'>
{
@Field(() => String)
sessionId!: string;
@@ -336,6 +338,9 @@ export class CopilotSessionType {
@Field(() => Boolean)
pinned!: boolean;
@Field(() => String, { nullable: true })
title!: string | null;
@Field(() => ID, { nullable: true })
parentSessionId!: string | null;
@@ -653,6 +658,7 @@ export class CopilotResolver {
parentSessionId: session.parentSessionId,
docId: session.docId,
pinned: session.pinned,
title: session.title,
promptName: session.prompt.name,
model: session.prompt.model,
optionalModels: session.prompt.optionalModels,

View File

@@ -1,6 +1,7 @@
import { randomUUID } from 'node:crypto';
import { Injectable, Logger } from '@nestjs/common';
import { ModuleRef } from '@nestjs/core';
import { Transactional } from '@nestjs-cls/transactional';
import { AiPromptRole } from '@prisma/client';
@@ -11,6 +12,9 @@ import {
CopilotQuotaExceeded,
CopilotSessionInvalidInput,
CopilotSessionNotFound,
JobQueue,
NoCopilotProviderAvailable,
OnJob,
} from '../../base';
import { QuotaService } from '../../core/quota';
import {
@@ -22,7 +26,12 @@ import {
} from '../../models';
import { ChatMessageCache } from './message';
import { PromptService } from './prompt';
import { PromptMessage, PromptParams } from './providers';
import {
CopilotProviderFactory,
ModelOutputType,
PromptMessage,
PromptParams,
} from './providers';
import {
type ChatHistory,
type ChatMessage,
@@ -33,6 +42,14 @@ import {
type SubmittedMessage,
} from './types';
declare global {
interface Jobs {
'copilot.session.generateTitle': {
sessionId: string;
};
}
}
export class ChatSession implements AsyncDisposable {
private stashMessageCount = 0;
constructor(
@@ -224,10 +241,12 @@ export class ChatSessionService {
private readonly logger = new Logger(ChatSessionService.name);
constructor(
private readonly moduleRef: ModuleRef,
private readonly models: Models,
private readonly jobs: JobQueue,
private readonly quota: QuotaService,
private readonly messageCache: ChatMessageCache,
private readonly prompt: PromptService,
private readonly models: Models
private readonly prompt: PromptService
) {}
async getSession(sessionId: string): Promise<ChatSessionState | undefined> {
@@ -244,6 +263,7 @@ export class ChatSessionService {
workspaceId: session.workspaceId,
docId: session.docId,
pinned: session.pinned,
title: session.title,
parentSessionId: session.parentSessionId,
prompt,
messages: messages.success ? messages.data : [],
@@ -282,6 +302,7 @@ export class ChatSessionService {
workspaceId: session.workspaceId,
docId: session.docId,
pinned: session.pinned,
title: session.title,
parentSessionId: session.parentSessionId,
prompt,
};
@@ -303,6 +324,7 @@ export class ChatSessionService {
workspaceId,
docId,
pinned,
title,
promptName,
tokenCost,
messages,
@@ -347,6 +369,7 @@ export class ChatSessionService {
workspaceId,
docId,
pinned,
title,
action: prompt.action || null,
tokens: tokenCost,
createdAt,
@@ -418,6 +441,7 @@ export class ChatSessionService {
...options,
sessionId,
prompt,
title: null,
messages: [],
// when client create chat session, we always find root session
parentSessionId: null,
@@ -520,8 +544,78 @@ export class ChatSessionService {
if (state) {
return new ChatSession(this.messageCache, state, async state => {
await this.models.copilotSession.updateMessages(state);
if (!state.prompt.action) {
await this.jobs.add('copilot.session.generateTitle', { sessionId });
}
});
}
return null;
}
// public for test mock
async chatWithPrompt(
promptName: string,
message: Partial<PromptMessage>
): Promise<string> {
const prompt = await this.prompt.get(promptName);
if (!prompt) {
throw new CopilotPromptNotFound({ name: promptName });
}
const cond = { modelId: prompt.model };
const msg = { role: 'user' as const, content: '', ...message };
const config = Object.assign({}, prompt.config);
const provider = await this.moduleRef
.get(CopilotProviderFactory)
.getProvider({
outputType: ModelOutputType.Text,
modelId: prompt.model,
});
if (!provider) {
throw new NoCopilotProviderAvailable();
}
return provider.text(cond, [...prompt.finish({}), msg], config);
}
@OnJob('copilot.session.generateTitle')
async generateSessionTitle(job: Jobs['copilot.session.generateTitle']) {
const { sessionId } = job;
try {
const session = await this.models.copilotSession.get(sessionId);
if (!session) {
this.logger.warn(
`Session ${sessionId} not found when generating title`
);
return;
}
const { userId, title, messages } = session;
if (
title ||
!messages.length ||
messages.filter(m => m.role === 'user').length === 0 ||
messages.filter(m => m.role === 'assistant').length === 0
) {
return;
}
{
const title = await this.chatWithPrompt('Summary as title', {
content: session.messages
.map(m => `[${m.role}]: ${m.content}`)
.join('\n'),
});
await this.models.copilotSession.update({ userId, sessionId, title });
}
} catch (error) {
console.error(
`Failed to generate title for session ${sessionId}:`,
error
);
throw error;
}
}
}

View File

@@ -50,6 +50,7 @@ export const ChatHistorySchema = z
workspaceId: z.string(),
docId: z.string().nullable(),
pinned: z.boolean(),
title: z.string().nullable(),
action: z.string().nullable(),
tokens: z.number(),
messages: z.array(ChatMessageSchema),
@@ -85,6 +86,7 @@ export interface ChatSessionForkOptions
export interface ChatSessionState
extends Omit<ChatSessionOptions, 'promptName'> {
title: string | null;
// connect ids
sessionId: string;
parentSessionId: string | null;

View File

@@ -324,6 +324,7 @@ type CopilotSessionType {
parentSessionId: ID
pinned: Boolean!
promptName: String!
title: String
}
type CopilotWorkspaceConfig {