feat(server): add fallback model and baseurl in schema (#13375)

fix AI-398

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **New Features**
* Added support for specifying fallback models for multiple AI
providers, enhancing reliability when primary models are unavailable.
* Providers can now fetch and update their list of available models
dynamically from external APIs.
* Configuration options expanded to allow custom base URLs for certain
providers.

* **Bug Fixes**
* Improved model selection logic to use fallback models if the requested
model is not available online.

* **Chores**
* Updated backend dependencies to include authentication support for
Google services.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
DarkSky
2025-08-01 15:22:48 +08:00
committed by GitHub
parent 19790c1b9e
commit 5cbcf6f907
14 changed files with 544 additions and 69 deletions

View File

@@ -47,6 +47,13 @@ defineModuleConfig('copilot', {
desc: 'The config for the openai provider.',
default: {
apiKey: '',
baseUrl: '',
fallback: {
text: '',
structured: '',
image: '',
embedding: '',
},
},
link: 'https://github.com/openai/openai-node',
},
@@ -60,28 +67,54 @@ defineModuleConfig('copilot', {
desc: 'The config for the gemini provider.',
default: {
apiKey: '',
baseUrl: '',
fallback: {
text: '',
structured: '',
image: '',
embedding: '',
},
},
},
'providers.geminiVertex': {
desc: 'The config for the gemini provider in Google Vertex AI.',
default: {},
default: {
baseURL: '',
fallback: {
text: '',
structured: '',
image: '',
embedding: '',
},
},
schema: VertexSchema,
},
'providers.perplexity': {
desc: 'The config for the perplexity provider.',
default: {
apiKey: '',
fallback: {
text: '',
},
},
},
'providers.anthropic': {
desc: 'The config for the anthropic provider.',
default: {
apiKey: '',
fallback: {
text: '',
},
},
},
'providers.anthropicVertex': {
desc: 'The config for the anthropic provider in Google Vertex AI.',
default: {},
default: {
baseURL: '',
fallback: {
text: '',
},
},
schema: VertexSchema,
},
'providers.morph': {

View File

@@ -3,12 +3,23 @@ import {
createAnthropic,
} from '@ai-sdk/anthropic';
import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types';
import {
CopilotChatOptions,
CopilotProviderType,
ModelConditions,
ModelInputType,
ModelOutputType,
PromptMessage,
StreamObject,
} from '../types';
import { AnthropicProvider } from './anthropic';
export type AnthropicOfficialConfig = {
apiKey: string;
baseUrl?: string;
fallback?: {
text?: string;
};
};
export class AnthropicOfficialProvider extends AnthropicProvider<AnthropicOfficialConfig> {
@@ -67,4 +78,31 @@ export class AnthropicOfficialProvider extends AnthropicProvider<AnthropicOffici
baseURL: this.config.baseUrl,
});
}
override async text(
cond: ModelConditions,
messages: PromptMessage[],
options: CopilotChatOptions = {}
): Promise<string> {
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
return super.text(fullCond, messages, options);
}
override async *streamText(
cond: ModelConditions,
messages: PromptMessage[],
options: CopilotChatOptions = {}
): AsyncIterable<string> {
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
yield* super.streamText(fullCond, messages, options);
}
override async *streamObject(
cond: ModelConditions,
messages: PromptMessage[],
options: CopilotChatOptions = {}
): AsyncIterable<StreamObject> {
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
yield* super.streamObject(fullCond, messages, options);
}
}

View File

@@ -4,10 +4,23 @@ import {
type GoogleVertexAnthropicProviderSettings,
} from '@ai-sdk/google-vertex/anthropic';
import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types';
import {
CopilotChatOptions,
CopilotProviderType,
ModelConditions,
ModelInputType,
ModelOutputType,
PromptMessage,
StreamObject,
} from '../types';
import { getGoogleAuth, VertexModelListSchema } from '../utils';
import { AnthropicProvider } from './anthropic';
export type AnthropicVertexConfig = GoogleVertexAnthropicProviderSettings;
export type AnthropicVertexConfig = GoogleVertexAnthropicProviderSettings & {
fallback?: {
text?: string;
};
};
export class AnthropicVertexProvider extends AnthropicProvider<AnthropicVertexConfig> {
override readonly type = CopilotProviderType.AnthropicVertex;
@@ -62,4 +75,54 @@ export class AnthropicVertexProvider extends AnthropicProvider<AnthropicVertexCo
super.setup();
this.instance = createVertexAnthropic(this.config);
}
override async text(
cond: ModelConditions,
messages: PromptMessage[],
options: CopilotChatOptions = {}
): Promise<string> {
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
return super.text(fullCond, messages, options);
}
override async *streamText(
cond: ModelConditions,
messages: PromptMessage[],
options: CopilotChatOptions = {}
): AsyncIterable<string> {
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
yield* super.streamText(fullCond, messages, options);
}
override async *streamObject(
cond: ModelConditions,
messages: PromptMessage[],
options: CopilotChatOptions = {}
): AsyncIterable<StreamObject> {
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
yield* super.streamObject(fullCond, messages, options);
}
override async refreshOnlineModels() {
try {
const { baseUrl, headers } = await getGoogleAuth(
this.config,
'anthropic'
);
if (baseUrl && !this.onlineModelList.length) {
const { publisherModels } = await fetch(`${baseUrl}/models`, {
headers: headers(),
})
.then(r => r.json())
.then(r => VertexModelListSchema.parse(r));
this.onlineModelList = publisherModels.map(
model =>
model.name.replace('publishers/anthropic/models/', '') +
(model.versionId !== 'default' ? `@${model.versionId}` : '')
);
}
} catch (e) {
this.logger.error('Failed to fetch available models', e);
}
}
}

