feat(server): allow chat session dangling & pin session support (#12849)

fix AI-181
fix AI-179
fix AI-178
fix PD-2682
fix PD-2683
This commit is contained in:
DarkSky
2025-06-19 13:17:01 +08:00
committed by GitHub
parent d80bfac1d2
commit bd04930560
28 changed files with 1422 additions and 394 deletions

View File

@@ -0,0 +1,21 @@
-- AlterTable
ALTER TABLE "ai_sessions_metadata" ALTER COLUMN "doc_id" DROP NOT NULL;
-- AlterTable
ALTER TABLE "ai_sessions_metadata" ADD COLUMN "pinned" BOOLEAN NOT NULL DEFAULT false;
-- AlterTable
CREATE UNIQUE INDEX idx_ai_session_unique_pinned
ON ai_sessions_metadata (user_id, workspace_id)
WHERE pinned = true AND deleted_at IS NULL;
-- AlterTable
CREATE UNIQUE INDEX idx_ai_session_unique_doc_root
ON ai_sessions_metadata (user_id, workspace_id, doc_id)
WHERE parent_session_id IS NULL AND doc_id IS NOT NULL AND deleted_at IS NULL;
-- DropIndex
DROP INDEX "ai_sessions_metadata_user_id_workspace_id_idx";
-- CreateIndex
CREATE INDEX "ai_sessions_metadata_user_id_workspace_id_doc_id_idx" ON "ai_sessions_metadata"("user_id", "workspace_id", "doc_id");

View File

@@ -434,8 +434,9 @@ model AiSession {
id String @id @default(uuid()) @db.VarChar
userId String @map("user_id") @db.VarChar
workspaceId String @map("workspace_id") @db.VarChar
docId String @map("doc_id") @db.VarChar
docId String? @map("doc_id") @db.VarChar
promptName String @map("prompt_name") @db.VarChar(32)
pinned Boolean @default(false)
// the session id of the parent session if this session is a forked session
parentSessionId String? @map("parent_session_id") @db.VarChar
messageCost Int @default(0)
@@ -449,7 +450,7 @@ model AiSession {
context AiContext[]
@@index([userId])
@@index([userId, workspaceId])
@@index([userId, workspaceId, docId])
@@map("ai_sessions_metadata")
}

View File

@@ -135,3 +135,31 @@ Generated by [AVA](https://avajs.dev).
],
},
]
## should create different session types and validate prompt constraints
> should create session with should create workspace session with text prompt
[
{
pinned: false,
},
]
> should create session with should create pinned session with text prompt
[
{
docId: 'pinned-doc',
pinned: true,
},
]
> should create session with should create doc session with text prompt
[
{
docId: 'normal-doc',
pinned: false,
},
]

View File

