mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-07-04 19:15:33 +08:00
@@ -0,0 +1,16 @@
|
||||
-- CreateEnum
|
||||
CREATE TYPE "AiPromptRole" AS ENUM ('system', 'assistant', 'user');
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "ai_prompts" (
|
||||
"id" VARCHAR NOT NULL,
|
||||
"name" VARCHAR(20) NOT NULL,
|
||||
"idx" INTEGER NOT NULL,
|
||||
"role" "AiPromptRole" NOT NULL,
|
||||
"content" TEXT NOT NULL,
|
||||
"created_at" TIMESTAMPTZ(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
CONSTRAINT "ai_prompts_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "ai_prompts_name_idx_key" ON "ai_prompts"("name", "idx");
|
||||
@@ -0,0 +1,24 @@
|
||||
-- CreateTable
|
||||
CREATE TABLE "ai_sessions" (
|
||||
"id" VARCHAR NOT NULL,
|
||||
"user_id" VARCHAR NOT NULL,
|
||||
"workspace_id" VARCHAR NOT NULL,
|
||||
"doc_id" VARCHAR NOT NULL,
|
||||
"prompt_name" VARCHAR NOT NULL,
|
||||
"action" BOOLEAN NOT NULL,
|
||||
"model" VARCHAR NOT NULL,
|
||||
"messages" JSON NOT NULL,
|
||||
"created_at" TIMESTAMPTZ(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updated_at" TIMESTAMPTZ(6) NOT NULL,
|
||||
|
||||
CONSTRAINT "ai_sessions_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "ai_sessions" ADD CONSTRAINT "ai_sessions_user_id_fkey" FOREIGN KEY ("user_id") REFERENCES "users"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "ai_sessions" ADD CONSTRAINT "ai_sessions_workspace_id_fkey" FOREIGN KEY ("workspace_id") REFERENCES "workspaces"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "ai_sessions" ADD CONSTRAINT "ai_sessions_doc_id_workspace_id_fkey" FOREIGN KEY ("doc_id", "workspace_id") REFERENCES "snapshots"("guid", "workspace_id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
@@ -30,6 +30,7 @@ model User {
|
||||
pagePermissions WorkspacePageUserPermission[]
|
||||
connectedAccounts ConnectedAccount[]
|
||||
sessions UserSession[]
|
||||
AiSession AiSession[]
|
||||
|
||||
@@map("users")
|
||||
}
|
||||
@@ -96,6 +97,7 @@ model Workspace {
|
||||
permissions WorkspaceUserPermission[]
|
||||
pagePermissions WorkspacePageUserPermission[]
|
||||
features WorkspaceFeatures[]
|
||||
AiSession AiSession[]
|
||||
|
||||
@@map("workspaces")
|
||||
}
|
||||
@@ -321,6 +323,8 @@ model Snapshot {
|
||||
// but the created time of last seen update that has been merged into snapshot.
|
||||
updatedAt DateTime @map("updated_at") @db.Timestamptz(6)
|
||||
|
||||
AiSession AiSession[]
|
||||
|
||||
@@id([id, workspaceId])
|
||||
@@map("snapshots")
|
||||
}
|
||||
@@ -422,6 +426,47 @@ model UserInvoice {
|
||||
@@map("user_invoices")
|
||||
}
|
||||
|
||||
enum AiPromptRole {
|
||||
system
|
||||
assistant
|
||||
user
|
||||
}
|
||||
|
||||
model AiPrompt {
|
||||
id String @id @default(uuid()) @db.VarChar
|
||||
// prompt name
|
||||
name String @db.VarChar(20)
|
||||
// if a group of prompts contains multiple sentences, idx specifies the order of each sentence
|
||||
idx Int @db.Integer
|
||||
// system/assistant/user
|
||||
role AiPromptRole
|
||||
// prompt content
|
||||
content String @db.Text
|
||||
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6)
|
||||
|
||||
@@unique([name, idx])
|
||||
@@map("ai_prompts")
|
||||
}
|
||||
|
||||
model AiSession {
|
||||
id String @id @default(uuid()) @db.VarChar
|
||||
userId String @map("user_id") @db.VarChar
|
||||
workspaceId String @map("workspace_id") @db.VarChar
|
||||
docId String @map("doc_id") @db.VarChar
|
||||
promptName String @map("prompt_name") @db.VarChar
|
||||
action Boolean @db.Boolean
|
||||
model String @db.VarChar
|
||||
messages Json @db.Json
|
||||
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6)
|
||||
updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(6)
|
||||
|
||||
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
workspace Workspace @relation(fields: [workspaceId], references: [id], onDelete: Cascade)
|
||||
doc Snapshot @relation(fields: [docId, workspaceId], references: [id, workspaceId], onDelete: Cascade)
|
||||
|
||||
@@map("ai_sessions")
|
||||
}
|
||||
|
||||
model DataMigration {
|
||||
id String @id @default(uuid()) @db.VarChar(36)
|
||||
name String @db.VarChar
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import { ServerFeature } from '../../core/config';
|
||||
import { Plugin } from '../registry';
|
||||
import { assertProvidersConfigs, CopilotProviderService } from './provider';
|
||||
import { PromptService } from './prompt';
|
||||
import { assertProvidersConfigs, CopilotProviderService } from './providers';
|
||||
|
||||
@Plugin({
|
||||
name: 'copilot',
|
||||
providers: [CopilotProviderService],
|
||||
providers: [PromptService, CopilotProviderService],
|
||||
contributesTo: ServerFeature.Copilot,
|
||||
if: config => {
|
||||
if (config.flavor.graphql) {
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
|
||||
import { ChatMessage } from './types';
|
||||
|
||||
@Injectable()
|
||||
export class PromptService {
|
||||
constructor(private readonly db: PrismaClient) {}
|
||||
|
||||
/**
|
||||
* list prompt names
|
||||
* @returns prompt names
|
||||
*/
|
||||
async list() {
|
||||
return this.db.aiPrompt
|
||||
.findMany({ select: { name: true } })
|
||||
.then(prompts => Array.from(new Set(prompts.map(p => p.name))));
|
||||
}
|
||||
|
||||
/**
|
||||
* get prompt messages by prompt name
|
||||
* @param name prompt name
|
||||
* @returns prompt messages
|
||||
*/
|
||||
async get(name: string): Promise<ChatMessage[]> {
|
||||
return this.db.aiPrompt.findMany({
|
||||
where: {
|
||||
name,
|
||||
},
|
||||
select: {
|
||||
role: true,
|
||||
content: true,
|
||||
},
|
||||
orderBy: {
|
||||
idx: 'asc',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
async set(name: string, messages: ChatMessage[]) {
|
||||
return this.db.$transaction(async tx => {
|
||||
const prompts = await tx.aiPrompt.count({ where: { name } });
|
||||
if (prompts > 0) {
|
||||
return 0;
|
||||
}
|
||||
return tx.aiPrompt
|
||||
.createMany({
|
||||
data: messages.map((m, idx) => ({ name, idx, ...m })),
|
||||
})
|
||||
.then(ret => ret.count);
|
||||
});
|
||||
}
|
||||
|
||||
async update(name: string, messages: ChatMessage[]) {
|
||||
return this.db.$transaction(async tx => {
|
||||
await tx.aiPrompt.deleteMany({ where: { name } });
|
||||
return tx.aiPrompt
|
||||
.createMany({
|
||||
data: messages.map((m, idx) => ({ name, idx, ...m })),
|
||||
})
|
||||
.then(ret => ret.count);
|
||||
});
|
||||
}
|
||||
|
||||
async delete(name: string) {
|
||||
return this.db.aiPrompt
|
||||
.deleteMany({
|
||||
where: { name },
|
||||
})
|
||||
.then(ret => ret.count);
|
||||
}
|
||||
}
|
||||
+2
-2
@@ -2,14 +2,14 @@ import assert from 'node:assert';
|
||||
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
|
||||
import { Config } from '../../fundamentals';
|
||||
import { Config } from '../../../fundamentals';
|
||||
import {
|
||||
CapabilityToCopilotProvider,
|
||||
CopilotConfig,
|
||||
CopilotProvider,
|
||||
CopilotProviderCapability,
|
||||
CopilotProviderType,
|
||||
} from './types';
|
||||
} from '../types';
|
||||
|
||||
type CopilotProviderConfig = CopilotConfig[keyof CopilotConfig];
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import { AiPromptRole } from '@prisma/client';
|
||||
import type { ClientOptions as OpenAIClientOptions } from 'openai';
|
||||
import { z } from 'zod';
|
||||
|
||||
export interface CopilotConfig {
|
||||
openai: OpenAIClientOptions;
|
||||
@@ -23,10 +25,18 @@ export interface CopilotProvider {
|
||||
getCapabilities(): CopilotProviderCapability[];
|
||||
}
|
||||
|
||||
export type ChatMessage = {
|
||||
role: 'system' | 'assistant' | 'user';
|
||||
content: string;
|
||||
};
|
||||
export const ChatMessageSchema = z
|
||||
.object({
|
||||
role: z.enum(
|
||||
Array.from(Object.values(AiPromptRole)) as [
|
||||
'system' | 'assistant' | 'user',
|
||||
]
|
||||
),
|
||||
content: z.string(),
|
||||
})
|
||||
.strict();
|
||||
|
||||
export type ChatMessage = z.infer<typeof ChatMessageSchema>;
|
||||
|
||||
export interface CopilotTextToTextProvider extends CopilotProvider {
|
||||
generateText(messages: ChatMessage[], model: string): Promise<string>;
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
/// <reference types="../src/global.d.ts" />
|
||||
|
||||
import { TestingModule } from '@nestjs/testing';
|
||||
import type { TestFn } from 'ava';
|
||||
import ava from 'ava';
|
||||
|
||||
import { AuthService } from '../src/core/auth';
|
||||
import { QuotaManagementService, QuotaModule } from '../src/core/quota';
|
||||
import { ConfigModule } from '../src/fundamentals/config';
|
||||
import { CopilotModule } from '../src/plugins/copilot';
|
||||
import { PromptService } from '../src/plugins/copilot/prompt';
|
||||
import { createTestingModule } from './utils';
|
||||
|
||||
const test = ava as TestFn<{
|
||||
auth: AuthService;
|
||||
quotaManager: QuotaManagementService;
|
||||
module: TestingModule;
|
||||
prompt: PromptService;
|
||||
}>;
|
||||
|
||||
test.beforeEach(async t => {
|
||||
const module = await createTestingModule({
|
||||
imports: [
|
||||
ConfigModule.forRoot({
|
||||
plugins: {
|
||||
copilot: {
|
||||
openai: {
|
||||
apiKey: '1',
|
||||
},
|
||||
},
|
||||
},
|
||||
}),
|
||||
QuotaModule,
|
||||
CopilotModule,
|
||||
],
|
||||
});
|
||||
|
||||
const quotaManager = module.get(QuotaManagementService);
|
||||
const auth = module.get(AuthService);
|
||||
const prompt = module.get(PromptService);
|
||||
|
||||
t.context.module = module;
|
||||
t.context.quotaManager = quotaManager;
|
||||
t.context.auth = auth;
|
||||
t.context.prompt = prompt;
|
||||
});
|
||||
|
||||
test.afterEach.always(async t => {
|
||||
await t.context.module.close();
|
||||
});
|
||||
|
||||
test('should be able to manage prompt', async t => {
|
||||
const { prompt } = t.context;
|
||||
|
||||
t.is((await prompt.list()).length, 0, 'should have no prompt');
|
||||
|
||||
await prompt.set('test', [
|
||||
{ role: 'system', content: 'hello' },
|
||||
{ role: 'user', content: 'hello' },
|
||||
]);
|
||||
t.is((await prompt.list()).length, 1, 'should have one prompt');
|
||||
t.is((await prompt.get('test')).length, 2, 'should have two messages');
|
||||
|
||||
await prompt.update('test', [{ role: 'system', content: 'hello' }]);
|
||||
t.is((await prompt.get('test')).length, 1, 'should have one message');
|
||||
|
||||
await prompt.delete('test');
|
||||
t.is((await prompt.list()).length, 0, 'should have no prompt');
|
||||
t.is((await prompt.get('test')).length, 0, 'should have no messages');
|
||||
});
|
||||
Reference in New Issue
Block a user