View File

@@ -37,11 +37,6 @@ import {
export const DEFAULT_DIMENSIONS = 256;
export type GeminiConfig = {
apiKey: string;
baseUrl?: string;
};
export abstract class GeminiProvider<T> extends CopilotProvider<T> {
private readonly MAX_STEPS = 20;

View File

@@ -2,15 +2,35 @@ import {
createGoogleGenerativeAI,
type GoogleGenerativeAIProvider,
} from '@ai-sdk/google';
import z from 'zod';
import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types';
import {
CopilotChatOptions,
CopilotEmbeddingOptions,
CopilotProviderType,
ModelConditions,
ModelInputType,
ModelOutputType,
PromptMessage,
StreamObject,
} from '../types';
import { GeminiProvider } from './gemini';
export type GeminiGenerativeConfig = {
apiKey: string;
baseUrl?: string;
fallback?: {
text?: string;
structured?: string;
image?: string;
embedding?: string;
};
};
const ModelListSchema = z.object({
models: z.array(z.object({ name: z.string() })),
});
export class GeminiGenerativeProvider extends GeminiProvider<GeminiGenerativeConfig> {
override readonly type = CopilotProviderType.Gemini;
@@ -71,27 +91,16 @@ export class GeminiGenerativeProvider extends GeminiProvider<GeminiGenerativeCon
],
},
{
name: 'Text Embedding 005',
id: 'text-embedding-005',
name: 'Gemini Embedding',
id: 'gemini-embedding-001',
capabilities: [
{
input: [ModelInputType.Text],
output: [ModelOutputType.Embedding],
defaultForOutputType: true,
},
],
},
// not exists yet
// {
// name: 'Gemini Embedding',
// id: 'gemini-embedding-001',
// capabilities: [
// {
// input: [ModelInputType.Text],
// output: [ModelOutputType.Embedding],
// defaultForOutputType: true,
// },
// ],
// },
];
protected instance!: GoogleGenerativeAIProvider;
@@ -107,4 +116,77 @@ export class GeminiGenerativeProvider extends GeminiProvider<GeminiGenerativeCon
baseURL: this.config.baseUrl,
});
}
override async text(
cond: ModelConditions,
messages: PromptMessage[],
options: CopilotChatOptions = {}
): Promise<string> {
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
return super.text(fullCond, messages, options);
}
override async structure(
cond: ModelConditions,
messages: PromptMessage[],
options?: CopilotChatOptions
): Promise<string> {
const fullCond = {
...cond,
fallbackModel: this.config.fallback?.structured,
};
return super.structure(fullCond, messages, options);
}
override async *streamText(
cond: ModelConditions,
messages: PromptMessage[],
options: CopilotChatOptions = {}
): AsyncIterable<string> {
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
yield* super.streamText(fullCond, messages, options);
}
override async *streamObject(
cond: ModelConditions,
messages: PromptMessage[],
options: CopilotChatOptions = {}
): AsyncIterable<StreamObject> {
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
yield* super.streamObject(fullCond, messages, options);
}
override async embedding(
cond: ModelConditions,
messages: string | string[],
options?: CopilotEmbeddingOptions
): Promise<number[][]> {
const fullCond = {
...cond,
fallbackModel: this.config.fallback?.embedding,
};
return super.embedding(fullCond, messages, options);
}
override async refreshOnlineModels() {
try {
const baseUrl =
this.config.baseUrl ||
'https://generativelanguage.googleapis.com/v1beta';
if (baseUrl && !this.onlineModelList.length) {
const { models } = await fetch(
`${baseUrl}/models?key=${this.config.apiKey}`
)
.then(r => r.json())
.then(
r => (console.log(JSON.stringify(r)), ModelListSchema.parse(r))
);
this.onlineModelList = models.map(model =>
model.name.replace('models/', '')
);
}
} catch (e) {
this.logger.error('Failed to fetch available models', e);
}
}
}

