mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-14 21:27:20 +00:00
@@ -0,0 +1,3 @@
|
||||
export { CopilotContextResolver, CopilotContextRootResolver } from './resolver';
|
||||
export { CopilotContextService } from './service';
|
||||
export { type ContextFile, ContextFileStatus } from './types';
|
||||
260
packages/backend/server/src/plugins/copilot/context/resolver.ts
Normal file
260
packages/backend/server/src/plugins/copilot/context/resolver.ts
Normal file
@@ -0,0 +1,260 @@
|
||||
import {
|
||||
Args,
|
||||
Field,
|
||||
ID,
|
||||
InputType,
|
||||
Mutation,
|
||||
ObjectType,
|
||||
Parent,
|
||||
registerEnumType,
|
||||
ResolveField,
|
||||
Resolver,
|
||||
} from '@nestjs/graphql';
|
||||
import { SafeIntResolver } from 'graphql-scalars';
|
||||
|
||||
import {
|
||||
CallMetric,
|
||||
CopilotFailedToModifyContext,
|
||||
CopilotSessionNotFound,
|
||||
RequestMutex,
|
||||
Throttle,
|
||||
TooManyRequest,
|
||||
} from '../../../base';
|
||||
import { CurrentUser } from '../../../core/auth';
|
||||
import { COPILOT_LOCKER, CopilotType } from '../resolver';
|
||||
import { ChatSessionService } from '../session';
|
||||
import { CopilotContextService } from './service';
|
||||
import { ContextDoc, type ContextFile, ContextFileStatus } from './types';
|
||||
|
||||
@InputType()
|
||||
class AddContextDocInput {
|
||||
@Field(() => String)
|
||||
contextId!: string;
|
||||
|
||||
@Field(() => String)
|
||||
docId!: string;
|
||||
}
|
||||
|
||||
@InputType()
|
||||
class RemoveContextFileInput {
|
||||
@Field(() => String)
|
||||
contextId!: string;
|
||||
|
||||
@Field(() => String)
|
||||
fileId!: string;
|
||||
}
|
||||
|
||||
@ObjectType('CopilotContext')
|
||||
export class CopilotContextType {
|
||||
@Field(() => ID)
|
||||
id!: string;
|
||||
|
||||
@Field(() => String)
|
||||
workspaceId!: string;
|
||||
}
|
||||
|
||||
registerEnumType(ContextFileStatus, { name: 'ContextFileStatus' });
|
||||
|
||||
@ObjectType()
|
||||
class CopilotContextDoc implements ContextDoc {
|
||||
@Field(() => ID)
|
||||
id!: string;
|
||||
|
||||
@Field(() => SafeIntResolver)
|
||||
createdAt!: number;
|
||||
}
|
||||
|
||||
@ObjectType()
|
||||
class CopilotContextFile implements ContextFile {
|
||||
@Field(() => ID)
|
||||
id!: string;
|
||||
|
||||
@Field(() => String)
|
||||
name!: string;
|
||||
|
||||
@Field(() => SafeIntResolver)
|
||||
chunkSize!: number;
|
||||
|
||||
@Field(() => ContextFileStatus)
|
||||
status!: ContextFileStatus;
|
||||
|
||||
@Field(() => String)
|
||||
blobId!: string;
|
||||
|
||||
@Field(() => SafeIntResolver)
|
||||
createdAt!: number;
|
||||
}
|
||||
|
||||
@ObjectType()
|
||||
class CopilotContextListItem {
|
||||
@Field(() => ID)
|
||||
id!: string;
|
||||
|
||||
@Field(() => SafeIntResolver)
|
||||
createdAt!: number;
|
||||
|
||||
@Field(() => String, { nullable: true })
|
||||
name!: string;
|
||||
|
||||
@Field(() => SafeIntResolver, { nullable: true })
|
||||
chunkSize!: number;
|
||||
|
||||
@Field(() => ContextFileStatus, { nullable: true })
|
||||
status!: ContextFileStatus;
|
||||
|
||||
@Field(() => String, { nullable: true })
|
||||
blobId!: string;
|
||||
}
|
||||
|
||||
@Throttle()
|
||||
@Resolver(() => CopilotType)
|
||||
export class CopilotContextRootResolver {
|
||||
constructor(
|
||||
private readonly mutex: RequestMutex,
|
||||
private readonly chatSession: ChatSessionService,
|
||||
private readonly context: CopilotContextService
|
||||
) {}
|
||||
|
||||
private async checkChatSession(
|
||||
user: CurrentUser,
|
||||
sessionId: string,
|
||||
workspaceId?: string
|
||||
): Promise<void> {
|
||||
const session = await this.chatSession.get(sessionId);
|
||||
if (
|
||||
!session ||
|
||||
session.config.workspaceId !== workspaceId ||
|
||||
session.config.userId !== user.id
|
||||
) {
|
||||
throw new CopilotSessionNotFound();
|
||||
}
|
||||
}
|
||||
|
||||
@ResolveField(() => [CopilotContextType], {
|
||||
description: 'Get the context list of a session',
|
||||
complexity: 2,
|
||||
})
|
||||
@CallMetric('ai', 'context_create')
|
||||
async contexts(
|
||||
@Parent() copilot: CopilotType,
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args('sessionId') sessionId: string,
|
||||
@Args('contextId', { nullable: true }) contextId?: string
|
||||
) {
|
||||
const lockFlag = `${COPILOT_LOCKER}:context:${sessionId}`;
|
||||
await using lock = await this.mutex.acquire(lockFlag);
|
||||
if (!lock) {
|
||||
return new TooManyRequest('Server is busy');
|
||||
}
|
||||
await this.checkChatSession(user, sessionId, copilot.workspaceId);
|
||||
|
||||
if (contextId) {
|
||||
const context = await this.context.get(contextId);
|
||||
if (context) return [context];
|
||||
} else {
|
||||
const context = await this.context.getBySessionId(sessionId);
|
||||
if (context) return [context];
|
||||
}
|
||||
return [];
|
||||
}
|
||||
|
||||
@Mutation(() => String, {
|
||||
description: 'Create a context session',
|
||||
})
|
||||
@CallMetric('ai', 'context_create')
|
||||
async createCopilotContext(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args('workspaceId') workspaceId: string,
|
||||
@Args('sessionId') sessionId: string
|
||||
) {
|
||||
const lockFlag = `${COPILOT_LOCKER}:context:${sessionId}`;
|
||||
await using lock = await this.mutex.acquire(lockFlag);
|
||||
if (!lock) {
|
||||
return new TooManyRequest('Server is busy');
|
||||
}
|
||||
await this.checkChatSession(user, sessionId, workspaceId);
|
||||
|
||||
const context = await this.context.create(sessionId);
|
||||
return context.id;
|
||||
}
|
||||
}
|
||||
|
||||
@Throttle()
|
||||
@Resolver(() => CopilotContextType)
|
||||
export class CopilotContextResolver {
|
||||
constructor(
|
||||
private readonly mutex: RequestMutex,
|
||||
|
||||
private readonly context: CopilotContextService
|
||||
) {}
|
||||
|
||||
@ResolveField(() => [CopilotContextDoc], {
|
||||
description: 'list files in context',
|
||||
})
|
||||
@CallMetric('ai', 'context_file_list')
|
||||
async docs(@Parent() context: CopilotContextType): Promise<ContextDoc[]> {
|
||||
const session = await this.context.get(context.id);
|
||||
return session.listDocs();
|
||||
}
|
||||
|
||||
@Mutation(() => [CopilotContextListItem], {
|
||||
description: 'add a doc to context',
|
||||
})
|
||||
@CallMetric('ai', 'context_doc_add')
|
||||
async addContextDoc(
|
||||
@Args({ name: 'options', type: () => AddContextDocInput })
|
||||
options: AddContextDocInput
|
||||
) {
|
||||
const lockFlag = `${COPILOT_LOCKER}:context:${options.contextId}`;
|
||||
await using lock = await this.mutex.acquire(lockFlag);
|
||||
if (!lock) {
|
||||
return new TooManyRequest('Server is busy');
|
||||
}
|
||||
const session = await this.context.get(options.contextId);
|
||||
|
||||
try {
|
||||
return await session.addDocRecord(options.docId);
|
||||
} catch (e: any) {
|
||||
throw new CopilotFailedToModifyContext({
|
||||
contextId: options.contextId,
|
||||
message: e.message,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@Mutation(() => Boolean, {
|
||||
description: 'remove a doc from context',
|
||||
})
|
||||
@CallMetric('ai', 'context_doc_remove')
|
||||
async removeContextDoc(
|
||||
@Args({ name: 'options', type: () => RemoveContextFileInput })
|
||||
options: RemoveContextFileInput
|
||||
) {
|
||||
const lockFlag = `${COPILOT_LOCKER}:context:${options.contextId}`;
|
||||
await using lock = await this.mutex.acquire(lockFlag);
|
||||
if (!lock) {
|
||||
return new TooManyRequest('Server is busy');
|
||||
}
|
||||
const session = await this.context.get(options.contextId);
|
||||
|
||||
try {
|
||||
return await session.removeDocRecord(options.fileId);
|
||||
} catch (e: any) {
|
||||
throw new CopilotFailedToModifyContext({
|
||||
contextId: options.contextId,
|
||||
message: e.message,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@ResolveField(() => [CopilotContextFile], {
|
||||
description: 'list files in context',
|
||||
})
|
||||
@CallMetric('ai', 'context_file_list')
|
||||
async files(
|
||||
@Parent() context: CopilotContextType
|
||||
): Promise<CopilotContextFile[]> {
|
||||
const session = await this.context.get(context.id);
|
||||
return session.listFiles();
|
||||
}
|
||||
}
|
||||
113
packages/backend/server/src/plugins/copilot/context/service.ts
Normal file
113
packages/backend/server/src/plugins/copilot/context/service.ts
Normal file
@@ -0,0 +1,113 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
|
||||
import {
|
||||
Cache,
|
||||
CopilotInvalidContext,
|
||||
CopilotSessionNotFound,
|
||||
} from '../../../base';
|
||||
import { ContextSession } from './session';
|
||||
import { ContextConfig, ContextConfigSchema } from './types';
|
||||
|
||||
const CONTEXT_SESSION_KEY = 'context-session';
|
||||
|
||||
@Injectable()
|
||||
export class CopilotContextService {
|
||||
constructor(
|
||||
private readonly cache: Cache,
|
||||
private readonly db: PrismaClient
|
||||
) {}
|
||||
|
||||
private async saveConfig(
|
||||
contextId: string,
|
||||
config: ContextConfig,
|
||||
refreshCache = false
|
||||
): Promise<void> {
|
||||
if (!refreshCache) {
|
||||
await this.db.aiContext.update({
|
||||
where: { id: contextId },
|
||||
data: { config },
|
||||
});
|
||||
}
|
||||
await this.cache.set(`${CONTEXT_SESSION_KEY}:${contextId}`, config);
|
||||
}
|
||||
|
||||
private async getCachedSession(
|
||||
contextId: string
|
||||
): Promise<ContextSession | undefined> {
|
||||
const cachedSession = await this.cache.get(
|
||||
`${CONTEXT_SESSION_KEY}:${contextId}`
|
||||
);
|
||||
if (cachedSession) {
|
||||
const config = ContextConfigSchema.safeParse(cachedSession);
|
||||
if (config.success) {
|
||||
return new ContextSession(
|
||||
contextId,
|
||||
config.data,
|
||||
this.saveConfig.bind(this, contextId)
|
||||
);
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// NOTE: we only cache config to avoid frequent database queries
|
||||
// but we do not need to cache session instances because a distributed
|
||||
// lock is already apply to mutation operation for the same context in
|
||||
// the resolver, so there will be no simultaneous writing to the config
|
||||
private async cacheSession(
|
||||
contextId: string,
|
||||
config: ContextConfig
|
||||
): Promise<ContextSession> {
|
||||
const dispatcher = this.saveConfig.bind(this, contextId);
|
||||
await dispatcher(config, true);
|
||||
return new ContextSession(contextId, config, dispatcher);
|
||||
}
|
||||
|
||||
async create(sessionId: string): Promise<ContextSession> {
|
||||
const session = await this.db.aiSession.findFirst({
|
||||
where: { id: sessionId },
|
||||
select: { workspaceId: true },
|
||||
});
|
||||
if (!session) {
|
||||
throw new CopilotSessionNotFound();
|
||||
}
|
||||
|
||||
// keep the context unique per session
|
||||
const existsContext = await this.getBySessionId(sessionId);
|
||||
if (existsContext) return existsContext;
|
||||
|
||||
const context = await this.db.aiContext.create({
|
||||
data: {
|
||||
sessionId,
|
||||
config: { workspaceId: session.workspaceId, docs: [], files: [] },
|
||||
},
|
||||
});
|
||||
|
||||
const config = ContextConfigSchema.parse(context.config);
|
||||
return await this.cacheSession(context.id, config);
|
||||
}
|
||||
|
||||
async get(id: string): Promise<ContextSession> {
|
||||
const context = await this.getCachedSession(id);
|
||||
if (context) return context;
|
||||
const ret = await this.db.aiContext.findUnique({
|
||||
where: { id },
|
||||
select: { config: true },
|
||||
});
|
||||
if (ret) {
|
||||
const config = ContextConfigSchema.safeParse(ret.config);
|
||||
if (config.success) return this.cacheSession(id, config.data);
|
||||
}
|
||||
throw new CopilotInvalidContext({ contextId: id });
|
||||
}
|
||||
|
||||
async getBySessionId(sessionId: string): Promise<ContextSession | null> {
|
||||
const existsContext = await this.db.aiContext.findFirst({
|
||||
where: { sessionId },
|
||||
select: { id: true },
|
||||
});
|
||||
if (existsContext) return this.get(existsContext.id);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
import { ContextConfig, ContextDoc, ContextList } from './types';
|
||||
|
||||
export class ContextSession implements AsyncDisposable {
|
||||
constructor(
|
||||
private readonly contextId: string,
|
||||
private readonly config: ContextConfig,
|
||||
private readonly dispatcher?: (config: ContextConfig) => Promise<void>
|
||||
) {}
|
||||
|
||||
get id() {
|
||||
return this.contextId;
|
||||
}
|
||||
|
||||
get workspaceId() {
|
||||
return this.config.workspaceId;
|
||||
}
|
||||
|
||||
listDocs(): ContextDoc[] {
|
||||
return [...this.config.docs];
|
||||
}
|
||||
|
||||
listFiles() {
|
||||
return this.config.files.map(f => ({ ...f }));
|
||||
}
|
||||
|
||||
get sortedList(): ContextList {
|
||||
const { docs, files } = this.config;
|
||||
return [...docs, ...files].toSorted(
|
||||
(a, b) => a.createdAt - b.createdAt
|
||||
) as ContextList;
|
||||
}
|
||||
|
||||
async addDocRecord(docId: string): Promise<ContextList> {
|
||||
if (!this.config.docs.some(f => f.id === docId)) {
|
||||
this.config.docs.push({ id: docId, createdAt: Date.now() });
|
||||
await this.save();
|
||||
}
|
||||
return this.sortedList;
|
||||
}
|
||||
|
||||
async removeDocRecord(docId: string): Promise<boolean> {
|
||||
const index = this.config.docs.findIndex(f => f.id === docId);
|
||||
if (index >= 0) {
|
||||
this.config.docs.splice(index, 1);
|
||||
await this.save();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
async save() {
|
||||
await this.dispatcher?.(this.config);
|
||||
}
|
||||
|
||||
async [Symbol.asyncDispose]() {
|
||||
await this.save();
|
||||
}
|
||||
}
|
||||
69
packages/backend/server/src/plugins/copilot/context/types.ts
Normal file
69
packages/backend/server/src/plugins/copilot/context/types.ts
Normal file
@@ -0,0 +1,69 @@
|
||||
import { z } from 'zod';
|
||||
|
||||
declare global {
|
||||
interface Events {
|
||||
'workspace.doc.embedding': {
|
||||
workspaceId: string;
|
||||
docId: string;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
export enum ContextFileStatus {
|
||||
processing = 'processing',
|
||||
finished = 'finished',
|
||||
failed = 'failed',
|
||||
}
|
||||
|
||||
export const ContextConfigSchema = z.object({
|
||||
workspaceId: z.string(),
|
||||
files: z
|
||||
.object({
|
||||
id: z.string(),
|
||||
chunkSize: z.number(),
|
||||
name: z.string(),
|
||||
status: z.enum([
|
||||
ContextFileStatus.processing,
|
||||
ContextFileStatus.finished,
|
||||
ContextFileStatus.failed,
|
||||
]),
|
||||
blobId: z.string(),
|
||||
createdAt: z.number(),
|
||||
})
|
||||
.array(),
|
||||
docs: z
|
||||
.object({
|
||||
id: z.string(),
|
||||
createdAt: z.number(),
|
||||
})
|
||||
.array(),
|
||||
});
|
||||
|
||||
export type ContextConfig = z.infer<typeof ContextConfigSchema>;
|
||||
export type ContextDoc = z.infer<typeof ContextConfigSchema>['docs'][number];
|
||||
export type ContextFile = z.infer<typeof ContextConfigSchema>['files'][number];
|
||||
export type ContextListItem = ContextDoc | ContextFile;
|
||||
export type ContextList = ContextListItem[];
|
||||
|
||||
export type ChunkSimilarity = {
|
||||
chunk: number;
|
||||
content: string;
|
||||
distance: number | null;
|
||||
};
|
||||
|
||||
export type FileChunkSimilarity = ChunkSimilarity & {
|
||||
fileId: string;
|
||||
};
|
||||
|
||||
export type DocChunkSimilarity = ChunkSimilarity & {
|
||||
docId: string;
|
||||
};
|
||||
|
||||
export type Embedding = {
|
||||
/**
|
||||
* The index of the embedding in the list of embeddings.
|
||||
*/
|
||||
index: number;
|
||||
content: string;
|
||||
embedding: Array<number>;
|
||||
};
|
||||
11
packages/backend/server/src/plugins/copilot/context/utils.ts
Normal file
11
packages/backend/server/src/plugins/copilot/context/utils.ts
Normal file
@@ -0,0 +1,11 @@
|
||||
export class GqlSignal implements AsyncDisposable {
|
||||
readonly abortController = new AbortController();
|
||||
|
||||
get signal() {
|
||||
return this.abortController.signal;
|
||||
}
|
||||
|
||||
async [Symbol.asyncDispose]() {
|
||||
this.abortController.abort();
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,11 @@ import { FeatureModule } from '../../core/features';
|
||||
import { PermissionModule } from '../../core/permission';
|
||||
import { QuotaModule } from '../../core/quota';
|
||||
import { Plugin } from '../registry';
|
||||
import {
|
||||
CopilotContextResolver,
|
||||
CopilotContextRootResolver,
|
||||
CopilotContextService,
|
||||
} from './context';
|
||||
import { CopilotController } from './controller';
|
||||
import { ChatMessageCache } from './message';
|
||||
import { PromptService } from './prompt';
|
||||
@@ -41,8 +46,13 @@ registerCopilotProvider(PerplexityProvider);
|
||||
CopilotProviderService,
|
||||
CopilotStorage,
|
||||
PromptsManagementResolver,
|
||||
// workflow
|
||||
CopilotWorkflowService,
|
||||
...CopilotWorkflowExecutors,
|
||||
// context
|
||||
CopilotContextRootResolver,
|
||||
CopilotContextResolver,
|
||||
CopilotContextService,
|
||||
],
|
||||
controllers: [CopilotController],
|
||||
contributesTo: ServerFeature.Copilot,
|
||||
|
||||
@@ -59,6 +59,7 @@ export class OpenAIProvider
|
||||
|
||||
private readonly logger = new Logger(OpenAIProvider.type);
|
||||
private readonly instance: OpenAI;
|
||||
|
||||
private existsModels: string[] | undefined;
|
||||
|
||||
constructor(config: ClientOptions) {
|
||||
|
||||
@@ -23,7 +23,7 @@ import {
|
||||
CallMetric,
|
||||
CopilotFailedToCreateMessage,
|
||||
CopilotSessionNotFound,
|
||||
FileUpload,
|
||||
type FileUpload,
|
||||
RequestMutex,
|
||||
Throttle,
|
||||
TooManyRequest,
|
||||
|
||||
@@ -198,6 +198,13 @@ const CopilotImageOptionsSchema = CopilotProviderOptionsSchema.merge(
|
||||
|
||||
export type CopilotImageOptions = z.infer<typeof CopilotImageOptionsSchema>;
|
||||
|
||||
export type CopilotContextFile = {
|
||||
id: string; // fileId
|
||||
created_at: number;
|
||||
// embedding status
|
||||
status: 'in_progress' | 'completed' | 'failed';
|
||||
};
|
||||
|
||||
export interface CopilotProvider {
|
||||
readonly type: CopilotProviderType;
|
||||
getCapabilities(): CopilotCapability[];
|
||||
|
||||
Reference in New Issue
Block a user