feat(server): paginated list endpoint (#13026)

fix AI-323
This commit is contained in:
DarkSky
2025-07-08 17:11:58 +08:00
committed by GitHub
parent 8c49a45162
commit 6dac94d90a
36 changed files with 1136 additions and 702 deletions

View File

@@ -1891,7 +1891,7 @@ test('should handle generateSessionTitle correctly under various conditions', as
await session.generateSessionTitle({ sessionId });
if (testCase.expectSnapshot) {
const sessionState = await session.getSession(sessionId);
const sessionState = await session.getSessionInfo(sessionId);
t.snapshot(
{
chatWithPromptCalled: testCase.expectNotCalled

View File

@@ -265,26 +265,31 @@ export class CopilotSessionModel extends BaseModel {
userId: true,
workspaceId: true,
docId: true,
pinned: true,
parentSessionId: true,
pinned: true,
title: true,
promptName: true,
tokenCost: true,
createdAt: true,
updatedAt: true,
messages: {
select: {
id: true,
role: true,
content: true,
streamObjects: true,
attachments: true,
streamObjects: true,
params: true,
createdAt: true,
},
orderBy: { createdAt: 'asc' },
},
promptName: true,
});
}
async list(options: ListSessionOptions) {
private getListConditions(
options: ListSessionOptions
): Prisma.AiSessionWhereInput {
const { userId, sessionId, workspaceId, docId, action, fork } = options;
function getNullCond<T>(
@@ -330,8 +335,18 @@ export class CopilotSessionModel extends BaseModel {
});
}
return { OR: conditions };
}
async count(options: ListSessionOptions) {
return await this.db.aiSession.count({
where: this.getListConditions(options),
});
}
async list(options: ListSessionOptions) {
return await this.db.aiSession.findMany({
where: { OR: conditions },
where: this.getListConditions(options),
select: {
id: true,
userId: true,
@@ -351,8 +366,8 @@ export class CopilotSessionModel extends BaseModel {
role: true,
content: true,
attachments: true,
params: true,
streamObjects: true,
params: true,
createdAt: true,
},
orderBy: {

View File

@@ -25,6 +25,9 @@ import {
CopilotFailedToCreateMessage,
CopilotSessionNotFound,
type FileUpload,
paginate,
Paginated,
PaginationInput,
RequestMutex,
Throttle,
TooManyRequest,
@@ -38,12 +41,7 @@ import { PromptService } from './prompt';
import { PromptMessage, StreamObject } from './providers';
import { ChatSessionService } from './session';
import { CopilotStorage } from './storage';
import {
type ChatHistory,
type ChatMessage,
type ChatSessionState,
SubmittedMessage,
} from './types';
import { type ChatHistory, type ChatMessage, SubmittedMessage } from './types';
export const COPILOT_LOCKER = 'copilot';
@@ -186,6 +184,9 @@ class QueryChatHistoriesInput
@Field(() => String, { nullable: true })
sessionId: string | undefined;
@Field(() => Boolean, { nullable: true })
withMessages: boolean | undefined;
@Field(() => Boolean, { nullable: true })
withPrompt: boolean | undefined;
}
@@ -239,7 +240,7 @@ class ChatMessageType implements Partial<ChatMessage> {
}
@ObjectType('CopilotHistories')
class CopilotHistoriesType implements Partial<ChatHistory> {
class CopilotHistoriesType implements Omit<ChatHistory, 'userId'> {
@Field(() => String)
sessionId!: string;
@@ -249,8 +250,17 @@ class CopilotHistoriesType implements Partial<ChatHistory> {
@Field(() => String, { nullable: true })
docId!: string | null;
@Field(() => Boolean)
pinned!: boolean;
@Field(() => String, { nullable: true })
parentSessionId!: string | null;
@Field(() => String)
promptName!: string;
@Field(() => String)
model!: string;
@Field(() => [String])
optionalModels!: string[];
@Field(() => String, {
description: 'An mark identifying which view to use to display the session',
@@ -258,6 +268,12 @@ class CopilotHistoriesType implements Partial<ChatHistory> {
})
action!: string | null;
@Field(() => Boolean)
pinned!: boolean;
@Field(() => String, { nullable: true })
title!: string | null;
@Field(() => Number, {
description: 'The number of tokens used in the session',
})
@@ -273,6 +289,11 @@ class CopilotHistoriesType implements Partial<ChatHistory> {
updatedAt!: Date;
}
@ObjectType()
export class PaginatedCopilotHistoriesType extends Paginated(
CopilotHistoriesType
) {}
@ObjectType('CopilotQuota')
class CopilotQuotaType {
@Field(() => SafeIntResolver, { nullable: true })
@@ -421,7 +442,7 @@ export class CopilotResolver {
@Args('sessionId') sessionId: string
): Promise<CopilotSessionType> {
await this.assertPermission(user, copilot);
const session = await this.chatSession.getSession(sessionId);
const session = await this.chatSession.getSessionInfo(sessionId);
if (!session) {
throw new NotFoundException('Session not found');
}
@@ -430,6 +451,7 @@ export class CopilotResolver {
@ResolveField(() => [CopilotSessionType], {
description: 'Get the session list in the workspace',
deprecationReason: 'use `chats` instead',
complexity: 2,
})
async sessions(
@@ -447,11 +469,12 @@ export class CopilotResolver {
Object.assign({}, copilot, { docId: maybeDocId })
);
const sessions = await this.chatSession.listSessions(
Object.assign({}, options, appendOptions)
const sessions = await this.chatSession.list(
Object.assign({}, options, appendOptions),
false
);
if (appendOptions.docId) {
type Session = Omit<ChatSessionState, 'messages'> & { docId: string };
type Session = ChatHistory & { docId: string };
const filtered = sessions.filter((s): s is Session => !!s.docId);
const accessible = await this.ac
.user(user.id)
@@ -463,7 +486,9 @@ export class CopilotResolver {
}
}
@ResolveField(() => [CopilotHistoriesType], {})
@ResolveField(() => [CopilotHistoriesType], {
deprecationReason: 'use `chats` instead',
})
@CallMetric('ai', 'histories')
async histories(
@Parent() copilot: CopilotType,
@@ -478,8 +503,9 @@ export class CopilotResolver {
await this.assertPermission(user, { workspaceId, docId });
}
const histories = await this.chatSession.listHistories(
Object.assign({}, options, { userId: user.id, workspaceId, docId })
const histories = await this.chatSession.list(
Object.assign({}, options, { userId: user.id, workspaceId, docId }),
true
);
return histories.map(h => ({
@@ -491,6 +517,48 @@ export class CopilotResolver {
}));
}
@ResolveField(() => PaginatedCopilotHistoriesType, {})
@CallMetric('ai', 'histories')
async chats(
@Parent() copilot: CopilotType,
@CurrentUser() user: CurrentUser,
@Args('pagination', PaginationInput.decode) pagination: PaginationInput,
@Args('docId', { nullable: true }) docId?: string,
@Args('options', { nullable: true }) options?: QueryChatHistoriesInput
): Promise<PaginatedCopilotHistoriesType> {
const workspaceId = copilot.workspaceId;
if (!workspaceId) {
return paginate([], 'updatedAt', pagination, 0);
} else {
await this.assertPermission(user, { workspaceId, docId });
}
const finalOptions = Object.assign(
{},
options,
{ userId: user.id, workspaceId, docId },
{ skip: pagination.offset, limit: pagination.first }
);
const totalCount = await this.chatSession.count(finalOptions);
const histories = await this.chatSession.list(
finalOptions,
!!options?.withMessages
);
return paginate(
histories.map(h => ({
...h,
// filter out empty messages
messages: h.messages?.filter(
m => m.content || m.attachments?.length
) as ChatMessageType[],
})),
'updatedAt',
pagination,
totalCount
);
}
@Mutation(() => String, {
description: 'Create a chat session',
})
@@ -657,18 +725,9 @@ export class CopilotResolver {
}
private transformToSessionType(
session: Omit<ChatSessionState, 'messages'>
session: Omit<ChatHistory, 'messages'>
): CopilotSessionType {
return {
id: session.sessionId,
parentSessionId: session.parentSessionId,
docId: session.docId,
pinned: session.pinned,
title: session.title,
promptName: session.prompt.name,
model: session.prompt.model,
optionalModels: session.prompt.optionalModels,
};
return { id: session.sessionId, ...session };
}
}

View File

@@ -4,6 +4,7 @@ import { Injectable, Logger } from '@nestjs/common';
import { ModuleRef } from '@nestjs/core';
import { Transactional } from '@nestjs-cls/transactional';
import { AiPromptRole } from '@prisma/client';
import { pick } from 'lodash-es';
import {
CopilotActionTaken,
@@ -25,7 +26,7 @@ import {
UpdateChatSessionOptions,
} from '../../models';
import { ChatMessageCache } from './message';
import { PromptService } from './prompt';
import { ChatPrompt, PromptService } from './prompt';
import {
CopilotProviderFactory,
ModelOutputType,
@@ -240,6 +241,14 @@ export class ChatSession implements AsyncDisposable {
}
}
type Session = NonNullable<
Awaited<ReturnType<Models['copilotSession']['get']>>
>;
type SessionHistory = ChatHistory & {
prompt: ChatPrompt;
};
@Injectable()
export class ChatSessionService {
private readonly logger = new Logger(ChatSessionService.name);
@@ -253,27 +262,55 @@ export class ChatSessionService {
private readonly prompt: PromptService
) {}
async getSession(sessionId: string): Promise<ChatSessionState | undefined> {
const session = await this.models.copilotSession.get(sessionId);
if (!session) return;
private getMessage(session: Session): ChatMessage[] {
if (!Array.isArray(session.messages) || !session.messages.length) {
return [];
}
const messages = ChatMessageSchema.array().safeParse(session.messages);
if (!messages.success) {
this.logger.error(
`Unexpected message schema: ${JSON.stringify(messages.error)}`
);
return [];
}
return messages.data;
}
private async getHistory(session: Session): Promise<SessionHistory> {
const prompt = await this.prompt.get(session.promptName);
if (!prompt) throw new CopilotPromptNotFound({ name: session.promptName });
const messages = ChatMessageSchema.array().safeParse(session.messages);
return {
...pick(session, [
'userId',
'workspaceId',
'docId',
'parentSessionId',
'pinned',
'title',
'createdAt',
'updatedAt',
]),
sessionId: session.id,
userId: session.userId,
workspaceId: session.workspaceId,
docId: session.docId,
pinned: session.pinned,
title: session.title,
parentSessionId: session.parentSessionId,
tokens: session.tokenCost,
messages: this.getMessage(session),
// prompt info
prompt,
messages: messages.success ? messages.data : [],
action: prompt.action || null,
model: prompt.model,
optionalModels: prompt.optionalModels || null,
promptName: prompt.name,
};
}
async getSessionInfo(sessionId: string): Promise<SessionHistory | undefined> {
const session = await this.models.copilotSession.get(sessionId);
if (!session) return;
return await this.getHistory(session);
}
// revert the latest messages not generate by user
// after revert, we can retry the action
async revertLatestMessage(
@@ -286,116 +323,70 @@ export class ChatSessionService {
);
}
async listSessions(
options: ListSessionOptions
): Promise<Omit<ChatSessionState, 'messages'>[]> {
const sessions = await this.models.copilotSession.list({
...options,
withMessages: false,
});
return Promise.all(
sessions.map(async session => {
const prompt = await this.prompt.get(session.promptName);
if (!prompt)
throw new CopilotPromptNotFound({ name: session.promptName });
return {
sessionId: session.id,
userId: session.userId,
workspaceId: session.workspaceId,
docId: session.docId,
pinned: session.pinned,
title: session.title,
parentSessionId: session.parentSessionId,
prompt,
};
})
);
async count(options: ListSessionOptions): Promise<number> {
return await this.models.copilotSession.count(options);
}
async listHistories(options: ListSessionOptions): Promise<ChatHistory[]> {
const { userId } = options;
async list(
options: ListSessionOptions,
withMessages: boolean
): Promise<ChatHistory[]> {
const { userId: reqUserId } = options;
const sessions = await this.models.copilotSession.list({
...options,
withMessages: true,
withMessages,
});
const histories = await Promise.all(
sessions.map(
async ({
userId: uid,
id,
workspaceId,
docId,
pinned,
title,
promptName,
tokenCost,
messages,
createdAt,
updatedAt,
}) => {
try {
const prompt = await this.prompt.get(promptName);
if (!prompt) {
throw new CopilotPromptNotFound({ name: promptName });
}
sessions.map(async session => {
const { userId, id: sessionId, createdAt } = session;
try {
const { prompt, messages, ...baseHistory } =
await this.getHistory(session);
if (withMessages) {
if (
// filter out the user's session that not match the action option
(uid === userId && !!options?.action !== !!prompt.action) ||
(userId === reqUserId && !!options?.action !== !!prompt.action) ||
// filter out the non chat session from other user
(uid !== userId && !!prompt.action)
(userId !== reqUserId && !!prompt.action)
) {
return undefined;
}
const ret = ChatMessageSchema.array().safeParse(messages);
if (ret.success) {
// render system prompt
const preload = (
options?.withPrompt
? prompt
.finish(ret.data[0]?.params || {}, id)
.filter(({ role }) => role !== 'system')
: []
) as ChatMessage[];
// render system prompt
const preload = (
options?.withPrompt
? prompt
.finish(messages[0]?.params || {}, sessionId)
.filter(({ role }) => role !== 'system')
: []
) as ChatMessage[];
// `createdAt` is required for history sorting in frontend
// let's fake the creating time of prompt messages
preload.forEach((msg, i) => {
msg.createdAt = new Date(
createdAt.getTime() - preload.length - i - 1
);
});
return {
sessionId: id,
workspaceId,
docId,
pinned,
title,
action: prompt.action || null,
tokens: tokenCost,
createdAt,
updatedAt,
messages: preload.concat(ret.data).map(m => ({
...m,
attachments: m.attachments
?.map(a => (typeof a === 'string' ? a : a.attachment))
.filter(a => !!a),
})),
};
} else {
this.logger.error(
`Unexpected message schema: ${JSON.stringify(ret.error)}`
// `createdAt` is required for history sorting in frontend
// let's fake the creating time of prompt messages
preload.forEach((msg, i) => {
msg.createdAt = new Date(
createdAt.getTime() - preload.length - i - 1
);
}
} catch (e) {
this.logger.error('Unexpected error in listHistories', e);
});
return {
...baseHistory,
messages: preload.concat(messages).map(m => ({
...m,
attachments: m.attachments
?.map(a => (typeof a === 'string' ? a : a.attachment))
.filter(a => !!a),
})),
};
} else {
return { ...baseHistory, messages: [] };
}
return undefined;
} catch (e) {
this.logger.error('Unexpected error in list ChatHistories', e);
}
)
return undefined;
})
);
return histories.filter((v): v is NonNullable<typeof v> => !!v);
@@ -461,7 +452,7 @@ export class ChatSessionService {
@Transactional()
async update(options: UpdateChatSession): Promise<string> {
const session = await this.getSession(options.sessionId);
const session = await this.getSessionInfo(options.sessionId);
if (!session) {
throw new CopilotSessionNotFound();
}
@@ -494,14 +485,14 @@ export class ChatSessionService {
@Transactional()
async fork(options: ChatSessionForkOptions): Promise<string> {
const state = await this.getSession(options.sessionId);
if (!state) {
const session = await this.getSessionInfo(options.sessionId);
if (!session) {
throw new CopilotSessionNotFound();
}
let messages = state.messages.map(m => ({ ...m, id: undefined }));
let messages = session.messages.map(m => ({ ...m, id: undefined }));
if (options.latestMessageId) {
const lastMessageIdx = state.messages.findLastIndex(
const lastMessageIdx = session.messages.findLastIndex(
({ id, role }) =>
role === AiPromptRole.assistant && id === options.latestMessageId
);
@@ -514,7 +505,7 @@ export class ChatSessionService {
}
return await this.models.copilotSession.fork({
...state,
...session,
userId: options.userId,
sessionId: randomUUID(),
parentSessionId: options.sessionId,
@@ -544,7 +535,7 @@ export class ChatSessionService {
* @returns
*/
async get(sessionId: string): Promise<ChatSession | null> {
const state = await this.getSession(sessionId);
const state = await this.getSessionInfo(sessionId);
if (state) {
return new ChatSession(this.messageCache, state, async state => {
await this.models.copilotSession.updateMessages(state);

View File

@@ -46,12 +46,19 @@ export type ChatMessage = z.infer<typeof ChatMessageSchema>;
export const ChatHistorySchema = z
.object({
userId: z.string(),
sessionId: z.string(),
workspaceId: z.string(),
docId: z.string().nullable(),
parentSessionId: z.string().nullable(),
pinned: z.boolean(),
title: z.string().nullable(),
action: z.string().nullable(),
model: z.string(),
optionalModels: z.array(z.string()),
promptName: z.string(),
tokens: z.number(),
messages: z.array(ChatMessageSchema),
createdAt: z.date(),
@@ -69,32 +76,26 @@ export type SubmittedMessage = z.infer<typeof SubmittedMessageSchema>;
// ======== Chat Session ========
export interface ChatSessionOptions {
// connect ids
userId: string;
workspaceId: string;
docId: string | null;
promptName: string;
pinned: boolean;
export type ChatSessionOptions = Pick<
ChatHistory,
'userId' | 'workspaceId' | 'docId' | 'promptName' | 'pinned'
> & {
reuseLatestChat?: boolean;
}
};
export interface ChatSessionForkOptions
extends Omit<ChatSessionOptions, 'pinned' | 'promptName'> {
sessionId: string;
export type ChatSessionForkOptions = Pick<
ChatHistory,
'userId' | 'sessionId' | 'workspaceId' | 'docId'
> & {
latestMessageId?: string;
}
};
export interface ChatSessionState
extends Omit<ChatSessionOptions, 'promptName'> {
title: string | null;
// connect ids
sessionId: string;
parentSessionId: string | null;
// states
export type ChatSessionState = Pick<
ChatHistory,
'userId' | 'sessionId' | 'workspaceId' | 'docId' | 'messages'
> & {
prompt: ChatPrompt;
messages: ChatMessage[];
}
};
export type CopilotContextFile = {
id: string; // fileId

View File

@@ -208,10 +208,11 @@ type ContextWorkspaceEmbeddingStatus {
type Copilot {
audioTranscription(blobId: String, jobId: String): TranscriptionResultType
chats(docId: String, options: QueryChatHistoriesInput, pagination: PaginationInput!): PaginatedCopilotHistoriesType!
"""Get the context list of a session"""
contexts(contextId: String, sessionId: String): [CopilotContext!]!
histories(docId: String, options: QueryChatHistoriesInput): [CopilotHistories!]!
histories(docId: String, options: QueryChatHistoriesInput): [CopilotHistories!]! @deprecated(reason: "use `chats` instead")
"""Get the quota of the user in the workspace"""
quota: CopilotQuota!
@@ -220,7 +221,7 @@ type Copilot {
session(sessionId: String!): CopilotSessionType!
"""Get the session list in the workspace"""
sessions(docId: String, options: QueryChatSessionsInput): [CopilotSessionType!]!
sessions(docId: String, options: QueryChatSessionsInput): [CopilotSessionType!]! @deprecated(reason: "use `chats` instead")
workspaceId: ID
}
@@ -313,8 +314,13 @@ type CopilotHistories {
createdAt: DateTime!
docId: String
messages: [ChatMessage!]!
model: String!
optionalModels: [String!]!
parentSessionId: String
pinned: Boolean!
promptName: String!
sessionId: String!
title: String
"""The number of tokens used in the session"""
tokens: Int!
@@ -322,6 +328,11 @@ type CopilotHistories {
workspaceId: String!
}
type CopilotHistoriesTypeEdge {
cursor: String!
node: CopilotHistories!
}
type CopilotInvalidContextDataType {
contextId: String!
}
@@ -1421,6 +1432,12 @@ type PaginatedCommentObjectType {
totalCount: Int!
}
type PaginatedCopilotHistoriesType {
edges: [CopilotHistoriesTypeEdge!]!
pageInfo: PageInfo!
totalCount: Int!
}
type PaginatedCopilotWorkspaceFileType {
edges: [CopilotWorkspaceFileTypeEdge!]!
pageInfo: PageInfo!
@@ -1552,6 +1569,7 @@ input QueryChatHistoriesInput {
sessionId: String
sessionOrder: ChatHistoryOrder
skip: Int
withMessages: Boolean
withPrompt: Boolean
}