From 98e218af9322019daf9b96a92db12d2b046a6385 Mon Sep 17 00:00:00 2001 From: darkskygit Date: Tue, 14 May 2024 13:05:07 +0000 Subject: [PATCH] feat: allow undefined new model (#6933) --- .github/workflows/build-test.yml | 1 + .../server/src/plugins/copilot/controller.ts | 6 +- .../src/plugins/copilot/providers/fal.ts | 2 +- .../src/plugins/copilot/providers/index.ts | 8 +-- .../src/plugins/copilot/providers/openai.ts | 19 +++++- .../server/src/plugins/copilot/types.ts | 2 +- packages/backend/server/tests/copilot.spec.ts | 60 ++++++++++++++----- .../backend/server/tests/utils/copilot.ts | 2 +- 8 files changed, 74 insertions(+), 26 deletions(-) diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index be4ef08b29..1fb9436a44 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -351,6 +351,7 @@ jobs: env: CARGO_TARGET_DIR: '${{ github.workspace }}/target' DATABASE_URL: postgresql://affine:affine@localhost:5432/affine + COPILOT_OPENAI_API_KEY: ${{ secrets.COPILOT_OPENAI_API_KEY }} - name: Upload server test coverage results uses: codecov/codecov-action@v4 diff --git a/packages/backend/server/src/plugins/copilot/controller.ts b/packages/backend/server/src/plugins/copilot/controller.ts index 6fcb935ea3..042996df14 100644 --- a/packages/backend/server/src/plugins/copilot/controller.ts +++ b/packages/backend/server/src/plugins/copilot/controller.ts @@ -133,7 +133,7 @@ export class CopilotController { @Query() params: Record ): Promise { const { model } = await this.checkRequest(user.id, sessionId); - const provider = this.provider.getProviderByCapability( + const provider = await this.provider.getProviderByCapability( CopilotCapability.TextToText, model ); @@ -179,7 +179,7 @@ export class CopilotController { ): Promise> { try { const { model } = await this.checkRequest(user.id, sessionId); - const provider = this.provider.getProviderByCapability( + const provider = await this.provider.getProviderByCapability( CopilotCapability.TextToText, model ); @@ -246,7 +246,7 @@ export class CopilotController { sessionId, messageId ); - const provider = this.provider.getProviderByCapability( + const provider = await this.provider.getProviderByCapability( hasAttachment ? CopilotCapability.ImageToImage : CopilotCapability.TextToImage, diff --git a/packages/backend/server/src/plugins/copilot/providers/fal.ts b/packages/backend/server/src/plugins/copilot/providers/fal.ts index 7752bb93c7..cfb13d879a 100644 --- a/packages/backend/server/src/plugins/copilot/providers/fal.ts +++ b/packages/backend/server/src/plugins/copilot/providers/fal.ts @@ -50,7 +50,7 @@ export class FalProvider return FalProvider.capabilities; } - isModelAvailable(model: string): boolean { + async isModelAvailable(model: string): Promise { return this.availableModels.includes(model); } diff --git a/packages/backend/server/src/plugins/copilot/providers/index.ts b/packages/backend/server/src/plugins/copilot/providers/index.ts index addc873e3e..675e09e307 100644 --- a/packages/backend/server/src/plugins/copilot/providers/index.ts +++ b/packages/backend/server/src/plugins/copilot/providers/index.ts @@ -48,7 +48,7 @@ export function registerCopilotProvider< const providerConfig = config.plugins.copilot?.[type]; if (!provider.assetsConfig(providerConfig as C)) { throw new Error( - `Invalid configuration for copilot provider ${type}: ${providerConfig}` + `Invalid configuration for copilot provider ${type}: ${JSON.stringify(providerConfig)}` ); } const instance = new provider(providerConfig as C); @@ -116,11 +116,11 @@ export class CopilotProviderService { return this.cachedProviders.get(provider)!; } - getProviderByCapability( + async getProviderByCapability( capability: C, model?: string, prefer?: CopilotProviderType - ): CapabilityToCopilotProvider[C] | null { + ): Promise { const providers = PROVIDER_CAPABILITY_MAP.get(capability); if (Array.isArray(providers) && providers.length) { let selectedProvider: CopilotProviderType | undefined = prefer; @@ -137,7 +137,7 @@ export class CopilotProviderService { const provider = this.getProvider(selectedProvider); if (provider.getCapabilities().includes(capability)) { if (model) { - if (provider.isModelAvailable(model)) { + if (await provider.isModelAvailable(model)) { return provider as CapabilityToCopilotProvider[C]; } } else { diff --git a/packages/backend/server/src/plugins/copilot/providers/openai.ts b/packages/backend/server/src/plugins/copilot/providers/openai.ts index cd1724d0d6..02b6db73f2 100644 --- a/packages/backend/server/src/plugins/copilot/providers/openai.ts +++ b/packages/backend/server/src/plugins/copilot/providers/openai.ts @@ -1,5 +1,6 @@ import assert from 'node:assert'; +import { Logger } from '@nestjs/common'; import { ClientOptions, OpenAI } from 'openai'; import { @@ -51,7 +52,9 @@ export class OpenAIProvider 'dall-e-3', ]; + private readonly logger = new Logger(OpenAIProvider.type); private readonly instance: OpenAI; + private existsModels: string[] | undefined; constructor(config: ClientOptions) { assert(OpenAIProvider.assetsConfig(config)); @@ -70,8 +73,20 @@ export class OpenAIProvider return OpenAIProvider.capabilities; } - isModelAvailable(model: string): boolean { - return this.availableModels.includes(model); + async isModelAvailable(model: string): Promise { + const knownModels = this.availableModels.includes(model); + if (knownModels) return true; + + if (!this.existsModels) { + try { + this.existsModels = await this.instance.models + .list() + .then(({ data }) => data.map(m => m.id)); + } catch (e) { + this.logger.error('Failed to fetch online model list', e); + } + } + return !!this.existsModels?.includes(model); } protected chatToGPTMessage( diff --git a/packages/backend/server/src/plugins/copilot/types.ts b/packages/backend/server/src/plugins/copilot/types.ts index 64a770d635..ee23f5fe11 100644 --- a/packages/backend/server/src/plugins/copilot/types.ts +++ b/packages/backend/server/src/plugins/copilot/types.ts @@ -172,7 +172,7 @@ export type CopilotImageOptions = z.infer; export interface CopilotProvider { readonly type: CopilotProviderType; getCapabilities(): CopilotCapability[]; - isModelAvailable(model: string): boolean; + isModelAvailable(model: string): Promise; } export interface CopilotTextToTextProvider extends CopilotProvider { diff --git a/packages/backend/server/tests/copilot.spec.ts b/packages/backend/server/tests/copilot.spec.ts index f9976e8b17..6faf4fce1d 100644 --- a/packages/backend/server/tests/copilot.spec.ts +++ b/packages/backend/server/tests/copilot.spec.ts @@ -36,7 +36,7 @@ test.beforeEach(async t => { plugins: { copilot: { openai: { - apiKey: '1', + apiKey: process.env.COPILOT_OPENAI_API_KEY ?? '1', }, fal: { apiKey: '1', @@ -368,7 +368,9 @@ test('should be able to get provider', async t => { const { provider } = t.context; { - const p = provider.getProviderByCapability(CopilotCapability.TextToText); + const p = await provider.getProviderByCapability( + CopilotCapability.TextToText + ); t.is( p?.type.toString(), 'openai', @@ -377,7 +379,7 @@ test('should be able to get provider', async t => { } { - const p = provider.getProviderByCapability( + const p = await provider.getProviderByCapability( CopilotCapability.TextToEmbedding ); t.is( @@ -388,7 +390,9 @@ test('should be able to get provider', async t => { } { - const p = provider.getProviderByCapability(CopilotCapability.TextToImage); + const p = await provider.getProviderByCapability( + CopilotCapability.TextToImage + ); t.is( p?.type.toString(), 'fal', @@ -397,7 +401,9 @@ test('should be able to get provider', async t => { } { - const p = provider.getProviderByCapability(CopilotCapability.ImageToImage); + const p = await provider.getProviderByCapability( + CopilotCapability.ImageToImage + ); t.is( p?.type.toString(), 'fal', @@ -406,7 +412,9 @@ test('should be able to get provider', async t => { } { - const p = provider.getProviderByCapability(CopilotCapability.ImageToText); + const p = await provider.getProviderByCapability( + CopilotCapability.ImageToText + ); t.is( p?.type.toString(), 'openai', @@ -417,7 +425,7 @@ test('should be able to get provider', async t => { // text-to-image use fal by default, but this case can use // model dall-e-3 to select openai provider { - const p = provider.getProviderByCapability( + const p = await provider.getProviderByCapability( CopilotCapability.TextToImage, 'dall-e-3' ); @@ -427,14 +435,38 @@ test('should be able to get provider', async t => { 'should get provider support text-to-image and model' ); } + + // gpt4o is not defined now, but it already published by openai + // we should check from online api if it is available + { + const p = await provider.getProviderByCapability( + CopilotCapability.ImageToText, + 'gpt-4o' + ); + t.is( + p?.type.toString(), + 'openai', + 'should get provider support text-to-image and model' + ); + } + + // if a model is not defined and not available in online api + // it should return null + { + const p = await provider.getProviderByCapability( + CopilotCapability.ImageToText, + 'gpt-4-not-exist' + ); + t.falsy(p, 'should not get provider'); + } }); test('should be able to register test provider', async t => { const { provider } = t.context; registerCopilotProvider(MockCopilotTestProvider); - const assertProvider = (cap: CopilotCapability) => { - const p = provider.getProviderByCapability(cap, 'test'); + const assertProvider = async (cap: CopilotCapability) => { + const p = await provider.getProviderByCapability(cap, 'test'); t.is( p?.type, CopilotProviderType.Test, @@ -442,9 +474,9 @@ test('should be able to register test provider', async t => { ); }; - assertProvider(CopilotCapability.TextToText); - assertProvider(CopilotCapability.TextToEmbedding); - assertProvider(CopilotCapability.TextToImage); - assertProvider(CopilotCapability.ImageToImage); - assertProvider(CopilotCapability.ImageToText); + await assertProvider(CopilotCapability.TextToText); + await assertProvider(CopilotCapability.TextToEmbedding); + await assertProvider(CopilotCapability.TextToImage); + await assertProvider(CopilotCapability.ImageToImage); + await assertProvider(CopilotCapability.ImageToText); }); diff --git a/packages/backend/server/tests/utils/copilot.ts b/packages/backend/server/tests/utils/copilot.ts index 18df53783d..52b65f69ad 100644 --- a/packages/backend/server/tests/utils/copilot.ts +++ b/packages/backend/server/tests/utils/copilot.ts @@ -46,7 +46,7 @@ export class MockCopilotTestProvider return MockCopilotTestProvider.capabilities; } - override isModelAvailable(model: string): boolean { + override async isModelAvailable(model: string): Promise { return this.availableModels.includes(model); }