@@ -48,7 +48,11 @@ import {
createCopilotContext,
createCopilotMessage,
createCopilotSession,
createDocCopilotSession,
createPinnedCopilotSession,
createWorkspaceCopilotSession,
forkCopilotSession,
getCopilotSession,
getHistories,
listContext,
listContextDocAndFiles,
@@ -302,12 +306,8 @@ test('should fork session correctly', async t => {
// prepare session
const { id } = await createWorkspace(app);
const sessionId = await createCopilotSession(
app,
id,
randomUUID(),
textPromptName
);
const docId = randomUUID();
const sessionId = await createCopilotSession(app, id, docId, textPromptName);
let forkedSessionId: string;
// should be able to fork session
@@ -316,7 +316,7 @@ test('should fork session correctly', async t => {
const messageId = await createCopilotMessage(app, sessionId);
await chatWithText(app, sessionId, messageId);
}
const histories = await getHistories(app, { workspaceId: id });
const histories = await getHistories(app, { workspaceId: id, docId });
const latestMessageId = histories[0].messages.findLast(
m => m.role === 'assistant'
)?.id;
@@ -375,7 +375,7 @@ test('should fork session correctly', async t => {
});
await app.switchUser(u1);
const histories = await getHistories(app, { workspaceId: id });
const histories = await getHistories(app, { workspaceId: id, docId });
const latestMessageId = histories
.find(h => h.sessionId === forkedSessionId)
?.messages.findLast(m => m.role === 'assistant')?.id;
@@ -612,10 +612,11 @@ test('should be able to retry with api', async t => {
// normal chat
{
const { id } = await createWorkspace(app);
const docId = randomUUID();
const sessionId = await createCopilotSession(
app,
id,
randomUUID(),
docId,
textPromptName
);
const messageId = await createCopilotMessage(app, sessionId);
@@ -623,7 +624,7 @@ test('should be able to retry with api', async t => {
await chatWithText(app, sessionId, messageId);
await chatWithText(app, sessionId, messageId);
const histories = await getHistories(app, { workspaceId: id });
const histories = await getHistories(app, { workspaceId: id, docId });
t.deepEqual(
histories.map(h => h.messages.map(m => m.content)),
[['generate text to text', 'generate text to text']],
@@ -634,10 +635,11 @@ test('should be able to retry with api', async t => {
// retry chat
{
const { id } = await createWorkspace(app);
const docId = randomUUID();
const sessionId = await createCopilotSession(
app,
id,
randomUUID(),
docId,
textPromptName
);
const messageId = await createCopilotMessage(app, sessionId);
@@ -646,7 +648,7 @@ test('should be able to retry with api', async t => {
await chatWithText(app, sessionId);
// should only have 1 message
const histories = await getHistories(app, { workspaceId: id });
const histories = await getHistories(app, { workspaceId: id, docId });
t.snapshot(
cleanObject(histories),
'should be able to list history after retry'
@@ -656,10 +658,11 @@ test('should be able to retry with api', async t => {
// retry chat with new message id
{
const { id } = await createWorkspace(app);
const docId = randomUUID();
const sessionId = await createCopilotSession(
app,
id,
randomUUID(),
docId,
textPromptName
);
const messageId = await createCopilotMessage(app, sessionId);
@@ -669,7 +672,7 @@ test('should be able to retry with api', async t => {
await chatWithText(app, sessionId, newMessageId, '', true);
// should only have 1 message
const histories = await getHistories(app, { workspaceId: id });
const histories = await getHistories(app, { workspaceId: id, docId });
t.snapshot(
cleanObject(histories),
'should be able to list history after retry'
@@ -746,10 +749,11 @@ test('should be able to list history', async t => {
const { app } = t.context;
const { id: workspaceId } = await createWorkspace(app);
const docId = randomUUID();
const sessionId = await createCopilotSession(
app,
workspaceId,
randomUUID(),
docId,
textPromptName
);
@@ -757,7 +761,7 @@ test('should be able to list history', async t => {
await chatWithText(app, sessionId, messageId);
{
const histories = await getHistories(app, { workspaceId });
const histories = await getHistories(app, { workspaceId, docId });
t.deepEqual(
histories.map(h => h.messages.map(m => m.content)),
[['hello', 'generate text to text']],
@@ -768,6 +772,7 @@ test('should be able to list history', async t => {
{
const histories = await getHistories(app, {
workspaceId,
docId,
options: { messageOrder: 'desc' },
});
t.deepEqual(
@@ -809,17 +814,18 @@ test('should reject request that user have not permission', async t => {
}
{
const docId = randomUUID();
const sessionId = await createCopilotSession(
app,
workspaceId,
randomUUID(),
docId,
textPromptName
);
const messageId = await createCopilotMessage(app, sessionId);
await chatWithText(app, sessionId, messageId);
const histories = await getHistories(app, { workspaceId });
const histories = await getHistories(app, { workspaceId, docId });
t.deepEqual(
histories.map(h => h.messages.map(m => m.content)),
[['generate text to text']],
@@ -1072,3 +1078,93 @@ test('should be able to transcript', async t => {
}
}
});
test('should create different session types and validate prompt constraints', async t => {
const { app } = t.context;
const { id: workspaceId } = await createWorkspace(app);
const validateSession = async (
description: string,
workspaceId: string,
createPromise: Promise<string>
) => {
const sessionId = await createPromise;
t.truthy(sessionId, description);
t.snapshot(
cleanObject(
[await getCopilotSession(app, workspaceId, sessionId)],
['id', 'workspaceId', 'promptName']
),
`should create session with ${description}`
);
return sessionId;
};
await validateSession(
'should create workspace session with text prompt',
workspaceId,
createWorkspaceCopilotSession(app, workspaceId, textPromptName)
);
await validateSession(
'should create pinned session with text prompt',
workspaceId,
createPinnedCopilotSession(app, workspaceId, 'pinned-doc', textPromptName)
);
await validateSession(
'should create doc session with text prompt',
workspaceId,
createDocCopilotSession(app, workspaceId, 'normal-doc', textPromptName)
);
});
test('should list histories for different session types correctly', async t => {
const { app } = t.context;
const { id: workspaceId } = await createWorkspace(app);
const pinnedDocId = 'pinned-doc';
const docId = 'normal-doc';
// create sessions and add messages
const [workspaceSessionId, pinnedSessionId, docSessionId] = await Promise.all(
[
createWorkspaceCopilotSession(app, workspaceId, textPromptName),
createPinnedCopilotSession(app, workspaceId, pinnedDocId, textPromptName),
createDocCopilotSession(app, workspaceId, docId, textPromptName),
]
);
await Promise.all([
createCopilotMessage(app, workspaceSessionId, 'workspace message'),
createCopilotMessage(app, pinnedSessionId, 'pinned message'),
createCopilotMessage(app, docSessionId, 'doc message'),
]);
const testHistoryQuery = async (
queryDocId: string | undefined,
expectedSessionId: string,
description: string
) => {
const histories = await getHistories(app, {
workspaceId,
docId: queryDocId,
});
t.is(histories.length, 1, `should return ${description}`);
t.is(
histories[0].sessionId,
expectedSessionId,
`should return correct ${description}`
);
};
await testHistoryQuery(
undefined,
workspaceSessionId,
'workspace session history'
);
await testHistoryQuery(
pinnedDocId,
pinnedSessionId,
'pinned session history'
);
await testHistoryQuery(docId, docSessionId, 'doc session history');
});

View File

@@ -275,7 +275,7 @@ test('should be able to manage chat session', async t => {
]);
const params = { word: 'world' };
const commonParams = { docId: 'test', workspaceId: 'test' };
const commonParams = { docId: 'test', workspaceId: 'test', pinned: false };
const sessionId = await session.create({
userId,
@@ -342,11 +342,12 @@ test('should be able to update chat session prompt', async t => {
docId: 'test',
workspaceId: 'test',
userId,
pinned: false,
});
t.truthy(sessionId, 'should create session');
// Update the session
const updatedSessionId = await session.updateSessionPrompt({
const updatedSessionId = await session.updateSession({
sessionId,
promptName: 'Search With AFFiNE AI',
userId,
@@ -371,7 +372,7 @@ test('should be able to fork chat session', async t => {
]);
const params = { word: 'world' };
const commonParams = { docId: 'test', workspaceId: 'test' };
const commonParams = { docId: 'test', workspaceId: 'test', pinned: false };
// create session
const sessionId = await session.create({
userId,
@@ -494,6 +495,7 @@ test('should be able to process message id', async t => {
workspaceId: 'test',
userId,
promptName: 'prompt',
pinned: false,
});
const s = (await session.get(sessionId))!;
@@ -537,6 +539,7 @@ test('should be able to generate with message id', async t => {
workspaceId: 'test',
userId,
promptName: 'prompt',
pinned: false,
});
const s = (await session.get(sessionId))!;
@@ -559,6 +562,7 @@ test('should be able to generate with message id', async t => {
workspaceId: 'test',
userId,
promptName: 'prompt',
pinned: false,
});
const s = (await session.get(sessionId))!;
@@ -586,6 +590,7 @@ test('should be able to generate with message id', async t => {
workspaceId: 'test',
userId,
promptName: 'prompt',
pinned: false,
});
const s = (await session.get(sessionId))!;
@@ -614,6 +619,7 @@ test('should save message correctly', async t => {
workspaceId: 'test',
userId,
promptName: 'prompt',
pinned: false,
});
const s = (await session.get(sessionId))!;
@@ -643,6 +649,7 @@ test('should revert message correctly', async t => {
workspaceId: 'test',
userId,
promptName: 'prompt',
pinned: false,
});
const s = (await session.get(sessionId))!;
@@ -742,6 +749,7 @@ test('should handle params correctly in chat session', async t => {
workspaceId: 'test',
userId,
promptName: 'prompt',
pinned: false,
});
const s = (await session.get(sessionId))!;
@@ -1506,6 +1514,7 @@ test('should be able to manage context', async t => {
workspaceId: 'test',
userId,
promptName: 'prompt',
pinned: false,
});
// use mocked embedding client
@@ -1729,6 +1738,7 @@ test('should be able to manage workspace embedding', async t => {
workspaceId: ws.id,
userId,
promptName: 'prompt',
pinned: false,
});
const contextSession = await context.create(sessionId);

View File

@@ -0,0 +1,162 @@
# Snapshot report for `src/__tests__/models/copilot-session.spec.ts`
The actual snapshot is saved in `copilot-session.spec.ts.snap`.
Generated by [AVA](https://avajs.dev).
## should list and filter session type
> workspace sessions should include workspace and pinned sessions
[
{
docId: null,
pinned: true,
},
{
docId: null,
pinned: false,
},
]
> doc sessions should only include sessions with matching docId
[
{
docId: 'doc-id-1',
pinned: false,
},
]
> session type identification results
[
{
session: {
docId: null,
pinned: false,
},
type: 'workspace',
},
{
session: {
docId: undefined,
pinned: false,
},
type: 'workspace',
},
{
session: {
docId: null,
pinned: true,
},
type: 'pinned',
},
{
session: {
docId: 'doc-id-1',
pinned: false,
},
type: 'doc',
},
]
## should pin and unpin sessions
> session states after creating second pinned session
[
{
docId: null,
id: 'first-session-id',
pinned: false,
},
{
docId: null,
id: 'second-session-id',
pinned: true,
},
]
> should return false when no sessions to unpin
false
> all sessions should be unpinned after unpin operation
[
{
id: 'first-session-id',
pinned: false,
},
{
id: 'second-session-id',
pinned: false,
},
{
id: 'third-session-id',
pinned: false,
},
]
## session updates and type conversions
> session states after pinning - should unpin existing
[
{
docId: null,
id: 'session-update-id',
pinned: true,
},
{
docId: null,
id: 'existing-pinned-session-id',
pinned: false,
},
]
> session state after unpinning
{
docId: null,
id: 'session-update-id',
pinned: false,
}
> session type conversion steps
[
{
session: {
docId: 'doc-update-id',
pinned: false,
},
step: 'workspace_to_doc',
type: 'doc',
},
{
session: {
docId: 'doc-update-id',
pinned: true,
},
step: 'doc_to_pinned',
type: 'pinned',
},
{
session: {
docId: null,
pinned: false,
},
step: 'pinned_to_workspace',
type: 'workspace',
},
{
session: {
docId: null,
pinned: true,
},
step: 'workspace_to_pinned',
type: 'pinned',
},
]

View File

@@ -5,12 +5,14 @@ import ava, { TestFn } from 'ava';
import Sinon from 'sinon';
import { Config } from '../../base';
import { ContextEmbedStatus } from '../../models/common/copilot';
import { CopilotContextModel } from '../../models/copilot-context';
import { CopilotSessionModel } from '../../models/copilot-session';
import { CopilotWorkspaceConfigModel } from '../../models/copilot-workspace';
import { UserModel } from '../../models/user';
import { WorkspaceModel } from '../../models/workspace';
import {
ContextEmbedStatus,
CopilotContextModel,
CopilotSessionModel,
CopilotWorkspaceConfigModel,
UserModel,
WorkspaceModel,
} from '../../models';
import { createTestingModule, type TestingModule } from '../utils';
import { cleanObject } from '../utils/copilot';
@@ -46,7 +48,7 @@ let docId = 'doc1';
test.beforeEach(async t => {
await t.context.module.initTestingDB();
await t.context.copilotSession.createPrompt('prompt-name', 'gpt-4o');
await t.context.copilotSession.createPrompt('prompt-name', 'gpt-4.1');
user = await t.context.user.create({
email: 'test@affine.pro',
});

View File

@@ -0,0 +1,341 @@
import { randomUUID } from 'node:crypto';
import { PrismaClient, User, Workspace } from '@prisma/client';
import ava, { ExecutionContext, TestFn } from 'ava';
import { CopilotPromptInvalid } from '../../base';
import {
CopilotSessionModel,
UpdateChatSessionData,
UserModel,
WorkspaceModel,
} from '../../models';
import { createTestingModule, type TestingModule } from '../utils';
interface Context {
module: TestingModule;
db: PrismaClient;
user: UserModel;
workspace: WorkspaceModel;
copilotSession: CopilotSessionModel;
}
const test = ava as TestFn<Context>;
test.before(async t => {
const module = await createTestingModule();
t.context.user = module.get(UserModel);
t.context.workspace = module.get(WorkspaceModel);
t.context.copilotSession = module.get(CopilotSessionModel);
t.context.db = module.get(PrismaClient);
t.context.module = module;
});
let user: User;
let workspace: Workspace;
test.beforeEach(async t => {
await t.context.module.initTestingDB();
user = await t.context.user.create({
email: 'test@affine.pro',
});
workspace = await t.context.workspace.create(user.id);
});
test.after(async t => {
await t.context.module.close();
});
const createTestPrompts = async (
copilotSession: CopilotSessionModel,
db: PrismaClient
) => {
await copilotSession.createPrompt('test-prompt', 'gpt-4.1');
await db.aiPrompt.create({
data: { name: 'action-prompt', model: 'gpt-4.1', action: 'edit' },
});
};
const createTestSession = async (
t: ExecutionContext<Context>,
overrides: Partial<{
sessionId: string;
userId: string;
workspaceId: string;
docId: string | null;
pinned: boolean;
promptName: string;
}> = {}
) => {
const sessionData = {
sessionId: randomUUID(),
userId: user.id,
workspaceId: workspace.id,
docId: null,
pinned: false,
promptName: 'test-prompt',
...overrides,
};
await t.context.copilotSession.create(sessionData);
return sessionData;
};
const getSessionState = async (db: PrismaClient, sessionId: string) => {
const session = await db.aiSession.findUnique({
where: { id: sessionId },
select: { id: true, pinned: true, docId: true },
});
return session;
};
test('should list and filter session type', async t => {
const { copilotSession, db } = t.context;
await createTestPrompts(copilotSession, db);
const docId = 'doc-id-1';
await createTestSession(t, { sessionId: randomUUID() });
await createTestSession(t, { sessionId: randomUUID(), pinned: true });
await createTestSession(t, { sessionId: randomUUID(), docId });
// should list sessions
{
const workspaceSessions = await copilotSession.list(user.id, workspace.id);
t.snapshot(
workspaceSessions.map(s => ({ docId: s.docId, pinned: s.pinned })),
'workspace sessions should include workspace and pinned sessions'
);
}
{
const docSessions = await copilotSession.list(user.id, workspace.id, docId);
t.snapshot(
docSessions.map(s => ({ docId: s.docId, pinned: s.pinned })),
'doc sessions should only include sessions with matching docId'
);
}
// should identify session types
{
// check get session type
const testCases = [
{ docId: null, pinned: false },
{ docId: undefined, pinned: false },
{ docId: null, pinned: true },
{ docId, pinned: false },
];
const sessionTypeResults = testCases.map(session => ({
session,
type: copilotSession.getSessionType(session),
}));
t.snapshot(sessionTypeResults, 'session type identification results');
}
});
test('should check session validation for prompts', async t => {
const { copilotSession, db } = t.context;
await createTestPrompts(copilotSession, db);
const docId = randomUUID();
const sessionTypes = [
{ name: 'workspace', session: { docId: null, pinned: false } },
{ name: 'pinned', session: { docId: null, pinned: true } },
{ name: 'doc', session: { docId, pinned: false } },
];
// non-action prompts should work for all session types
sessionTypes.forEach(({ name, session }) => {
t.notThrows(
() =>
copilotSession.checkSessionPrompt(session, 'test-prompt', undefined),
`${name} session should allow non-action prompts`
);
});
// action prompts should only work for doc session type
{
const actionPromptTests = [
{
name: 'workspace',
session: sessionTypes[0].session,
shouldThrow: true,
},
{ name: 'pinned', session: sessionTypes[1].session, shouldThrow: true },
{ name: 'doc', session: sessionTypes[2].session, shouldThrow: false },
];
actionPromptTests.forEach(({ name, session, shouldThrow }) => {
if (shouldThrow) {
t.throws(
() =>
copilotSession.checkSessionPrompt(session, 'action-prompt', 'edit'),
{ instanceOf: CopilotPromptInvalid },
`${name} session should reject action prompts`
);
} else {
t.notThrows(
() =>
copilotSession.checkSessionPrompt(session, 'action-prompt', 'edit'),
`${name} session should allow action prompts`
);
}
});
}
});
test('should pin and unpin sessions', async t => {
const { copilotSession, db } = t.context;
await createTestPrompts(copilotSession, db);
const firstSessionId = 'first-session-id';
const secondSessionId = 'second-session-id';
const thirdSessionId = 'third-session-id';
// should unpin existing pinned session when creating a new one
{
await copilotSession.create({
sessionId: firstSessionId,
userId: user.id,
workspaceId: workspace.id,
docId: null,
promptName: 'test-prompt',
pinned: true,
});
const firstSession = await copilotSession.get(firstSessionId);
t.truthy(firstSession, 'first session should be created successfully');
t.is(firstSession?.pinned, true, 'first session should be pinned');
// should unpin the first one when creating second pinned session
await copilotSession.create({
sessionId: secondSessionId,
userId: user.id,
workspaceId: workspace.id,
docId: null,
promptName: 'test-prompt',
pinned: true,
});
const sessionStatesAfterSecondPin = await Promise.all([
getSessionState(db, firstSessionId),
getSessionState(db, secondSessionId),
]);
t.snapshot(
sessionStatesAfterSecondPin,
'session states after creating second pinned session'
);
}
// should can unpin a pinned session
{
await createTestSession(t, { sessionId: thirdSessionId, pinned: true });
const unpinResult = await copilotSession.unpin(workspace.id, user.id);
t.is(
unpinResult,
true,
'unpin operation should return true when sessions are unpinned'
);
const unpinResultAgain = await copilotSession.unpin(workspace.id, user.id);
t.snapshot(
unpinResultAgain,
'should return false when no sessions to unpin'
);
}
// should unpin all sessions
{
const allSessionsAfterUnpin = await db.aiSession.findMany({
where: { id: { in: [firstSessionId, secondSessionId, thirdSessionId] } },
select: { pinned: true, id: true },
orderBy: { id: 'asc' },
});
t.snapshot(
allSessionsAfterUnpin,
'all sessions should be unpinned after unpin operation'
);
}
});
test('session updates and type conversions', async t => {
const { copilotSession, db } = t.context;
await createTestPrompts(copilotSession, db);
const sessionId = 'session-update-id';
const docId = 'doc-update-id';
await createTestSession(t, { sessionId });
// should unpin existing pinned session
{
const existingPinnedId = 'existing-pinned-session-id';
await createTestSession(t, { sessionId: existingPinnedId, pinned: true });
await copilotSession.update(user.id, sessionId, { pinned: true });
const sessionStatesAfterPin = await Promise.all([
getSessionState(db, sessionId),
getSessionState(db, existingPinnedId),
]);
t.snapshot(
sessionStatesAfterPin,
'session states after pinning - should unpin existing'
);
}
// should unpin the session
{
await copilotSession.update(user.id, sessionId, { pinned: false });
const sessionStateAfterUnpin = await getSessionState(db, sessionId);
t.snapshot(sessionStateAfterUnpin, 'session state after unpinning');
}
// should convert session types
{
const conversionSteps: any[] = [];
let session = await db.aiSession.findUnique({
where: { id: sessionId },
select: { docId: true, pinned: true },
});
const convertSession = async (
step: string,
data: UpdateChatSessionData
) => {
await copilotSession.update(user.id, sessionId, data);
session = await db.aiSession.findUnique({
where: { id: sessionId },
select: { docId: true, pinned: true },
});
conversionSteps.push({
step,
session,
type: copilotSession.getSessionType(session!),
});
};
{
await convertSession('workspace_to_doc', { docId }); // Workspace → Doc session
await convertSession('doc_to_pinned', { pinned: true }); // Doc → Pinned session
await convertSession('pinned_to_workspace', {
pinned: false,
docId: null,
}); // Pinned → Workspace session
await convertSession('workspace_to_pinned', { pinned: true }); // Workspace → Pinned session
}
t.snapshot(conversionSteps, 'session type conversion steps');
}
});

View File

@@ -20,8 +20,9 @@ export const cleanObject = (
export async function createCopilotSession(
app: TestingApp,
workspaceId: string,
docId: string,
promptName: string
docId: string | null,
promptName: string,
pinned: boolean = false
): Promise<string> {
const res = await app.gql(
`
@@ -29,12 +30,73 @@ export async function createCopilotSession(
createCopilotSession(options: $options)
}
`,
{ options: { workspaceId, docId, promptName } }
{ options: { workspaceId, docId, promptName, pinned } }
);
return res.createCopilotSession;
}
export async function createWorkspaceCopilotSession(
app: TestingApp,
workspaceId: string,
promptName: string
): Promise<string> {
return createCopilotSession(app, workspaceId, null, promptName);
}
export async function createPinnedCopilotSession(
app: TestingApp,
workspaceId: string,
docId: string,
promptName: string
): Promise<string> {
return createCopilotSession(app, workspaceId, docId, promptName, true);
}
export async function createDocCopilotSession(
app: TestingApp,
workspaceId: string,
docId: string,
promptName: string
): Promise<string> {
return createCopilotSession(app, workspaceId, docId, promptName);
}
export async function getCopilotSession(
app: TestingApp,
workspaceId: string,
sessionId: string
): Promise<{
id: string;
docId: string | null;
parentSessionId: string | null;
pinned: boolean;
promptName: string;
}> {
const res = await app.gql(
`
query getCopilotSession(
$workspaceId: String!
$sessionId: String!
) {
currentUser {
copilot(workspaceId: $workspaceId) {
session(sessionId: $sessionId) {
id
docId
parentSessionId
pinned
promptName
}
}
}
}`,
{ workspaceId, sessionId }
);
return res.currentUser?.copilot?.session;
}
export async function updateCopilotSession(
app: TestingApp,
sessionId: string,

View File

@@ -643,6 +643,10 @@ export const USER_FRIENDLY_ERRORS = {
type: 'resource_not_found',
message: `Copilot session not found.`,
},
copilot_session_invalid_input: {
type: 'invalid_input',
message: `Copilot session input is invalid.`,
},
copilot_session_deleted: {
type: 'action_forbidden',
message: `Copilot session has been deleted.`,

View File

@@ -657,6 +657,12 @@ export class CopilotSessionNotFound extends UserFriendlyError {
}
}
export class CopilotSessionInvalidInput extends UserFriendlyError {
constructor(message?: string) {
super('invalid_input', 'copilot_session_invalid_input', message);
}
}
export class CopilotSessionDeleted extends UserFriendlyError {
constructor(message?: string) {
super('action_forbidden', 'copilot_session_deleted', message);
@@ -1145,6 +1151,7 @@ export enum ErrorNames {
WORKSPACE_ID_REQUIRED_FOR_TEAM_SUBSCRIPTION,
WORKSPACE_ID_REQUIRED_TO_UPDATE_TEAM_SUBSCRIPTION,
COPILOT_SESSION_NOT_FOUND,
COPILOT_SESSION_INVALID_INPUT,
COPILOT_SESSION_DELETED,
NO_COPILOT_PROVIDER_AVAILABLE,
COPILOT_FAILED_TO_GENERATE_TEXT,

View File

@@ -1,36 +1,366 @@
import { Injectable } from '@nestjs/common';
import { Transactional } from '@nestjs-cls/transactional';
import { AiPromptRole, Prisma } from '@prisma/client';
import { omit } from 'lodash-es';
import {
CopilotPromptInvalid,
CopilotSessionDeleted,
CopilotSessionInvalidInput,
CopilotSessionNotFound,
} from '../base';
import { BaseModel } from './base';
interface ChatSessionState {
export enum SessionType {
Workspace = 'workspace', // docId is null and pinned is false
Pinned = 'pinned', // pinned is true
Doc = 'doc', // docId points to specific document
}
type ChatAttachment = { attachment: string; mimeType: string } | string;
type ChatStreamObject = {
type: 'text-delta' | 'reasoning' | 'tool-call' | 'tool-result';
textDelta?: string;
toolCallId?: string;
toolName?: string;
args?: Record<string, any>;
result?: any;
};
type ChatMessage = {
id?: string | undefined;
role: 'system' | 'assistant' | 'user';
content: string;
attachments?: ChatAttachment[] | null;
params?: Record<string, any> | null;
streamObjects?: ChatStreamObject[] | null;
createdAt: Date;
};
type ChatSession = {
sessionId: string;
workspaceId: string;
docId: string;
docId?: string | null;
pinned?: boolean;
messages?: ChatMessage[];
// connect ids
userId: string;
promptName: string;
}
parentSessionId?: string | null;
};
export type UpdateChatSessionData = Partial<
Pick<ChatSession, 'docId' | 'pinned' | 'promptName'>
>;
export type UpdateChatSession = Pick<ChatSession, 'userId' | 'sessionId'> &
UpdateChatSessionData;
export type ListSessionOptions = {
sessionId: string | undefined;
action: boolean | undefined;
fork: boolean | undefined;
limit: number | undefined;
skip: number | undefined;
sessionOrder: 'asc' | 'desc' | undefined;
messageOrder: 'asc' | 'desc' | undefined;
};
// TODO(@darkskygit): not ready to replace business codes yet, just for test
@Injectable()
export class CopilotSessionModel extends BaseModel {
async create(state: ChatSessionState) {
const row = await this.db.aiSession.create({
data: {
id: state.sessionId,
workspaceId: state.workspaceId,
docId: state.docId,
// connect
userId: state.userId,
promptName: state.promptName,
},
});
return row;
getSessionType(session: Pick<ChatSession, 'docId' | 'pinned'>): SessionType {
if (session.pinned) return SessionType.Pinned;
if (!session.docId) return SessionType.Workspace;
return SessionType.Doc;
}
checkSessionPrompt(
session: Pick<ChatSession, 'docId' | 'pinned'>,
promptName: string,
promptAction: string | undefined
): boolean {
const sessionType = this.getSessionType(session);
// workspace and pinned sessions cannot use action prompts
if (
[SessionType.Workspace, SessionType.Pinned].includes(sessionType) &&
!!promptAction?.trim()
) {
throw new CopilotPromptInvalid(
`${promptName} are not allowed for ${sessionType} sessions`
);
}
return true;
}
// NOTE: just for test, remove it after copilot prompt model is ready
async createPrompt(name: string, model: string) {
await this.db.aiPrompt.create({
data: { name, model },
});
}
@Transactional()
async create(state: ChatSession) {
if (state.pinned) {
await this.unpin(state.workspaceId, state.userId);
}
const row = await this.db.aiSession.create({
data: {
id: state.sessionId,
workspaceId: state.workspaceId,
docId: state.docId,
pinned: state.pinned ?? false,
// connect
userId: state.userId,
promptName: state.promptName,
parentSessionId: state.parentSessionId,
},
});
return row;
}
@Transactional()
async has(
sessionId: string,
userId: string,
params?: Prisma.AiSessionCountArgs['where']
) {
return await this.db.aiSession
.count({ where: { id: sessionId, userId, ...params } })
.then(c => c > 0);
}
@Transactional()
async getChatSessionId(state: Omit<ChatSession, 'promptName'>) {
const extraCondition: Record<string, any> = {};
if (state.parentSessionId) {
// also check session id if provided session is forked session
extraCondition.id = state.sessionId;
extraCondition.parentSessionId = state.parentSessionId;
}
const session = await this.db.aiSession.findFirst({
where: {
userId: state.userId,
workspaceId: state.workspaceId,
docId: state.docId,
parentSessionId: null,
prompt: { action: { equals: null } },
...extraCondition,
},
select: { id: true, deletedAt: true },
});
if (session?.deletedAt) throw new CopilotSessionDeleted();
return session?.id;
}
@Transactional()
async getExists<Select extends Prisma.AiSessionSelect>(
sessionId: string,
select?: Select,
where?: Omit<Prisma.AiSessionWhereInput, 'id' | 'deletedAt'>
) {
return (await this.db.aiSession.findUnique({
where: { ...where, id: sessionId, deletedAt: null },
select,
})) as Prisma.AiSessionGetPayload<{ select: Select }>;
}
@Transactional()
async get(sessionId: string) {
return await this.getExists(sessionId, {
id: true,
userId: true,
workspaceId: true,
docId: true,
pinned: true,
parentSessionId: true,
messages: {
select: {
id: true,
role: true,
content: true,
attachments: true,
params: true,
createdAt: true,
},
orderBy: { createdAt: 'asc' },
},
promptName: true,
});
}
async list(
userId: string,
workspaceId?: string,
docId?: string,
options?: ListSessionOptions
) {
const extraCondition = [];
if (!options?.action && options?.fork) {
// only query forked session if fork == true and action == false
extraCondition.push({
userId: { not: userId },
workspaceId: workspaceId,
docId: docId ?? null,
id: options?.sessionId ? { equals: options.sessionId } : undefined,
// should only find forked session
parentSessionId: { not: null },
deletedAt: null,
});
}
return await this.db.aiSession.findMany({
where: {
OR: [
{
userId,
workspaceId: workspaceId,
docId: docId ?? null,
id: options?.sessionId ? { equals: options.sessionId } : undefined,
deletedAt: null,
},
...extraCondition,
],
},
select: {
id: true,
userId: true,
docId: true,
pinned: true,
promptName: true,
tokenCost: true,
createdAt: true,
messages: {
select: {
id: true,
role: true,
content: true,
attachments: true,
params: true,
streamObjects: true,
createdAt: true,
},
orderBy: {
// message order is asc by default
createdAt: options?.messageOrder === 'desc' ? 'desc' : 'asc',
},
},
},
take: options?.limit,
skip: options?.skip,
orderBy: {
// session order is desc by default
createdAt: options?.sessionOrder === 'asc' ? 'asc' : 'desc',
},
});
}
@Transactional()
async unpin(workspaceId: string, userId: string): Promise<boolean> {
const { count } = await this.db.aiSession.updateMany({
where: { userId, workspaceId, pinned: true, deletedAt: null },
data: { pinned: false },
});
return count > 0;
}
@Transactional()
async update(
userId: string,
sessionId: string,
data: UpdateChatSessionData
): Promise<string> {
const session = await this.getExists(
sessionId,
{ id: true, workspaceId: true, docId: true, pinned: true, prompt: true },
{ userId }
);
if (!session) {
throw new CopilotSessionNotFound();
}
if (data.promptName && session.prompt.action) {
throw new CopilotSessionInvalidInput(
`Cannot update prompt for action: ${session.id}`
);
}
if (data.pinned && data.pinned !== session.pinned) {
// if pin the session, unpin exists session in the workspace
await this.unpin(session.workspaceId, userId);
}
await this.db.aiSession.update({ where: { id: sessionId }, data });
return sessionId;
}
@Transactional()
async getMessages(
sessionId: string,
select?: Prisma.AiSessionMessageSelect,
orderBy?: Prisma.AiSessionMessageOrderByWithRelationInput
) {
return this.db.aiSessionMessage.findMany({
where: { sessionId },
select,
orderBy: orderBy ?? { createdAt: 'asc' },
});
}
@Transactional()
async setMessages(
sessionId: string,
messages: ChatMessage[],
tokenCost: number
) {
await this.db.aiSessionMessage.createMany({
data: messages.map(m => ({
...m,
attachments: m.attachments || undefined,
params: omit(m.params, ['docs']) || undefined,
streamObjects: m.streamObjects || undefined,
sessionId,
})),
});
// only count message generated by user
const userMessages = messages.filter(m => m.role === 'user');
await this.db.aiSession.update({
where: { id: sessionId },
data: {
messageCost: { increment: userMessages.length },
tokenCost: { increment: tokenCost },
},
});
}
@Transactional()
async revertLatestMessage(
sessionId: string,
removeLatestUserMessage: boolean
) {
const id = await this.getExists(sessionId, { id: true }).then(
session => session?.id
);
if (!id) {
throw new CopilotSessionNotFound();
}
const ids = await this.getMessages(id, { id: true, role: true }).then(
roles =>
roles
.slice(
roles.findLastIndex(({ role }) => role === AiPromptRole.user) +
(removeLatestUserMessage ? 0 : 1)
)
.map(({ id }) => id)
);
if (ids.length) {
await this.db.aiSessionMessage.deleteMany({ where: { id: { in: ids } } });
}
}
}

View File

@@ -102,6 +102,8 @@ export class ModelsModule {}
export * from './common';
export * from './copilot-context';
export * from './copilot-job';
export * from './copilot-session';
export * from './copilot-workspace';
export * from './doc';
export * from './doc-user';
export * from './feature';

View File

@@ -33,6 +33,7 @@ import { CurrentUser } from '../../core/auth';
import { Admin } from '../../core/common';
import { AccessController } from '../../core/permission';
import { UserType } from '../../core/user';
import type { UpdateChatSession } from '../../models';
import { PromptService } from './prompt';
import { PromptMessage, StreamObject } from './providers';
import { ChatSessionService } from './session';
@@ -57,22 +58,38 @@ class CreateChatSessionInput {
@Field(() => String)
workspaceId!: string;
@Field(() => String)
docId!: string;
@Field(() => String, { nullable: true })
docId?: string;
@Field(() => String, {
description: 'The prompt name to use for the session',
})
promptName!: string;
@Field(() => Boolean, { nullable: true })
pinned?: boolean;
}
@InputType()
class UpdateChatSessionInput {
class UpdateChatSessionInput implements Omit<UpdateChatSession, 'userId'> {
@Field(() => String)
sessionId!: string;
@Field(() => String, {
description: 'The workspace id of the session',
nullable: true,
})
docId!: string | null | undefined;
@Field(() => Boolean, {
description: 'Whether to pin the session',
nullable: true,
})
pinned!: boolean | undefined;
@Field(() => String, {
description: 'The prompt name to use for the session',
nullable: true,
})
promptName!: string;
}
@@ -219,6 +236,9 @@ class CopilotHistoriesType implements Partial<ChatHistory> {
@Field(() => String)
sessionId!: string;
@Field(() => Boolean)
pinned!: boolean;
@Field(() => String, {
description: 'An mark identifying which view to use to display the session',
nullable: true,
@@ -304,6 +324,12 @@ export class CopilotSessionType {
@Field(() => ID)
id!: string;
@Field(() => String, { nullable: true })
docId!: string | null;
@Field(() => Boolean)
pinned!: boolean;
@Field(() => ID, { nullable: true })
parentSessionId!: string | null;
@@ -459,22 +485,33 @@ export class CopilotResolver {
@Args({ name: 'options', type: () => CreateChatSessionInput })
options: CreateChatSessionInput
): Promise<string> {
await this.ac.user(user.id).doc(options).allowLocal().assert('Doc.Update');
// permission check based on session type
if (options.docId) {
await this.ac
.user(user.id)
.doc({ workspaceId: options.workspaceId, docId: options.docId })
.allowLocal()
.assert('Doc.Update');
} else {
await this.ac
.user(user.id)
.workspace(options.workspaceId)
.allowLocal()
.assert('Workspace.Copilot');
}
const lockFlag = `${COPILOT_LOCKER}:session:${user.id}:${options.workspaceId}`;
await using lock = await this.mutex.acquire(lockFlag);
if (!lock) {
throw new TooManyRequest('Server is busy');
}
if (options.workspaceId === options.docId) {
// filter out session create request for root doc
throw new CopilotDocNotFound({ docId: options.docId });
}
await this.chatSession.checkQuota(user.id);
return await this.chatSession.create({
...options,
pinned: options.pinned ?? false,
docId: options.docId ?? null,
userId: user.id,
});
}
@@ -493,11 +530,19 @@ export class CopilotResolver {
throw new CopilotSessionNotFound();
}
const { workspaceId, docId } = session.config;
await this.ac
.user(user.id)
.doc(workspaceId, docId)
.allowLocal()
.assert('Doc.Update');
if (docId) {
await this.ac
.user(user.id)
.doc(workspaceId, docId)
.allowLocal()
.assert('Doc.Update');
} else {
await this.ac
.user(user.id)
.workspace(workspaceId)
.allowLocal()
.assert('Workspace.Copilot');
}
const lockFlag = `${COPILOT_LOCKER}:session:${user.id}:${workspaceId}`;
await using lock = await this.mutex.acquire(lockFlag);
if (!lock) {
@@ -505,7 +550,7 @@ export class CopilotResolver {
}
await this.chatSession.checkQuota(user.id);
return await this.chatSession.updateSessionPrompt({
return await this.chatSession.updateSession({
...options,
userId: user.id,
});
@@ -619,6 +664,8 @@ export class CopilotResolver {
return {
id: session.sessionId,
parentSessionId: session.parentSessionId,
docId: session.docId,
pinned: session.pinned,
promptName: session.prompt.name,
model: session.prompt.model,
optionalModels: session.prompt.optionalModels,

View File

@@ -1,34 +1,36 @@
import { randomUUID } from 'node:crypto';
import { Injectable, Logger } from '@nestjs/common';
import { AiPromptRole, Prisma, PrismaClient } from '@prisma/client';
import { omit } from 'lodash-es';
import { Transactional } from '@nestjs-cls/transactional';
import { AiPromptRole, PrismaClient } from '@prisma/client';
import {
CopilotActionTaken,
CopilotMessageNotFound,
CopilotPromptNotFound,
CopilotQuotaExceeded,
CopilotSessionDeleted,
CopilotSessionInvalidInput,
CopilotSessionNotFound,
PrismaTransaction,
} from '../../base';
import { QuotaService } from '../../core/quota';
import { Models } from '../../models';
import {
Models,
type UpdateChatSession,
UpdateChatSessionData,
} from '../../models';
import { ChatMessageCache } from './message';
import { PromptService } from './prompt';
import { PromptMessage, PromptParams } from './providers';
import {
ChatHistory,
ChatMessage,
type ChatHistory,
type ChatMessage,
ChatMessageSchema,
ChatSessionForkOptions,
ChatSessionOptions,
ChatSessionPromptUpdateOptions,
ChatSessionState,
type ChatSessionForkOptions,
type ChatSessionOptions,
type ChatSessionState,
getTokenEncoder,
ListHistoriesOptions,
SubmittedMessage,
type ListHistoriesOptions,
type SubmittedMessage,
} from './types';
export class ChatSession implements AsyncDisposable {
@@ -229,141 +231,56 @@ export class ChatSessionService {
private readonly models: Models
) {}
private async haveSession(
sessionId: string,
userId: string,
tx?: PrismaTransaction,
params?: Prisma.AiSessionCountArgs['where']
) {
const executor = tx ?? this.db;
return await executor.aiSession
.count({
where: {
id: sessionId,
userId,
...params,
},
})
.then(c => c > 0);
}
@Transactional()
private async setSession(state: ChatSessionState): Promise<string> {
return await this.db.$transaction(async tx => {
let sessionId = state.sessionId;
const session = this.models.copilotSession;
let sessionId = state.sessionId;
// find existing session if session is chat session
if (!state.prompt.action) {
const extraCondition: Record<string, any> = {};
if (state.parentSessionId) {
// also check session id if provided session is forked session
extraCondition.id = state.sessionId;
extraCondition.parentSessionId = state.parentSessionId;
}
const { id, deletedAt } =
(await tx.aiSession.findFirst({
where: {
userId: state.userId,
workspaceId: state.workspaceId,
docId: state.docId,
prompt: { action: { equals: null } },
parentSessionId: null,
...extraCondition,
},
select: { id: true, deletedAt: true },
})) || {};
if (deletedAt) throw new CopilotSessionDeleted();
if (id) sessionId = id;
// find existing session if session is chat session
if (!state.prompt.action) {
const id = await session.getChatSessionId(state);
if (id) sessionId = id;
}
const haveSession = await session.has(sessionId, state.userId);
if (haveSession) {
// message will only exists when setSession call by session.save
if (state.messages.length) {
await session.setMessages(
sessionId,
state.messages,
this.calculateTokenSize(state.messages, state.prompt.model)
);
}
} else {
await session.create({
...state,
sessionId,
promptName: state.prompt.name,
});
}
const haveSession = await this.haveSession(sessionId, state.userId, tx);
if (haveSession) {
// message will only exists when setSession call by session.save
if (state.messages.length) {
await tx.aiSessionMessage.createMany({
data: state.messages.map(m => ({
...m,
streamObjects: m.streamObjects || undefined,
attachments: m.attachments || undefined,
params: omit(m.params, ['docs']) || undefined,
sessionId,
})),
});
// only count message generated by user
const userMessages = state.messages.filter(m => m.role === 'user');
await tx.aiSession.update({
where: { id: sessionId },
data: {
messageCost: { increment: userMessages.length },
tokenCost: {
increment: this.calculateTokenSize(
state.messages,
state.prompt.model
),
},
},
});
}
} else {
await tx.aiSession.create({
data: {
id: sessionId,
workspaceId: state.workspaceId,
docId: state.docId,
// connect
userId: state.userId,
promptName: state.prompt.name,
parentSessionId: state.parentSessionId,
},
});
}
return sessionId;
});
return sessionId;
}
async getSession(sessionId: string): Promise<ChatSessionState | undefined> {
return await this.db.aiSession
.findUnique({
where: { id: sessionId, deletedAt: null },
select: {
id: true,
userId: true,
workspaceId: true,
docId: true,
parentSessionId: true,
messages: {
select: {
id: true,
role: true,
content: true,
attachments: true,
params: true,
createdAt: true,
},
orderBy: { createdAt: 'asc' },
},
promptName: true,
},
})
.then(async session => {
if (!session) return;
const prompt = await this.prompt.get(session.promptName);
if (!prompt)
throw new CopilotPromptNotFound({ name: session.promptName });
const session = await this.models.copilotSession.get(sessionId);
if (!session) return;
const prompt = await this.prompt.get(session.promptName);
if (!prompt) throw new CopilotPromptNotFound({ name: session.promptName });
const messages = ChatMessageSchema.array().safeParse(session.messages);
const messages = ChatMessageSchema.array().safeParse(session.messages);
return {
sessionId: session.id,
userId: session.userId,
workspaceId: session.workspaceId,
docId: session.docId,
parentSessionId: session.parentSessionId,
prompt,
messages: messages.success ? messages.data : [],
};
});
return {
sessionId: session.id,
userId: session.userId,
workspaceId: session.workspaceId,
docId: session.docId,
pinned: session.pinned,
parentSessionId: session.parentSessionId,
prompt,
messages: messages.success ? messages.data : [],
};
}
// revert the latest messages not generate by user
@@ -372,34 +289,10 @@ export class ChatSessionService {
sessionId: string,
removeLatestUserMessage: boolean
) {
await this.db.$transaction(async tx => {
const id = await tx.aiSession
.findUnique({
where: { id: sessionId, deletedAt: null },
select: { id: true },
})
.then(session => session?.id);
if (!id) {
throw new CopilotSessionNotFound();
}
const ids = await tx.aiSessionMessage
.findMany({
where: { sessionId: id },
select: { id: true, role: true },
orderBy: { createdAt: 'asc' },
})
.then(roles =>
roles
.slice(
roles.findLastIndex(({ role }) => role === AiPromptRole.user) +
(removeLatestUserMessage ? 0 : 1)
)
.map(({ id }) => id)
);
if (ids.length) {
await tx.aiSessionMessage.deleteMany({ where: { id: { in: ids } } });
}
});
await this.models.copilotSession.revertLatestMessage(
sessionId,
removeLatestUserMessage
);
}
private calculateTokenSize(messages: PromptMessage[], model: string): number {
@@ -441,6 +334,7 @@ export class ChatSessionService {
userId: true,
workspaceId: true,
docId: true,
pinned: true,
parentSessionId: true,
promptName: true,
},
@@ -457,6 +351,7 @@ export class ChatSessionService {
userId: session.userId,
workspaceId: session.workspaceId,
docId: session.docId,
pinned: session.pinned,
parentSessionId: session.parentSessionId,
prompt,
};
@@ -471,138 +366,83 @@ export class ChatSessionService {
docId?: string,
options?: ListHistoriesOptions
): Promise<ChatHistory[]> {
const extraCondition = [];
if (!options?.action && options?.fork) {
// only query forked session if fork == true and action == false
extraCondition.push({
userId: { not: userId },
workspaceId: workspaceId,
docId: workspaceId === docId ? undefined : docId,
id: options?.sessionId ? { equals: options.sessionId } : undefined,
// should only find forked session
parentSessionId: { not: null },
deletedAt: null,
});
}
return await this.db.aiSession
.findMany({
where: {
OR: [
{
userId,
workspaceId: workspaceId,
docId: workspaceId === docId ? undefined : docId,
id: options?.sessionId
? { equals: options.sessionId }
: undefined,
deletedAt: null,
},
...extraCondition,
],
},
select: {
id: true,
userId: true,
promptName: true,
tokenCost: true,
createdAt: true,
messages: {
select: {
id: true,
role: true,
content: true,
streamObjects: true,
attachments: true,
params: true,
createdAt: true,
},
orderBy: {
// message order is asc by default
createdAt: options?.messageOrder === 'desc' ? 'desc' : 'asc',
},
},
},
take: options?.limit,
skip: options?.skip,
orderBy: {
// session order is desc by default
createdAt: options?.sessionOrder === 'asc' ? 'asc' : 'desc',
},
})
.then(sessions =>
Promise.all(
sessions.map(
async ({
id,
userId: uid,
promptName,
tokenCost,
messages,
createdAt,
}) => {
try {
const prompt = await this.prompt.get(promptName);
if (!prompt) {
throw new CopilotPromptNotFound({ name: promptName });
}
if (
// filter out the user's session that not match the action option
(uid === userId && !!options?.action !== !!prompt.action) ||
// filter out the non chat session from other user
(uid !== userId && !!prompt.action)
) {
return undefined;
}
const ret = ChatMessageSchema.array().safeParse(messages);
if (ret.success) {
// render system prompt
const preload = (
options?.withPrompt
? prompt
.finish(ret.data[0]?.params || {}, id)
.filter(({ role }) => role !== 'system')
: []
) as ChatMessage[];
// `createdAt` is required for history sorting in frontend
// let's fake the creating time of prompt messages
preload.forEach((msg, i) => {
msg.createdAt = new Date(
createdAt.getTime() - preload.length - i - 1
);
});
return {
sessionId: id,
action: prompt.action || null,
tokens: tokenCost,
createdAt,
messages: preload.concat(ret.data).map(m => ({
...m,
attachments: m.attachments
?.map(a => (typeof a === 'string' ? a : a.attachment))
.filter(a => !!a),
})),
};
} else {
this.logger.error(
`Unexpected message schema: ${JSON.stringify(ret.error)}`
);
}
} catch (e) {
this.logger.error('Unexpected error in listHistories', e);
}
const sessions = await this.models.copilotSession.list(
userId,
workspaceId,
docId,
options
);
const histories = await Promise.all(
sessions.map(
async ({
id,
userId: uid,
pinned,
promptName,
tokenCost,
messages,
createdAt,
}) => {
try {
const prompt = await this.prompt.get(promptName);
if (!prompt) {
throw new CopilotPromptNotFound({ name: promptName });
}
if (
// filter out the user's session that not match the action option
(uid === userId && !!options?.action !== !!prompt.action) ||
// filter out the non chat session from other user
(uid !== userId && !!prompt.action)
) {
return undefined;
}
)
)
const ret = ChatMessageSchema.array().safeParse(messages);
if (ret.success) {
// render system prompt
const preload = (
options?.withPrompt
? prompt
.finish(ret.data[0]?.params || {}, id)
.filter(({ role }) => role !== 'system')
: []
) as ChatMessage[];
// `createdAt` is required for history sorting in frontend
// let's fake the creating time of prompt messages
preload.forEach((msg, i) => {
msg.createdAt = new Date(
createdAt.getTime() - preload.length - i - 1
);
});
return {
sessionId: id,
pinned,
action: prompt.action || null,
tokens: tokenCost,
createdAt,
messages: preload.concat(ret.data).map(m => ({
...m,
attachments: m.attachments
?.map(a => (typeof a === 'string' ? a : a.attachment))
.filter(a => !!a),
})),
};
} else {
this.logger.error(
`Unexpected message schema: ${JSON.stringify(ret.error)}`
);
}
} catch (e) {
this.logger.error('Unexpected error in listHistories', e);
}
return undefined;
}
)
.then(histories =>
histories.filter((v): v is NonNullable<typeof v> => !!v)
);
);
return histories.filter((v): v is NonNullable<typeof v> => !!v);
}
async getQuota(userId: string) {
@@ -637,6 +477,17 @@ export class ChatSessionService {
throw new CopilotPromptNotFound({ name: options.promptName });
}
if (options.pinned) {
await this.unpin(options.workspaceId, options.userId);
}
// validate prompt compatibility with session type
this.models.copilotSession.checkSessionPrompt(
options,
prompt.name,
prompt.action
);
return await this.setSession({
...options,
sessionId,
@@ -647,30 +498,47 @@ export class ChatSessionService {
});
}
async updateSessionPrompt(
options: ChatSessionPromptUpdateOptions
): Promise<string> {
const prompt = await this.prompt.get(options.promptName);
if (!prompt) {
this.logger.error(`Prompt not found: ${options.promptName}`);
throw new CopilotPromptNotFound({ name: options.promptName });
@Transactional()
async unpin(workspaceId: string, userId: string) {
await this.models.copilotSession.unpin(workspaceId, userId);
}
@Transactional()
async updateSession(options: UpdateChatSession): Promise<string> {
const session = await this.getSession(options.sessionId);
if (!session) {
throw new CopilotSessionNotFound();
}
return await this.db.$transaction(async tx => {
let sessionId = options.sessionId;
const haveSession = await this.haveSession(
sessionId,
options.userId,
tx,
{ prompt: { action: null } }
);
if (haveSession) {
await tx.aiSession.update({
where: { id: sessionId },
data: { promptName: prompt.name },
});
const finalData: UpdateChatSessionData = {};
if (options.promptName) {
const prompt = await this.prompt.get(options.promptName);
if (!prompt) {
this.logger.error(`Prompt not found: ${options.promptName}`);
throw new CopilotPromptNotFound({ name: options.promptName });
}
return sessionId;
});
this.models.copilotSession.checkSessionPrompt(
session,
prompt.name,
prompt.action
);
finalData.promptName = prompt.name;
}
finalData.pinned = options.pinned;
finalData.docId = options.docId;
if (Object.keys(finalData).length === 0) {
throw new CopilotSessionInvalidInput(
'No valid fields to update in the session'
);
}
return await this.models.copilotSession.update(
options.userId,
options.sessionId,
finalData
);
}
async fork(options: ChatSessionForkOptions): Promise<string> {
@@ -678,6 +546,10 @@ export class ChatSessionService {
if (!state) {
throw new CopilotSessionNotFound();
}
if (state.pinned) {
await this.unpin(options.workspaceId, options.userId);
}
let messages = state.messages.map(m => ({ ...m, id: undefined }));
if (options.latestMessageId) {
const lastMessageIdx = state.messages.findLastIndex(
@@ -706,7 +578,9 @@ export class ChatSessionService {
}
async cleanup(
options: Omit<ChatSessionOptions, 'promptName'> & { sessionIds: string[] }
options: Omit<ChatSessionOptions, 'pinned' | 'promptName'> & {
sessionIds: string[];
}
) {
return await this.db.$transaction(async tx => {
const sessions = await tx.aiSession.findMany({

View File

@@ -84,6 +84,7 @@ export type ChatMessage = z.infer<typeof ChatMessageSchema>;
export const ChatHistorySchema = z
.object({
sessionId: z.string(),
pinned: z.boolean(),
action: z.string().nullable(),
tokens: z.number(),
messages: z.array(ChatMessageSchema),
@@ -105,17 +106,13 @@ export interface ChatSessionOptions {
// connect ids
userId: string;
workspaceId: string;
docId: string;
promptName: string;
}
export interface ChatSessionPromptUpdateOptions
extends Pick<ChatSessionState, 'sessionId' | 'userId'> {
docId: string | null;
promptName: string;
pinned: boolean;
}
export interface ChatSessionForkOptions
extends Omit<ChatSessionOptions, 'promptName'> {
extends Omit<ChatSessionOptions, 'pinned' | 'promptName'> {
sessionId: string;
latestMessageId?: string;
}

View File

@@ -241,6 +241,7 @@ type CopilotHistories {
action: String
createdAt: DateTime!
messages: [ChatMessage!]!
pinned: Boolean!
sessionId: String!
"""The number of tokens used in the session"""
@@ -332,10 +333,12 @@ type CopilotQuota {
}
type CopilotSessionType {
docId: String
id: ID!
model: String!
optionalModels: [String!]!
parentSessionId: ID
pinned: Boolean!
promptName: String!
}
@@ -386,7 +389,8 @@ input CreateChatMessageInput {
}
input CreateChatSessionInput {
docId: String!
docId: String
pinned: Boolean
"""The prompt name to use for the session"""
promptName: String!
@@ -570,6 +574,7 @@ enum ErrorNames {
COPILOT_PROVIDER_SIDE_ERROR
COPILOT_QUOTA_EXCEEDED
COPILOT_SESSION_DELETED
COPILOT_SESSION_INVALID_INPUT
COPILOT_SESSION_NOT_FOUND
COPILOT_TRANSCRIPTION_AUDIO_NOT_PROVIDED
COPILOT_TRANSCRIPTION_JOB_EXISTS
@@ -1751,8 +1756,14 @@ input UpdateAppConfigInput {
}
input UpdateChatSessionInput {
"""The workspace id of the session"""
docId: String
"""Whether to pin the session"""
pinned: Boolean
"""The prompt name to use for the session"""
promptName: String!
promptName: String
sessionId: String!
}