diff --git a/packages/backend/server/migrations/20250609063353_ai_session_independence/migration.sql b/packages/backend/server/migrations/20250609063353_ai_session_independence/migration.sql new file mode 100644 index 0000000000..6def2babdb --- /dev/null +++ b/packages/backend/server/migrations/20250609063353_ai_session_independence/migration.sql @@ -0,0 +1,21 @@ +-- AlterTable +ALTER TABLE "ai_sessions_metadata" ALTER COLUMN "doc_id" DROP NOT NULL; + +-- AlterTable +ALTER TABLE "ai_sessions_metadata" ADD COLUMN "pinned" BOOLEAN NOT NULL DEFAULT false; + +-- AlterTable +CREATE UNIQUE INDEX idx_ai_session_unique_pinned +ON ai_sessions_metadata (user_id, workspace_id) +WHERE pinned = true AND deleted_at IS NULL; + +-- AlterTable +CREATE UNIQUE INDEX idx_ai_session_unique_doc_root +ON ai_sessions_metadata (user_id, workspace_id, doc_id) +WHERE parent_session_id IS NULL AND doc_id IS NOT NULL AND deleted_at IS NULL; + +-- DropIndex +DROP INDEX "ai_sessions_metadata_user_id_workspace_id_idx"; + +-- CreateIndex +CREATE INDEX "ai_sessions_metadata_user_id_workspace_id_doc_id_idx" ON "ai_sessions_metadata"("user_id", "workspace_id", "doc_id"); diff --git a/packages/backend/server/schema.prisma b/packages/backend/server/schema.prisma index c21e6071bf..51e666f944 100644 --- a/packages/backend/server/schema.prisma +++ b/packages/backend/server/schema.prisma @@ -434,8 +434,9 @@ model AiSession { id String @id @default(uuid()) @db.VarChar userId String @map("user_id") @db.VarChar workspaceId String @map("workspace_id") @db.VarChar - docId String @map("doc_id") @db.VarChar + docId String? @map("doc_id") @db.VarChar promptName String @map("prompt_name") @db.VarChar(32) + pinned Boolean @default(false) // the session id of the parent session if this session is a forked session parentSessionId String? @map("parent_session_id") @db.VarChar messageCost Int @default(0) @@ -449,7 +450,7 @@ model AiSession { context AiContext[] @@index([userId]) - @@index([userId, workspaceId]) + @@index([userId, workspaceId, docId]) @@map("ai_sessions_metadata") } diff --git a/packages/backend/server/src/__tests__/__snapshots__/copilot.e2e.ts.md b/packages/backend/server/src/__tests__/__snapshots__/copilot.e2e.ts.md index f5b84bfb9a..3c0ca41331 100644 --- a/packages/backend/server/src/__tests__/__snapshots__/copilot.e2e.ts.md +++ b/packages/backend/server/src/__tests__/__snapshots__/copilot.e2e.ts.md @@ -135,3 +135,31 @@ Generated by [AVA](https://avajs.dev). ], }, ] + +## should create different session types and validate prompt constraints + +> should create session with should create workspace session with text prompt + + [ + { + pinned: false, + }, + ] + +> should create session with should create pinned session with text prompt + + [ + { + docId: 'pinned-doc', + pinned: true, + }, + ] + +> should create session with should create doc session with text prompt + + [ + { + docId: 'normal-doc', + pinned: false, + }, + ] diff --git a/packages/backend/server/src/__tests__/__snapshots__/copilot.e2e.ts.snap b/packages/backend/server/src/__tests__/__snapshots__/copilot.e2e.ts.snap index f69a3815f8..22de4d5a0c 100644 Binary files a/packages/backend/server/src/__tests__/__snapshots__/copilot.e2e.ts.snap and b/packages/backend/server/src/__tests__/__snapshots__/copilot.e2e.ts.snap differ diff --git a/packages/backend/server/src/__tests__/copilot.e2e.ts b/packages/backend/server/src/__tests__/copilot.e2e.ts index f0f95f9013..3de76d525c 100644 --- a/packages/backend/server/src/__tests__/copilot.e2e.ts +++ b/packages/backend/server/src/__tests__/copilot.e2e.ts @@ -48,7 +48,11 @@ import { createCopilotContext, createCopilotMessage, createCopilotSession, + createDocCopilotSession, + createPinnedCopilotSession, + createWorkspaceCopilotSession, forkCopilotSession, + getCopilotSession, getHistories, listContext, listContextDocAndFiles, @@ -302,12 +306,8 @@ test('should fork session correctly', async t => { // prepare session const { id } = await createWorkspace(app); - const sessionId = await createCopilotSession( - app, - id, - randomUUID(), - textPromptName - ); + const docId = randomUUID(); + const sessionId = await createCopilotSession(app, id, docId, textPromptName); let forkedSessionId: string; // should be able to fork session @@ -316,7 +316,7 @@ test('should fork session correctly', async t => { const messageId = await createCopilotMessage(app, sessionId); await chatWithText(app, sessionId, messageId); } - const histories = await getHistories(app, { workspaceId: id }); + const histories = await getHistories(app, { workspaceId: id, docId }); const latestMessageId = histories[0].messages.findLast( m => m.role === 'assistant' )?.id; @@ -375,7 +375,7 @@ test('should fork session correctly', async t => { }); await app.switchUser(u1); - const histories = await getHistories(app, { workspaceId: id }); + const histories = await getHistories(app, { workspaceId: id, docId }); const latestMessageId = histories .find(h => h.sessionId === forkedSessionId) ?.messages.findLast(m => m.role === 'assistant')?.id; @@ -612,10 +612,11 @@ test('should be able to retry with api', async t => { // normal chat { const { id } = await createWorkspace(app); + const docId = randomUUID(); const sessionId = await createCopilotSession( app, id, - randomUUID(), + docId, textPromptName ); const messageId = await createCopilotMessage(app, sessionId); @@ -623,7 +624,7 @@ test('should be able to retry with api', async t => { await chatWithText(app, sessionId, messageId); await chatWithText(app, sessionId, messageId); - const histories = await getHistories(app, { workspaceId: id }); + const histories = await getHistories(app, { workspaceId: id, docId }); t.deepEqual( histories.map(h => h.messages.map(m => m.content)), [['generate text to text', 'generate text to text']], @@ -634,10 +635,11 @@ test('should be able to retry with api', async t => { // retry chat { const { id } = await createWorkspace(app); + const docId = randomUUID(); const sessionId = await createCopilotSession( app, id, - randomUUID(), + docId, textPromptName ); const messageId = await createCopilotMessage(app, sessionId); @@ -646,7 +648,7 @@ test('should be able to retry with api', async t => { await chatWithText(app, sessionId); // should only have 1 message - const histories = await getHistories(app, { workspaceId: id }); + const histories = await getHistories(app, { workspaceId: id, docId }); t.snapshot( cleanObject(histories), 'should be able to list history after retry' @@ -656,10 +658,11 @@ test('should be able to retry with api', async t => { // retry chat with new message id { const { id } = await createWorkspace(app); + const docId = randomUUID(); const sessionId = await createCopilotSession( app, id, - randomUUID(), + docId, textPromptName ); const messageId = await createCopilotMessage(app, sessionId); @@ -669,7 +672,7 @@ test('should be able to retry with api', async t => { await chatWithText(app, sessionId, newMessageId, '', true); // should only have 1 message - const histories = await getHistories(app, { workspaceId: id }); + const histories = await getHistories(app, { workspaceId: id, docId }); t.snapshot( cleanObject(histories), 'should be able to list history after retry' @@ -746,10 +749,11 @@ test('should be able to list history', async t => { const { app } = t.context; const { id: workspaceId } = await createWorkspace(app); + const docId = randomUUID(); const sessionId = await createCopilotSession( app, workspaceId, - randomUUID(), + docId, textPromptName ); @@ -757,7 +761,7 @@ test('should be able to list history', async t => { await chatWithText(app, sessionId, messageId); { - const histories = await getHistories(app, { workspaceId }); + const histories = await getHistories(app, { workspaceId, docId }); t.deepEqual( histories.map(h => h.messages.map(m => m.content)), [['hello', 'generate text to text']], @@ -768,6 +772,7 @@ test('should be able to list history', async t => { { const histories = await getHistories(app, { workspaceId, + docId, options: { messageOrder: 'desc' }, }); t.deepEqual( @@ -809,17 +814,18 @@ test('should reject request that user have not permission', async t => { } { + const docId = randomUUID(); const sessionId = await createCopilotSession( app, workspaceId, - randomUUID(), + docId, textPromptName ); const messageId = await createCopilotMessage(app, sessionId); await chatWithText(app, sessionId, messageId); - const histories = await getHistories(app, { workspaceId }); + const histories = await getHistories(app, { workspaceId, docId }); t.deepEqual( histories.map(h => h.messages.map(m => m.content)), [['generate text to text']], @@ -1072,3 +1078,93 @@ test('should be able to transcript', async t => { } } }); + +test('should create different session types and validate prompt constraints', async t => { + const { app } = t.context; + const { id: workspaceId } = await createWorkspace(app); + + const validateSession = async ( + description: string, + workspaceId: string, + createPromise: Promise + ) => { + const sessionId = await createPromise; + + t.truthy(sessionId, description); + t.snapshot( + cleanObject( + [await getCopilotSession(app, workspaceId, sessionId)], + ['id', 'workspaceId', 'promptName'] + ), + `should create session with ${description}` + ); + return sessionId; + }; + + await validateSession( + 'should create workspace session with text prompt', + workspaceId, + createWorkspaceCopilotSession(app, workspaceId, textPromptName) + ); + await validateSession( + 'should create pinned session with text prompt', + workspaceId, + createPinnedCopilotSession(app, workspaceId, 'pinned-doc', textPromptName) + ); + await validateSession( + 'should create doc session with text prompt', + workspaceId, + createDocCopilotSession(app, workspaceId, 'normal-doc', textPromptName) + ); +}); + +test('should list histories for different session types correctly', async t => { + const { app } = t.context; + const { id: workspaceId } = await createWorkspace(app); + const pinnedDocId = 'pinned-doc'; + const docId = 'normal-doc'; + + // create sessions and add messages + const [workspaceSessionId, pinnedSessionId, docSessionId] = await Promise.all( + [ + createWorkspaceCopilotSession(app, workspaceId, textPromptName), + createPinnedCopilotSession(app, workspaceId, pinnedDocId, textPromptName), + createDocCopilotSession(app, workspaceId, docId, textPromptName), + ] + ); + + await Promise.all([ + createCopilotMessage(app, workspaceSessionId, 'workspace message'), + createCopilotMessage(app, pinnedSessionId, 'pinned message'), + createCopilotMessage(app, docSessionId, 'doc message'), + ]); + + const testHistoryQuery = async ( + queryDocId: string | undefined, + expectedSessionId: string, + description: string + ) => { + const histories = await getHistories(app, { + workspaceId, + docId: queryDocId, + }); + t.is(histories.length, 1, `should return ${description}`); + t.is( + histories[0].sessionId, + expectedSessionId, + `should return correct ${description}` + ); + }; + + await testHistoryQuery( + undefined, + workspaceSessionId, + 'workspace session history' + ); + await testHistoryQuery( + pinnedDocId, + pinnedSessionId, + 'pinned session history' + ); + await testHistoryQuery(docId, docSessionId, 'doc session history'); +}); diff --git a/packages/backend/server/src/__tests__/copilot.spec.ts b/packages/backend/server/src/__tests__/copilot.spec.ts index 44724e8c9b..38fb5094d0 100644 --- a/packages/backend/server/src/__tests__/copilot.spec.ts +++ b/packages/backend/server/src/__tests__/copilot.spec.ts @@ -275,7 +275,7 @@ test('should be able to manage chat session', async t => { ]); const params = { word: 'world' }; - const commonParams = { docId: 'test', workspaceId: 'test' }; + const commonParams = { docId: 'test', workspaceId: 'test', pinned: false }; const sessionId = await session.create({ userId, @@ -342,11 +342,12 @@ test('should be able to update chat session prompt', async t => { docId: 'test', workspaceId: 'test', userId, + pinned: false, }); t.truthy(sessionId, 'should create session'); // Update the session - const updatedSessionId = await session.updateSessionPrompt({ + const updatedSessionId = await session.updateSession({ sessionId, promptName: 'Search With AFFiNE AI', userId, @@ -371,7 +372,7 @@ test('should be able to fork chat session', async t => { ]); const params = { word: 'world' }; - const commonParams = { docId: 'test', workspaceId: 'test' }; + const commonParams = { docId: 'test', workspaceId: 'test', pinned: false }; // create session const sessionId = await session.create({ userId, @@ -494,6 +495,7 @@ test('should be able to process message id', async t => { workspaceId: 'test', userId, promptName: 'prompt', + pinned: false, }); const s = (await session.get(sessionId))!; @@ -537,6 +539,7 @@ test('should be able to generate with message id', async t => { workspaceId: 'test', userId, promptName: 'prompt', + pinned: false, }); const s = (await session.get(sessionId))!; @@ -559,6 +562,7 @@ test('should be able to generate with message id', async t => { workspaceId: 'test', userId, promptName: 'prompt', + pinned: false, }); const s = (await session.get(sessionId))!; @@ -586,6 +590,7 @@ test('should be able to generate with message id', async t => { workspaceId: 'test', userId, promptName: 'prompt', + pinned: false, }); const s = (await session.get(sessionId))!; @@ -614,6 +619,7 @@ test('should save message correctly', async t => { workspaceId: 'test', userId, promptName: 'prompt', + pinned: false, }); const s = (await session.get(sessionId))!; @@ -643,6 +649,7 @@ test('should revert message correctly', async t => { workspaceId: 'test', userId, promptName: 'prompt', + pinned: false, }); const s = (await session.get(sessionId))!; @@ -742,6 +749,7 @@ test('should handle params correctly in chat session', async t => { workspaceId: 'test', userId, promptName: 'prompt', + pinned: false, }); const s = (await session.get(sessionId))!; @@ -1506,6 +1514,7 @@ test('should be able to manage context', async t => { workspaceId: 'test', userId, promptName: 'prompt', + pinned: false, }); // use mocked embedding client @@ -1729,6 +1738,7 @@ test('should be able to manage workspace embedding', async t => { workspaceId: ws.id, userId, promptName: 'prompt', + pinned: false, }); const contextSession = await context.create(sessionId); diff --git a/packages/backend/server/src/__tests__/models/__snapshots__/copilot-session.spec.ts.md b/packages/backend/server/src/__tests__/models/__snapshots__/copilot-session.spec.ts.md new file mode 100644 index 0000000000..301b6c87c9 --- /dev/null +++ b/packages/backend/server/src/__tests__/models/__snapshots__/copilot-session.spec.ts.md @@ -0,0 +1,162 @@ +# Snapshot report for `src/__tests__/models/copilot-session.spec.ts` + +The actual snapshot is saved in `copilot-session.spec.ts.snap`. + +Generated by [AVA](https://avajs.dev). + +## should list and filter session type + +> workspace sessions should include workspace and pinned sessions + + [ + { + docId: null, + pinned: true, + }, + { + docId: null, + pinned: false, + }, + ] + +> doc sessions should only include sessions with matching docId + + [ + { + docId: 'doc-id-1', + pinned: false, + }, + ] + +> session type identification results + + [ + { + session: { + docId: null, + pinned: false, + }, + type: 'workspace', + }, + { + session: { + docId: undefined, + pinned: false, + }, + type: 'workspace', + }, + { + session: { + docId: null, + pinned: true, + }, + type: 'pinned', + }, + { + session: { + docId: 'doc-id-1', + pinned: false, + }, + type: 'doc', + }, + ] + +## should pin and unpin sessions + +> session states after creating second pinned session + + [ + { + docId: null, + id: 'first-session-id', + pinned: false, + }, + { + docId: null, + id: 'second-session-id', + pinned: true, + }, + ] + +> should return false when no sessions to unpin + + false + +> all sessions should be unpinned after unpin operation + + [ + { + id: 'first-session-id', + pinned: false, + }, + { + id: 'second-session-id', + pinned: false, + }, + { + id: 'third-session-id', + pinned: false, + }, + ] + +## session updates and type conversions + +> session states after pinning - should unpin existing + + [ + { + docId: null, + id: 'session-update-id', + pinned: true, + }, + { + docId: null, + id: 'existing-pinned-session-id', + pinned: false, + }, + ] + +> session state after unpinning + + { + docId: null, + id: 'session-update-id', + pinned: false, + } + +> session type conversion steps + + [ + { + session: { + docId: 'doc-update-id', + pinned: false, + }, + step: 'workspace_to_doc', + type: 'doc', + }, + { + session: { + docId: 'doc-update-id', + pinned: true, + }, + step: 'doc_to_pinned', + type: 'pinned', + }, + { + session: { + docId: null, + pinned: false, + }, + step: 'pinned_to_workspace', + type: 'workspace', + }, + { + session: { + docId: null, + pinned: true, + }, + step: 'workspace_to_pinned', + type: 'pinned', + }, + ] diff --git a/packages/backend/server/src/__tests__/models/__snapshots__/copilot-session.spec.ts.snap b/packages/backend/server/src/__tests__/models/__snapshots__/copilot-session.spec.ts.snap new file mode 100644 index 0000000000..7c87d7102d Binary files /dev/null and b/packages/backend/server/src/__tests__/models/__snapshots__/copilot-session.spec.ts.snap differ diff --git a/packages/backend/server/src/__tests__/models/copilot-context.spec.ts b/packages/backend/server/src/__tests__/models/copilot-context.spec.ts index e189f9a9ac..d36542dabf 100644 --- a/packages/backend/server/src/__tests__/models/copilot-context.spec.ts +++ b/packages/backend/server/src/__tests__/models/copilot-context.spec.ts @@ -5,12 +5,14 @@ import ava, { TestFn } from 'ava'; import Sinon from 'sinon'; import { Config } from '../../base'; -import { ContextEmbedStatus } from '../../models/common/copilot'; -import { CopilotContextModel } from '../../models/copilot-context'; -import { CopilotSessionModel } from '../../models/copilot-session'; -import { CopilotWorkspaceConfigModel } from '../../models/copilot-workspace'; -import { UserModel } from '../../models/user'; -import { WorkspaceModel } from '../../models/workspace'; +import { + ContextEmbedStatus, + CopilotContextModel, + CopilotSessionModel, + CopilotWorkspaceConfigModel, + UserModel, + WorkspaceModel, +} from '../../models'; import { createTestingModule, type TestingModule } from '../utils'; import { cleanObject } from '../utils/copilot'; @@ -46,7 +48,7 @@ let docId = 'doc1'; test.beforeEach(async t => { await t.context.module.initTestingDB(); - await t.context.copilotSession.createPrompt('prompt-name', 'gpt-4o'); + await t.context.copilotSession.createPrompt('prompt-name', 'gpt-4.1'); user = await t.context.user.create({ email: 'test@affine.pro', }); diff --git a/packages/backend/server/src/__tests__/models/copilot-session.spec.ts b/packages/backend/server/src/__tests__/models/copilot-session.spec.ts new file mode 100644 index 0000000000..36cb4cc040 --- /dev/null +++ b/packages/backend/server/src/__tests__/models/copilot-session.spec.ts @@ -0,0 +1,341 @@ +import { randomUUID } from 'node:crypto'; + +import { PrismaClient, User, Workspace } from '@prisma/client'; +import ava, { ExecutionContext, TestFn } from 'ava'; + +import { CopilotPromptInvalid } from '../../base'; +import { + CopilotSessionModel, + UpdateChatSessionData, + UserModel, + WorkspaceModel, +} from '../../models'; +import { createTestingModule, type TestingModule } from '../utils'; + +interface Context { + module: TestingModule; + db: PrismaClient; + user: UserModel; + workspace: WorkspaceModel; + copilotSession: CopilotSessionModel; +} + +const test = ava as TestFn; + +test.before(async t => { + const module = await createTestingModule(); + t.context.user = module.get(UserModel); + t.context.workspace = module.get(WorkspaceModel); + t.context.copilotSession = module.get(CopilotSessionModel); + t.context.db = module.get(PrismaClient); + t.context.module = module; +}); + +let user: User; +let workspace: Workspace; + +test.beforeEach(async t => { + await t.context.module.initTestingDB(); + user = await t.context.user.create({ + email: 'test@affine.pro', + }); + workspace = await t.context.workspace.create(user.id); +}); + +test.after(async t => { + await t.context.module.close(); +}); + +const createTestPrompts = async ( + copilotSession: CopilotSessionModel, + db: PrismaClient +) => { + await copilotSession.createPrompt('test-prompt', 'gpt-4.1'); + await db.aiPrompt.create({ + data: { name: 'action-prompt', model: 'gpt-4.1', action: 'edit' }, + }); +}; + +const createTestSession = async ( + t: ExecutionContext, + overrides: Partial<{ + sessionId: string; + userId: string; + workspaceId: string; + docId: string | null; + pinned: boolean; + promptName: string; + }> = {} +) => { + const sessionData = { + sessionId: randomUUID(), + userId: user.id, + workspaceId: workspace.id, + docId: null, + pinned: false, + promptName: 'test-prompt', + ...overrides, + }; + + await t.context.copilotSession.create(sessionData); + return sessionData; +}; + +const getSessionState = async (db: PrismaClient, sessionId: string) => { + const session = await db.aiSession.findUnique({ + where: { id: sessionId }, + select: { id: true, pinned: true, docId: true }, + }); + return session; +}; + +test('should list and filter session type', async t => { + const { copilotSession, db } = t.context; + + await createTestPrompts(copilotSession, db); + + const docId = 'doc-id-1'; + await createTestSession(t, { sessionId: randomUUID() }); + await createTestSession(t, { sessionId: randomUUID(), pinned: true }); + await createTestSession(t, { sessionId: randomUUID(), docId }); + + // should list sessions + { + const workspaceSessions = await copilotSession.list(user.id, workspace.id); + + t.snapshot( + workspaceSessions.map(s => ({ docId: s.docId, pinned: s.pinned })), + 'workspace sessions should include workspace and pinned sessions' + ); + } + + { + const docSessions = await copilotSession.list(user.id, workspace.id, docId); + + t.snapshot( + docSessions.map(s => ({ docId: s.docId, pinned: s.pinned })), + 'doc sessions should only include sessions with matching docId' + ); + } + + // should identify session types + { + // check get session type + const testCases = [ + { docId: null, pinned: false }, + { docId: undefined, pinned: false }, + { docId: null, pinned: true }, + { docId, pinned: false }, + ]; + + const sessionTypeResults = testCases.map(session => ({ + session, + type: copilotSession.getSessionType(session), + })); + + t.snapshot(sessionTypeResults, 'session type identification results'); + } +}); + +test('should check session validation for prompts', async t => { + const { copilotSession, db } = t.context; + + await createTestPrompts(copilotSession, db); + + const docId = randomUUID(); + const sessionTypes = [ + { name: 'workspace', session: { docId: null, pinned: false } }, + { name: 'pinned', session: { docId: null, pinned: true } }, + { name: 'doc', session: { docId, pinned: false } }, + ]; + + // non-action prompts should work for all session types + sessionTypes.forEach(({ name, session }) => { + t.notThrows( + () => + copilotSession.checkSessionPrompt(session, 'test-prompt', undefined), + `${name} session should allow non-action prompts` + ); + }); + + // action prompts should only work for doc session type + { + const actionPromptTests = [ + { + name: 'workspace', + session: sessionTypes[0].session, + shouldThrow: true, + }, + { name: 'pinned', session: sessionTypes[1].session, shouldThrow: true }, + { name: 'doc', session: sessionTypes[2].session, shouldThrow: false }, + ]; + + actionPromptTests.forEach(({ name, session, shouldThrow }) => { + if (shouldThrow) { + t.throws( + () => + copilotSession.checkSessionPrompt(session, 'action-prompt', 'edit'), + { instanceOf: CopilotPromptInvalid }, + `${name} session should reject action prompts` + ); + } else { + t.notThrows( + () => + copilotSession.checkSessionPrompt(session, 'action-prompt', 'edit'), + `${name} session should allow action prompts` + ); + } + }); + } +}); + +test('should pin and unpin sessions', async t => { + const { copilotSession, db } = t.context; + + await createTestPrompts(copilotSession, db); + + const firstSessionId = 'first-session-id'; + const secondSessionId = 'second-session-id'; + const thirdSessionId = 'third-session-id'; + + // should unpin existing pinned session when creating a new one + { + await copilotSession.create({ + sessionId: firstSessionId, + userId: user.id, + workspaceId: workspace.id, + docId: null, + promptName: 'test-prompt', + pinned: true, + }); + + const firstSession = await copilotSession.get(firstSessionId); + t.truthy(firstSession, 'first session should be created successfully'); + t.is(firstSession?.pinned, true, 'first session should be pinned'); + + // should unpin the first one when creating second pinned session + await copilotSession.create({ + sessionId: secondSessionId, + userId: user.id, + workspaceId: workspace.id, + docId: null, + promptName: 'test-prompt', + pinned: true, + }); + + const sessionStatesAfterSecondPin = await Promise.all([ + getSessionState(db, firstSessionId), + getSessionState(db, secondSessionId), + ]); + + t.snapshot( + sessionStatesAfterSecondPin, + 'session states after creating second pinned session' + ); + } + + // should can unpin a pinned session + { + await createTestSession(t, { sessionId: thirdSessionId, pinned: true }); + const unpinResult = await copilotSession.unpin(workspace.id, user.id); + t.is( + unpinResult, + true, + 'unpin operation should return true when sessions are unpinned' + ); + + const unpinResultAgain = await copilotSession.unpin(workspace.id, user.id); + t.snapshot( + unpinResultAgain, + 'should return false when no sessions to unpin' + ); + } + + // should unpin all sessions + { + const allSessionsAfterUnpin = await db.aiSession.findMany({ + where: { id: { in: [firstSessionId, secondSessionId, thirdSessionId] } }, + select: { pinned: true, id: true }, + orderBy: { id: 'asc' }, + }); + + t.snapshot( + allSessionsAfterUnpin, + 'all sessions should be unpinned after unpin operation' + ); + } +}); + +test('session updates and type conversions', async t => { + const { copilotSession, db } = t.context; + + await createTestPrompts(copilotSession, db); + + const sessionId = 'session-update-id'; + const docId = 'doc-update-id'; + + await createTestSession(t, { sessionId }); + + // should unpin existing pinned session + { + const existingPinnedId = 'existing-pinned-session-id'; + await createTestSession(t, { sessionId: existingPinnedId, pinned: true }); + + await copilotSession.update(user.id, sessionId, { pinned: true }); + + const sessionStatesAfterPin = await Promise.all([ + getSessionState(db, sessionId), + getSessionState(db, existingPinnedId), + ]); + + t.snapshot( + sessionStatesAfterPin, + 'session states after pinning - should unpin existing' + ); + } + + // should unpin the session + { + await copilotSession.update(user.id, sessionId, { pinned: false }); + const sessionStateAfterUnpin = await getSessionState(db, sessionId); + t.snapshot(sessionStateAfterUnpin, 'session state after unpinning'); + } + + // should convert session types + { + const conversionSteps: any[] = []; + + let session = await db.aiSession.findUnique({ + where: { id: sessionId }, + select: { docId: true, pinned: true }, + }); + + const convertSession = async ( + step: string, + data: UpdateChatSessionData + ) => { + await copilotSession.update(user.id, sessionId, data); + session = await db.aiSession.findUnique({ + where: { id: sessionId }, + select: { docId: true, pinned: true }, + }); + conversionSteps.push({ + step, + session, + type: copilotSession.getSessionType(session!), + }); + }; + + { + await convertSession('workspace_to_doc', { docId }); // Workspace → Doc session + await convertSession('doc_to_pinned', { pinned: true }); // Doc → Pinned session + await convertSession('pinned_to_workspace', { + pinned: false, + docId: null, + }); // Pinned → Workspace session + await convertSession('workspace_to_pinned', { pinned: true }); // Workspace → Pinned session + } + + t.snapshot(conversionSteps, 'session type conversion steps'); + } +}); diff --git a/packages/backend/server/src/__tests__/utils/copilot.ts b/packages/backend/server/src/__tests__/utils/copilot.ts index a907ea671c..30009e8516 100644 --- a/packages/backend/server/src/__tests__/utils/copilot.ts +++ b/packages/backend/server/src/__tests__/utils/copilot.ts @@ -20,8 +20,9 @@ export const cleanObject = ( export async function createCopilotSession( app: TestingApp, workspaceId: string, - docId: string, - promptName: string + docId: string | null, + promptName: string, + pinned: boolean = false ): Promise { const res = await app.gql( ` @@ -29,12 +30,73 @@ export async function createCopilotSession( createCopilotSession(options: $options) } `, - { options: { workspaceId, docId, promptName } } + { options: { workspaceId, docId, promptName, pinned } } ); return res.createCopilotSession; } +export async function createWorkspaceCopilotSession( + app: TestingApp, + workspaceId: string, + promptName: string +): Promise { + return createCopilotSession(app, workspaceId, null, promptName); +} + +export async function createPinnedCopilotSession( + app: TestingApp, + workspaceId: string, + docId: string, + promptName: string +): Promise { + return createCopilotSession(app, workspaceId, docId, promptName, true); +} + +export async function createDocCopilotSession( + app: TestingApp, + workspaceId: string, + docId: string, + promptName: string +): Promise { + return createCopilotSession(app, workspaceId, docId, promptName); +} + +export async function getCopilotSession( + app: TestingApp, + workspaceId: string, + sessionId: string +): Promise<{ + id: string; + docId: string | null; + parentSessionId: string | null; + pinned: boolean; + promptName: string; +}> { + const res = await app.gql( + ` + query getCopilotSession( + $workspaceId: String! + $sessionId: String! + ) { + currentUser { + copilot(workspaceId: $workspaceId) { + session(sessionId: $sessionId) { + id + docId + parentSessionId + pinned + promptName + } + } + } + }`, + { workspaceId, sessionId } + ); + + return res.currentUser?.copilot?.session; +} + export async function updateCopilotSession( app: TestingApp, sessionId: string, diff --git a/packages/backend/server/src/base/error/def.ts b/packages/backend/server/src/base/error/def.ts index ba27d64909..27ea6d77ad 100644 --- a/packages/backend/server/src/base/error/def.ts +++ b/packages/backend/server/src/base/error/def.ts @@ -643,6 +643,10 @@ export const USER_FRIENDLY_ERRORS = { type: 'resource_not_found', message: `Copilot session not found.`, }, + copilot_session_invalid_input: { + type: 'invalid_input', + message: `Copilot session input is invalid.`, + }, copilot_session_deleted: { type: 'action_forbidden', message: `Copilot session has been deleted.`, diff --git a/packages/backend/server/src/base/error/errors.gen.ts b/packages/backend/server/src/base/error/errors.gen.ts index a57749436f..be86944d96 100644 --- a/packages/backend/server/src/base/error/errors.gen.ts +++ b/packages/backend/server/src/base/error/errors.gen.ts @@ -657,6 +657,12 @@ export class CopilotSessionNotFound extends UserFriendlyError { } } +export class CopilotSessionInvalidInput extends UserFriendlyError { + constructor(message?: string) { + super('invalid_input', 'copilot_session_invalid_input', message); + } +} + export class CopilotSessionDeleted extends UserFriendlyError { constructor(message?: string) { super('action_forbidden', 'copilot_session_deleted', message); @@ -1145,6 +1151,7 @@ export enum ErrorNames { WORKSPACE_ID_REQUIRED_FOR_TEAM_SUBSCRIPTION, WORKSPACE_ID_REQUIRED_TO_UPDATE_TEAM_SUBSCRIPTION, COPILOT_SESSION_NOT_FOUND, + COPILOT_SESSION_INVALID_INPUT, COPILOT_SESSION_DELETED, NO_COPILOT_PROVIDER_AVAILABLE, COPILOT_FAILED_TO_GENERATE_TEXT, diff --git a/packages/backend/server/src/models/copilot-session.ts b/packages/backend/server/src/models/copilot-session.ts index 106acdf0cd..c335ae364a 100644 --- a/packages/backend/server/src/models/copilot-session.ts +++ b/packages/backend/server/src/models/copilot-session.ts @@ -1,36 +1,366 @@ import { Injectable } from '@nestjs/common'; +import { Transactional } from '@nestjs-cls/transactional'; +import { AiPromptRole, Prisma } from '@prisma/client'; +import { omit } from 'lodash-es'; +import { + CopilotPromptInvalid, + CopilotSessionDeleted, + CopilotSessionInvalidInput, + CopilotSessionNotFound, +} from '../base'; import { BaseModel } from './base'; -interface ChatSessionState { +export enum SessionType { + Workspace = 'workspace', // docId is null and pinned is false + Pinned = 'pinned', // pinned is true + Doc = 'doc', // docId points to specific document +} + +type ChatAttachment = { attachment: string; mimeType: string } | string; + +type ChatStreamObject = { + type: 'text-delta' | 'reasoning' | 'tool-call' | 'tool-result'; + textDelta?: string; + toolCallId?: string; + toolName?: string; + args?: Record; + result?: any; +}; + +type ChatMessage = { + id?: string | undefined; + role: 'system' | 'assistant' | 'user'; + content: string; + attachments?: ChatAttachment[] | null; + params?: Record | null; + streamObjects?: ChatStreamObject[] | null; + createdAt: Date; +}; + +type ChatSession = { sessionId: string; workspaceId: string; - docId: string; + docId?: string | null; + pinned?: boolean; + messages?: ChatMessage[]; // connect ids userId: string; promptName: string; -} + parentSessionId?: string | null; +}; + +export type UpdateChatSessionData = Partial< + Pick +>; +export type UpdateChatSession = Pick & + UpdateChatSessionData; + +export type ListSessionOptions = { + sessionId: string | undefined; + action: boolean | undefined; + fork: boolean | undefined; + limit: number | undefined; + skip: number | undefined; + sessionOrder: 'asc' | 'desc' | undefined; + messageOrder: 'asc' | 'desc' | undefined; +}; -// TODO(@darkskygit): not ready to replace business codes yet, just for test @Injectable() export class CopilotSessionModel extends BaseModel { - async create(state: ChatSessionState) { - const row = await this.db.aiSession.create({ - data: { - id: state.sessionId, - workspaceId: state.workspaceId, - docId: state.docId, - // connect - userId: state.userId, - promptName: state.promptName, - }, - }); - return row; + getSessionType(session: Pick): SessionType { + if (session.pinned) return SessionType.Pinned; + if (!session.docId) return SessionType.Workspace; + return SessionType.Doc; } + checkSessionPrompt( + session: Pick, + promptName: string, + promptAction: string | undefined + ): boolean { + const sessionType = this.getSessionType(session); + + // workspace and pinned sessions cannot use action prompts + if ( + [SessionType.Workspace, SessionType.Pinned].includes(sessionType) && + !!promptAction?.trim() + ) { + throw new CopilotPromptInvalid( + `${promptName} are not allowed for ${sessionType} sessions` + ); + } + + return true; + } + + // NOTE: just for test, remove it after copilot prompt model is ready async createPrompt(name: string, model: string) { await this.db.aiPrompt.create({ data: { name, model }, }); } + + @Transactional() + async create(state: ChatSession) { + if (state.pinned) { + await this.unpin(state.workspaceId, state.userId); + } + + const row = await this.db.aiSession.create({ + data: { + id: state.sessionId, + workspaceId: state.workspaceId, + docId: state.docId, + pinned: state.pinned ?? false, + // connect + userId: state.userId, + promptName: state.promptName, + parentSessionId: state.parentSessionId, + }, + }); + return row; + } + + @Transactional() + async has( + sessionId: string, + userId: string, + params?: Prisma.AiSessionCountArgs['where'] + ) { + return await this.db.aiSession + .count({ where: { id: sessionId, userId, ...params } }) + .then(c => c > 0); + } + + @Transactional() + async getChatSessionId(state: Omit) { + const extraCondition: Record = {}; + if (state.parentSessionId) { + // also check session id if provided session is forked session + extraCondition.id = state.sessionId; + extraCondition.parentSessionId = state.parentSessionId; + } + + const session = await this.db.aiSession.findFirst({ + where: { + userId: state.userId, + workspaceId: state.workspaceId, + docId: state.docId, + parentSessionId: null, + prompt: { action: { equals: null } }, + ...extraCondition, + }, + select: { id: true, deletedAt: true }, + }); + if (session?.deletedAt) throw new CopilotSessionDeleted(); + return session?.id; + } + + @Transactional() + async getExists