mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-13 04:48:53 +00:00
feat(server): impl context model (#11027)
This commit is contained in:
@@ -0,0 +1,187 @@
|
||||
import { randomUUID } from 'node:crypto';
|
||||
|
||||
import { AiSession, PrismaClient, User, Workspace } from '@prisma/client';
|
||||
import ava, { TestFn } from 'ava';
|
||||
|
||||
import { Config } from '../../base';
|
||||
import { CopilotContextModel } from '../../models/copilot-context';
|
||||
import { CopilotSessionModel } from '../../models/copilot-session';
|
||||
import { UserModel } from '../../models/user';
|
||||
import { WorkspaceModel } from '../../models/workspace';
|
||||
import { createTestingModule, type TestingModule } from '../utils';
|
||||
|
||||
interface Context {
|
||||
config: Config;
|
||||
module: TestingModule;
|
||||
db: PrismaClient;
|
||||
user: UserModel;
|
||||
workspace: WorkspaceModel;
|
||||
copilotSession: CopilotSessionModel;
|
||||
copilotContext: CopilotContextModel;
|
||||
}
|
||||
|
||||
const test = ava as TestFn<Context>;
|
||||
|
||||
test.before(async t => {
|
||||
const module = await createTestingModule();
|
||||
t.context.user = module.get(UserModel);
|
||||
t.context.workspace = module.get(WorkspaceModel);
|
||||
t.context.copilotSession = module.get(CopilotSessionModel);
|
||||
t.context.copilotContext = module.get(CopilotContextModel);
|
||||
t.context.db = module.get(PrismaClient);
|
||||
t.context.config = module.get(Config);
|
||||
t.context.module = module;
|
||||
});
|
||||
|
||||
let user: User;
|
||||
let workspace: Workspace;
|
||||
let session: AiSession;
|
||||
let docId = 'doc1';
|
||||
|
||||
test.beforeEach(async t => {
|
||||
await t.context.module.initTestingDB();
|
||||
await t.context.copilotSession.createPrompt('prompt-name', 'gpt-4o');
|
||||
user = await t.context.user.create({
|
||||
email: 'test@affine.pro',
|
||||
});
|
||||
workspace = await t.context.workspace.create(user.id);
|
||||
session = await t.context.copilotSession.create({
|
||||
sessionId: randomUUID(),
|
||||
workspaceId: workspace.id,
|
||||
docId,
|
||||
userId: user.id,
|
||||
promptName: 'prompt-name',
|
||||
});
|
||||
});
|
||||
|
||||
test.after(async t => {
|
||||
await t.context.module.close();
|
||||
});
|
||||
|
||||
test('should create a copilot context', async t => {
|
||||
const { id: contextId } = await t.context.copilotContext.create(session.id);
|
||||
t.truthy(contextId);
|
||||
|
||||
const context = await t.context.copilotContext.get(contextId);
|
||||
t.is(context?.id, contextId, 'should get context by id');
|
||||
|
||||
const config = await t.context.copilotContext.getConfig(contextId);
|
||||
t.is(config?.workspaceId, workspace.id, 'should get context config');
|
||||
|
||||
const context1 = await t.context.copilotContext.getBySessionId(session.id);
|
||||
t.is(context1?.id, contextId, 'should get context by session id');
|
||||
});
|
||||
|
||||
test('should get null for non-exist job', async t => {
|
||||
const job = await t.context.copilotContext.get('non-exist');
|
||||
t.is(job, null);
|
||||
});
|
||||
|
||||
test('should update context', async t => {
|
||||
const { id: contextId } = await t.context.copilotContext.create(session.id);
|
||||
const config = await t.context.copilotContext.getConfig(contextId);
|
||||
|
||||
const doc = {
|
||||
id: docId,
|
||||
createdAt: Date.now(),
|
||||
};
|
||||
config?.docs.push(doc);
|
||||
await t.context.copilotContext.update(contextId, { config });
|
||||
|
||||
const config1 = await t.context.copilotContext.getConfig(contextId);
|
||||
t.deepEqual(config1, config);
|
||||
});
|
||||
|
||||
test('should insert embedding by doc id', async t => {
|
||||
const { id: contextId } = await t.context.copilotContext.create(session.id);
|
||||
|
||||
{
|
||||
await t.context.copilotContext.insertEmbedding(contextId, 'file-id', [
|
||||
{
|
||||
index: 0,
|
||||
content: 'content',
|
||||
embedding: Array.from({ length: 512 }, () => 1),
|
||||
},
|
||||
]);
|
||||
|
||||
{
|
||||
const ret = await t.context.copilotContext.matchEmbedding(
|
||||
Array.from({ length: 512 }, () => 0.9),
|
||||
contextId,
|
||||
1,
|
||||
1
|
||||
);
|
||||
t.is(ret.length, 1);
|
||||
t.is(ret[0].content, 'content');
|
||||
}
|
||||
|
||||
{
|
||||
await t.context.copilotContext.deleteEmbedding(contextId, 'file-id');
|
||||
const ret = await t.context.copilotContext.matchEmbedding(
|
||||
Array.from({ length: 512 }, () => 0.9),
|
||||
contextId,
|
||||
1,
|
||||
1
|
||||
);
|
||||
t.is(ret.length, 0);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
await t.context.db.snapshot.create({
|
||||
data: {
|
||||
workspaceId: workspace.id,
|
||||
id: docId,
|
||||
blob: Buffer.from([1, 1]),
|
||||
state: Buffer.from([1, 1]),
|
||||
updatedAt: new Date(),
|
||||
createdAt: new Date(),
|
||||
},
|
||||
});
|
||||
|
||||
await t.context.copilotContext.insertWorkspaceEmbedding(
|
||||
workspace.id,
|
||||
docId,
|
||||
[
|
||||
{
|
||||
index: 0,
|
||||
content: 'content',
|
||||
embedding: Array.from({ length: 512 }, () => 1),
|
||||
},
|
||||
]
|
||||
);
|
||||
|
||||
{
|
||||
const ret = await t.context.copilotContext.hasWorkspaceEmbedding(
|
||||
workspace.id,
|
||||
[docId]
|
||||
);
|
||||
t.true(ret.has(docId), 'should return true when embedding exists');
|
||||
}
|
||||
|
||||
{
|
||||
const ret = await t.context.copilotContext.matchWorkspaceEmbedding(
|
||||
Array.from({ length: 512 }, () => 0.9),
|
||||
workspace.id,
|
||||
1,
|
||||
1
|
||||
);
|
||||
t.is(ret.length, 1);
|
||||
t.is(ret[0].content, 'content');
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
test('should check embedding table', async t => {
|
||||
{
|
||||
const ret = await t.context.copilotContext.checkEmbeddingAvailable();
|
||||
t.true(ret, 'should return true when embedding table is available');
|
||||
}
|
||||
|
||||
// {
|
||||
// await t.context.db
|
||||
// .$executeRaw`DROP TABLE IF EXISTS "ai_context_embeddings"`;
|
||||
// const ret = await t.context.copilotContext.checkEmbeddingAvailable();
|
||||
// t.false(ret, 'should return false when embedding table is not available');
|
||||
// }
|
||||
});
|
||||
@@ -1,5 +1,6 @@
|
||||
import { AiJobStatus, AiJobType } from '@prisma/client';
|
||||
import type { JsonValue } from '@prisma/client/runtime/library';
|
||||
import { z } from 'zod';
|
||||
|
||||
export interface CopilotJob {
|
||||
id?: string;
|
||||
@@ -10,3 +11,91 @@ export interface CopilotJob {
|
||||
status?: AiJobStatus;
|
||||
payload?: JsonValue;
|
||||
}
|
||||
|
||||
export interface CopilotContext {
|
||||
id?: string;
|
||||
sessionId: string;
|
||||
config: JsonValue;
|
||||
createdAt: Date;
|
||||
updatedAt: Date;
|
||||
}
|
||||
|
||||
export enum ContextEmbedStatus {
|
||||
processing = 'processing',
|
||||
finished = 'finished',
|
||||
failed = 'failed',
|
||||
}
|
||||
|
||||
export enum ContextCategories {
|
||||
Tag = 'tag',
|
||||
Collection = 'collection',
|
||||
}
|
||||
|
||||
export const ContextDocSchema = z.object({
|
||||
id: z.string(),
|
||||
createdAt: z.number(),
|
||||
});
|
||||
|
||||
export const ContextFileSchema = z.object({
|
||||
id: z.string(),
|
||||
chunkSize: z.number(),
|
||||
name: z.string(),
|
||||
status: z.enum([
|
||||
ContextEmbedStatus.processing,
|
||||
ContextEmbedStatus.finished,
|
||||
ContextEmbedStatus.failed,
|
||||
]),
|
||||
error: z.string().nullable(),
|
||||
blobId: z.string(),
|
||||
createdAt: z.number(),
|
||||
});
|
||||
|
||||
export const ContextCategorySchema = z.object({
|
||||
id: z.string(),
|
||||
type: z.enum([ContextCategories.Tag, ContextCategories.Collection]),
|
||||
docs: ContextDocSchema.array(),
|
||||
createdAt: z.number(),
|
||||
});
|
||||
|
||||
export const ContextConfigSchema = z.object({
|
||||
workspaceId: z.string(),
|
||||
files: ContextFileSchema.array(),
|
||||
docs: ContextDocSchema.array(),
|
||||
categories: ContextCategorySchema.array(),
|
||||
});
|
||||
|
||||
export const MinimalContextConfigSchema = ContextConfigSchema.pick({
|
||||
workspaceId: true,
|
||||
});
|
||||
|
||||
export type ContextCategory = z.infer<typeof ContextCategorySchema>;
|
||||
export type ContextDoc = z.infer<typeof ContextDocSchema>;
|
||||
export type ContextFile = z.infer<typeof ContextFileSchema>;
|
||||
export type ContextConfig = z.infer<typeof ContextConfigSchema>;
|
||||
export type ContextListItem = ContextDoc | ContextFile;
|
||||
export type ContextList = ContextListItem[];
|
||||
|
||||
// embeddings
|
||||
|
||||
export type Embedding = {
|
||||
/**
|
||||
* The index of the embedding in the list of embeddings.
|
||||
*/
|
||||
index: number;
|
||||
content: string;
|
||||
embedding: Array<number>;
|
||||
};
|
||||
|
||||
export type ChunkSimilarity = {
|
||||
chunk: number;
|
||||
content: string;
|
||||
distance: number | null;
|
||||
};
|
||||
|
||||
export type FileChunkSimilarity = ChunkSimilarity & {
|
||||
fileId: string;
|
||||
};
|
||||
|
||||
export type DocChunkSimilarity = ChunkSimilarity & {
|
||||
docId: string;
|
||||
};
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
export * from './copilot';
|
||||
export * from './doc';
|
||||
export * from './feature';
|
||||
export * from './role';
|
||||
|
||||
233
packages/backend/server/src/models/copilot-context.ts
Normal file
233
packages/backend/server/src/models/copilot-context.ts
Normal file
@@ -0,0 +1,233 @@
|
||||
import { randomUUID } from 'node:crypto';
|
||||
|
||||
import { Injectable } from '@nestjs/common';
|
||||
import { Prisma } from '@prisma/client';
|
||||
|
||||
import { CopilotSessionNotFound } from '../base';
|
||||
import { BaseModel } from './base';
|
||||
import {
|
||||
ChunkSimilarity,
|
||||
ContextConfigSchema,
|
||||
ContextDoc,
|
||||
ContextEmbedStatus,
|
||||
CopilotContext,
|
||||
DocChunkSimilarity,
|
||||
Embedding,
|
||||
FileChunkSimilarity,
|
||||
MinimalContextConfigSchema,
|
||||
} from './common/copilot';
|
||||
|
||||
type UpdateCopilotContextInput = Pick<CopilotContext, 'config'>;
|
||||
|
||||
/**
|
||||
* Copilot Job Model
|
||||
*/
|
||||
@Injectable()
|
||||
export class CopilotContextModel extends BaseModel {
|
||||
// contexts
|
||||
|
||||
async create(sessionId: string) {
|
||||
const session = await this.db.aiSession.findFirst({
|
||||
where: { id: sessionId },
|
||||
select: { workspaceId: true },
|
||||
});
|
||||
if (!session) {
|
||||
throw new CopilotSessionNotFound();
|
||||
}
|
||||
|
||||
const row = await this.db.aiContext.create({
|
||||
data: {
|
||||
sessionId,
|
||||
config: {
|
||||
workspaceId: session.workspaceId,
|
||||
docs: [],
|
||||
files: [],
|
||||
categories: [],
|
||||
},
|
||||
},
|
||||
});
|
||||
return row;
|
||||
}
|
||||
|
||||
async get(id: string) {
|
||||
const row = await this.db.aiContext.findFirst({
|
||||
where: { id },
|
||||
});
|
||||
return row;
|
||||
}
|
||||
|
||||
async getConfig(id: string) {
|
||||
const row = await this.get(id);
|
||||
if (row) {
|
||||
const config = ContextConfigSchema.safeParse(row.config);
|
||||
if (config.success) {
|
||||
return config.data;
|
||||
}
|
||||
const minimalConfig = MinimalContextConfigSchema.safeParse(row.config);
|
||||
if (minimalConfig.success) {
|
||||
// fulfill the missing fields
|
||||
return {
|
||||
...minimalConfig.data,
|
||||
docs: [],
|
||||
files: [],
|
||||
categories: [],
|
||||
};
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
async getBySessionId(sessionId: string) {
|
||||
const row = await this.db.aiContext.findFirst({
|
||||
where: { sessionId },
|
||||
});
|
||||
return row;
|
||||
}
|
||||
|
||||
async mergeDocStatus(
|
||||
workspaceId: string,
|
||||
docs: (ContextDoc & { status?: ContextEmbedStatus | null })[]
|
||||
) {
|
||||
const docIds = Array.from(new Set(docs.map(doc => doc.id)));
|
||||
const finishedDoc = await this.hasWorkspaceEmbedding(workspaceId, docIds);
|
||||
|
||||
for (const doc of docs) {
|
||||
const status = finishedDoc.has(doc.id)
|
||||
? ContextEmbedStatus.finished
|
||||
: null;
|
||||
doc.status = status;
|
||||
}
|
||||
|
||||
return docs;
|
||||
}
|
||||
|
||||
async update(contextId: string, data: UpdateCopilotContextInput) {
|
||||
const ret = await this.db.aiContext.updateMany({
|
||||
where: {
|
||||
id: contextId,
|
||||
},
|
||||
data: {
|
||||
config: data.config || undefined,
|
||||
},
|
||||
});
|
||||
return ret.count > 0;
|
||||
}
|
||||
|
||||
// embeddings
|
||||
|
||||
async checkEmbeddingAvailable(): Promise<boolean> {
|
||||
const [{ count }] = await this.db.$queryRaw<
|
||||
{ count: number }[]
|
||||
>`SELECT count(1) FROM pg_tables WHERE tablename in ('ai_context_embeddings', 'ai_workspace_embeddings')`;
|
||||
return Number(count) === 2;
|
||||
}
|
||||
|
||||
async hasWorkspaceEmbedding(workspaceId: string, docIds: string[]) {
|
||||
const existsIds = await this.db.aiWorkspaceEmbedding
|
||||
.findMany({
|
||||
where: {
|
||||
workspaceId,
|
||||
docId: { in: docIds },
|
||||
},
|
||||
select: {
|
||||
docId: true,
|
||||
},
|
||||
})
|
||||
.then(r => r.map(r => r.docId));
|
||||
return new Set(existsIds);
|
||||
}
|
||||
|
||||
private processEmbeddings(
|
||||
contextOrWorkspaceId: string,
|
||||
fileOrDocId: string,
|
||||
embeddings: Embedding[],
|
||||
withId = true
|
||||
) {
|
||||
const groups = embeddings.map(e =>
|
||||
[
|
||||
withId ? randomUUID() : undefined,
|
||||
contextOrWorkspaceId,
|
||||
fileOrDocId,
|
||||
e.index,
|
||||
e.content,
|
||||
Prisma.raw(`'[${e.embedding.join(',')}]'`),
|
||||
new Date(),
|
||||
].filter(v => v !== undefined)
|
||||
);
|
||||
return Prisma.join(groups.map(row => Prisma.sql`(${Prisma.join(row)})`));
|
||||
}
|
||||
|
||||
async insertEmbedding(
|
||||
contextId: string,
|
||||
fileId: string,
|
||||
embeddings: Embedding[]
|
||||
) {
|
||||
const values = this.processEmbeddings(contextId, fileId, embeddings);
|
||||
|
||||
await this.db.$executeRaw`
|
||||
INSERT INTO "ai_context_embeddings"
|
||||
("id", "context_id", "file_id", "chunk", "content", "embedding", "updated_at") VALUES ${values}
|
||||
ON CONFLICT (context_id, file_id, chunk) DO UPDATE SET
|
||||
content = EXCLUDED.content, embedding = EXCLUDED.embedding, updated_at = excluded.updated_at;
|
||||
`;
|
||||
}
|
||||
|
||||
async matchEmbedding(
|
||||
embedding: number[],
|
||||
contextId: string,
|
||||
topK: number,
|
||||
threshold: number
|
||||
): Promise<ChunkSimilarity[]> {
|
||||
const similarityChunks = await this.db.$queryRaw<
|
||||
Array<FileChunkSimilarity>
|
||||
>`
|
||||
SELECT "file_id" as "fileId", "chunk", "content", "embedding" <=> ${embedding}::vector as "distance"
|
||||
FROM "ai_context_embeddings"
|
||||
WHERE context_id = ${contextId}
|
||||
ORDER BY "distance" ASC
|
||||
LIMIT ${topK};
|
||||
`;
|
||||
return similarityChunks.filter(c => Number(c.distance) <= threshold);
|
||||
}
|
||||
|
||||
async insertWorkspaceEmbedding(
|
||||
workspaceId: string,
|
||||
docId: string,
|
||||
embeddings: Embedding[]
|
||||
) {
|
||||
const values = this.processEmbeddings(
|
||||
workspaceId,
|
||||
docId,
|
||||
embeddings,
|
||||
false
|
||||
);
|
||||
await this.db.$executeRaw`
|
||||
INSERT INTO "ai_workspace_embeddings"
|
||||
("workspace_id", "doc_id", "chunk", "content", "embedding", "updated_at") VALUES ${values}
|
||||
ON CONFLICT (workspace_id, doc_id, chunk) DO UPDATE SET
|
||||
embedding = EXCLUDED.embedding, updated_at = excluded.updated_at;
|
||||
`;
|
||||
}
|
||||
|
||||
async matchWorkspaceEmbedding(
|
||||
embedding: number[],
|
||||
workspaceId: string,
|
||||
topK: number,
|
||||
threshold: number
|
||||
): Promise<ChunkSimilarity[]> {
|
||||
const similarityChunks = await this.db.$queryRaw<Array<DocChunkSimilarity>>`
|
||||
SELECT "doc_id" as "docId", "chunk", "content", "embedding" <=> ${embedding}::vector as "distance"
|
||||
FROM "ai_workspace_embeddings"
|
||||
WHERE "workspace_id" = ${workspaceId}
|
||||
ORDER BY "distance" ASC
|
||||
LIMIT ${topK};
|
||||
`;
|
||||
return similarityChunks.filter(c => Number(c.distance) <= threshold);
|
||||
}
|
||||
|
||||
async deleteEmbedding(contextId: string, fileId: string) {
|
||||
await this.db.aiContextEmbedding.deleteMany({
|
||||
where: { contextId, fileId },
|
||||
});
|
||||
}
|
||||
}
|
||||
36
packages/backend/server/src/models/copilot-session.ts
Normal file
36
packages/backend/server/src/models/copilot-session.ts
Normal file
@@ -0,0 +1,36 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
|
||||
import { BaseModel } from './base';
|
||||
|
||||
interface ChatSessionState {
|
||||
sessionId: string;
|
||||
workspaceId: string;
|
||||
docId: string;
|
||||
// connect ids
|
||||
userId: string;
|
||||
promptName: string;
|
||||
}
|
||||
|
||||
// TODO(@darkskygit): not ready to replace business codes yet, just for test
|
||||
@Injectable()
|
||||
export class CopilotSessionModel extends BaseModel {
|
||||
async create(state: ChatSessionState) {
|
||||
const row = await this.db.aiSession.create({
|
||||
data: {
|
||||
id: state.sessionId,
|
||||
workspaceId: state.workspaceId,
|
||||
docId: state.docId,
|
||||
// connect
|
||||
userId: state.userId,
|
||||
promptName: state.promptName,
|
||||
},
|
||||
});
|
||||
return row;
|
||||
}
|
||||
|
||||
async createPrompt(name: string, model: string) {
|
||||
await this.db.aiPrompt.create({
|
||||
data: { name, model },
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -7,7 +7,9 @@ import {
|
||||
import { ModuleRef } from '@nestjs/core';
|
||||
|
||||
import { ApplyType } from '../base';
|
||||
import { CopilotContextModel } from './copilot-context';
|
||||
import { CopilotJobModel } from './copilot-job';
|
||||
import { CopilotSessionModel } from './copilot-session';
|
||||
import { DocModel } from './doc';
|
||||
import { DocUserModel } from './doc-user';
|
||||
import { FeatureModel } from './feature';
|
||||
@@ -39,6 +41,8 @@ const MODELS = {
|
||||
history: HistoryModel,
|
||||
notification: NotificationModel,
|
||||
settings: SettingsModel,
|
||||
copilotSession: CopilotSessionModel,
|
||||
copilotContext: CopilotContextModel,
|
||||
copilotJob: CopilotJobModel,
|
||||
};
|
||||
|
||||
@@ -92,6 +96,7 @@ const ModelsSymbolProvider: ExistingProvider = {
|
||||
export class ModelsModule {}
|
||||
|
||||
export * from './common';
|
||||
export * from './copilot-context';
|
||||
export * from './copilot-job';
|
||||
export * from './doc';
|
||||
export * from './doc-user';
|
||||
|
||||
Reference in New Issue
Block a user