From 3c01d944fb2ec954f7b2a1f36a5b5ceedf8d91d8 Mon Sep 17 00:00:00 2001 From: darkskygit Date: Tue, 2 Apr 2024 07:04:54 +0000 Subject: [PATCH] feat: add prompt service (#6241) fix CLOUD-19 --- .../20240321065017_ai_prompts/migration.sql | 16 +++++ .../20240325125057_ai_sessions/migration.sql | 24 +++++++ packages/backend/server/schema.prisma | 45 ++++++++++++ .../server/src/plugins/copilot/index.ts | 5 +- .../server/src/plugins/copilot/prompt.ts | 72 +++++++++++++++++++ .../{provider.ts => providers/index.ts} | 4 +- .../server/src/plugins/copilot/types.ts | 18 +++-- packages/backend/server/tests/copilot.spec.ts | 70 ++++++++++++++++++ 8 files changed, 246 insertions(+), 8 deletions(-) create mode 100644 packages/backend/server/migrations/20240321065017_ai_prompts/migration.sql create mode 100644 packages/backend/server/migrations/20240325125057_ai_sessions/migration.sql create mode 100644 packages/backend/server/src/plugins/copilot/prompt.ts rename packages/backend/server/src/plugins/copilot/{provider.ts => providers/index.ts} (98%) create mode 100644 packages/backend/server/tests/copilot.spec.ts diff --git a/packages/backend/server/migrations/20240321065017_ai_prompts/migration.sql b/packages/backend/server/migrations/20240321065017_ai_prompts/migration.sql new file mode 100644 index 0000000000..668f635154 --- /dev/null +++ b/packages/backend/server/migrations/20240321065017_ai_prompts/migration.sql @@ -0,0 +1,16 @@ +-- CreateEnum +CREATE TYPE "AiPromptRole" AS ENUM ('system', 'assistant', 'user'); + +-- CreateTable +CREATE TABLE "ai_prompts" ( + "id" VARCHAR NOT NULL, + "name" VARCHAR(20) NOT NULL, + "idx" INTEGER NOT NULL, + "role" "AiPromptRole" NOT NULL, + "content" TEXT NOT NULL, + "created_at" TIMESTAMPTZ(6) NOT NULL DEFAULT CURRENT_TIMESTAMP, + CONSTRAINT "ai_prompts_pkey" PRIMARY KEY ("id") +); + +-- CreateIndex +CREATE UNIQUE INDEX "ai_prompts_name_idx_key" ON "ai_prompts"("name", "idx"); \ No newline at end of file diff --git a/packages/backend/server/migrations/20240325125057_ai_sessions/migration.sql b/packages/backend/server/migrations/20240325125057_ai_sessions/migration.sql new file mode 100644 index 0000000000..42242736e7 --- /dev/null +++ b/packages/backend/server/migrations/20240325125057_ai_sessions/migration.sql @@ -0,0 +1,24 @@ +-- CreateTable +CREATE TABLE "ai_sessions" ( + "id" VARCHAR NOT NULL, + "user_id" VARCHAR NOT NULL, + "workspace_id" VARCHAR NOT NULL, + "doc_id" VARCHAR NOT NULL, + "prompt_name" VARCHAR NOT NULL, + "action" BOOLEAN NOT NULL, + "model" VARCHAR NOT NULL, + "messages" JSON NOT NULL, + "created_at" TIMESTAMPTZ(6) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updated_at" TIMESTAMPTZ(6) NOT NULL, + + CONSTRAINT "ai_sessions_pkey" PRIMARY KEY ("id") +); + +-- AddForeignKey +ALTER TABLE "ai_sessions" ADD CONSTRAINT "ai_sessions_user_id_fkey" FOREIGN KEY ("user_id") REFERENCES "users"("id") ON DELETE CASCADE ON UPDATE CASCADE; + +-- AddForeignKey +ALTER TABLE "ai_sessions" ADD CONSTRAINT "ai_sessions_workspace_id_fkey" FOREIGN KEY ("workspace_id") REFERENCES "workspaces"("id") ON DELETE CASCADE ON UPDATE CASCADE; + +-- AddForeignKey +ALTER TABLE "ai_sessions" ADD CONSTRAINT "ai_sessions_doc_id_workspace_id_fkey" FOREIGN KEY ("doc_id", "workspace_id") REFERENCES "snapshots"("guid", "workspace_id") ON DELETE CASCADE ON UPDATE CASCADE; diff --git a/packages/backend/server/schema.prisma b/packages/backend/server/schema.prisma index 21bc899883..593e6a8fd1 100644 --- a/packages/backend/server/schema.prisma +++ b/packages/backend/server/schema.prisma @@ -30,6 +30,7 @@ model User { pagePermissions WorkspacePageUserPermission[] connectedAccounts ConnectedAccount[] sessions UserSession[] + AiSession AiSession[] @@map("users") } @@ -96,6 +97,7 @@ model Workspace { permissions WorkspaceUserPermission[] pagePermissions WorkspacePageUserPermission[] features WorkspaceFeatures[] + AiSession AiSession[] @@map("workspaces") } @@ -321,6 +323,8 @@ model Snapshot { // but the created time of last seen update that has been merged into snapshot. updatedAt DateTime @map("updated_at") @db.Timestamptz(6) + AiSession AiSession[] + @@id([id, workspaceId]) @@map("snapshots") } @@ -422,6 +426,47 @@ model UserInvoice { @@map("user_invoices") } +enum AiPromptRole { + system + assistant + user +} + +model AiPrompt { + id String @id @default(uuid()) @db.VarChar + // prompt name + name String @db.VarChar(20) + // if a group of prompts contains multiple sentences, idx specifies the order of each sentence + idx Int @db.Integer + // system/assistant/user + role AiPromptRole + // prompt content + content String @db.Text + createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6) + + @@unique([name, idx]) + @@map("ai_prompts") +} + +model AiSession { + id String @id @default(uuid()) @db.VarChar + userId String @map("user_id") @db.VarChar + workspaceId String @map("workspace_id") @db.VarChar + docId String @map("doc_id") @db.VarChar + promptName String @map("prompt_name") @db.VarChar + action Boolean @db.Boolean + model String @db.VarChar + messages Json @db.Json + createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6) + updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(6) + + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + workspace Workspace @relation(fields: [workspaceId], references: [id], onDelete: Cascade) + doc Snapshot @relation(fields: [docId, workspaceId], references: [id, workspaceId], onDelete: Cascade) + + @@map("ai_sessions") +} + model DataMigration { id String @id @default(uuid()) @db.VarChar(36) name String @db.VarChar diff --git a/packages/backend/server/src/plugins/copilot/index.ts b/packages/backend/server/src/plugins/copilot/index.ts index 53dd28a178..954bfb4a7a 100644 --- a/packages/backend/server/src/plugins/copilot/index.ts +++ b/packages/backend/server/src/plugins/copilot/index.ts @@ -1,10 +1,11 @@ import { ServerFeature } from '../../core/config'; import { Plugin } from '../registry'; -import { assertProvidersConfigs, CopilotProviderService } from './provider'; +import { PromptService } from './prompt'; +import { assertProvidersConfigs, CopilotProviderService } from './providers'; @Plugin({ name: 'copilot', - providers: [CopilotProviderService], + providers: [PromptService, CopilotProviderService], contributesTo: ServerFeature.Copilot, if: config => { if (config.flavor.graphql) { diff --git a/packages/backend/server/src/plugins/copilot/prompt.ts b/packages/backend/server/src/plugins/copilot/prompt.ts new file mode 100644 index 0000000000..e259383b10 --- /dev/null +++ b/packages/backend/server/src/plugins/copilot/prompt.ts @@ -0,0 +1,72 @@ +import { Injectable } from '@nestjs/common'; +import { PrismaClient } from '@prisma/client'; + +import { ChatMessage } from './types'; + +@Injectable() +export class PromptService { + constructor(private readonly db: PrismaClient) {} + + /** + * list prompt names + * @returns prompt names + */ + async list() { + return this.db.aiPrompt + .findMany({ select: { name: true } }) + .then(prompts => Array.from(new Set(prompts.map(p => p.name)))); + } + + /** + * get prompt messages by prompt name + * @param name prompt name + * @returns prompt messages + */ + async get(name: string): Promise { + return this.db.aiPrompt.findMany({ + where: { + name, + }, + select: { + role: true, + content: true, + }, + orderBy: { + idx: 'asc', + }, + }); + } + + async set(name: string, messages: ChatMessage[]) { + return this.db.$transaction(async tx => { + const prompts = await tx.aiPrompt.count({ where: { name } }); + if (prompts > 0) { + return 0; + } + return tx.aiPrompt + .createMany({ + data: messages.map((m, idx) => ({ name, idx, ...m })), + }) + .then(ret => ret.count); + }); + } + + async update(name: string, messages: ChatMessage[]) { + return this.db.$transaction(async tx => { + await tx.aiPrompt.deleteMany({ where: { name } }); + return tx.aiPrompt + .createMany({ + data: messages.map((m, idx) => ({ name, idx, ...m })), + }) + .then(ret => ret.count); + }); + } + + async delete(name: string) { + return this.db.aiPrompt + .deleteMany({ + where: { name }, + }) + .then(ret => ret.count); + } +} diff --git a/packages/backend/server/src/plugins/copilot/provider.ts b/packages/backend/server/src/plugins/copilot/providers/index.ts similarity index 98% rename from packages/backend/server/src/plugins/copilot/provider.ts rename to packages/backend/server/src/plugins/copilot/providers/index.ts index 24bf67ba63..2b66669d88 100644 --- a/packages/backend/server/src/plugins/copilot/provider.ts +++ b/packages/backend/server/src/plugins/copilot/providers/index.ts @@ -2,14 +2,14 @@ import assert from 'node:assert'; import { Injectable, Logger } from '@nestjs/common'; -import { Config } from '../../fundamentals'; +import { Config } from '../../../fundamentals'; import { CapabilityToCopilotProvider, CopilotConfig, CopilotProvider, CopilotProviderCapability, CopilotProviderType, -} from './types'; +} from '../types'; type CopilotProviderConfig = CopilotConfig[keyof CopilotConfig]; diff --git a/packages/backend/server/src/plugins/copilot/types.ts b/packages/backend/server/src/plugins/copilot/types.ts index fd72c4fd32..b0ab8b16a7 100644 --- a/packages/backend/server/src/plugins/copilot/types.ts +++ b/packages/backend/server/src/plugins/copilot/types.ts @@ -1,4 +1,6 @@ +import { AiPromptRole } from '@prisma/client'; import type { ClientOptions as OpenAIClientOptions } from 'openai'; +import { z } from 'zod'; export interface CopilotConfig { openai: OpenAIClientOptions; @@ -23,10 +25,18 @@ export interface CopilotProvider { getCapabilities(): CopilotProviderCapability[]; } -export type ChatMessage = { - role: 'system' | 'assistant' | 'user'; - content: string; -}; +export const ChatMessageSchema = z + .object({ + role: z.enum( + Array.from(Object.values(AiPromptRole)) as [ + 'system' | 'assistant' | 'user', + ] + ), + content: z.string(), + }) + .strict(); + +export type ChatMessage = z.infer; export interface CopilotTextToTextProvider extends CopilotProvider { generateText(messages: ChatMessage[], model: string): Promise; diff --git a/packages/backend/server/tests/copilot.spec.ts b/packages/backend/server/tests/copilot.spec.ts new file mode 100644 index 0000000000..5fee16d985 --- /dev/null +++ b/packages/backend/server/tests/copilot.spec.ts @@ -0,0 +1,70 @@ +/// + +import { TestingModule } from '@nestjs/testing'; +import type { TestFn } from 'ava'; +import ava from 'ava'; + +import { AuthService } from '../src/core/auth'; +import { QuotaManagementService, QuotaModule } from '../src/core/quota'; +import { ConfigModule } from '../src/fundamentals/config'; +import { CopilotModule } from '../src/plugins/copilot'; +import { PromptService } from '../src/plugins/copilot/prompt'; +import { createTestingModule } from './utils'; + +const test = ava as TestFn<{ + auth: AuthService; + quotaManager: QuotaManagementService; + module: TestingModule; + prompt: PromptService; +}>; + +test.beforeEach(async t => { + const module = await createTestingModule({ + imports: [ + ConfigModule.forRoot({ + plugins: { + copilot: { + openai: { + apiKey: '1', + }, + }, + }, + }), + QuotaModule, + CopilotModule, + ], + }); + + const quotaManager = module.get(QuotaManagementService); + const auth = module.get(AuthService); + const prompt = module.get(PromptService); + + t.context.module = module; + t.context.quotaManager = quotaManager; + t.context.auth = auth; + t.context.prompt = prompt; +}); + +test.afterEach.always(async t => { + await t.context.module.close(); +}); + +test('should be able to manage prompt', async t => { + const { prompt } = t.context; + + t.is((await prompt.list()).length, 0, 'should have no prompt'); + + await prompt.set('test', [ + { role: 'system', content: 'hello' }, + { role: 'user', content: 'hello' }, + ]); + t.is((await prompt.list()).length, 1, 'should have one prompt'); + t.is((await prompt.get('test')).length, 2, 'should have two messages'); + + await prompt.update('test', [{ role: 'system', content: 'hello' }]); + t.is((await prompt.get('test')).length, 1, 'should have one message'); + + await prompt.delete('test'); + t.is((await prompt.list()).length, 0, 'should have no prompt'); + t.is((await prompt.get('test')).length, 0, 'should have no messages'); +});