diff --git a/packages/backend/server/migrations/20240402100608_ai_prompt_session_metadata/migration.sql b/packages/backend/server/migrations/20240402100608_ai_prompt_session_metadata/migration.sql index c2c3bfc649..837d9601ea 100644 --- a/packages/backend/server/migrations/20240402100608_ai_prompt_session_metadata/migration.sql +++ b/packages/backend/server/migrations/20240402100608_ai_prompt_session_metadata/migration.sql @@ -80,11 +80,5 @@ ALTER TABLE "ai_sessions_messages" ADD CONSTRAINT "ai_sessions_messages_session_ -- AddForeignKey ALTER TABLE "ai_sessions_metadata" ADD CONSTRAINT "ai_sessions_metadata_user_id_fkey" FOREIGN KEY ("user_id") REFERENCES "users"("id") ON DELETE CASCADE ON UPDATE CASCADE; --- AddForeignKey -ALTER TABLE "ai_sessions_metadata" ADD CONSTRAINT "ai_sessions_metadata_workspace_id_fkey" FOREIGN KEY ("workspace_id") REFERENCES "workspaces"("id") ON DELETE CASCADE ON UPDATE CASCADE; - --- AddForeignKey -ALTER TABLE "ai_sessions_metadata" ADD CONSTRAINT "ai_sessions_metadata_doc_id_workspace_id_fkey" FOREIGN KEY ("doc_id", "workspace_id") REFERENCES "snapshots"("guid", "workspace_id") ON DELETE CASCADE ON UPDATE CASCADE; - -- AddForeignKey ALTER TABLE "ai_sessions_metadata" ADD CONSTRAINT "ai_sessions_metadata_prompt_name_fkey" FOREIGN KEY ("prompt_name") REFERENCES "ai_prompts_metadata"("name") ON DELETE CASCADE ON UPDATE CASCADE; diff --git a/packages/backend/server/schema.prisma b/packages/backend/server/schema.prisma index 09f82c1266..f9f5ae0696 100644 --- a/packages/backend/server/schema.prisma +++ b/packages/backend/server/schema.prisma @@ -97,7 +97,6 @@ model Workspace { permissions WorkspaceUserPermission[] pagePermissions WorkspacePageUserPermission[] features WorkspaceFeatures[] - aiSessions AiSession[] @@map("workspaces") } @@ -323,8 +322,6 @@ model Snapshot { // but the created time of last seen update that has been merged into snapshot. updatedAt DateTime @map("updated_at") @db.Timestamptz(6) - aiSessions AiSession[] - @@id([id, workspaceId]) @@map("snapshots") } @@ -485,11 +482,9 @@ model AiSession { promptName String @map("prompt_name") @db.VarChar(32) createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6) - user User @relation(fields: [userId], references: [id], onDelete: Cascade) - workspace Workspace @relation(fields: [workspaceId], references: [id], onDelete: Cascade) - doc Snapshot @relation(fields: [docId, workspaceId], references: [id, workspaceId], onDelete: Cascade) - prompt AiPrompt @relation(fields: [promptName], references: [name], onDelete: Cascade) - messages AiSessionMessage[] + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + prompt AiPrompt @relation(fields: [promptName], references: [name], onDelete: Cascade) + messages AiSessionMessage[] @@map("ai_sessions_metadata") } diff --git a/packages/backend/server/src/core/workspaces/permission.ts b/packages/backend/server/src/core/workspaces/permission.ts index 083a7e892e..9cca48766f 100644 --- a/packages/backend/server/src/core/workspaces/permission.ts +++ b/packages/backend/server/src/core/workspaces/permission.ts @@ -26,6 +26,22 @@ export class PermissionService { return data?.type as Permission; } + /** + * check whether a workspace exists and has any one can access it + * @param workspaceId workspace id + * @returns + */ + async hasWorkspace(workspaceId: string) { + return await this.prisma.workspaceUserPermission + .count({ + where: { + workspaceId, + accepted: true, + }, + }) + .then(count => count > 0); + } + async getOwnedWorkspaces(userId: string) { return this.prisma.workspaceUserPermission .findMany({ @@ -96,6 +112,23 @@ export class PermissionService { return count !== 0; } + /** + * only check permission if the workspace is a cloud workspace + * @param workspaceId workspace id + * @param userId user id, check if is a public workspace if not provided + * @param permission default is read + */ + async checkCloudWorkspace( + workspaceId: string, + userId?: string, + permission: Permission = Permission.Read + ) { + const hasWorkspace = await this.hasWorkspace(workspaceId); + if (hasWorkspace) { + await this.checkWorkspace(workspaceId, userId, permission); + } + } + async checkWorkspace( ws: string, user?: string, @@ -263,6 +296,25 @@ export class PermissionService { /// End regin: workspace permission /// Start regin: page permission + /** + * only check permission if the workspace is a cloud workspace + * @param workspaceId workspace id + * @param pageId page id aka doc id + * @param userId user id, check if is a public page if not provided + * @param permission default is read + */ + async checkCloudPagePermission( + workspaceId: string, + pageId: string, + userId?: string, + permission = Permission.Read + ) { + const hasWorkspace = await this.hasWorkspace(workspaceId); + if (hasWorkspace) { + await this.checkPagePermission(workspaceId, pageId, userId, permission); + } + } + async checkPagePermission( ws: string, page: string, diff --git a/packages/backend/server/src/plugins/copilot/controller.ts b/packages/backend/server/src/plugins/copilot/controller.ts new file mode 100644 index 0000000000..58678efd1a --- /dev/null +++ b/packages/backend/server/src/plugins/copilot/controller.ts @@ -0,0 +1,151 @@ +import { + BadRequestException, + Controller, + Get, + InternalServerErrorException, + Param, + Query, + Req, + Sse, +} from '@nestjs/common'; +import { + concatMap, + connect, + EMPTY, + from, + map, + merge, + Observable, + switchMap, + toArray, +} from 'rxjs'; + +import { Public } from '../../core/auth'; +import { CurrentUser } from '../../core/auth/current-user'; +import { CopilotProviderService } from './providers'; +import { ChatSessionService } from './session'; +import { CopilotCapability } from './types'; + +export interface ChatEvent { + data: string; + id?: string; +} + +@Controller('/api/copilot') +export class CopilotController { + constructor( + private readonly chatSession: ChatSessionService, + private readonly provider: CopilotProviderService + ) {} + + @Public() + @Get('/chat/:sessionId') + async chat( + @CurrentUser() user: CurrentUser, + @Req() req: Request, + @Param('sessionId') sessionId: string, + @Query('message') content: string, + @Query() params: Record + ): Promise { + const provider = this.provider.getProviderByCapability( + CopilotCapability.TextToText + ); + if (!provider) { + throw new InternalServerErrorException('No provider available'); + } + const session = await this.chatSession.get(sessionId); + if (!session) { + throw new BadRequestException('Session not found'); + } + if (!content || !content.trim()) { + throw new BadRequestException('Message is empty'); + } + session.push({ + role: 'user', + content: decodeURIComponent(content), + createdAt: new Date(), + }); + + try { + delete params.message; + const content = await provider.generateText( + session.finish(params), + session.model, + { + signal: req.signal, + user: user.id, + } + ); + + session.push({ + role: 'assistant', + content, + createdAt: new Date(), + }); + await session.save(); + + return content; + } catch (e: any) { + throw new InternalServerErrorException( + e.message || "Couldn't generate text" + ); + } + } + + @Public() + @Sse('/chat/:sessionId/stream') + async chatStream( + @CurrentUser() user: CurrentUser, + @Req() req: Request, + @Param('sessionId') sessionId: string, + @Query('message') content: string, + @Query() params: Record + ): Promise> { + const provider = this.provider.getProviderByCapability( + CopilotCapability.TextToText + ); + if (!provider) { + throw new InternalServerErrorException('No provider available'); + } + const session = await this.chatSession.get(sessionId); + if (!session) { + throw new BadRequestException('Session not found'); + } + if (!content || !content.trim()) { + throw new BadRequestException('Message is empty'); + } + session.push({ + role: 'user', + content: decodeURIComponent(content), + createdAt: new Date(), + }); + + delete params.message; + return from( + provider.generateTextStream(session.finish(params), session.model, { + signal: req.signal, + user: user.id, + }) + ).pipe( + connect(shared$ => + merge( + // actual chat event stream + shared$.pipe(map(data => ({ id: sessionId, data }))), + // save the generated text to the session + shared$.pipe( + toArray(), + concatMap(values => { + session.push({ + role: 'assistant', + content: values.join(''), + createdAt: new Date(), + }); + return from(session.save()); + }), + switchMap(() => EMPTY) + ) + ) + ) + ); + } +} diff --git a/packages/backend/server/src/plugins/copilot/index.ts b/packages/backend/server/src/plugins/copilot/index.ts index 732109abff..d3f7185f93 100644 --- a/packages/backend/server/src/plugins/copilot/index.ts +++ b/packages/backend/server/src/plugins/copilot/index.ts @@ -1,6 +1,8 @@ import { ServerFeature } from '../../core/config'; +import { QuotaService } from '../../core/quota'; import { PermissionService } from '../../core/workspaces/permission'; import { Plugin } from '../registry'; +import { CopilotController } from './controller'; import { PromptService } from './prompt'; import { assertProvidersConfigs, @@ -8,6 +10,7 @@ import { OpenAIProvider, registerCopilotProvider, } from './providers'; +import { CopilotResolver, UserCopilotResolver } from './resolver'; import { ChatSessionService } from './session'; registerCopilotProvider(OpenAIProvider); @@ -16,10 +19,14 @@ registerCopilotProvider(OpenAIProvider); name: 'copilot', providers: [ PermissionService, + QuotaService, ChatSessionService, + CopilotResolver, + UserCopilotResolver, PromptService, CopilotProviderService, ], + controllers: [CopilotController], contributesTo: ServerFeature.Copilot, if: config => { if (config.flavor.graphql) { diff --git a/packages/backend/server/src/plugins/copilot/prompt.ts b/packages/backend/server/src/plugins/copilot/prompt.ts index b0857dd03f..7d7f4929d8 100644 --- a/packages/backend/server/src/plugins/copilot/prompt.ts +++ b/packages/backend/server/src/plugins/copilot/prompt.ts @@ -38,16 +38,16 @@ export class ChatPrompt { ) { return new ChatPrompt( options.name, - options.action, - options.model, + options.action || undefined, + options.model || undefined, options.messages ); } constructor( public readonly name: string, - public readonly action: string | null, - public readonly model: string | null, + public readonly action: string | undefined, + public readonly model: string | undefined, private readonly messages: PromptMessage[] ) { this.encoder = getTokenEncoder(model); diff --git a/packages/backend/server/src/plugins/copilot/providers/openai.ts b/packages/backend/server/src/plugins/copilot/providers/openai.ts index 0a863430f3..af85794466 100644 --- a/packages/backend/server/src/plugins/copilot/providers/openai.ts +++ b/packages/backend/server/src/plugins/copilot/providers/openai.ts @@ -3,14 +3,16 @@ import assert from 'node:assert'; import { ClientOptions, OpenAI } from 'openai'; import { - ChatMessage, ChatMessageRole, CopilotCapability, CopilotProviderType, CopilotTextToEmbeddingProvider, CopilotTextToTextProvider, + PromptMessage, } from '../types'; +const DEFAULT_DIMENSIONS = 256; + export class OpenAIProvider implements CopilotTextToTextProvider, CopilotTextToEmbeddingProvider { @@ -50,7 +52,7 @@ export class OpenAIProvider return OpenAIProvider.capabilities; } - private chatToGPTMessage(messages: ChatMessage[]) { + private chatToGPTMessage(messages: PromptMessage[]) { // filter redundant fields return messages.map(message => ({ role: message.role, @@ -63,7 +65,7 @@ export class OpenAIProvider embeddings, model, }: { - messages?: ChatMessage[]; + messages?: PromptMessage[]; embeddings?: string[]; model: string; }) { @@ -106,7 +108,7 @@ export class OpenAIProvider // ====== text to text ====== async generateText( - messages: ChatMessage[], + messages: PromptMessage[], model: string = 'gpt-3.5-turbo', options: { temperature?: number; @@ -134,8 +136,8 @@ export class OpenAIProvider } async *generateTextStream( - messages: ChatMessage[], - model: string, + messages: PromptMessage[], + model: string = 'gpt-3.5-turbo', options: { temperature?: number; maxTokens?: number; @@ -179,7 +181,7 @@ export class OpenAIProvider dimensions: number; signal?: AbortSignal; user?: string; - } = { dimensions: 256 } + } = { dimensions: DEFAULT_DIMENSIONS } ): Promise { messages = Array.isArray(messages) ? messages : [messages]; this.checkParams({ embeddings: messages, model }); @@ -187,7 +189,7 @@ export class OpenAIProvider const result = await this.instance.embeddings.create({ model: model, input: messages, - dimensions: options.dimensions, + dimensions: options.dimensions || DEFAULT_DIMENSIONS, user: options.user, }); return result.data.map(e => e.embedding); diff --git a/packages/backend/server/src/plugins/copilot/resolver.ts b/packages/backend/server/src/plugins/copilot/resolver.ts new file mode 100644 index 0000000000..4126def0c1 --- /dev/null +++ b/packages/backend/server/src/plugins/copilot/resolver.ts @@ -0,0 +1,260 @@ +import { + Args, + Field, + ID, + InputType, + Mutation, + ObjectType, + Parent, + registerEnumType, + ResolveField, + Resolver, +} from '@nestjs/graphql'; +import { SafeIntResolver } from 'graphql-scalars'; + +import { CurrentUser, Public } from '../../core/auth'; +import { QuotaService } from '../../core/quota'; +import { UserType } from '../../core/user'; +import { PermissionService } from '../../core/workspaces/permission'; +import { + MutexService, + PaymentRequiredException, + TooManyRequestsException, +} from '../../fundamentals'; +import { ChatSessionService, ListHistoriesOptions } from './session'; +import { AvailableModels, type ChatHistory, type ChatMessage } from './types'; + +registerEnumType(AvailableModels, { name: 'CopilotModel' }); + +// ================== Input Types ================== + +@InputType() +class CreateChatSessionInput { + @Field(() => String) + workspaceId!: string; + + @Field(() => String) + docId!: string; + + @Field(() => String, { + description: 'An mark identifying which view to use to display the session', + nullable: true, + }) + action!: string | undefined; + + @Field(() => String, { + description: 'The prompt name to use for the session', + }) + promptName!: string; +} + +@InputType() +class QueryChatHistoriesInput implements Partial { + @Field(() => Boolean, { nullable: true }) + action: boolean | undefined; + + @Field(() => Number, { nullable: true }) + limit: number | undefined; + + @Field(() => Number, { nullable: true }) + skip: number | undefined; + + @Field(() => String, { nullable: true }) + sessionId: string | undefined; +} + +// ================== Return Types ================== + +@ObjectType('ChatMessage') +class ChatMessageType implements Partial { + @Field(() => String) + role!: 'system' | 'assistant' | 'user'; + + @Field(() => String) + content!: string; + + @Field(() => [String], { nullable: true }) + attachments!: string[]; + + @Field(() => Date, { nullable: true }) + createdAt!: Date | undefined; +} + +@ObjectType('CopilotHistories') +class CopilotHistoriesType implements Partial { + @Field(() => String) + sessionId!: string; + + @Field(() => String, { + description: 'An mark identifying which view to use to display the session', + }) + action!: string; + + @Field(() => Number, { + description: 'The number of tokens used in the session', + }) + tokens!: number; + + @Field(() => [ChatMessageType]) + messages!: ChatMessageType[]; +} + +@ObjectType('CopilotQuota') +class CopilotQuotaType { + @Field(() => SafeIntResolver) + limit!: number; + + @Field(() => SafeIntResolver) + used!: number; +} + +// ================== Resolver ================== + +@ObjectType('Copilot') +export class CopilotType { + @Field(() => ID, { nullable: true }) + workspaceId!: string | undefined; +} + +@Resolver(() => CopilotType) +export class CopilotResolver { + constructor( + private readonly permissions: PermissionService, + private readonly quota: QuotaService, + private readonly mutex: MutexService, + private readonly chatSession: ChatSessionService + ) {} + + @ResolveField(() => CopilotQuotaType, { + name: 'quota', + description: 'Get the quota of the user in the workspace', + complexity: 2, + }) + async getQuota(@CurrentUser() user: CurrentUser) { + const quota = await this.quota.getUserQuota(user.id); + const limit = quota.feature.copilotActionLimit; + + const actions = await this.chatSession.countUserActions(user.id); + const chats = await this.chatSession + .listHistories(user.id) + .then(histories => + histories.reduce( + (acc, h) => acc + h.messages.filter(m => m.role === 'user').length, + 0 + ) + ); + + return { limit, used: actions + chats }; + } + + @ResolveField(() => [String], { + description: 'Get the session list of chats in the workspace', + complexity: 2, + }) + async chats( + @Parent() copilot: CopilotType, + @CurrentUser() user: CurrentUser + ) { + if (!copilot.workspaceId) return []; + await this.permissions.checkCloudWorkspace(copilot.workspaceId, user.id); + return await this.chatSession.listSessions(user.id, copilot.workspaceId); + } + + @ResolveField(() => [String], { + description: 'Get the session list of actions in the workspace', + complexity: 2, + }) + async actions( + @Parent() copilot: CopilotType, + @CurrentUser() user: CurrentUser + ) { + if (!copilot.workspaceId) return []; + await this.permissions.checkCloudWorkspace(copilot.workspaceId, user.id); + return await this.chatSession.listSessions(user.id, copilot.workspaceId, { + action: true, + }); + } + + @ResolveField(() => [CopilotHistoriesType], {}) + async histories( + @Parent() copilot: CopilotType, + @CurrentUser() user: CurrentUser, + @Args('docId', { nullable: true }) docId?: string, + @Args({ + name: 'options', + type: () => QueryChatHistoriesInput, + nullable: true, + }) + options?: QueryChatHistoriesInput + ) { + const workspaceId = copilot.workspaceId; + if (!workspaceId) { + return []; + } else if (docId) { + await this.permissions.checkCloudPagePermission( + workspaceId, + docId, + user.id + ); + } else { + await this.permissions.checkCloudWorkspace(workspaceId, user.id); + } + + return await this.chatSession.listHistories( + user.id, + workspaceId, + docId, + options + ); + } + + @Public() + @Mutation(() => String, { + description: 'Create a chat session', + }) + async createCopilotSession( + @CurrentUser() user: CurrentUser, + @Args({ name: 'options', type: () => CreateChatSessionInput }) + options: CreateChatSessionInput + ) { + await this.permissions.checkCloudPagePermission( + options.workspaceId, + options.docId, + user.id + ); + const lockFlag = `session:${user.id}:${options.workspaceId}`; + await using lock = await this.mutex.lock(lockFlag); + if (!lock) { + return new TooManyRequestsException('Server is busy'); + } + + const { limit, used } = await this.getQuota(user); + if (limit && Number.isFinite(limit) && used >= limit) { + return new PaymentRequiredException( + `You have reached the limit of actions in this workspace, please upgrade your plan.` + ); + } + + const session = await this.chatSession.create({ + ...options, + userId: user.id, + }); + return session; + } +} + +@Resolver(() => UserType) +export class UserCopilotResolver { + constructor(private readonly permissions: PermissionService) {} + + @ResolveField(() => CopilotType) + async copilot( + @CurrentUser() user: CurrentUser, + @Args('workspaceId', { nullable: true }) workspaceId?: string + ) { + if (workspaceId) { + await this.permissions.checkCloudWorkspace(workspaceId, user.id); + } + return { workspaceId }; + } +} diff --git a/packages/backend/server/src/plugins/copilot/session.ts b/packages/backend/server/src/plugins/copilot/session.ts index 6ac0691f17..6cf1656496 100644 --- a/packages/backend/server/src/plugins/copilot/session.ts +++ b/packages/backend/server/src/plugins/copilot/session.ts @@ -11,6 +11,7 @@ import { ChatMessageSchema, getTokenEncoder, PromptMessage, + PromptMessageSchema, PromptParams, } from './types'; @@ -105,37 +106,62 @@ export class ChatSession implements AsyncDisposable { @Injectable() export class ChatSessionService { private readonly logger = new Logger(ChatSessionService.name); + constructor( private readonly db: PrismaClient, private readonly prompt: PromptService ) {} - private async setSession(state: ChatSessionState): Promise { - await this.db.aiSession.upsert({ - where: { - id: state.sessionId, - }, - update: { - messages: { - create: state.messages.map((m, idx) => ({ idx, ...m })), - }, - }, - create: { - id: state.sessionId, - messages: { create: state.messages }, - // connect - user: { connect: { id: state.userId } }, - workspace: { connect: { id: state.workspaceId } }, - doc: { - connect: { - id_workspaceId: { - id: state.docId, + private async setSession(state: ChatSessionState): Promise { + return await this.db.$transaction(async tx => { + let sessionId = state.sessionId; + + // find existing session if session is chat session + if (!state.prompt.action) { + const { id } = + (await tx.aiSession.findFirst({ + where: { + userId: state.userId, workspaceId: state.workspaceId, + docId: state.docId, + prompt: { action: { equals: null } }, }, + select: { id: true }, + })) || {}; + if (id) sessionId = id; + } + + await tx.aiSession.upsert({ + where: { + id: sessionId, + userId: state.userId, + }, + update: { + messages: { + // delete old messages + deleteMany: {}, + create: state.messages.map(m => ({ + ...m, + params: m.params || undefined, + })), }, }, - prompt: { connect: { name: state.prompt.name } }, - }, + create: { + id: sessionId, + workspaceId: state.workspaceId, + docId: state.docId, + messages: { + create: state.messages.map(m => ({ + ...m, + params: m.params || undefined, + })), + }, + // connect + user: { connect: { id: state.userId } }, + prompt: { connect: { name: state.prompt.name } }, + }, + }); + return sessionId; }); } @@ -171,6 +197,7 @@ export class ChatSessionService { }) .then(async session => { if (!session) return; + const messages = ChatMessageSchema.array().safeParse(session.messages); return { @@ -184,18 +211,58 @@ export class ChatSessionService { }); } - async listHistories( + private calculateTokenSize( + messages: PromptMessage[], + model: AvailableModel + ): number { + const encoder = getTokenEncoder(model); + return messages + .map(m => encoder?.encode_ordinary(m.content).length || 0) + .reduce((total, length) => total + length, 0); + } + + async countUserActions(userId: string): Promise { + return await this.db.aiSession.count({ + where: { userId, prompt: { action: { not: null } } }, + }); + } + + async listSessions( + userId: string, workspaceId: string, - docId: string, - options: ListHistoriesOptions + options?: { docId?: string; action?: boolean } + ): Promise { + return await this.db.aiSession + .findMany({ + where: { + userId, + workspaceId, + docId: workspaceId === options?.docId ? undefined : options?.docId, + prompt: { + action: options?.action ? { not: null } : null, + }, + }, + select: { id: true }, + }) + .then(sessions => sessions.map(({ id }) => id)); + } + + async listHistories( + userId: string, + workspaceId?: string, + docId?: string, + options?: ListHistoriesOptions ): Promise { return await this.db.aiSession .findMany({ where: { + userId, workspaceId: workspaceId, docId: workspaceId === docId ? undefined : docId, - prompt: { action: { not: null } }, - id: options.sessionId ? { equals: options.sessionId } : undefined, + prompt: { + action: options?.action ? { not: null } : null, + }, + id: options?.sessionId ? { equals: options.sessionId } : undefined, }, select: { id: true, @@ -210,20 +277,33 @@ export class ChatSessionService { }, }, }, - take: options.limit, - skip: options.skip, + take: options?.limit, + skip: options?.skip, orderBy: { createdAt: 'desc' }, }) .then(sessions => sessions .map(({ id, prompt, messages }) => { - const ret = ChatMessageSchema.array().safeParse(messages); - if (ret.success) { - const encoder = getTokenEncoder(prompt.model as AvailableModel); - const tokens = ret.data - .map(m => encoder?.encode_ordinary(m.content).length || 0) - .reduce((total, length) => total + length, 0); - return { sessionId: id, tokens, messages: ret.data }; + try { + const ret = PromptMessageSchema.array().safeParse(messages); + if (ret.success) { + const tokens = this.calculateTokenSize( + ret.data, + prompt.model as AvailableModel + ); + return { + sessionId: id, + action: prompt.action || undefined, + tokens, + messages: ret.data, + }; + } else { + this.logger.error( + `Unexpected message schema: ${JSON.stringify(ret.error)}` + ); + } + } catch (e) { + this.logger.error('Unexpected error in listHistories', e); } return undefined; }) @@ -238,8 +318,12 @@ export class ChatSessionService { this.logger.error(`Prompt not found: ${options.promptName}`); throw new Error('Prompt not found'); } - await this.setSession({ ...options, sessionId, prompt, messages: [] }); - return sessionId; + return await this.setSession({ + ...options, + sessionId, + prompt, + messages: [], + }); } /** diff --git a/packages/backend/server/src/plugins/copilot/types.ts b/packages/backend/server/src/plugins/copilot/types.ts index 34c6c996f3..86a73a86df 100644 --- a/packages/backend/server/src/plugins/copilot/types.ts +++ b/packages/backend/server/src/plugins/copilot/types.ts @@ -76,8 +76,9 @@ export type ChatMessage = z.infer; export const ChatHistorySchema = z .object({ sessionId: z.string(), + action: z.string().optional(), tokens: z.number(), - messages: z.array(ChatMessageSchema), + messages: z.array(PromptMessageSchema.or(ChatMessageSchema)), }) .strict(); @@ -104,8 +105,8 @@ export interface CopilotProvider { export interface CopilotTextToTextProvider extends CopilotProvider { generateText( messages: PromptMessage[], - model: string, - options: { + model?: string, + options?: { temperature?: number; maxTokens?: number; signal?: AbortSignal; @@ -114,8 +115,8 @@ export interface CopilotTextToTextProvider extends CopilotProvider { ): Promise; generateTextStream( messages: PromptMessage[], - model: string, - options: { + model?: string, + options?: { temperature?: number; maxTokens?: number; signal?: AbortSignal; diff --git a/packages/backend/server/src/schema.gql b/packages/backend/server/src/schema.gql index b1537af422..e46289db84 100644 --- a/packages/backend/server/src/schema.gql +++ b/packages/backend/server/src/schema.gql @@ -2,6 +2,51 @@ # THIS FILE WAS AUTOMATICALLY GENERATED (DO NOT MODIFY) # ------------------------------------------------------ +type ChatMessage { + attachments: [String!] + content: String! + createdAt: DateTime + role: String! +} + +type Copilot { + """Get the session list of actions in the workspace""" + actions: [String!]! + + """Get the session list of chats in the workspace""" + chats: [String!]! + histories(docId: String, options: QueryChatHistoriesInput): [CopilotHistories!]! + + """Get the quota of the user in the workspace""" + quota: CopilotQuota! + workspaceId: ID +} + +type CopilotHistories { + """An mark identifying which view to use to display the session""" + action: String! + messages: [ChatMessage!]! + sessionId: String! + + """The number of tokens used in the session""" + tokens: Int! +} + +type CopilotQuota { + limit: SafeInt! + used: SafeInt! +} + +input CreateChatSessionInput { + """An mark identifying which view to use to display the session""" + action: String + docId: String! + + """The prompt name to use for the session""" + promptName: String! + workspaceId: String! +} + input CreateCheckoutSessionInput { coupon: String idempotencyKey: String! @@ -122,6 +167,9 @@ type Mutation { """Create a subscription checkout link of stripe""" createCheckoutSession(input: CreateCheckoutSessionInput!): String! + """Create a chat session""" + createCopilotSession(options: CreateChatSessionInput!): String! + """Create a stripe customer portal to manage payment methods""" createCustomerPortal: String! @@ -223,6 +271,13 @@ type Query { workspaces: [WorkspaceType!]! } +input QueryChatHistoriesInput { + action: Boolean + limit: Int + sessionId: String + skip: Int +} + type QuotaQueryType { blobLimit: SafeInt! copilotActionLimit: SafeInt @@ -380,6 +435,7 @@ type UserSubscription { type UserType { """User avatar url""" avatarUrl: String + copilot(workspaceId: String): Copilot! """User email verified""" createdAt: DateTime @deprecated(reason: "useless") diff --git a/packages/frontend/graphql/src/graphql/get-copilot-anonymous-histories.gql b/packages/frontend/graphql/src/graphql/get-copilot-anonymous-histories.gql deleted file mode 100644 index f04af20b34..0000000000 --- a/packages/frontend/graphql/src/graphql/get-copilot-anonymous-histories.gql +++ /dev/null @@ -1,17 +0,0 @@ -query getCopilotAnonymousHistories( - $workspaceId: String! - $docId: String - $options: QueryChatHistoriesInput -) { - copilotAnonymous(workspaceId: $workspaceId) { - histories(docId: $docId, options: $options) { - sessionId - tokens - messages { - role - content - attachments - } - } - } -} diff --git a/packages/frontend/graphql/src/graphql/get-copilot-anonymous-sessions.gql b/packages/frontend/graphql/src/graphql/get-copilot-anonymous-sessions.gql deleted file mode 100644 index 57c4f77a5a..0000000000 --- a/packages/frontend/graphql/src/graphql/get-copilot-anonymous-sessions.gql +++ /dev/null @@ -1,6 +0,0 @@ -query getCopilotAnonymousSessions($workspaceId: String!) { - copilotAnonymous(workspaceId: $workspaceId) { - chats - actions - } -} diff --git a/packages/frontend/graphql/src/graphql/get-copilot-histories.gql b/packages/frontend/graphql/src/graphql/get-copilot-histories.gql index 75541dfd36..496c772598 100644 --- a/packages/frontend/graphql/src/graphql/get-copilot-histories.gql +++ b/packages/frontend/graphql/src/graphql/get-copilot-histories.gql @@ -12,6 +12,7 @@ query getCopilotHistories( role content attachments + createdAt } } } diff --git a/packages/frontend/graphql/src/graphql/get-copilot-quota.gql b/packages/frontend/graphql/src/graphql/get-copilot-quota.gql new file mode 100644 index 0000000000..61e589b2fd --- /dev/null +++ b/packages/frontend/graphql/src/graphql/get-copilot-quota.gql @@ -0,0 +1,10 @@ +query getCopilotQuota($workspaceId: String!, $docId: String!) { + currentUser { + copilot { + quota { + limit + used + } + } + } +} diff --git a/packages/frontend/graphql/src/graphql/get-copilot-sessions.gql b/packages/frontend/graphql/src/graphql/get-copilot-sessions.gql index 1c065f8d1d..66ce82960a 100644 --- a/packages/frontend/graphql/src/graphql/get-copilot-sessions.gql +++ b/packages/frontend/graphql/src/graphql/get-copilot-sessions.gql @@ -1,8 +1,8 @@ query getCopilotSessions($workspaceId: String!) { currentUser { copilot(workspaceId: $workspaceId) { - chats actions + chats } } } diff --git a/packages/frontend/graphql/src/graphql/index.ts b/packages/frontend/graphql/src/graphql/index.ts index 0672676298..cdb7e92477 100644 --- a/packages/frontend/graphql/src/graphql/index.ts +++ b/packages/frontend/graphql/src/graphql/index.ts @@ -251,41 +251,6 @@ mutation removeEarlyAccess($email: String!) { }`, }; -export const getCopilotAnonymousHistoriesQuery = { - id: 'getCopilotAnonymousHistoriesQuery' as const, - operationName: 'getCopilotAnonymousHistories', - definitionName: 'copilotAnonymous', - containsFile: false, - query: ` -query getCopilotAnonymousHistories($workspaceId: String!, $docId: String, $options: QueryChatHistoriesInput) { - copilotAnonymous(workspaceId: $workspaceId) { - histories(docId: $docId, options: $options) { - sessionId - tokens - messages { - role - content - attachments - } - } - } -}`, -}; - -export const getCopilotAnonymousSessionsQuery = { - id: 'getCopilotAnonymousSessionsQuery' as const, - operationName: 'getCopilotAnonymousSessions', - definitionName: 'copilotAnonymous', - containsFile: false, - query: ` -query getCopilotAnonymousSessions($workspaceId: String!) { - copilotAnonymous(workspaceId: $workspaceId) { - chats - actions - } -}`, -}; - export const getCopilotHistoriesQuery = { id: 'getCopilotHistoriesQuery' as const, operationName: 'getCopilotHistories', @@ -302,6 +267,7 @@ query getCopilotHistories($workspaceId: String!, $docId: String, $options: Query role content attachments + createdAt } } } @@ -309,6 +275,24 @@ query getCopilotHistories($workspaceId: String!, $docId: String, $options: Query }`, }; +export const getCopilotQuotaQuery = { + id: 'getCopilotQuotaQuery' as const, + operationName: 'getCopilotQuota', + definitionName: 'currentUser', + containsFile: false, + query: ` +query getCopilotQuota($workspaceId: String!, $docId: String!) { + currentUser { + copilot { + quota { + limit + used + } + } + } +}`, +}; + export const getCopilotSessionsQuery = { id: 'getCopilotSessionsQuery' as const, operationName: 'getCopilotSessions', @@ -318,8 +302,8 @@ export const getCopilotSessionsQuery = { query getCopilotSessions($workspaceId: String!) { currentUser { copilot(workspaceId: $workspaceId) { - chats actions + chats } } }`, diff --git a/packages/frontend/graphql/src/schema.ts b/packages/frontend/graphql/src/schema.ts index 37acadc048..9c1bd6f58c 100644 --- a/packages/frontend/graphql/src/schema.ts +++ b/packages/frontend/graphql/src/schema.ts @@ -35,9 +35,10 @@ export interface Scalars { } export interface CreateChatSessionInput { - action: Scalars['Boolean']['input']; + /** An mark identifying which view to use to display the session */ + action: InputMaybe; docId: Scalars['String']['input']; - model: Scalars['String']['input']; + /** The prompt name to use for the session */ promptName: Scalars['String']['input']; workspaceId: Scalars['String']['input']; } @@ -333,43 +334,6 @@ export type PasswordLimitsFragment = { maxLength: number; }; -export type GetCopilotAnonymousHistoriesQueryVariables = Exact<{ - workspaceId: Scalars['String']['input']; - docId: InputMaybe; - options: InputMaybe; -}>; - -export type GetCopilotAnonymousHistoriesQuery = { - __typename?: 'Query'; - copilotAnonymous: { - __typename?: 'Copilot'; - histories: Array<{ - __typename?: 'CopilotHistories'; - sessionId: string; - tokens: number; - messages: Array<{ - __typename?: 'ChatMessage'; - role: string; - content: string; - attachments: Array | null; - }>; - }>; - }; -}; - -export type GetCopilotAnonymousSessionsQueryVariables = Exact<{ - workspaceId: Scalars['String']['input']; -}>; - -export type GetCopilotAnonymousSessionsQuery = { - __typename?: 'Query'; - copilotAnonymous: { - __typename?: 'Copilot'; - chats: Array; - actions: Array; - }; -}; - export type GetCopilotHistoriesQueryVariables = Exact<{ workspaceId: Scalars['String']['input']; docId: InputMaybe; @@ -391,12 +355,29 @@ export type GetCopilotHistoriesQuery = { role: string; content: string; attachments: Array | null; + createdAt: string | null; }>; }>; }; } | null; }; +export type GetCopilotQuotaQueryVariables = Exact<{ + workspaceId: Scalars['String']['input']; + docId: Scalars['String']['input']; +}>; + +export type GetCopilotQuotaQuery = { + __typename?: 'Query'; + currentUser: { + __typename?: 'UserType'; + copilot: { + __typename?: 'Copilot'; + quota: { __typename?: 'CopilotQuota'; limit: number; used: number }; + }; + } | null; +}; + export type GetCopilotSessionsQueryVariables = Exact<{ workspaceId: Scalars['String']['input']; }>; @@ -407,8 +388,8 @@ export type GetCopilotSessionsQuery = { __typename?: 'UserType'; copilot: { __typename?: 'Copilot'; - chats: Array; actions: Array; + chats: Array; }; } | null; }; @@ -1057,21 +1038,16 @@ export type Queries = variables: EarlyAccessUsersQueryVariables; response: EarlyAccessUsersQuery; } - | { - name: 'getCopilotAnonymousHistoriesQuery'; - variables: GetCopilotAnonymousHistoriesQueryVariables; - response: GetCopilotAnonymousHistoriesQuery; - } - | { - name: 'getCopilotAnonymousSessionsQuery'; - variables: GetCopilotAnonymousSessionsQueryVariables; - response: GetCopilotAnonymousSessionsQuery; - } | { name: 'getCopilotHistoriesQuery'; variables: GetCopilotHistoriesQueryVariables; response: GetCopilotHistoriesQuery; } + | { + name: 'getCopilotQuotaQuery'; + variables: GetCopilotQuotaQueryVariables; + response: GetCopilotQuotaQuery; + } | { name: 'getCopilotSessionsQuery'; variables: GetCopilotSessionsQueryVariables;