feat(server): adapt context model (#11028)

expose more field in listContextObject
This commit is contained in:
darkskygit
2025-03-21 05:36:45 +00:00
parent a5b975ac46
commit 5acba9d5a0
25 changed files with 537 additions and 377 deletions

View File

@@ -37,7 +37,7 @@ BEGIN -- check if pgvector extension is installed
"file_id" VARCHAR NOT NULL,
"chunk" INTEGER NOT NULL,
"content" VARCHAR NOT NULL,
"embedding" vector(512) NOT NULL,
"embedding" vector(1024) NOT NULL,
"created_at" TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" TIMESTAMPTZ(3) NOT NULL,
@@ -50,7 +50,7 @@ BEGIN -- check if pgvector extension is installed
"doc_id" VARCHAR NOT NULL,
"chunk" INTEGER NOT NULL,
"content" VARCHAR NOT NULL,
"embedding" vector(512) NOT NULL,
"embedding" vector(1024) NOT NULL,
"created_at" TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" TIMESTAMPTZ(3) NOT NULL,

View File

@@ -1,6 +1,7 @@
import { randomUUID } from 'node:crypto';
import { ProjectRoot } from '@affine-tools/utils/path';
import { PrismaClient } from '@prisma/client';
import type { TestFn } from 'ava';
import ava from 'ava';
import Sinon from 'sinon';
@@ -9,6 +10,7 @@ import { EventBus } from '../base';
import { ConfigModule } from '../base/config';
import { AuthService } from '../core/auth';
import { QuotaModule } from '../core/quota';
import { ContextCategories } from '../models';
import { CopilotModule } from '../plugins/copilot';
import {
CopilotContextDocJob,
@@ -54,6 +56,7 @@ import { MockCopilotTestProvider, WorkflowTestCases } from './utils/copilot';
const test = ava as TestFn<{
auth: AuthService;
module: TestingModule;
db: PrismaClient;
event: EventBus;
context: CopilotContextService;
prompt: PromptService;
@@ -95,6 +98,7 @@ test.before(async t => {
});
const auth = module.get(AuthService);
const db = module.get(PrismaClient);
const event = module.get(EventBus);
const context = module.get(CopilotContextService);
const prompt = module.get(PromptService);
@@ -106,6 +110,7 @@ test.before(async t => {
t.context.module = module;
t.context.auth = auth;
t.context.db = db;
t.context.event = event;
t.context.context = context;
t.context.prompt = prompt;
@@ -1338,47 +1343,112 @@ test('should be able to manage context', async t => {
{
const session = await context.create(chatSession);
await storage.put(userId, session.workspaceId, 'blob', buffer);
// file record
{
await storage.put(userId, session.workspaceId, 'blob', buffer);
const file = await session.addFile('blob', 'sample.pdf');
const file = await session.addFile('blob', 'sample.pdf');
const handler = Sinon.spy(event, 'emit');
const handler = Sinon.spy(event, 'emit');
await jobs.embedPendingFile({
userId,
workspaceId: session.workspaceId,
contextId: session.id,
blobId: file.blobId,
fileId: file.id,
fileName: file.name,
});
await jobs.embedPendingFile({
userId,
workspaceId: session.workspaceId,
contextId: session.id,
blobId: file.blobId,
fileId: file.id,
fileName: file.name,
t.deepEqual(handler.lastCall.args, [
'workspace.file.embed.finished',
{
contextId: session.id,
fileId: file.id,
chunkSize: 1,
},
]);
const list = session.files;
t.deepEqual(
list.map(f => f.id),
[file.id],
'should list file id'
);
const result = await session.matchFileChunks('test', 1, undefined, 1);
t.is(result.length, 1, 'should match context');
t.is(result[0].fileId, file.id, 'should match file id');
}
// doc record
const docId = randomUUID();
await t.context.db.snapshot.create({
data: {
workspaceId: session.workspaceId,
id: docId,
blob: Buffer.from([1, 1]),
state: Buffer.from([1, 1]),
updatedAt: new Date(),
createdAt: new Date(),
},
});
t.deepEqual(handler.lastCall.args, [
'workspace.file.embed.finished',
{
contextId: session.id,
fileId: file.id,
chunkSize: 1,
},
]);
{
await session.addDocRecord(docId);
const docs = session.docs.map(d => d.id);
t.deepEqual(docs, [docId], 'should list doc id');
const list = session.listFiles();
t.deepEqual(
list.map(f => f.id),
[file.id],
'should list file id'
);
await session.removeDocRecord(docId);
t.deepEqual(session.docs, [], 'should remove doc id');
}
const docId = randomUUID();
await session.addDocRecord(docId);
const docs = session.listDocs().map(d => d.id);
t.deepEqual(docs, [docId], 'should list doc id');
// tag record
{
const tagId = randomUUID();
await session.addCategoryRecord(ContextCategories.Tag, tagId, [docId]);
const tags = session.tags.map(t => t.id);
t.deepEqual(tags, [tagId], 'should list tag id');
await session.removeDocRecord(docId);
t.deepEqual(session.listDocs(), [], 'should remove doc id');
await session.removeCategoryRecord(ContextCategories.Tag, tagId);
t.deepEqual(session.tags, [], 'should remove tag id');
const result = await session.matchFileChunks('test', 1, undefined, 1);
t.is(result.length, 1, 'should match context');
t.is(result[0].fileId, file.id, 'should match file id');
await t.throwsAsync(
session.addCategoryRecord(ContextCategories.Tag, tagId, [
'not-exists-doc',
]),
{
instanceOf: Error,
},
'should throw error if doc id not exists'
);
}
// collection record
{
const collectionId = randomUUID();
await session.addCategoryRecord(
ContextCategories.Collection,
collectionId,
[docId]
);
const collection = session.collections.map(l => l.id);
t.deepEqual(collection, [collectionId], 'should list collection id');
await session.removeCategoryRecord(
ContextCategories.Collection,
collectionId
);
t.deepEqual(session.collections, [], 'should remove collection id');
await t.throwsAsync(
session.addCategoryRecord(ContextCategories.Collection, collectionId, [
'not-exists-doc',
]),
{
instanceOf: Error,
},
'should throw error if doc id not exists'
);
}
}
});

View File

@@ -84,6 +84,7 @@ test('should update context', async t => {
const doc = {
id: docId,
createdAt: Date.now(),
status: null,
};
config?.docs.push(doc);
await t.context.copilotContext.update(contextId, { config });
@@ -96,16 +97,20 @@ test('should insert embedding by doc id', async t => {
const { id: contextId } = await t.context.copilotContext.create(session.id);
{
await t.context.copilotContext.insertEmbedding(contextId, 'file-id', [
{
index: 0,
content: 'content',
embedding: Array.from({ length: 512 }, () => 1),
},
]);
await t.context.copilotContext.insertContentEmbedding(
contextId,
'file-id',
[
{
index: 0,
content: 'content',
embedding: Array.from({ length: 512 }, () => 1),
},
]
);
{
const ret = await t.context.copilotContext.matchEmbedding(
const ret = await t.context.copilotContext.matchContentEmbedding(
Array.from({ length: 512 }, () => 0.9),
contextId,
1,
@@ -117,7 +122,7 @@ test('should insert embedding by doc id', async t => {
{
await t.context.copilotContext.deleteEmbedding(contextId, 'file-id');
const ret = await t.context.copilotContext.matchEmbedding(
const ret = await t.context.copilotContext.matchContentEmbedding(
Array.from({ length: 512 }, () => 0.9),
contextId,
1,

View File

@@ -650,6 +650,10 @@ export const USER_FRIENDLY_ERRORS = {
args: { docId: 'string' },
message: ({ docId }) => `Doc ${docId} not found.`,
},
copilot_docs_not_found: {
type: 'resource_not_found',
message: () => `Some docs not found.`,
},
copilot_message_not_found: {
type: 'resource_not_found',
args: { messageId: 'string' },

View File

@@ -664,6 +664,12 @@ export class CopilotDocNotFound extends UserFriendlyError {
super('resource_not_found', 'copilot_doc_not_found', message, args);
}
}
export class CopilotDocsNotFound extends UserFriendlyError {
constructor(message?: string) {
super('resource_not_found', 'copilot_docs_not_found', message);
}
}
@ObjectType()
class CopilotMessageNotFoundDataType {
@Field() messageId!: string
@@ -997,6 +1003,7 @@ export enum ErrorNames {
UNSPLASH_IS_NOT_CONFIGURED,
COPILOT_ACTION_TAKEN,
COPILOT_DOC_NOT_FOUND,
COPILOT_DOCS_NOT_FOUND,
COPILOT_MESSAGE_NOT_FOUND,
COPILOT_PROMPT_NOT_FOUND,
COPILOT_PROMPT_INVALID,

View File

@@ -34,6 +34,13 @@ export enum ContextCategories {
export const ContextDocSchema = z.object({
id: z.string(),
createdAt: z.number(),
status: z
.enum([
ContextEmbedStatus.processing,
ContextEmbedStatus.finished,
ContextEmbedStatus.failed,
])
.nullable(),
});
export const ContextFileSchema = z.object({

View File

@@ -6,7 +6,6 @@ import { Prisma } from '@prisma/client';
import { CopilotSessionNotFound } from '../base';
import { BaseModel } from './base';
import {
ChunkSimilarity,
ContextConfigSchema,
ContextDoc,
ContextEmbedStatus,
@@ -24,7 +23,7 @@ type UpdateCopilotContextInput = Pick<CopilotContext, 'config'>;
*/
@Injectable()
export class CopilotContextModel extends BaseModel {
// contexts
// ================ contexts ================
async create(sessionId: string) {
const session = await this.db.aiSession.findFirst({
@@ -113,7 +112,7 @@ export class CopilotContextModel extends BaseModel {
return ret.count > 0;
}
// embeddings
// ================ embeddings ================
async checkEmbeddingAvailable(): Promise<boolean> {
const [{ count }] = await this.db.$queryRaw<
@@ -157,7 +156,7 @@ export class CopilotContextModel extends BaseModel {
return Prisma.join(groups.map(row => Prisma.sql`(${Prisma.join(row)})`));
}
async insertEmbedding(
async insertContentEmbedding(
contextId: string,
fileId: string,
embeddings: Embedding[]
@@ -172,12 +171,12 @@ export class CopilotContextModel extends BaseModel {
`;
}
async matchEmbedding(
async matchContentEmbedding(
embedding: number[],
contextId: string,
topK: number,
threshold: number
): Promise<ChunkSimilarity[]> {
): Promise<FileChunkSimilarity[]> {
const similarityChunks = await this.db.$queryRaw<
Array<FileChunkSimilarity>
>`
@@ -214,7 +213,7 @@ export class CopilotContextModel extends BaseModel {
workspaceId: string,
topK: number,
threshold: number
): Promise<ChunkSimilarity[]> {
): Promise<DocChunkSimilarity[]> {
const similarityChunks = await this.db.$queryRaw<Array<DocChunkSimilarity>>`
SELECT "doc_id" as "docId", "chunk", "content", "embedding" <=> ${embedding}::vector as "distance"
FROM "ai_workspace_embeddings"

View File

@@ -185,6 +185,23 @@ export class DocModel extends BaseModel {
});
}
/**
* Check if all doc exists in the workspace.
* Ignore pending updates.
*/
async existsAll(workspaceId: string, docIds: string[]) {
const count = await this.db.snapshot.count({
where: {
workspaceId,
id: { in: docIds },
},
});
if (count === docIds.length) {
return true;
}
return false;
}
/**
* Detect a doc exists or not, including updates
*/

View File

@@ -1,6 +1,7 @@
import OpenAI from 'openai';
import { Embedding, EmbeddingClient } from './types';
import { Embedding } from '../../../models';
import { EmbeddingClient } from './types';
export class OpenAIEmbeddingClient extends EmbeddingClient {
constructor(private readonly client: OpenAI) {
@@ -15,6 +16,7 @@ export class OpenAIEmbeddingClient extends EmbeddingClient {
{
input,
model: 'text-embedding-3-large',
dimensions: 1024,
encoding_format: 'float',
},
{ signal }

View File

@@ -1,7 +1,3 @@
export { CopilotContextDocJob } from './job';
export { CopilotContextResolver, CopilotContextRootResolver } from './resolver';
export { CopilotContextService } from './service';
export {
type ContextFile,
ContextEmbedStatus as ContextFileStatus,
} from './types';

View File

@@ -1,7 +1,4 @@
import { randomUUID } from 'node:crypto';
import { Injectable, OnModuleInit } from '@nestjs/common';
import { Prisma, PrismaClient } from '@prisma/client';
import OpenAI from 'openai';
import {
@@ -15,10 +12,11 @@ import {
OnJob,
} from '../../../base';
import { DocReader } from '../../../core/doc';
import { Models } from '../../../models';
import { CopilotStorage } from '../storage';
import { OpenAIEmbeddingClient } from './embedding';
import { Embedding, EmbeddingClient } from './types';
import { checkEmbeddingAvailable, readStream } from './utils';
import { EmbeddingClient } from './types';
import { readStream } from './utils';
declare global {
interface Jobs {
@@ -45,10 +43,10 @@ export class CopilotContextDocJob implements OnModuleInit {
constructor(
config: Config,
private readonly db: PrismaClient,
private readonly doc: DocReader,
private readonly event: EventBus,
private readonly logger: AFFiNELogger,
private readonly models: Models,
private readonly queue: JobQueue,
private readonly storage: CopilotStorage
) {
@@ -60,7 +58,8 @@ export class CopilotContextDocJob implements OnModuleInit {
}
async onModuleInit() {
this.supportEmbedding = await checkEmbeddingAvailable(this.db);
this.supportEmbedding =
await this.models.copilotContext.checkEmbeddingAvailable();
}
// public this client to allow overriding in tests
@@ -91,23 +90,6 @@ export class CopilotContextDocJob implements OnModuleInit {
}
}
private processEmbeddings(
contextOrWorkspaceId: string,
fileOrDocId: string,
embeddings: Embedding[]
) {
const groups = embeddings.map(e => [
randomUUID(),
contextOrWorkspaceId,
fileOrDocId,
e.index,
e.content,
Prisma.raw(`'[${e.embedding.join(',')}]'`),
new Date(),
]);
return Prisma.join(groups.map(row => Prisma.sql`(${Prisma.join(row)})`));
}
async readCopilotBlob(
userId: string,
workspaceId: string,
@@ -145,14 +127,11 @@ export class CopilotContextDocJob implements OnModuleInit {
for (const chunk of chunks) {
const embeddings = await this.embeddingClient.generateEmbeddings(chunk);
const values = this.processEmbeddings(contextId, fileId, embeddings);
await this.db.$executeRaw`
INSERT INTO "ai_context_embeddings"
("id", "context_id", "file_id", "chunk", "content", "embedding", "updated_at") VALUES ${values}
ON CONFLICT (context_id, file_id, chunk) DO UPDATE SET
content = EXCLUDED.content, embedding = EXCLUDED.embedding, updated_at = excluded.updated_at;
`;
await this.models.copilotContext.insertContentEmbedding(
contextId,
fileId,
embeddings
);
}
this.event.emit('workspace.file.embed.finished', {
@@ -188,13 +167,11 @@ export class CopilotContextDocJob implements OnModuleInit {
);
for (const chunks of embeddings) {
const values = this.processEmbeddings(workspaceId, docId, chunks);
await this.db.$executeRaw`
INSERT INTO "ai_workspace_embeddings"
("workspace_id", "doc_id", "chunk", "content", "embedding", "updated_at") VALUES ${values}
ON CONFLICT (context_id, file_id, chunk) DO UPDATE SET
embedding = EXCLUDED.embedding, updated_at = excluded.updated_at;
`;
await this.models.copilotContext.insertWorkspaceEmbedding(
workspaceId,
docId,
chunks
);
}
}
} catch (e: any) {

View File

@@ -34,25 +34,41 @@ import {
} from '../../../base';
import { CurrentUser } from '../../../core/auth';
import { AccessController } from '../../../core/permission';
import { COPILOT_LOCKER, CopilotType } from '../resolver';
import { ChatSessionService } from '../session';
import { CopilotStorage } from '../storage';
import { CopilotContextDocJob } from './job';
import { CopilotContextService } from './service';
import {
ContextCategories,
ContextCategory,
ContextDoc,
ContextEmbedStatus,
type ContextFile,
ContextFile,
DocChunkSimilarity,
FileChunkSimilarity,
MAX_EMBEDDABLE_SIZE,
} from './types';
Models,
} from '../../../models';
import { COPILOT_LOCKER, CopilotType } from '../resolver';
import { ChatSessionService } from '../session';
import { CopilotStorage } from '../storage';
import { CopilotContextDocJob } from './job';
import { CopilotContextService } from './service';
import { MAX_EMBEDDABLE_SIZE } from './types';
import { readStream } from './utils';
@InputType()
class AddRemoveContextCategoryInput {
class AddContextCategoryInput {
@Field(() => String)
contextId!: string;
@Field(() => ContextCategories)
type!: ContextCategories;
@Field(() => String)
categoryId!: string;
@Field(() => [String], { nullable: true })
docs!: string[] | null;
}
@InputType()
class RemoveContextCategoryInput {
@Field(() => String)
contextId!: string;
@@ -111,21 +127,7 @@ export class CopilotContextType {
registerEnumType(ContextCategories, { name: 'ContextCategories' });
@ObjectType()
class CopilotContextCategory implements ContextCategory {
@Field(() => ID)
id!: string;
@Field(() => ContextCategories)
type!: ContextCategories;
@Field(() => SafeIntResolver)
createdAt!: number;
}
registerEnumType(ContextEmbedStatus, { name: 'ContextEmbedStatus' });
@ObjectType()
class CopilotContextDoc implements ContextDoc {
class CopilotDocType implements ContextDoc {
@Field(() => ID)
id!: string;
@@ -136,6 +138,29 @@ class CopilotContextDoc implements ContextDoc {
createdAt!: number;
}
@ObjectType()
class CopilotContextCategory implements Omit<ContextCategory, 'docs'> {
@Field(() => ID)
id!: string;
@Field(() => ContextCategories)
type!: ContextCategories;
@Field(() => [CopilotDocType])
docs!: CopilotDocType[];
@Field(() => SafeIntResolver)
createdAt!: number;
}
registerEnumType(ContextEmbedStatus, { name: 'ContextEmbedStatus' });
@ObjectType()
class CopilotContextDoc extends CopilotDocType {
@Field(() => String, { nullable: true })
error!: string | null;
}
@ObjectType()
class CopilotContextFile implements ContextFile {
@Field(() => ID)
@@ -338,6 +363,7 @@ export class CopilotContextRootResolver {
export class CopilotContextResolver {
constructor(
private readonly ac: AccessController,
private readonly models: Models,
private readonly mutex: RequestMutex,
private readonly context: CopilotContextService,
private readonly jobs: CopilotContextDocJob,
@@ -354,13 +380,61 @@ export class CopilotContextResolver {
return controller.signal;
}
@ResolveField(() => [CopilotContextCategory], {
description: 'list collections in context',
})
@CallMetric('ai', 'context_file_list')
async collections(
@Parent() context: CopilotContextType
): Promise<ContextCategory[]> {
const session = await this.context.get(context.id);
const collections = session.collections;
await this.models.copilotContext.mergeDocStatus(
session.workspaceId,
collections.flatMap(c => c.docs)
);
return collections;
}
@ResolveField(() => [CopilotContextCategory], {
description: 'list tags in context',
})
@CallMetric('ai', 'context_file_list')
async tags(
@Parent() context: CopilotContextType
): Promise<ContextCategory[]> {
const session = await this.context.get(context.id);
const tags = session.tags;
await this.models.copilotContext.mergeDocStatus(
session.workspaceId,
tags.flatMap(c => c.docs)
);
return tags;
}
@ResolveField(() => [CopilotContextDoc], {
description: 'list files in context',
})
@CallMetric('ai', 'context_file_list')
async docs(@Parent() context: CopilotContextType): Promise<ContextDoc[]> {
const session = await this.context.get(context.id);
return session.listDocs();
const docs = session.docs;
await this.models.copilotContext.mergeDocStatus(session.workspaceId, docs);
return docs;
}
@ResolveField(() => [CopilotContextFile], {
description: 'list files in context',
})
@CallMetric('ai', 'context_file_list')
async files(
@Parent() context: CopilotContextType
): Promise<CopilotContextFile[]> {
const session = await this.context.get(context.id);
return session.files;
}
@Mutation(() => CopilotContextCategory, {
@@ -368,18 +442,33 @@ export class CopilotContextResolver {
})
@CallMetric('ai', 'context_category_add')
async addContextCategory(
@Args({ name: 'options', type: () => AddRemoveContextCategoryInput })
options: AddRemoveContextCategoryInput
) {
@Args({ name: 'options', type: () => AddContextCategoryInput })
options: AddContextCategoryInput
): Promise<CopilotContextCategory> {
const lockFlag = `${COPILOT_LOCKER}:context:${options.contextId}`;
await using lock = await this.mutex.acquire(lockFlag);
if (!lock) {
return new TooManyRequest('Server is busy');
throw new TooManyRequest('Server is busy');
}
const session = await this.context.get(options.contextId);
try {
return await session.addCategoryRecord(options.type, options.categoryId);
const records = await session.addCategoryRecord(
options.type,
options.categoryId,
options.docs || []
);
if (options.docs) {
await this.jobs.addDocEmbeddingQueue(
options.docs.map(docId => ({
workspaceId: session.workspaceId,
docId,
}))
);
}
return records;
} catch (e: any) {
throw new CopilotFailedToModifyContext({
contextId: options.contextId,
@@ -393,8 +482,8 @@ export class CopilotContextResolver {
})
@CallMetric('ai', 'context_category_remove')
async removeContextCategory(
@Args({ name: 'options', type: () => AddRemoveContextCategoryInput })
options: AddRemoveContextCategoryInput
@Args({ name: 'options', type: () => RemoveContextCategoryInput })
options: RemoveContextCategoryInput
) {
const lockFlag = `${COPILOT_LOCKER}:context:${options.contextId}`;
await using lock = await this.mutex.acquire(lockFlag);
@@ -432,7 +521,16 @@ export class CopilotContextResolver {
const session = await this.context.get(options.contextId);
try {
return await session.addDocRecord(options.docId);
const record = await session.addDocRecord(options.docId);
await this.jobs.addDocEmbeddingQueue([
{
workspaceId: session.workspaceId,
docId: options.docId,
},
]);
return record;
} catch (e: any) {
throw new CopilotFailedToModifyContext({
contextId: options.contextId,
@@ -466,17 +564,6 @@ export class CopilotContextResolver {
}
}
@ResolveField(() => [CopilotContextFile], {
description: 'list files in context',
})
@CallMetric('ai', 'context_file_list')
async files(
@Parent() context: CopilotContextType
): Promise<CopilotContextFile[]> {
const session = await this.context.get(context.id);
return session.listFiles();
}
@Mutation(() => CopilotContextFile, {
description: 'add a file to context',
})

View File

@@ -1,27 +1,23 @@
import { Injectable, OnModuleInit } from '@nestjs/common';
import { PrismaClient } from '@prisma/client';
import OpenAI from 'openai';
import {
Cache,
Config,
CopilotInvalidContext,
CopilotSessionNotFound,
NoCopilotProviderAvailable,
OnEvent,
PrismaTransaction,
} from '../../../base';
import { OpenAIEmbeddingClient } from './embedding';
import { ContextSession } from './session';
import {
ContextConfig,
ContextConfigSchema,
ContextEmbedStatus,
ContextFile,
EmbeddingClient,
MinimalContextConfigSchema,
} from './types';
import { checkEmbeddingAvailable } from './utils';
Models,
} from '../../../models';
import { OpenAIEmbeddingClient } from './embedding';
import { ContextSession } from './session';
import { EmbeddingClient } from './types';
const CONTEXT_SESSION_KEY = 'context-session';
@@ -33,7 +29,7 @@ export class CopilotContextService implements OnModuleInit {
constructor(
config: Config,
private readonly cache: Cache,
private readonly db: PrismaClient
private readonly models: Models
) {
const configure = config.plugins.copilot.openai;
if (configure) {
@@ -42,7 +38,8 @@ export class CopilotContextService implements OnModuleInit {
}
async onModuleInit() {
const supportEmbedding = await checkEmbeddingAvailable(this.db);
const supportEmbedding =
await this.models.copilotContext.checkEmbeddingAvailable();
if (supportEmbedding) {
this.supportEmbedding = true;
}
@@ -60,15 +57,10 @@ export class CopilotContextService implements OnModuleInit {
private async saveConfig(
contextId: string,
config: ContextConfig,
tx?: PrismaTransaction,
refreshCache = false
): Promise<void> {
if (!refreshCache) {
const executor = tx || this.db;
await executor.aiContext.update({
where: { id: contextId },
data: { config },
});
await this.models.copilotContext.update(contextId, { config });
}
await this.cache.set(`${CONTEXT_SESSION_KEY}:${contextId}`, config);
}
@@ -86,7 +78,7 @@ export class CopilotContextService implements OnModuleInit {
this.embeddingClient,
contextId,
config.data,
this.db,
this.models,
this.saveConfig.bind(this, contextId)
);
}
@@ -103,41 +95,22 @@ export class CopilotContextService implements OnModuleInit {
config: ContextConfig
): Promise<ContextSession> {
const dispatcher = this.saveConfig.bind(this, contextId);
await dispatcher(config, undefined, true);
await dispatcher(config, true);
return new ContextSession(
this.embeddingClient,
contextId,
config,
this.db,
this.models,
dispatcher
);
}
async create(sessionId: string): Promise<ContextSession> {
const session = await this.db.aiSession.findFirst({
where: { id: sessionId },
select: { workspaceId: true },
});
if (!session) {
throw new CopilotSessionNotFound();
}
// keep the context unique per session
const existsContext = await this.getBySessionId(sessionId);
if (existsContext) return existsContext;
const context = await this.db.aiContext.create({
data: {
sessionId,
config: {
workspaceId: session.workspaceId,
docs: [],
files: [],
categories: [],
},
},
});
const context = await this.models.copilotContext.create(sessionId);
const config = ContextConfigSchema.parse(context.config);
return await this.cacheSession(context.id, config);
}
@@ -149,34 +122,16 @@ export class CopilotContextService implements OnModuleInit {
const context = await this.getCachedSession(id);
if (context) return context;
const ret = await this.db.aiContext.findUnique({
where: { id },
select: { config: true },
});
if (ret) {
const config = ContextConfigSchema.safeParse(ret.config);
if (config.success) {
return this.cacheSession(id, config.data);
}
const minimalConfig = MinimalContextConfigSchema.safeParse(ret.config);
if (minimalConfig.success) {
// fulfill the missing fields
return this.cacheSession(id, {
...minimalConfig.data,
docs: [],
files: [],
categories: [],
});
}
const config = await this.models.copilotContext.getConfig(id);
if (config) {
return this.cacheSession(id, config);
}
throw new CopilotInvalidContext({ contextId: id });
}
async getBySessionId(sessionId: string): Promise<ContextSession | null> {
const existsContext = await this.db.aiContext.findFirst({
where: { sessionId },
select: { id: true },
});
const existsContext =
await this.models.copilotContext.getBySessionId(sessionId);
if (existsContext) return this.get(existsContext.id);
return null;
}

View File

@@ -1,30 +1,25 @@
import { PrismaClient } from '@prisma/client';
import { nanoid } from 'nanoid';
import { PrismaTransaction } from '../../../base';
import { CopilotDocsNotFound } from '../../../base';
import {
ChunkSimilarity,
ContextCategories,
ContextCategory,
ContextConfig,
ContextDoc,
ContextEmbedStatus,
ContextFile,
ContextList,
DocChunkSimilarity,
EmbeddingClient,
FileChunkSimilarity,
} from './types';
Models,
} from '../../../models';
import { EmbeddingClient } from './types';
export class ContextSession implements AsyncDisposable {
constructor(
private readonly client: EmbeddingClient,
private readonly contextId: string,
private readonly config: ContextConfig,
private readonly db: PrismaClient,
private readonly dispatcher?: (
config: ContextConfig,
tx?: PrismaTransaction
) => Promise<void>
private readonly models: Models,
private readonly dispatcher?: (config: ContextConfig) => Promise<void>
) {}
get id() {
@@ -35,11 +30,28 @@ export class ContextSession implements AsyncDisposable {
return this.config.workspaceId;
}
listDocs(): ContextDoc[] {
return [...this.config.docs];
get categories(): ContextCategory[] {
return this.config.categories.map(c => ({
...c,
docs: c.docs.map(d => ({ ...d })),
}));
}
listFiles() {
get tags() {
const categories = this.config.categories;
return categories.filter(c => c.type === ContextCategories.Tag);
}
get collections() {
const categories = this.config.categories;
return categories.filter(c => c.type === ContextCategories.Collection);
}
get docs(): ContextDoc[] {
return this.config.docs.map(d => ({ ...d }));
}
get files() {
return this.config.files.map(f => ({ ...f }));
}
@@ -50,14 +62,25 @@ export class ContextSession implements AsyncDisposable {
) as ContextList;
}
async addCategoryRecord(type: ContextCategories, id: string) {
async addCategoryRecord(type: ContextCategories, id: string, docs: string[]) {
const existDocs = await this.models.doc.existsAll(this.workspaceId, docs);
if (!existDocs) {
throw new CopilotDocsNotFound();
}
const category = this.config.categories.find(
c => c.type === type && c.id === id
);
if (category) {
return category;
}
const record = { id, type, createdAt: Date.now() };
const createdAt = Date.now();
const record = {
id,
type,
docs: docs.map(id => ({ id, createdAt, status: null })),
createdAt,
};
this.config.categories.push(record);
await this.save();
return record;
@@ -122,14 +145,10 @@ export class ContextSession implements AsyncDisposable {
}
async removeFile(fileId: string): Promise<boolean> {
return await this.db.$transaction(async tx => {
await tx.aiContextEmbedding.deleteMany({
where: { contextId: this.contextId, fileId },
});
this.config.files = this.config.files.filter(f => f.id !== fileId);
await this.save(tx);
return true;
});
await this.models.copilotContext.deleteEmbedding(this.contextId, fileId);
this.config.files = this.config.files.filter(f => f.id !== fileId);
await this.save();
return true;
}
/**
@@ -145,21 +164,18 @@ export class ContextSession implements AsyncDisposable {
topK: number = 5,
signal?: AbortSignal,
threshold: number = 0.7
): Promise<FileChunkSimilarity[]> {
) {
const embedding = await this.client
.getEmbeddings([content], signal)
.then(r => r?.[0]?.embedding);
if (!embedding) return [];
const similarityChunks = await this.db.$queryRaw<
Array<FileChunkSimilarity>
>`
SELECT "file_id" as "fileId", "chunk", "content", "embedding" <=> ${embedding}::vector as "distance"
FROM "ai_context_embeddings"
WHERE context_id = ${this.id}
ORDER BY "distance" ASC
LIMIT ${topK};
`;
return similarityChunks.filter(c => Number(c.distance) <= threshold);
return this.models.copilotContext.matchContentEmbedding(
embedding,
this.id,
topK,
threshold
);
}
/**
@@ -175,19 +191,18 @@ export class ContextSession implements AsyncDisposable {
topK: number = 5,
signal?: AbortSignal,
threshold: number = 0.7
): Promise<ChunkSimilarity[]> {
) {
const embedding = await this.client
.getEmbeddings([content], signal)
.then(r => r?.[0]?.embedding);
if (!embedding) return [];
const similarityChunks = await this.db.$queryRaw<Array<DocChunkSimilarity>>`
SELECT "doc_id" as "docId", "chunk", "content", "embedding" <=> ${embedding}::vector as "distance"
FROM "ai_workspace_embeddings"
WHERE "workspace_id" = ${this.workspaceId}
ORDER BY "distance" ASC
LIMIT ${topK};
`;
return similarityChunks.filter(c => Number(c.distance) <= threshold);
return this.models.copilotContext.matchWorkspaceEmbedding(
embedding,
this.id,
topK,
threshold
);
}
async saveFileRecord(
@@ -195,8 +210,7 @@ export class ContextSession implements AsyncDisposable {
cb: (
record: Pick<ContextFile, 'id' | 'status'> &
Partial<Omit<ContextFile, 'id' | 'status'>>
) => ContextFile,
tx?: PrismaTransaction
) => ContextFile
) {
const files = this.config.files;
const file = files.find(f => f.id === fileId);
@@ -206,11 +220,11 @@ export class ContextSession implements AsyncDisposable {
const file = { id: fileId, status: ContextEmbedStatus.processing };
files.push(cb(file));
}
await this.save(tx);
await this.save();
}
async save(tx?: PrismaTransaction) {
await this.dispatcher?.(this.config, tx);
async save() {
await this.dispatcher?.(this.config);
}
async [Symbol.asyncDispose]() {

View File

@@ -1,8 +1,7 @@
import { File } from 'node:buffer';
import { z } from 'zod';
import { CopilotContextFileNotSupported, OneMB } from '../../../base';
import { Embedding } from '../../../models';
import { parseDoc } from '../../../native';
declare global {
@@ -26,99 +25,11 @@ declare global {
export const MAX_EMBEDDABLE_SIZE = 50 * OneMB;
export enum ContextEmbedStatus {
processing = 'processing',
finished = 'finished',
failed = 'failed',
}
export enum ContextCategories {
Tag = 'tag',
Collection = 'collection',
}
export const ContextConfigSchema = z.object({
workspaceId: z.string(),
files: z
.object({
id: z.string(),
chunkSize: z.number(),
name: z.string(),
status: z.enum([
ContextEmbedStatus.processing,
ContextEmbedStatus.finished,
ContextEmbedStatus.failed,
]),
error: z.string().nullable(),
blobId: z.string(),
createdAt: z.number(),
})
.array(),
docs: z
.object({
id: z.string(),
// status for workspace doc embedding progress
// only exists when the client submits the doc embedding task
status: z
.enum([
ContextEmbedStatus.processing,
ContextEmbedStatus.finished,
ContextEmbedStatus.failed,
])
.nullable(),
createdAt: z.number(),
})
.array(),
categories: z
.object({
id: z.string(),
type: z.enum([ContextCategories.Tag, ContextCategories.Collection]),
createdAt: z.number(),
})
.array(),
});
export const MinimalContextConfigSchema = ContextConfigSchema.pick({
workspaceId: true,
});
export type ContextConfig = z.infer<typeof ContextConfigSchema>;
export type ContextCategory = z.infer<
typeof ContextConfigSchema
>['categories'][number];
export type ContextDoc = z.infer<typeof ContextConfigSchema>['docs'][number];
export type ContextFile = z.infer<typeof ContextConfigSchema>['files'][number];
export type ContextListItem = ContextDoc | ContextFile;
export type ContextList = ContextListItem[];
export type Chunk = {
index: number;
content: string;
};
export type ChunkSimilarity = {
chunk: number;
content: string;
distance: number | null;
};
export type FileChunkSimilarity = ChunkSimilarity & {
fileId: string;
};
export type DocChunkSimilarity = ChunkSimilarity & {
docId: string;
};
export type Embedding = {
/**
* The index of the embedding in the list of embeddings.
*/
index: number;
content: string;
embedding: Array<number>;
};
export abstract class EmbeddingClient {
async getFileEmbeddings(
file: File,

View File

@@ -1,7 +1,5 @@
import { Readable } from 'node:stream';
import { PrismaClient } from '@prisma/client';
import { readBufferWithLimit } from '../../../base';
import { MAX_EMBEDDABLE_SIZE } from './types';
@@ -17,17 +15,6 @@ export class GqlSignal implements AsyncDisposable {
}
}
export async function checkEmbeddingAvailable(
db: PrismaClient
): Promise<boolean> {
const [{ count }] = await db.$queryRaw<
{
count: number;
}[]
>`SELECT count(1) FROM pg_tables WHERE tablename in ('ai_context_embeddings', 'ai_workspace_embeddings')`;
return Number(count) === 2;
}
export function readStream(
readable: Readable,
maxSize = MAX_EMBEDDABLE_SIZE

View File

@@ -2,6 +2,13 @@
# THIS FILE WAS AUTOMATICALLY GENERATED (DO NOT MODIFY)
# ------------------------------------------------------
input AddContextCategoryInput {
categoryId: String!
contextId: String!
docs: [String!]
type: ContextCategories!
}
input AddContextDocInput {
contextId: String!
docId: String!
@@ -12,12 +19,6 @@ input AddContextFileInput {
contextId: String!
}
input AddRemoveContextCategoryInput {
categoryId: String!
contextId: String!
type: ContextCategories!
}
enum AiJobStatus {
claimed
failed
@@ -98,6 +99,9 @@ type Copilot {
}
type CopilotContext {
"""list collections in context"""
collections: [CopilotContextCategory!]!
"""list files in context"""
docs: [CopilotContextDoc!]!
@@ -110,17 +114,22 @@ type CopilotContext {
"""match workspace doc content"""
matchWorkspaceContext(content: String!, limit: SafeInt): ContextMatchedDocChunk!
"""list tags in context"""
tags: [CopilotContextCategory!]!
workspaceId: String!
}
type CopilotContextCategory {
createdAt: SafeInt!
docs: [CopilotDocType!]!
id: ID!
type: ContextCategories!
}
type CopilotContextDoc {
createdAt: SafeInt!
error: String
id: ID!
status: ContextEmbedStatus
}
@@ -144,6 +153,12 @@ type CopilotDocNotFoundDataType {
docId: String!
}
type CopilotDocType {
createdAt: SafeInt!
id: ID!
status: ContextEmbedStatus
}
type CopilotFailedToMatchContextDataType {
content: String!
contextId: String!
@@ -405,6 +420,7 @@ enum ErrorNames {
CAPTCHA_VERIFICATION_FAILED
COPILOT_ACTION_TAKEN
COPILOT_CONTEXT_FILE_NOT_SUPPORTED
COPILOT_DOCS_NOT_FOUND
COPILOT_DOC_NOT_FOUND
COPILOT_EMBEDDING_UNAVAILABLE
COPILOT_FAILED_TO_CREATE_MESSAGE
@@ -838,7 +854,7 @@ type Mutation {
activateLicense(license: String!, workspaceId: String!): License!
"""add a category to context"""
addContextCategory(options: AddRemoveContextCategoryInput!): CopilotContextCategory!
addContextCategory(options: AddContextCategoryInput!): CopilotContextCategory!
"""add a doc to context"""
addContextDoc(options: AddContextDocInput!): CopilotContextDoc!
@@ -926,7 +942,7 @@ type Mutation {
removeAvatar: RemoveAvatar!
"""remove a category from context"""
removeContextCategory(options: AddRemoveContextCategoryInput!): Boolean!
removeContextCategory(options: RemoveContextCategoryInput!): Boolean!
"""remove a doc from context"""
removeContextDoc(options: RemoveContextDocInput!): Boolean!
@@ -1192,6 +1208,12 @@ type RemoveAvatar {
success: Boolean!
}
input RemoveContextCategoryInput {
categoryId: String!
contextId: String!
type: ContextCategories!
}
input RemoveContextDocInput {
contextId: String!
docId: String!