mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-14 21:27:20 +00:00
feat(server): context awareness for copilot (#9611)
fix PD-2167 fix PD-2169 fix PD-2190
This commit is contained in:
@@ -1,30 +1,53 @@
|
||||
import {
|
||||
Args,
|
||||
Context,
|
||||
Field,
|
||||
Float,
|
||||
ID,
|
||||
InputType,
|
||||
Mutation,
|
||||
ObjectType,
|
||||
Parent,
|
||||
Query,
|
||||
registerEnumType,
|
||||
ResolveField,
|
||||
Resolver,
|
||||
} from '@nestjs/graphql';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
import type { Request } from 'express';
|
||||
import { SafeIntResolver } from 'graphql-scalars';
|
||||
import GraphQLUpload from 'graphql-upload/GraphQLUpload.mjs';
|
||||
|
||||
import {
|
||||
BlobQuotaExceeded,
|
||||
CallMetric,
|
||||
CopilotEmbeddingUnavailable,
|
||||
CopilotFailedToMatchContext,
|
||||
CopilotFailedToModifyContext,
|
||||
CopilotSessionNotFound,
|
||||
EventBus,
|
||||
type FileUpload,
|
||||
RequestMutex,
|
||||
Throttle,
|
||||
TooManyRequest,
|
||||
UserFriendlyError,
|
||||
} from '../../../base';
|
||||
import { CurrentUser } from '../../../core/auth';
|
||||
import { AccessController } from '../../../core/permission';
|
||||
import { COPILOT_LOCKER, CopilotType } from '../resolver';
|
||||
import { ChatSessionService } from '../session';
|
||||
import { CopilotStorage } from '../storage';
|
||||
import { CopilotContextDocJob } from './job';
|
||||
import { CopilotContextService } from './service';
|
||||
import { ContextDoc, type ContextFile, ContextFileStatus } from './types';
|
||||
import {
|
||||
ContextDoc,
|
||||
type ContextFile,
|
||||
ContextFileStatus,
|
||||
DocChunkSimilarity,
|
||||
FileChunkSimilarity,
|
||||
MAX_EMBEDDABLE_SIZE,
|
||||
} from './types';
|
||||
import { readStream } from './utils';
|
||||
|
||||
@InputType()
|
||||
class AddContextDocInput {
|
||||
@@ -44,6 +67,24 @@ class RemoveContextDocInput {
|
||||
docId!: string;
|
||||
}
|
||||
|
||||
@InputType()
|
||||
class AddContextFileInput {
|
||||
@Field(() => String)
|
||||
contextId!: string;
|
||||
|
||||
@Field(() => String)
|
||||
blobId!: string;
|
||||
}
|
||||
|
||||
@InputType()
|
||||
class RemoveContextFileInput {
|
||||
@Field(() => String)
|
||||
contextId!: string;
|
||||
|
||||
@Field(() => String)
|
||||
fileId!: string;
|
||||
}
|
||||
|
||||
@ObjectType('CopilotContext')
|
||||
export class CopilotContextType {
|
||||
@Field(() => ID)
|
||||
@@ -78,6 +119,9 @@ class CopilotContextFile implements ContextFile {
|
||||
@Field(() => ContextFileStatus)
|
||||
status!: ContextFileStatus;
|
||||
|
||||
@Field(() => String, { nullable: true })
|
||||
error!: string | null;
|
||||
|
||||
@Field(() => String)
|
||||
blobId!: string;
|
||||
|
||||
@@ -86,30 +130,51 @@ class CopilotContextFile implements ContextFile {
|
||||
}
|
||||
|
||||
@ObjectType()
|
||||
class CopilotContextListItem {
|
||||
@Field(() => ID)
|
||||
id!: string;
|
||||
class ContextMatchedFileChunk implements FileChunkSimilarity {
|
||||
@Field(() => String)
|
||||
fileId!: string;
|
||||
|
||||
@Field(() => SafeIntResolver)
|
||||
createdAt!: number;
|
||||
chunk!: number;
|
||||
|
||||
@Field(() => String, { nullable: true })
|
||||
name!: string;
|
||||
@Field(() => String)
|
||||
content!: string;
|
||||
|
||||
@Field(() => SafeIntResolver, { nullable: true })
|
||||
chunkSize!: number;
|
||||
@Field(() => Float, { nullable: true })
|
||||
distance!: number | null;
|
||||
}
|
||||
|
||||
@Field(() => ContextFileStatus, { nullable: true })
|
||||
status!: ContextFileStatus;
|
||||
@ObjectType()
|
||||
class ContextWorkspaceEmbeddingStatus {
|
||||
@Field(() => SafeIntResolver)
|
||||
total!: number;
|
||||
|
||||
@Field(() => String, { nullable: true })
|
||||
blobId!: string;
|
||||
@Field(() => SafeIntResolver)
|
||||
embedded!: number;
|
||||
}
|
||||
|
||||
@ObjectType()
|
||||
class ContextMatchedDocChunk implements DocChunkSimilarity {
|
||||
@Field(() => String)
|
||||
docId!: string;
|
||||
|
||||
@Field(() => SafeIntResolver)
|
||||
chunk!: number;
|
||||
|
||||
@Field(() => String)
|
||||
content!: string;
|
||||
|
||||
@Field(() => Float, { nullable: true })
|
||||
distance!: number | null;
|
||||
}
|
||||
|
||||
@Throttle()
|
||||
@Resolver(() => CopilotType)
|
||||
export class CopilotContextRootResolver {
|
||||
constructor(
|
||||
private readonly db: PrismaClient,
|
||||
private readonly ac: AccessController,
|
||||
private readonly event: EventBus,
|
||||
private readonly mutex: RequestMutex,
|
||||
private readonly chatSession: ChatSessionService,
|
||||
private readonly context: CopilotContextService
|
||||
@@ -138,27 +203,30 @@ export class CopilotContextRootResolver {
|
||||
async contexts(
|
||||
@Parent() copilot: CopilotType,
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args('sessionId') sessionId: string,
|
||||
@Args('sessionId', { nullable: true }) sessionId?: string,
|
||||
@Args('contextId', { nullable: true }) contextId?: string
|
||||
) {
|
||||
const lockFlag = `${COPILOT_LOCKER}:context:${sessionId}`;
|
||||
await using lock = await this.mutex.acquire(lockFlag);
|
||||
if (!lock) {
|
||||
return new TooManyRequest('Server is busy');
|
||||
}
|
||||
await this.checkChatSession(
|
||||
user,
|
||||
sessionId,
|
||||
copilot.workspaceId || undefined
|
||||
);
|
||||
if (sessionId || contextId) {
|
||||
const lockFlag = `${COPILOT_LOCKER}:context:${sessionId || contextId}`;
|
||||
await using lock = await this.mutex.acquire(lockFlag);
|
||||
if (!lock) {
|
||||
return new TooManyRequest('Server is busy');
|
||||
}
|
||||
|
||||
if (contextId) {
|
||||
const context = await this.context.get(contextId);
|
||||
if (context) return [context];
|
||||
} else {
|
||||
const context = await this.context.getBySessionId(sessionId);
|
||||
if (context) return [context];
|
||||
if (contextId) {
|
||||
const context = await this.context.get(contextId);
|
||||
if (context) return [context];
|
||||
} else if (sessionId) {
|
||||
await this.checkChatSession(
|
||||
user,
|
||||
sessionId,
|
||||
copilot.workspaceId || undefined
|
||||
);
|
||||
const context = await this.context.getBySessionId(sessionId);
|
||||
if (context) return [context];
|
||||
}
|
||||
}
|
||||
|
||||
return [];
|
||||
}
|
||||
|
||||
@@ -181,17 +249,80 @@ export class CopilotContextRootResolver {
|
||||
const context = await this.context.create(sessionId);
|
||||
return context.id;
|
||||
}
|
||||
|
||||
@Mutation(() => Boolean, {
|
||||
description: 'queue workspace doc embedding',
|
||||
})
|
||||
@CallMetric('ai', 'context_queue_workspace_doc')
|
||||
async queueWorkspaceEmbedding(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args('workspaceId') workspaceId: string,
|
||||
@Args('docId', { type: () => [String] }) docIds: string[]
|
||||
) {
|
||||
await this.ac
|
||||
.user(user.id)
|
||||
.workspace(workspaceId)
|
||||
.allowLocal()
|
||||
.assert('Workspace.Copilot');
|
||||
|
||||
if (this.context.canEmbedding) {
|
||||
this.event.emit(
|
||||
'workspace.doc.embedding',
|
||||
docIds.map(docId => ({ workspaceId, docId }))
|
||||
);
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
@Query(() => ContextWorkspaceEmbeddingStatus, {
|
||||
description: 'query workspace embedding status',
|
||||
})
|
||||
@CallMetric('ai', 'context_query_workspace_embedding_status')
|
||||
async queryWorkspaceEmbeddingStatus(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args('workspaceId') workspaceId: string
|
||||
) {
|
||||
await this.ac
|
||||
.user(user.id)
|
||||
.workspace(workspaceId)
|
||||
.allowLocal()
|
||||
.assert('Workspace.Copilot');
|
||||
|
||||
if (this.context.canEmbedding) {
|
||||
const total = await this.db.snapshot.count({ where: { workspaceId } });
|
||||
const embedded = await this.db.snapshot.count({
|
||||
where: { workspaceId, embedding: { isNot: null } },
|
||||
});
|
||||
return { total, embedded };
|
||||
}
|
||||
|
||||
return { total: 0, embedded: 0 };
|
||||
}
|
||||
}
|
||||
|
||||
@Throttle()
|
||||
@Resolver(() => CopilotContextType)
|
||||
export class CopilotContextResolver {
|
||||
constructor(
|
||||
private readonly ac: AccessController,
|
||||
private readonly mutex: RequestMutex,
|
||||
|
||||
private readonly context: CopilotContextService
|
||||
private readonly context: CopilotContextService,
|
||||
private readonly jobs: CopilotContextDocJob,
|
||||
private readonly storage: CopilotStorage
|
||||
) {}
|
||||
|
||||
private getSignal(req: Request) {
|
||||
const controller = new AbortController();
|
||||
req.socket.on('close', hasError => {
|
||||
if (hasError) {
|
||||
controller.abort();
|
||||
}
|
||||
});
|
||||
return controller.signal;
|
||||
}
|
||||
|
||||
@ResolveField(() => [CopilotContextDoc], {
|
||||
description: 'list files in context',
|
||||
})
|
||||
@@ -201,7 +332,7 @@ export class CopilotContextResolver {
|
||||
return session.listDocs();
|
||||
}
|
||||
|
||||
@Mutation(() => [CopilotContextListItem], {
|
||||
@Mutation(() => CopilotContextDoc, {
|
||||
description: 'add a doc to context',
|
||||
})
|
||||
@CallMetric('ai', 'context_doc_add')
|
||||
@@ -261,4 +392,175 @@ export class CopilotContextResolver {
|
||||
const session = await this.context.get(context.id);
|
||||
return session.listFiles();
|
||||
}
|
||||
|
||||
@Mutation(() => CopilotContextFile, {
|
||||
description: 'add a file to context',
|
||||
})
|
||||
@CallMetric('ai', 'context_file_add')
|
||||
async addContextFile(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Context() ctx: { req: Request },
|
||||
@Args({ name: 'options', type: () => AddContextFileInput })
|
||||
options: AddContextFileInput,
|
||||
@Args({ name: 'content', type: () => GraphQLUpload })
|
||||
content: FileUpload
|
||||
) {
|
||||
if (!this.context.canEmbedding) {
|
||||
throw new CopilotEmbeddingUnavailable();
|
||||
}
|
||||
|
||||
const lockFlag = `${COPILOT_LOCKER}:context:${options.contextId}`;
|
||||
await using lock = await this.mutex.acquire(lockFlag);
|
||||
if (!lock) {
|
||||
return new TooManyRequest('Server is busy');
|
||||
}
|
||||
|
||||
const length = Number(ctx.req.headers['content-length']);
|
||||
if (length && length >= MAX_EMBEDDABLE_SIZE) {
|
||||
throw new BlobQuotaExceeded();
|
||||
}
|
||||
|
||||
const session = await this.context.get(options.contextId);
|
||||
|
||||
try {
|
||||
const file = await session.addFile(options.blobId, content.filename);
|
||||
|
||||
const buffer = await readStream(content.createReadStream());
|
||||
await this.storage.put(
|
||||
user.id,
|
||||
session.workspaceId,
|
||||
options.blobId,
|
||||
buffer
|
||||
);
|
||||
|
||||
await this.jobs.addFileEmbeddingQueue({
|
||||
userId: user.id,
|
||||
workspaceId: session.workspaceId,
|
||||
contextId: session.id,
|
||||
blobId: file.blobId,
|
||||
fileId: file.id,
|
||||
fileName: file.name,
|
||||
});
|
||||
|
||||
return file;
|
||||
} catch (e: any) {
|
||||
// passthrough user friendly error
|
||||
if (e instanceof UserFriendlyError) {
|
||||
throw e;
|
||||
}
|
||||
throw new CopilotFailedToModifyContext({
|
||||
contextId: options.contextId,
|
||||
message: e.message,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@Mutation(() => Boolean, {
|
||||
description: 'remove a file from context',
|
||||
})
|
||||
@CallMetric('ai', 'context_file_remove')
|
||||
async removeContextFile(
|
||||
@Args({ name: 'options', type: () => RemoveContextFileInput })
|
||||
options: RemoveContextFileInput
|
||||
) {
|
||||
if (!this.context.canEmbedding) {
|
||||
throw new CopilotEmbeddingUnavailable();
|
||||
}
|
||||
|
||||
const lockFlag = `${COPILOT_LOCKER}:context:${options.contextId}`;
|
||||
await using lock = await this.mutex.acquire(lockFlag);
|
||||
if (!lock) {
|
||||
return new TooManyRequest('Server is busy');
|
||||
}
|
||||
const session = await this.context.get(options.contextId);
|
||||
|
||||
try {
|
||||
return await session.removeFile(options.fileId);
|
||||
} catch (e: any) {
|
||||
throw new CopilotFailedToModifyContext({
|
||||
contextId: options.contextId,
|
||||
message: e.message,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@ResolveField(() => [ContextMatchedFileChunk], {
|
||||
description: 'match file context',
|
||||
})
|
||||
@CallMetric('ai', 'context_file_remove')
|
||||
async matchContext(
|
||||
@Context() ctx: { req: Request },
|
||||
@Parent() context: CopilotContextType,
|
||||
@Args('content') content: string,
|
||||
@Args('limit', { type: () => SafeIntResolver, nullable: true })
|
||||
limit?: number,
|
||||
@Args('threshold', { type: () => Float, nullable: true })
|
||||
threshold?: number
|
||||
) {
|
||||
if (!this.context.canEmbedding) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const lockFlag = `${COPILOT_LOCKER}:context:${context.id}`;
|
||||
await using lock = await this.mutex.acquire(lockFlag);
|
||||
if (!lock) {
|
||||
return new TooManyRequest('Server is busy');
|
||||
}
|
||||
const session = await this.context.get(context.id);
|
||||
|
||||
try {
|
||||
return await session.matchFileChunks(
|
||||
content,
|
||||
limit,
|
||||
this.getSignal(ctx.req),
|
||||
threshold
|
||||
);
|
||||
} catch (e: any) {
|
||||
throw new CopilotFailedToMatchContext({
|
||||
contextId: context.id,
|
||||
// don't record the large content
|
||||
content: content.slice(0, 512),
|
||||
message: e.message,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@ResolveField(() => ContextMatchedDocChunk, {
|
||||
description: 'match workspace doc content',
|
||||
})
|
||||
@CallMetric('ai', 'context_match_workspace_doc')
|
||||
async matchWorkspaceContext(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Context() ctx: { req: Request },
|
||||
@Parent() context: CopilotContextType,
|
||||
@Args('content') content: string,
|
||||
@Args('limit', { type: () => SafeIntResolver, nullable: true })
|
||||
limit?: number
|
||||
) {
|
||||
if (!this.context.canEmbedding) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const session = await this.context.get(context.id);
|
||||
await this.ac
|
||||
.user(user.id)
|
||||
.workspace(session.workspaceId)
|
||||
.allowLocal()
|
||||
.assert('Workspace.Copilot');
|
||||
|
||||
try {
|
||||
return await session.matchWorkspaceChunks(
|
||||
content,
|
||||
limit,
|
||||
this.getSignal(ctx.req)
|
||||
);
|
||||
} catch (e: any) {
|
||||
throw new CopilotFailedToMatchContext({
|
||||
contextId: context.id,
|
||||
// don't record the large content
|
||||
content: content.slice(0, 512),
|
||||
message: e.message,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user