feat: copilot controller (#6272)

fix CLOUD-27
This commit is contained in:
darkskygit
2024-04-10 11:58:40 +00:00
parent e6a576551a
commit 7c38a54f81
18 changed files with 729 additions and 179 deletions

View File

@@ -11,6 +11,7 @@ import {
ChatMessageSchema,
getTokenEncoder,
PromptMessage,
PromptMessageSchema,
PromptParams,
} from './types';
@@ -105,37 +106,62 @@ export class ChatSession implements AsyncDisposable {
@Injectable()
export class ChatSessionService {
private readonly logger = new Logger(ChatSessionService.name);
constructor(
private readonly db: PrismaClient,
private readonly prompt: PromptService
) {}
private async setSession(state: ChatSessionState): Promise<void> {
await this.db.aiSession.upsert({
where: {
id: state.sessionId,
},
update: {
messages: {
create: state.messages.map((m, idx) => ({ idx, ...m })),
},
},
create: {
id: state.sessionId,
messages: { create: state.messages },
// connect
user: { connect: { id: state.userId } },
workspace: { connect: { id: state.workspaceId } },
doc: {
connect: {
id_workspaceId: {
id: state.docId,
private async setSession(state: ChatSessionState): Promise<string> {
return await this.db.$transaction(async tx => {
let sessionId = state.sessionId;
// find existing session if session is chat session
if (!state.prompt.action) {
const { id } =
(await tx.aiSession.findFirst({
where: {
userId: state.userId,
workspaceId: state.workspaceId,
docId: state.docId,
prompt: { action: { equals: null } },
},
select: { id: true },
})) || {};
if (id) sessionId = id;
}
await tx.aiSession.upsert({
where: {
id: sessionId,
userId: state.userId,
},
update: {
messages: {
// delete old messages
deleteMany: {},
create: state.messages.map(m => ({
...m,
params: m.params || undefined,
})),
},
},
prompt: { connect: { name: state.prompt.name } },
},
create: {
id: sessionId,
workspaceId: state.workspaceId,
docId: state.docId,
messages: {
create: state.messages.map(m => ({
...m,
params: m.params || undefined,
})),
},
// connect
user: { connect: { id: state.userId } },
prompt: { connect: { name: state.prompt.name } },
},
});
return sessionId;
});
}
@@ -171,6 +197,7 @@ export class ChatSessionService {
})
.then(async session => {
if (!session) return;
const messages = ChatMessageSchema.array().safeParse(session.messages);
return {
@@ -184,18 +211,58 @@ export class ChatSessionService {
});
}
async listHistories(
private calculateTokenSize(
messages: PromptMessage[],
model: AvailableModel
): number {
const encoder = getTokenEncoder(model);
return messages
.map(m => encoder?.encode_ordinary(m.content).length || 0)
.reduce((total, length) => total + length, 0);
}
async countUserActions(userId: string): Promise<number> {
return await this.db.aiSession.count({
where: { userId, prompt: { action: { not: null } } },
});
}
async listSessions(
userId: string,
workspaceId: string,
docId: string,
options: ListHistoriesOptions
options?: { docId?: string; action?: boolean }
): Promise<string[]> {
return await this.db.aiSession
.findMany({
where: {
userId,
workspaceId,
docId: workspaceId === options?.docId ? undefined : options?.docId,
prompt: {
action: options?.action ? { not: null } : null,
},
},
select: { id: true },
})
.then(sessions => sessions.map(({ id }) => id));
}
async listHistories(
userId: string,
workspaceId?: string,
docId?: string,
options?: ListHistoriesOptions
): Promise<ChatHistory[]> {
return await this.db.aiSession
.findMany({
where: {
userId,
workspaceId: workspaceId,
docId: workspaceId === docId ? undefined : docId,
prompt: { action: { not: null } },
id: options.sessionId ? { equals: options.sessionId } : undefined,
prompt: {
action: options?.action ? { not: null } : null,
},
id: options?.sessionId ? { equals: options.sessionId } : undefined,
},
select: {
id: true,
@@ -210,20 +277,33 @@ export class ChatSessionService {
},
},
},
take: options.limit,
skip: options.skip,
take: options?.limit,
skip: options?.skip,
orderBy: { createdAt: 'desc' },
})
.then(sessions =>
sessions
.map(({ id, prompt, messages }) => {
const ret = ChatMessageSchema.array().safeParse(messages);
if (ret.success) {
const encoder = getTokenEncoder(prompt.model as AvailableModel);
const tokens = ret.data
.map(m => encoder?.encode_ordinary(m.content).length || 0)
.reduce((total, length) => total + length, 0);
return { sessionId: id, tokens, messages: ret.data };
try {
const ret = PromptMessageSchema.array().safeParse(messages);
if (ret.success) {
const tokens = this.calculateTokenSize(
ret.data,
prompt.model as AvailableModel
);
return {
sessionId: id,
action: prompt.action || undefined,
tokens,
messages: ret.data,
};
} else {
this.logger.error(
`Unexpected message schema: ${JSON.stringify(ret.error)}`
);
}
} catch (e) {
this.logger.error('Unexpected error in listHistories', e);
}
return undefined;
})
@@ -238,8 +318,12 @@ export class ChatSessionService {
this.logger.error(`Prompt not found: ${options.promptName}`);
throw new Error('Prompt not found');
}
await this.setSession({ ...options, sessionId, prompt, messages: [] });
return sessionId;
return await this.setSession({
...options,
sessionId,
prompt,
messages: [],
});
}
/**