View File

@@ -4,10 +4,27 @@ import {
type GoogleVertexProviderSettings,
} from '@ai-sdk/google-vertex';
import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types';
import {
CopilotChatOptions,
CopilotEmbeddingOptions,
CopilotProviderType,
ModelConditions,
ModelInputType,
ModelOutputType,
PromptMessage,
StreamObject,
} from '../types';
import { getGoogleAuth, VertexModelListSchema } from '../utils';
import { GeminiProvider } from './gemini';
export type GeminiVertexConfig = GoogleVertexProviderSettings;
export type GeminiVertexConfig = GoogleVertexProviderSettings & {
fallback?: {
text?: string;
structured?: string;
image?: string;
embedding?: string;
};
};
export class GeminiVertexProvider extends GeminiProvider<GeminiVertexConfig> {
override readonly type = CopilotProviderType.GeminiVertex;
@@ -72,4 +89,73 @@ export class GeminiVertexProvider extends GeminiProvider<GeminiVertexConfig> {
super.setup();
this.instance = createVertex(this.config);
}
override async text(
cond: ModelConditions,
messages: PromptMessage[],
options: CopilotChatOptions = {}
): Promise<string> {
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
return super.text(fullCond, messages, options);
}
override async structure(
cond: ModelConditions,
messages: PromptMessage[],
options?: CopilotChatOptions
): Promise<string> {
const fullCond = {
...cond,
fallbackModel: this.config.fallback?.structured,
};
return super.structure(fullCond, messages, options);
}
override async *streamText(
cond: ModelConditions,
messages: PromptMessage[],
options: CopilotChatOptions = {}
): AsyncIterable<string> {
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
yield* super.streamText(fullCond, messages, options);
}
override async *streamObject(
cond: ModelConditions,
messages: PromptMessage[],
options: CopilotChatOptions = {}
): AsyncIterable<StreamObject> {
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
yield* super.streamObject(fullCond, messages, options);
}
override async embedding(
cond: ModelConditions,
messages: string | string[],
options?: CopilotEmbeddingOptions
): Promise<number[][]> {
const fullCond = {
...cond,
fallbackModel: this.config.fallback?.embedding,
};
return super.embedding(fullCond, messages, options);
}
override async refreshOnlineModels() {
try {
const { baseUrl, headers } = await getGoogleAuth(this.config, 'google');
if (baseUrl && !this.onlineModelList.length) {
const { publisherModels } = await fetch(`${baseUrl}/models`, {
headers: headers(),
})
.then(r => r.json())
.then(r => VertexModelListSchema.parse(r));
this.onlineModelList = publisherModels.map(model =>
model.name.replace('publishers/google/models/', '')
);
}
} catch (e) {
this.logger.error('Failed to fetch available models', e);
}
}
}

View File

