feat: copilot controller (#6272)

fix CLOUD-27
This commit is contained in:
darkskygit
2024-04-10 11:58:40 +00:00
parent e6a576551a
commit 7c38a54f81
18 changed files with 729 additions and 179 deletions

View File

@@ -80,11 +80,5 @@ ALTER TABLE "ai_sessions_messages" ADD CONSTRAINT "ai_sessions_messages_session_
-- AddForeignKey -- AddForeignKey
ALTER TABLE "ai_sessions_metadata" ADD CONSTRAINT "ai_sessions_metadata_user_id_fkey" FOREIGN KEY ("user_id") REFERENCES "users"("id") ON DELETE CASCADE ON UPDATE CASCADE; ALTER TABLE "ai_sessions_metadata" ADD CONSTRAINT "ai_sessions_metadata_user_id_fkey" FOREIGN KEY ("user_id") REFERENCES "users"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "ai_sessions_metadata" ADD CONSTRAINT "ai_sessions_metadata_workspace_id_fkey" FOREIGN KEY ("workspace_id") REFERENCES "workspaces"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "ai_sessions_metadata" ADD CONSTRAINT "ai_sessions_metadata_doc_id_workspace_id_fkey" FOREIGN KEY ("doc_id", "workspace_id") REFERENCES "snapshots"("guid", "workspace_id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey -- AddForeignKey
ALTER TABLE "ai_sessions_metadata" ADD CONSTRAINT "ai_sessions_metadata_prompt_name_fkey" FOREIGN KEY ("prompt_name") REFERENCES "ai_prompts_metadata"("name") ON DELETE CASCADE ON UPDATE CASCADE; ALTER TABLE "ai_sessions_metadata" ADD CONSTRAINT "ai_sessions_metadata_prompt_name_fkey" FOREIGN KEY ("prompt_name") REFERENCES "ai_prompts_metadata"("name") ON DELETE CASCADE ON UPDATE CASCADE;

View File

@@ -97,7 +97,6 @@ model Workspace {
permissions WorkspaceUserPermission[] permissions WorkspaceUserPermission[]
pagePermissions WorkspacePageUserPermission[] pagePermissions WorkspacePageUserPermission[]
features WorkspaceFeatures[] features WorkspaceFeatures[]
aiSessions AiSession[]
@@map("workspaces") @@map("workspaces")
} }
@@ -323,8 +322,6 @@ model Snapshot {
// but the created time of last seen update that has been merged into snapshot. // but the created time of last seen update that has been merged into snapshot.
updatedAt DateTime @map("updated_at") @db.Timestamptz(6) updatedAt DateTime @map("updated_at") @db.Timestamptz(6)
aiSessions AiSession[]
@@id([id, workspaceId]) @@id([id, workspaceId])
@@map("snapshots") @@map("snapshots")
} }
@@ -485,11 +482,9 @@ model AiSession {
promptName String @map("prompt_name") @db.VarChar(32) promptName String @map("prompt_name") @db.VarChar(32)
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6) createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6)
user User @relation(fields: [userId], references: [id], onDelete: Cascade) user User @relation(fields: [userId], references: [id], onDelete: Cascade)
workspace Workspace @relation(fields: [workspaceId], references: [id], onDelete: Cascade) prompt AiPrompt @relation(fields: [promptName], references: [name], onDelete: Cascade)
doc Snapshot @relation(fields: [docId, workspaceId], references: [id, workspaceId], onDelete: Cascade) messages AiSessionMessage[]
prompt AiPrompt @relation(fields: [promptName], references: [name], onDelete: Cascade)
messages AiSessionMessage[]
@@map("ai_sessions_metadata") @@map("ai_sessions_metadata")
} }

View File

