feat(server): basic context api (#10056)

fix CLOUD-97
fix CLOUD-98
This commit is contained in:
darkskygit
2025-02-11 10:45:00 +00:00
parent a47369bf9b
commit a725df6ebe
41 changed files with 1698 additions and 374 deletions

View File

@@ -0,0 +1,3 @@
export { CopilotContextResolver, CopilotContextRootResolver } from './resolver';
export { CopilotContextService } from './service';
export { type ContextFile, ContextFileStatus } from './types';

View File

@@ -0,0 +1,260 @@
import {
Args,
Field,
ID,
InputType,
Mutation,
ObjectType,
Parent,
registerEnumType,
ResolveField,
Resolver,
} from '@nestjs/graphql';
import { SafeIntResolver } from 'graphql-scalars';
import {
CallMetric,
CopilotFailedToModifyContext,
CopilotSessionNotFound,
RequestMutex,
Throttle,
TooManyRequest,
} from '../../../base';
import { CurrentUser } from '../../../core/auth';
import { COPILOT_LOCKER, CopilotType } from '../resolver';
import { ChatSessionService } from '../session';
import { CopilotContextService } from './service';
import { ContextDoc, type ContextFile, ContextFileStatus } from './types';
@InputType()
class AddContextDocInput {
@Field(() => String)
contextId!: string;
@Field(() => String)
docId!: string;
}
@InputType()
class RemoveContextFileInput {
@Field(() => String)
contextId!: string;
@Field(() => String)
fileId!: string;
}
@ObjectType('CopilotContext')
export class CopilotContextType {
@Field(() => ID)
id!: string;
@Field(() => String)
workspaceId!: string;
}
registerEnumType(ContextFileStatus, { name: 'ContextFileStatus' });
@ObjectType()
class CopilotContextDoc implements ContextDoc {
@Field(() => ID)
id!: string;
@Field(() => SafeIntResolver)
createdAt!: number;
}
@ObjectType()
class CopilotContextFile implements ContextFile {
@Field(() => ID)
id!: string;
@Field(() => String)
name!: string;
@Field(() => SafeIntResolver)
chunkSize!: number;
@Field(() => ContextFileStatus)
status!: ContextFileStatus;
@Field(() => String)
blobId!: string;
@Field(() => SafeIntResolver)
createdAt!: number;
}
@ObjectType()
class CopilotContextListItem {
@Field(() => ID)
id!: string;
@Field(() => SafeIntResolver)
createdAt!: number;
@Field(() => String, { nullable: true })
name!: string;
@Field(() => SafeIntResolver, { nullable: true })
chunkSize!: number;
@Field(() => ContextFileStatus, { nullable: true })
status!: ContextFileStatus;
@Field(() => String, { nullable: true })
blobId!: string;
}
@Throttle()
@Resolver(() => CopilotType)
export class CopilotContextRootResolver {
constructor(
private readonly mutex: RequestMutex,
private readonly chatSession: ChatSessionService,
private readonly context: CopilotContextService
) {}
private async checkChatSession(
user: CurrentUser,
sessionId: string,
workspaceId?: string
): Promise<void> {
const session = await this.chatSession.get(sessionId);
if (
!session ||
session.config.workspaceId !== workspaceId ||
session.config.userId !== user.id
) {
throw new CopilotSessionNotFound();
}
}
@ResolveField(() => [CopilotContextType], {
description: 'Get the context list of a session',
complexity: 2,
})
@CallMetric('ai', 'context_create')
async contexts(
@Parent() copilot: CopilotType,
@CurrentUser() user: CurrentUser,
@Args('sessionId') sessionId: string,
@Args('contextId', { nullable: true }) contextId?: string
) {
const lockFlag = `${COPILOT_LOCKER}:context:${sessionId}`;
await using lock = await this.mutex.acquire(lockFlag);
if (!lock) {
return new TooManyRequest('Server is busy');
}
await this.checkChatSession(user, sessionId, copilot.workspaceId);
if (contextId) {
const context = await this.context.get(contextId);
if (context) return [context];
} else {
const context = await this.context.getBySessionId(sessionId);
if (context) return [context];
}
return [];
}
@Mutation(() => String, {
description: 'Create a context session',
})
@CallMetric('ai', 'context_create')
async createCopilotContext(
@CurrentUser() user: CurrentUser,
@Args('workspaceId') workspaceId: string,
@Args('sessionId') sessionId: string
) {
const lockFlag = `${COPILOT_LOCKER}:context:${sessionId}`;
await using lock = await this.mutex.acquire(lockFlag);
if (!lock) {
return new TooManyRequest('Server is busy');
}
await this.checkChatSession(user, sessionId, workspaceId);
const context = await this.context.create(sessionId);
return context.id;
}
}
@Throttle()
@Resolver(() => CopilotContextType)
export class CopilotContextResolver {
constructor(
private readonly mutex: RequestMutex,
private readonly context: CopilotContextService
) {}
@ResolveField(() => [CopilotContextDoc], {
description: 'list files in context',
})
@CallMetric('ai', 'context_file_list')
async docs(@Parent() context: CopilotContextType): Promise<ContextDoc[]> {
const session = await this.context.get(context.id);
return session.listDocs();
}
@Mutation(() => [CopilotContextListItem], {
description: 'add a doc to context',
})
@CallMetric('ai', 'context_doc_add')
async addContextDoc(
@Args({ name: 'options', type: () => AddContextDocInput })
options: AddContextDocInput
) {
const lockFlag = `${COPILOT_LOCKER}:context:${options.contextId}`;
await using lock = await this.mutex.acquire(lockFlag);
if (!lock) {
return new TooManyRequest('Server is busy');
}
const session = await this.context.get(options.contextId);
try {
return await session.addDocRecord(options.docId);
} catch (e: any) {
throw new CopilotFailedToModifyContext({
contextId: options.contextId,
message: e.message,
});
}
}
@Mutation(() => Boolean, {
description: 'remove a doc from context',
})
@CallMetric('ai', 'context_doc_remove')
async removeContextDoc(
@Args({ name: 'options', type: () => RemoveContextFileInput })
options: RemoveContextFileInput
) {
const lockFlag = `${COPILOT_LOCKER}:context:${options.contextId}`;
await using lock = await this.mutex.acquire(lockFlag);
if (!lock) {
return new TooManyRequest('Server is busy');
}
const session = await this.context.get(options.contextId);
try {
return await session.removeDocRecord(options.fileId);
} catch (e: any) {
throw new CopilotFailedToModifyContext({
contextId: options.contextId,
message: e.message,
});
}
}
@ResolveField(() => [CopilotContextFile], {
description: 'list files in context',
})
@CallMetric('ai', 'context_file_list')
async files(
@Parent() context: CopilotContextType
): Promise<CopilotContextFile[]> {
const session = await this.context.get(context.id);
return session.listFiles();
}
}

View File

@@ -0,0 +1,113 @@
import { Injectable } from '@nestjs/common';
import { PrismaClient } from '@prisma/client';
import {
Cache,
CopilotInvalidContext,
CopilotSessionNotFound,
} from '../../../base';
import { ContextSession } from './session';
import { ContextConfig, ContextConfigSchema } from './types';
const CONTEXT_SESSION_KEY = 'context-session';
@Injectable()
export class CopilotContextService {
constructor(
private readonly cache: Cache,
private readonly db: PrismaClient
) {}
private async saveConfig(
contextId: string,
config: ContextConfig,
refreshCache = false
): Promise<void> {
if (!refreshCache) {
await this.db.aiContext.update({
where: { id: contextId },
data: { config },
});
}
await this.cache.set(`${CONTEXT_SESSION_KEY}:${contextId}`, config);
}
private async getCachedSession(
contextId: string
): Promise<ContextSession | undefined> {
const cachedSession = await this.cache.get(
`${CONTEXT_SESSION_KEY}:${contextId}`
);
if (cachedSession) {
const config = ContextConfigSchema.safeParse(cachedSession);
if (config.success) {
return new ContextSession(
contextId,
config.data,
this.saveConfig.bind(this, contextId)
);
}
}
return undefined;
}
// NOTE: we only cache config to avoid frequent database queries
// but we do not need to cache session instances because a distributed
// lock is already apply to mutation operation for the same context in
// the resolver, so there will be no simultaneous writing to the config
private async cacheSession(
contextId: string,
config: ContextConfig
): Promise<ContextSession> {
const dispatcher = this.saveConfig.bind(this, contextId);
await dispatcher(config, true);
return new ContextSession(contextId, config, dispatcher);
}
async create(sessionId: string): Promise<ContextSession> {
const session = await this.db.aiSession.findFirst({
where: { id: sessionId },
select: { workspaceId: true },
});
if (!session) {
throw new CopilotSessionNotFound();
}
// keep the context unique per session
const existsContext = await this.getBySessionId(sessionId);
if (existsContext) return existsContext;
const context = await this.db.aiContext.create({
data: {
sessionId,
config: { workspaceId: session.workspaceId, docs: [], files: [] },
},
});
const config = ContextConfigSchema.parse(context.config);
return await this.cacheSession(context.id, config);
}
async get(id: string): Promise<ContextSession> {
const context = await this.getCachedSession(id);
if (context) return context;
const ret = await this.db.aiContext.findUnique({
where: { id },
select: { config: true },
});
if (ret) {
const config = ContextConfigSchema.safeParse(ret.config);
if (config.success) return this.cacheSession(id, config.data);
}
throw new CopilotInvalidContext({ contextId: id });
}
async getBySessionId(sessionId: string): Promise<ContextSession | null> {
const existsContext = await this.db.aiContext.findFirst({
where: { sessionId },
select: { id: true },
});
if (existsContext) return this.get(existsContext.id);
return null;
}
}

View File

@@ -0,0 +1,58 @@
import { ContextConfig, ContextDoc, ContextList } from './types';
export class ContextSession implements AsyncDisposable {
constructor(
private readonly contextId: string,
private readonly config: ContextConfig,
private readonly dispatcher?: (config: ContextConfig) => Promise<void>
) {}
get id() {
return this.contextId;
}
get workspaceId() {
return this.config.workspaceId;
}
listDocs(): ContextDoc[] {
return [...this.config.docs];
}
listFiles() {
return this.config.files.map(f => ({ ...f }));
}
get sortedList(): ContextList {
const { docs, files } = this.config;
return [...docs, ...files].toSorted(
(a, b) => a.createdAt - b.createdAt
) as ContextList;
}
async addDocRecord(docId: string): Promise<ContextList> {
if (!this.config.docs.some(f => f.id === docId)) {
this.config.docs.push({ id: docId, createdAt: Date.now() });
await this.save();
}
return this.sortedList;
}
async removeDocRecord(docId: string): Promise<boolean> {
const index = this.config.docs.findIndex(f => f.id === docId);
if (index >= 0) {
this.config.docs.splice(index, 1);
await this.save();
return true;
}
return false;
}
async save() {
await this.dispatcher?.(this.config);
}
async [Symbol.asyncDispose]() {
await this.save();
}
}

View File

@@ -0,0 +1,69 @@
import { z } from 'zod';
declare global {
interface Events {
'workspace.doc.embedding': {
workspaceId: string;
docId: string;
};
}
}
export enum ContextFileStatus {
processing = 'processing',
finished = 'finished',
failed = 'failed',
}
export const ContextConfigSchema = z.object({
workspaceId: z.string(),
files: z
.object({
id: z.string(),
chunkSize: z.number(),
name: z.string(),
status: z.enum([
ContextFileStatus.processing,
ContextFileStatus.finished,
ContextFileStatus.failed,
]),
blobId: z.string(),
createdAt: z.number(),
})
.array(),
docs: z
.object({
id: z.string(),
createdAt: z.number(),
})
.array(),
});
export type ContextConfig = z.infer<typeof ContextConfigSchema>;
export type ContextDoc = z.infer<typeof ContextConfigSchema>['docs'][number];
export type ContextFile = z.infer<typeof ContextConfigSchema>['files'][number];
export type ContextListItem = ContextDoc | ContextFile;
export type ContextList = ContextListItem[];
export type ChunkSimilarity = {
chunk: number;
content: string;
distance: number | null;
};
export type FileChunkSimilarity = ChunkSimilarity & {
fileId: string;
};
export type DocChunkSimilarity = ChunkSimilarity & {
docId: string;
};
export type Embedding = {
/**
* The index of the embedding in the list of embeddings.
*/
index: number;
content: string;
embedding: Array<number>;
};

View File

@@ -0,0 +1,11 @@
export class GqlSignal implements AsyncDisposable {
readonly abortController = new AbortController();
get signal() {
return this.abortController.signal;
}
async [Symbol.asyncDispose]() {
this.abortController.abort();
}
}

View File

@@ -5,6 +5,11 @@ import { FeatureModule } from '../../core/features';
import { PermissionModule } from '../../core/permission';
import { QuotaModule } from '../../core/quota';
import { Plugin } from '../registry';
import {
CopilotContextResolver,
CopilotContextRootResolver,
CopilotContextService,
} from './context';
import { CopilotController } from './controller';
import { ChatMessageCache } from './message';
import { PromptService } from './prompt';
@@ -41,8 +46,13 @@ registerCopilotProvider(PerplexityProvider);
CopilotProviderService,
CopilotStorage,
PromptsManagementResolver,
// workflow
CopilotWorkflowService,
...CopilotWorkflowExecutors,
// context
CopilotContextRootResolver,
CopilotContextResolver,
CopilotContextService,
],
controllers: [CopilotController],
contributesTo: ServerFeature.Copilot,

View File

@@ -59,6 +59,7 @@ export class OpenAIProvider
private readonly logger = new Logger(OpenAIProvider.type);
private readonly instance: OpenAI;
private existsModels: string[] | undefined;
constructor(config: ClientOptions) {

View File

@@ -23,7 +23,7 @@ import {
CallMetric,
CopilotFailedToCreateMessage,
CopilotSessionNotFound,
FileUpload,
type FileUpload,
RequestMutex,
Throttle,
TooManyRequest,

View File

@@ -198,6 +198,13 @@ const CopilotImageOptionsSchema = CopilotProviderOptionsSchema.merge(
export type CopilotImageOptions = z.infer<typeof CopilotImageOptionsSchema>;
export type CopilotContextFile = {
id: string; // fileId
created_at: number;
// embedding status
status: 'in_progress' | 'completed' | 'failed';
};
export interface CopilotProvider {
readonly type: CopilotProviderType;
getCapabilities(): CopilotCapability[];