feat(server): impl context model (#11027)

This commit is contained in:
darkskygit
2025-03-20 10:24:28 +00:00
parent c1ec17ccba
commit b24376a9f7
8 changed files with 555 additions and 0 deletions

View File

@@ -0,0 +1,187 @@
import { randomUUID } from 'node:crypto';
import { AiSession, PrismaClient, User, Workspace } from '@prisma/client';
import ava, { TestFn } from 'ava';
import { Config } from '../../base';
import { CopilotContextModel } from '../../models/copilot-context';
import { CopilotSessionModel } from '../../models/copilot-session';
import { UserModel } from '../../models/user';
import { WorkspaceModel } from '../../models/workspace';
import { createTestingModule, type TestingModule } from '../utils';
interface Context {
config: Config;
module: TestingModule;
db: PrismaClient;
user: UserModel;
workspace: WorkspaceModel;
copilotSession: CopilotSessionModel;
copilotContext: CopilotContextModel;
}
const test = ava as TestFn<Context>;
test.before(async t => {
const module = await createTestingModule();
t.context.user = module.get(UserModel);
t.context.workspace = module.get(WorkspaceModel);
t.context.copilotSession = module.get(CopilotSessionModel);
t.context.copilotContext = module.get(CopilotContextModel);
t.context.db = module.get(PrismaClient);
t.context.config = module.get(Config);
t.context.module = module;
});
let user: User;
let workspace: Workspace;
let session: AiSession;
let docId = 'doc1';
test.beforeEach(async t => {
await t.context.module.initTestingDB();
await t.context.copilotSession.createPrompt('prompt-name', 'gpt-4o');
user = await t.context.user.create({
email: 'test@affine.pro',
});
workspace = await t.context.workspace.create(user.id);
session = await t.context.copilotSession.create({
sessionId: randomUUID(),
workspaceId: workspace.id,
docId,
userId: user.id,
promptName: 'prompt-name',
});
});
test.after(async t => {
await t.context.module.close();
});
test('should create a copilot context', async t => {
const { id: contextId } = await t.context.copilotContext.create(session.id);
t.truthy(contextId);
const context = await t.context.copilotContext.get(contextId);
t.is(context?.id, contextId, 'should get context by id');
const config = await t.context.copilotContext.getConfig(contextId);
t.is(config?.workspaceId, workspace.id, 'should get context config');
const context1 = await t.context.copilotContext.getBySessionId(session.id);
t.is(context1?.id, contextId, 'should get context by session id');
});
test('should get null for non-exist job', async t => {
const job = await t.context.copilotContext.get('non-exist');
t.is(job, null);
});
test('should update context', async t => {
const { id: contextId } = await t.context.copilotContext.create(session.id);
const config = await t.context.copilotContext.getConfig(contextId);
const doc = {
id: docId,
createdAt: Date.now(),
};
config?.docs.push(doc);
await t.context.copilotContext.update(contextId, { config });
const config1 = await t.context.copilotContext.getConfig(contextId);
t.deepEqual(config1, config);
});
test('should insert embedding by doc id', async t => {
const { id: contextId } = await t.context.copilotContext.create(session.id);
{
await t.context.copilotContext.insertEmbedding(contextId, 'file-id', [
{
index: 0,
content: 'content',
embedding: Array.from({ length: 512 }, () => 1),
},
]);
{
const ret = await t.context.copilotContext.matchEmbedding(
Array.from({ length: 512 }, () => 0.9),
contextId,
1,
1
);
t.is(ret.length, 1);
t.is(ret[0].content, 'content');
}
{
await t.context.copilotContext.deleteEmbedding(contextId, 'file-id');
const ret = await t.context.copilotContext.matchEmbedding(
Array.from({ length: 512 }, () => 0.9),
contextId,
1,
1
);
t.is(ret.length, 0);
}
}
{
await t.context.db.snapshot.create({
data: {
workspaceId: workspace.id,
id: docId,
blob: Buffer.from([1, 1]),
state: Buffer.from([1, 1]),
updatedAt: new Date(),
createdAt: new Date(),
},
});
await t.context.copilotContext.insertWorkspaceEmbedding(
workspace.id,
docId,
[
{
index: 0,
content: 'content',
embedding: Array.from({ length: 512 }, () => 1),
},
]
);
{
const ret = await t.context.copilotContext.hasWorkspaceEmbedding(
workspace.id,
[docId]
);
t.true(ret.has(docId), 'should return true when embedding exists');
}
{
const ret = await t.context.copilotContext.matchWorkspaceEmbedding(
Array.from({ length: 512 }, () => 0.9),
workspace.id,
1,
1
);
t.is(ret.length, 1);
t.is(ret[0].content, 'content');
}
}
});
test('should check embedding table', async t => {
{
const ret = await t.context.copilotContext.checkEmbeddingAvailable();
t.true(ret, 'should return true when embedding table is available');
}
// {
// await t.context.db
// .$executeRaw`DROP TABLE IF EXISTS "ai_context_embeddings"`;
// const ret = await t.context.copilotContext.checkEmbeddingAvailable();
// t.false(ret, 'should return false when embedding table is not available');
// }
});

View File

@@ -1,5 +1,6 @@
import { AiJobStatus, AiJobType } from '@prisma/client';
import type { JsonValue } from '@prisma/client/runtime/library';
import { z } from 'zod';
export interface CopilotJob {
id?: string;
@@ -10,3 +11,91 @@ export interface CopilotJob {
status?: AiJobStatus;
payload?: JsonValue;
}
export interface CopilotContext {
id?: string;
sessionId: string;
config: JsonValue;
createdAt: Date;
updatedAt: Date;
}
export enum ContextEmbedStatus {
processing = 'processing',
finished = 'finished',
failed = 'failed',
}
export enum ContextCategories {
Tag = 'tag',
Collection = 'collection',
}
export const ContextDocSchema = z.object({
id: z.string(),
createdAt: z.number(),
});
export const ContextFileSchema = z.object({
id: z.string(),
chunkSize: z.number(),
name: z.string(),
status: z.enum([
ContextEmbedStatus.processing,
ContextEmbedStatus.finished,
ContextEmbedStatus.failed,
]),
error: z.string().nullable(),
blobId: z.string(),
createdAt: z.number(),
});
export const ContextCategorySchema = z.object({
id: z.string(),
type: z.enum([ContextCategories.Tag, ContextCategories.Collection]),
docs: ContextDocSchema.array(),
createdAt: z.number(),
});
export const ContextConfigSchema = z.object({
workspaceId: z.string(),
files: ContextFileSchema.array(),
docs: ContextDocSchema.array(),
categories: ContextCategorySchema.array(),
});
export const MinimalContextConfigSchema = ContextConfigSchema.pick({
workspaceId: true,
});
export type ContextCategory = z.infer<typeof ContextCategorySchema>;
export type ContextDoc = z.infer<typeof ContextDocSchema>;
export type ContextFile = z.infer<typeof ContextFileSchema>;
export type ContextConfig = z.infer<typeof ContextConfigSchema>;
export type ContextListItem = ContextDoc | ContextFile;
export type ContextList = ContextListItem[];
// embeddings
export type Embedding = {
/**
* The index of the embedding in the list of embeddings.
*/
index: number;
content: string;
embedding: Array<number>;
};
export type ChunkSimilarity = {
chunk: number;
content: string;
distance: number | null;
};
export type FileChunkSimilarity = ChunkSimilarity & {
fileId: string;
};
export type DocChunkSimilarity = ChunkSimilarity & {
docId: string;
};

View File

@@ -1,3 +1,4 @@
export * from './copilot';
export * from './doc';
export * from './feature';
export * from './role';

View File

@@ -0,0 +1,233 @@
import { randomUUID } from 'node:crypto';
import { Injectable } from '@nestjs/common';
import { Prisma } from '@prisma/client';
import { CopilotSessionNotFound } from '../base';
import { BaseModel } from './base';
import {
ChunkSimilarity,
ContextConfigSchema,
ContextDoc,
ContextEmbedStatus,
CopilotContext,
DocChunkSimilarity,
Embedding,
FileChunkSimilarity,
MinimalContextConfigSchema,
} from './common/copilot';
type UpdateCopilotContextInput = Pick<CopilotContext, 'config'>;
/**
* Copilot Job Model
*/
@Injectable()
export class CopilotContextModel extends BaseModel {
// contexts
async create(sessionId: string) {
const session = await this.db.aiSession.findFirst({
where: { id: sessionId },
select: { workspaceId: true },
});
if (!session) {
throw new CopilotSessionNotFound();
}
const row = await this.db.aiContext.create({
data: {
sessionId,
config: {
workspaceId: session.workspaceId,
docs: [],
files: [],
categories: [],
},
},
});
return row;
}
async get(id: string) {
const row = await this.db.aiContext.findFirst({
where: { id },
});
return row;
}
async getConfig(id: string) {
const row = await this.get(id);
if (row) {
const config = ContextConfigSchema.safeParse(row.config);
if (config.success) {
return config.data;
}
const minimalConfig = MinimalContextConfigSchema.safeParse(row.config);
if (minimalConfig.success) {
// fulfill the missing fields
return {
...minimalConfig.data,
docs: [],
files: [],
categories: [],
};
}
}
return null;
}
async getBySessionId(sessionId: string) {
const row = await this.db.aiContext.findFirst({
where: { sessionId },
});
return row;
}
async mergeDocStatus(
workspaceId: string,
docs: (ContextDoc & { status?: ContextEmbedStatus | null })[]
) {
const docIds = Array.from(new Set(docs.map(doc => doc.id)));
const finishedDoc = await this.hasWorkspaceEmbedding(workspaceId, docIds);
for (const doc of docs) {
const status = finishedDoc.has(doc.id)
? ContextEmbedStatus.finished
: null;
doc.status = status;
}
return docs;
}
async update(contextId: string, data: UpdateCopilotContextInput) {
const ret = await this.db.aiContext.updateMany({
where: {
id: contextId,
},
data: {
config: data.config || undefined,
},
});
return ret.count > 0;
}
// embeddings
async checkEmbeddingAvailable(): Promise<boolean> {
const [{ count }] = await this.db.$queryRaw<
{ count: number }[]
>`SELECT count(1) FROM pg_tables WHERE tablename in ('ai_context_embeddings', 'ai_workspace_embeddings')`;
return Number(count) === 2;
}
async hasWorkspaceEmbedding(workspaceId: string, docIds: string[]) {
const existsIds = await this.db.aiWorkspaceEmbedding
.findMany({
where: {
workspaceId,
docId: { in: docIds },
},
select: {
docId: true,
},
})
.then(r => r.map(r => r.docId));
return new Set(existsIds);
}
private processEmbeddings(
contextOrWorkspaceId: string,
fileOrDocId: string,
embeddings: Embedding[],
withId = true
) {
const groups = embeddings.map(e =>
[
withId ? randomUUID() : undefined,
contextOrWorkspaceId,
fileOrDocId,
e.index,
e.content,
Prisma.raw(`'[${e.embedding.join(',')}]'`),
new Date(),
].filter(v => v !== undefined)
);
return Prisma.join(groups.map(row => Prisma.sql`(${Prisma.join(row)})`));
}
async insertEmbedding(
contextId: string,
fileId: string,
embeddings: Embedding[]
) {
const values = this.processEmbeddings(contextId, fileId, embeddings);
await this.db.$executeRaw`
INSERT INTO "ai_context_embeddings"
("id", "context_id", "file_id", "chunk", "content", "embedding", "updated_at") VALUES ${values}
ON CONFLICT (context_id, file_id, chunk) DO UPDATE SET
content = EXCLUDED.content, embedding = EXCLUDED.embedding, updated_at = excluded.updated_at;
`;
}
async matchEmbedding(
embedding: number[],
contextId: string,
topK: number,
threshold: number
): Promise<ChunkSimilarity[]> {
const similarityChunks = await this.db.$queryRaw<
Array<FileChunkSimilarity>
>`
SELECT "file_id" as "fileId", "chunk", "content", "embedding" <=> ${embedding}::vector as "distance"
FROM "ai_context_embeddings"
WHERE context_id = ${contextId}
ORDER BY "distance" ASC
LIMIT ${topK};
`;
return similarityChunks.filter(c => Number(c.distance) <= threshold);
}
async insertWorkspaceEmbedding(
workspaceId: string,
docId: string,
embeddings: Embedding[]
) {
const values = this.processEmbeddings(
workspaceId,
docId,
embeddings,
false
);
await this.db.$executeRaw`
INSERT INTO "ai_workspace_embeddings"
("workspace_id", "doc_id", "chunk", "content", "embedding", "updated_at") VALUES ${values}
ON CONFLICT (workspace_id, doc_id, chunk) DO UPDATE SET
embedding = EXCLUDED.embedding, updated_at = excluded.updated_at;
`;
}
async matchWorkspaceEmbedding(
embedding: number[],
workspaceId: string,
topK: number,
threshold: number
): Promise<ChunkSimilarity[]> {
const similarityChunks = await this.db.$queryRaw<Array<DocChunkSimilarity>>`
SELECT "doc_id" as "docId", "chunk", "content", "embedding" <=> ${embedding}::vector as "distance"
FROM "ai_workspace_embeddings"
WHERE "workspace_id" = ${workspaceId}
ORDER BY "distance" ASC
LIMIT ${topK};
`;
return similarityChunks.filter(c => Number(c.distance) <= threshold);
}
async deleteEmbedding(contextId: string, fileId: string) {
await this.db.aiContextEmbedding.deleteMany({
where: { contextId, fileId },
});
}
}

View File

@@ -0,0 +1,36 @@
import { Injectable } from '@nestjs/common';
import { BaseModel } from './base';
interface ChatSessionState {
sessionId: string;
workspaceId: string;
docId: string;
// connect ids
userId: string;
promptName: string;
}
// TODO(@darkskygit): not ready to replace business codes yet, just for test
@Injectable()
export class CopilotSessionModel extends BaseModel {
async create(state: ChatSessionState) {
const row = await this.db.aiSession.create({
data: {
id: state.sessionId,
workspaceId: state.workspaceId,
docId: state.docId,
// connect
userId: state.userId,
promptName: state.promptName,
},
});
return row;
}
async createPrompt(name: string, model: string) {
await this.db.aiPrompt.create({
data: { name, model },
});
}
}

View File

@@ -7,7 +7,9 @@ import {
import { ModuleRef } from '@nestjs/core';
import { ApplyType } from '../base';
import { CopilotContextModel } from './copilot-context';
import { CopilotJobModel } from './copilot-job';
import { CopilotSessionModel } from './copilot-session';
import { DocModel } from './doc';
import { DocUserModel } from './doc-user';
import { FeatureModel } from './feature';
@@ -39,6 +41,8 @@ const MODELS = {
history: HistoryModel,
notification: NotificationModel,
settings: SettingsModel,
copilotSession: CopilotSessionModel,
copilotContext: CopilotContextModel,
copilotJob: CopilotJobModel,
};
@@ -92,6 +96,7 @@ const ModelsSymbolProvider: ExistingProvider = {
export class ModelsModule {}
export * from './common';
export * from './copilot-context';
export * from './copilot-job';
export * from './doc';
export * from './doc-user';