mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-25 18:26:05 +08:00
feat(server): impl context model (#11027)
This commit is contained in:
@@ -71,5 +71,8 @@ BEGIN -- check if pgvector extension is installed
|
|||||||
|
|
||||||
-- AddForeignKey
|
-- AddForeignKey
|
||||||
ALTER TABLE "ai_workspace_embeddings" ADD CONSTRAINT "ai_workspace_embeddings_workspace_id_doc_id_fkey" FOREIGN KEY ("workspace_id", "doc_id") REFERENCES "snapshots"("workspace_id", "guid") ON DELETE CASCADE ON UPDATE CASCADE;
|
ALTER TABLE "ai_workspace_embeddings" ADD CONSTRAINT "ai_workspace_embeddings_workspace_id_doc_id_fkey" FOREIGN KEY ("workspace_id", "doc_id") REFERENCES "snapshots"("workspace_id", "guid") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE UNIQUE INDEX "ai_workspace_embeddings_workspace_id_doc_id_chunk_key" ON "ai_workspace_embeddings"("workspace_id", "doc_id", "chunk");
|
||||||
END IF;
|
END IF;
|
||||||
END $$;
|
END $$;
|
||||||
|
|||||||
@@ -472,6 +472,7 @@ model AiWorkspaceEmbedding {
|
|||||||
snapshot Snapshot @relation(fields: [workspaceId, docId], references: [workspaceId, id], onDelete: Cascade)
|
snapshot Snapshot @relation(fields: [workspaceId, docId], references: [workspaceId, id], onDelete: Cascade)
|
||||||
|
|
||||||
@@id([workspaceId, docId])
|
@@id([workspaceId, docId])
|
||||||
|
@@unique([workspaceId, docId, chunk])
|
||||||
@@index([embedding], map: "ai_workspace_embeddings_idx")
|
@@index([embedding], map: "ai_workspace_embeddings_idx")
|
||||||
@@map("ai_workspace_embeddings")
|
@@map("ai_workspace_embeddings")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,187 @@
|
|||||||
|
import { randomUUID } from 'node:crypto';
|
||||||
|
|
||||||
|
import { AiSession, PrismaClient, User, Workspace } from '@prisma/client';
|
||||||
|
import ava, { TestFn } from 'ava';
|
||||||
|
|
||||||
|
import { Config } from '../../base';
|
||||||
|
import { CopilotContextModel } from '../../models/copilot-context';
|
||||||
|
import { CopilotSessionModel } from '../../models/copilot-session';
|
||||||
|
import { UserModel } from '../../models/user';
|
||||||
|
import { WorkspaceModel } from '../../models/workspace';
|
||||||
|
import { createTestingModule, type TestingModule } from '../utils';
|
||||||
|
|
||||||
|
interface Context {
|
||||||
|
config: Config;
|
||||||
|
module: TestingModule;
|
||||||
|
db: PrismaClient;
|
||||||
|
user: UserModel;
|
||||||
|
workspace: WorkspaceModel;
|
||||||
|
copilotSession: CopilotSessionModel;
|
||||||
|
copilotContext: CopilotContextModel;
|
||||||
|
}
|
||||||
|
|
||||||
|
const test = ava as TestFn<Context>;
|
||||||
|
|
||||||
|
test.before(async t => {
|
||||||
|
const module = await createTestingModule();
|
||||||
|
t.context.user = module.get(UserModel);
|
||||||
|
t.context.workspace = module.get(WorkspaceModel);
|
||||||
|
t.context.copilotSession = module.get(CopilotSessionModel);
|
||||||
|
t.context.copilotContext = module.get(CopilotContextModel);
|
||||||
|
t.context.db = module.get(PrismaClient);
|
||||||
|
t.context.config = module.get(Config);
|
||||||
|
t.context.module = module;
|
||||||
|
});
|
||||||
|
|
||||||
|
let user: User;
|
||||||
|
let workspace: Workspace;
|
||||||
|
let session: AiSession;
|
||||||
|
let docId = 'doc1';
|
||||||
|
|
||||||
|
test.beforeEach(async t => {
|
||||||
|
await t.context.module.initTestingDB();
|
||||||
|
await t.context.copilotSession.createPrompt('prompt-name', 'gpt-4o');
|
||||||
|
user = await t.context.user.create({
|
||||||
|
email: 'test@affine.pro',
|
||||||
|
});
|
||||||
|
workspace = await t.context.workspace.create(user.id);
|
||||||
|
session = await t.context.copilotSession.create({
|
||||||
|
sessionId: randomUUID(),
|
||||||
|
workspaceId: workspace.id,
|
||||||
|
docId,
|
||||||
|
userId: user.id,
|
||||||
|
promptName: 'prompt-name',
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test.after(async t => {
|
||||||
|
await t.context.module.close();
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should create a copilot context', async t => {
|
||||||
|
const { id: contextId } = await t.context.copilotContext.create(session.id);
|
||||||
|
t.truthy(contextId);
|
||||||
|
|
||||||
|
const context = await t.context.copilotContext.get(contextId);
|
||||||
|
t.is(context?.id, contextId, 'should get context by id');
|
||||||
|
|
||||||
|
const config = await t.context.copilotContext.getConfig(contextId);
|
||||||
|
t.is(config?.workspaceId, workspace.id, 'should get context config');
|
||||||
|
|
||||||
|
const context1 = await t.context.copilotContext.getBySessionId(session.id);
|
||||||
|
t.is(context1?.id, contextId, 'should get context by session id');
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should get null for non-exist job', async t => {
|
||||||
|
const job = await t.context.copilotContext.get('non-exist');
|
||||||
|
t.is(job, null);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should update context', async t => {
|
||||||
|
const { id: contextId } = await t.context.copilotContext.create(session.id);
|
||||||
|
const config = await t.context.copilotContext.getConfig(contextId);
|
||||||
|
|
||||||
|
const doc = {
|
||||||
|
id: docId,
|
||||||
|
createdAt: Date.now(),
|
||||||
|
};
|
||||||
|
config?.docs.push(doc);
|
||||||
|
await t.context.copilotContext.update(contextId, { config });
|
||||||
|
|
||||||
|
const config1 = await t.context.copilotContext.getConfig(contextId);
|
||||||
|
t.deepEqual(config1, config);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should insert embedding by doc id', async t => {
|
||||||
|
const { id: contextId } = await t.context.copilotContext.create(session.id);
|
||||||
|
|
||||||
|
{
|
||||||
|
await t.context.copilotContext.insertEmbedding(contextId, 'file-id', [
|
||||||
|
{
|
||||||
|
index: 0,
|
||||||
|
content: 'content',
|
||||||
|
embedding: Array.from({ length: 512 }, () => 1),
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
|
||||||
|
{
|
||||||
|
const ret = await t.context.copilotContext.matchEmbedding(
|
||||||
|
Array.from({ length: 512 }, () => 0.9),
|
||||||
|
contextId,
|
||||||
|
1,
|
||||||
|
1
|
||||||
|
);
|
||||||
|
t.is(ret.length, 1);
|
||||||
|
t.is(ret[0].content, 'content');
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
await t.context.copilotContext.deleteEmbedding(contextId, 'file-id');
|
||||||
|
const ret = await t.context.copilotContext.matchEmbedding(
|
||||||
|
Array.from({ length: 512 }, () => 0.9),
|
||||||
|
contextId,
|
||||||
|
1,
|
||||||
|
1
|
||||||
|
);
|
||||||
|
t.is(ret.length, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
await t.context.db.snapshot.create({
|
||||||
|
data: {
|
||||||
|
workspaceId: workspace.id,
|
||||||
|
id: docId,
|
||||||
|
blob: Buffer.from([1, 1]),
|
||||||
|
state: Buffer.from([1, 1]),
|
||||||
|
updatedAt: new Date(),
|
||||||
|
createdAt: new Date(),
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
await t.context.copilotContext.insertWorkspaceEmbedding(
|
||||||
|
workspace.id,
|
||||||
|
docId,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
index: 0,
|
||||||
|
content: 'content',
|
||||||
|
embedding: Array.from({ length: 512 }, () => 1),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
|
{
|
||||||
|
const ret = await t.context.copilotContext.hasWorkspaceEmbedding(
|
||||||
|
workspace.id,
|
||||||
|
[docId]
|
||||||
|
);
|
||||||
|
t.true(ret.has(docId), 'should return true when embedding exists');
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
const ret = await t.context.copilotContext.matchWorkspaceEmbedding(
|
||||||
|
Array.from({ length: 512 }, () => 0.9),
|
||||||
|
workspace.id,
|
||||||
|
1,
|
||||||
|
1
|
||||||
|
);
|
||||||
|
t.is(ret.length, 1);
|
||||||
|
t.is(ret[0].content, 'content');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should check embedding table', async t => {
|
||||||
|
{
|
||||||
|
const ret = await t.context.copilotContext.checkEmbeddingAvailable();
|
||||||
|
t.true(ret, 'should return true when embedding table is available');
|
||||||
|
}
|
||||||
|
|
||||||
|
// {
|
||||||
|
// await t.context.db
|
||||||
|
// .$executeRaw`DROP TABLE IF EXISTS "ai_context_embeddings"`;
|
||||||
|
// const ret = await t.context.copilotContext.checkEmbeddingAvailable();
|
||||||
|
// t.false(ret, 'should return false when embedding table is not available');
|
||||||
|
// }
|
||||||
|
});
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
import { AiJobStatus, AiJobType } from '@prisma/client';
|
import { AiJobStatus, AiJobType } from '@prisma/client';
|
||||||
import type { JsonValue } from '@prisma/client/runtime/library';
|
import type { JsonValue } from '@prisma/client/runtime/library';
|
||||||
|
import { z } from 'zod';
|
||||||
|
|
||||||
export interface CopilotJob {
|
export interface CopilotJob {
|
||||||
id?: string;
|
id?: string;
|
||||||
@@ -10,3 +11,91 @@ export interface CopilotJob {
|
|||||||
status?: AiJobStatus;
|
status?: AiJobStatus;
|
||||||
payload?: JsonValue;
|
payload?: JsonValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface CopilotContext {
|
||||||
|
id?: string;
|
||||||
|
sessionId: string;
|
||||||
|
config: JsonValue;
|
||||||
|
createdAt: Date;
|
||||||
|
updatedAt: Date;
|
||||||
|
}
|
||||||
|
|
||||||
|
export enum ContextEmbedStatus {
|
||||||
|
processing = 'processing',
|
||||||
|
finished = 'finished',
|
||||||
|
failed = 'failed',
|
||||||
|
}
|
||||||
|
|
||||||
|
export enum ContextCategories {
|
||||||
|
Tag = 'tag',
|
||||||
|
Collection = 'collection',
|
||||||
|
}
|
||||||
|
|
||||||
|
export const ContextDocSchema = z.object({
|
||||||
|
id: z.string(),
|
||||||
|
createdAt: z.number(),
|
||||||
|
});
|
||||||
|
|
||||||
|
export const ContextFileSchema = z.object({
|
||||||
|
id: z.string(),
|
||||||
|
chunkSize: z.number(),
|
||||||
|
name: z.string(),
|
||||||
|
status: z.enum([
|
||||||
|
ContextEmbedStatus.processing,
|
||||||
|
ContextEmbedStatus.finished,
|
||||||
|
ContextEmbedStatus.failed,
|
||||||
|
]),
|
||||||
|
error: z.string().nullable(),
|
||||||
|
blobId: z.string(),
|
||||||
|
createdAt: z.number(),
|
||||||
|
});
|
||||||
|
|
||||||
|
export const ContextCategorySchema = z.object({
|
||||||
|
id: z.string(),
|
||||||
|
type: z.enum([ContextCategories.Tag, ContextCategories.Collection]),
|
||||||
|
docs: ContextDocSchema.array(),
|
||||||
|
createdAt: z.number(),
|
||||||
|
});
|
||||||
|
|
||||||
|
export const ContextConfigSchema = z.object({
|
||||||
|
workspaceId: z.string(),
|
||||||
|
files: ContextFileSchema.array(),
|
||||||
|
docs: ContextDocSchema.array(),
|
||||||
|
categories: ContextCategorySchema.array(),
|
||||||
|
});
|
||||||
|
|
||||||
|
export const MinimalContextConfigSchema = ContextConfigSchema.pick({
|
||||||
|
workspaceId: true,
|
||||||
|
});
|
||||||
|
|
||||||
|
export type ContextCategory = z.infer<typeof ContextCategorySchema>;
|
||||||
|
export type ContextDoc = z.infer<typeof ContextDocSchema>;
|
||||||
|
export type ContextFile = z.infer<typeof ContextFileSchema>;
|
||||||
|
export type ContextConfig = z.infer<typeof ContextConfigSchema>;
|
||||||
|
export type ContextListItem = ContextDoc | ContextFile;
|
||||||
|
export type ContextList = ContextListItem[];
|
||||||
|
|
||||||
|
// embeddings
|
||||||
|
|
||||||
|
export type Embedding = {
|
||||||
|
/**
|
||||||
|
* The index of the embedding in the list of embeddings.
|
||||||
|
*/
|
||||||
|
index: number;
|
||||||
|
content: string;
|
||||||
|
embedding: Array<number>;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ChunkSimilarity = {
|
||||||
|
chunk: number;
|
||||||
|
content: string;
|
||||||
|
distance: number | null;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type FileChunkSimilarity = ChunkSimilarity & {
|
||||||
|
fileId: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type DocChunkSimilarity = ChunkSimilarity & {
|
||||||
|
docId: string;
|
||||||
|
};
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
export * from './copilot';
|
||||||
export * from './doc';
|
export * from './doc';
|
||||||
export * from './feature';
|
export * from './feature';
|
||||||
export * from './role';
|
export * from './role';
|
||||||
|
|||||||
233
packages/backend/server/src/models/copilot-context.ts
Normal file
233
packages/backend/server/src/models/copilot-context.ts
Normal file
@@ -0,0 +1,233 @@
|
|||||||
|
import { randomUUID } from 'node:crypto';
|
||||||
|
|
||||||
|
import { Injectable } from '@nestjs/common';
|
||||||
|
import { Prisma } from '@prisma/client';
|
||||||
|
|
||||||
|
import { CopilotSessionNotFound } from '../base';
|
||||||
|
import { BaseModel } from './base';
|
||||||
|
import {
|
||||||
|
ChunkSimilarity,
|
||||||
|
ContextConfigSchema,
|
||||||
|
ContextDoc,
|
||||||
|
ContextEmbedStatus,
|
||||||
|
CopilotContext,
|
||||||
|
DocChunkSimilarity,
|
||||||
|
Embedding,
|
||||||
|
FileChunkSimilarity,
|
||||||
|
MinimalContextConfigSchema,
|
||||||
|
} from './common/copilot';
|
||||||
|
|
||||||
|
type UpdateCopilotContextInput = Pick<CopilotContext, 'config'>;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Copilot Job Model
|
||||||
|
*/
|
||||||
|
@Injectable()
|
||||||
|
export class CopilotContextModel extends BaseModel {
|
||||||
|
// contexts
|
||||||
|
|
||||||
|
async create(sessionId: string) {
|
||||||
|
const session = await this.db.aiSession.findFirst({
|
||||||
|
where: { id: sessionId },
|
||||||
|
select: { workspaceId: true },
|
||||||
|
});
|
||||||
|
if (!session) {
|
||||||
|
throw new CopilotSessionNotFound();
|
||||||
|
}
|
||||||
|
|
||||||
|
const row = await this.db.aiContext.create({
|
||||||
|
data: {
|
||||||
|
sessionId,
|
||||||
|
config: {
|
||||||
|
workspaceId: session.workspaceId,
|
||||||
|
docs: [],
|
||||||
|
files: [],
|
||||||
|
categories: [],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
return row;
|
||||||
|
}
|
||||||
|
|
||||||
|
async get(id: string) {
|
||||||
|
const row = await this.db.aiContext.findFirst({
|
||||||
|
where: { id },
|
||||||
|
});
|
||||||
|
return row;
|
||||||
|
}
|
||||||
|
|
||||||
|
async getConfig(id: string) {
|
||||||
|
const row = await this.get(id);
|
||||||
|
if (row) {
|
||||||
|
const config = ContextConfigSchema.safeParse(row.config);
|
||||||
|
if (config.success) {
|
||||||
|
return config.data;
|
||||||
|
}
|
||||||
|
const minimalConfig = MinimalContextConfigSchema.safeParse(row.config);
|
||||||
|
if (minimalConfig.success) {
|
||||||
|
// fulfill the missing fields
|
||||||
|
return {
|
||||||
|
...minimalConfig.data,
|
||||||
|
docs: [],
|
||||||
|
files: [],
|
||||||
|
categories: [],
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
async getBySessionId(sessionId: string) {
|
||||||
|
const row = await this.db.aiContext.findFirst({
|
||||||
|
where: { sessionId },
|
||||||
|
});
|
||||||
|
return row;
|
||||||
|
}
|
||||||
|
|
||||||
|
async mergeDocStatus(
|
||||||
|
workspaceId: string,
|
||||||
|
docs: (ContextDoc & { status?: ContextEmbedStatus | null })[]
|
||||||
|
) {
|
||||||
|
const docIds = Array.from(new Set(docs.map(doc => doc.id)));
|
||||||
|
const finishedDoc = await this.hasWorkspaceEmbedding(workspaceId, docIds);
|
||||||
|
|
||||||
|
for (const doc of docs) {
|
||||||
|
const status = finishedDoc.has(doc.id)
|
||||||
|
? ContextEmbedStatus.finished
|
||||||
|
: null;
|
||||||
|
doc.status = status;
|
||||||
|
}
|
||||||
|
|
||||||
|
return docs;
|
||||||
|
}
|
||||||
|
|
||||||
|
async update(contextId: string, data: UpdateCopilotContextInput) {
|
||||||
|
const ret = await this.db.aiContext.updateMany({
|
||||||
|
where: {
|
||||||
|
id: contextId,
|
||||||
|
},
|
||||||
|
data: {
|
||||||
|
config: data.config || undefined,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
return ret.count > 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// embeddings
|
||||||
|
|
||||||
|
async checkEmbeddingAvailable(): Promise<boolean> {
|
||||||
|
const [{ count }] = await this.db.$queryRaw<
|
||||||
|
{ count: number }[]
|
||||||
|
>`SELECT count(1) FROM pg_tables WHERE tablename in ('ai_context_embeddings', 'ai_workspace_embeddings')`;
|
||||||
|
return Number(count) === 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
async hasWorkspaceEmbedding(workspaceId: string, docIds: string[]) {
|
||||||
|
const existsIds = await this.db.aiWorkspaceEmbedding
|
||||||
|
.findMany({
|
||||||
|
where: {
|
||||||
|
workspaceId,
|
||||||
|
docId: { in: docIds },
|
||||||
|
},
|
||||||
|
select: {
|
||||||
|
docId: true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.then(r => r.map(r => r.docId));
|
||||||
|
return new Set(existsIds);
|
||||||
|
}
|
||||||
|
|
||||||
|
private processEmbeddings(
|
||||||
|
contextOrWorkspaceId: string,
|
||||||
|
fileOrDocId: string,
|
||||||
|
embeddings: Embedding[],
|
||||||
|
withId = true
|
||||||
|
) {
|
||||||
|
const groups = embeddings.map(e =>
|
||||||
|
[
|
||||||
|
withId ? randomUUID() : undefined,
|
||||||
|
contextOrWorkspaceId,
|
||||||
|
fileOrDocId,
|
||||||
|
e.index,
|
||||||
|
e.content,
|
||||||
|
Prisma.raw(`'[${e.embedding.join(',')}]'`),
|
||||||
|
new Date(),
|
||||||
|
].filter(v => v !== undefined)
|
||||||
|
);
|
||||||
|
return Prisma.join(groups.map(row => Prisma.sql`(${Prisma.join(row)})`));
|
||||||
|
}
|
||||||
|
|
||||||
|
async insertEmbedding(
|
||||||
|
contextId: string,
|
||||||
|
fileId: string,
|
||||||
|
embeddings: Embedding[]
|
||||||
|
) {
|
||||||
|
const values = this.processEmbeddings(contextId, fileId, embeddings);
|
||||||
|
|
||||||
|
await this.db.$executeRaw`
|
||||||
|
INSERT INTO "ai_context_embeddings"
|
||||||
|
("id", "context_id", "file_id", "chunk", "content", "embedding", "updated_at") VALUES ${values}
|
||||||
|
ON CONFLICT (context_id, file_id, chunk) DO UPDATE SET
|
||||||
|
content = EXCLUDED.content, embedding = EXCLUDED.embedding, updated_at = excluded.updated_at;
|
||||||
|
`;
|
||||||
|
}
|
||||||
|
|
||||||
|
async matchEmbedding(
|
||||||
|
embedding: number[],
|
||||||
|
contextId: string,
|
||||||
|
topK: number,
|
||||||
|
threshold: number
|
||||||
|
): Promise<ChunkSimilarity[]> {
|
||||||
|
const similarityChunks = await this.db.$queryRaw<
|
||||||
|
Array<FileChunkSimilarity>
|
||||||
|
>`
|
||||||
|
SELECT "file_id" as "fileId", "chunk", "content", "embedding" <=> ${embedding}::vector as "distance"
|
||||||
|
FROM "ai_context_embeddings"
|
||||||
|
WHERE context_id = ${contextId}
|
||||||
|
ORDER BY "distance" ASC
|
||||||
|
LIMIT ${topK};
|
||||||
|
`;
|
||||||
|
return similarityChunks.filter(c => Number(c.distance) <= threshold);
|
||||||
|
}
|
||||||
|
|
||||||
|
async insertWorkspaceEmbedding(
|
||||||
|
workspaceId: string,
|
||||||
|
docId: string,
|
||||||
|
embeddings: Embedding[]
|
||||||
|
) {
|
||||||
|
const values = this.processEmbeddings(
|
||||||
|
workspaceId,
|
||||||
|
docId,
|
||||||
|
embeddings,
|
||||||
|
false
|
||||||
|
);
|
||||||
|
await this.db.$executeRaw`
|
||||||
|
INSERT INTO "ai_workspace_embeddings"
|
||||||
|
("workspace_id", "doc_id", "chunk", "content", "embedding", "updated_at") VALUES ${values}
|
||||||
|
ON CONFLICT (workspace_id, doc_id, chunk) DO UPDATE SET
|
||||||
|
embedding = EXCLUDED.embedding, updated_at = excluded.updated_at;
|
||||||
|
`;
|
||||||
|
}
|
||||||
|
|
||||||
|
async matchWorkspaceEmbedding(
|
||||||
|
embedding: number[],
|
||||||
|
workspaceId: string,
|
||||||
|
topK: number,
|
||||||
|
threshold: number
|
||||||
|
): Promise<ChunkSimilarity[]> {
|
||||||
|
const similarityChunks = await this.db.$queryRaw<Array<DocChunkSimilarity>>`
|
||||||
|
SELECT "doc_id" as "docId", "chunk", "content", "embedding" <=> ${embedding}::vector as "distance"
|
||||||
|
FROM "ai_workspace_embeddings"
|
||||||
|
WHERE "workspace_id" = ${workspaceId}
|
||||||
|
ORDER BY "distance" ASC
|
||||||
|
LIMIT ${topK};
|
||||||
|
`;
|
||||||
|
return similarityChunks.filter(c => Number(c.distance) <= threshold);
|
||||||
|
}
|
||||||
|
|
||||||
|
async deleteEmbedding(contextId: string, fileId: string) {
|
||||||
|
await this.db.aiContextEmbedding.deleteMany({
|
||||||
|
where: { contextId, fileId },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
36
packages/backend/server/src/models/copilot-session.ts
Normal file
36
packages/backend/server/src/models/copilot-session.ts
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
import { Injectable } from '@nestjs/common';
|
||||||
|
|
||||||
|
import { BaseModel } from './base';
|
||||||
|
|
||||||
|
interface ChatSessionState {
|
||||||
|
sessionId: string;
|
||||||
|
workspaceId: string;
|
||||||
|
docId: string;
|
||||||
|
// connect ids
|
||||||
|
userId: string;
|
||||||
|
promptName: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(@darkskygit): not ready to replace business codes yet, just for test
|
||||||
|
@Injectable()
|
||||||
|
export class CopilotSessionModel extends BaseModel {
|
||||||
|
async create(state: ChatSessionState) {
|
||||||
|
const row = await this.db.aiSession.create({
|
||||||
|
data: {
|
||||||
|
id: state.sessionId,
|
||||||
|
workspaceId: state.workspaceId,
|
||||||
|
docId: state.docId,
|
||||||
|
// connect
|
||||||
|
userId: state.userId,
|
||||||
|
promptName: state.promptName,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
return row;
|
||||||
|
}
|
||||||
|
|
||||||
|
async createPrompt(name: string, model: string) {
|
||||||
|
await this.db.aiPrompt.create({
|
||||||
|
data: { name, model },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -7,7 +7,9 @@ import {
|
|||||||
import { ModuleRef } from '@nestjs/core';
|
import { ModuleRef } from '@nestjs/core';
|
||||||
|
|
||||||
import { ApplyType } from '../base';
|
import { ApplyType } from '../base';
|
||||||
|
import { CopilotContextModel } from './copilot-context';
|
||||||
import { CopilotJobModel } from './copilot-job';
|
import { CopilotJobModel } from './copilot-job';
|
||||||
|
import { CopilotSessionModel } from './copilot-session';
|
||||||
import { DocModel } from './doc';
|
import { DocModel } from './doc';
|
||||||
import { DocUserModel } from './doc-user';
|
import { DocUserModel } from './doc-user';
|
||||||
import { FeatureModel } from './feature';
|
import { FeatureModel } from './feature';
|
||||||
@@ -39,6 +41,8 @@ const MODELS = {
|
|||||||
history: HistoryModel,
|
history: HistoryModel,
|
||||||
notification: NotificationModel,
|
notification: NotificationModel,
|
||||||
settings: SettingsModel,
|
settings: SettingsModel,
|
||||||
|
copilotSession: CopilotSessionModel,
|
||||||
|
copilotContext: CopilotContextModel,
|
||||||
copilotJob: CopilotJobModel,
|
copilotJob: CopilotJobModel,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -92,6 +96,7 @@ const ModelsSymbolProvider: ExistingProvider = {
|
|||||||
export class ModelsModule {}
|
export class ModelsModule {}
|
||||||
|
|
||||||
export * from './common';
|
export * from './common';
|
||||||
|
export * from './copilot-context';
|
||||||
export * from './copilot-job';
|
export * from './copilot-job';
|
||||||
export * from './doc';
|
export * from './doc';
|
||||||
export * from './doc-user';
|
export * from './doc-user';
|
||||||
|
|||||||
Reference in New Issue
Block a user