mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-14 21:27:20 +00:00
feat(server): allow chat session dangling & pin session support (#12849)
fix AI-181 fix AI-179 fix AI-178 fix PD-2682 fix PD-2683
This commit is contained in:
@@ -135,3 +135,31 @@ Generated by [AVA](https://avajs.dev).
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
## should create different session types and validate prompt constraints
|
||||
|
||||
> should create session with should create workspace session with text prompt
|
||||
|
||||
[
|
||||
{
|
||||
pinned: false,
|
||||
},
|
||||
]
|
||||
|
||||
> should create session with should create pinned session with text prompt
|
||||
|
||||
[
|
||||
{
|
||||
docId: 'pinned-doc',
|
||||
pinned: true,
|
||||
},
|
||||
]
|
||||
|
||||
> should create session with should create doc session with text prompt
|
||||
|
||||
[
|
||||
{
|
||||
docId: 'normal-doc',
|
||||
pinned: false,
|
||||
},
|
||||
]
|
||||
|
||||
Binary file not shown.
@@ -48,7 +48,11 @@ import {
|
||||
createCopilotContext,
|
||||
createCopilotMessage,
|
||||
createCopilotSession,
|
||||
createDocCopilotSession,
|
||||
createPinnedCopilotSession,
|
||||
createWorkspaceCopilotSession,
|
||||
forkCopilotSession,
|
||||
getCopilotSession,
|
||||
getHistories,
|
||||
listContext,
|
||||
listContextDocAndFiles,
|
||||
@@ -302,12 +306,8 @@ test('should fork session correctly', async t => {
|
||||
|
||||
// prepare session
|
||||
const { id } = await createWorkspace(app);
|
||||
const sessionId = await createCopilotSession(
|
||||
app,
|
||||
id,
|
||||
randomUUID(),
|
||||
textPromptName
|
||||
);
|
||||
const docId = randomUUID();
|
||||
const sessionId = await createCopilotSession(app, id, docId, textPromptName);
|
||||
|
||||
let forkedSessionId: string;
|
||||
// should be able to fork session
|
||||
@@ -316,7 +316,7 @@ test('should fork session correctly', async t => {
|
||||
const messageId = await createCopilotMessage(app, sessionId);
|
||||
await chatWithText(app, sessionId, messageId);
|
||||
}
|
||||
const histories = await getHistories(app, { workspaceId: id });
|
||||
const histories = await getHistories(app, { workspaceId: id, docId });
|
||||
const latestMessageId = histories[0].messages.findLast(
|
||||
m => m.role === 'assistant'
|
||||
)?.id;
|
||||
@@ -375,7 +375,7 @@ test('should fork session correctly', async t => {
|
||||
});
|
||||
|
||||
await app.switchUser(u1);
|
||||
const histories = await getHistories(app, { workspaceId: id });
|
||||
const histories = await getHistories(app, { workspaceId: id, docId });
|
||||
const latestMessageId = histories
|
||||
.find(h => h.sessionId === forkedSessionId)
|
||||
?.messages.findLast(m => m.role === 'assistant')?.id;
|
||||
@@ -612,10 +612,11 @@ test('should be able to retry with api', async t => {
|
||||
// normal chat
|
||||
{
|
||||
const { id } = await createWorkspace(app);
|
||||
const docId = randomUUID();
|
||||
const sessionId = await createCopilotSession(
|
||||
app,
|
||||
id,
|
||||
randomUUID(),
|
||||
docId,
|
||||
textPromptName
|
||||
);
|
||||
const messageId = await createCopilotMessage(app, sessionId);
|
||||
@@ -623,7 +624,7 @@ test('should be able to retry with api', async t => {
|
||||
await chatWithText(app, sessionId, messageId);
|
||||
await chatWithText(app, sessionId, messageId);
|
||||
|
||||
const histories = await getHistories(app, { workspaceId: id });
|
||||
const histories = await getHistories(app, { workspaceId: id, docId });
|
||||
t.deepEqual(
|
||||
histories.map(h => h.messages.map(m => m.content)),
|
||||
[['generate text to text', 'generate text to text']],
|
||||
@@ -634,10 +635,11 @@ test('should be able to retry with api', async t => {
|
||||
// retry chat
|
||||
{
|
||||
const { id } = await createWorkspace(app);
|
||||
const docId = randomUUID();
|
||||
const sessionId = await createCopilotSession(
|
||||
app,
|
||||
id,
|
||||
randomUUID(),
|
||||
docId,
|
||||
textPromptName
|
||||
);
|
||||
const messageId = await createCopilotMessage(app, sessionId);
|
||||
@@ -646,7 +648,7 @@ test('should be able to retry with api', async t => {
|
||||
await chatWithText(app, sessionId);
|
||||
|
||||
// should only have 1 message
|
||||
const histories = await getHistories(app, { workspaceId: id });
|
||||
const histories = await getHistories(app, { workspaceId: id, docId });
|
||||
t.snapshot(
|
||||
cleanObject(histories),
|
||||
'should be able to list history after retry'
|
||||
@@ -656,10 +658,11 @@ test('should be able to retry with api', async t => {
|
||||
// retry chat with new message id
|
||||
{
|
||||
const { id } = await createWorkspace(app);
|
||||
const docId = randomUUID();
|
||||
const sessionId = await createCopilotSession(
|
||||
app,
|
||||
id,
|
||||
randomUUID(),
|
||||
docId,
|
||||
textPromptName
|
||||
);
|
||||
const messageId = await createCopilotMessage(app, sessionId);
|
||||
@@ -669,7 +672,7 @@ test('should be able to retry with api', async t => {
|
||||
await chatWithText(app, sessionId, newMessageId, '', true);
|
||||
|
||||
// should only have 1 message
|
||||
const histories = await getHistories(app, { workspaceId: id });
|
||||
const histories = await getHistories(app, { workspaceId: id, docId });
|
||||
t.snapshot(
|
||||
cleanObject(histories),
|
||||
'should be able to list history after retry'
|
||||
@@ -746,10 +749,11 @@ test('should be able to list history', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
const { id: workspaceId } = await createWorkspace(app);
|
||||
const docId = randomUUID();
|
||||
const sessionId = await createCopilotSession(
|
||||
app,
|
||||
workspaceId,
|
||||
randomUUID(),
|
||||
docId,
|
||||
textPromptName
|
||||
);
|
||||
|
||||
@@ -757,7 +761,7 @@ test('should be able to list history', async t => {
|
||||
await chatWithText(app, sessionId, messageId);
|
||||
|
||||
{
|
||||
const histories = await getHistories(app, { workspaceId });
|
||||
const histories = await getHistories(app, { workspaceId, docId });
|
||||
t.deepEqual(
|
||||
histories.map(h => h.messages.map(m => m.content)),
|
||||
[['hello', 'generate text to text']],
|
||||
@@ -768,6 +772,7 @@ test('should be able to list history', async t => {
|
||||
{
|
||||
const histories = await getHistories(app, {
|
||||
workspaceId,
|
||||
docId,
|
||||
options: { messageOrder: 'desc' },
|
||||
});
|
||||
t.deepEqual(
|
||||
@@ -809,17 +814,18 @@ test('should reject request that user have not permission', async t => {
|
||||
}
|
||||
|
||||
{
|
||||
const docId = randomUUID();
|
||||
const sessionId = await createCopilotSession(
|
||||
app,
|
||||
workspaceId,
|
||||
randomUUID(),
|
||||
docId,
|
||||
textPromptName
|
||||
);
|
||||
|
||||
const messageId = await createCopilotMessage(app, sessionId);
|
||||
await chatWithText(app, sessionId, messageId);
|
||||
|
||||
const histories = await getHistories(app, { workspaceId });
|
||||
const histories = await getHistories(app, { workspaceId, docId });
|
||||
t.deepEqual(
|
||||
histories.map(h => h.messages.map(m => m.content)),
|
||||
[['generate text to text']],
|
||||
@@ -1072,3 +1078,93 @@ test('should be able to transcript', async t => {
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
test('should create different session types and validate prompt constraints', async t => {
|
||||
const { app } = t.context;
|
||||
const { id: workspaceId } = await createWorkspace(app);
|
||||
|
||||
const validateSession = async (
|
||||
description: string,
|
||||
workspaceId: string,
|
||||
createPromise: Promise<string>
|
||||
) => {
|
||||
const sessionId = await createPromise;
|
||||
|
||||
t.truthy(sessionId, description);
|
||||
t.snapshot(
|
||||
cleanObject(
|
||||
[await getCopilotSession(app, workspaceId, sessionId)],
|
||||
['id', 'workspaceId', 'promptName']
|
||||
),
|
||||
`should create session with ${description}`
|
||||
);
|
||||
return sessionId;
|
||||
};
|
||||
|
||||
await validateSession(
|
||||
'should create workspace session with text prompt',
|
||||
workspaceId,
|
||||
createWorkspaceCopilotSession(app, workspaceId, textPromptName)
|
||||
);
|
||||
await validateSession(
|
||||
'should create pinned session with text prompt',
|
||||
workspaceId,
|
||||
createPinnedCopilotSession(app, workspaceId, 'pinned-doc', textPromptName)
|
||||
);
|
||||
await validateSession(
|
||||
'should create doc session with text prompt',
|
||||
workspaceId,
|
||||
createDocCopilotSession(app, workspaceId, 'normal-doc', textPromptName)
|
||||
);
|
||||
});
|
||||
|
||||
test('should list histories for different session types correctly', async t => {
|
||||
const { app } = t.context;
|
||||
const { id: workspaceId } = await createWorkspace(app);
|
||||
const pinnedDocId = 'pinned-doc';
|
||||
const docId = 'normal-doc';
|
||||
|
||||
// create sessions and add messages
|
||||
const [workspaceSessionId, pinnedSessionId, docSessionId] = await Promise.all(
|
||||
[
|
||||
createWorkspaceCopilotSession(app, workspaceId, textPromptName),
|
||||
createPinnedCopilotSession(app, workspaceId, pinnedDocId, textPromptName),
|
||||
createDocCopilotSession(app, workspaceId, docId, textPromptName),
|
||||
]
|
||||
);
|
||||
|
||||
await Promise.all([
|
||||
createCopilotMessage(app, workspaceSessionId, 'workspace message'),
|
||||
createCopilotMessage(app, pinnedSessionId, 'pinned message'),
|
||||
createCopilotMessage(app, docSessionId, 'doc message'),
|
||||
]);
|
||||
|
||||
const testHistoryQuery = async (
|
||||
queryDocId: string | undefined,
|
||||
expectedSessionId: string,
|
||||
description: string
|
||||
) => {
|
||||
const histories = await getHistories(app, {
|
||||
workspaceId,
|
||||
docId: queryDocId,
|
||||
});
|
||||
t.is(histories.length, 1, `should return ${description}`);
|
||||
t.is(
|
||||
histories[0].sessionId,
|
||||
expectedSessionId,
|
||||
`should return correct ${description}`
|
||||
);
|
||||
};
|
||||
|
||||
await testHistoryQuery(
|
||||
undefined,
|
||||
workspaceSessionId,
|
||||
'workspace session history'
|
||||
);
|
||||
await testHistoryQuery(
|
||||
pinnedDocId,
|
||||
pinnedSessionId,
|
||||
'pinned session history'
|
||||
);
|
||||
await testHistoryQuery(docId, docSessionId, 'doc session history');
|
||||
});
|
||||
|
||||
@@ -275,7 +275,7 @@ test('should be able to manage chat session', async t => {
|
||||
]);
|
||||
|
||||
const params = { word: 'world' };
|
||||
const commonParams = { docId: 'test', workspaceId: 'test' };
|
||||
const commonParams = { docId: 'test', workspaceId: 'test', pinned: false };
|
||||
|
||||
const sessionId = await session.create({
|
||||
userId,
|
||||
@@ -342,11 +342,12 @@ test('should be able to update chat session prompt', async t => {
|
||||
docId: 'test',
|
||||
workspaceId: 'test',
|
||||
userId,
|
||||
pinned: false,
|
||||
});
|
||||
t.truthy(sessionId, 'should create session');
|
||||
|
||||
// Update the session
|
||||
const updatedSessionId = await session.updateSessionPrompt({
|
||||
const updatedSessionId = await session.updateSession({
|
||||
sessionId,
|
||||
promptName: 'Search With AFFiNE AI',
|
||||
userId,
|
||||
@@ -371,7 +372,7 @@ test('should be able to fork chat session', async t => {
|
||||
]);
|
||||
|
||||
const params = { word: 'world' };
|
||||
const commonParams = { docId: 'test', workspaceId: 'test' };
|
||||
const commonParams = { docId: 'test', workspaceId: 'test', pinned: false };
|
||||
// create session
|
||||
const sessionId = await session.create({
|
||||
userId,
|
||||
@@ -494,6 +495,7 @@ test('should be able to process message id', async t => {
|
||||
workspaceId: 'test',
|
||||
userId,
|
||||
promptName: 'prompt',
|
||||
pinned: false,
|
||||
});
|
||||
const s = (await session.get(sessionId))!;
|
||||
|
||||
@@ -537,6 +539,7 @@ test('should be able to generate with message id', async t => {
|
||||
workspaceId: 'test',
|
||||
userId,
|
||||
promptName: 'prompt',
|
||||
pinned: false,
|
||||
});
|
||||
const s = (await session.get(sessionId))!;
|
||||
|
||||
@@ -559,6 +562,7 @@ test('should be able to generate with message id', async t => {
|
||||
workspaceId: 'test',
|
||||
userId,
|
||||
promptName: 'prompt',
|
||||
pinned: false,
|
||||
});
|
||||
const s = (await session.get(sessionId))!;
|
||||
|
||||
@@ -586,6 +590,7 @@ test('should be able to generate with message id', async t => {
|
||||
workspaceId: 'test',
|
||||
userId,
|
||||
promptName: 'prompt',
|
||||
pinned: false,
|
||||
});
|
||||
const s = (await session.get(sessionId))!;
|
||||
|
||||
@@ -614,6 +619,7 @@ test('should save message correctly', async t => {
|
||||
workspaceId: 'test',
|
||||
userId,
|
||||
promptName: 'prompt',
|
||||
pinned: false,
|
||||
});
|
||||
const s = (await session.get(sessionId))!;
|
||||
|
||||
@@ -643,6 +649,7 @@ test('should revert message correctly', async t => {
|
||||
workspaceId: 'test',
|
||||
userId,
|
||||
promptName: 'prompt',
|
||||
pinned: false,
|
||||
});
|
||||
const s = (await session.get(sessionId))!;
|
||||
|
||||
@@ -742,6 +749,7 @@ test('should handle params correctly in chat session', async t => {
|
||||
workspaceId: 'test',
|
||||
userId,
|
||||
promptName: 'prompt',
|
||||
pinned: false,
|
||||
});
|
||||
|
||||
const s = (await session.get(sessionId))!;
|
||||
@@ -1506,6 +1514,7 @@ test('should be able to manage context', async t => {
|
||||
workspaceId: 'test',
|
||||
userId,
|
||||
promptName: 'prompt',
|
||||
pinned: false,
|
||||
});
|
||||
|
||||
// use mocked embedding client
|
||||
@@ -1729,6 +1738,7 @@ test('should be able to manage workspace embedding', async t => {
|
||||
workspaceId: ws.id,
|
||||
userId,
|
||||
promptName: 'prompt',
|
||||
pinned: false,
|
||||
});
|
||||
const contextSession = await context.create(sessionId);
|
||||
|
||||
|
||||
@@ -0,0 +1,162 @@
|
||||
# Snapshot report for `src/__tests__/models/copilot-session.spec.ts`
|
||||
|
||||
The actual snapshot is saved in `copilot-session.spec.ts.snap`.
|
||||
|
||||
Generated by [AVA](https://avajs.dev).
|
||||
|
||||
## should list and filter session type
|
||||
|
||||
> workspace sessions should include workspace and pinned sessions
|
||||
|
||||
[
|
||||
{
|
||||
docId: null,
|
||||
pinned: true,
|
||||
},
|
||||
{
|
||||
docId: null,
|
||||
pinned: false,
|
||||
},
|
||||
]
|
||||
|
||||
> doc sessions should only include sessions with matching docId
|
||||
|
||||
[
|
||||
{
|
||||
docId: 'doc-id-1',
|
||||
pinned: false,
|
||||
},
|
||||
]
|
||||
|
||||
> session type identification results
|
||||
|
||||
[
|
||||
{
|
||||
session: {
|
||||
docId: null,
|
||||
pinned: false,
|
||||
},
|
||||
type: 'workspace',
|
||||
},
|
||||
{
|
||||
session: {
|
||||
docId: undefined,
|
||||
pinned: false,
|
||||
},
|
||||
type: 'workspace',
|
||||
},
|
||||
{
|
||||
session: {
|
||||
docId: null,
|
||||
pinned: true,
|
||||
},
|
||||
type: 'pinned',
|
||||
},
|
||||
{
|
||||
session: {
|
||||
docId: 'doc-id-1',
|
||||
pinned: false,
|
||||
},
|
||||
type: 'doc',
|
||||
},
|
||||
]
|
||||
|
||||
## should pin and unpin sessions
|
||||
|
||||
> session states after creating second pinned session
|
||||
|
||||
[
|
||||
{
|
||||
docId: null,
|
||||
id: 'first-session-id',
|
||||
pinned: false,
|
||||
},
|
||||
{
|
||||
docId: null,
|
||||
id: 'second-session-id',
|
||||
pinned: true,
|
||||
},
|
||||
]
|
||||
|
||||
> should return false when no sessions to unpin
|
||||
|
||||
false
|
||||
|
||||
> all sessions should be unpinned after unpin operation
|
||||
|
||||
[
|
||||
{
|
||||
id: 'first-session-id',
|
||||
pinned: false,
|
||||
},
|
||||
{
|
||||
id: 'second-session-id',
|
||||
pinned: false,
|
||||
},
|
||||
{
|
||||
id: 'third-session-id',
|
||||
pinned: false,
|
||||
},
|
||||
]
|
||||
|
||||
## session updates and type conversions
|
||||
|
||||
> session states after pinning - should unpin existing
|
||||
|
||||
[
|
||||
{
|
||||
docId: null,
|
||||
id: 'session-update-id',
|
||||
pinned: true,
|
||||
},
|
||||
{
|
||||
docId: null,
|
||||
id: 'existing-pinned-session-id',
|
||||
pinned: false,
|
||||
},
|
||||
]
|
||||
|
||||
> session state after unpinning
|
||||
|
||||
{
|
||||
docId: null,
|
||||
id: 'session-update-id',
|
||||
pinned: false,
|
||||
}
|
||||
|
||||
> session type conversion steps
|
||||
|
||||
[
|
||||
{
|
||||
session: {
|
||||
docId: 'doc-update-id',
|
||||
pinned: false,
|
||||
},
|
||||
step: 'workspace_to_doc',
|
||||
type: 'doc',
|
||||
},
|
||||
{
|
||||
session: {
|
||||
docId: 'doc-update-id',
|
||||
pinned: true,
|
||||
},
|
||||
step: 'doc_to_pinned',
|
||||
type: 'pinned',
|
||||
},
|
||||
{
|
||||
session: {
|
||||
docId: null,
|
||||
pinned: false,
|
||||
},
|
||||
step: 'pinned_to_workspace',
|
||||
type: 'workspace',
|
||||
},
|
||||
{
|
||||
session: {
|
||||
docId: null,
|
||||
pinned: true,
|
||||
},
|
||||
step: 'workspace_to_pinned',
|
||||
type: 'pinned',
|
||||
},
|
||||
]
|
||||
Binary file not shown.
@@ -5,12 +5,14 @@ import ava, { TestFn } from 'ava';
|
||||
import Sinon from 'sinon';
|
||||
|
||||
import { Config } from '../../base';
|
||||
import { ContextEmbedStatus } from '../../models/common/copilot';
|
||||
import { CopilotContextModel } from '../../models/copilot-context';
|
||||
import { CopilotSessionModel } from '../../models/copilot-session';
|
||||
import { CopilotWorkspaceConfigModel } from '../../models/copilot-workspace';
|
||||
import { UserModel } from '../../models/user';
|
||||
import { WorkspaceModel } from '../../models/workspace';
|
||||
import {
|
||||
ContextEmbedStatus,
|
||||
CopilotContextModel,
|
||||
CopilotSessionModel,
|
||||
CopilotWorkspaceConfigModel,
|
||||
UserModel,
|
||||
WorkspaceModel,
|
||||
} from '../../models';
|
||||
import { createTestingModule, type TestingModule } from '../utils';
|
||||
import { cleanObject } from '../utils/copilot';
|
||||
|
||||
@@ -46,7 +48,7 @@ let docId = 'doc1';
|
||||
|
||||
test.beforeEach(async t => {
|
||||
await t.context.module.initTestingDB();
|
||||
await t.context.copilotSession.createPrompt('prompt-name', 'gpt-4o');
|
||||
await t.context.copilotSession.createPrompt('prompt-name', 'gpt-4.1');
|
||||
user = await t.context.user.create({
|
||||
email: 'test@affine.pro',
|
||||
});
|
||||
|
||||
@@ -0,0 +1,341 @@
|
||||
import { randomUUID } from 'node:crypto';
|
||||
|
||||
import { PrismaClient, User, Workspace } from '@prisma/client';
|
||||
import ava, { ExecutionContext, TestFn } from 'ava';
|
||||
|
||||
import { CopilotPromptInvalid } from '../../base';
|
||||
import {
|
||||
CopilotSessionModel,
|
||||
UpdateChatSessionData,
|
||||
UserModel,
|
||||
WorkspaceModel,
|
||||
} from '../../models';
|
||||
import { createTestingModule, type TestingModule } from '../utils';
|
||||
|
||||
interface Context {
|
||||
module: TestingModule;
|
||||
db: PrismaClient;
|
||||
user: UserModel;
|
||||
workspace: WorkspaceModel;
|
||||
copilotSession: CopilotSessionModel;
|
||||
}
|
||||
|
||||
const test = ava as TestFn<Context>;
|
||||
|
||||
test.before(async t => {
|
||||
const module = await createTestingModule();
|
||||
t.context.user = module.get(UserModel);
|
||||
t.context.workspace = module.get(WorkspaceModel);
|
||||
t.context.copilotSession = module.get(CopilotSessionModel);
|
||||
t.context.db = module.get(PrismaClient);
|
||||
t.context.module = module;
|
||||
});
|
||||
|
||||
let user: User;
|
||||
let workspace: Workspace;
|
||||
|
||||
test.beforeEach(async t => {
|
||||
await t.context.module.initTestingDB();
|
||||
user = await t.context.user.create({
|
||||
email: 'test@affine.pro',
|
||||
});
|
||||
workspace = await t.context.workspace.create(user.id);
|
||||
});
|
||||
|
||||
test.after(async t => {
|
||||
await t.context.module.close();
|
||||
});
|
||||
|
||||
const createTestPrompts = async (
|
||||
copilotSession: CopilotSessionModel,
|
||||
db: PrismaClient
|
||||
) => {
|
||||
await copilotSession.createPrompt('test-prompt', 'gpt-4.1');
|
||||
await db.aiPrompt.create({
|
||||
data: { name: 'action-prompt', model: 'gpt-4.1', action: 'edit' },
|
||||
});
|
||||
};
|
||||
|
||||
const createTestSession = async (
|
||||
t: ExecutionContext<Context>,
|
||||
overrides: Partial<{
|
||||
sessionId: string;
|
||||
userId: string;
|
||||
workspaceId: string;
|
||||
docId: string | null;
|
||||
pinned: boolean;
|
||||
promptName: string;
|
||||
}> = {}
|
||||
) => {
|
||||
const sessionData = {
|
||||
sessionId: randomUUID(),
|
||||
userId: user.id,
|
||||
workspaceId: workspace.id,
|
||||
docId: null,
|
||||
pinned: false,
|
||||
promptName: 'test-prompt',
|
||||
...overrides,
|
||||
};
|
||||
|
||||
await t.context.copilotSession.create(sessionData);
|
||||
return sessionData;
|
||||
};
|
||||
|
||||
const getSessionState = async (db: PrismaClient, sessionId: string) => {
|
||||
const session = await db.aiSession.findUnique({
|
||||
where: { id: sessionId },
|
||||
select: { id: true, pinned: true, docId: true },
|
||||
});
|
||||
return session;
|
||||
};
|
||||
|
||||
test('should list and filter session type', async t => {
|
||||
const { copilotSession, db } = t.context;
|
||||
|
||||
await createTestPrompts(copilotSession, db);
|
||||
|
||||
const docId = 'doc-id-1';
|
||||
await createTestSession(t, { sessionId: randomUUID() });
|
||||
await createTestSession(t, { sessionId: randomUUID(), pinned: true });
|
||||
await createTestSession(t, { sessionId: randomUUID(), docId });
|
||||
|
||||
// should list sessions
|
||||
{
|
||||
const workspaceSessions = await copilotSession.list(user.id, workspace.id);
|
||||
|
||||
t.snapshot(
|
||||
workspaceSessions.map(s => ({ docId: s.docId, pinned: s.pinned })),
|
||||
'workspace sessions should include workspace and pinned sessions'
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
const docSessions = await copilotSession.list(user.id, workspace.id, docId);
|
||||
|
||||
t.snapshot(
|
||||
docSessions.map(s => ({ docId: s.docId, pinned: s.pinned })),
|
||||
'doc sessions should only include sessions with matching docId'
|
||||
);
|
||||
}
|
||||
|
||||
// should identify session types
|
||||
{
|
||||
// check get session type
|
||||
const testCases = [
|
||||
{ docId: null, pinned: false },
|
||||
{ docId: undefined, pinned: false },
|
||||
{ docId: null, pinned: true },
|
||||
{ docId, pinned: false },
|
||||
];
|
||||
|
||||
const sessionTypeResults = testCases.map(session => ({
|
||||
session,
|
||||
type: copilotSession.getSessionType(session),
|
||||
}));
|
||||
|
||||
t.snapshot(sessionTypeResults, 'session type identification results');
|
||||
}
|
||||
});
|
||||
|
||||
test('should check session validation for prompts', async t => {
|
||||
const { copilotSession, db } = t.context;
|
||||
|
||||
await createTestPrompts(copilotSession, db);
|
||||
|
||||
const docId = randomUUID();
|
||||
const sessionTypes = [
|
||||
{ name: 'workspace', session: { docId: null, pinned: false } },
|
||||
{ name: 'pinned', session: { docId: null, pinned: true } },
|
||||
{ name: 'doc', session: { docId, pinned: false } },
|
||||
];
|
||||
|
||||
// non-action prompts should work for all session types
|
||||
sessionTypes.forEach(({ name, session }) => {
|
||||
t.notThrows(
|
||||
() =>
|
||||
copilotSession.checkSessionPrompt(session, 'test-prompt', undefined),
|
||||
`${name} session should allow non-action prompts`
|
||||
);
|
||||
});
|
||||
|
||||
// action prompts should only work for doc session type
|
||||
{
|
||||
const actionPromptTests = [
|
||||
{
|
||||
name: 'workspace',
|
||||
session: sessionTypes[0].session,
|
||||
shouldThrow: true,
|
||||
},
|
||||
{ name: 'pinned', session: sessionTypes[1].session, shouldThrow: true },
|
||||
{ name: 'doc', session: sessionTypes[2].session, shouldThrow: false },
|
||||
];
|
||||
|
||||
actionPromptTests.forEach(({ name, session, shouldThrow }) => {
|
||||
if (shouldThrow) {
|
||||
t.throws(
|
||||
() =>
|
||||
copilotSession.checkSessionPrompt(session, 'action-prompt', 'edit'),
|
||||
{ instanceOf: CopilotPromptInvalid },
|
||||
`${name} session should reject action prompts`
|
||||
);
|
||||
} else {
|
||||
t.notThrows(
|
||||
() =>
|
||||
copilotSession.checkSessionPrompt(session, 'action-prompt', 'edit'),
|
||||
`${name} session should allow action prompts`
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
test('should pin and unpin sessions', async t => {
|
||||
const { copilotSession, db } = t.context;
|
||||
|
||||
await createTestPrompts(copilotSession, db);
|
||||
|
||||
const firstSessionId = 'first-session-id';
|
||||
const secondSessionId = 'second-session-id';
|
||||
const thirdSessionId = 'third-session-id';
|
||||
|
||||
// should unpin existing pinned session when creating a new one
|
||||
{
|
||||
await copilotSession.create({
|
||||
sessionId: firstSessionId,
|
||||
userId: user.id,
|
||||
workspaceId: workspace.id,
|
||||
docId: null,
|
||||
promptName: 'test-prompt',
|
||||
pinned: true,
|
||||
});
|
||||
|
||||
const firstSession = await copilotSession.get(firstSessionId);
|
||||
t.truthy(firstSession, 'first session should be created successfully');
|
||||
t.is(firstSession?.pinned, true, 'first session should be pinned');
|
||||
|
||||
// should unpin the first one when creating second pinned session
|
||||
await copilotSession.create({
|
||||
sessionId: secondSessionId,
|
||||
userId: user.id,
|
||||
workspaceId: workspace.id,
|
||||
docId: null,
|
||||
promptName: 'test-prompt',
|
||||
pinned: true,
|
||||
});
|
||||
|
||||
const sessionStatesAfterSecondPin = await Promise.all([
|
||||
getSessionState(db, firstSessionId),
|
||||
getSessionState(db, secondSessionId),
|
||||
]);
|
||||
|
||||
t.snapshot(
|
||||
sessionStatesAfterSecondPin,
|
||||
'session states after creating second pinned session'
|
||||
);
|
||||
}
|
||||
|
||||
// should can unpin a pinned session
|
||||
{
|
||||
await createTestSession(t, { sessionId: thirdSessionId, pinned: true });
|
||||
const unpinResult = await copilotSession.unpin(workspace.id, user.id);
|
||||
t.is(
|
||||
unpinResult,
|
||||
true,
|
||||
'unpin operation should return true when sessions are unpinned'
|
||||
);
|
||||
|
||||
const unpinResultAgain = await copilotSession.unpin(workspace.id, user.id);
|
||||
t.snapshot(
|
||||
unpinResultAgain,
|
||||
'should return false when no sessions to unpin'
|
||||
);
|
||||
}
|
||||
|
||||
// should unpin all sessions
|
||||
{
|
||||
const allSessionsAfterUnpin = await db.aiSession.findMany({
|
||||
where: { id: { in: [firstSessionId, secondSessionId, thirdSessionId] } },
|
||||
select: { pinned: true, id: true },
|
||||
orderBy: { id: 'asc' },
|
||||
});
|
||||
|
||||
t.snapshot(
|
||||
allSessionsAfterUnpin,
|
||||
'all sessions should be unpinned after unpin operation'
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
test('session updates and type conversions', async t => {
|
||||
const { copilotSession, db } = t.context;
|
||||
|
||||
await createTestPrompts(copilotSession, db);
|
||||
|
||||
const sessionId = 'session-update-id';
|
||||
const docId = 'doc-update-id';
|
||||
|
||||
await createTestSession(t, { sessionId });
|
||||
|
||||
// should unpin existing pinned session
|
||||
{
|
||||
const existingPinnedId = 'existing-pinned-session-id';
|
||||
await createTestSession(t, { sessionId: existingPinnedId, pinned: true });
|
||||
|
||||
await copilotSession.update(user.id, sessionId, { pinned: true });
|
||||
|
||||
const sessionStatesAfterPin = await Promise.all([
|
||||
getSessionState(db, sessionId),
|
||||
getSessionState(db, existingPinnedId),
|
||||
]);
|
||||
|
||||
t.snapshot(
|
||||
sessionStatesAfterPin,
|
||||
'session states after pinning - should unpin existing'
|
||||
);
|
||||
}
|
||||
|
||||
// should unpin the session
|
||||
{
|
||||
await copilotSession.update(user.id, sessionId, { pinned: false });
|
||||
const sessionStateAfterUnpin = await getSessionState(db, sessionId);
|
||||
t.snapshot(sessionStateAfterUnpin, 'session state after unpinning');
|
||||
}
|
||||
|
||||
// should convert session types
|
||||
{
|
||||
const conversionSteps: any[] = [];
|
||||
|
||||
let session = await db.aiSession.findUnique({
|
||||
where: { id: sessionId },
|
||||
select: { docId: true, pinned: true },
|
||||
});
|
||||
|
||||
const convertSession = async (
|
||||
step: string,
|
||||
data: UpdateChatSessionData
|
||||
) => {
|
||||
await copilotSession.update(user.id, sessionId, data);
|
||||
session = await db.aiSession.findUnique({
|
||||
where: { id: sessionId },
|
||||
select: { docId: true, pinned: true },
|
||||
});
|
||||
conversionSteps.push({
|
||||
step,
|
||||
session,
|
||||
type: copilotSession.getSessionType(session!),
|
||||
});
|
||||
};
|
||||
|
||||
{
|
||||
await convertSession('workspace_to_doc', { docId }); // Workspace → Doc session
|
||||
await convertSession('doc_to_pinned', { pinned: true }); // Doc → Pinned session
|
||||
await convertSession('pinned_to_workspace', {
|
||||
pinned: false,
|
||||
docId: null,
|
||||
}); // Pinned → Workspace session
|
||||
await convertSession('workspace_to_pinned', { pinned: true }); // Workspace → Pinned session
|
||||
}
|
||||
|
||||
t.snapshot(conversionSteps, 'session type conversion steps');
|
||||
}
|
||||
});
|
||||
@@ -20,8 +20,9 @@ export const cleanObject = (
|
||||
export async function createCopilotSession(
|
||||
app: TestingApp,
|
||||
workspaceId: string,
|
||||
docId: string,
|
||||
promptName: string
|
||||
docId: string | null,
|
||||
promptName: string,
|
||||
pinned: boolean = false
|
||||
): Promise<string> {
|
||||
const res = await app.gql(
|
||||
`
|
||||
@@ -29,12 +30,73 @@ export async function createCopilotSession(
|
||||
createCopilotSession(options: $options)
|
||||
}
|
||||
`,
|
||||
{ options: { workspaceId, docId, promptName } }
|
||||
{ options: { workspaceId, docId, promptName, pinned } }
|
||||
);
|
||||
|
||||
return res.createCopilotSession;
|
||||
}
|
||||
|
||||
export async function createWorkspaceCopilotSession(
|
||||
app: TestingApp,
|
||||
workspaceId: string,
|
||||
promptName: string
|
||||
): Promise<string> {
|
||||
return createCopilotSession(app, workspaceId, null, promptName);
|
||||
}
|
||||
|
||||
export async function createPinnedCopilotSession(
|
||||
app: TestingApp,
|
||||
workspaceId: string,
|
||||
docId: string,
|
||||
promptName: string
|
||||
): Promise<string> {
|
||||
return createCopilotSession(app, workspaceId, docId, promptName, true);
|
||||
}
|
||||
|
||||
export async function createDocCopilotSession(
|
||||
app: TestingApp,
|
||||
workspaceId: string,
|
||||
docId: string,
|
||||
promptName: string
|
||||
): Promise<string> {
|
||||
return createCopilotSession(app, workspaceId, docId, promptName);
|
||||
}
|
||||
|
||||
export async function getCopilotSession(
|
||||
app: TestingApp,
|
||||
workspaceId: string,
|
||||
sessionId: string
|
||||
): Promise<{
|
||||
id: string;
|
||||
docId: string | null;
|
||||
parentSessionId: string | null;
|
||||
pinned: boolean;
|
||||
promptName: string;
|
||||
}> {
|
||||
const res = await app.gql(
|
||||
`
|
||||
query getCopilotSession(
|
||||
$workspaceId: String!
|
||||
$sessionId: String!
|
||||
) {
|
||||
currentUser {
|
||||
copilot(workspaceId: $workspaceId) {
|
||||
session(sessionId: $sessionId) {
|
||||
id
|
||||
docId
|
||||
parentSessionId
|
||||
pinned
|
||||
promptName
|
||||
}
|
||||
}
|
||||
}
|
||||
}`,
|
||||
{ workspaceId, sessionId }
|
||||
);
|
||||
|
||||
return res.currentUser?.copilot?.session;
|
||||
}
|
||||
|
||||
export async function updateCopilotSession(
|
||||
app: TestingApp,
|
||||
sessionId: string,
|
||||
|
||||
@@ -643,6 +643,10 @@ export const USER_FRIENDLY_ERRORS = {
|
||||
type: 'resource_not_found',
|
||||
message: `Copilot session not found.`,
|
||||
},
|
||||
copilot_session_invalid_input: {
|
||||
type: 'invalid_input',
|
||||
message: `Copilot session input is invalid.`,
|
||||
},
|
||||
copilot_session_deleted: {
|
||||
type: 'action_forbidden',
|
||||
message: `Copilot session has been deleted.`,
|
||||
|
||||
@@ -657,6 +657,12 @@ export class CopilotSessionNotFound extends UserFriendlyError {
|
||||
}
|
||||
}
|
||||
|
||||
export class CopilotSessionInvalidInput extends UserFriendlyError {
|
||||
constructor(message?: string) {
|
||||
super('invalid_input', 'copilot_session_invalid_input', message);
|
||||
}
|
||||
}
|
||||
|
||||
export class CopilotSessionDeleted extends UserFriendlyError {
|
||||
constructor(message?: string) {
|
||||
super('action_forbidden', 'copilot_session_deleted', message);
|
||||
@@ -1145,6 +1151,7 @@ export enum ErrorNames {
|
||||
WORKSPACE_ID_REQUIRED_FOR_TEAM_SUBSCRIPTION,
|
||||
WORKSPACE_ID_REQUIRED_TO_UPDATE_TEAM_SUBSCRIPTION,
|
||||
COPILOT_SESSION_NOT_FOUND,
|
||||
COPILOT_SESSION_INVALID_INPUT,
|
||||
COPILOT_SESSION_DELETED,
|
||||
NO_COPILOT_PROVIDER_AVAILABLE,
|
||||
COPILOT_FAILED_TO_GENERATE_TEXT,
|
||||
|
||||
@@ -1,36 +1,366 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
import { Transactional } from '@nestjs-cls/transactional';
|
||||
import { AiPromptRole, Prisma } from '@prisma/client';
|
||||
import { omit } from 'lodash-es';
|
||||
|
||||
import {
|
||||
CopilotPromptInvalid,
|
||||
CopilotSessionDeleted,
|
||||
CopilotSessionInvalidInput,
|
||||
CopilotSessionNotFound,
|
||||
} from '../base';
|
||||
import { BaseModel } from './base';
|
||||
|
||||
interface ChatSessionState {
|
||||
export enum SessionType {
|
||||
Workspace = 'workspace', // docId is null and pinned is false
|
||||
Pinned = 'pinned', // pinned is true
|
||||
Doc = 'doc', // docId points to specific document
|
||||
}
|
||||
|
||||
type ChatAttachment = { attachment: string; mimeType: string } | string;
|
||||
|
||||
type ChatStreamObject = {
|
||||
type: 'text-delta' | 'reasoning' | 'tool-call' | 'tool-result';
|
||||
textDelta?: string;
|
||||
toolCallId?: string;
|
||||
toolName?: string;
|
||||
args?: Record<string, any>;
|
||||
result?: any;
|
||||
};
|
||||
|
||||
type ChatMessage = {
|
||||
id?: string | undefined;
|
||||
role: 'system' | 'assistant' | 'user';
|
||||
content: string;
|
||||
attachments?: ChatAttachment[] | null;
|
||||
params?: Record<string, any> | null;
|
||||
streamObjects?: ChatStreamObject[] | null;
|
||||
createdAt: Date;
|
||||
};
|
||||
|
||||
type ChatSession = {
|
||||
sessionId: string;
|
||||
workspaceId: string;
|
||||
docId: string;
|
||||
docId?: string | null;
|
||||
pinned?: boolean;
|
||||
messages?: ChatMessage[];
|
||||
// connect ids
|
||||
userId: string;
|
||||
promptName: string;
|
||||
}
|
||||
parentSessionId?: string | null;
|
||||
};
|
||||
|
||||
export type UpdateChatSessionData = Partial<
|
||||
Pick<ChatSession, 'docId' | 'pinned' | 'promptName'>
|
||||
>;
|
||||
export type UpdateChatSession = Pick<ChatSession, 'userId' | 'sessionId'> &
|
||||
UpdateChatSessionData;
|
||||
|
||||
export type ListSessionOptions = {
|
||||
sessionId: string | undefined;
|
||||
action: boolean | undefined;
|
||||
fork: boolean | undefined;
|
||||
limit: number | undefined;
|
||||
skip: number | undefined;
|
||||
sessionOrder: 'asc' | 'desc' | undefined;
|
||||
messageOrder: 'asc' | 'desc' | undefined;
|
||||
};
|
||||
|
||||
// TODO(@darkskygit): not ready to replace business codes yet, just for test
|
||||
@Injectable()
|
||||
export class CopilotSessionModel extends BaseModel {
|
||||
async create(state: ChatSessionState) {
|
||||
const row = await this.db.aiSession.create({
|
||||
data: {
|
||||
id: state.sessionId,
|
||||
workspaceId: state.workspaceId,
|
||||
docId: state.docId,
|
||||
// connect
|
||||
userId: state.userId,
|
||||
promptName: state.promptName,
|
||||
},
|
||||
});
|
||||
return row;
|
||||
getSessionType(session: Pick<ChatSession, 'docId' | 'pinned'>): SessionType {
|
||||
if (session.pinned) return SessionType.Pinned;
|
||||
if (!session.docId) return SessionType.Workspace;
|
||||
return SessionType.Doc;
|
||||
}
|
||||
|
||||
checkSessionPrompt(
|
||||
session: Pick<ChatSession, 'docId' | 'pinned'>,
|
||||
promptName: string,
|
||||
promptAction: string | undefined
|
||||
): boolean {
|
||||
const sessionType = this.getSessionType(session);
|
||||
|
||||
// workspace and pinned sessions cannot use action prompts
|
||||
if (
|
||||
[SessionType.Workspace, SessionType.Pinned].includes(sessionType) &&
|
||||
!!promptAction?.trim()
|
||||
) {
|
||||
throw new CopilotPromptInvalid(
|
||||
`${promptName} are not allowed for ${sessionType} sessions`
|
||||
);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// NOTE: just for test, remove it after copilot prompt model is ready
|
||||
async createPrompt(name: string, model: string) {
|
||||
await this.db.aiPrompt.create({
|
||||
data: { name, model },
|
||||
});
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
async create(state: ChatSession) {
|
||||
if (state.pinned) {
|
||||
await this.unpin(state.workspaceId, state.userId);
|
||||
}
|
||||
|
||||
const row = await this.db.aiSession.create({
|
||||
data: {
|
||||
id: state.sessionId,
|
||||
workspaceId: state.workspaceId,
|
||||
docId: state.docId,
|
||||
pinned: state.pinned ?? false,
|
||||
// connect
|
||||
userId: state.userId,
|
||||
promptName: state.promptName,
|
||||
parentSessionId: state.parentSessionId,
|
||||
},
|
||||
});
|
||||
return row;
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
async has(
|
||||
sessionId: string,
|
||||
userId: string,
|
||||
params?: Prisma.AiSessionCountArgs['where']
|
||||
) {
|
||||
return await this.db.aiSession
|
||||
.count({ where: { id: sessionId, userId, ...params } })
|
||||
.then(c => c > 0);
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
async getChatSessionId(state: Omit<ChatSession, 'promptName'>) {
|
||||
const extraCondition: Record<string, any> = {};
|
||||
if (state.parentSessionId) {
|
||||
// also check session id if provided session is forked session
|
||||
extraCondition.id = state.sessionId;
|
||||
extraCondition.parentSessionId = state.parentSessionId;
|
||||
}
|
||||
|
||||
const session = await this.db.aiSession.findFirst({
|
||||
where: {
|
||||
userId: state.userId,
|
||||
workspaceId: state.workspaceId,
|
||||
docId: state.docId,
|
||||
parentSessionId: null,
|
||||
prompt: { action: { equals: null } },
|
||||
...extraCondition,
|
||||
},
|
||||
select: { id: true, deletedAt: true },
|
||||
});
|
||||
if (session?.deletedAt) throw new CopilotSessionDeleted();
|
||||
return session?.id;
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
async getExists<Select extends Prisma.AiSessionSelect>(
|
||||
sessionId: string,
|
||||
select?: Select,
|
||||
where?: Omit<Prisma.AiSessionWhereInput, 'id' | 'deletedAt'>
|
||||
) {
|
||||
return (await this.db.aiSession.findUnique({
|
||||
where: { ...where, id: sessionId, deletedAt: null },
|
||||
select,
|
||||
})) as Prisma.AiSessionGetPayload<{ select: Select }>;
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
async get(sessionId: string) {
|
||||
return await this.getExists(sessionId, {
|
||||
id: true,
|
||||
userId: true,
|
||||
workspaceId: true,
|
||||
docId: true,
|
||||
pinned: true,
|
||||
parentSessionId: true,
|
||||
messages: {
|
||||
select: {
|
||||
id: true,
|
||||
role: true,
|
||||
content: true,
|
||||
attachments: true,
|
||||
params: true,
|
||||
createdAt: true,
|
||||
},
|
||||
orderBy: { createdAt: 'asc' },
|
||||
},
|
||||
promptName: true,
|
||||
});
|
||||
}
|
||||
|
||||
async list(
|
||||
userId: string,
|
||||
workspaceId?: string,
|
||||
docId?: string,
|
||||
options?: ListSessionOptions
|
||||
) {
|
||||
const extraCondition = [];
|
||||
|
||||
if (!options?.action && options?.fork) {
|
||||
// only query forked session if fork == true and action == false
|
||||
extraCondition.push({
|
||||
userId: { not: userId },
|
||||
workspaceId: workspaceId,
|
||||
docId: docId ?? null,
|
||||
id: options?.sessionId ? { equals: options.sessionId } : undefined,
|
||||
// should only find forked session
|
||||
parentSessionId: { not: null },
|
||||
deletedAt: null,
|
||||
});
|
||||
}
|
||||
|
||||
return await this.db.aiSession.findMany({
|
||||
where: {
|
||||
OR: [
|
||||
{
|
||||
userId,
|
||||
workspaceId: workspaceId,
|
||||
docId: docId ?? null,
|
||||
id: options?.sessionId ? { equals: options.sessionId } : undefined,
|
||||
deletedAt: null,
|
||||
},
|
||||
...extraCondition,
|
||||
],
|
||||
},
|
||||
select: {
|
||||
id: true,
|
||||
userId: true,
|
||||
docId: true,
|
||||
pinned: true,
|
||||
promptName: true,
|
||||
tokenCost: true,
|
||||
createdAt: true,
|
||||
messages: {
|
||||
select: {
|
||||
id: true,
|
||||
role: true,
|
||||
content: true,
|
||||
attachments: true,
|
||||
params: true,
|
||||
streamObjects: true,
|
||||
createdAt: true,
|
||||
},
|
||||
orderBy: {
|
||||
// message order is asc by default
|
||||
createdAt: options?.messageOrder === 'desc' ? 'desc' : 'asc',
|
||||
},
|
||||
},
|
||||
},
|
||||
take: options?.limit,
|
||||
skip: options?.skip,
|
||||
orderBy: {
|
||||
// session order is desc by default
|
||||
createdAt: options?.sessionOrder === 'asc' ? 'asc' : 'desc',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
async unpin(workspaceId: string, userId: string): Promise<boolean> {
|
||||
const { count } = await this.db.aiSession.updateMany({
|
||||
where: { userId, workspaceId, pinned: true, deletedAt: null },
|
||||
data: { pinned: false },
|
||||
});
|
||||
|
||||
return count > 0;
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
async update(
|
||||
userId: string,
|
||||
sessionId: string,
|
||||
data: UpdateChatSessionData
|
||||
): Promise<string> {
|
||||
const session = await this.getExists(
|
||||
sessionId,
|
||||
{ id: true, workspaceId: true, docId: true, pinned: true, prompt: true },
|
||||
{ userId }
|
||||
);
|
||||
if (!session) {
|
||||
throw new CopilotSessionNotFound();
|
||||
}
|
||||
if (data.promptName && session.prompt.action) {
|
||||
throw new CopilotSessionInvalidInput(
|
||||
`Cannot update prompt for action: ${session.id}`
|
||||
);
|
||||
}
|
||||
if (data.pinned && data.pinned !== session.pinned) {
|
||||
// if pin the session, unpin exists session in the workspace
|
||||
await this.unpin(session.workspaceId, userId);
|
||||
}
|
||||
|
||||
await this.db.aiSession.update({ where: { id: sessionId }, data });
|
||||
|
||||
return sessionId;
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
async getMessages(
|
||||
sessionId: string,
|
||||
select?: Prisma.AiSessionMessageSelect,
|
||||
orderBy?: Prisma.AiSessionMessageOrderByWithRelationInput
|
||||
) {
|
||||
return this.db.aiSessionMessage.findMany({
|
||||
where: { sessionId },
|
||||
select,
|
||||
orderBy: orderBy ?? { createdAt: 'asc' },
|
||||
});
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
async setMessages(
|
||||
sessionId: string,
|
||||
messages: ChatMessage[],
|
||||
tokenCost: number
|
||||
) {
|
||||
await this.db.aiSessionMessage.createMany({
|
||||
data: messages.map(m => ({
|
||||
...m,
|
||||
attachments: m.attachments || undefined,
|
||||
params: omit(m.params, ['docs']) || undefined,
|
||||
streamObjects: m.streamObjects || undefined,
|
||||
sessionId,
|
||||
})),
|
||||
});
|
||||
|
||||
// only count message generated by user
|
||||
const userMessages = messages.filter(m => m.role === 'user');
|
||||
await this.db.aiSession.update({
|
||||
where: { id: sessionId },
|
||||
data: {
|
||||
messageCost: { increment: userMessages.length },
|
||||
tokenCost: { increment: tokenCost },
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
async revertLatestMessage(
|
||||
sessionId: string,
|
||||
removeLatestUserMessage: boolean
|
||||
) {
|
||||
const id = await this.getExists(sessionId, { id: true }).then(
|
||||
session => session?.id
|
||||
);
|
||||
if (!id) {
|
||||
throw new CopilotSessionNotFound();
|
||||
}
|
||||
const ids = await this.getMessages(id, { id: true, role: true }).then(
|
||||
roles =>
|
||||
roles
|
||||
.slice(
|
||||
roles.findLastIndex(({ role }) => role === AiPromptRole.user) +
|
||||
(removeLatestUserMessage ? 0 : 1)
|
||||
)
|
||||
.map(({ id }) => id)
|
||||
);
|
||||
if (ids.length) {
|
||||
await this.db.aiSessionMessage.deleteMany({ where: { id: { in: ids } } });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,6 +102,8 @@ export class ModelsModule {}
|
||||
export * from './common';
|
||||
export * from './copilot-context';
|
||||
export * from './copilot-job';
|
||||
export * from './copilot-session';
|
||||
export * from './copilot-workspace';
|
||||
export * from './doc';
|
||||
export * from './doc-user';
|
||||
export * from './feature';
|
||||
|
||||
@@ -33,6 +33,7 @@ import { CurrentUser } from '../../core/auth';
|
||||
import { Admin } from '../../core/common';
|
||||
import { AccessController } from '../../core/permission';
|
||||
import { UserType } from '../../core/user';
|
||||
import type { UpdateChatSession } from '../../models';
|
||||
import { PromptService } from './prompt';
|
||||
import { PromptMessage, StreamObject } from './providers';
|
||||
import { ChatSessionService } from './session';
|
||||
@@ -57,22 +58,38 @@ class CreateChatSessionInput {
|
||||
@Field(() => String)
|
||||
workspaceId!: string;
|
||||
|
||||
@Field(() => String)
|
||||
docId!: string;
|
||||
@Field(() => String, { nullable: true })
|
||||
docId?: string;
|
||||
|
||||
@Field(() => String, {
|
||||
description: 'The prompt name to use for the session',
|
||||
})
|
||||
promptName!: string;
|
||||
|
||||
@Field(() => Boolean, { nullable: true })
|
||||
pinned?: boolean;
|
||||
}
|
||||
|
||||
@InputType()
|
||||
class UpdateChatSessionInput {
|
||||
class UpdateChatSessionInput implements Omit<UpdateChatSession, 'userId'> {
|
||||
@Field(() => String)
|
||||
sessionId!: string;
|
||||
|
||||
@Field(() => String, {
|
||||
description: 'The workspace id of the session',
|
||||
nullable: true,
|
||||
})
|
||||
docId!: string | null | undefined;
|
||||
|
||||
@Field(() => Boolean, {
|
||||
description: 'Whether to pin the session',
|
||||
nullable: true,
|
||||
})
|
||||
pinned!: boolean | undefined;
|
||||
|
||||
@Field(() => String, {
|
||||
description: 'The prompt name to use for the session',
|
||||
nullable: true,
|
||||
})
|
||||
promptName!: string;
|
||||
}
|
||||
@@ -219,6 +236,9 @@ class CopilotHistoriesType implements Partial<ChatHistory> {
|
||||
@Field(() => String)
|
||||
sessionId!: string;
|
||||
|
||||
@Field(() => Boolean)
|
||||
pinned!: boolean;
|
||||
|
||||
@Field(() => String, {
|
||||
description: 'An mark identifying which view to use to display the session',
|
||||
nullable: true,
|
||||
@@ -304,6 +324,12 @@ export class CopilotSessionType {
|
||||
@Field(() => ID)
|
||||
id!: string;
|
||||
|
||||
@Field(() => String, { nullable: true })
|
||||
docId!: string | null;
|
||||
|
||||
@Field(() => Boolean)
|
||||
pinned!: boolean;
|
||||
|
||||
@Field(() => ID, { nullable: true })
|
||||
parentSessionId!: string | null;
|
||||
|
||||
@@ -459,22 +485,33 @@ export class CopilotResolver {
|
||||
@Args({ name: 'options', type: () => CreateChatSessionInput })
|
||||
options: CreateChatSessionInput
|
||||
): Promise<string> {
|
||||
await this.ac.user(user.id).doc(options).allowLocal().assert('Doc.Update');
|
||||
// permission check based on session type
|
||||
if (options.docId) {
|
||||
await this.ac
|
||||
.user(user.id)
|
||||
.doc({ workspaceId: options.workspaceId, docId: options.docId })
|
||||
.allowLocal()
|
||||
.assert('Doc.Update');
|
||||
} else {
|
||||
await this.ac
|
||||
.user(user.id)
|
||||
.workspace(options.workspaceId)
|
||||
.allowLocal()
|
||||
.assert('Workspace.Copilot');
|
||||
}
|
||||
|
||||
const lockFlag = `${COPILOT_LOCKER}:session:${user.id}:${options.workspaceId}`;
|
||||
await using lock = await this.mutex.acquire(lockFlag);
|
||||
if (!lock) {
|
||||
throw new TooManyRequest('Server is busy');
|
||||
}
|
||||
|
||||
if (options.workspaceId === options.docId) {
|
||||
// filter out session create request for root doc
|
||||
throw new CopilotDocNotFound({ docId: options.docId });
|
||||
}
|
||||
|
||||
await this.chatSession.checkQuota(user.id);
|
||||
|
||||
return await this.chatSession.create({
|
||||
...options,
|
||||
pinned: options.pinned ?? false,
|
||||
docId: options.docId ?? null,
|
||||
userId: user.id,
|
||||
});
|
||||
}
|
||||
@@ -493,11 +530,19 @@ export class CopilotResolver {
|
||||
throw new CopilotSessionNotFound();
|
||||
}
|
||||
const { workspaceId, docId } = session.config;
|
||||
await this.ac
|
||||
.user(user.id)
|
||||
.doc(workspaceId, docId)
|
||||
.allowLocal()
|
||||
.assert('Doc.Update');
|
||||
if (docId) {
|
||||
await this.ac
|
||||
.user(user.id)
|
||||
.doc(workspaceId, docId)
|
||||
.allowLocal()
|
||||
.assert('Doc.Update');
|
||||
} else {
|
||||
await this.ac
|
||||
.user(user.id)
|
||||
.workspace(workspaceId)
|
||||
.allowLocal()
|
||||
.assert('Workspace.Copilot');
|
||||
}
|
||||
const lockFlag = `${COPILOT_LOCKER}:session:${user.id}:${workspaceId}`;
|
||||
await using lock = await this.mutex.acquire(lockFlag);
|
||||
if (!lock) {
|
||||
@@ -505,7 +550,7 @@ export class CopilotResolver {
|
||||
}
|
||||
|
||||
await this.chatSession.checkQuota(user.id);
|
||||
return await this.chatSession.updateSessionPrompt({
|
||||
return await this.chatSession.updateSession({
|
||||
...options,
|
||||
userId: user.id,
|
||||
});
|
||||
@@ -619,6 +664,8 @@ export class CopilotResolver {
|
||||
return {
|
||||
id: session.sessionId,
|
||||
parentSessionId: session.parentSessionId,
|
||||
docId: session.docId,
|
||||
pinned: session.pinned,
|
||||
promptName: session.prompt.name,
|
||||
model: session.prompt.model,
|
||||
optionalModels: session.prompt.optionalModels,
|
||||
|
||||
@@ -1,34 +1,36 @@
|
||||
import { randomUUID } from 'node:crypto';
|
||||
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
import { AiPromptRole, Prisma, PrismaClient } from '@prisma/client';
|
||||
import { omit } from 'lodash-es';
|
||||
import { Transactional } from '@nestjs-cls/transactional';
|
||||
import { AiPromptRole, PrismaClient } from '@prisma/client';
|
||||
|
||||
import {
|
||||
CopilotActionTaken,
|
||||
CopilotMessageNotFound,
|
||||
CopilotPromptNotFound,
|
||||
CopilotQuotaExceeded,
|
||||
CopilotSessionDeleted,
|
||||
CopilotSessionInvalidInput,
|
||||
CopilotSessionNotFound,
|
||||
PrismaTransaction,
|
||||
} from '../../base';
|
||||
import { QuotaService } from '../../core/quota';
|
||||
import { Models } from '../../models';
|
||||
import {
|
||||
Models,
|
||||
type UpdateChatSession,
|
||||
UpdateChatSessionData,
|
||||
} from '../../models';
|
||||
import { ChatMessageCache } from './message';
|
||||
import { PromptService } from './prompt';
|
||||
import { PromptMessage, PromptParams } from './providers';
|
||||
import {
|
||||
ChatHistory,
|
||||
ChatMessage,
|
||||
type ChatHistory,
|
||||
type ChatMessage,
|
||||
ChatMessageSchema,
|
||||
ChatSessionForkOptions,
|
||||
ChatSessionOptions,
|
||||
ChatSessionPromptUpdateOptions,
|
||||
ChatSessionState,
|
||||
type ChatSessionForkOptions,
|
||||
type ChatSessionOptions,
|
||||
type ChatSessionState,
|
||||
getTokenEncoder,
|
||||
ListHistoriesOptions,
|
||||
SubmittedMessage,
|
||||
type ListHistoriesOptions,
|
||||
type SubmittedMessage,
|
||||
} from './types';
|
||||
|
||||
export class ChatSession implements AsyncDisposable {
|
||||
@@ -229,141 +231,56 @@ export class ChatSessionService {
|
||||
private readonly models: Models
|
||||
) {}
|
||||
|
||||
private async haveSession(
|
||||
sessionId: string,
|
||||
userId: string,
|
||||
tx?: PrismaTransaction,
|
||||
params?: Prisma.AiSessionCountArgs['where']
|
||||
) {
|
||||
const executor = tx ?? this.db;
|
||||
return await executor.aiSession
|
||||
.count({
|
||||
where: {
|
||||
id: sessionId,
|
||||
userId,
|
||||
...params,
|
||||
},
|
||||
})
|
||||
.then(c => c > 0);
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
private async setSession(state: ChatSessionState): Promise<string> {
|
||||
return await this.db.$transaction(async tx => {
|
||||
let sessionId = state.sessionId;
|
||||
const session = this.models.copilotSession;
|
||||
let sessionId = state.sessionId;
|
||||
|
||||
// find existing session if session is chat session
|
||||
if (!state.prompt.action) {
|
||||
const extraCondition: Record<string, any> = {};
|
||||
if (state.parentSessionId) {
|
||||
// also check session id if provided session is forked session
|
||||
extraCondition.id = state.sessionId;
|
||||
extraCondition.parentSessionId = state.parentSessionId;
|
||||
}
|
||||
const { id, deletedAt } =
|
||||
(await tx.aiSession.findFirst({
|
||||
where: {
|
||||
userId: state.userId,
|
||||
workspaceId: state.workspaceId,
|
||||
docId: state.docId,
|
||||
prompt: { action: { equals: null } },
|
||||
parentSessionId: null,
|
||||
...extraCondition,
|
||||
},
|
||||
select: { id: true, deletedAt: true },
|
||||
})) || {};
|
||||
if (deletedAt) throw new CopilotSessionDeleted();
|
||||
if (id) sessionId = id;
|
||||
// find existing session if session is chat session
|
||||
if (!state.prompt.action) {
|
||||
const id = await session.getChatSessionId(state);
|
||||
if (id) sessionId = id;
|
||||
}
|
||||
|
||||
const haveSession = await session.has(sessionId, state.userId);
|
||||
if (haveSession) {
|
||||
// message will only exists when setSession call by session.save
|
||||
if (state.messages.length) {
|
||||
await session.setMessages(
|
||||
sessionId,
|
||||
state.messages,
|
||||
this.calculateTokenSize(state.messages, state.prompt.model)
|
||||
);
|
||||
}
|
||||
} else {
|
||||
await session.create({
|
||||
...state,
|
||||
sessionId,
|
||||
promptName: state.prompt.name,
|
||||
});
|
||||
}
|
||||
|
||||
const haveSession = await this.haveSession(sessionId, state.userId, tx);
|
||||
if (haveSession) {
|
||||
// message will only exists when setSession call by session.save
|
||||
if (state.messages.length) {
|
||||
await tx.aiSessionMessage.createMany({
|
||||
data: state.messages.map(m => ({
|
||||
...m,
|
||||
streamObjects: m.streamObjects || undefined,
|
||||
attachments: m.attachments || undefined,
|
||||
params: omit(m.params, ['docs']) || undefined,
|
||||
sessionId,
|
||||
})),
|
||||
});
|
||||
|
||||
// only count message generated by user
|
||||
const userMessages = state.messages.filter(m => m.role === 'user');
|
||||
await tx.aiSession.update({
|
||||
where: { id: sessionId },
|
||||
data: {
|
||||
messageCost: { increment: userMessages.length },
|
||||
tokenCost: {
|
||||
increment: this.calculateTokenSize(
|
||||
state.messages,
|
||||
state.prompt.model
|
||||
),
|
||||
},
|
||||
},
|
||||
});
|
||||
}
|
||||
} else {
|
||||
await tx.aiSession.create({
|
||||
data: {
|
||||
id: sessionId,
|
||||
workspaceId: state.workspaceId,
|
||||
docId: state.docId,
|
||||
// connect
|
||||
userId: state.userId,
|
||||
promptName: state.prompt.name,
|
||||
parentSessionId: state.parentSessionId,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
return sessionId;
|
||||
});
|
||||
return sessionId;
|
||||
}
|
||||
|
||||
async getSession(sessionId: string): Promise<ChatSessionState | undefined> {
|
||||
return await this.db.aiSession
|
||||
.findUnique({
|
||||
where: { id: sessionId, deletedAt: null },
|
||||
select: {
|
||||
id: true,
|
||||
userId: true,
|
||||
workspaceId: true,
|
||||
docId: true,
|
||||
parentSessionId: true,
|
||||
messages: {
|
||||
select: {
|
||||
id: true,
|
||||
role: true,
|
||||
content: true,
|
||||
attachments: true,
|
||||
params: true,
|
||||
createdAt: true,
|
||||
},
|
||||
orderBy: { createdAt: 'asc' },
|
||||
},
|
||||
promptName: true,
|
||||
},
|
||||
})
|
||||
.then(async session => {
|
||||
if (!session) return;
|
||||
const prompt = await this.prompt.get(session.promptName);
|
||||
if (!prompt)
|
||||
throw new CopilotPromptNotFound({ name: session.promptName });
|
||||
const session = await this.models.copilotSession.get(sessionId);
|
||||
if (!session) return;
|
||||
const prompt = await this.prompt.get(session.promptName);
|
||||
if (!prompt) throw new CopilotPromptNotFound({ name: session.promptName });
|
||||
|
||||
const messages = ChatMessageSchema.array().safeParse(session.messages);
|
||||
const messages = ChatMessageSchema.array().safeParse(session.messages);
|
||||
|
||||
return {
|
||||
sessionId: session.id,
|
||||
userId: session.userId,
|
||||
workspaceId: session.workspaceId,
|
||||
docId: session.docId,
|
||||
parentSessionId: session.parentSessionId,
|
||||
prompt,
|
||||
messages: messages.success ? messages.data : [],
|
||||
};
|
||||
});
|
||||
return {
|
||||
sessionId: session.id,
|
||||
userId: session.userId,
|
||||
workspaceId: session.workspaceId,
|
||||
docId: session.docId,
|
||||
pinned: session.pinned,
|
||||
parentSessionId: session.parentSessionId,
|
||||
prompt,
|
||||
messages: messages.success ? messages.data : [],
|
||||
};
|
||||
}
|
||||
|
||||
// revert the latest messages not generate by user
|
||||
@@ -372,34 +289,10 @@ export class ChatSessionService {
|
||||
sessionId: string,
|
||||
removeLatestUserMessage: boolean
|
||||
) {
|
||||
await this.db.$transaction(async tx => {
|
||||
const id = await tx.aiSession
|
||||
.findUnique({
|
||||
where: { id: sessionId, deletedAt: null },
|
||||
select: { id: true },
|
||||
})
|
||||
.then(session => session?.id);
|
||||
if (!id) {
|
||||
throw new CopilotSessionNotFound();
|
||||
}
|
||||
const ids = await tx.aiSessionMessage
|
||||
.findMany({
|
||||
where: { sessionId: id },
|
||||
select: { id: true, role: true },
|
||||
orderBy: { createdAt: 'asc' },
|
||||
})
|
||||
.then(roles =>
|
||||
roles
|
||||
.slice(
|
||||
roles.findLastIndex(({ role }) => role === AiPromptRole.user) +
|
||||
(removeLatestUserMessage ? 0 : 1)
|
||||
)
|
||||
.map(({ id }) => id)
|
||||
);
|
||||
if (ids.length) {
|
||||
await tx.aiSessionMessage.deleteMany({ where: { id: { in: ids } } });
|
||||
}
|
||||
});
|
||||
await this.models.copilotSession.revertLatestMessage(
|
||||
sessionId,
|
||||
removeLatestUserMessage
|
||||
);
|
||||
}
|
||||
|
||||
private calculateTokenSize(messages: PromptMessage[], model: string): number {
|
||||
@@ -441,6 +334,7 @@ export class ChatSessionService {
|
||||
userId: true,
|
||||
workspaceId: true,
|
||||
docId: true,
|
||||
pinned: true,
|
||||
parentSessionId: true,
|
||||
promptName: true,
|
||||
},
|
||||
@@ -457,6 +351,7 @@ export class ChatSessionService {
|
||||
userId: session.userId,
|
||||
workspaceId: session.workspaceId,
|
||||
docId: session.docId,
|
||||
pinned: session.pinned,
|
||||
parentSessionId: session.parentSessionId,
|
||||
prompt,
|
||||
};
|
||||
@@ -471,138 +366,83 @@ export class ChatSessionService {
|
||||
docId?: string,
|
||||
options?: ListHistoriesOptions
|
||||
): Promise<ChatHistory[]> {
|
||||
const extraCondition = [];
|
||||
|
||||
if (!options?.action && options?.fork) {
|
||||
// only query forked session if fork == true and action == false
|
||||
extraCondition.push({
|
||||
userId: { not: userId },
|
||||
workspaceId: workspaceId,
|
||||
docId: workspaceId === docId ? undefined : docId,
|
||||
id: options?.sessionId ? { equals: options.sessionId } : undefined,
|
||||
// should only find forked session
|
||||
parentSessionId: { not: null },
|
||||
deletedAt: null,
|
||||
});
|
||||
}
|
||||
|
||||
return await this.db.aiSession
|
||||
.findMany({
|
||||
where: {
|
||||
OR: [
|
||||
{
|
||||
userId,
|
||||
workspaceId: workspaceId,
|
||||
docId: workspaceId === docId ? undefined : docId,
|
||||
id: options?.sessionId
|
||||
? { equals: options.sessionId }
|
||||
: undefined,
|
||||
deletedAt: null,
|
||||
},
|
||||
...extraCondition,
|
||||
],
|
||||
},
|
||||
select: {
|
||||
id: true,
|
||||
userId: true,
|
||||
promptName: true,
|
||||
tokenCost: true,
|
||||
createdAt: true,
|
||||
messages: {
|
||||
select: {
|
||||
id: true,
|
||||
role: true,
|
||||
content: true,
|
||||
streamObjects: true,
|
||||
attachments: true,
|
||||
params: true,
|
||||
createdAt: true,
|
||||
},
|
||||
orderBy: {
|
||||
// message order is asc by default
|
||||
createdAt: options?.messageOrder === 'desc' ? 'desc' : 'asc',
|
||||
},
|
||||
},
|
||||
},
|
||||
take: options?.limit,
|
||||
skip: options?.skip,
|
||||
orderBy: {
|
||||
// session order is desc by default
|
||||
createdAt: options?.sessionOrder === 'asc' ? 'asc' : 'desc',
|
||||
},
|
||||
})
|
||||
.then(sessions =>
|
||||
Promise.all(
|
||||
sessions.map(
|
||||
async ({
|
||||
id,
|
||||
userId: uid,
|
||||
promptName,
|
||||
tokenCost,
|
||||
messages,
|
||||
createdAt,
|
||||
}) => {
|
||||
try {
|
||||
const prompt = await this.prompt.get(promptName);
|
||||
if (!prompt) {
|
||||
throw new CopilotPromptNotFound({ name: promptName });
|
||||
}
|
||||
if (
|
||||
// filter out the user's session that not match the action option
|
||||
(uid === userId && !!options?.action !== !!prompt.action) ||
|
||||
// filter out the non chat session from other user
|
||||
(uid !== userId && !!prompt.action)
|
||||
) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const ret = ChatMessageSchema.array().safeParse(messages);
|
||||
if (ret.success) {
|
||||
// render system prompt
|
||||
const preload = (
|
||||
options?.withPrompt
|
||||
? prompt
|
||||
.finish(ret.data[0]?.params || {}, id)
|
||||
.filter(({ role }) => role !== 'system')
|
||||
: []
|
||||
) as ChatMessage[];
|
||||
|
||||
// `createdAt` is required for history sorting in frontend
|
||||
// let's fake the creating time of prompt messages
|
||||
preload.forEach((msg, i) => {
|
||||
msg.createdAt = new Date(
|
||||
createdAt.getTime() - preload.length - i - 1
|
||||
);
|
||||
});
|
||||
|
||||
return {
|
||||
sessionId: id,
|
||||
action: prompt.action || null,
|
||||
tokens: tokenCost,
|
||||
createdAt,
|
||||
messages: preload.concat(ret.data).map(m => ({
|
||||
...m,
|
||||
attachments: m.attachments
|
||||
?.map(a => (typeof a === 'string' ? a : a.attachment))
|
||||
.filter(a => !!a),
|
||||
})),
|
||||
};
|
||||
} else {
|
||||
this.logger.error(
|
||||
`Unexpected message schema: ${JSON.stringify(ret.error)}`
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
this.logger.error('Unexpected error in listHistories', e);
|
||||
}
|
||||
const sessions = await this.models.copilotSession.list(
|
||||
userId,
|
||||
workspaceId,
|
||||
docId,
|
||||
options
|
||||
);
|
||||
const histories = await Promise.all(
|
||||
sessions.map(
|
||||
async ({
|
||||
id,
|
||||
userId: uid,
|
||||
pinned,
|
||||
promptName,
|
||||
tokenCost,
|
||||
messages,
|
||||
createdAt,
|
||||
}) => {
|
||||
try {
|
||||
const prompt = await this.prompt.get(promptName);
|
||||
if (!prompt) {
|
||||
throw new CopilotPromptNotFound({ name: promptName });
|
||||
}
|
||||
if (
|
||||
// filter out the user's session that not match the action option
|
||||
(uid === userId && !!options?.action !== !!prompt.action) ||
|
||||
// filter out the non chat session from other user
|
||||
(uid !== userId && !!prompt.action)
|
||||
) {
|
||||
return undefined;
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
const ret = ChatMessageSchema.array().safeParse(messages);
|
||||
if (ret.success) {
|
||||
// render system prompt
|
||||
const preload = (
|
||||
options?.withPrompt
|
||||
? prompt
|
||||
.finish(ret.data[0]?.params || {}, id)
|
||||
.filter(({ role }) => role !== 'system')
|
||||
: []
|
||||
) as ChatMessage[];
|
||||
|
||||
// `createdAt` is required for history sorting in frontend
|
||||
// let's fake the creating time of prompt messages
|
||||
preload.forEach((msg, i) => {
|
||||
msg.createdAt = new Date(
|
||||
createdAt.getTime() - preload.length - i - 1
|
||||
);
|
||||
});
|
||||
|
||||
return {
|
||||
sessionId: id,
|
||||
pinned,
|
||||
action: prompt.action || null,
|
||||
tokens: tokenCost,
|
||||
createdAt,
|
||||
messages: preload.concat(ret.data).map(m => ({
|
||||
...m,
|
||||
attachments: m.attachments
|
||||
?.map(a => (typeof a === 'string' ? a : a.attachment))
|
||||
.filter(a => !!a),
|
||||
})),
|
||||
};
|
||||
} else {
|
||||
this.logger.error(
|
||||
`Unexpected message schema: ${JSON.stringify(ret.error)}`
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
this.logger.error('Unexpected error in listHistories', e);
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
)
|
||||
.then(histories =>
|
||||
histories.filter((v): v is NonNullable<typeof v> => !!v)
|
||||
);
|
||||
);
|
||||
|
||||
return histories.filter((v): v is NonNullable<typeof v> => !!v);
|
||||
}
|
||||
|
||||
async getQuota(userId: string) {
|
||||
@@ -637,6 +477,17 @@ export class ChatSessionService {
|
||||
throw new CopilotPromptNotFound({ name: options.promptName });
|
||||
}
|
||||
|
||||
if (options.pinned) {
|
||||
await this.unpin(options.workspaceId, options.userId);
|
||||
}
|
||||
|
||||
// validate prompt compatibility with session type
|
||||
this.models.copilotSession.checkSessionPrompt(
|
||||
options,
|
||||
prompt.name,
|
||||
prompt.action
|
||||
);
|
||||
|
||||
return await this.setSession({
|
||||
...options,
|
||||
sessionId,
|
||||
@@ -647,30 +498,47 @@ export class ChatSessionService {
|
||||
});
|
||||
}
|
||||
|
||||
async updateSessionPrompt(
|
||||
options: ChatSessionPromptUpdateOptions
|
||||
): Promise<string> {
|
||||
const prompt = await this.prompt.get(options.promptName);
|
||||
if (!prompt) {
|
||||
this.logger.error(`Prompt not found: ${options.promptName}`);
|
||||
throw new CopilotPromptNotFound({ name: options.promptName });
|
||||
@Transactional()
|
||||
async unpin(workspaceId: string, userId: string) {
|
||||
await this.models.copilotSession.unpin(workspaceId, userId);
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
async updateSession(options: UpdateChatSession): Promise<string> {
|
||||
const session = await this.getSession(options.sessionId);
|
||||
if (!session) {
|
||||
throw new CopilotSessionNotFound();
|
||||
}
|
||||
return await this.db.$transaction(async tx => {
|
||||
let sessionId = options.sessionId;
|
||||
const haveSession = await this.haveSession(
|
||||
sessionId,
|
||||
options.userId,
|
||||
tx,
|
||||
{ prompt: { action: null } }
|
||||
);
|
||||
if (haveSession) {
|
||||
await tx.aiSession.update({
|
||||
where: { id: sessionId },
|
||||
data: { promptName: prompt.name },
|
||||
});
|
||||
|
||||
const finalData: UpdateChatSessionData = {};
|
||||
if (options.promptName) {
|
||||
const prompt = await this.prompt.get(options.promptName);
|
||||
if (!prompt) {
|
||||
this.logger.error(`Prompt not found: ${options.promptName}`);
|
||||
throw new CopilotPromptNotFound({ name: options.promptName });
|
||||
}
|
||||
return sessionId;
|
||||
});
|
||||
|
||||
this.models.copilotSession.checkSessionPrompt(
|
||||
session,
|
||||
prompt.name,
|
||||
prompt.action
|
||||
);
|
||||
finalData.promptName = prompt.name;
|
||||
}
|
||||
finalData.pinned = options.pinned;
|
||||
finalData.docId = options.docId;
|
||||
|
||||
if (Object.keys(finalData).length === 0) {
|
||||
throw new CopilotSessionInvalidInput(
|
||||
'No valid fields to update in the session'
|
||||
);
|
||||
}
|
||||
|
||||
return await this.models.copilotSession.update(
|
||||
options.userId,
|
||||
options.sessionId,
|
||||
finalData
|
||||
);
|
||||
}
|
||||
|
||||
async fork(options: ChatSessionForkOptions): Promise<string> {
|
||||
@@ -678,6 +546,10 @@ export class ChatSessionService {
|
||||
if (!state) {
|
||||
throw new CopilotSessionNotFound();
|
||||
}
|
||||
if (state.pinned) {
|
||||
await this.unpin(options.workspaceId, options.userId);
|
||||
}
|
||||
|
||||
let messages = state.messages.map(m => ({ ...m, id: undefined }));
|
||||
if (options.latestMessageId) {
|
||||
const lastMessageIdx = state.messages.findLastIndex(
|
||||
@@ -706,7 +578,9 @@ export class ChatSessionService {
|
||||
}
|
||||
|
||||
async cleanup(
|
||||
options: Omit<ChatSessionOptions, 'promptName'> & { sessionIds: string[] }
|
||||
options: Omit<ChatSessionOptions, 'pinned' | 'promptName'> & {
|
||||
sessionIds: string[];
|
||||
}
|
||||
) {
|
||||
return await this.db.$transaction(async tx => {
|
||||
const sessions = await tx.aiSession.findMany({
|
||||
|
||||
@@ -84,6 +84,7 @@ export type ChatMessage = z.infer<typeof ChatMessageSchema>;
|
||||
export const ChatHistorySchema = z
|
||||
.object({
|
||||
sessionId: z.string(),
|
||||
pinned: z.boolean(),
|
||||
action: z.string().nullable(),
|
||||
tokens: z.number(),
|
||||
messages: z.array(ChatMessageSchema),
|
||||
@@ -105,17 +106,13 @@ export interface ChatSessionOptions {
|
||||
// connect ids
|
||||
userId: string;
|
||||
workspaceId: string;
|
||||
docId: string;
|
||||
promptName: string;
|
||||
}
|
||||
|
||||
export interface ChatSessionPromptUpdateOptions
|
||||
extends Pick<ChatSessionState, 'sessionId' | 'userId'> {
|
||||
docId: string | null;
|
||||
promptName: string;
|
||||
pinned: boolean;
|
||||
}
|
||||
|
||||
export interface ChatSessionForkOptions
|
||||
extends Omit<ChatSessionOptions, 'promptName'> {
|
||||
extends Omit<ChatSessionOptions, 'pinned' | 'promptName'> {
|
||||
sessionId: string;
|
||||
latestMessageId?: string;
|
||||
}
|
||||
|
||||
@@ -241,6 +241,7 @@ type CopilotHistories {
|
||||
action: String
|
||||
createdAt: DateTime!
|
||||
messages: [ChatMessage!]!
|
||||
pinned: Boolean!
|
||||
sessionId: String!
|
||||
|
||||
"""The number of tokens used in the session"""
|
||||
@@ -332,10 +333,12 @@ type CopilotQuota {
|
||||
}
|
||||
|
||||
type CopilotSessionType {
|
||||
docId: String
|
||||
id: ID!
|
||||
model: String!
|
||||
optionalModels: [String!]!
|
||||
parentSessionId: ID
|
||||
pinned: Boolean!
|
||||
promptName: String!
|
||||
}
|
||||
|
||||
@@ -386,7 +389,8 @@ input CreateChatMessageInput {
|
||||
}
|
||||
|
||||
input CreateChatSessionInput {
|
||||
docId: String!
|
||||
docId: String
|
||||
pinned: Boolean
|
||||
|
||||
"""The prompt name to use for the session"""
|
||||
promptName: String!
|
||||
@@ -570,6 +574,7 @@ enum ErrorNames {
|
||||
COPILOT_PROVIDER_SIDE_ERROR
|
||||
COPILOT_QUOTA_EXCEEDED
|
||||
COPILOT_SESSION_DELETED
|
||||
COPILOT_SESSION_INVALID_INPUT
|
||||
COPILOT_SESSION_NOT_FOUND
|
||||
COPILOT_TRANSCRIPTION_AUDIO_NOT_PROVIDED
|
||||
COPILOT_TRANSCRIPTION_JOB_EXISTS
|
||||
@@ -1751,8 +1756,14 @@ input UpdateAppConfigInput {
|
||||
}
|
||||
|
||||
input UpdateChatSessionInput {
|
||||
"""The workspace id of the session"""
|
||||
docId: String
|
||||
|
||||
"""Whether to pin the session"""
|
||||
pinned: Boolean
|
||||
|
||||
"""The prompt name to use for the session"""
|
||||
promptName: String!
|
||||
promptName: String
|
||||
sessionId: String!
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user