diff --git a/packages/backend/server/src/plugins/copilot/index.ts b/packages/backend/server/src/plugins/copilot/index.ts index f58ba8754c..ccbc0d4401 100644 --- a/packages/backend/server/src/plugins/copilot/index.ts +++ b/packages/backend/server/src/plugins/copilot/index.ts @@ -15,7 +15,11 @@ import { OpenAIProvider, registerCopilotProvider, } from './providers'; -import { CopilotResolver, UserCopilotResolver } from './resolver'; +import { + CopilotResolver, + PromptsManagementResolver, + UserCopilotResolver, +} from './resolver'; import { ChatSessionService } from './session'; import { CopilotStorage } from './storage'; @@ -34,6 +38,7 @@ registerCopilotProvider(OpenAIProvider); PromptService, CopilotProviderService, CopilotStorage, + PromptsManagementResolver, ], controllers: [CopilotController], contributesTo: ServerFeature.Copilot, diff --git a/packages/backend/server/src/plugins/copilot/prompt.ts b/packages/backend/server/src/plugins/copilot/prompt.ts index d23ece0ca3..e07eb434bb 100644 --- a/packages/backend/server/src/plugins/copilot/prompt.ts +++ b/packages/backend/server/src/plugins/copilot/prompt.ts @@ -142,12 +142,32 @@ export class PromptService { * list prompt names * @returns prompt names */ - async list() { + async listNames() { return this.db.aiPrompt .findMany({ select: { name: true } }) .then(prompts => Array.from(new Set(prompts.map(p => p.name)))); } + async list() { + return this.db.aiPrompt.findMany({ + select: { + name: true, + action: true, + model: true, + messages: { + select: { + role: true, + content: true, + params: true, + }, + orderBy: { + idx: 'asc', + }, + }, + }, + }); + } + /** * get prompt messages by prompt name * @param name prompt name diff --git a/packages/backend/server/src/plugins/copilot/resolver.ts b/packages/backend/server/src/plugins/copilot/resolver.ts index 582080021a..6f7cdc5041 100644 --- a/packages/backend/server/src/plugins/copilot/resolver.ts +++ b/packages/backend/server/src/plugins/copilot/resolver.ts @@ -9,14 +9,17 @@ import { Mutation, ObjectType, Parent, + Query, registerEnumType, ResolveField, Resolver, } from '@nestjs/graphql'; +import { AiPromptRole } from '@prisma/client'; import { GraphQLJSON, SafeIntResolver } from 'graphql-scalars'; import GraphQLUpload from 'graphql-upload/GraphQLUpload.mjs'; import { CurrentUser } from '../../core/auth'; +import { Admin } from '../../core/common'; import { UserType } from '../../core/user'; import { PermissionService } from '../../core/workspaces/permission'; import { @@ -25,6 +28,7 @@ import { Throttle, TooManyRequestsException, } from '../../fundamentals'; +import { PromptService } from './prompt'; import { ChatSessionService } from './session'; import { CopilotStorage } from './storage'; import { @@ -152,6 +156,40 @@ class CopilotQuotaType { used!: number; } +registerEnumType(AiPromptRole, { + name: 'CopilotPromptMessageRole', +}); + +@InputType('CopilotPromptMessageInput') +@ObjectType() +class CopilotPromptMessageType { + @Field(() => AiPromptRole) + role!: AiPromptRole; + + @Field(() => String) + content!: string; + + @Field(() => GraphQLJSON, { nullable: true }) + params!: Record | null; +} + +registerEnumType(AvailableModels, { name: 'CopilotModels' }); + +@ObjectType() +class CopilotPromptType { + @Field(() => String) + name!: string; + + @Field(() => AvailableModels) + model!: AvailableModels; + + @Field(() => String, { nullable: true }) + action!: string | null; + + @Field(() => [CopilotPromptMessageType]) + messages!: CopilotPromptMessageType[]; +} + // ================== Resolver ================== @ObjectType('Copilot') @@ -370,3 +408,54 @@ export class UserCopilotResolver { return { workspaceId }; } } + +@InputType() +class CreateCopilotPromptInput { + @Field(() => String) + name!: string; + + @Field(() => AvailableModels) + model!: AvailableModels; + + @Field(() => String, { nullable: true }) + action!: string | null; + + @Field(() => [CopilotPromptMessageType]) + messages!: CopilotPromptMessageType[]; +} + +@Admin() +@Resolver(() => String) +export class PromptsManagementResolver { + constructor(private readonly promptService: PromptService) {} + + @Query(() => [CopilotPromptType], { + description: 'List all copilot prompts', + }) + async listCopilotPrompts() { + return this.promptService.list(); + } + + @Mutation(() => CopilotPromptType, { + description: 'Create a copilot prompt', + }) + async createCopilotPrompt( + @Args({ type: () => CreateCopilotPromptInput, name: 'input' }) + input: CreateCopilotPromptInput + ) { + await this.promptService.set(input.name, input.model, input.messages); + return this.promptService.get(input.name); + } + + @Mutation(() => CopilotPromptType, { + description: 'Update a copilot prompt', + }) + async updateCopilotPrompt( + @Args('name') name: string, + @Args('messages', { type: () => [CopilotPromptMessageType] }) + messages: CopilotPromptMessageType[] + ) { + await this.promptService.update(name, messages); + return this.promptService.get(name); + } +} diff --git a/packages/backend/server/src/schema.gql b/packages/backend/server/src/schema.gql index 639f6a71cf..0e3ce5f07e 100644 --- a/packages/backend/server/src/schema.gql +++ b/packages/backend/server/src/schema.gql @@ -34,6 +34,44 @@ type CopilotHistories { tokens: Int! } +enum CopilotModels { + DallE3 + Gpt4Omni + Gpt4TurboPreview + Gpt4VisionPreview + Gpt35Turbo + TextEmbedding3Large + TextEmbedding3Small + TextEmbeddingAda002 + TextModerationLatest + TextModerationStable +} + +input CopilotPromptMessageInput { + content: String! + params: JSON + role: CopilotPromptMessageRole! +} + +enum CopilotPromptMessageRole { + assistant + system + user +} + +type CopilotPromptMessageType { + content: String! + params: JSON + role: CopilotPromptMessageRole! +} + +type CopilotPromptType { + action: String + messages: [CopilotPromptMessageType!]! + model: CopilotModels! + name: String! +} + type CopilotQuota { limit: SafeInt used: SafeInt! @@ -63,6 +101,13 @@ input CreateCheckoutSessionInput { successCallbackLink: String! } +input CreateCopilotPromptInput { + action: String + messages: [CopilotPromptMessageInput!]! + model: CopilotModels! + name: String! +} + type CredentialsRequirementType { password: PasswordLimitsType! } @@ -206,6 +251,9 @@ type Mutation { """Create a chat message""" createCopilotMessage(options: CreateChatMessageInput!): String! + """Create a copilot prompt""" + createCopilotPrompt(input: CreateCopilotPromptInput!): CopilotPromptType! + """Create a chat session""" createCopilotSession(options: CreateChatSessionInput!): String! @@ -238,6 +286,9 @@ type Mutation { setBlob(blob: Upload!, workspaceId: String!): String! setWorkspaceExperimentalFeature(enable: Boolean!, feature: FeatureType!, workspaceId: String!): Boolean! sharePage(pageId: String!, workspaceId: String!): Boolean! @deprecated(reason: "renamed to publishPage") + + """Update a copilot prompt""" + updateCopilotPrompt(messages: [CopilotPromptMessageInput!]!, name: String!): CopilotPromptType! updateProfile(input: UpdateUserInput!): UserType! """update server runtime configurable setting""" @@ -296,6 +347,9 @@ type Query { """List blobs of workspace""" listBlobs(workspaceId: String!): [String!]! @deprecated(reason: "use `workspace.blobs` instead") + + """List all copilot prompts""" + listCopilotPrompts: [CopilotPromptType!]! listWorkspaceFeatures(feature: FeatureType!): [WorkspaceType!]! prices: [SubscriptionPrice!]! diff --git a/packages/backend/server/tests/copilot.spec.ts b/packages/backend/server/tests/copilot.spec.ts index dc8d4590ed..7d0b12e35f 100644 --- a/packages/backend/server/tests/copilot.spec.ts +++ b/packages/backend/server/tests/copilot.spec.ts @@ -77,13 +77,13 @@ test.beforeEach(async t => { test('should be able to manage prompt', async t => { const { prompt } = t.context; - t.is((await prompt.list()).length, 0, 'should have no prompt'); + t.is((await prompt.listNames()).length, 0, 'should have no prompt'); await prompt.set('test', 'test', [ { role: 'system', content: 'hello' }, { role: 'user', content: 'hello' }, ]); - t.is((await prompt.list()).length, 1, 'should have one prompt'); + t.is((await prompt.listNames()).length, 1, 'should have one prompt'); t.is( (await prompt.get('test'))!.finish({}).length, 2, @@ -98,7 +98,7 @@ test('should be able to manage prompt', async t => { ); await prompt.delete('test'); - t.is((await prompt.list()).length, 0, 'should have no prompt'); + t.is((await prompt.listNames()).length, 0, 'should have no prompt'); t.is(await prompt.get('test'), null, 'should not have the prompt'); });