@@ -26,6 +26,22 @@ export class PermissionService {
return data?.type as Permission; return data?.type as Permission;
} }
/**
* check whether a workspace exists and has any one can access it
* @param workspaceId workspace id
* @returns
*/
async hasWorkspace(workspaceId: string) {
return await this.prisma.workspaceUserPermission
.count({
where: {
workspaceId,
accepted: true,
},
})
.then(count => count > 0);
}
async getOwnedWorkspaces(userId: string) { async getOwnedWorkspaces(userId: string) {
return this.prisma.workspaceUserPermission return this.prisma.workspaceUserPermission
.findMany({ .findMany({
@@ -96,6 +112,23 @@ export class PermissionService {
return count !== 0; return count !== 0;
} }
/**
* only check permission if the workspace is a cloud workspace
* @param workspaceId workspace id
* @param userId user id, check if is a public workspace if not provided
* @param permission default is read
*/
async checkCloudWorkspace(
workspaceId: string,
userId?: string,
permission: Permission = Permission.Read
) {
const hasWorkspace = await this.hasWorkspace(workspaceId);
if (hasWorkspace) {
await this.checkWorkspace(workspaceId, userId, permission);
}
}
async checkWorkspace( async checkWorkspace(
ws: string, ws: string,
user?: string, user?: string,
@@ -263,6 +296,25 @@ export class PermissionService {
/// End regin: workspace permission /// End regin: workspace permission
/// Start regin: page permission /// Start regin: page permission
/**
* only check permission if the workspace is a cloud workspace
* @param workspaceId workspace id
* @param pageId page id aka doc id
* @param userId user id, check if is a public page if not provided
* @param permission default is read
*/
async checkCloudPagePermission(
workspaceId: string,
pageId: string,
userId?: string,
permission = Permission.Read
) {
const hasWorkspace = await this.hasWorkspace(workspaceId);
if (hasWorkspace) {
await this.checkPagePermission(workspaceId, pageId, userId, permission);
}
}
async checkPagePermission( async checkPagePermission(
ws: string, ws: string,
page: string, page: string,

View File

@@ -0,0 +1,151 @@
import {
BadRequestException,
Controller,
Get,
InternalServerErrorException,
Param,
Query,
Req,
Sse,
} from '@nestjs/common';
import {
concatMap,
connect,
EMPTY,
from,
map,
merge,
Observable,
switchMap,
toArray,
} from 'rxjs';
import { Public } from '../../core/auth';
import { CurrentUser } from '../../core/auth/current-user';
import { CopilotProviderService } from './providers';
import { ChatSessionService } from './session';
import { CopilotCapability } from './types';
export interface ChatEvent {
data: string;
id?: string;
}
@Controller('/api/copilot')
export class CopilotController {
constructor(
private readonly chatSession: ChatSessionService,
private readonly provider: CopilotProviderService
) {}
@Public()
@Get('/chat/:sessionId')
async chat(
@CurrentUser() user: CurrentUser,
@Req() req: Request,
@Param('sessionId') sessionId: string,
@Query('message') content: string,
@Query() params: Record<string, string | string[]>
): Promise<string> {
const provider = this.provider.getProviderByCapability(
CopilotCapability.TextToText
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
}
const session = await this.chatSession.get(sessionId);
if (!session) {
throw new BadRequestException('Session not found');
}
if (!content || !content.trim()) {
throw new BadRequestException('Message is empty');
}
session.push({
role: 'user',
content: decodeURIComponent(content),
createdAt: new Date(),
});
try {
delete params.message;
const content = await provider.generateText(
session.finish(params),
session.model,
{
signal: req.signal,
user: user.id,
}
);
session.push({
role: 'assistant',
content,
createdAt: new Date(),
});
await session.save();
return content;
} catch (e: any) {
throw new InternalServerErrorException(
e.message || "Couldn't generate text"
);
}
}
@Public()
@Sse('/chat/:sessionId/stream')
async chatStream(
@CurrentUser() user: CurrentUser,
@Req() req: Request,
@Param('sessionId') sessionId: string,
@Query('message') content: string,
@Query() params: Record<string, string>
): Promise<Observable<ChatEvent>> {
const provider = this.provider.getProviderByCapability(
CopilotCapability.TextToText
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
}
const session = await this.chatSession.get(sessionId);
if (!session) {
throw new BadRequestException('Session not found');
}
if (!content || !content.trim()) {
throw new BadRequestException('Message is empty');
}
session.push({
role: 'user',
content: decodeURIComponent(content),
createdAt: new Date(),
});
delete params.message;
return from(
provider.generateTextStream(session.finish(params), session.model, {
signal: req.signal,
user: user.id,
})
).pipe(
connect(shared$ =>
merge(
// actual chat event stream
shared$.pipe(map(data => ({ id: sessionId, data }))),
// save the generated text to the session
shared$.pipe(
toArray(),
concatMap(values => {
session.push({
role: 'assistant',
content: values.join(''),
createdAt: new Date(),
});
return from(session.save());
}),
switchMap(() => EMPTY)
)
)
)
);
}
}

View File

@@ -1,6 +1,8 @@
import { ServerFeature } from '../../core/config'; import { ServerFeature } from '../../core/config';
import { QuotaService } from '../../core/quota';
import { PermissionService } from '../../core/workspaces/permission'; import { PermissionService } from '../../core/workspaces/permission';
import { Plugin } from '../registry'; import { Plugin } from '../registry';
import { CopilotController } from './controller';
import { PromptService } from './prompt'; import { PromptService } from './prompt';
import { import {
assertProvidersConfigs, assertProvidersConfigs,
@@ -8,6 +10,7 @@ import {
OpenAIProvider, OpenAIProvider,
registerCopilotProvider, registerCopilotProvider,
} from './providers'; } from './providers';
import { CopilotResolver, UserCopilotResolver } from './resolver';
import { ChatSessionService } from './session'; import { ChatSessionService } from './session';
registerCopilotProvider(OpenAIProvider); registerCopilotProvider(OpenAIProvider);
@@ -16,10 +19,14 @@ registerCopilotProvider(OpenAIProvider);
name: 'copilot', name: 'copilot',
providers: [ providers: [
PermissionService, PermissionService,
QuotaService,
ChatSessionService, ChatSessionService,
CopilotResolver,
UserCopilotResolver,
PromptService, PromptService,
CopilotProviderService, CopilotProviderService,
], ],
controllers: [CopilotController],
contributesTo: ServerFeature.Copilot, contributesTo: ServerFeature.Copilot,
if: config => { if: config => {
if (config.flavor.graphql) { if (config.flavor.graphql) {

View File

@@ -38,16 +38,16 @@ export class ChatPrompt {
) { ) {
return new ChatPrompt( return new ChatPrompt(
options.name, options.name,
options.action, options.action || undefined,
options.model, options.model || undefined,
options.messages options.messages
); );
} }
constructor( constructor(
public readonly name: string, public readonly name: string,
public readonly action: string | null, public readonly action: string | undefined,
public readonly model: string | null, public readonly model: string | undefined,
private readonly messages: PromptMessage[] private readonly messages: PromptMessage[]
) { ) {
this.encoder = getTokenEncoder(model); this.encoder = getTokenEncoder(model);

View File

@@ -3,14 +3,16 @@ import assert from 'node:assert';
import { ClientOptions, OpenAI } from 'openai'; import { ClientOptions, OpenAI } from 'openai';
import { import {
ChatMessage,
ChatMessageRole, ChatMessageRole,
CopilotCapability, CopilotCapability,
CopilotProviderType, CopilotProviderType,
CopilotTextToEmbeddingProvider, CopilotTextToEmbeddingProvider,
CopilotTextToTextProvider, CopilotTextToTextProvider,
PromptMessage,
} from '../types'; } from '../types';
const DEFAULT_DIMENSIONS = 256;
export class OpenAIProvider export class OpenAIProvider
implements CopilotTextToTextProvider, CopilotTextToEmbeddingProvider implements CopilotTextToTextProvider, CopilotTextToEmbeddingProvider
{ {
@@ -50,7 +52,7 @@ export class OpenAIProvider
return OpenAIProvider.capabilities; return OpenAIProvider.capabilities;
} }
private chatToGPTMessage(messages: ChatMessage[]) { private chatToGPTMessage(messages: PromptMessage[]) {
// filter redundant fields // filter redundant fields
return messages.map(message => ({ return messages.map(message => ({
role: message.role, role: message.role,
@@ -63,7 +65,7 @@ export class OpenAIProvider
embeddings, embeddings,
model, model,
}: { }: {
messages?: ChatMessage[]; messages?: PromptMessage[];
embeddings?: string[]; embeddings?: string[];
model: string; model: string;
}) { }) {
@@ -106,7 +108,7 @@ export class OpenAIProvider
// ====== text to text ====== // ====== text to text ======
async generateText( async generateText(
messages: ChatMessage[], messages: PromptMessage[],
model: string = 'gpt-3.5-turbo', model: string = 'gpt-3.5-turbo',
options: { options: {
temperature?: number; temperature?: number;
@@ -134,8 +136,8 @@ export class OpenAIProvider
} }
async *generateTextStream( async *generateTextStream(
messages: ChatMessage[], messages: PromptMessage[],
model: string, model: string = 'gpt-3.5-turbo',
options: { options: {
temperature?: number; temperature?: number;
maxTokens?: number; maxTokens?: number;
@@ -179,7 +181,7 @@ export class OpenAIProvider
dimensions: number; dimensions: number;
signal?: AbortSignal; signal?: AbortSignal;
user?: string; user?: string;
} = { dimensions: 256 } } = { dimensions: DEFAULT_DIMENSIONS }
): Promise<number[][]> { ): Promise<number[][]> {
messages = Array.isArray(messages) ? messages : [messages]; messages = Array.isArray(messages) ? messages : [messages];
this.checkParams({ embeddings: messages, model }); this.checkParams({ embeddings: messages, model });
@@ -187,7 +189,7 @@ export class OpenAIProvider
const result = await this.instance.embeddings.create({ const result = await this.instance.embeddings.create({
model: model, model: model,
input: messages, input: messages,
dimensions: options.dimensions, dimensions: options.dimensions || DEFAULT_DIMENSIONS,
user: options.user, user: options.user,
}); });
return result.data.map(e => e.embedding); return result.data.map(e => e.embedding);

View File

@@ -0,0 +1,260 @@
import {
Args,
Field,
ID,
InputType,
Mutation,
ObjectType,
Parent,
registerEnumType,
ResolveField,
Resolver,
} from '@nestjs/graphql';
import { SafeIntResolver } from 'graphql-scalars';
import { CurrentUser, Public } from '../../core/auth';
import { QuotaService } from '../../core/quota';
import { UserType } from '../../core/user';
import { PermissionService } from '../../core/workspaces/permission';
import {
MutexService,
PaymentRequiredException,
TooManyRequestsException,
} from '../../fundamentals';
import { ChatSessionService, ListHistoriesOptions } from './session';
import { AvailableModels, type ChatHistory, type ChatMessage } from './types';
registerEnumType(AvailableModels, { name: 'CopilotModel' });
// ================== Input Types ==================
@InputType()
class CreateChatSessionInput {
@Field(() => String)
workspaceId!: string;
@Field(() => String)
docId!: string;
@Field(() => String, {
description: 'An mark identifying which view to use to display the session',
nullable: true,
})
action!: string | undefined;
@Field(() => String, {
description: 'The prompt name to use for the session',
})
promptName!: string;
}
@InputType()
class QueryChatHistoriesInput implements Partial<ListHistoriesOptions> {
@Field(() => Boolean, { nullable: true })
action: boolean | undefined;
@Field(() => Number, { nullable: true })
limit: number | undefined;
@Field(() => Number, { nullable: true })
skip: number | undefined;
@Field(() => String, { nullable: true })
sessionId: string | undefined;
}
// ================== Return Types ==================
@ObjectType('ChatMessage')
class ChatMessageType implements Partial<ChatMessage> {
@Field(() => String)
role!: 'system' | 'assistant' | 'user';
@Field(() => String)
content!: string;
@Field(() => [String], { nullable: true })
attachments!: string[];
@Field(() => Date, { nullable: true })
createdAt!: Date | undefined;
}
@ObjectType('CopilotHistories')
class CopilotHistoriesType implements Partial<ChatHistory> {
@Field(() => String)
sessionId!: string;
@Field(() => String, {
description: 'An mark identifying which view to use to display the session',
})
action!: string;
@Field(() => Number, {
description: 'The number of tokens used in the session',
})
tokens!: number;
@Field(() => [ChatMessageType])
messages!: ChatMessageType[];
}
@ObjectType('CopilotQuota')
class CopilotQuotaType {
@Field(() => SafeIntResolver)
limit!: number;
@Field(() => SafeIntResolver)
used!: number;
}
// ================== Resolver ==================
@ObjectType('Copilot')
export class CopilotType {
@Field(() => ID, { nullable: true })
workspaceId!: string | undefined;
}
@Resolver(() => CopilotType)
export class CopilotResolver {
constructor(
private readonly permissions: PermissionService,
private readonly quota: QuotaService,
private readonly mutex: MutexService,
private readonly chatSession: ChatSessionService
) {}
@ResolveField(() => CopilotQuotaType, {
name: 'quota',
description: 'Get the quota of the user in the workspace',
complexity: 2,
})
async getQuota(@CurrentUser() user: CurrentUser) {
const quota = await this.quota.getUserQuota(user.id);
const limit = quota.feature.copilotActionLimit;
const actions = await this.chatSession.countUserActions(user.id);
const chats = await this.chatSession
.listHistories(user.id)
.then(histories =>
histories.reduce(
(acc, h) => acc + h.messages.filter(m => m.role === 'user').length,
0
)
);
return { limit, used: actions + chats };
}
@ResolveField(() => [String], {
description: 'Get the session list of chats in the workspace',
complexity: 2,
})
async chats(
@Parent() copilot: CopilotType,
@CurrentUser() user: CurrentUser
) {
if (!copilot.workspaceId) return [];
await this.permissions.checkCloudWorkspace(copilot.workspaceId, user.id);
return await this.chatSession.listSessions(user.id, copilot.workspaceId);
}
@ResolveField(() => [String], {
description: 'Get the session list of actions in the workspace',
complexity: 2,
})
async actions(
@Parent() copilot: CopilotType,
@CurrentUser() user: CurrentUser
) {
if (!copilot.workspaceId) return [];
await this.permissions.checkCloudWorkspace(copilot.workspaceId, user.id);
return await this.chatSession.listSessions(user.id, copilot.workspaceId, {
action: true,
});
}
@ResolveField(() => [CopilotHistoriesType], {})
async histories(
@Parent() copilot: CopilotType,
@CurrentUser() user: CurrentUser,
@Args('docId', { nullable: true }) docId?: string,
@Args({
name: 'options',
type: () => QueryChatHistoriesInput,
nullable: true,
})
options?: QueryChatHistoriesInput
) {
const workspaceId = copilot.workspaceId;
if (!workspaceId) {
return [];
} else if (docId) {
await this.permissions.checkCloudPagePermission(
workspaceId,
docId,
user.id
);
} else {
await this.permissions.checkCloudWorkspace(workspaceId, user.id);
}
return await this.chatSession.listHistories(
user.id,
workspaceId,
docId,
options
);
}
@Public()
@Mutation(() => String, {
description: 'Create a chat session',
})
async createCopilotSession(
@CurrentUser() user: CurrentUser,
@Args({ name: 'options', type: () => CreateChatSessionInput })
options: CreateChatSessionInput
) {
await this.permissions.checkCloudPagePermission(
options.workspaceId,
options.docId,
user.id
);
const lockFlag = `session:${user.id}:${options.workspaceId}`;
await using lock = await this.mutex.lock(lockFlag);
if (!lock) {
return new TooManyRequestsException('Server is busy');
}
const { limit, used } = await this.getQuota(user);
if (limit && Number.isFinite(limit) && used >= limit) {
return new PaymentRequiredException(
`You have reached the limit of actions in this workspace, please upgrade your plan.`
);
}
const session = await this.chatSession.create({
...options,
userId: user.id,
});
return session;
}
}
@Resolver(() => UserType)
export class UserCopilotResolver {
constructor(private readonly permissions: PermissionService) {}
@ResolveField(() => CopilotType)
async copilot(
@CurrentUser() user: CurrentUser,
@Args('workspaceId', { nullable: true }) workspaceId?: string
) {
if (workspaceId) {
await this.permissions.checkCloudWorkspace(workspaceId, user.id);
}
return { workspaceId };
}
}

View File

@@ -11,6 +11,7 @@ import {
ChatMessageSchema, ChatMessageSchema,
getTokenEncoder, getTokenEncoder,
PromptMessage, PromptMessage,
PromptMessageSchema,
PromptParams, PromptParams,
} from './types'; } from './types';
@@ -105,37 +106,62 @@ export class ChatSession implements AsyncDisposable {
@Injectable() @Injectable()
export class ChatSessionService { export class ChatSessionService {
private readonly logger = new Logger(ChatSessionService.name); private readonly logger = new Logger(ChatSessionService.name);
constructor( constructor(
private readonly db: PrismaClient, private readonly db: PrismaClient,
private readonly prompt: PromptService private readonly prompt: PromptService
) {} ) {}
private async setSession(state: ChatSessionState): Promise<void> { private async setSession(state: ChatSessionState): Promise<string> {
await this.db.aiSession.upsert({ return await this.db.$transaction(async tx => {
where: { let sessionId = state.sessionId;
id: state.sessionId,
}, // find existing session if session is chat session
update: { if (!state.prompt.action) {
messages: { const { id } =
create: state.messages.map((m, idx) => ({ idx, ...m })), (await tx.aiSession.findFirst({
}, where: {
}, userId: state.userId,
create: {
id: state.sessionId,
messages: { create: state.messages },
// connect
user: { connect: { id: state.userId } },
workspace: { connect: { id: state.workspaceId } },
doc: {
connect: {
id_workspaceId: {
id: state.docId,
workspaceId: state.workspaceId, workspaceId: state.workspaceId,
docId: state.docId,
prompt: { action: { equals: null } },
}, },
select: { id: true },
})) || {};
if (id) sessionId = id;
}
await tx.aiSession.upsert({
where: {
id: sessionId,
userId: state.userId,
},
update: {
messages: {
// delete old messages
deleteMany: {},
create: state.messages.map(m => ({
...m,
params: m.params || undefined,
})),
}, },
}, },
prompt: { connect: { name: state.prompt.name } }, create: {
}, id: sessionId,
workspaceId: state.workspaceId,
docId: state.docId,
messages: {
create: state.messages.map(m => ({
...m,
params: m.params || undefined,
})),
},
// connect
user: { connect: { id: state.userId } },
prompt: { connect: { name: state.prompt.name } },
},
});
return sessionId;
}); });
} }
@@ -171,6 +197,7 @@ export class ChatSessionService {
}) })
.then(async session => { .then(async session => {
if (!session) return; if (!session) return;
const messages = ChatMessageSchema.array().safeParse(session.messages); const messages = ChatMessageSchema.array().safeParse(session.messages);
return { return {
@@ -184,18 +211,58 @@ export class ChatSessionService {
}); });
} }
async listHistories( private calculateTokenSize(
messages: PromptMessage[],
model: AvailableModel
): number {
const encoder = getTokenEncoder(model);
return messages
.map(m => encoder?.encode_ordinary(m.content).length || 0)
.reduce((total, length) => total + length, 0);
}
async countUserActions(userId: string): Promise<number> {
return await this.db.aiSession.count({
where: { userId, prompt: { action: { not: null } } },
});
}
async listSessions(
userId: string,
workspaceId: string, workspaceId: string,
docId: string, options?: { docId?: string; action?: boolean }
options: ListHistoriesOptions ): Promise<string[]> {
return await this.db.aiSession
.findMany({
where: {
userId,
workspaceId,
docId: workspaceId === options?.docId ? undefined : options?.docId,
prompt: {
action: options?.action ? { not: null } : null,
},
},
select: { id: true },
})
.then(sessions => sessions.map(({ id }) => id));
}
async listHistories(
userId: string,
workspaceId?: string,
docId?: string,
options?: ListHistoriesOptions
): Promise<ChatHistory[]> { ): Promise<ChatHistory[]> {
return await this.db.aiSession return await this.db.aiSession
.findMany({ .findMany({
where: { where: {
userId,
workspaceId: workspaceId, workspaceId: workspaceId,
docId: workspaceId === docId ? undefined : docId, docId: workspaceId === docId ? undefined : docId,
prompt: { action: { not: null } }, prompt: {
id: options.sessionId ? { equals: options.sessionId } : undefined, action: options?.action ? { not: null } : null,
},
id: options?.sessionId ? { equals: options.sessionId } : undefined,
}, },
select: { select: {
id: true, id: true,
@@ -210,20 +277,33 @@ export class ChatSessionService {
}, },
}, },
}, },
take: options.limit, take: options?.limit,
skip: options.skip, skip: options?.skip,
orderBy: { createdAt: 'desc' }, orderBy: { createdAt: 'desc' },
}) })
.then(sessions => .then(sessions =>
sessions sessions
.map(({ id, prompt, messages }) => { .map(({ id, prompt, messages }) => {
const ret = ChatMessageSchema.array().safeParse(messages); try {
if (ret.success) { const ret = PromptMessageSchema.array().safeParse(messages);
const encoder = getTokenEncoder(prompt.model as AvailableModel); if (ret.success) {
const tokens = ret.data const tokens = this.calculateTokenSize(
.map(m => encoder?.encode_ordinary(m.content).length || 0) ret.data,
.reduce((total, length) => total + length, 0); prompt.model as AvailableModel
return { sessionId: id, tokens, messages: ret.data }; );
return {
sessionId: id,
action: prompt.action || undefined,
tokens,
messages: ret.data,
};
} else {
this.logger.error(
`Unexpected message schema: ${JSON.stringify(ret.error)}`
);
}
} catch (e) {
this.logger.error('Unexpected error in listHistories', e);
} }
return undefined; return undefined;
}) })
@@ -238,8 +318,12 @@ export class ChatSessionService {
this.logger.error(`Prompt not found: ${options.promptName}`); this.logger.error(`Prompt not found: ${options.promptName}`);
throw new Error('Prompt not found'); throw new Error('Prompt not found');
} }
await this.setSession({ ...options, sessionId, prompt, messages: [] }); return await this.setSession({
return sessionId; ...options,
sessionId,
prompt,
messages: [],
});
} }
/** /**

View File

@@ -76,8 +76,9 @@ export type ChatMessage = z.infer<typeof ChatMessageSchema>;
export const ChatHistorySchema = z export const ChatHistorySchema = z
.object({ .object({
sessionId: z.string(), sessionId: z.string(),
action: z.string().optional(),
tokens: z.number(), tokens: z.number(),
messages: z.array(ChatMessageSchema), messages: z.array(PromptMessageSchema.or(ChatMessageSchema)),
}) })
.strict(); .strict();
@@ -104,8 +105,8 @@ export interface CopilotProvider {
export interface CopilotTextToTextProvider extends CopilotProvider { export interface CopilotTextToTextProvider extends CopilotProvider {
generateText( generateText(
messages: PromptMessage[], messages: PromptMessage[],
model: string, model?: string,
options: { options?: {
temperature?: number; temperature?: number;
maxTokens?: number; maxTokens?: number;
signal?: AbortSignal; signal?: AbortSignal;
@@ -114,8 +115,8 @@ export interface CopilotTextToTextProvider extends CopilotProvider {
): Promise<string>; ): Promise<string>;
generateTextStream( generateTextStream(
messages: PromptMessage[], messages: PromptMessage[],
model: string, model?: string,
options: { options?: {
temperature?: number; temperature?: number;
maxTokens?: number; maxTokens?: number;
signal?: AbortSignal; signal?: AbortSignal;

View File

@@ -2,6 +2,51 @@
# THIS FILE WAS AUTOMATICALLY GENERATED (DO NOT MODIFY) # THIS FILE WAS AUTOMATICALLY GENERATED (DO NOT MODIFY)
# ------------------------------------------------------ # ------------------------------------------------------
type ChatMessage {
attachments: [String!]
content: String!
createdAt: DateTime
role: String!
}
type Copilot {
"""Get the session list of actions in the workspace"""
actions: [String!]!
"""Get the session list of chats in the workspace"""
chats: [String!]!
histories(docId: String, options: QueryChatHistoriesInput): [CopilotHistories!]!
"""Get the quota of the user in the workspace"""
quota: CopilotQuota!
workspaceId: ID
}
type CopilotHistories {
"""An mark identifying which view to use to display the session"""
action: String!
messages: [ChatMessage!]!
sessionId: String!
"""The number of tokens used in the session"""
tokens: Int!
}
type CopilotQuota {
limit: SafeInt!
used: SafeInt!
}
input CreateChatSessionInput {
"""An mark identifying which view to use to display the session"""
action: String
docId: String!
"""The prompt name to use for the session"""
promptName: String!
workspaceId: String!
}
input CreateCheckoutSessionInput { input CreateCheckoutSessionInput {
coupon: String coupon: String
idempotencyKey: String! idempotencyKey: String!
@@ -122,6 +167,9 @@ type Mutation {
"""Create a subscription checkout link of stripe""" """Create a subscription checkout link of stripe"""
createCheckoutSession(input: CreateCheckoutSessionInput!): String! createCheckoutSession(input: CreateCheckoutSessionInput!): String!
"""Create a chat session"""
createCopilotSession(options: CreateChatSessionInput!): String!
"""Create a stripe customer portal to manage payment methods""" """Create a stripe customer portal to manage payment methods"""
createCustomerPortal: String! createCustomerPortal: String!
@@ -223,6 +271,13 @@ type Query {
workspaces: [WorkspaceType!]! workspaces: [WorkspaceType!]!
} }
input QueryChatHistoriesInput {
action: Boolean
limit: Int
sessionId: String
skip: Int
}
type QuotaQueryType { type QuotaQueryType {
blobLimit: SafeInt! blobLimit: SafeInt!
copilotActionLimit: SafeInt copilotActionLimit: SafeInt
@@ -380,6 +435,7 @@ type UserSubscription {
type UserType { type UserType {
"""User avatar url""" """User avatar url"""
avatarUrl: String avatarUrl: String
copilot(workspaceId: String): Copilot!
"""User email verified""" """User email verified"""
createdAt: DateTime @deprecated(reason: "useless") createdAt: DateTime @deprecated(reason: "useless")

View File

@@ -1,17 +0,0 @@
query getCopilotAnonymousHistories(
$workspaceId: String!
$docId: String
$options: QueryChatHistoriesInput
) {
copilotAnonymous(workspaceId: $workspaceId) {
histories(docId: $docId, options: $options) {
sessionId
tokens
messages {
role
content
attachments
}
}
}
}

View File

@@ -1,6 +0,0 @@
query getCopilotAnonymousSessions($workspaceId: String!) {
copilotAnonymous(workspaceId: $workspaceId) {
chats
actions
}
}

View File

@@ -12,6 +12,7 @@ query getCopilotHistories(
role role
content content
attachments attachments
createdAt
} }
} }
} }

View File

@@ -0,0 +1,10 @@
query getCopilotQuota($workspaceId: String!, $docId: String!) {
currentUser {
copilot {
quota {
limit
used
}
}
}
}

View File

@@ -1,8 +1,8 @@
query getCopilotSessions($workspaceId: String!) { query getCopilotSessions($workspaceId: String!) {
currentUser { currentUser {
copilot(workspaceId: $workspaceId) { copilot(workspaceId: $workspaceId) {
chats
actions actions
chats
} }
} }
} }

View File

@@ -251,41 +251,6 @@ mutation removeEarlyAccess($email: String!) {
}`, }`,
}; };
export const getCopilotAnonymousHistoriesQuery = {
id: 'getCopilotAnonymousHistoriesQuery' as const,
operationName: 'getCopilotAnonymousHistories',
definitionName: 'copilotAnonymous',
containsFile: false,
query: `
query getCopilotAnonymousHistories($workspaceId: String!, $docId: String, $options: QueryChatHistoriesInput) {
copilotAnonymous(workspaceId: $workspaceId) {
histories(docId: $docId, options: $options) {
sessionId
tokens
messages {
role
content
attachments
}
}
}
}`,
};
export const getCopilotAnonymousSessionsQuery = {
id: 'getCopilotAnonymousSessionsQuery' as const,
operationName: 'getCopilotAnonymousSessions',
definitionName: 'copilotAnonymous',
containsFile: false,
query: `
query getCopilotAnonymousSessions($workspaceId: String!) {
copilotAnonymous(workspaceId: $workspaceId) {
chats
actions
}
}`,
};
export const getCopilotHistoriesQuery = { export const getCopilotHistoriesQuery = {
id: 'getCopilotHistoriesQuery' as const, id: 'getCopilotHistoriesQuery' as const,
operationName: 'getCopilotHistories', operationName: 'getCopilotHistories',
@@ -302,6 +267,7 @@ query getCopilotHistories($workspaceId: String!, $docId: String, $options: Query
role role
content content
attachments attachments
createdAt
} }
} }
} }
@@ -309,6 +275,24 @@ query getCopilotHistories($workspaceId: String!, $docId: String, $options: Query
}`, }`,
}; };
export const getCopilotQuotaQuery = {
id: 'getCopilotQuotaQuery' as const,
operationName: 'getCopilotQuota',
definitionName: 'currentUser',
containsFile: false,
query: `
query getCopilotQuota($workspaceId: String!, $docId: String!) {
currentUser {
copilot {
quota {
limit
used
}
}
}
}`,
};
export const getCopilotSessionsQuery = { export const getCopilotSessionsQuery = {
id: 'getCopilotSessionsQuery' as const, id: 'getCopilotSessionsQuery' as const,
operationName: 'getCopilotSessions', operationName: 'getCopilotSessions',
@@ -318,8 +302,8 @@ export const getCopilotSessionsQuery = {
query getCopilotSessions($workspaceId: String!) { query getCopilotSessions($workspaceId: String!) {
currentUser { currentUser {
copilot(workspaceId: $workspaceId) { copilot(workspaceId: $workspaceId) {
chats
actions actions
chats
} }
} }
}`, }`,

View File

@@ -35,9 +35,10 @@ export interface Scalars {
} }
export interface CreateChatSessionInput { export interface CreateChatSessionInput {
action: Scalars['Boolean']['input']; /** An mark identifying which view to use to display the session */
action: InputMaybe<Scalars['String']['input']>;
docId: Scalars['String']['input']; docId: Scalars['String']['input'];
model: Scalars['String']['input']; /** The prompt name to use for the session */
promptName: Scalars['String']['input']; promptName: Scalars['String']['input'];
workspaceId: Scalars['String']['input']; workspaceId: Scalars['String']['input'];
} }
@@ -333,43 +334,6 @@ export type PasswordLimitsFragment = {
maxLength: number; maxLength: number;
}; };
export type GetCopilotAnonymousHistoriesQueryVariables = Exact<{
workspaceId: Scalars['String']['input'];
docId: InputMaybe<Scalars['String']['input']>;
options: InputMaybe<QueryChatHistoriesInput>;
}>;
export type GetCopilotAnonymousHistoriesQuery = {
__typename?: 'Query';
copilotAnonymous: {
__typename?: 'Copilot';
histories: Array<{
__typename?: 'CopilotHistories';
sessionId: string;
tokens: number;
messages: Array<{
__typename?: 'ChatMessage';
role: string;
content: string;
attachments: Array<string> | null;
}>;
}>;
};
};
export type GetCopilotAnonymousSessionsQueryVariables = Exact<{
workspaceId: Scalars['String']['input'];
}>;
export type GetCopilotAnonymousSessionsQuery = {
__typename?: 'Query';
copilotAnonymous: {
__typename?: 'Copilot';
chats: Array<string>;
actions: Array<string>;
};
};
export type GetCopilotHistoriesQueryVariables = Exact<{ export type GetCopilotHistoriesQueryVariables = Exact<{
workspaceId: Scalars['String']['input']; workspaceId: Scalars['String']['input'];
docId: InputMaybe<Scalars['String']['input']>; docId: InputMaybe<Scalars['String']['input']>;
@@ -391,12 +355,29 @@ export type GetCopilotHistoriesQuery = {
role: string; role: string;
content: string; content: string;
attachments: Array<string> | null; attachments: Array<string> | null;
createdAt: string | null;
}>; }>;
}>; }>;
}; };
} | null; } | null;
}; };
export type GetCopilotQuotaQueryVariables = Exact<{
workspaceId: Scalars['String']['input'];
docId: Scalars['String']['input'];
}>;
export type GetCopilotQuotaQuery = {
__typename?: 'Query';
currentUser: {
__typename?: 'UserType';
copilot: {
__typename?: 'Copilot';
quota: { __typename?: 'CopilotQuota'; limit: number; used: number };
};
} | null;
};
export type GetCopilotSessionsQueryVariables = Exact<{ export type GetCopilotSessionsQueryVariables = Exact<{
workspaceId: Scalars['String']['input']; workspaceId: Scalars['String']['input'];
}>; }>;
@@ -407,8 +388,8 @@ export type GetCopilotSessionsQuery = {
__typename?: 'UserType'; __typename?: 'UserType';
copilot: { copilot: {
__typename?: 'Copilot'; __typename?: 'Copilot';
chats: Array<string>;
actions: Array<string>; actions: Array<string>;
chats: Array<string>;
}; };
} | null; } | null;
}; };
@@ -1057,21 +1038,16 @@ export type Queries =
variables: EarlyAccessUsersQueryVariables; variables: EarlyAccessUsersQueryVariables;
response: EarlyAccessUsersQuery; response: EarlyAccessUsersQuery;
} }
| {
name: 'getCopilotAnonymousHistoriesQuery';
variables: GetCopilotAnonymousHistoriesQueryVariables;
response: GetCopilotAnonymousHistoriesQuery;
}
| {
name: 'getCopilotAnonymousSessionsQuery';
variables: GetCopilotAnonymousSessionsQueryVariables;
response: GetCopilotAnonymousSessionsQuery;
}
| { | {
name: 'getCopilotHistoriesQuery'; name: 'getCopilotHistoriesQuery';
variables: GetCopilotHistoriesQueryVariables; variables: GetCopilotHistoriesQueryVariables;
response: GetCopilotHistoriesQuery; response: GetCopilotHistoriesQuery;
} }
| {
name: 'getCopilotQuotaQuery';
variables: GetCopilotQuotaQueryVariables;
response: GetCopilotQuotaQuery;
}
| { | {
name: 'getCopilotSessionsQuery'; name: 'getCopilotSessionsQuery';
variables: GetCopilotSessionsQueryVariables; variables: GetCopilotSessionsQueryVariables;