diff --git a/packages/backend/server/migrations/20250417062046_workspace_level_embed_file_doc/migration.sql b/packages/backend/server/migrations/20250417062046_workspace_level_embed_file_doc/migration.sql new file mode 100644 index 0000000000..85e311ab8b --- /dev/null +++ b/packages/backend/server/migrations/20250417062046_workspace_level_embed_file_doc/migration.sql @@ -0,0 +1,81 @@ +-- CreateTable +CREATE TABLE "ai_workspace_ignored_docs" ( + "workspace_id" VARCHAR NOT NULL, + "doc_id" VARCHAR NOT NULL, + "created_at" TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + + CONSTRAINT "ai_workspace_ignored_docs_pkey" PRIMARY KEY ("workspace_id","doc_id") +); + +-- CreateTable +CREATE TABLE "ai_workspace_files" ( + "workspace_id" VARCHAR NOT NULL, + "file_id" VARCHAR NOT NULL, + "file_name" VARCHAR NOT NULL, + "mime_type" VARCHAR NOT NULL, + "size" INTEGER NOT NULL, + "created_at" TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + + CONSTRAINT "ai_workspace_files_pkey" PRIMARY KEY ("workspace_id","file_id") +); + +-- AddForeignKey +ALTER TABLE "ai_workspace_ignored_docs" ADD CONSTRAINT "ai_workspace_ignored_docs_workspace_id_fkey" FOREIGN KEY ("workspace_id") REFERENCES "workspaces"("id") ON DELETE CASCADE ON UPDATE CASCADE; + +-- AddForeignKey +ALTER TABLE "ai_workspace_files" ADD CONSTRAINT "ai_workspace_files_workspace_id_fkey" FOREIGN KEY ("workspace_id") REFERENCES "workspaces"("id") ON DELETE CASCADE ON UPDATE CASCADE; + +DO $$ +DECLARE error_message TEXT; +BEGIN -- check if pgvector extension is installed + IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'vector') THEN + BEGIN + -- CreateExtension + CREATE EXTENSION IF NOT EXISTS "vector"; + EXCEPTION + WHEN OTHERS THEN + -- if not found and cannot create extension, raise the exception + error_message := 'pgvector extension not found.' || E'\n' || + '****************************************************************************' || E'\n' || + '* *' || E'\n' || + '* NOTICE: From AFFiNE 0.20 onwards, the copilot module will depend *' || E'\n' || + '* on pgvector. *' || E'\n' || + '* *' || E'\n' || + '* 1. If you are using the official PostgreSQL Docker container, *' || E'\n' || + '* please switch to the pgvector/pgvector:pg${VERSION} container, *' || E'\n' || + '* where ${VERSION} is the major version of your PostgreSQL container. *' || E'\n' || + '* *' || E'\n' || + '* 2. If you are using a self-installed PostgreSQL, please follow the *' || E'\n' || + '* the official pgvector installation guide to install it into your *' || E'\n' || + '* database: https://github.com/pgvector/pgvector?tab=readme-ov- *' || E'\n' || + '* file#installation-notes---linux-and-mac *' || E'\n' || + '* *' || E'\n' || + '****************************************************************************'; + + RAISE WARNING '%', error_message; + END; + END IF; + -- check again, initialize the tables if the extension is installed + IF EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'vector') THEN + -- CreateTable + CREATE TABLE "ai_workspace_file_embeddings" ( + "workspace_id" VARCHAR NOT NULL, + "file_id" VARCHAR NOT NULL, + "chunk" INTEGER NOT NULL, + "content" VARCHAR NOT NULL, + "embedding" vector(1024) NOT NULL, + "created_at" TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + + CONSTRAINT "ai_workspace_file_embeddings_pkey" PRIMARY KEY ("workspace_id","file_id") + ); + + -- CreateIndex + CREATE INDEX "ai_workspace_file_embeddings_idx" ON "ai_workspace_file_embeddings" USING hnsw (embedding vector_cosine_ops); + + -- CreateIndex + CREATE UNIQUE INDEX "ai_workspace_file_embeddings_workspace_id_file_id_chunk_key" ON "ai_workspace_file_embeddings"("workspace_id", "file_id", "chunk"); + + -- AddForeignKey + ALTER TABLE "ai_workspace_file_embeddings" ADD CONSTRAINT "ai_workspace_file_embeddings_workspace_id_file_id_fkey" FOREIGN KEY ("workspace_id", "file_id") REFERENCES "ai_workspace_files"("workspace_id", "file_id") ON DELETE CASCADE ON UPDATE CASCADE; + END IF; +END $$; diff --git a/packages/backend/server/schema.prisma b/packages/backend/server/schema.prisma index 274a20fe3d..d994758f61 100644 --- a/packages/backend/server/schema.prisma +++ b/packages/backend/server/schema.prisma @@ -114,11 +114,13 @@ model Workspace { name String? @db.VarChar avatarKey String? @map("avatar_key") @db.VarChar - features WorkspaceFeature[] - docs WorkspaceDoc[] - permissions WorkspaceUserRole[] - docPermissions WorkspaceDocUserRole[] - blobs Blob[] + features WorkspaceFeature[] + docs WorkspaceDoc[] + permissions WorkspaceUserRole[] + docPermissions WorkspaceDocUserRole[] + blobs Blob[] + AiWorkspaceIgnoredDocs AiWorkspaceIgnoredDocs[] + AiWorkspaceFiles AiWorkspaceFiles[] @@map("workspaces") } @@ -481,6 +483,53 @@ model AiWorkspaceEmbedding { @@map("ai_workspace_embeddings") } +model AiWorkspaceIgnoredDocs { + workspaceId String @map("workspace_id") @db.VarChar + docId String @map("doc_id") @db.VarChar + + createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3) + + workspace Workspace @relation(fields: [workspaceId], references: [id], onDelete: Cascade) + + @@id([workspaceId, docId]) + @@map("ai_workspace_ignored_docs") +} + +model AiWorkspaceFiles { + workspaceId String @map("workspace_id") @db.VarChar + fileId String @map("file_id") @db.VarChar + fileName String @map("file_name") @db.VarChar + mimeType String @map("mime_type") @db.VarChar + size Int @db.Integer + + createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3) + + workspace Workspace @relation(fields: [workspaceId], references: [id], onDelete: Cascade) + + embeddings AiWorkspaceFileEmbedding[] + + @@id([workspaceId, fileId]) + @@map("ai_workspace_files") +} + +model AiWorkspaceFileEmbedding { + workspaceId String @map("workspace_id") @db.VarChar + fileId String @map("file_id") @db.VarChar + // a file can be divided into multiple chunks and embedded separately. + chunk Int @db.Integer + content String @db.VarChar + embedding Unsupported("vector(1024)") + + createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3) + + file AiWorkspaceFiles @relation(fields: [workspaceId, fileId], references: [workspaceId, fileId], onDelete: Cascade) + + @@id([workspaceId, fileId]) + @@unique([workspaceId, fileId, chunk]) + @@index([embedding], map: "ai_workspace_file_embeddings_idx") + @@map("ai_workspace_file_embeddings") +} + enum AiJobStatus { pending running diff --git a/packages/backend/server/src/__tests__/models/copilot-workspace.spec.ts b/packages/backend/server/src/__tests__/models/copilot-workspace.spec.ts new file mode 100644 index 0000000000..ab349c816d --- /dev/null +++ b/packages/backend/server/src/__tests__/models/copilot-workspace.spec.ts @@ -0,0 +1,149 @@ +import { PrismaClient, User, Workspace } from '@prisma/client'; +import ava, { TestFn } from 'ava'; + +import { Config } from '../../base'; +import { CopilotWorkspaceConfigModel } from '../../models/copilot-workspace'; +import { UserModel } from '../../models/user'; +import { WorkspaceModel } from '../../models/workspace'; +import { createTestingModule, type TestingModule } from '../utils'; + +interface Context { + config: Config; + module: TestingModule; + db: PrismaClient; + user: UserModel; + workspace: WorkspaceModel; + copilotWorkspace: CopilotWorkspaceConfigModel; +} + +const test = ava as TestFn; + +test.before(async t => { + const module = await createTestingModule(); + t.context.user = module.get(UserModel); + t.context.workspace = module.get(WorkspaceModel); + t.context.copilotWorkspace = module.get(CopilotWorkspaceConfigModel); + t.context.db = module.get(PrismaClient); + t.context.config = module.get(Config); + t.context.module = module; +}); + +let user: User; +let workspace: Workspace; + +let docId = 'doc1'; + +test.beforeEach(async t => { + await t.context.module.initTestingDB(); + user = await t.context.user.create({ + email: 'test@affine.pro', + }); + workspace = await t.context.workspace.create(user.id); +}); + +test.after(async t => { + await t.context.module.close(); +}); + +test('should manage copilot workspace ignored docs', async t => { + const ignoredDocs = await t.context.copilotWorkspace.listIgnoredDocs( + workspace.id + ); + t.deepEqual(ignoredDocs, []); + + { + const count = await t.context.copilotWorkspace.updateIgnoredDocs( + workspace.id, + [docId] + ); + t.is(count, 1, 'should add ignored doc'); + + const ret = await t.context.copilotWorkspace.listIgnoredDocs(workspace.id); + t.deepEqual(ret, [docId], 'should return added doc'); + + const check = await t.context.copilotWorkspace.checkIgnoredDocs( + workspace.id, + [docId] + ); + t.deepEqual(check, [docId], 'should return ignored docs in workspace'); + } + + { + const count = await t.context.copilotWorkspace.updateIgnoredDocs( + workspace.id, + [docId] + ); + t.is(count, 1, 'should not add ignored doc again'); + + const ret = await t.context.copilotWorkspace.listIgnoredDocs(workspace.id); + t.deepEqual(ret, [docId], 'should not add ignored doc again'); + } + + { + const count = await t.context.copilotWorkspace.updateIgnoredDocs( + workspace.id, + ['new_doc'] + ); + t.is(count, 2, 'should add new ignored doc'); + + const ret = await t.context.copilotWorkspace.listIgnoredDocs(workspace.id); + t.deepEqual(ret, [docId, 'new_doc'], 'should add ignored doc'); + } + + { + await t.context.copilotWorkspace.updateIgnoredDocs( + workspace.id, + undefined, + [docId] + ); + + const ret = await t.context.copilotWorkspace.listIgnoredDocs(workspace.id); + t.deepEqual(ret, ['new_doc'], 'should remove ignored doc'); + } +}); + +test('should insert and search embedding', async t => { + { + await t.context.copilotWorkspace.addWorkspaceFile( + workspace.id, + { + fileName: 'file1', + mimeType: 'text/plain', + + size: 1, + }, + [ + { + index: 0, + content: 'content', + embedding: Array.from({ length: 1024 }, () => 1), + }, + ] + ); + + { + const ret = await t.context.copilotWorkspace.matchWorkspaceFileEmbedding( + workspace.id, + Array.from({ length: 1024 }, () => 0.9), + 1, + 1 + ); + t.is(ret.length, 1); + t.is(ret[0].content, 'content'); + } + } +}); + +test('should check embedding table', async t => { + { + const ret = await t.context.copilotWorkspace.checkEmbeddingAvailable(); + t.true(ret, 'should return true when embedding table is available'); + } + + // { + // await t.context.db + // .$executeRaw`DROP TABLE IF EXISTS "ai_workspace_file_embeddings"`; + // const ret = await t.context.copilotWorkspace.checkEmbeddingAvailable(); + // t.false(ret, 'should return false when embedding table is not available'); + // } +}); diff --git a/packages/backend/server/src/models/common/copilot.ts b/packages/backend/server/src/models/common/copilot.ts index 2ba40e16ff..9fe2a56e67 100644 --- a/packages/backend/server/src/models/common/copilot.ts +++ b/packages/backend/server/src/models/common/copilot.ts @@ -105,3 +105,17 @@ export type FileChunkSimilarity = ChunkSimilarity & { export type DocChunkSimilarity = ChunkSimilarity & { docId: string; }; + +export const CopilotWorkspaceFileSchema = z.object({ + fileName: z.string(), + mimeType: z.string(), + size: z.number(), +}); + +export type CopilotWorkspaceFile = z.infer< + typeof CopilotWorkspaceFileSchema +> & { + workspaceId: string; + fileId: string; + createdAt: Date; +}; diff --git a/packages/backend/server/src/models/copilot-context.ts b/packages/backend/server/src/models/copilot-context.ts index 712527671d..670ac28742 100644 --- a/packages/backend/server/src/models/copilot-context.ts +++ b/packages/backend/server/src/models/copilot-context.ts @@ -177,12 +177,12 @@ export class CopilotContextModel extends BaseModel { const similarityChunks = await this.db.$queryRaw< Array >` - SELECT "file_id" as "fileId", "chunk", "content", "embedding" <=> ${embedding}::vector as "distance" - FROM "ai_context_embeddings" - WHERE context_id = ${contextId} - ORDER BY "distance" ASC - LIMIT ${topK}; - `; + SELECT "file_id" as "fileId", "chunk", "content", "embedding" <=> ${embedding}::vector as "distance" + FROM "ai_context_embeddings" + WHERE context_id = ${contextId} + ORDER BY "distance" ASC + LIMIT ${topK}; + `; return similarityChunks.filter(c => Number(c.distance) <= threshold); } @@ -198,11 +198,11 @@ export class CopilotContextModel extends BaseModel { false ); await this.db.$executeRaw` - INSERT INTO "ai_workspace_embeddings" - ("workspace_id", "doc_id", "chunk", "content", "embedding", "updated_at") VALUES ${values} - ON CONFLICT (workspace_id, doc_id, chunk) DO UPDATE SET - embedding = EXCLUDED.embedding, updated_at = excluded.updated_at; - `; + INSERT INTO "ai_workspace_embeddings" + ("workspace_id", "doc_id", "chunk", "content", "embedding", "updated_at") VALUES ${values} + ON CONFLICT (workspace_id, doc_id, chunk) DO UPDATE SET + embedding = EXCLUDED.embedding, updated_at = excluded.updated_at; + `; } async matchWorkspaceEmbedding( @@ -212,12 +212,12 @@ export class CopilotContextModel extends BaseModel { threshold: number ): Promise { const similarityChunks = await this.db.$queryRaw>` - SELECT "doc_id" as "docId", "chunk", "content", "embedding" <=> ${embedding}::vector as "distance" - FROM "ai_workspace_embeddings" - WHERE "workspace_id" = ${workspaceId} - ORDER BY "distance" ASC - LIMIT ${topK}; - `; + SELECT "doc_id" as "docId", "chunk", "content", "embedding" <=> ${embedding}::vector as "distance" + FROM "ai_workspace_embeddings" + WHERE "workspace_id" = ${workspaceId} + ORDER BY "distance" ASC + LIMIT ${topK}; + `; return similarityChunks.filter(c => Number(c.distance) <= threshold); } diff --git a/packages/backend/server/src/models/copilot-workspace.ts b/packages/backend/server/src/models/copilot-workspace.ts new file mode 100644 index 0000000000..8cfd0af254 --- /dev/null +++ b/packages/backend/server/src/models/copilot-workspace.ts @@ -0,0 +1,156 @@ +import { randomUUID } from 'node:crypto'; + +import { Injectable } from '@nestjs/common'; +import { Transactional } from '@nestjs-cls/transactional'; +import { Prisma } from '@prisma/client'; + +import { BaseModel } from './base'; +import { + type CopilotWorkspaceFile, + type Embedding, + FileChunkSimilarity, +} from './common'; + +@Injectable() +export class CopilotWorkspaceConfigModel extends BaseModel { + @Transactional() + async updateIgnoredDocs( + workspaceId: string, + add: string[] = [], + remove: string[] = [] + ) { + const removed = new Set(remove); + const ignored = await this.listIgnoredDocs(workspaceId).then( + r => new Set(r.filter(id => !removed.has(id))) + ); + const added = add.filter(id => !ignored.has(id)); + + if (added.length) { + await this.db.aiWorkspaceIgnoredDocs.createMany({ + data: added.map(docId => ({ + workspaceId, + docId, + })), + }); + } + + if (removed.size) { + await this.db.aiWorkspaceIgnoredDocs.deleteMany({ + where: { + workspaceId, + docId: { + in: Array.from(removed), + }, + }, + }); + } + + return added.length + ignored.size; + } + + async listIgnoredDocs(workspaceId: string): Promise { + const row = await this.db.aiWorkspaceIgnoredDocs.findMany({ + where: { + workspaceId, + }, + select: { + docId: true, + }, + }); + return row.map(r => r.docId); + } + + @Transactional() + async checkIgnoredDocs(workspaceId: string, docIds: string[]) { + const ignored = await this.listIgnoredDocs(workspaceId).then( + r => new Set(r) + ); + + return docIds.filter(id => ignored.has(id)); + } + + // ================ embeddings ================ + + async checkEmbeddingAvailable(): Promise { + const [{ count }] = await this.db.$queryRaw< + { count: number }[] + >`SELECT count(1) FROM pg_tables WHERE tablename in ('ai_workspace_file_embeddings')`; + return Number(count) === 1; + } + + private processEmbeddings( + workspaceId: string, + fileId: string, + embeddings: Embedding[] + ) { + const groups = embeddings.map(e => + [ + workspaceId, + fileId, + e.index, + e.content, + Prisma.raw(`'[${e.embedding.join(',')}]'`), + ].filter(v => v !== undefined) + ); + return Prisma.join(groups.map(row => Prisma.sql`(${Prisma.join(row)})`)); + } + + @Transactional() + async addWorkspaceFile( + workspaceId: string, + file: Pick, + embeddings: Embedding[] + ): Promise { + const fileId = randomUUID(); + await this.db.aiWorkspaceFiles.create({ + data: { ...file, workspaceId, fileId }, + }); + + const values = this.processEmbeddings(workspaceId, fileId, embeddings); + await this.db.$executeRaw` + INSERT INTO "ai_workspace_file_embeddings" + ("workspace_id", "file_id", "chunk", "content", "embedding") VALUES ${values} + ON CONFLICT (workspace_id, file_id, chunk) DO NOTHING; + `; + return fileId; + } + + async listWorkspaceFiles( + workspaceId: string + ): Promise { + const files = await this.db.aiWorkspaceFiles.findMany({ + where: { + workspaceId, + }, + }); + return files; + } + + async matchWorkspaceFileEmbedding( + workspaceId: string, + embedding: number[], + topK: number, + threshold: number + ): Promise { + const similarityChunks = await this.db.$queryRaw< + Array + >` + SELECT "file_id" as "fileId", "chunk", "content", "embedding" <=> ${embedding}::vector as "distance" + FROM "ai_workspace_file_embeddings" + WHERE workspace_id = ${workspaceId} + ORDER BY "distance" ASC + LIMIT ${topK}; + `; + return similarityChunks.filter(c => Number(c.distance) <= threshold); + } + + async removeWorkspaceFile(workspaceId: string, fileId: string) { + // embeddings will be removed by foreign key constraint + await this.db.aiWorkspaceFiles.deleteMany({ + where: { + workspaceId, + fileId, + }, + }); + } +} diff --git a/packages/backend/server/src/models/index.ts b/packages/backend/server/src/models/index.ts index e0cbb0754b..aea37ad0c0 100644 --- a/packages/backend/server/src/models/index.ts +++ b/packages/backend/server/src/models/index.ts @@ -11,6 +11,7 @@ import { AppConfigModel } from './config'; import { CopilotContextModel } from './copilot-context'; import { CopilotJobModel } from './copilot-job'; import { CopilotSessionModel } from './copilot-session'; +import { CopilotWorkspaceConfigModel } from './copilot-workspace'; import { DocModel } from './doc'; import { DocUserModel } from './doc-user'; import { FeatureModel } from './feature'; @@ -44,6 +45,7 @@ const MODELS = { userSettings: UserSettingsModel, copilotSession: CopilotSessionModel, copilotContext: CopilotContextModel, + copilotWorkspace: CopilotWorkspaceConfigModel, copilotJob: CopilotJobModel, appConfig: AppConfigModel, };