From 5cbcf6f907efd98352ecd5ef038840bfac1847aa Mon Sep 17 00:00:00 2001 From: DarkSky <25152247+darkskygit@users.noreply.github.com> Date: Fri, 1 Aug 2025 15:22:48 +0800 Subject: [PATCH] feat(server): add fallback model and baseurl in schema (#13375) fix AI-398 ## Summary by CodeRabbit * **New Features** * Added support for specifying fallback models for multiple AI providers, enhancing reliability when primary models are unavailable. * Providers can now fetch and update their list of available models dynamically from external APIs. * Configuration options expanded to allow custom base URLs for certain providers. * **Bug Fixes** * Improved model selection logic to use fallback models if the requested model is not available online. * **Chores** * Updated backend dependencies to include authentication support for Google services. --- .docker/selfhost/schema.json | 57 +++++++-- packages/backend/server/package.json | 1 + .../server/src/plugins/copilot/config.ts | 37 +++++- .../copilot/providers/anthropic/official.ts | 40 ++++++- .../copilot/providers/anthropic/vertex.ts | 67 ++++++++++- .../copilot/providers/gemini/gemini.ts | 5 - .../copilot/providers/gemini/generative.ts | 112 +++++++++++++++--- .../copilot/providers/gemini/vertex.ts | 90 +++++++++++++- .../src/plugins/copilot/providers/openai.ts | 61 +++++++++- .../plugins/copilot/providers/perplexity.ts | 15 ++- .../src/plugins/copilot/providers/provider.ts | 27 ++++- .../src/plugins/copilot/providers/types.ts | 2 + .../src/plugins/copilot/providers/utils.ts | 56 ++++++++- yarn.lock | 43 +++---- 14 files changed, 544 insertions(+), 69 deletions(-) diff --git a/.docker/selfhost/schema.json b/.docker/selfhost/schema.json index 972a0f8675..0ff8196e1b 100644 --- a/.docker/selfhost/schema.json +++ b/.docker/selfhost/schema.json @@ -669,9 +669,16 @@ }, "providers.openai": { "type": "object", - "description": "The config for the openai provider.\n@default {\"apiKey\":\"\"}\n@link https://github.com/openai/openai-node", + "description": "The config for the openai provider.\n@default {\"apiKey\":\"\",\"baseUrl\":\"\",\"fallback\":{\"text\":\"\",\"structured\":\"\",\"image\":\"\",\"embedding\":\"\"}}\n@link https://github.com/openai/openai-node", "default": { - "apiKey": "" + "apiKey": "", + "baseUrl": "", + "fallback": { + "text": "", + "structured": "", + "image": "", + "embedding": "" + } } }, "providers.fal": { @@ -683,14 +690,21 @@ }, "providers.gemini": { "type": "object", - "description": "The config for the gemini provider.\n@default {\"apiKey\":\"\"}", + "description": "The config for the gemini provider.\n@default {\"apiKey\":\"\",\"baseUrl\":\"\",\"fallback\":{\"text\":\"\",\"structured\":\"\",\"image\":\"\",\"embedding\":\"\"}}", "default": { - "apiKey": "" + "apiKey": "", + "baseUrl": "", + "fallback": { + "text": "", + "structured": "", + "image": "", + "embedding": "" + } } }, "providers.geminiVertex": { "type": "object", - "description": "The config for the google vertex provider.\n@default {}", + "description": "The config for the google vertex provider.\n@default {\"baseURL\":\"\",\"fallback\":{\"text\":\"\",\"structured\":\"\",\"image\":\"\",\"embedding\":\"\"}}", "properties": { "location": { "type": "string", @@ -721,25 +735,39 @@ } } }, - "default": {} + "default": { + "baseURL": "", + "fallback": { + "text": "", + "structured": "", + "image": "", + "embedding": "" + } + } }, "providers.perplexity": { "type": "object", - "description": "The config for the perplexity provider.\n@default {\"apiKey\":\"\"}", + "description": "The config for the perplexity provider.\n@default {\"apiKey\":\"\",\"fallback\":{\"text\":\"\"}}", "default": { - "apiKey": "" + "apiKey": "", + "fallback": { + "text": "" + } } }, "providers.anthropic": { "type": "object", - "description": "The config for the anthropic provider.\n@default {\"apiKey\":\"\"}", + "description": "The config for the anthropic provider.\n@default {\"apiKey\":\"\",\"fallback\":{\"text\":\"\"}}", "default": { - "apiKey": "" + "apiKey": "", + "fallback": { + "text": "" + } } }, "providers.anthropicVertex": { "type": "object", - "description": "The config for the google vertex provider.\n@default {}", + "description": "The config for the google vertex provider.\n@default {\"baseURL\":\"\",\"fallback\":{\"text\":\"\"}}", "properties": { "location": { "type": "string", @@ -770,7 +798,12 @@ } } }, - "default": {} + "default": { + "baseURL": "", + "fallback": { + "text": "" + } + } }, "providers.morph": { "type": "object", diff --git a/packages/backend/server/package.json b/packages/backend/server/package.json index 3608195e70..5ad6bc0a69 100644 --- a/packages/backend/server/package.json +++ b/packages/backend/server/package.json @@ -86,6 +86,7 @@ "express": "^5.0.1", "fast-xml-parser": "^5.0.0", "get-stream": "^9.0.1", + "google-auth-library": "^10.2.0", "graphql": "^16.9.0", "graphql-scalars": "^1.24.0", "graphql-upload": "^17.0.0", diff --git a/packages/backend/server/src/plugins/copilot/config.ts b/packages/backend/server/src/plugins/copilot/config.ts index c36045a018..89a14e6ec0 100644 --- a/packages/backend/server/src/plugins/copilot/config.ts +++ b/packages/backend/server/src/plugins/copilot/config.ts @@ -47,6 +47,13 @@ defineModuleConfig('copilot', { desc: 'The config for the openai provider.', default: { apiKey: '', + baseUrl: '', + fallback: { + text: '', + structured: '', + image: '', + embedding: '', + }, }, link: 'https://github.com/openai/openai-node', }, @@ -60,28 +67,54 @@ defineModuleConfig('copilot', { desc: 'The config for the gemini provider.', default: { apiKey: '', + baseUrl: '', + fallback: { + text: '', + structured: '', + image: '', + embedding: '', + }, }, }, 'providers.geminiVertex': { desc: 'The config for the gemini provider in Google Vertex AI.', - default: {}, + default: { + baseURL: '', + fallback: { + text: '', + structured: '', + image: '', + embedding: '', + }, + }, schema: VertexSchema, }, 'providers.perplexity': { desc: 'The config for the perplexity provider.', default: { apiKey: '', + fallback: { + text: '', + }, }, }, 'providers.anthropic': { desc: 'The config for the anthropic provider.', default: { apiKey: '', + fallback: { + text: '', + }, }, }, 'providers.anthropicVertex': { desc: 'The config for the anthropic provider in Google Vertex AI.', - default: {}, + default: { + baseURL: '', + fallback: { + text: '', + }, + }, schema: VertexSchema, }, 'providers.morph': { diff --git a/packages/backend/server/src/plugins/copilot/providers/anthropic/official.ts b/packages/backend/server/src/plugins/copilot/providers/anthropic/official.ts index 7c348d6918..0c68785686 100644 --- a/packages/backend/server/src/plugins/copilot/providers/anthropic/official.ts +++ b/packages/backend/server/src/plugins/copilot/providers/anthropic/official.ts @@ -3,12 +3,23 @@ import { createAnthropic, } from '@ai-sdk/anthropic'; -import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types'; +import { + CopilotChatOptions, + CopilotProviderType, + ModelConditions, + ModelInputType, + ModelOutputType, + PromptMessage, + StreamObject, +} from '../types'; import { AnthropicProvider } from './anthropic'; export type AnthropicOfficialConfig = { apiKey: string; baseUrl?: string; + fallback?: { + text?: string; + }; }; export class AnthropicOfficialProvider extends AnthropicProvider { @@ -67,4 +78,31 @@ export class AnthropicOfficialProvider extends AnthropicProvider { + const fullCond = { ...cond, fallbackModel: this.config.fallback?.text }; + return super.text(fullCond, messages, options); + } + + override async *streamText( + cond: ModelConditions, + messages: PromptMessage[], + options: CopilotChatOptions = {} + ): AsyncIterable { + const fullCond = { ...cond, fallbackModel: this.config.fallback?.text }; + yield* super.streamText(fullCond, messages, options); + } + + override async *streamObject( + cond: ModelConditions, + messages: PromptMessage[], + options: CopilotChatOptions = {} + ): AsyncIterable { + const fullCond = { ...cond, fallbackModel: this.config.fallback?.text }; + yield* super.streamObject(fullCond, messages, options); + } } diff --git a/packages/backend/server/src/plugins/copilot/providers/anthropic/vertex.ts b/packages/backend/server/src/plugins/copilot/providers/anthropic/vertex.ts index b86b3c80d8..332889b033 100644 --- a/packages/backend/server/src/plugins/copilot/providers/anthropic/vertex.ts +++ b/packages/backend/server/src/plugins/copilot/providers/anthropic/vertex.ts @@ -4,10 +4,23 @@ import { type GoogleVertexAnthropicProviderSettings, } from '@ai-sdk/google-vertex/anthropic'; -import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types'; +import { + CopilotChatOptions, + CopilotProviderType, + ModelConditions, + ModelInputType, + ModelOutputType, + PromptMessage, + StreamObject, +} from '../types'; +import { getGoogleAuth, VertexModelListSchema } from '../utils'; import { AnthropicProvider } from './anthropic'; -export type AnthropicVertexConfig = GoogleVertexAnthropicProviderSettings; +export type AnthropicVertexConfig = GoogleVertexAnthropicProviderSettings & { + fallback?: { + text?: string; + }; +}; export class AnthropicVertexProvider extends AnthropicProvider { override readonly type = CopilotProviderType.AnthropicVertex; @@ -62,4 +75,54 @@ export class AnthropicVertexProvider extends AnthropicProvider { + const fullCond = { ...cond, fallbackModel: this.config.fallback?.text }; + return super.text(fullCond, messages, options); + } + + override async *streamText( + cond: ModelConditions, + messages: PromptMessage[], + options: CopilotChatOptions = {} + ): AsyncIterable { + const fullCond = { ...cond, fallbackModel: this.config.fallback?.text }; + yield* super.streamText(fullCond, messages, options); + } + + override async *streamObject( + cond: ModelConditions, + messages: PromptMessage[], + options: CopilotChatOptions = {} + ): AsyncIterable { + const fullCond = { ...cond, fallbackModel: this.config.fallback?.text }; + yield* super.streamObject(fullCond, messages, options); + } + + override async refreshOnlineModels() { + try { + const { baseUrl, headers } = await getGoogleAuth( + this.config, + 'anthropic' + ); + if (baseUrl && !this.onlineModelList.length) { + const { publisherModels } = await fetch(`${baseUrl}/models`, { + headers: headers(), + }) + .then(r => r.json()) + .then(r => VertexModelListSchema.parse(r)); + this.onlineModelList = publisherModels.map( + model => + model.name.replace('publishers/anthropic/models/', '') + + (model.versionId !== 'default' ? `@${model.versionId}` : '') + ); + } + } catch (e) { + this.logger.error('Failed to fetch available models', e); + } + } } diff --git a/packages/backend/server/src/plugins/copilot/providers/gemini/gemini.ts b/packages/backend/server/src/plugins/copilot/providers/gemini/gemini.ts index 5820e35e08..075499e797 100644 --- a/packages/backend/server/src/plugins/copilot/providers/gemini/gemini.ts +++ b/packages/backend/server/src/plugins/copilot/providers/gemini/gemini.ts @@ -37,11 +37,6 @@ import { export const DEFAULT_DIMENSIONS = 256; -export type GeminiConfig = { - apiKey: string; - baseUrl?: string; -}; - export abstract class GeminiProvider extends CopilotProvider { private readonly MAX_STEPS = 20; diff --git a/packages/backend/server/src/plugins/copilot/providers/gemini/generative.ts b/packages/backend/server/src/plugins/copilot/providers/gemini/generative.ts index 32b4575917..f07b42d1de 100644 --- a/packages/backend/server/src/plugins/copilot/providers/gemini/generative.ts +++ b/packages/backend/server/src/plugins/copilot/providers/gemini/generative.ts @@ -2,15 +2,35 @@ import { createGoogleGenerativeAI, type GoogleGenerativeAIProvider, } from '@ai-sdk/google'; +import z from 'zod'; -import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types'; +import { + CopilotChatOptions, + CopilotEmbeddingOptions, + CopilotProviderType, + ModelConditions, + ModelInputType, + ModelOutputType, + PromptMessage, + StreamObject, +} from '../types'; import { GeminiProvider } from './gemini'; export type GeminiGenerativeConfig = { apiKey: string; baseUrl?: string; + fallback?: { + text?: string; + structured?: string; + image?: string; + embedding?: string; + }; }; +const ModelListSchema = z.object({ + models: z.array(z.object({ name: z.string() })), +}); + export class GeminiGenerativeProvider extends GeminiProvider { override readonly type = CopilotProviderType.Gemini; @@ -71,27 +91,16 @@ export class GeminiGenerativeProvider extends GeminiProvider { + const fullCond = { ...cond, fallbackModel: this.config.fallback?.text }; + return super.text(fullCond, messages, options); + } + + override async structure( + cond: ModelConditions, + messages: PromptMessage[], + options?: CopilotChatOptions + ): Promise { + const fullCond = { + ...cond, + fallbackModel: this.config.fallback?.structured, + }; + return super.structure(fullCond, messages, options); + } + + override async *streamText( + cond: ModelConditions, + messages: PromptMessage[], + options: CopilotChatOptions = {} + ): AsyncIterable { + const fullCond = { ...cond, fallbackModel: this.config.fallback?.text }; + yield* super.streamText(fullCond, messages, options); + } + + override async *streamObject( + cond: ModelConditions, + messages: PromptMessage[], + options: CopilotChatOptions = {} + ): AsyncIterable { + const fullCond = { ...cond, fallbackModel: this.config.fallback?.text }; + yield* super.streamObject(fullCond, messages, options); + } + + override async embedding( + cond: ModelConditions, + messages: string | string[], + options?: CopilotEmbeddingOptions + ): Promise { + const fullCond = { + ...cond, + fallbackModel: this.config.fallback?.embedding, + }; + return super.embedding(fullCond, messages, options); + } + + override async refreshOnlineModels() { + try { + const baseUrl = + this.config.baseUrl || + 'https://generativelanguage.googleapis.com/v1beta'; + if (baseUrl && !this.onlineModelList.length) { + const { models } = await fetch( + `${baseUrl}/models?key=${this.config.apiKey}` + ) + .then(r => r.json()) + .then( + r => (console.log(JSON.stringify(r)), ModelListSchema.parse(r)) + ); + this.onlineModelList = models.map(model => + model.name.replace('models/', '') + ); + } + } catch (e) { + this.logger.error('Failed to fetch available models', e); + } + } } diff --git a/packages/backend/server/src/plugins/copilot/providers/gemini/vertex.ts b/packages/backend/server/src/plugins/copilot/providers/gemini/vertex.ts index 4217e215e4..609de0728b 100644 --- a/packages/backend/server/src/plugins/copilot/providers/gemini/vertex.ts +++ b/packages/backend/server/src/plugins/copilot/providers/gemini/vertex.ts @@ -4,10 +4,27 @@ import { type GoogleVertexProviderSettings, } from '@ai-sdk/google-vertex'; -import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types'; +import { + CopilotChatOptions, + CopilotEmbeddingOptions, + CopilotProviderType, + ModelConditions, + ModelInputType, + ModelOutputType, + PromptMessage, + StreamObject, +} from '../types'; +import { getGoogleAuth, VertexModelListSchema } from '../utils'; import { GeminiProvider } from './gemini'; -export type GeminiVertexConfig = GoogleVertexProviderSettings; +export type GeminiVertexConfig = GoogleVertexProviderSettings & { + fallback?: { + text?: string; + structured?: string; + image?: string; + embedding?: string; + }; +}; export class GeminiVertexProvider extends GeminiProvider { override readonly type = CopilotProviderType.GeminiVertex; @@ -72,4 +89,73 @@ export class GeminiVertexProvider extends GeminiProvider { super.setup(); this.instance = createVertex(this.config); } + + override async text( + cond: ModelConditions, + messages: PromptMessage[], + options: CopilotChatOptions = {} + ): Promise { + const fullCond = { ...cond, fallbackModel: this.config.fallback?.text }; + return super.text(fullCond, messages, options); + } + + override async structure( + cond: ModelConditions, + messages: PromptMessage[], + options?: CopilotChatOptions + ): Promise { + const fullCond = { + ...cond, + fallbackModel: this.config.fallback?.structured, + }; + return super.structure(fullCond, messages, options); + } + + override async *streamText( + cond: ModelConditions, + messages: PromptMessage[], + options: CopilotChatOptions = {} + ): AsyncIterable { + const fullCond = { ...cond, fallbackModel: this.config.fallback?.text }; + yield* super.streamText(fullCond, messages, options); + } + + override async *streamObject( + cond: ModelConditions, + messages: PromptMessage[], + options: CopilotChatOptions = {} + ): AsyncIterable { + const fullCond = { ...cond, fallbackModel: this.config.fallback?.text }; + yield* super.streamObject(fullCond, messages, options); + } + + override async embedding( + cond: ModelConditions, + messages: string | string[], + options?: CopilotEmbeddingOptions + ): Promise { + const fullCond = { + ...cond, + fallbackModel: this.config.fallback?.embedding, + }; + return super.embedding(fullCond, messages, options); + } + + override async refreshOnlineModels() { + try { + const { baseUrl, headers } = await getGoogleAuth(this.config, 'google'); + if (baseUrl && !this.onlineModelList.length) { + const { publisherModels } = await fetch(`${baseUrl}/models`, { + headers: headers(), + }) + .then(r => r.json()) + .then(r => VertexModelListSchema.parse(r)); + this.onlineModelList = publisherModels.map(model => + model.name.replace('publishers/google/models/', '') + ); + } + } catch (e) { + this.logger.error('Failed to fetch available models', e); + } + } } diff --git a/packages/backend/server/src/plugins/copilot/providers/openai.ts b/packages/backend/server/src/plugins/copilot/providers/openai.ts index 8c2ffea059..efd8bd9e78 100644 --- a/packages/backend/server/src/plugins/copilot/providers/openai.ts +++ b/packages/backend/server/src/plugins/copilot/providers/openai.ts @@ -46,8 +46,18 @@ export const DEFAULT_DIMENSIONS = 256; export type OpenAIConfig = { apiKey: string; baseUrl?: string; + fallback?: { + text?: string; + structured?: string; + image?: string; + embedding?: string; + }; }; +const ModelListSchema = z.object({ + data: z.array(z.object({ id: z.string() })), +}); + const ImageResponseSchema = z.union([ z.object({ data: z.array(z.object({ b64_json: z.string() })), @@ -271,6 +281,25 @@ export class OpenAIProvider extends CopilotProvider { } } + override async refreshOnlineModels() { + try { + const baseUrl = this.config.baseUrl || 'https://api.openai.com/v1'; + if (baseUrl && !this.onlineModelList.length) { + const { data } = await fetch(`${baseUrl}/models`, { + headers: { + Authorization: `Bearer ${this.config.apiKey}`, + 'Content-Type': 'application/json', + }, + }) + .then(r => r.json()) + .then(r => ModelListSchema.parse(r)); + this.onlineModelList = data.map(model => model.id); + } + } catch (e) { + this.logger.error('Failed to fetch available models', e); + } + } + override getProviderSpecificTools( toolName: CopilotChatTools, model: string @@ -291,6 +320,7 @@ export class OpenAIProvider extends CopilotProvider { const fullCond = { ...cond, outputType: ModelOutputType.Text, + fallbackModel: this.config.fallback?.text, }; await this.checkParams({ messages, cond: fullCond, options }); const model = this.selectModel(fullCond); @@ -331,6 +361,7 @@ export class OpenAIProvider extends CopilotProvider { const fullCond = { ...cond, outputType: ModelOutputType.Text, + fallbackModel: this.config.fallback?.text, }; await this.checkParams({ messages, cond: fullCond, options }); const model = this.selectModel(fullCond); @@ -376,7 +407,11 @@ export class OpenAIProvider extends CopilotProvider { messages: PromptMessage[], options: CopilotChatOptions = {} ): AsyncIterable { - const fullCond = { ...cond, outputType: ModelOutputType.Object }; + const fullCond = { + ...cond, + outputType: ModelOutputType.Object, + fallbackModel: this.config.fallback?.text, + }; await this.checkParams({ cond: fullCond, messages, options }); const model = this.selectModel(fullCond); @@ -409,7 +444,11 @@ export class OpenAIProvider extends CopilotProvider { messages: PromptMessage[], options: CopilotStructuredOptions = {} ): Promise { - const fullCond = { ...cond, outputType: ModelOutputType.Structured }; + const fullCond = { + ...cond, + outputType: ModelOutputType.Structured, + fallbackModel: this.config.fallback?.structured, + }; await this.checkParams({ messages, cond: fullCond, options }); const model = this.selectModel(fullCond); @@ -449,7 +488,11 @@ export class OpenAIProvider extends CopilotProvider { chunkMessages: PromptMessage[][], options: CopilotChatOptions = {} ): Promise { - const fullCond = { ...cond, outputType: ModelOutputType.Text }; + const fullCond = { + ...cond, + outputType: ModelOutputType.Text, + fallbackModel: this.config.fallback?.text, + }; await this.checkParams({ messages: [], cond: fullCond, options }); const model = this.selectModel(fullCond); // get the log probability of "yes"/"no" @@ -594,7 +637,11 @@ export class OpenAIProvider extends CopilotProvider { messages: PromptMessage[], options: CopilotImageOptions = {} ) { - const fullCond = { ...cond, outputType: ModelOutputType.Image }; + const fullCond = { + ...cond, + outputType: ModelOutputType.Image, + fallbackModel: this.config.fallback?.image, + }; await this.checkParams({ messages, cond: fullCond, options }); const model = this.selectModel(fullCond); @@ -644,7 +691,11 @@ export class OpenAIProvider extends CopilotProvider { options: CopilotEmbeddingOptions = { dimensions: DEFAULT_DIMENSIONS } ): Promise { messages = Array.isArray(messages) ? messages : [messages]; - const fullCond = { ...cond, outputType: ModelOutputType.Embedding }; + const fullCond = { + ...cond, + outputType: ModelOutputType.Embedding, + fallbackModel: this.config.fallback?.embedding, + }; await this.checkParams({ embeddings: messages, cond: fullCond, options }); const model = this.selectModel(fullCond); diff --git a/packages/backend/server/src/plugins/copilot/providers/perplexity.ts b/packages/backend/server/src/plugins/copilot/providers/perplexity.ts index 706f948ae5..7780ed054c 100644 --- a/packages/backend/server/src/plugins/copilot/providers/perplexity.ts +++ b/packages/backend/server/src/plugins/copilot/providers/perplexity.ts @@ -20,6 +20,9 @@ import { chatToGPTMessage, CitationParser } from './utils'; export type PerplexityConfig = { apiKey: string; endpoint?: string; + fallback?: { + text?: string; + }; }; const PerplexityErrorSchema = z.union([ @@ -109,7 +112,11 @@ export class PerplexityProvider extends CopilotProvider { messages: PromptMessage[], options: CopilotChatOptions = {} ): Promise { - const fullCond = { ...cond, outputType: ModelOutputType.Text }; + const fullCond = { + ...cond, + outputType: ModelOutputType.Text, + fallbackModel: this.config.fallback?.text, + }; await this.checkParams({ cond: fullCond, messages, options }); const model = this.selectModel(fullCond); @@ -149,7 +156,11 @@ export class PerplexityProvider extends CopilotProvider { messages: PromptMessage[], options: CopilotChatOptions = {} ): AsyncIterable { - const fullCond = { ...cond, outputType: ModelOutputType.Text }; + const fullCond = { + ...cond, + outputType: ModelOutputType.Text, + fallbackModel: this.config.fallback?.text, + }; await this.checkParams({ cond: fullCond, messages, options }); const model = this.selectModel(fullCond); diff --git a/packages/backend/server/src/plugins/copilot/providers/provider.ts b/packages/backend/server/src/plugins/copilot/providers/provider.ts index 4ab382e066..3a8a3370c6 100644 --- a/packages/backend/server/src/plugins/copilot/providers/provider.ts +++ b/packages/backend/server/src/plugins/copilot/providers/provider.ts @@ -53,6 +53,7 @@ import { @Injectable() export abstract class CopilotProvider { protected readonly logger = new Logger(this.constructor.name); + protected onlineModelList: string[] = []; abstract readonly type: CopilotProviderType; abstract readonly models: CopilotProviderModel[]; abstract configured(): boolean; @@ -80,11 +81,18 @@ export abstract class CopilotProvider { protected setup() { if (this.configured()) { this.factory.register(this); + if (env.selfhosted) { + this.refreshOnlineModels().catch(e => + this.logger.error('Failed to refresh online models', e) + ); + } } else { this.factory.unregister(this); } } + async refreshOnlineModels() {} + private findValidModel( cond: ModelFullConditions ): CopilotProviderModel | undefined { @@ -95,9 +103,26 @@ export abstract class CopilotProvider { inputTypes.every(type => cap.input.includes(type))); if (modelId) { - return this.models.find( + const hasOnlineModel = this.onlineModelList.includes(modelId); + const hasFallbackModel = cond.fallbackModel + ? this.onlineModelList.includes(cond.fallbackModel) + : undefined; + + const model = this.models.find( m => m.id === modelId && m.capabilities.some(matcher) ); + + if (model) { + // return fallback model if current model is not alive + if (!hasOnlineModel && hasFallbackModel) { + // oxlint-disable-next-line typescript-eslint(no-non-null-assertion) + return { id: cond.fallbackModel!, capabilities: [] }; + } + return model; + } + // allow online model without capabilities check + if (hasOnlineModel) return { id: modelId, capabilities: [] }; + return undefined; } if (!outputType) return undefined; diff --git a/packages/backend/server/src/plugins/copilot/providers/types.ts b/packages/backend/server/src/plugins/copilot/providers/types.ts index 34a5bd9fba..443c34f579 100644 --- a/packages/backend/server/src/plugins/copilot/providers/types.ts +++ b/packages/backend/server/src/plugins/copilot/providers/types.ts @@ -237,6 +237,7 @@ export interface ModelCapability { export interface CopilotProviderModel { id: string; + name?: string; capabilities: ModelCapability[]; } @@ -247,4 +248,5 @@ export type ModelConditions = { export type ModelFullConditions = ModelConditions & { outputType?: ModelOutputType; + fallbackModel?: string; }; diff --git a/packages/backend/server/src/plugins/copilot/providers/utils.ts b/packages/backend/server/src/plugins/copilot/providers/utils.ts index 426d70a44f..f39509ab75 100644 --- a/packages/backend/server/src/plugins/copilot/providers/utils.ts +++ b/packages/backend/server/src/plugins/copilot/providers/utils.ts @@ -1,3 +1,5 @@ +import { GoogleVertexProviderSettings } from '@ai-sdk/google-vertex'; +import { GoogleVertexAnthropicProviderSettings } from '@ai-sdk/google-vertex/anthropic'; import { Logger } from '@nestjs/common'; import { CoreAssistantMessage, @@ -7,7 +9,8 @@ import { TextPart, TextStreamPart, } from 'ai'; -import { ZodType } from 'zod'; +import { GoogleAuth, GoogleAuthOptions } from 'google-auth-library'; +import z, { ZodType } from 'zod'; import { CustomAITools } from '../tools'; import { PromptMessage, StreamObject } from './types'; @@ -655,3 +658,54 @@ export class StreamObjectParser { }, ''); } } + +export const VertexModelListSchema = z.object({ + publisherModels: z.array( + z.object({ + name: z.string(), + versionId: z.string(), + }) + ), +}); + +export async function getGoogleAuth( + options: GoogleVertexAnthropicProviderSettings | GoogleVertexProviderSettings, + publisher: 'anthropic' | 'google' +) { + function getBaseUrl() { + const { baseURL, location } = options; + if (baseURL?.trim()) { + try { + const url = new URL(baseURL); + if (url.pathname.endsWith('/')) { + url.pathname = url.pathname.slice(0, -1); + } + return url.toString(); + } catch {} + } else if (location) { + return `https://${location}-aiplatform.googleapis.com/v1beta1/publishers/${publisher}`; + } + return undefined; + } + + async function generateAuthToken() { + if (!options.googleAuthOptions) { + return undefined; + } + const auth = new GoogleAuth({ + scopes: ['https://www.googleapis.com/auth/cloud-platform'], + ...(options.googleAuthOptions as GoogleAuthOptions), + }); + const client = await auth.getClient(); + const token = await client.getAccessToken(); + return token.token; + } + + const token = await generateAuthToken(); + + return { + baseUrl: getBaseUrl(), + headers: () => ({ Authorization: `Bearer ${token}` }), + fetch: options.fetch, + }; +} diff --git a/yarn.lock b/yarn.lock index 46597c2e73..01d38ae90f 100644 --- a/yarn.lock +++ b/yarn.lock @@ -998,6 +998,7 @@ __metadata: express: "npm:^5.0.1" fast-xml-parser: "npm:^5.0.0" get-stream: "npm:^9.0.1" + google-auth-library: "npm:^10.2.0" graphql: "npm:^16.9.0" graphql-scalars: "npm:^1.24.0" graphql-upload: "npm:^17.0.0" @@ -22147,14 +22148,14 @@ __metadata: languageName: node linkType: hard -"gaxios@npm:^7.0.0-rc.1, gaxios@npm:^7.0.0-rc.4": - version: 7.0.0-rc.6 - resolution: "gaxios@npm:7.0.0-rc.6" +"gaxios@npm:^7.0.0, gaxios@npm:^7.0.0-rc.4": + version: 7.1.1 + resolution: "gaxios@npm:7.1.1" dependencies: extend: "npm:^3.0.2" https-proxy-agent: "npm:^7.0.1" node-fetch: "npm:^3.3.2" - checksum: 10/60c688d4c65062c97bf0f33f959713df106e207065586bf5deb546ef5d02cddcba46d138c0b7eb8712950ca880fa28d3665936b19156224f8c478d9c4f817aea + checksum: 10/9e5fa8b458c318a95d4dff0f6ac187a1b8933fb1de5b376b7098b27dfc5bf6025b62c87ed20bdae0496ae73a279834bc6b974c28849a674deed0089f2ba57b98 languageName: node linkType: hard @@ -22169,14 +22170,14 @@ __metadata: languageName: node linkType: hard -"gcp-metadata@npm:^7.0.0-rc.1": - version: 7.0.0-rc.1 - resolution: "gcp-metadata@npm:7.0.0-rc.1" +"gcp-metadata@npm:^7.0.0": + version: 7.0.1 + resolution: "gcp-metadata@npm:7.0.1" dependencies: - gaxios: "npm:^7.0.0-rc.1" + gaxios: "npm:^7.0.0" google-logging-utils: "npm:^1.0.0" json-bigint: "npm:^1.0.0" - checksum: 10/2c58401c7945c41144bc6a44a066c050d36c34ee10e04e85ffde488afec6a3f67ebe29e697d0328aa20181b38a36aae9165896c0201387ea8cff031cdb790ab9 + checksum: 10/c82f20a4ce22278998fe033e668a66bff04d2b3e95e19f968adeac829e12274e07b453fcfcf34573a6d702b3570c5556cba6eb6b59d1c03757c866e3271972c1 languageName: node linkType: hard @@ -22540,18 +22541,18 @@ __metadata: languageName: node linkType: hard -"google-auth-library@npm:^10.0.0-rc.1": - version: 10.0.0-rc.3 - resolution: "google-auth-library@npm:10.0.0-rc.3" +"google-auth-library@npm:^10.0.0-rc.1, google-auth-library@npm:^10.2.0": + version: 10.2.0 + resolution: "google-auth-library@npm:10.2.0" dependencies: base64-js: "npm:^1.3.0" ecdsa-sig-formatter: "npm:^1.0.11" - gaxios: "npm:^7.0.0-rc.4" - gcp-metadata: "npm:^7.0.0-rc.1" + gaxios: "npm:^7.0.0" + gcp-metadata: "npm:^7.0.0" google-logging-utils: "npm:^1.0.0" - gtoken: "npm:^8.0.0-rc.1" + gtoken: "npm:^8.0.0" jws: "npm:^4.0.0" - checksum: 10/d76f470ddba1d5ec84cb72b03af388722db2987a95f64c7d91169f3f67608528a327fd132bb00d74be10b08e6446906b16983cf0aa9f6fa10ac7a6d3aacff0d7 + checksum: 10/dfa6ad7240da3915b7e15d1f39cd6906e6714502b09957be07d2429665a44c6b46d8ca7a077cec7fcb83f82ed2c3cb5a9b7c1ba79ebb3e8920eba286a24bdd63 languageName: node linkType: hard @@ -22748,13 +22749,13 @@ __metadata: languageName: node linkType: hard -"gtoken@npm:^8.0.0-rc.1": - version: 8.0.0-rc.1 - resolution: "gtoken@npm:8.0.0-rc.1" +"gtoken@npm:^8.0.0": + version: 8.0.0 + resolution: "gtoken@npm:8.0.0" dependencies: - gaxios: "npm:^7.0.0-rc.1" + gaxios: "npm:^7.0.0" jws: "npm:^4.0.0" - checksum: 10/d2481344df8d9f62ec3ae7fe97b562c93dc294c7d1e7d8e1603162fe2726cfb5993f2b1a4e04388ee0f49c5fd02c1b0799dc58cebbf8af10489af8de80a72902 + checksum: 10/b921430395dcd06ee63c3fc5a5e339ca4d6dcb38b6d618beb0f260bae1088d53d130f86029a9d578f1601c64685f49a65dba57bbd617c4b14039180b67b6c5ce languageName: node linkType: hard