feat: add prompt service (#6241)

fix CLOUD-19
This commit is contained in:
darkskygit
2024-04-02 07:04:54 +00:00
parent 593161dccb
commit 3c01d944fb
8 changed files with 246 additions and 8 deletions
@@ -0,0 +1,16 @@
-- CreateEnum
CREATE TYPE "AiPromptRole" AS ENUM ('system', 'assistant', 'user');
-- CreateTable
CREATE TABLE "ai_prompts" (
"id" VARCHAR NOT NULL,
"name" VARCHAR(20) NOT NULL,
"idx" INTEGER NOT NULL,
"role" "AiPromptRole" NOT NULL,
"content" TEXT NOT NULL,
"created_at" TIMESTAMPTZ(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT "ai_prompts_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE UNIQUE INDEX "ai_prompts_name_idx_key" ON "ai_prompts"("name", "idx");
@@ -0,0 +1,24 @@
-- CreateTable
CREATE TABLE "ai_sessions" (
"id" VARCHAR NOT NULL,
"user_id" VARCHAR NOT NULL,
"workspace_id" VARCHAR NOT NULL,
"doc_id" VARCHAR NOT NULL,
"prompt_name" VARCHAR NOT NULL,
"action" BOOLEAN NOT NULL,
"model" VARCHAR NOT NULL,
"messages" JSON NOT NULL,
"created_at" TIMESTAMPTZ(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" TIMESTAMPTZ(6) NOT NULL,
CONSTRAINT "ai_sessions_pkey" PRIMARY KEY ("id")
);
-- AddForeignKey
ALTER TABLE "ai_sessions" ADD CONSTRAINT "ai_sessions_user_id_fkey" FOREIGN KEY ("user_id") REFERENCES "users"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "ai_sessions" ADD CONSTRAINT "ai_sessions_workspace_id_fkey" FOREIGN KEY ("workspace_id") REFERENCES "workspaces"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "ai_sessions" ADD CONSTRAINT "ai_sessions_doc_id_workspace_id_fkey" FOREIGN KEY ("doc_id", "workspace_id") REFERENCES "snapshots"("guid", "workspace_id") ON DELETE CASCADE ON UPDATE CASCADE;
+45
View File
@@ -30,6 +30,7 @@ model User {
pagePermissions WorkspacePageUserPermission[]
connectedAccounts ConnectedAccount[]
sessions UserSession[]
AiSession AiSession[]
@@map("users")
}
@@ -96,6 +97,7 @@ model Workspace {
permissions WorkspaceUserPermission[]
pagePermissions WorkspacePageUserPermission[]
features WorkspaceFeatures[]
AiSession AiSession[]
@@map("workspaces")
}
@@ -321,6 +323,8 @@ model Snapshot {
// but the created time of last seen update that has been merged into snapshot.
updatedAt DateTime @map("updated_at") @db.Timestamptz(6)
AiSession AiSession[]
@@id([id, workspaceId])
@@map("snapshots")
}
@@ -422,6 +426,47 @@ model UserInvoice {
@@map("user_invoices")
}
enum AiPromptRole {
system
assistant
user
}
model AiPrompt {
id String @id @default(uuid()) @db.VarChar
// prompt name
name String @db.VarChar(20)
// if a group of prompts contains multiple sentences, idx specifies the order of each sentence
idx Int @db.Integer
// system/assistant/user
role AiPromptRole
// prompt content
content String @db.Text
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6)
@@unique([name, idx])
@@map("ai_prompts")
}
model AiSession {
id String @id @default(uuid()) @db.VarChar
userId String @map("user_id") @db.VarChar
workspaceId String @map("workspace_id") @db.VarChar
docId String @map("doc_id") @db.VarChar
promptName String @map("prompt_name") @db.VarChar
action Boolean @db.Boolean
model String @db.VarChar
messages Json @db.Json
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6)
updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(6)
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
workspace Workspace @relation(fields: [workspaceId], references: [id], onDelete: Cascade)
doc Snapshot @relation(fields: [docId, workspaceId], references: [id, workspaceId], onDelete: Cascade)
@@map("ai_sessions")
}
model DataMigration {
id String @id @default(uuid()) @db.VarChar(36)
name String @db.VarChar
@@ -1,10 +1,11 @@
import { ServerFeature } from '../../core/config';
import { Plugin } from '../registry';
import { assertProvidersConfigs, CopilotProviderService } from './provider';
import { PromptService } from './prompt';
import { assertProvidersConfigs, CopilotProviderService } from './providers';
@Plugin({
name: 'copilot',
providers: [CopilotProviderService],
providers: [PromptService, CopilotProviderService],
contributesTo: ServerFeature.Copilot,
if: config => {
if (config.flavor.graphql) {
@@ -0,0 +1,72 @@
import { Injectable } from '@nestjs/common';
import { PrismaClient } from '@prisma/client';
import { ChatMessage } from './types';
@Injectable()
export class PromptService {
constructor(private readonly db: PrismaClient) {}
/**
* list prompt names
* @returns prompt names
*/
async list() {
return this.db.aiPrompt
.findMany({ select: { name: true } })
.then(prompts => Array.from(new Set(prompts.map(p => p.name))));
}
/**
* get prompt messages by prompt name
* @param name prompt name
* @returns prompt messages
*/
async get(name: string): Promise<ChatMessage[]> {
return this.db.aiPrompt.findMany({
where: {
name,
},
select: {
role: true,
content: true,
},
orderBy: {
idx: 'asc',
},
});
}
async set(name: string, messages: ChatMessage[]) {
return this.db.$transaction(async tx => {
const prompts = await tx.aiPrompt.count({ where: { name } });
if (prompts > 0) {
return 0;
}
return tx.aiPrompt
.createMany({
data: messages.map((m, idx) => ({ name, idx, ...m })),
})
.then(ret => ret.count);
});
}
async update(name: string, messages: ChatMessage[]) {
return this.db.$transaction(async tx => {
await tx.aiPrompt.deleteMany({ where: { name } });
return tx.aiPrompt
.createMany({
data: messages.map((m, idx) => ({ name, idx, ...m })),
})
.then(ret => ret.count);
});
}
async delete(name: string) {
return this.db.aiPrompt
.deleteMany({
where: { name },
})
.then(ret => ret.count);
}
}
@@ -2,14 +2,14 @@ import assert from 'node:assert';
import { Injectable, Logger } from '@nestjs/common';
import { Config } from '../../fundamentals';
import { Config } from '../../../fundamentals';
import {
CapabilityToCopilotProvider,
CopilotConfig,
CopilotProvider,
CopilotProviderCapability,
CopilotProviderType,
} from './types';
} from '../types';
type CopilotProviderConfig = CopilotConfig[keyof CopilotConfig];
@@ -1,4 +1,6 @@
import { AiPromptRole } from '@prisma/client';
import type { ClientOptions as OpenAIClientOptions } from 'openai';
import { z } from 'zod';
export interface CopilotConfig {
openai: OpenAIClientOptions;
@@ -23,10 +25,18 @@ export interface CopilotProvider {
getCapabilities(): CopilotProviderCapability[];
}
export type ChatMessage = {
role: 'system' | 'assistant' | 'user';
content: string;
};
export const ChatMessageSchema = z
.object({
role: z.enum(
Array.from(Object.values(AiPromptRole)) as [
'system' | 'assistant' | 'user',
]
),
content: z.string(),
})
.strict();
export type ChatMessage = z.infer<typeof ChatMessageSchema>;
export interface CopilotTextToTextProvider extends CopilotProvider {
generateText(messages: ChatMessage[], model: string): Promise<string>;
@@ -0,0 +1,70 @@
/// <reference types="../src/global.d.ts" />
import { TestingModule } from '@nestjs/testing';
import type { TestFn } from 'ava';
import ava from 'ava';
import { AuthService } from '../src/core/auth';
import { QuotaManagementService, QuotaModule } from '../src/core/quota';
import { ConfigModule } from '../src/fundamentals/config';
import { CopilotModule } from '../src/plugins/copilot';
import { PromptService } from '../src/plugins/copilot/prompt';
import { createTestingModule } from './utils';
const test = ava as TestFn<{
auth: AuthService;
quotaManager: QuotaManagementService;
module: TestingModule;
prompt: PromptService;
}>;
test.beforeEach(async t => {
const module = await createTestingModule({
imports: [
ConfigModule.forRoot({
plugins: {
copilot: {
openai: {
apiKey: '1',
},
},
},
}),
QuotaModule,
CopilotModule,
],
});
const quotaManager = module.get(QuotaManagementService);
const auth = module.get(AuthService);
const prompt = module.get(PromptService);
t.context.module = module;
t.context.quotaManager = quotaManager;
t.context.auth = auth;
t.context.prompt = prompt;
});
test.afterEach.always(async t => {
await t.context.module.close();
});
test('should be able to manage prompt', async t => {
const { prompt } = t.context;
t.is((await prompt.list()).length, 0, 'should have no prompt');
await prompt.set('test', [
{ role: 'system', content: 'hello' },
{ role: 'user', content: 'hello' },
]);
t.is((await prompt.list()).length, 1, 'should have one prompt');
t.is((await prompt.get('test')).length, 2, 'should have two messages');
await prompt.update('test', [{ role: 'system', content: 'hello' }]);
t.is((await prompt.get('test')).length, 1, 'should have one message');
await prompt.delete('test');
t.is((await prompt.list()).length, 0, 'should have no prompt');
t.is((await prompt.get('test')).length, 0, 'should have no messages');
});