mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-14 21:27:20 +00:00
@@ -330,3 +330,45 @@ Generated by [AVA](https://avajs.dev).
|
||||
],
|
||||
{},
|
||||
]
|
||||
|
||||
## should handle generateSessionTitle correctly under various conditions
|
||||
|
||||
> should generate title when conditions are met
|
||||
|
||||
{
|
||||
chatWithPromptCalled: undefined,
|
||||
exists: true,
|
||||
title: 'What is Machine Learning?',
|
||||
}
|
||||
|
||||
> should not generate title when session already has title
|
||||
|
||||
{
|
||||
chatWithPromptCalled: false,
|
||||
exists: true,
|
||||
title: 'Existing Title',
|
||||
}
|
||||
|
||||
> should not generate title when no user messages exist
|
||||
|
||||
{
|
||||
chatWithPromptCalled: false,
|
||||
exists: true,
|
||||
title: null,
|
||||
}
|
||||
|
||||
> should not generate title when no assistant messages exist
|
||||
|
||||
{
|
||||
chatWithPromptCalled: false,
|
||||
exists: true,
|
||||
title: null,
|
||||
}
|
||||
|
||||
> should use correct prompt for title generation
|
||||
|
||||
{
|
||||
content: `[user]: Explain quantum computing briefly␊
|
||||
[assistant]: Quantum computing uses quantum mechanics principles.`,
|
||||
promptName: 'Summary as title',
|
||||
}
|
||||
|
||||
Binary file not shown.
@@ -11,7 +11,11 @@ import { EventBus, JobQueue } from '../base';
|
||||
import { ConfigModule } from '../base/config';
|
||||
import { AuthService } from '../core/auth';
|
||||
import { QuotaModule } from '../core/quota';
|
||||
import { ContextCategories, WorkspaceModel } from '../models';
|
||||
import {
|
||||
ContextCategories,
|
||||
CopilotSessionModel,
|
||||
WorkspaceModel,
|
||||
} from '../models';
|
||||
import { CopilotModule } from '../plugins/copilot';
|
||||
import { CopilotContextService } from '../plugins/copilot/context';
|
||||
import {
|
||||
@@ -57,12 +61,13 @@ import { MockCopilotProvider } from './mocks';
|
||||
import { createTestingModule, TestingModule } from './utils';
|
||||
import { WorkflowTestCases } from './utils/copilot';
|
||||
|
||||
const test = ava as TestFn<{
|
||||
type Context = {
|
||||
auth: AuthService;
|
||||
module: TestingModule;
|
||||
db: PrismaClient;
|
||||
event: EventBus;
|
||||
workspace: WorkspaceModel;
|
||||
copilotSession: CopilotSessionModel;
|
||||
context: CopilotContextService;
|
||||
prompt: PromptService;
|
||||
transcript: CopilotTranscriptionService;
|
||||
@@ -78,7 +83,8 @@ const test = ava as TestFn<{
|
||||
html: CopilotCheckHtmlExecutor;
|
||||
json: CopilotCheckJsonExecutor;
|
||||
};
|
||||
}>;
|
||||
};
|
||||
const test = ava as TestFn<Context>;
|
||||
let userId: string;
|
||||
|
||||
test.before(async t => {
|
||||
@@ -119,6 +125,7 @@ test.before(async t => {
|
||||
const db = module.get(PrismaClient);
|
||||
const event = module.get(EventBus);
|
||||
const workspace = module.get(WorkspaceModel);
|
||||
const copilotSession = module.get(CopilotSessionModel);
|
||||
const prompt = module.get(PromptService);
|
||||
const factory = module.get(CopilotProviderFactory);
|
||||
|
||||
@@ -136,6 +143,7 @@ test.before(async t => {
|
||||
t.context.db = db;
|
||||
t.context.event = event;
|
||||
t.context.workspace = workspace;
|
||||
t.context.copilotSession = copilotSession;
|
||||
t.context.prompt = prompt;
|
||||
t.context.factory = factory;
|
||||
t.context.session = session;
|
||||
@@ -1752,3 +1760,168 @@ test('should be able to manage workspace embedding', async t => {
|
||||
t.is(ret2.length, 0, 'should not match workspace context');
|
||||
}
|
||||
});
|
||||
|
||||
test('should handle generateSessionTitle correctly under various conditions', async t => {
|
||||
const { prompt, session, workspace, copilotSession } = t.context;
|
||||
|
||||
await prompt.set('test', 'model', [{ role: 'user', content: '{{content}}' }]);
|
||||
const createSession = async (
|
||||
options: {
|
||||
userMessage?: string;
|
||||
assistantMessage?: string;
|
||||
existingTitle?: string;
|
||||
} = {}
|
||||
) => {
|
||||
const ws = await workspace.create(userId);
|
||||
const sessionId = await session.create({
|
||||
docId: 'test-doc',
|
||||
workspaceId: ws.id,
|
||||
userId,
|
||||
promptName: 'test',
|
||||
pinned: false,
|
||||
});
|
||||
|
||||
if (options.existingTitle) {
|
||||
await copilotSession.update({
|
||||
userId,
|
||||
sessionId,
|
||||
title: options.existingTitle,
|
||||
});
|
||||
}
|
||||
|
||||
const chatSession = await session.get(sessionId);
|
||||
if (chatSession) {
|
||||
if (options.userMessage) {
|
||||
chatSession.push({
|
||||
role: 'user',
|
||||
content: options.userMessage,
|
||||
createdAt: new Date(),
|
||||
});
|
||||
}
|
||||
if (options.assistantMessage) {
|
||||
chatSession.push({
|
||||
role: 'assistant',
|
||||
content: options.assistantMessage,
|
||||
createdAt: new Date(),
|
||||
});
|
||||
}
|
||||
await chatSession.save();
|
||||
}
|
||||
|
||||
return sessionId;
|
||||
};
|
||||
|
||||
const testCases = [
|
||||
{
|
||||
name: 'should generate title when conditions are met',
|
||||
setup: () =>
|
||||
createSession({
|
||||
userMessage: 'What is machine learning?',
|
||||
assistantMessage:
|
||||
'Machine learning is a subset of artificial intelligence.',
|
||||
}),
|
||||
mockFn: () => 'What is Machine Learning?',
|
||||
expectSnapshot: true,
|
||||
},
|
||||
{
|
||||
name: 'should not generate title when session already has title',
|
||||
setup: () =>
|
||||
createSession({
|
||||
userMessage: 'Test message',
|
||||
assistantMessage: 'Test response',
|
||||
existingTitle: 'Existing Title',
|
||||
}),
|
||||
mockFn: () => 'New Title',
|
||||
expectSnapshot: true,
|
||||
expectNotCalled: true,
|
||||
},
|
||||
{
|
||||
name: 'should not generate title when no user messages exist',
|
||||
setup: () =>
|
||||
createSession({ assistantMessage: 'Hello! How can I help you?' }),
|
||||
mockFn: () => 'New Title',
|
||||
expectSnapshot: true,
|
||||
expectNotCalled: true,
|
||||
},
|
||||
{
|
||||
name: 'should not generate title when no assistant messages exist',
|
||||
setup: () => createSession({ userMessage: 'What is AI?' }),
|
||||
mockFn: () => 'New Title',
|
||||
expectSnapshot: true,
|
||||
expectNotCalled: true,
|
||||
},
|
||||
{
|
||||
name: 'should handle errors gracefully',
|
||||
setup: () =>
|
||||
createSession({
|
||||
userMessage: 'Test question',
|
||||
assistantMessage: 'Test answer',
|
||||
}),
|
||||
mockFn: () => {
|
||||
throw new Error('Mock error for testing');
|
||||
},
|
||||
expectError: 'Mock error for testing',
|
||||
},
|
||||
];
|
||||
|
||||
for (const testCase of testCases) {
|
||||
const sessionId = await testCase.setup();
|
||||
let chatWithPromptCalled = false;
|
||||
|
||||
const mockStub = Sinon.stub(session, 'chatWithPrompt').callsFake(
|
||||
async () => {
|
||||
chatWithPromptCalled = true;
|
||||
return testCase.mockFn();
|
||||
}
|
||||
);
|
||||
|
||||
if (testCase.expectError) {
|
||||
await t.throwsAsync(
|
||||
() => session.generateSessionTitle({ sessionId }),
|
||||
{ message: testCase.expectError },
|
||||
testCase.name
|
||||
);
|
||||
} else {
|
||||
await session.generateSessionTitle({ sessionId });
|
||||
|
||||
if (testCase.expectSnapshot) {
|
||||
const sessionState = await session.getSession(sessionId);
|
||||
t.snapshot(
|
||||
{
|
||||
chatWithPromptCalled: testCase.expectNotCalled
|
||||
? chatWithPromptCalled
|
||||
: undefined,
|
||||
title: sessionState?.title,
|
||||
exists: !!sessionState,
|
||||
},
|
||||
testCase.name
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
mockStub.restore();
|
||||
}
|
||||
|
||||
{
|
||||
const sessionId = await createSession({
|
||||
userMessage: 'Explain quantum computing briefly',
|
||||
assistantMessage: 'Quantum computing uses quantum mechanics principles.',
|
||||
});
|
||||
|
||||
let capturedArgs: any[] = [];
|
||||
Sinon.stub(session, 'chatWithPrompt').callsFake(async (...args) => {
|
||||
capturedArgs = args;
|
||||
return 'Quantum Computing Explained';
|
||||
});
|
||||
|
||||
await session.generateSessionTitle({ sessionId });
|
||||
|
||||
t.snapshot(
|
||||
{
|
||||
promptName: capturedArgs[0],
|
||||
content: capturedArgs[1]?.content,
|
||||
},
|
||||
'should use correct prompt for title generation'
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -58,6 +58,7 @@ test.beforeEach(async t => {
|
||||
workspaceId: workspace.id,
|
||||
docId,
|
||||
userId: user.id,
|
||||
title: null,
|
||||
promptName: 'prompt-name',
|
||||
promptAction: null,
|
||||
});
|
||||
|
||||
@@ -82,6 +82,7 @@ const createTestSession = async (
|
||||
workspaceId: workspace.id,
|
||||
docId: null,
|
||||
pinned: false,
|
||||
title: null,
|
||||
promptName: TEST_PROMPTS.NORMAL,
|
||||
promptAction: null,
|
||||
...overrides,
|
||||
@@ -297,6 +298,7 @@ test('should pin and unpin sessions', async t => {
|
||||
promptName: 'test-prompt',
|
||||
promptAction: null,
|
||||
pinned: true,
|
||||
title: null,
|
||||
});
|
||||
|
||||
const firstSession = await copilotSession.get(firstSessionId);
|
||||
@@ -312,6 +314,7 @@ test('should pin and unpin sessions', async t => {
|
||||
promptName: 'test-prompt',
|
||||
promptAction: null,
|
||||
pinned: true,
|
||||
title: null,
|
||||
});
|
||||
|
||||
const sessionStatesAfterSecondPin = await getSessionStates(db, [
|
||||
@@ -796,6 +799,7 @@ test('should handle fork and session attachment operations', async t => {
|
||||
workspaceId: workspace.id,
|
||||
docId: forkConfig.docId,
|
||||
pinned: forkConfig.pinned,
|
||||
title: null,
|
||||
parentSessionId,
|
||||
prompt: { name: TEST_PROMPTS.NORMAL, action: null, model: 'gpt-4.1' },
|
||||
messages: [
|
||||
|
||||
@@ -50,6 +50,7 @@ type PureChatSession = {
|
||||
workspaceId: string;
|
||||
docId?: string | null;
|
||||
pinned?: boolean;
|
||||
title: string | null;
|
||||
messages?: ChatMessage[];
|
||||
// connect ids
|
||||
userId: string;
|
||||
@@ -82,7 +83,7 @@ type UpdateChatSessionMessage = ChatSessionBaseState & {
|
||||
};
|
||||
|
||||
export type UpdateChatSessionOptions = ChatSessionBaseState &
|
||||
Pick<Partial<ChatSession>, 'docId' | 'pinned' | 'promptName'>;
|
||||
Pick<Partial<ChatSession>, 'docId' | 'pinned' | 'promptName' | 'title'>;
|
||||
|
||||
export type UpdateChatSession = ChatSessionBaseState & UpdateChatSessionOptions;
|
||||
|
||||
@@ -254,7 +255,7 @@ export class CopilotSessionModel extends BaseModel {
|
||||
return (await this.db.aiSession.findUnique({
|
||||
where: { ...where, id: sessionId, deletedAt: null },
|
||||
select,
|
||||
})) as Prisma.AiSessionGetPayload<{ select: Select }>;
|
||||
})) as Prisma.AiSessionGetPayload<{ select: Select }> | null;
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
@@ -266,6 +267,7 @@ export class CopilotSessionModel extends BaseModel {
|
||||
docId: true,
|
||||
pinned: true,
|
||||
parentSessionId: true,
|
||||
title: true,
|
||||
messages: {
|
||||
select: {
|
||||
id: true,
|
||||
@@ -331,6 +333,7 @@ export class CopilotSessionModel extends BaseModel {
|
||||
docId: true,
|
||||
parentSessionId: true,
|
||||
pinned: true,
|
||||
title: true,
|
||||
promptName: true,
|
||||
tokenCost: true,
|
||||
createdAt: true,
|
||||
@@ -373,7 +376,7 @@ export class CopilotSessionModel extends BaseModel {
|
||||
|
||||
@Transactional()
|
||||
async update(options: UpdateChatSessionOptions): Promise<string> {
|
||||
const { userId, sessionId, docId, promptName, pinned } = options;
|
||||
const { userId, sessionId, docId, promptName, pinned, title } = options;
|
||||
const session = await this.getExists(
|
||||
sessionId,
|
||||
{
|
||||
@@ -419,7 +422,7 @@ export class CopilotSessionModel extends BaseModel {
|
||||
|
||||
await this.db.aiSession.update({
|
||||
where: { id: sessionId },
|
||||
data: { docId, promptName, pinned },
|
||||
data: { docId, promptName, pinned, title },
|
||||
});
|
||||
|
||||
return sessionId;
|
||||
@@ -522,17 +525,29 @@ export class CopilotSessionModel extends BaseModel {
|
||||
if (!id) {
|
||||
throw new CopilotSessionNotFound();
|
||||
}
|
||||
const ids = await this.getMessages(id, { id: true, role: true }).then(
|
||||
roles =>
|
||||
roles
|
||||
.slice(
|
||||
roles.findLastIndex(({ role }) => role === AiPromptRole.user) +
|
||||
(removeLatestUserMessage ? 0 : 1)
|
||||
)
|
||||
.map(({ id }) => id)
|
||||
);
|
||||
const messages = await this.getMessages(id, { id: true, role: true });
|
||||
const ids = messages
|
||||
.slice(
|
||||
messages.findLastIndex(({ role }) => role === AiPromptRole.user) +
|
||||
(removeLatestUserMessage ? 0 : 1)
|
||||
)
|
||||
.map(({ id }) => id);
|
||||
|
||||
if (ids.length) {
|
||||
await this.db.aiSessionMessage.deleteMany({ where: { id: { in: ids } } });
|
||||
|
||||
// clear the title if there only one round of conversation left
|
||||
const remainingMessages = await this.getMessages(id, { role: true });
|
||||
const userMessageCount = remainingMessages.filter(
|
||||
m => m.role === AiPromptRole.user
|
||||
).length;
|
||||
|
||||
if (userMessageCount <= 1) {
|
||||
await this.db.aiSession.update({
|
||||
where: { id },
|
||||
data: { title: null },
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -67,7 +67,9 @@ class CreateChatSessionInput {
|
||||
}
|
||||
|
||||
@InputType()
|
||||
class UpdateChatSessionInput implements Omit<UpdateChatSession, 'userId'> {
|
||||
class UpdateChatSessionInput
|
||||
implements Omit<UpdateChatSession, 'userId' | 'title'>
|
||||
{
|
||||
@Field(() => String)
|
||||
sessionId!: string;
|
||||
|
||||
@@ -336,6 +338,9 @@ export class CopilotSessionType {
|
||||
@Field(() => Boolean)
|
||||
pinned!: boolean;
|
||||
|
||||
@Field(() => String, { nullable: true })
|
||||
title!: string | null;
|
||||
|
||||
@Field(() => ID, { nullable: true })
|
||||
parentSessionId!: string | null;
|
||||
|
||||
@@ -653,6 +658,7 @@ export class CopilotResolver {
|
||||
parentSessionId: session.parentSessionId,
|
||||
docId: session.docId,
|
||||
pinned: session.pinned,
|
||||
title: session.title,
|
||||
promptName: session.prompt.name,
|
||||
model: session.prompt.model,
|
||||
optionalModels: session.prompt.optionalModels,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { randomUUID } from 'node:crypto';
|
||||
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
import { ModuleRef } from '@nestjs/core';
|
||||
import { Transactional } from '@nestjs-cls/transactional';
|
||||
import { AiPromptRole } from '@prisma/client';
|
||||
|
||||
@@ -11,6 +12,9 @@ import {
|
||||
CopilotQuotaExceeded,
|
||||
CopilotSessionInvalidInput,
|
||||
CopilotSessionNotFound,
|
||||
JobQueue,
|
||||
NoCopilotProviderAvailable,
|
||||
OnJob,
|
||||
} from '../../base';
|
||||
import { QuotaService } from '../../core/quota';
|
||||
import {
|
||||
@@ -22,7 +26,12 @@ import {
|
||||
} from '../../models';
|
||||
import { ChatMessageCache } from './message';
|
||||
import { PromptService } from './prompt';
|
||||
import { PromptMessage, PromptParams } from './providers';
|
||||
import {
|
||||
CopilotProviderFactory,
|
||||
ModelOutputType,
|
||||
PromptMessage,
|
||||
PromptParams,
|
||||
} from './providers';
|
||||
import {
|
||||
type ChatHistory,
|
||||
type ChatMessage,
|
||||
@@ -33,6 +42,14 @@ import {
|
||||
type SubmittedMessage,
|
||||
} from './types';
|
||||
|
||||
declare global {
|
||||
interface Jobs {
|
||||
'copilot.session.generateTitle': {
|
||||
sessionId: string;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
export class ChatSession implements AsyncDisposable {
|
||||
private stashMessageCount = 0;
|
||||
constructor(
|
||||
@@ -224,10 +241,12 @@ export class ChatSessionService {
|
||||
private readonly logger = new Logger(ChatSessionService.name);
|
||||
|
||||
constructor(
|
||||
private readonly moduleRef: ModuleRef,
|
||||
private readonly models: Models,
|
||||
private readonly jobs: JobQueue,
|
||||
private readonly quota: QuotaService,
|
||||
private readonly messageCache: ChatMessageCache,
|
||||
private readonly prompt: PromptService,
|
||||
private readonly models: Models
|
||||
private readonly prompt: PromptService
|
||||
) {}
|
||||
|
||||
async getSession(sessionId: string): Promise<ChatSessionState | undefined> {
|
||||
@@ -244,6 +263,7 @@ export class ChatSessionService {
|
||||
workspaceId: session.workspaceId,
|
||||
docId: session.docId,
|
||||
pinned: session.pinned,
|
||||
title: session.title,
|
||||
parentSessionId: session.parentSessionId,
|
||||
prompt,
|
||||
messages: messages.success ? messages.data : [],
|
||||
@@ -282,6 +302,7 @@ export class ChatSessionService {
|
||||
workspaceId: session.workspaceId,
|
||||
docId: session.docId,
|
||||
pinned: session.pinned,
|
||||
title: session.title,
|
||||
parentSessionId: session.parentSessionId,
|
||||
prompt,
|
||||
};
|
||||
@@ -303,6 +324,7 @@ export class ChatSessionService {
|
||||
workspaceId,
|
||||
docId,
|
||||
pinned,
|
||||
title,
|
||||
promptName,
|
||||
tokenCost,
|
||||
messages,
|
||||
@@ -347,6 +369,7 @@ export class ChatSessionService {
|
||||
workspaceId,
|
||||
docId,
|
||||
pinned,
|
||||
title,
|
||||
action: prompt.action || null,
|
||||
tokens: tokenCost,
|
||||
createdAt,
|
||||
@@ -418,6 +441,7 @@ export class ChatSessionService {
|
||||
...options,
|
||||
sessionId,
|
||||
prompt,
|
||||
title: null,
|
||||
messages: [],
|
||||
// when client create chat session, we always find root session
|
||||
parentSessionId: null,
|
||||
@@ -520,8 +544,78 @@ export class ChatSessionService {
|
||||
if (state) {
|
||||
return new ChatSession(this.messageCache, state, async state => {
|
||||
await this.models.copilotSession.updateMessages(state);
|
||||
if (!state.prompt.action) {
|
||||
await this.jobs.add('copilot.session.generateTitle', { sessionId });
|
||||
}
|
||||
});
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
// public for test mock
|
||||
async chatWithPrompt(
|
||||
promptName: string,
|
||||
message: Partial<PromptMessage>
|
||||
): Promise<string> {
|
||||
const prompt = await this.prompt.get(promptName);
|
||||
if (!prompt) {
|
||||
throw new CopilotPromptNotFound({ name: promptName });
|
||||
}
|
||||
|
||||
const cond = { modelId: prompt.model };
|
||||
const msg = { role: 'user' as const, content: '', ...message };
|
||||
const config = Object.assign({}, prompt.config);
|
||||
|
||||
const provider = await this.moduleRef
|
||||
.get(CopilotProviderFactory)
|
||||
.getProvider({
|
||||
outputType: ModelOutputType.Text,
|
||||
modelId: prompt.model,
|
||||
});
|
||||
|
||||
if (!provider) {
|
||||
throw new NoCopilotProviderAvailable();
|
||||
}
|
||||
|
||||
return provider.text(cond, [...prompt.finish({}), msg], config);
|
||||
}
|
||||
|
||||
@OnJob('copilot.session.generateTitle')
|
||||
async generateSessionTitle(job: Jobs['copilot.session.generateTitle']) {
|
||||
const { sessionId } = job;
|
||||
|
||||
try {
|
||||
const session = await this.models.copilotSession.get(sessionId);
|
||||
if (!session) {
|
||||
this.logger.warn(
|
||||
`Session ${sessionId} not found when generating title`
|
||||
);
|
||||
return;
|
||||
}
|
||||
const { userId, title, messages } = session;
|
||||
if (
|
||||
title ||
|
||||
!messages.length ||
|
||||
messages.filter(m => m.role === 'user').length === 0 ||
|
||||
messages.filter(m => m.role === 'assistant').length === 0
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
{
|
||||
const title = await this.chatWithPrompt('Summary as title', {
|
||||
content: session.messages
|
||||
.map(m => `[${m.role}]: ${m.content}`)
|
||||
.join('\n'),
|
||||
});
|
||||
await this.models.copilotSession.update({ userId, sessionId, title });
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`Failed to generate title for session ${sessionId}:`,
|
||||
error
|
||||
);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,6 +50,7 @@ export const ChatHistorySchema = z
|
||||
workspaceId: z.string(),
|
||||
docId: z.string().nullable(),
|
||||
pinned: z.boolean(),
|
||||
title: z.string().nullable(),
|
||||
action: z.string().nullable(),
|
||||
tokens: z.number(),
|
||||
messages: z.array(ChatMessageSchema),
|
||||
@@ -85,6 +86,7 @@ export interface ChatSessionForkOptions
|
||||
|
||||
export interface ChatSessionState
|
||||
extends Omit<ChatSessionOptions, 'promptName'> {
|
||||
title: string | null;
|
||||
// connect ids
|
||||
sessionId: string;
|
||||
parentSessionId: string | null;
|
||||
|
||||
@@ -324,6 +324,7 @@ type CopilotSessionType {
|
||||
parentSessionId: ID
|
||||
pinned: Boolean!
|
||||
promptName: String!
|
||||
title: String
|
||||
}
|
||||
|
||||
type CopilotWorkspaceConfig {
|
||||
|
||||
Reference in New Issue
Block a user