diff --git a/packages/backend/server/migrations/20240325125057_ai_sessions/migration.sql b/packages/backend/server/migrations/20240325125057_ai_sessions/migration.sql index 42242736e7..3e66cfc73e 100644 --- a/packages/backend/server/migrations/20240325125057_ai_sessions/migration.sql +++ b/packages/backend/server/migrations/20240325125057_ai_sessions/migration.sql @@ -1,11 +1,12 @@ -- CreateTable CREATE TABLE "ai_sessions" ( - "id" VARCHAR NOT NULL, + "id" VARCHAR(36) NOT NULL, "user_id" VARCHAR NOT NULL, "workspace_id" VARCHAR NOT NULL, "doc_id" VARCHAR NOT NULL, "prompt_name" VARCHAR NOT NULL, "action" BOOLEAN NOT NULL, + "flavor" VARCHAR NOT NULL, "model" VARCHAR NOT NULL, "messages" JSON NOT NULL, "created_at" TIMESTAMPTZ(6) NOT NULL DEFAULT CURRENT_TIMESTAMP, diff --git a/packages/backend/server/migrations/20240402100608_ai_prompt_session_metadata/migration.sql b/packages/backend/server/migrations/20240402100608_ai_prompt_session_metadata/migration.sql new file mode 100644 index 0000000000..c2c3bfc649 --- /dev/null +++ b/packages/backend/server/migrations/20240402100608_ai_prompt_session_metadata/migration.sql @@ -0,0 +1,90 @@ +/* + Warnings: + + - You are about to drop the `ai_prompts` table. If the table is not empty, all the data it contains will be lost. + - You are about to drop the `ai_sessions` table. If the table is not empty, all the data it contains will be lost. + +*/ +-- DropForeignKey +ALTER TABLE "ai_sessions" DROP CONSTRAINT "ai_sessions_doc_id_workspace_id_fkey"; + +-- DropForeignKey +ALTER TABLE "ai_sessions" DROP CONSTRAINT "ai_sessions_user_id_fkey"; + +-- DropForeignKey +ALTER TABLE "ai_sessions" DROP CONSTRAINT "ai_sessions_workspace_id_fkey"; + +-- DropTable +DROP TABLE "ai_prompts"; + +-- DropTable +DROP TABLE "ai_sessions"; + +-- CreateTable +CREATE TABLE "ai_prompts_messages" ( + "prompt_id" INTEGER NOT NULL, + "idx" INTEGER NOT NULL, + "role" "AiPromptRole" NOT NULL, + "content" TEXT NOT NULL, + "params" JSON, + "created_at" TIMESTAMPTZ(6) NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +-- CreateTable +CREATE TABLE "ai_prompts_metadata" ( + "id" SERIAL NOT NULL, + "name" VARCHAR(32) NOT NULL, + "action" VARCHAR, + "model" VARCHAR, + "created_at" TIMESTAMPTZ(6) NOT NULL DEFAULT CURRENT_TIMESTAMP, + + CONSTRAINT "ai_prompts_metadata_pkey" PRIMARY KEY ("id") +); + +-- CreateTable +CREATE TABLE "ai_sessions_messages" ( + "id" VARCHAR(36) NOT NULL, + "session_id" VARCHAR(36) NOT NULL, + "role" "AiPromptRole" NOT NULL, + "content" TEXT NOT NULL, + "created_at" TIMESTAMPTZ(6) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updated_at" TIMESTAMPTZ(6) NOT NULL, + + CONSTRAINT "ai_sessions_messages_pkey" PRIMARY KEY ("id") +); + +-- CreateTable +CREATE TABLE "ai_sessions_metadata" ( + "id" VARCHAR(36) NOT NULL, + "user_id" VARCHAR(36) NOT NULL, + "workspace_id" VARCHAR(36) NOT NULL, + "doc_id" VARCHAR(36) NOT NULL, + "prompt_name" VARCHAR(32) NOT NULL, + "created_at" TIMESTAMPTZ(6) NOT NULL DEFAULT CURRENT_TIMESTAMP, + + CONSTRAINT "ai_sessions_metadata_pkey" PRIMARY KEY ("id") +); + +-- CreateIndex +CREATE UNIQUE INDEX "ai_prompts_messages_prompt_id_idx_key" ON "ai_prompts_messages"("prompt_id", "idx"); + +-- CreateIndex +CREATE UNIQUE INDEX "ai_prompts_metadata_name_key" ON "ai_prompts_metadata"("name"); + +-- AddForeignKey +ALTER TABLE "ai_prompts_messages" ADD CONSTRAINT "ai_prompts_messages_prompt_id_fkey" FOREIGN KEY ("prompt_id") REFERENCES "ai_prompts_metadata"("id") ON DELETE CASCADE ON UPDATE CASCADE; + +-- AddForeignKey +ALTER TABLE "ai_sessions_messages" ADD CONSTRAINT "ai_sessions_messages_session_id_fkey" FOREIGN KEY ("session_id") REFERENCES "ai_sessions_metadata"("id") ON DELETE CASCADE ON UPDATE CASCADE; + +-- AddForeignKey +ALTER TABLE "ai_sessions_metadata" ADD CONSTRAINT "ai_sessions_metadata_user_id_fkey" FOREIGN KEY ("user_id") REFERENCES "users"("id") ON DELETE CASCADE ON UPDATE CASCADE; + +-- AddForeignKey +ALTER TABLE "ai_sessions_metadata" ADD CONSTRAINT "ai_sessions_metadata_workspace_id_fkey" FOREIGN KEY ("workspace_id") REFERENCES "workspaces"("id") ON DELETE CASCADE ON UPDATE CASCADE; + +-- AddForeignKey +ALTER TABLE "ai_sessions_metadata" ADD CONSTRAINT "ai_sessions_metadata_doc_id_workspace_id_fkey" FOREIGN KEY ("doc_id", "workspace_id") REFERENCES "snapshots"("guid", "workspace_id") ON DELETE CASCADE ON UPDATE CASCADE; + +-- AddForeignKey +ALTER TABLE "ai_sessions_metadata" ADD CONSTRAINT "ai_sessions_metadata_prompt_name_fkey" FOREIGN KEY ("prompt_name") REFERENCES "ai_prompts_metadata"("name") ON DELETE CASCADE ON UPDATE CASCADE; diff --git a/packages/backend/server/package.json b/packages/backend/server/package.json index 61a1cd94c8..7c9dc23fe1 100644 --- a/packages/backend/server/package.json +++ b/packages/backend/server/package.json @@ -72,6 +72,7 @@ "keyv": "^4.5.4", "lodash-es": "^4.17.21", "mixpanel": "^0.18.0", + "mustache": "^4.2.0", "nanoid": "^5.0.6", "nest-commander": "^3.12.5", "nestjs-throttler-storage-redis": "^0.4.1", @@ -87,6 +88,7 @@ "semver": "^7.6.0", "socket.io": "^4.7.4", "stripe": "^14.18.0", + "tiktoken": "^1.0.13", "ts-node": "^10.9.2", "typescript": "^5.3.3", "ws": "^8.16.0", @@ -105,6 +107,7 @@ "@types/keyv": "^4.2.0", "@types/lodash-es": "^4.17.12", "@types/mixpanel": "^2.14.8", + "@types/mustache": "^4", "@types/node": "^20.11.20", "@types/nodemailer": "^6.4.14", "@types/on-headers": "^1.0.3", diff --git a/packages/backend/server/schema.prisma b/packages/backend/server/schema.prisma index 593e6a8fd1..09f82c1266 100644 --- a/packages/backend/server/schema.prisma +++ b/packages/backend/server/schema.prisma @@ -30,7 +30,7 @@ model User { pagePermissions WorkspacePageUserPermission[] connectedAccounts ConnectedAccount[] sessions UserSession[] - AiSession AiSession[] + aiSessions AiSession[] @@map("users") } @@ -97,7 +97,7 @@ model Workspace { permissions WorkspaceUserPermission[] pagePermissions WorkspacePageUserPermission[] features WorkspaceFeatures[] - AiSession AiSession[] + aiSessions AiSession[] @@map("workspaces") } @@ -323,7 +323,7 @@ model Snapshot { // but the created time of last seen update that has been merged into snapshot. updatedAt DateTime @map("updated_at") @db.Timestamptz(6) - AiSession AiSession[] + aiSessions AiSession[] @@id([id, workspaceId]) @@map("snapshots") @@ -432,39 +432,66 @@ enum AiPromptRole { user } -model AiPrompt { - id String @id @default(uuid()) @db.VarChar - // prompt name - name String @db.VarChar(20) +model AiPromptMessage { + promptId Int @map("prompt_id") @db.Integer // if a group of prompts contains multiple sentences, idx specifies the order of each sentence idx Int @db.Integer // system/assistant/user role AiPromptRole // prompt content content String @db.Text + params Json? @db.Json createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6) - @@unique([name, idx]) - @@map("ai_prompts") + prompt AiPrompt @relation(fields: [promptId], references: [id], onDelete: Cascade) + + @@unique([promptId, idx]) + @@map("ai_prompts_messages") +} + +model AiPrompt { + id Int @id @default(autoincrement()) @db.Integer + name String @unique @db.VarChar(32) + // an mark identifying which view to use to display the session + // it is only used in the frontend and does not affect the backend + action String? @db.VarChar + model String? @db.VarChar + createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6) + + messages AiPromptMessage[] + sessions AiSession[] + + @@map("ai_prompts_metadata") +} + +model AiSessionMessage { + id String @id @default(uuid()) @db.VarChar(36) + sessionId String @map("session_id") @db.VarChar(36) + role AiPromptRole + content String @db.Text + createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6) + updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(6) + + session AiSession @relation(fields: [sessionId], references: [id], onDelete: Cascade) + + @@map("ai_sessions_messages") } model AiSession { - id String @id @default(uuid()) @db.VarChar - userId String @map("user_id") @db.VarChar - workspaceId String @map("workspace_id") @db.VarChar - docId String @map("doc_id") @db.VarChar - promptName String @map("prompt_name") @db.VarChar - action Boolean @db.Boolean - model String @db.VarChar - messages Json @db.Json + id String @id @default(uuid()) @db.VarChar(36) + userId String @map("user_id") @db.VarChar(36) + workspaceId String @map("workspace_id") @db.VarChar(36) + docId String @map("doc_id") @db.VarChar(36) + promptName String @map("prompt_name") @db.VarChar(32) createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6) - updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(6) - user User @relation(fields: [userId], references: [id], onDelete: Cascade) - workspace Workspace @relation(fields: [workspaceId], references: [id], onDelete: Cascade) - doc Snapshot @relation(fields: [docId, workspaceId], references: [id, workspaceId], onDelete: Cascade) + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + workspace Workspace @relation(fields: [workspaceId], references: [id], onDelete: Cascade) + doc Snapshot @relation(fields: [docId, workspaceId], references: [id, workspaceId], onDelete: Cascade) + prompt AiPrompt @relation(fields: [promptName], references: [name], onDelete: Cascade) + messages AiSessionMessage[] - @@map("ai_sessions") + @@map("ai_sessions_metadata") } model DataMigration { diff --git a/packages/backend/server/src/config/affine.env.ts b/packages/backend/server/src/config/affine.env.ts index d069be52de..148cb6585e 100644 --- a/packages/backend/server/src/config/affine.env.ts +++ b/packages/backend/server/src/config/affine.env.ts @@ -19,6 +19,7 @@ AFFiNE.ENV_MAP = { MAILER_SECURE: ['mailer.secure', 'boolean'], THROTTLE_TTL: ['rateLimiter.ttl', 'int'], THROTTLE_LIMIT: ['rateLimiter.limit', 'int'], + COPILOT_OPENAI_API_KEY: 'plugins.copilot.openai.apiKey', REDIS_SERVER_HOST: 'plugins.redis.host', REDIS_SERVER_PORT: ['plugins.redis.port', 'int'], REDIS_SERVER_USER: 'plugins.redis.username', diff --git a/packages/backend/server/src/config/affine.self.ts b/packages/backend/server/src/config/affine.self.ts index 5cc19b37f7..736b87e9f3 100644 --- a/packages/backend/server/src/config/affine.self.ts +++ b/packages/backend/server/src/config/affine.self.ts @@ -39,9 +39,7 @@ if (env.R2_OBJECT_STORAGE_ACCOUNT_ID) { } AFFiNE.plugins.use('copilot', { - openai: { - apiKey: 'test', - }, + openai: {}, }); AFFiNE.plugins.use('redis'); AFFiNE.plugins.use('payment', { diff --git a/packages/backend/server/src/data/migrations/1712068777394-prompts.ts b/packages/backend/server/src/data/migrations/1712068777394-prompts.ts new file mode 100644 index 0000000000..125bcb60e3 --- /dev/null +++ b/packages/backend/server/src/data/migrations/1712068777394-prompts.ts @@ -0,0 +1,33 @@ +import { PrismaClient } from '@prisma/client'; + +import { prompts } from './utils/prompts'; + +export class Prompts1712068777394 { + // do the migration + static async up(db: PrismaClient) { + await db.$transaction(async tx => { + await Promise.all( + prompts.map(prompt => + tx.aiPrompt.create({ + data: { + name: prompt.name, + action: prompt.action, + model: prompt.model, + messages: { + create: prompt.messages.map((message, idx) => ({ + idx, + role: message.role, + content: message.content, + params: message.params, + })), + }, + }, + }) + ) + ); + }); + } + + // revert the migration + static async down(_db: PrismaClient) {} +} diff --git a/packages/backend/server/src/data/migrations/1712224382221-refresh-free-plan.ts b/packages/backend/server/src/data/migrations/1712224382221-refresh-free-plan.ts new file mode 100644 index 0000000000..7d94a9d74a --- /dev/null +++ b/packages/backend/server/src/data/migrations/1712224382221-refresh-free-plan.ts @@ -0,0 +1,16 @@ +import { PrismaClient } from '@prisma/client'; + +import { Quotas } from '../../core/quota'; +import { upgradeQuotaVersion } from './utils/user-quotas'; + +export class RefreshFreePlan1712224382221 { + // do the migration + static async up(db: PrismaClient) { + // free plan 1.0 + const quota = Quotas[4]; + await upgradeQuotaVersion(db, quota, 'free plan 1.1 migration'); + } + + // revert the migration + static async down(_db: PrismaClient) {} +} diff --git a/packages/backend/server/src/data/migrations/utils/prompts.ts b/packages/backend/server/src/data/migrations/utils/prompts.ts new file mode 100644 index 0000000000..6423658a0d --- /dev/null +++ b/packages/backend/server/src/data/migrations/utils/prompts.ts @@ -0,0 +1,275 @@ +import { AiPromptRole } from '@prisma/client'; + +type PromptMessage = { + role: AiPromptRole; + content: string; + params?: Record; +}; + +type Prompt = { + name: string; + action?: string; + model: string; + messages: PromptMessage[]; +}; + +export const prompts: Prompt[] = [ + { + name: 'debug:chat:gpt4', + model: 'gpt-4-turbo-preview', + messages: [], + }, + { + name: 'debug:action:gpt4', + action: 'text', + model: 'gpt-4-turbo-preview', + messages: [], + }, + { + name: 'debug:action:vision4', + action: 'text', + model: 'gpt-4-vision-preview', + messages: [], + }, + { + name: 'Summary', + action: 'text', + model: 'gpt-3.5-turbo', + messages: [ + { + role: 'assistant', + content: + 'Summarize the key points from the following content in a clear and concise manner, suitable for a reader who is seeking a quick understanding of the original content. Ensure to capture the main ideas and any significant details without unnecessary elaboration:\n\n{{content}}', + }, + ], + }, + { + name: 'Summary the webpage', + action: 'text', + model: 'gpt-3.5-turbo', + messages: [ + { + role: 'assistant', + content: + 'Summarize the insights from the following webpage content:\n\nFirst, provide a brief summary of the webpage content below. Then, list the insights derived from it, one by one.\n\n{{#links}}\n- {{.}}\n{{/links}}', + }, + ], + }, + { + name: 'Explain this image', + action: 'text', + model: 'gpt-3.5-turbo', + messages: [ + { + role: 'assistant', + content: + 'Describe the scene captured in this image, focusing on the details, colors, emotions, and any interactions between subjects or objects present.\n\n{{image}}', + }, + ], + }, + { + name: 'Explain this code', + action: 'text', + model: 'gpt-3.5-turbo', + messages: [ + { + role: 'assistant', + content: + 'Analyze and explain the functionality of the following code snippet, highlighting its purpose, the logic behind its operations, and its potential output:\n\n{{code}}', + }, + ], + }, + { + name: 'Translate to', + action: 'text', + model: 'gpt-3.5-turbo', + messages: [ + { + role: 'assistant', + content: + 'Please translate the following content into {{language}} and return it to us, adhering to the original format of the content:\n\n{{content}}', + params: { + language: [ + 'English', + 'Spanish', + 'German', + 'French', + 'Italian', + 'Simplified Chinese', + 'Traditional Chinese', + 'Japanese', + 'Russian', + 'Korean', + ], + }, + }, + ], + }, + { + name: 'Write an article about this', + action: 'text', + model: 'gpt-3.5-turbo', + messages: [ + { + role: 'assistant', + content: 'Write an article about following content:\n\n{{content}}', + }, + ], + }, + { + name: 'Write a twitter about this', + action: 'text', + model: 'gpt-3.5-turbo', + messages: [ + { + role: 'assistant', + content: 'Write a twitter about following content:\n\n{{content}}', + }, + ], + }, + { + name: 'Write a poem about this', + action: 'text', + model: 'gpt-3.5-turbo', + messages: [ + { + role: 'assistant', + content: 'Write a poem about following content:\n\n{{content}}', + }, + ], + }, + { + name: 'Write a blog post about this', + action: 'text', + model: 'gpt-3.5-turbo', + messages: [ + { + role: 'assistant', + content: 'Write a blog post about following content:\n\n{{content}}', + }, + ], + }, + { + name: 'Change tone to', + action: 'text', + model: 'gpt-3.5-turbo', + messages: [ + { + role: 'assistant', + content: + 'Please rephrase the following content to convey a more {{tone}} tone:\n\n{{content}}', + params: { tone: ['professional', 'informal', 'friendly', 'critical'] }, + }, + ], + }, + { + name: 'Brainstorm ideas about this', + action: 'text', + model: 'gpt-3.5-turbo', + messages: [ + { + role: 'assistant', + content: + 'Using the information following content, brainstorm ideas and output your thoughts in a bulleted points format.\n\n{{content}}', + }, + ], + }, + { + name: 'Improve writing for it', + action: 'text', + model: 'gpt-3.5-turbo', + messages: [ + { + role: 'assistant', + content: + 'Please rewrite the following content to enhance its clarity, coherence, and overall quality, ensuring that the message is effectively communicated and free of any grammatical errors. Provide a refined version that maintains the original intent but exhibits improved structure and readability:\n\n{{content}}', + }, + ], + }, + { + name: 'Improve grammar for it', + action: 'text', + model: 'gpt-3.5-turbo', + messages: [ + { + role: 'assistant', + content: + 'Please correct the grammar in the following content to ensure that it is free from any grammatical errors, maintaining proper sentence structure, correct tense usage, and accurate punctuation. Ensure that the final content is grammatically sound while preserving the original message:\n\n{{content}}', + }, + ], + }, + { + name: 'Fix spelling for it', + action: 'text', + model: 'gpt-3.5-turbo', + messages: [ + { + role: 'assistant', + content: + "Please carefully review the following content and correct all spelling mistakes. Ensure that each word is spelled correctly, adhering to standard {{language}} spelling conventions. The content's meaning should remain unchanged; only the spelling errors need to be addressed:\n\n{{content}}", + params: { + language: [ + 'English', + 'Spanish', + 'German', + 'French', + 'Italian', + 'Simplified Chinese', + 'Traditional Chinese', + 'Japanese', + 'Russian', + 'Korean', + ], + }, + }, + ], + }, + { + name: 'Find action items from it', + action: 'todo-list', + model: 'gpt-3.5-turbo', + messages: [ + { + role: 'assistant', + content: + 'Identify action items from the following content and return them as a to-do list in Markdown format:\n\n{{content}}', + }, + ], + }, + { + name: 'Check code error', + action: 'text', + model: 'gpt-3.5-turbo', + messages: [ + { + role: 'assistant', + content: + 'Review the following code snippet for any syntax errors and list them individually:\n\n{{content}}', + }, + ], + }, + { + name: 'Create a presentation', + action: 'text', + model: 'gpt-3.5-turbo', + messages: [ + { + role: 'assistant', + content: + 'I want to write a PPT, that has many pages, each page has 1 to 4 sections,\neach section has a title of no more than 30 words and no more than 500 words of content,\nbut also need some keywords that match the content of the paragraph used to generate images,\nTry to have a different number of section per page\nThe first page is the cover, which generates a general title (no more than 4 words) and description based on the topic\nthis is a template:\n- page name\n - title\n - keywords\n - description\n- page name\n - section name\n - keywords\n - content\n - section name\n - keywords\n - content\n- page name\n - section name\n - keywords\n - content\n - section name\n - keywords\n - content\n - section name\n - keywords\n - content\n- page name\n - section name\n - keywords\n - content\n - section name\n - keywords\n - content\n - section name\n - keywords\n - content\n - section name\n - keywords\n - content\n- page name\n - section name\n - keywords\n - content\n\n\nplease help me to write this ppt, do not output any content that does not belong to the ppt content itself outside of the content, Directly output the title content keywords without prefix like Title:xxx, Content: xxx, Keywords: xxx\nThe PPT is based on the following topics:\n\n{{content}}', + }, + ], + }, + { + name: 'Create headings', + action: 'text', + model: 'gpt-3.5-turbo', + messages: [ + { + role: 'assistant', + content: + 'Craft a distilled heading from the following content, maximum 10 words, format: H1.\n\n{{content}}', + }, + ], + }, +]; diff --git a/packages/backend/server/src/plugins/copilot/index.ts b/packages/backend/server/src/plugins/copilot/index.ts index 954bfb4a7a..0cb50c5ae4 100644 --- a/packages/backend/server/src/plugins/copilot/index.ts +++ b/packages/backend/server/src/plugins/copilot/index.ts @@ -2,10 +2,11 @@ import { ServerFeature } from '../../core/config'; import { Plugin } from '../registry'; import { PromptService } from './prompt'; import { assertProvidersConfigs, CopilotProviderService } from './providers'; +import { ChatSessionService } from './session'; @Plugin({ name: 'copilot', - providers: [PromptService, CopilotProviderService], + providers: [ChatSessionService, PromptService, CopilotProviderService], contributesTo: ServerFeature.Copilot, if: config => { if (config.flavor.graphql) { diff --git a/packages/backend/server/src/plugins/copilot/prompt.ts b/packages/backend/server/src/plugins/copilot/prompt.ts index e259383b10..b0857dd03f 100644 --- a/packages/backend/server/src/plugins/copilot/prompt.ts +++ b/packages/backend/server/src/plugins/copilot/prompt.ts @@ -1,7 +1,124 @@ import { Injectable } from '@nestjs/common'; -import { PrismaClient } from '@prisma/client'; +import { AiPrompt, PrismaClient } from '@prisma/client'; +import Mustache from 'mustache'; +import { Tiktoken } from 'tiktoken'; -import { ChatMessage } from './types'; +import { + getTokenEncoder, + PromptMessage, + PromptMessageSchema, + PromptParams, +} from './types'; + +// disable escaping +Mustache.escape = (text: string) => text; + +function extractMustacheParams(template: string) { + const regex = /\{\{\s*([^{}]+)\s*\}\}/g; + const params = []; + let match; + + while ((match = regex.exec(template)) !== null) { + params.push(match[1]); + } + + return Array.from(new Set(params)); +} + +export class ChatPrompt { + public readonly encoder?: Tiktoken; + private readonly promptTokenSize: number; + private readonly templateParamKeys: string[] = []; + private readonly templateParams: PromptParams = {}; + + static createFromPrompt( + options: Omit & { + messages: PromptMessage[]; + } + ) { + return new ChatPrompt( + options.name, + options.action, + options.model, + options.messages + ); + } + + constructor( + public readonly name: string, + public readonly action: string | null, + public readonly model: string | null, + private readonly messages: PromptMessage[] + ) { + this.encoder = getTokenEncoder(model); + this.promptTokenSize = + this.encoder?.encode_ordinary(messages.map(m => m.content).join('') || '') + .length || 0; + this.templateParamKeys = extractMustacheParams( + messages.map(m => m.content).join('') + ); + this.templateParams = messages.reduce( + (acc, m) => Object.assign(acc, m.params), + {} as PromptParams + ); + } + + /** + * get prompt token size + */ + get tokens() { + return this.promptTokenSize; + } + + /** + * get prompt param keys in template + */ + get paramKeys() { + return this.templateParamKeys.slice(); + } + + /** + * get prompt params + */ + get params() { + return { ...this.templateParams }; + } + + encode(message: string) { + return this.encoder?.encode_ordinary(message).length || 0; + } + + private checkParams(params: PromptParams) { + const selfParams = this.templateParams; + for (const key of Object.keys(selfParams)) { + const options = selfParams[key]; + const income = params[key]; + if ( + typeof income !== 'string' || + (Array.isArray(options) && !options.includes(income)) + ) { + throw new Error(`Invalid param: ${key}`); + } + } + } + + /** + * render prompt messages with params + * @param params record of params, e.g. { name: 'Alice' } + * @returns e.g. [{ role: 'system', content: 'Hello, {{name}}' }] => [{ role: 'system', content: 'Hello, Alice' }] + */ + finish(params: PromptParams) { + this.checkParams(params); + return this.messages.map(m => ({ + ...m, + content: Mustache.render(m.content, params), + })); + } + + free() { + this.encoder?.free(); + } +} @Injectable() export class PromptService { @@ -22,51 +139,74 @@ export class PromptService { * @param name prompt name * @returns prompt messages */ - async get(name: string): Promise { - return this.db.aiPrompt.findMany({ - where: { - name, - }, - select: { - role: true, - content: true, - }, - orderBy: { - idx: 'asc', - }, - }); + async get(name: string): Promise { + return this.db.aiPrompt + .findUnique({ + where: { + name, + }, + select: { + name: true, + action: true, + model: true, + messages: { + select: { + role: true, + content: true, + params: true, + }, + orderBy: { + idx: 'asc', + }, + }, + }, + }) + .then(p => { + const messages = PromptMessageSchema.array().safeParse(p?.messages); + if (p && messages.success) { + return ChatPrompt.createFromPrompt({ ...p, messages: messages.data }); + } + return null; + }); } - async set(name: string, messages: ChatMessage[]) { - return this.db.$transaction(async tx => { - const prompts = await tx.aiPrompt.count({ where: { name } }); - if (prompts > 0) { - return 0; - } - return tx.aiPrompt - .createMany({ - data: messages.map((m, idx) => ({ name, idx, ...m })), - }) - .then(ret => ret.count); - }); + async set(name: string, messages: PromptMessage[]) { + return await this.db.aiPrompt + .create({ + data: { + name, + messages: { + create: messages.map((m, idx) => ({ + idx, + ...m, + params: m.params || undefined, + })), + }, + }, + }) + .then(ret => ret.id); } - async update(name: string, messages: ChatMessage[]) { - return this.db.$transaction(async tx => { - await tx.aiPrompt.deleteMany({ where: { name } }); - return tx.aiPrompt - .createMany({ - data: messages.map((m, idx) => ({ name, idx, ...m })), - }) - .then(ret => ret.count); - }); + async update(name: string, messages: PromptMessage[]) { + return this.db.aiPrompt + .update({ + where: { name }, + data: { + messages: { + // cleanup old messages + deleteMany: {}, + create: messages.map((m, idx) => ({ + idx, + ...m, + params: m.params || undefined, + })), + }, + }, + }) + .then(ret => ret.id); } async delete(name: string) { - return this.db.aiPrompt - .deleteMany({ - where: { name }, - }) - .then(ret => ret.count); + return this.db.aiPrompt.delete({ where: { name } }).then(ret => ret.id); } } diff --git a/packages/backend/server/src/plugins/copilot/session.ts b/packages/backend/server/src/plugins/copilot/session.ts new file mode 100644 index 0000000000..e762cbe552 --- /dev/null +++ b/packages/backend/server/src/plugins/copilot/session.ts @@ -0,0 +1,203 @@ +import { randomUUID } from 'node:crypto'; + +import { Injectable, Logger } from '@nestjs/common'; +import { PrismaClient } from '@prisma/client'; + +import { ChatPrompt, PromptService } from './prompt'; +import { + ChatMessage, + ChatMessageSchema, + PromptMessage, + PromptParams, +} from './types'; + +export interface ChatSessionOptions { + userId: string; + workspaceId: string; + docId: string; + promptName: string; +} + +export interface ChatSessionState + extends Omit { + // connect ids + sessionId: string; + // states + prompt: ChatPrompt; + messages: ChatMessage[]; +} + +export class ChatSession implements AsyncDisposable { + constructor( + private readonly state: ChatSessionState, + private readonly dispose?: (state: ChatSessionState) => Promise, + private readonly maxTokenSize = 3840 + ) {} + + get model() { + return this.state.prompt.model; + } + + push(message: ChatMessage) { + this.state.messages.push(message); + } + + pop() { + this.state.messages.pop(); + } + + private takeMessages(): ChatMessage[] { + if (this.state.prompt.action) { + const messages = this.state.messages; + return messages.slice(messages.length - 1); + } + const ret = []; + const messages = this.state.messages.slice(); + + let size = this.state.prompt.tokens; + while (messages.length) { + const message = messages.pop(); + if (!message) break; + + size += this.state.prompt.encode(message.content); + if (size > this.maxTokenSize) { + break; + } + ret.push(message); + } + ret.reverse(); + + return ret; + } + + finish(params: PromptParams): PromptMessage[] { + const messages = this.takeMessages(); + return [...this.state.prompt.finish(params), ...messages]; + } + + async save() { + await this.dispose?.(this.state); + } + + async [Symbol.asyncDispose]() { + this.state.prompt.free(); + await this.save?.(); + } +} + +@Injectable() +export class ChatSessionService { + private readonly logger = new Logger(ChatSessionService.name); + constructor( + private readonly db: PrismaClient, + private readonly prompt: PromptService + ) {} + + private async setSession(state: ChatSessionState): Promise { + await this.db.aiSession.upsert({ + where: { + id: state.sessionId, + }, + update: { + messages: { + create: state.messages.map((m, idx) => ({ idx, ...m })), + }, + }, + create: { + id: state.sessionId, + messages: { create: state.messages }, + // connect + user: { connect: { id: state.userId } }, + workspace: { connect: { id: state.workspaceId } }, + doc: { + connect: { + id_workspaceId: { + id: state.docId, + workspaceId: state.workspaceId, + }, + }, + }, + prompt: { connect: { name: state.prompt.name } }, + }, + }); + } + + private async getSession( + sessionId: string + ): Promise { + return await this.db.aiSession + .findUnique({ + where: { id: sessionId }, + select: { + id: true, + userId: true, + workspaceId: true, + docId: true, + messages: true, + prompt: { + select: { + name: true, + action: true, + model: true, + messages: { + select: { + role: true, + content: true, + }, + orderBy: { + idx: 'asc', + }, + }, + }, + }, + }, + }) + .then(async session => { + if (!session) return; + const messages = ChatMessageSchema.array().safeParse(session.messages); + + return { + sessionId: session.id, + userId: session.userId, + workspaceId: session.workspaceId, + docId: session.docId, + prompt: ChatPrompt.createFromPrompt(session.prompt), + messages: messages.success ? messages.data : [], + }; + }); + } + + async create(options: ChatSessionOptions): Promise { + const sessionId = randomUUID(); + const prompt = await this.prompt.get(options.promptName); + if (!prompt) { + this.logger.error(`Prompt not found: ${options.promptName}`); + throw new Error('Prompt not found'); + } + await this.setSession({ ...options, sessionId, prompt, messages: [] }); + return sessionId; + } + + /** + * usage: + * ``` typescript + * { + * // allocate a session, can be reused chat in about 12 hours with same session + * await using session = await session.get(sessionId); + * session.push(message); + * copilot.generateText(session.finish(), model); + * } + * // session will be disposed after the block + * @param sessionId session id + * @returns + */ + async get(sessionId: string): Promise { + const state = await this.getSession(sessionId); + if (state) { + return new ChatSession(state, async state => { + await this.setSession(state); + }); + } + return null; + } +} diff --git a/packages/backend/server/src/plugins/copilot/types.ts b/packages/backend/server/src/plugins/copilot/types.ts index b0ab8b16a7..f48ef60bb0 100644 --- a/packages/backend/server/src/plugins/copilot/types.ts +++ b/packages/backend/server/src/plugins/copilot/types.ts @@ -1,5 +1,11 @@ import { AiPromptRole } from '@prisma/client'; import type { ClientOptions as OpenAIClientOptions } from 'openai'; +import { + encoding_for_model, + get_encoding, + Tiktoken, + TiktokenModel, +} from 'tiktoken'; import { z } from 'zod'; export interface CopilotConfig { @@ -9,6 +15,76 @@ export interface CopilotConfig { }; } +export enum AvailableModels { + // text to text + Gpt4VisionPreview = 'gpt-4-vision-preview', + Gpt4TurboPreview = 'gpt-4-turbo-preview', + Gpt35Turbo = 'gpt-3.5-turbo', + // embeddings + TextEmbedding3Large = 'text-embedding-3-large', + TextEmbedding3Small = 'text-embedding-3-small', + TextEmbeddingAda002 = 'text-embedding-ada-002', + // moderation + TextModerationLatest = 'text-moderation-latest', + TextModerationStable = 'text-moderation-stable', +} + +export type AvailableModel = keyof typeof AvailableModels; + +export function getTokenEncoder(model?: string | null): Tiktoken | undefined { + if (!model) return undefined; + const modelStr = AvailableModels[model as AvailableModel]; + if (!modelStr) return undefined; + if (modelStr.startsWith('gpt')) { + return encoding_for_model(modelStr as TiktokenModel); + } else if (modelStr.startsWith('dall')) { + // dalle don't need to calc the token + return undefined; + } else { + return get_encoding('cl100k_base'); + } +} + +// ======== ChatMessage ======== + +export const ChatMessageRole = Object.values(AiPromptRole) as [ + 'system', + 'assistant', + 'user', +]; + +export const PromptMessageSchema = z.object({ + role: z.enum(ChatMessageRole), + content: z.string(), + attachments: z.array(z.string()).optional(), + params: z + .record(z.union([z.string(), z.array(z.string())])) + .optional() + .nullable(), +}); + +export type PromptMessage = z.infer; + +export type PromptParams = NonNullable; + +export const ChatMessageSchema = PromptMessageSchema.extend({ + createdAt: z.date(), +}).strict(); + +export type ChatMessage = z.infer; + +export const ChatHistorySchema = z + .object({ + sessionId: z.string(), + tokens: z.number(), + messages: z.array(ChatMessageSchema), + }) + .strict(); + +export type ChatHistory = z.infer; + +// ======== Provider Interface ======== + export enum CopilotProviderType { FAL = 'fal', OpenAI = 'openai', @@ -25,24 +101,26 @@ export interface CopilotProvider { getCapabilities(): CopilotProviderCapability[]; } -export const ChatMessageSchema = z - .object({ - role: z.enum( - Array.from(Object.values(AiPromptRole)) as [ - 'system' | 'assistant' | 'user', - ] - ), - content: z.string(), - }) - .strict(); - -export type ChatMessage = z.infer; - export interface CopilotTextToTextProvider extends CopilotProvider { - generateText(messages: ChatMessage[], model: string): Promise; + generateText( + messages: PromptMessage[], + model: string, + options: { + temperature?: number; + maxTokens?: number; + signal?: AbortSignal; + user?: string; + } + ): Promise; generateTextStream( - messages: ChatMessage[], - model: string + messages: PromptMessage[], + model: string, + options: { + temperature?: number; + maxTokens?: number; + signal?: AbortSignal; + user?: string; + } ): AsyncIterable; } diff --git a/packages/backend/server/tests/copilot.spec.ts b/packages/backend/server/tests/copilot.spec.ts index 5fee16d985..141e6c477b 100644 --- a/packages/backend/server/tests/copilot.spec.ts +++ b/packages/backend/server/tests/copilot.spec.ts @@ -59,12 +59,74 @@ test('should be able to manage prompt', async t => { { role: 'user', content: 'hello' }, ]); t.is((await prompt.list()).length, 1, 'should have one prompt'); - t.is((await prompt.get('test')).length, 2, 'should have two messages'); + t.is( + (await prompt.get('test'))!.finish({}).length, + 2, + 'should have two messages' + ); await prompt.update('test', [{ role: 'system', content: 'hello' }]); - t.is((await prompt.get('test')).length, 1, 'should have one message'); + t.is( + (await prompt.get('test'))!.finish({}).length, + 1, + 'should have one message' + ); await prompt.delete('test'); t.is((await prompt.list()).length, 0, 'should have no prompt'); - t.is((await prompt.get('test')).length, 0, 'should have no messages'); + t.is(await prompt.get('test'), null, 'should not have the prompt'); +}); + +test('should be able to render prompt', async t => { + const { prompt } = t.context; + + const msg = { + role: 'system' as const, + content: 'translate {{src_language}} to {{dest_language}}: {{content}}', + params: { src_language: ['eng'], dest_language: ['chs', 'jpn', 'kor'] }, + }; + const params = { + src_language: 'eng', + dest_language: 'chs', + content: 'hello world', + }; + + await prompt.set('test', [msg]); + const testPrompt = await prompt.get('test'); + t.assert(testPrompt, 'should have prompt'); + t.is( + testPrompt?.finish(params).pop()?.content, + 'translate eng to chs: hello world', + 'should render the prompt' + ); + t.deepEqual( + testPrompt?.paramKeys, + Object.keys(params), + 'should have param keys' + ); + t.deepEqual(testPrompt?.params, msg.params, 'should have params'); + t.throws(() => testPrompt?.finish({ src_language: 'abc' }), { + instanceOf: Error, + }); +}); + +test('should be able to render listed prompt', async t => { + const { prompt } = t.context; + + const msg = { + role: 'system' as const, + content: 'links:\n{{#links}}- {{.}}\n{{/links}}', + }; + const params = { + links: ['https://affine.pro', 'https://github.com/toeverything/affine'], + }; + + await prompt.set('test', [msg]); + const testPrompt = await prompt.get('test'); + + t.is( + testPrompt?.finish(params).pop()?.content, + 'links:\n- https://affine.pro\n- https://github.com/toeverything/affine\n', + 'should render the prompt' + ); }); diff --git a/packages/backend/server/tests/quota.spec.ts b/packages/backend/server/tests/quota.spec.ts index 551107efd1..7c55329a77 100644 --- a/packages/backend/server/tests/quota.spec.ts +++ b/packages/backend/server/tests/quota.spec.ts @@ -49,7 +49,7 @@ test('should be able to set quota', async t => { const q1 = await quota.getUserQuota(u1.id); t.truthy(q1, 'should have quota'); t.is(q1?.feature.name, QuotaType.FreePlanV1, 'should be free plan'); - t.is(q1?.feature.version, 3, 'should be version 2'); + t.is(q1?.feature.version, 3, 'should be version 3'); await quota.switchUserQuota(u1.id, QuotaType.ProPlanV1); diff --git a/yarn.lock b/yarn.lock index 8fa3dac146..4ad6c6d417 100644 --- a/yarn.lock +++ b/yarn.lock @@ -697,6 +697,7 @@ __metadata: "@types/keyv": "npm:^4.2.0" "@types/lodash-es": "npm:^4.17.12" "@types/mixpanel": "npm:^2.14.8" + "@types/mustache": "npm:^4" "@types/node": "npm:^20.11.20" "@types/nodemailer": "npm:^6.4.14" "@types/on-headers": "npm:^1.0.3" @@ -720,6 +721,7 @@ __metadata: keyv: "npm:^4.5.4" lodash-es: "npm:^4.17.21" mixpanel: "npm:^0.18.0" + mustache: "npm:^4.2.0" nanoid: "npm:^5.0.6" nest-commander: "npm:^3.12.5" nestjs-throttler-storage-redis: "npm:^0.4.1" @@ -738,6 +740,7 @@ __metadata: socket.io: "npm:^4.7.4" stripe: "npm:^14.18.0" supertest: "npm:^6.3.4" + tiktoken: "npm:^1.0.13" ts-node: "npm:^10.9.2" typescript: "npm:^5.3.3" ws: "npm:^8.16.0" @@ -14489,6 +14492,13 @@ __metadata: languageName: node linkType: hard +"@types/mustache@npm:^4": + version: 4.2.5 + resolution: "@types/mustache@npm:4.2.5" + checksum: 10/29581027fe420120ae0591e28d44209d0e01adf5175910d03401327777ee9c649a1508e2aa63147c782c7e53fcea4b69b5f9a2fbedcadc5500561d1161ae5ded + languageName: node + linkType: hard + "@types/mute-stream@npm:^0.0.4": version: 0.0.4 resolution: "@types/mute-stream@npm:0.0.4" @@ -33968,6 +33978,13 @@ __metadata: languageName: node linkType: hard +"tiktoken@npm:^1.0.13": + version: 1.0.13 + resolution: "tiktoken@npm:1.0.13" + checksum: 10/4217ffbcd4126dc2dd17503fda35be91cf4be64c514f70e1049982d1bd2b5cea6334e76812411cb284dfa7b412159839d546048ac98220faf3c629e217266ddc + languageName: node + linkType: hard + "time-zone@npm:^1.0.0": version: 1.0.0 resolution: "time-zone@npm:1.0.0"