diff --git a/packages/backend/server/package.json b/packages/backend/server/package.json index 8768adec35..61a1cd94c8 100644 --- a/packages/backend/server/package.json +++ b/packages/backend/server/package.json @@ -77,6 +77,7 @@ "nestjs-throttler-storage-redis": "^0.4.1", "nodemailer": "^6.9.10", "on-headers": "^1.0.2", + "openai": "^4.29.2", "parse-duration": "^1.1.0", "pretty-time": "^1.1.0", "prisma": "^5.10.2", diff --git a/packages/backend/server/src/config/affine.self.ts b/packages/backend/server/src/config/affine.self.ts index cecc33fbbc..5cc19b37f7 100644 --- a/packages/backend/server/src/config/affine.self.ts +++ b/packages/backend/server/src/config/affine.self.ts @@ -38,6 +38,11 @@ if (env.R2_OBJECT_STORAGE_ACCOUNT_ID) { }`; } +AFFiNE.plugins.use('copilot', { + openai: { + apiKey: 'test', + }, +}); AFFiNE.plugins.use('redis'); AFFiNE.plugins.use('payment', { stripe: { diff --git a/packages/backend/server/src/core/config.ts b/packages/backend/server/src/core/config.ts index e34ee1dc51..2d09d92c3f 100644 --- a/packages/backend/server/src/core/config.ts +++ b/packages/backend/server/src/core/config.ts @@ -5,6 +5,7 @@ import { DeploymentType } from '../fundamentals'; import { Public } from './auth'; export enum ServerFeature { + Copilot = 'copilot', Payment = 'payment', OAuth = 'oauth', } diff --git a/packages/backend/server/src/plugins/config.ts b/packages/backend/server/src/plugins/config.ts index eea08c491f..ba512a65d9 100644 --- a/packages/backend/server/src/plugins/config.ts +++ b/packages/backend/server/src/plugins/config.ts @@ -1,3 +1,4 @@ +import { CopilotConfig } from './copilot'; import { GCloudConfig } from './gcloud/config'; import { OAuthConfig } from './oauth'; import { PaymentConfig } from './payment'; @@ -6,6 +7,7 @@ import { R2StorageConfig, S3StorageConfig } from './storage'; declare module '../fundamentals/config' { interface PluginsConfig { + readonly copilot: CopilotConfig; readonly payment: PaymentConfig; readonly redis: RedisOptions; readonly gcloud: GCloudConfig; diff --git a/packages/backend/server/src/plugins/copilot/index.ts b/packages/backend/server/src/plugins/copilot/index.ts new file mode 100644 index 0000000000..53dd28a178 --- /dev/null +++ b/packages/backend/server/src/plugins/copilot/index.ts @@ -0,0 +1,18 @@ +import { ServerFeature } from '../../core/config'; +import { Plugin } from '../registry'; +import { assertProvidersConfigs, CopilotProviderService } from './provider'; + +@Plugin({ + name: 'copilot', + providers: [CopilotProviderService], + contributesTo: ServerFeature.Copilot, + if: config => { + if (config.flavor.graphql) { + return assertProvidersConfigs(config); + } + return false; + }, +}) +export class CopilotModule {} + +export type { CopilotConfig } from './types'; diff --git a/packages/backend/server/src/plugins/copilot/provider.ts b/packages/backend/server/src/plugins/copilot/provider.ts new file mode 100644 index 0000000000..24bf67ba63 --- /dev/null +++ b/packages/backend/server/src/plugins/copilot/provider.ts @@ -0,0 +1,135 @@ +import assert from 'node:assert'; + +import { Injectable, Logger } from '@nestjs/common'; + +import { Config } from '../../fundamentals'; +import { + CapabilityToCopilotProvider, + CopilotConfig, + CopilotProvider, + CopilotProviderCapability, + CopilotProviderType, +} from './types'; + +type CopilotProviderConfig = CopilotConfig[keyof CopilotConfig]; + +interface CopilotProviderDefinition { + // constructor signature + new (config: C): CopilotProvider; + // type of the provider + readonly type: CopilotProviderType; + // capabilities of the provider, like text to text, text to image, etc. + readonly capabilities: CopilotProviderCapability[]; + // asserts that the config is valid for this provider + assetsConfig(config: C): boolean; +} + +// registered provider factory +const COPILOT_PROVIDER = new Map< + CopilotProviderType, + (config: Config, logger: Logger) => CopilotProvider +>(); + +// map of capabilities to providers +const PROVIDER_CAPABILITY_MAP = new Map< + CopilotProviderCapability, + CopilotProviderType[] +>(); + +// config assertions for providers +const ASSERT_CONFIG = new Map void>(); + +export function registerCopilotProvider< + C extends CopilotProviderConfig = CopilotProviderConfig, +>(provider: CopilotProviderDefinition) { + const type = provider.type; + + const factory = (config: Config, logger: Logger) => { + const providerConfig = config.plugins.copilot?.[type]; + if (!provider.assetsConfig(providerConfig as C)) { + throw new Error( + `Invalid configuration for copilot provider ${type}: ${providerConfig}` + ); + } + const instance = new provider(providerConfig as C); + logger.log( + `Copilot provider ${type} registered, capabilities: ${provider.capabilities.join(', ')}` + ); + + return instance; + }; + // register the provider + COPILOT_PROVIDER.set(type, factory); + // register the provider capabilities + for (const capability of provider.capabilities) { + const providers = PROVIDER_CAPABILITY_MAP.get(capability) || []; + if (!providers.includes(type)) { + providers.push(type); + } + PROVIDER_CAPABILITY_MAP.set(capability, providers); + } + // register the provider config assertion + ASSERT_CONFIG.set(type, (config: Config) => { + assert(config.plugins.copilot); + const providerConfig = config.plugins.copilot[type]; + if (!providerConfig) return false; + return provider.assetsConfig(providerConfig as C); + }); +} + +/// Asserts that the config is valid for any registered providers +export function assertProvidersConfigs(config: Config) { + return ( + Array.from(ASSERT_CONFIG.values()).findIndex(assertConfig => + assertConfig(config) + ) !== -1 + ); +} + +@Injectable() +export class CopilotProviderService { + private readonly logger = new Logger(CopilotProviderService.name); + constructor(private readonly config: Config) {} + + private readonly cachedProviders = new Map< + CopilotProviderType, + CopilotProvider + >(); + + private create(provider: CopilotProviderType): CopilotProvider { + assert(this.config.plugins.copilot); + const providerFactory = COPILOT_PROVIDER.get(provider); + + if (!providerFactory) { + throw new Error(`Unknown copilot provider type: ${provider}`); + } + + return providerFactory(this.config, this.logger); + } + + getProvider(provider: CopilotProviderType): CopilotProvider { + if (!this.cachedProviders.has(provider)) { + this.cachedProviders.set(provider, this.create(provider)); + } + + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + return this.cachedProviders.get(provider)!; + } + + getProviderByCapability( + capability: C, + prefer?: CopilotProviderType + ): CapabilityToCopilotProvider[C] | null { + const providers = PROVIDER_CAPABILITY_MAP.get(capability); + if (Array.isArray(providers) && providers.length) { + const selectedCapability = + prefer && providers.includes(prefer) ? prefer : providers[0]; + + const provider = this.getProvider(selectedCapability); + assert(provider.getCapabilities().includes(capability)); + + return provider as CapabilityToCopilotProvider[C]; + } + return null; + } +} diff --git a/packages/backend/server/src/plugins/copilot/types.ts b/packages/backend/server/src/plugins/copilot/types.ts new file mode 100644 index 0000000000..fd72c4fd32 --- /dev/null +++ b/packages/backend/server/src/plugins/copilot/types.ts @@ -0,0 +1,50 @@ +import type { ClientOptions as OpenAIClientOptions } from 'openai'; + +export interface CopilotConfig { + openai: OpenAIClientOptions; + fal: { + secret: string; + }; +} + +export enum CopilotProviderType { + FAL = 'fal', + OpenAI = 'openai', +} + +export enum CopilotProviderCapability { + TextToText = 'text-to-text', + TextToEmbedding = 'text-to-embedding', + TextToImage = 'text-to-image', + ImageToImage = 'image-to-image', +} + +export interface CopilotProvider { + getCapabilities(): CopilotProviderCapability[]; +} + +export type ChatMessage = { + role: 'system' | 'assistant' | 'user'; + content: string; +}; + +export interface CopilotTextToTextProvider extends CopilotProvider { + generateText(messages: ChatMessage[], model: string): Promise; + generateTextStream( + messages: ChatMessage[], + model: string + ): AsyncIterable; +} + +export interface CopilotTextToEmbeddingProvider extends CopilotProvider {} + +export interface CopilotTextToImageProvider extends CopilotProvider {} + +export interface CopilotImageToImageProvider extends CopilotProvider {} + +export type CapabilityToCopilotProvider = { + [CopilotProviderCapability.TextToText]: CopilotTextToTextProvider; + [CopilotProviderCapability.TextToEmbedding]: CopilotTextToEmbeddingProvider; + [CopilotProviderCapability.TextToImage]: CopilotTextToImageProvider; + [CopilotProviderCapability.ImageToImage]: CopilotImageToImageProvider; +}; diff --git a/packages/backend/server/src/plugins/index.ts b/packages/backend/server/src/plugins/index.ts index 42ea147ad3..9d82b90c10 100644 --- a/packages/backend/server/src/plugins/index.ts +++ b/packages/backend/server/src/plugins/index.ts @@ -1,3 +1,4 @@ +import './copilot'; import './gcloud'; import './oauth'; import './payment'; diff --git a/packages/backend/server/src/schema.gql b/packages/backend/server/src/schema.gql index e76bda04f8..982b4fa9d2 100644 --- a/packages/backend/server/src/schema.gql +++ b/packages/backend/server/src/schema.gql @@ -271,6 +271,7 @@ enum ServerDeploymentType { } enum ServerFeature { + Copilot OAuth Payment } diff --git a/yarn.lock b/yarn.lock index 17e5f89fb7..6dc9399197 100644 --- a/yarn.lock +++ b/yarn.lock @@ -726,6 +726,7 @@ __metadata: nodemailer: "npm:^6.9.10" nodemon: "npm:^3.1.0" on-headers: "npm:^1.0.2" + openai: "npm:^4.29.2" parse-duration: "npm:^1.1.0" pretty-time: "npm:^1.1.0" prisma: "npm:^5.10.2"