feat: add session impl (#6254)

This commit is contained in:
darkskygit
2024-04-10 11:15:25 +00:00
parent 8a02c81745
commit 46a368d7f1
16 changed files with 1033 additions and 88 deletions

View File

@@ -1,11 +1,12 @@
-- CreateTable
CREATE TABLE "ai_sessions" (
"id" VARCHAR NOT NULL,
"id" VARCHAR(36) NOT NULL,
"user_id" VARCHAR NOT NULL,
"workspace_id" VARCHAR NOT NULL,
"doc_id" VARCHAR NOT NULL,
"prompt_name" VARCHAR NOT NULL,
"action" BOOLEAN NOT NULL,
"flavor" VARCHAR NOT NULL,
"model" VARCHAR NOT NULL,
"messages" JSON NOT NULL,
"created_at" TIMESTAMPTZ(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,

View File

@@ -0,0 +1,90 @@
/*
Warnings:
- You are about to drop the `ai_prompts` table. If the table is not empty, all the data it contains will be lost.
- You are about to drop the `ai_sessions` table. If the table is not empty, all the data it contains will be lost.
*/
-- DropForeignKey
ALTER TABLE "ai_sessions" DROP CONSTRAINT "ai_sessions_doc_id_workspace_id_fkey";
-- DropForeignKey
ALTER TABLE "ai_sessions" DROP CONSTRAINT "ai_sessions_user_id_fkey";
-- DropForeignKey
ALTER TABLE "ai_sessions" DROP CONSTRAINT "ai_sessions_workspace_id_fkey";
-- DropTable
DROP TABLE "ai_prompts";
-- DropTable
DROP TABLE "ai_sessions";
-- CreateTable
CREATE TABLE "ai_prompts_messages" (
"prompt_id" INTEGER NOT NULL,
"idx" INTEGER NOT NULL,
"role" "AiPromptRole" NOT NULL,
"content" TEXT NOT NULL,
"params" JSON,
"created_at" TIMESTAMPTZ(6) NOT NULL DEFAULT CURRENT_TIMESTAMP
);
-- CreateTable
CREATE TABLE "ai_prompts_metadata" (
"id" SERIAL NOT NULL,
"name" VARCHAR(32) NOT NULL,
"action" VARCHAR,
"model" VARCHAR,
"created_at" TIMESTAMPTZ(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT "ai_prompts_metadata_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "ai_sessions_messages" (
"id" VARCHAR(36) NOT NULL,
"session_id" VARCHAR(36) NOT NULL,
"role" "AiPromptRole" NOT NULL,
"content" TEXT NOT NULL,
"created_at" TIMESTAMPTZ(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" TIMESTAMPTZ(6) NOT NULL,
CONSTRAINT "ai_sessions_messages_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "ai_sessions_metadata" (
"id" VARCHAR(36) NOT NULL,
"user_id" VARCHAR(36) NOT NULL,
"workspace_id" VARCHAR(36) NOT NULL,
"doc_id" VARCHAR(36) NOT NULL,
"prompt_name" VARCHAR(32) NOT NULL,
"created_at" TIMESTAMPTZ(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT "ai_sessions_metadata_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE UNIQUE INDEX "ai_prompts_messages_prompt_id_idx_key" ON "ai_prompts_messages"("prompt_id", "idx");
-- CreateIndex
CREATE UNIQUE INDEX "ai_prompts_metadata_name_key" ON "ai_prompts_metadata"("name");
-- AddForeignKey
ALTER TABLE "ai_prompts_messages" ADD CONSTRAINT "ai_prompts_messages_prompt_id_fkey" FOREIGN KEY ("prompt_id") REFERENCES "ai_prompts_metadata"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "ai_sessions_messages" ADD CONSTRAINT "ai_sessions_messages_session_id_fkey" FOREIGN KEY ("session_id") REFERENCES "ai_sessions_metadata"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "ai_sessions_metadata" ADD CONSTRAINT "ai_sessions_metadata_user_id_fkey" FOREIGN KEY ("user_id") REFERENCES "users"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "ai_sessions_metadata" ADD CONSTRAINT "ai_sessions_metadata_workspace_id_fkey" FOREIGN KEY ("workspace_id") REFERENCES "workspaces"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "ai_sessions_metadata" ADD CONSTRAINT "ai_sessions_metadata_doc_id_workspace_id_fkey" FOREIGN KEY ("doc_id", "workspace_id") REFERENCES "snapshots"("guid", "workspace_id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "ai_sessions_metadata" ADD CONSTRAINT "ai_sessions_metadata_prompt_name_fkey" FOREIGN KEY ("prompt_name") REFERENCES "ai_prompts_metadata"("name") ON DELETE CASCADE ON UPDATE CASCADE;

View File

@@ -72,6 +72,7 @@
"keyv": "^4.5.4",
"lodash-es": "^4.17.21",
"mixpanel": "^0.18.0",
"mustache": "^4.2.0",
"nanoid": "^5.0.6",
"nest-commander": "^3.12.5",
"nestjs-throttler-storage-redis": "^0.4.1",
@@ -87,6 +88,7 @@
"semver": "^7.6.0",
"socket.io": "^4.7.4",
"stripe": "^14.18.0",
"tiktoken": "^1.0.13",
"ts-node": "^10.9.2",
"typescript": "^5.3.3",
"ws": "^8.16.0",
@@ -105,6 +107,7 @@
"@types/keyv": "^4.2.0",
"@types/lodash-es": "^4.17.12",
"@types/mixpanel": "^2.14.8",
"@types/mustache": "^4",
"@types/node": "^20.11.20",
"@types/nodemailer": "^6.4.14",
"@types/on-headers": "^1.0.3",

View File

@@ -30,7 +30,7 @@ model User {
pagePermissions WorkspacePageUserPermission[]
connectedAccounts ConnectedAccount[]
sessions UserSession[]
AiSession AiSession[]
aiSessions AiSession[]
@@map("users")
}
@@ -97,7 +97,7 @@ model Workspace {
permissions WorkspaceUserPermission[]
pagePermissions WorkspacePageUserPermission[]
features WorkspaceFeatures[]
AiSession AiSession[]
aiSessions AiSession[]
@@map("workspaces")
}
@@ -323,7 +323,7 @@ model Snapshot {
// but the created time of last seen update that has been merged into snapshot.
updatedAt DateTime @map("updated_at") @db.Timestamptz(6)
AiSession AiSession[]
aiSessions AiSession[]
@@id([id, workspaceId])
@@map("snapshots")
@@ -432,39 +432,66 @@ enum AiPromptRole {
user
}
model AiPrompt {
id String @id @default(uuid()) @db.VarChar
// prompt name
name String @db.VarChar(20)
model AiPromptMessage {
promptId Int @map("prompt_id") @db.Integer
// if a group of prompts contains multiple sentences, idx specifies the order of each sentence
idx Int @db.Integer
// system/assistant/user
role AiPromptRole
// prompt content
content String @db.Text
params Json? @db.Json
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6)
@@unique([name, idx])
@@map("ai_prompts")
prompt AiPrompt @relation(fields: [promptId], references: [id], onDelete: Cascade)
@@unique([promptId, idx])
@@map("ai_prompts_messages")
}
model AiPrompt {
id Int @id @default(autoincrement()) @db.Integer
name String @unique @db.VarChar(32)
// an mark identifying which view to use to display the session
// it is only used in the frontend and does not affect the backend
action String? @db.VarChar
model String? @db.VarChar
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6)
messages AiPromptMessage[]
sessions AiSession[]
@@map("ai_prompts_metadata")
}
model AiSessionMessage {
id String @id @default(uuid()) @db.VarChar(36)
sessionId String @map("session_id") @db.VarChar(36)
role AiPromptRole
content String @db.Text
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6)
updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(6)
session AiSession @relation(fields: [sessionId], references: [id], onDelete: Cascade)
@@map("ai_sessions_messages")
}
model AiSession {
id String @id @default(uuid()) @db.VarChar
userId String @map("user_id") @db.VarChar
workspaceId String @map("workspace_id") @db.VarChar
docId String @map("doc_id") @db.VarChar
promptName String @map("prompt_name") @db.VarChar
action Boolean @db.Boolean
model String @db.VarChar
messages Json @db.Json
id String @id @default(uuid()) @db.VarChar(36)
userId String @map("user_id") @db.VarChar(36)
workspaceId String @map("workspace_id") @db.VarChar(36)
docId String @map("doc_id") @db.VarChar(36)
promptName String @map("prompt_name") @db.VarChar(32)
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6)
updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(6)
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
workspace Workspace @relation(fields: [workspaceId], references: [id], onDelete: Cascade)
doc Snapshot @relation(fields: [docId, workspaceId], references: [id, workspaceId], onDelete: Cascade)
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
workspace Workspace @relation(fields: [workspaceId], references: [id], onDelete: Cascade)
doc Snapshot @relation(fields: [docId, workspaceId], references: [id, workspaceId], onDelete: Cascade)
prompt AiPrompt @relation(fields: [promptName], references: [name], onDelete: Cascade)
messages AiSessionMessage[]
@@map("ai_sessions")
@@map("ai_sessions_metadata")
}
model DataMigration {

View File

@@ -19,6 +19,7 @@ AFFiNE.ENV_MAP = {
MAILER_SECURE: ['mailer.secure', 'boolean'],
THROTTLE_TTL: ['rateLimiter.ttl', 'int'],
THROTTLE_LIMIT: ['rateLimiter.limit', 'int'],
COPILOT_OPENAI_API_KEY: 'plugins.copilot.openai.apiKey',
REDIS_SERVER_HOST: 'plugins.redis.host',
REDIS_SERVER_PORT: ['plugins.redis.port', 'int'],
REDIS_SERVER_USER: 'plugins.redis.username',

View File

@@ -39,9 +39,7 @@ if (env.R2_OBJECT_STORAGE_ACCOUNT_ID) {
}
AFFiNE.plugins.use('copilot', {
openai: {
apiKey: 'test',
},
openai: {},
});
AFFiNE.plugins.use('redis');
AFFiNE.plugins.use('payment', {

View File

@@ -0,0 +1,33 @@
import { PrismaClient } from '@prisma/client';
import { prompts } from './utils/prompts';
export class Prompts1712068777394 {
// do the migration
static async up(db: PrismaClient) {
await db.$transaction(async tx => {
await Promise.all(
prompts.map(prompt =>
tx.aiPrompt.create({
data: {
name: prompt.name,
action: prompt.action,
model: prompt.model,
messages: {
create: prompt.messages.map((message, idx) => ({
idx,
role: message.role,
content: message.content,
params: message.params,
})),
},
},
})
)
);
});
}
// revert the migration
static async down(_db: PrismaClient) {}
}

View File

@@ -0,0 +1,16 @@
import { PrismaClient } from '@prisma/client';
import { Quotas } from '../../core/quota';
import { upgradeQuotaVersion } from './utils/user-quotas';
export class RefreshFreePlan1712224382221 {
// do the migration
static async up(db: PrismaClient) {
// free plan 1.0
const quota = Quotas[4];
await upgradeQuotaVersion(db, quota, 'free plan 1.1 migration');
}
// revert the migration
static async down(_db: PrismaClient) {}
}

View File

@@ -0,0 +1,275 @@
import { AiPromptRole } from '@prisma/client';
type PromptMessage = {
role: AiPromptRole;
content: string;
params?: Record<string, string | string[]>;
};
type Prompt = {
name: string;
action?: string;
model: string;
messages: PromptMessage[];
};
export const prompts: Prompt[] = [
{
name: 'debug:chat:gpt4',
model: 'gpt-4-turbo-preview',
messages: [],
},
{
name: 'debug:action:gpt4',
action: 'text',
model: 'gpt-4-turbo-preview',
messages: [],
},
{
name: 'debug:action:vision4',
action: 'text',
model: 'gpt-4-vision-preview',
messages: [],
},
{
name: 'Summary',
action: 'text',
model: 'gpt-3.5-turbo',
messages: [
{
role: 'assistant',
content:
'Summarize the key points from the following content in a clear and concise manner, suitable for a reader who is seeking a quick understanding of the original content. Ensure to capture the main ideas and any significant details without unnecessary elaboration:\n\n{{content}}',
},
],
},
{
name: 'Summary the webpage',
action: 'text',
model: 'gpt-3.5-turbo',
messages: [
{
role: 'assistant',
content:
'Summarize the insights from the following webpage content:\n\nFirst, provide a brief summary of the webpage content below. Then, list the insights derived from it, one by one.\n\n{{#links}}\n- {{.}}\n{{/links}}',
},
],
},
{
name: 'Explain this image',
action: 'text',
model: 'gpt-3.5-turbo',
messages: [
{
role: 'assistant',
content:
'Describe the scene captured in this image, focusing on the details, colors, emotions, and any interactions between subjects or objects present.\n\n{{image}}',
},
],
},
{
name: 'Explain this code',
action: 'text',
model: 'gpt-3.5-turbo',
messages: [
{
role: 'assistant',
content:
'Analyze and explain the functionality of the following code snippet, highlighting its purpose, the logic behind its operations, and its potential output:\n\n{{code}}',
},
],
},
{
name: 'Translate to',
action: 'text',
model: 'gpt-3.5-turbo',
messages: [
{
role: 'assistant',
content:
'Please translate the following content into {{language}} and return it to us, adhering to the original format of the content:\n\n{{content}}',
params: {
language: [
'English',
'Spanish',
'German',
'French',
'Italian',
'Simplified Chinese',
'Traditional Chinese',
'Japanese',
'Russian',
'Korean',
],
},
},
],
},
{
name: 'Write an article about this',
action: 'text',
model: 'gpt-3.5-turbo',
messages: [
{
role: 'assistant',
content: 'Write an article about following content:\n\n{{content}}',
},
],
},
{
name: 'Write a twitter about this',
action: 'text',
model: 'gpt-3.5-turbo',
messages: [
{
role: 'assistant',
content: 'Write a twitter about following content:\n\n{{content}}',
},
],
},
{
name: 'Write a poem about this',
action: 'text',
model: 'gpt-3.5-turbo',
messages: [
{
role: 'assistant',
content: 'Write a poem about following content:\n\n{{content}}',
},
],
},
{
name: 'Write a blog post about this',
action: 'text',
model: 'gpt-3.5-turbo',
messages: [
{
role: 'assistant',
content: 'Write a blog post about following content:\n\n{{content}}',
},
],
},
{
name: 'Change tone to',
action: 'text',
model: 'gpt-3.5-turbo',
messages: [
{
role: 'assistant',
content:
'Please rephrase the following content to convey a more {{tone}} tone:\n\n{{content}}',
params: { tone: ['professional', 'informal', 'friendly', 'critical'] },
},
],
},
{
name: 'Brainstorm ideas about this',
action: 'text',
model: 'gpt-3.5-turbo',
messages: [
{
role: 'assistant',
content:
'Using the information following content, brainstorm ideas and output your thoughts in a bulleted points format.\n\n{{content}}',
},
],
},
{
name: 'Improve writing for it',
action: 'text',
model: 'gpt-3.5-turbo',
messages: [
{
role: 'assistant',
content:
'Please rewrite the following content to enhance its clarity, coherence, and overall quality, ensuring that the message is effectively communicated and free of any grammatical errors. Provide a refined version that maintains the original intent but exhibits improved structure and readability:\n\n{{content}}',
},
],
},
{
name: 'Improve grammar for it',
action: 'text',
model: 'gpt-3.5-turbo',
messages: [
{
role: 'assistant',
content:
'Please correct the grammar in the following content to ensure that it is free from any grammatical errors, maintaining proper sentence structure, correct tense usage, and accurate punctuation. Ensure that the final content is grammatically sound while preserving the original message:\n\n{{content}}',
},
],
},
{
name: 'Fix spelling for it',
action: 'text',
model: 'gpt-3.5-turbo',
messages: [
{
role: 'assistant',
content:
"Please carefully review the following content and correct all spelling mistakes. Ensure that each word is spelled correctly, adhering to standard {{language}} spelling conventions. The content's meaning should remain unchanged; only the spelling errors need to be addressed:\n\n{{content}}",
params: {
language: [
'English',
'Spanish',
'German',
'French',
'Italian',
'Simplified Chinese',
'Traditional Chinese',
'Japanese',
'Russian',
'Korean',
],
},
},
],
},
{
name: 'Find action items from it',
action: 'todo-list',
model: 'gpt-3.5-turbo',
messages: [
{
role: 'assistant',
content:
'Identify action items from the following content and return them as a to-do list in Markdown format:\n\n{{content}}',
},
],
},
{
name: 'Check code error',
action: 'text',
model: 'gpt-3.5-turbo',
messages: [
{
role: 'assistant',
content:
'Review the following code snippet for any syntax errors and list them individually:\n\n{{content}}',
},
],
},
{
name: 'Create a presentation',
action: 'text',
model: 'gpt-3.5-turbo',
messages: [
{
role: 'assistant',
content:
'I want to write a PPT, that has many pages, each page has 1 to 4 sections,\neach section has a title of no more than 30 words and no more than 500 words of content,\nbut also need some keywords that match the content of the paragraph used to generate images,\nTry to have a different number of section per page\nThe first page is the cover, which generates a general title (no more than 4 words) and description based on the topic\nthis is a template:\n- page name\n - title\n - keywords\n - description\n- page name\n - section name\n - keywords\n - content\n - section name\n - keywords\n - content\n- page name\n - section name\n - keywords\n - content\n - section name\n - keywords\n - content\n - section name\n - keywords\n - content\n- page name\n - section name\n - keywords\n - content\n - section name\n - keywords\n - content\n - section name\n - keywords\n - content\n - section name\n - keywords\n - content\n- page name\n - section name\n - keywords\n - content\n\n\nplease help me to write this ppt, do not output any content that does not belong to the ppt content itself outside of the content, Directly output the title content keywords without prefix like Title:xxx, Content: xxx, Keywords: xxx\nThe PPT is based on the following topics:\n\n{{content}}',
},
],
},
{
name: 'Create headings',
action: 'text',
model: 'gpt-3.5-turbo',
messages: [
{
role: 'assistant',
content:
'Craft a distilled heading from the following content, maximum 10 words, format: H1.\n\n{{content}}',
},
],
},
];

View File

@@ -2,10 +2,11 @@ import { ServerFeature } from '../../core/config';
import { Plugin } from '../registry';
import { PromptService } from './prompt';
import { assertProvidersConfigs, CopilotProviderService } from './providers';
import { ChatSessionService } from './session';
@Plugin({
name: 'copilot',
providers: [PromptService, CopilotProviderService],
providers: [ChatSessionService, PromptService, CopilotProviderService],
contributesTo: ServerFeature.Copilot,
if: config => {
if (config.flavor.graphql) {

View File

@@ -1,7 +1,124 @@
import { Injectable } from '@nestjs/common';
import { PrismaClient } from '@prisma/client';
import { AiPrompt, PrismaClient } from '@prisma/client';
import Mustache from 'mustache';
import { Tiktoken } from 'tiktoken';
import { ChatMessage } from './types';
import {
getTokenEncoder,
PromptMessage,
PromptMessageSchema,
PromptParams,
} from './types';
// disable escaping
Mustache.escape = (text: string) => text;
function extractMustacheParams(template: string) {
const regex = /\{\{\s*([^{}]+)\s*\}\}/g;
const params = [];
let match;
while ((match = regex.exec(template)) !== null) {
params.push(match[1]);
}
return Array.from(new Set(params));
}
export class ChatPrompt {
public readonly encoder?: Tiktoken;
private readonly promptTokenSize: number;
private readonly templateParamKeys: string[] = [];
private readonly templateParams: PromptParams = {};
static createFromPrompt(
options: Omit<AiPrompt, 'id' | 'createdAt'> & {
messages: PromptMessage[];
}
) {
return new ChatPrompt(
options.name,
options.action,
options.model,
options.messages
);
}
constructor(
public readonly name: string,
public readonly action: string | null,
public readonly model: string | null,
private readonly messages: PromptMessage[]
) {
this.encoder = getTokenEncoder(model);
this.promptTokenSize =
this.encoder?.encode_ordinary(messages.map(m => m.content).join('') || '')
.length || 0;
this.templateParamKeys = extractMustacheParams(
messages.map(m => m.content).join('')
);
this.templateParams = messages.reduce(
(acc, m) => Object.assign(acc, m.params),
{} as PromptParams
);
}
/**
* get prompt token size
*/
get tokens() {
return this.promptTokenSize;
}
/**
* get prompt param keys in template
*/
get paramKeys() {
return this.templateParamKeys.slice();
}
/**
* get prompt params
*/
get params() {
return { ...this.templateParams };
}
encode(message: string) {
return this.encoder?.encode_ordinary(message).length || 0;
}
private checkParams(params: PromptParams) {
const selfParams = this.templateParams;
for (const key of Object.keys(selfParams)) {
const options = selfParams[key];
const income = params[key];
if (
typeof income !== 'string' ||
(Array.isArray(options) && !options.includes(income))
) {
throw new Error(`Invalid param: ${key}`);
}
}
}
/**
* render prompt messages with params
* @param params record of params, e.g. { name: 'Alice' }
* @returns e.g. [{ role: 'system', content: 'Hello, {{name}}' }] => [{ role: 'system', content: 'Hello, Alice' }]
*/
finish(params: PromptParams) {
this.checkParams(params);
return this.messages.map(m => ({
...m,
content: Mustache.render(m.content, params),
}));
}
free() {
this.encoder?.free();
}
}
@Injectable()
export class PromptService {
@@ -22,51 +139,74 @@ export class PromptService {
* @param name prompt name
* @returns prompt messages
*/
async get(name: string): Promise<ChatMessage[]> {
return this.db.aiPrompt.findMany({
where: {
name,
},
select: {
role: true,
content: true,
},
orderBy: {
idx: 'asc',
},
});
async get(name: string): Promise<ChatPrompt | null> {
return this.db.aiPrompt
.findUnique({
where: {
name,
},
select: {
name: true,
action: true,
model: true,
messages: {
select: {
role: true,
content: true,
params: true,
},
orderBy: {
idx: 'asc',
},
},
},
})
.then(p => {
const messages = PromptMessageSchema.array().safeParse(p?.messages);
if (p && messages.success) {
return ChatPrompt.createFromPrompt({ ...p, messages: messages.data });
}
return null;
});
}
async set(name: string, messages: ChatMessage[]) {
return this.db.$transaction(async tx => {
const prompts = await tx.aiPrompt.count({ where: { name } });
if (prompts > 0) {
return 0;
}
return tx.aiPrompt
.createMany({
data: messages.map((m, idx) => ({ name, idx, ...m })),
})
.then(ret => ret.count);
});
async set(name: string, messages: PromptMessage[]) {
return await this.db.aiPrompt
.create({
data: {
name,
messages: {
create: messages.map((m, idx) => ({
idx,
...m,
params: m.params || undefined,
})),
},
},
})
.then(ret => ret.id);
}
async update(name: string, messages: ChatMessage[]) {
return this.db.$transaction(async tx => {
await tx.aiPrompt.deleteMany({ where: { name } });
return tx.aiPrompt
.createMany({
data: messages.map((m, idx) => ({ name, idx, ...m })),
})
.then(ret => ret.count);
});
async update(name: string, messages: PromptMessage[]) {
return this.db.aiPrompt
.update({
where: { name },
data: {
messages: {
// cleanup old messages
deleteMany: {},
create: messages.map((m, idx) => ({
idx,
...m,
params: m.params || undefined,
})),
},
},
})
.then(ret => ret.id);
}
async delete(name: string) {
return this.db.aiPrompt
.deleteMany({
where: { name },
})
.then(ret => ret.count);
return this.db.aiPrompt.delete({ where: { name } }).then(ret => ret.id);
}
}

View File

@@ -0,0 +1,203 @@
import { randomUUID } from 'node:crypto';
import { Injectable, Logger } from '@nestjs/common';
import { PrismaClient } from '@prisma/client';
import { ChatPrompt, PromptService } from './prompt';
import {
ChatMessage,
ChatMessageSchema,
PromptMessage,
PromptParams,
} from './types';
export interface ChatSessionOptions {
userId: string;
workspaceId: string;
docId: string;
promptName: string;
}
export interface ChatSessionState
extends Omit<ChatSessionOptions, 'promptName'> {
// connect ids
sessionId: string;
// states
prompt: ChatPrompt;
messages: ChatMessage[];
}
export class ChatSession implements AsyncDisposable {
constructor(
private readonly state: ChatSessionState,
private readonly dispose?: (state: ChatSessionState) => Promise<void>,
private readonly maxTokenSize = 3840
) {}
get model() {
return this.state.prompt.model;
}
push(message: ChatMessage) {
this.state.messages.push(message);
}
pop() {
this.state.messages.pop();
}
private takeMessages(): ChatMessage[] {
if (this.state.prompt.action) {
const messages = this.state.messages;
return messages.slice(messages.length - 1);
}
const ret = [];
const messages = this.state.messages.slice();
let size = this.state.prompt.tokens;
while (messages.length) {
const message = messages.pop();
if (!message) break;
size += this.state.prompt.encode(message.content);
if (size > this.maxTokenSize) {
break;
}
ret.push(message);
}
ret.reverse();
return ret;
}
finish(params: PromptParams): PromptMessage[] {
const messages = this.takeMessages();
return [...this.state.prompt.finish(params), ...messages];
}
async save() {
await this.dispose?.(this.state);
}
async [Symbol.asyncDispose]() {
this.state.prompt.free();
await this.save?.();
}
}
@Injectable()
export class ChatSessionService {
private readonly logger = new Logger(ChatSessionService.name);
constructor(
private readonly db: PrismaClient,
private readonly prompt: PromptService
) {}
private async setSession(state: ChatSessionState): Promise<void> {
await this.db.aiSession.upsert({
where: {
id: state.sessionId,
},
update: {
messages: {
create: state.messages.map((m, idx) => ({ idx, ...m })),
},
},
create: {
id: state.sessionId,
messages: { create: state.messages },
// connect
user: { connect: { id: state.userId } },
workspace: { connect: { id: state.workspaceId } },
doc: {
connect: {
id_workspaceId: {
id: state.docId,
workspaceId: state.workspaceId,
},
},
},
prompt: { connect: { name: state.prompt.name } },
},
});
}
private async getSession(
sessionId: string
): Promise<ChatSessionState | undefined> {
return await this.db.aiSession
.findUnique({
where: { id: sessionId },
select: {
id: true,
userId: true,
workspaceId: true,
docId: true,
messages: true,
prompt: {
select: {
name: true,
action: true,
model: true,
messages: {
select: {
role: true,
content: true,
},
orderBy: {
idx: 'asc',
},
},
},
},
},
})
.then(async session => {
if (!session) return;
const messages = ChatMessageSchema.array().safeParse(session.messages);
return {
sessionId: session.id,
userId: session.userId,
workspaceId: session.workspaceId,
docId: session.docId,
prompt: ChatPrompt.createFromPrompt(session.prompt),
messages: messages.success ? messages.data : [],
};
});
}
async create(options: ChatSessionOptions): Promise<string> {
const sessionId = randomUUID();
const prompt = await this.prompt.get(options.promptName);
if (!prompt) {
this.logger.error(`Prompt not found: ${options.promptName}`);
throw new Error('Prompt not found');
}
await this.setSession({ ...options, sessionId, prompt, messages: [] });
return sessionId;
}
/**
* usage:
* ``` typescript
* {
* // allocate a session, can be reused chat in about 12 hours with same session
* await using session = await session.get(sessionId);
* session.push(message);
* copilot.generateText(session.finish(), model);
* }
* // session will be disposed after the block
* @param sessionId session id
* @returns
*/
async get(sessionId: string): Promise<ChatSession | null> {
const state = await this.getSession(sessionId);
if (state) {
return new ChatSession(state, async state => {
await this.setSession(state);
});
}
return null;
}
}

View File

@@ -1,5 +1,11 @@
import { AiPromptRole } from '@prisma/client';
import type { ClientOptions as OpenAIClientOptions } from 'openai';
import {
encoding_for_model,
get_encoding,
Tiktoken,
TiktokenModel,
} from 'tiktoken';
import { z } from 'zod';
export interface CopilotConfig {
@@ -9,6 +15,76 @@ export interface CopilotConfig {
};
}
export enum AvailableModels {
// text to text
Gpt4VisionPreview = 'gpt-4-vision-preview',
Gpt4TurboPreview = 'gpt-4-turbo-preview',
Gpt35Turbo = 'gpt-3.5-turbo',
// embeddings
TextEmbedding3Large = 'text-embedding-3-large',
TextEmbedding3Small = 'text-embedding-3-small',
TextEmbeddingAda002 = 'text-embedding-ada-002',
// moderation
TextModerationLatest = 'text-moderation-latest',
TextModerationStable = 'text-moderation-stable',
}
export type AvailableModel = keyof typeof AvailableModels;
export function getTokenEncoder(model?: string | null): Tiktoken | undefined {
if (!model) return undefined;
const modelStr = AvailableModels[model as AvailableModel];
if (!modelStr) return undefined;
if (modelStr.startsWith('gpt')) {
return encoding_for_model(modelStr as TiktokenModel);
} else if (modelStr.startsWith('dall')) {
// dalle don't need to calc the token
return undefined;
} else {
return get_encoding('cl100k_base');
}
}
// ======== ChatMessage ========
export const ChatMessageRole = Object.values(AiPromptRole) as [
'system',
'assistant',
'user',
];
export const PromptMessageSchema = z.object({
role: z.enum(ChatMessageRole),
content: z.string(),
attachments: z.array(z.string()).optional(),
params: z
.record(z.union([z.string(), z.array(z.string())]))
.optional()
.nullable(),
});
export type PromptMessage = z.infer<typeof PromptMessageSchema>;
export type PromptParams = NonNullable<PromptMessage['params']>;
export const ChatMessageSchema = PromptMessageSchema.extend({
createdAt: z.date(),
}).strict();
export type ChatMessage = z.infer<typeof ChatMessageSchema>;
export const ChatHistorySchema = z
.object({
sessionId: z.string(),
tokens: z.number(),
messages: z.array(ChatMessageSchema),
})
.strict();
export type ChatHistory = z.infer<typeof ChatHistorySchema>;
// ======== Provider Interface ========
export enum CopilotProviderType {
FAL = 'fal',
OpenAI = 'openai',
@@ -25,24 +101,26 @@ export interface CopilotProvider {
getCapabilities(): CopilotProviderCapability[];
}
export const ChatMessageSchema = z
.object({
role: z.enum(
Array.from(Object.values(AiPromptRole)) as [
'system' | 'assistant' | 'user',
]
),
content: z.string(),
})
.strict();
export type ChatMessage = z.infer<typeof ChatMessageSchema>;
export interface CopilotTextToTextProvider extends CopilotProvider {
generateText(messages: ChatMessage[], model: string): Promise<string>;
generateText(
messages: PromptMessage[],
model: string,
options: {
temperature?: number;
maxTokens?: number;
signal?: AbortSignal;
user?: string;
}
): Promise<string>;
generateTextStream(
messages: ChatMessage[],
model: string
messages: PromptMessage[],
model: string,
options: {
temperature?: number;
maxTokens?: number;
signal?: AbortSignal;
user?: string;
}
): AsyncIterable<string>;
}

View File

@@ -59,12 +59,74 @@ test('should be able to manage prompt', async t => {
{ role: 'user', content: 'hello' },
]);
t.is((await prompt.list()).length, 1, 'should have one prompt');
t.is((await prompt.get('test')).length, 2, 'should have two messages');
t.is(
(await prompt.get('test'))!.finish({}).length,
2,
'should have two messages'
);
await prompt.update('test', [{ role: 'system', content: 'hello' }]);
t.is((await prompt.get('test')).length, 1, 'should have one message');
t.is(
(await prompt.get('test'))!.finish({}).length,
1,
'should have one message'
);
await prompt.delete('test');
t.is((await prompt.list()).length, 0, 'should have no prompt');
t.is((await prompt.get('test')).length, 0, 'should have no messages');
t.is(await prompt.get('test'), null, 'should not have the prompt');
});
test('should be able to render prompt', async t => {
const { prompt } = t.context;
const msg = {
role: 'system' as const,
content: 'translate {{src_language}} to {{dest_language}}: {{content}}',
params: { src_language: ['eng'], dest_language: ['chs', 'jpn', 'kor'] },
};
const params = {
src_language: 'eng',
dest_language: 'chs',
content: 'hello world',
};
await prompt.set('test', [msg]);
const testPrompt = await prompt.get('test');
t.assert(testPrompt, 'should have prompt');
t.is(
testPrompt?.finish(params).pop()?.content,
'translate eng to chs: hello world',
'should render the prompt'
);
t.deepEqual(
testPrompt?.paramKeys,
Object.keys(params),
'should have param keys'
);
t.deepEqual(testPrompt?.params, msg.params, 'should have params');
t.throws(() => testPrompt?.finish({ src_language: 'abc' }), {
instanceOf: Error,
});
});
test('should be able to render listed prompt', async t => {
const { prompt } = t.context;
const msg = {
role: 'system' as const,
content: 'links:\n{{#links}}- {{.}}\n{{/links}}',
};
const params = {
links: ['https://affine.pro', 'https://github.com/toeverything/affine'],
};
await prompt.set('test', [msg]);
const testPrompt = await prompt.get('test');
t.is(
testPrompt?.finish(params).pop()?.content,
'links:\n- https://affine.pro\n- https://github.com/toeverything/affine\n',
'should render the prompt'
);
});

View File

@@ -49,7 +49,7 @@ test('should be able to set quota', async t => {
const q1 = await quota.getUserQuota(u1.id);
t.truthy(q1, 'should have quota');
t.is(q1?.feature.name, QuotaType.FreePlanV1, 'should be free plan');
t.is(q1?.feature.version, 3, 'should be version 2');
t.is(q1?.feature.version, 3, 'should be version 3');
await quota.switchUserQuota(u1.id, QuotaType.ProPlanV1);

View File

@@ -697,6 +697,7 @@ __metadata:
"@types/keyv": "npm:^4.2.0"
"@types/lodash-es": "npm:^4.17.12"
"@types/mixpanel": "npm:^2.14.8"
"@types/mustache": "npm:^4"
"@types/node": "npm:^20.11.20"
"@types/nodemailer": "npm:^6.4.14"
"@types/on-headers": "npm:^1.0.3"
@@ -720,6 +721,7 @@ __metadata:
keyv: "npm:^4.5.4"
lodash-es: "npm:^4.17.21"
mixpanel: "npm:^0.18.0"
mustache: "npm:^4.2.0"
nanoid: "npm:^5.0.6"
nest-commander: "npm:^3.12.5"
nestjs-throttler-storage-redis: "npm:^0.4.1"
@@ -738,6 +740,7 @@ __metadata:
socket.io: "npm:^4.7.4"
stripe: "npm:^14.18.0"
supertest: "npm:^6.3.4"
tiktoken: "npm:^1.0.13"
ts-node: "npm:^10.9.2"
typescript: "npm:^5.3.3"
ws: "npm:^8.16.0"
@@ -14489,6 +14492,13 @@ __metadata:
languageName: node
linkType: hard
"@types/mustache@npm:^4":
version: 4.2.5
resolution: "@types/mustache@npm:4.2.5"
checksum: 10/29581027fe420120ae0591e28d44209d0e01adf5175910d03401327777ee9c649a1508e2aa63147c782c7e53fcea4b69b5f9a2fbedcadc5500561d1161ae5ded
languageName: node
linkType: hard
"@types/mute-stream@npm:^0.0.4":
version: 0.0.4
resolution: "@types/mute-stream@npm:0.0.4"
@@ -33968,6 +33978,13 @@ __metadata:
languageName: node
linkType: hard
"tiktoken@npm:^1.0.13":
version: 1.0.13
resolution: "tiktoken@npm:1.0.13"
checksum: 10/4217ffbcd4126dc2dd17503fda35be91cf4be64c514f70e1049982d1bd2b5cea6334e76812411cb284dfa7b412159839d546048ac98220faf3c629e217266ddc
languageName: node
linkType: hard
"time-zone@npm:^1.0.0":
version: 1.0.0
resolution: "time-zone@npm:1.0.0"