@@ -46,8 +46,18 @@ export const DEFAULT_DIMENSIONS = 256;
export type OpenAIConfig = {
apiKey: string;
baseUrl?: string;
fallback?: {
text?: string;
structured?: string;
image?: string;
embedding?: string;
};
};
const ModelListSchema = z.object({
data: z.array(z.object({ id: z.string() })),
});
const ImageResponseSchema = z.union([
z.object({
data: z.array(z.object({ b64_json: z.string() })),
@@ -271,6 +281,25 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
}
}
override async refreshOnlineModels() {
try {
const baseUrl = this.config.baseUrl || 'https://api.openai.com/v1';
if (baseUrl && !this.onlineModelList.length) {
const { data } = await fetch(`${baseUrl}/models`, {
headers: {
Authorization: `Bearer ${this.config.apiKey}`,
'Content-Type': 'application/json',
},
})
.then(r => r.json())
.then(r => ModelListSchema.parse(r));
this.onlineModelList = data.map(model => model.id);
}
} catch (e) {
this.logger.error('Failed to fetch available models', e);
}
}
override getProviderSpecificTools(
toolName: CopilotChatTools,
model: string
@@ -291,6 +320,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
const fullCond = {
...cond,
outputType: ModelOutputType.Text,
fallbackModel: this.config.fallback?.text,
};
await this.checkParams({ messages, cond: fullCond, options });
const model = this.selectModel(fullCond);
@@ -331,6 +361,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
const fullCond = {
...cond,
outputType: ModelOutputType.Text,
fallbackModel: this.config.fallback?.text,
};
await this.checkParams({ messages, cond: fullCond, options });
const model = this.selectModel(fullCond);
@@ -376,7 +407,11 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
messages: PromptMessage[],
options: CopilotChatOptions = {}
): AsyncIterable<StreamObject> {
const fullCond = { ...cond, outputType: ModelOutputType.Object };
const fullCond = {
...cond,
outputType: ModelOutputType.Object,
fallbackModel: this.config.fallback?.text,
};
await this.checkParams({ cond: fullCond, messages, options });
const model = this.selectModel(fullCond);
@@ -409,7 +444,11 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
messages: PromptMessage[],
options: CopilotStructuredOptions = {}
): Promise<string> {
const fullCond = { ...cond, outputType: ModelOutputType.Structured };
const fullCond = {
...cond,
outputType: ModelOutputType.Structured,
fallbackModel: this.config.fallback?.structured,
};
await this.checkParams({ messages, cond: fullCond, options });
const model = this.selectModel(fullCond);
@@ -449,7 +488,11 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
chunkMessages: PromptMessage[][],
options: CopilotChatOptions = {}
): Promise<number[]> {
const fullCond = { ...cond, outputType: ModelOutputType.Text };
const fullCond = {
...cond,
outputType: ModelOutputType.Text,
fallbackModel: this.config.fallback?.text,
};
await this.checkParams({ messages: [], cond: fullCond, options });
const model = this.selectModel(fullCond);
// get the log probability of "yes"/"no"
@@ -594,7 +637,11 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
messages: PromptMessage[],
options: CopilotImageOptions = {}
) {
const fullCond = { ...cond, outputType: ModelOutputType.Image };
const fullCond = {
...cond,
outputType: ModelOutputType.Image,
fallbackModel: this.config.fallback?.image,
};
await this.checkParams({ messages, cond: fullCond, options });
const model = this.selectModel(fullCond);
@@ -644,7 +691,11 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
options: CopilotEmbeddingOptions = { dimensions: DEFAULT_DIMENSIONS }
): Promise<number[][]> {
messages = Array.isArray(messages) ? messages : [messages];
const fullCond = { ...cond, outputType: ModelOutputType.Embedding };
const fullCond = {
...cond,
outputType: ModelOutputType.Embedding,
fallbackModel: this.config.fallback?.embedding,
};
await this.checkParams({ embeddings: messages, cond: fullCond, options });
const model = this.selectModel(fullCond);

View File

@@ -20,6 +20,9 @@ import { chatToGPTMessage, CitationParser } from './utils';
export type PerplexityConfig = {
apiKey: string;
endpoint?: string;
fallback?: {
text?: string;
};
};
const PerplexityErrorSchema = z.union([
@@ -109,7 +112,11 @@ export class PerplexityProvider extends CopilotProvider<PerplexityConfig> {
messages: PromptMessage[],
options: CopilotChatOptions = {}
): Promise<string> {
const fullCond = { ...cond, outputType: ModelOutputType.Text };
const fullCond = {
...cond,
outputType: ModelOutputType.Text,
fallbackModel: this.config.fallback?.text,
};
await this.checkParams({ cond: fullCond, messages, options });
const model = this.selectModel(fullCond);
@@ -149,7 +156,11 @@ export class PerplexityProvider extends CopilotProvider<PerplexityConfig> {
messages: PromptMessage[],
options: CopilotChatOptions = {}
): AsyncIterable<string> {
const fullCond = { ...cond, outputType: ModelOutputType.Text };
const fullCond = {
...cond,
outputType: ModelOutputType.Text,
fallbackModel: this.config.fallback?.text,
};
await this.checkParams({ cond: fullCond, messages, options });
const model = this.selectModel(fullCond);

View File

@@ -53,6 +53,7 @@ import {
@Injectable()
export abstract class CopilotProvider<C = any> {
protected readonly logger = new Logger(this.constructor.name);
protected onlineModelList: string[] = [];
abstract readonly type: CopilotProviderType;
abstract readonly models: CopilotProviderModel[];
abstract configured(): boolean;
@@ -80,11 +81,18 @@ export abstract class CopilotProvider<C = any> {
protected setup() {
if (this.configured()) {
this.factory.register(this);
if (env.selfhosted) {
this.refreshOnlineModels().catch(e =>
this.logger.error('Failed to refresh online models', e)
);
}
} else {
this.factory.unregister(this);
}
}
async refreshOnlineModels() {}
private findValidModel(
cond: ModelFullConditions
): CopilotProviderModel | undefined {
@@ -95,9 +103,26 @@ export abstract class CopilotProvider<C = any> {
inputTypes.every(type => cap.input.includes(type)));
if (modelId) {
return this.models.find(
const hasOnlineModel = this.onlineModelList.includes(modelId);
const hasFallbackModel = cond.fallbackModel
? this.onlineModelList.includes(cond.fallbackModel)
: undefined;
const model = this.models.find(
m => m.id === modelId && m.capabilities.some(matcher)
);
if (model) {
// return fallback model if current model is not alive
if (!hasOnlineModel && hasFallbackModel) {
// oxlint-disable-next-line typescript-eslint(no-non-null-assertion)
return { id: cond.fallbackModel!, capabilities: [] };
}
return model;
}
// allow online model without capabilities check
if (hasOnlineModel) return { id: modelId, capabilities: [] };
return undefined;
}
if (!outputType) return undefined;

View File

@@ -237,6 +237,7 @@ export interface ModelCapability {
export interface CopilotProviderModel {
id: string;
name?: string;
capabilities: ModelCapability[];
}
@@ -247,4 +248,5 @@ export type ModelConditions = {
export type ModelFullConditions = ModelConditions & {
outputType?: ModelOutputType;
fallbackModel?: string;
};

View File

@@ -1,3 +1,5 @@
import { GoogleVertexProviderSettings } from '@ai-sdk/google-vertex';
import { GoogleVertexAnthropicProviderSettings } from '@ai-sdk/google-vertex/anthropic';
import { Logger } from '@nestjs/common';
import {
CoreAssistantMessage,
@@ -7,7 +9,8 @@ import {
TextPart,
TextStreamPart,
} from 'ai';
import { ZodType } from 'zod';
import { GoogleAuth, GoogleAuthOptions } from 'google-auth-library';
import z, { ZodType } from 'zod';
import { CustomAITools } from '../tools';
import { PromptMessage, StreamObject } from './types';
@@ -655,3 +658,54 @@ export class StreamObjectParser {
}, '');
}
}
export const VertexModelListSchema = z.object({
publisherModels: z.array(
z.object({
name: z.string(),
versionId: z.string(),
})
),
});
export async function getGoogleAuth(
options: GoogleVertexAnthropicProviderSettings | GoogleVertexProviderSettings,
publisher: 'anthropic' | 'google'
) {
function getBaseUrl() {
const { baseURL, location } = options;
if (baseURL?.trim()) {
try {
const url = new URL(baseURL);
if (url.pathname.endsWith('/')) {
url.pathname = url.pathname.slice(0, -1);
}
return url.toString();
} catch {}
} else if (location) {
return `https://${location}-aiplatform.googleapis.com/v1beta1/publishers/${publisher}`;
}
return undefined;
}
async function generateAuthToken() {
if (!options.googleAuthOptions) {
return undefined;
}
const auth = new GoogleAuth({
scopes: ['https://www.googleapis.com/auth/cloud-platform'],
...(options.googleAuthOptions as GoogleAuthOptions),
});
const client = await auth.getClient();
const token = await client.getAccessToken();
return token.token;
}
const token = await generateAuthToken();
return {
baseUrl: getBaseUrl(),
headers: () => ({ Authorization: `Bearer ${token}` }),
fetch: options.fetch,
};
}