feat(server): add fallback model and baseurl in schema (#13375)

fix AI-398

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **New Features**
* Added support for specifying fallback models for multiple AI
providers, enhancing reliability when primary models are unavailable.
* Providers can now fetch and update their list of available models
dynamically from external APIs.
* Configuration options expanded to allow custom base URLs for certain
providers.

* **Bug Fixes**
* Improved model selection logic to use fallback models if the requested
model is not available online.

* **Chores**
* Updated backend dependencies to include authentication support for
Google services.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
DarkSky
2025-08-01 15:22:48 +08:00
committed by GitHub
parent 19790c1b9e
commit 5cbcf6f907
14 changed files with 544 additions and 69 deletions

View File

@@ -669,9 +669,16 @@
}, },
"providers.openai": { "providers.openai": {
"type": "object", "type": "object",
"description": "The config for the openai provider.\n@default {\"apiKey\":\"\"}\n@link https://github.com/openai/openai-node", "description": "The config for the openai provider.\n@default {\"apiKey\":\"\",\"baseUrl\":\"\",\"fallback\":{\"text\":\"\",\"structured\":\"\",\"image\":\"\",\"embedding\":\"\"}}\n@link https://github.com/openai/openai-node",
"default": { "default": {
"apiKey": "" "apiKey": "",
"baseUrl": "",
"fallback": {
"text": "",
"structured": "",
"image": "",
"embedding": ""
}
} }
}, },
"providers.fal": { "providers.fal": {
@@ -683,14 +690,21 @@
}, },
"providers.gemini": { "providers.gemini": {
"type": "object", "type": "object",
"description": "The config for the gemini provider.\n@default {\"apiKey\":\"\"}", "description": "The config for the gemini provider.\n@default {\"apiKey\":\"\",\"baseUrl\":\"\",\"fallback\":{\"text\":\"\",\"structured\":\"\",\"image\":\"\",\"embedding\":\"\"}}",
"default": { "default": {
"apiKey": "" "apiKey": "",
"baseUrl": "",
"fallback": {
"text": "",
"structured": "",
"image": "",
"embedding": ""
}
} }
}, },
"providers.geminiVertex": { "providers.geminiVertex": {
"type": "object", "type": "object",
"description": "The config for the google vertex provider.\n@default {}", "description": "The config for the google vertex provider.\n@default {\"baseURL\":\"\",\"fallback\":{\"text\":\"\",\"structured\":\"\",\"image\":\"\",\"embedding\":\"\"}}",
"properties": { "properties": {
"location": { "location": {
"type": "string", "type": "string",
@@ -721,25 +735,39 @@
} }
} }
}, },
"default": {} "default": {
"baseURL": "",
"fallback": {
"text": "",
"structured": "",
"image": "",
"embedding": ""
}
}
}, },
"providers.perplexity": { "providers.perplexity": {
"type": "object", "type": "object",
"description": "The config for the perplexity provider.\n@default {\"apiKey\":\"\"}", "description": "The config for the perplexity provider.\n@default {\"apiKey\":\"\",\"fallback\":{\"text\":\"\"}}",
"default": { "default": {
"apiKey": "" "apiKey": "",
"fallback": {
"text": ""
}
} }
}, },
"providers.anthropic": { "providers.anthropic": {
"type": "object", "type": "object",
"description": "The config for the anthropic provider.\n@default {\"apiKey\":\"\"}", "description": "The config for the anthropic provider.\n@default {\"apiKey\":\"\",\"fallback\":{\"text\":\"\"}}",
"default": { "default": {
"apiKey": "" "apiKey": "",
"fallback": {
"text": ""
}
} }
}, },
"providers.anthropicVertex": { "providers.anthropicVertex": {
"type": "object", "type": "object",
"description": "The config for the google vertex provider.\n@default {}", "description": "The config for the google vertex provider.\n@default {\"baseURL\":\"\",\"fallback\":{\"text\":\"\"}}",
"properties": { "properties": {
"location": { "location": {
"type": "string", "type": "string",
@@ -770,7 +798,12 @@
} }
} }
}, },
"default": {} "default": {
"baseURL": "",
"fallback": {
"text": ""
}
}
}, },
"providers.morph": { "providers.morph": {
"type": "object", "type": "object",

View File

@@ -86,6 +86,7 @@
"express": "^5.0.1", "express": "^5.0.1",
"fast-xml-parser": "^5.0.0", "fast-xml-parser": "^5.0.0",
"get-stream": "^9.0.1", "get-stream": "^9.0.1",
"google-auth-library": "^10.2.0",
"graphql": "^16.9.0", "graphql": "^16.9.0",
"graphql-scalars": "^1.24.0", "graphql-scalars": "^1.24.0",
"graphql-upload": "^17.0.0", "graphql-upload": "^17.0.0",

View File

@@ -47,6 +47,13 @@ defineModuleConfig('copilot', {
desc: 'The config for the openai provider.', desc: 'The config for the openai provider.',
default: { default: {
apiKey: '', apiKey: '',
baseUrl: '',
fallback: {
text: '',
structured: '',
image: '',
embedding: '',
},
}, },
link: 'https://github.com/openai/openai-node', link: 'https://github.com/openai/openai-node',
}, },
@@ -60,28 +67,54 @@ defineModuleConfig('copilot', {
desc: 'The config for the gemini provider.', desc: 'The config for the gemini provider.',
default: { default: {
apiKey: '', apiKey: '',
baseUrl: '',
fallback: {
text: '',
structured: '',
image: '',
embedding: '',
},
}, },
}, },
'providers.geminiVertex': { 'providers.geminiVertex': {
desc: 'The config for the gemini provider in Google Vertex AI.', desc: 'The config for the gemini provider in Google Vertex AI.',
default: {}, default: {
baseURL: '',
fallback: {
text: '',
structured: '',
image: '',
embedding: '',
},
},
schema: VertexSchema, schema: VertexSchema,
}, },
'providers.perplexity': { 'providers.perplexity': {
desc: 'The config for the perplexity provider.', desc: 'The config for the perplexity provider.',
default: { default: {
apiKey: '', apiKey: '',
fallback: {
text: '',
},
}, },
}, },
'providers.anthropic': { 'providers.anthropic': {
desc: 'The config for the anthropic provider.', desc: 'The config for the anthropic provider.',
default: { default: {
apiKey: '', apiKey: '',
fallback: {
text: '',
},
}, },
}, },
'providers.anthropicVertex': { 'providers.anthropicVertex': {
desc: 'The config for the anthropic provider in Google Vertex AI.', desc: 'The config for the anthropic provider in Google Vertex AI.',
default: {}, default: {
baseURL: '',
fallback: {
text: '',
},
},
schema: VertexSchema, schema: VertexSchema,
}, },
'providers.morph': { 'providers.morph': {

View File

@@ -3,12 +3,23 @@ import {
createAnthropic, createAnthropic,
} from '@ai-sdk/anthropic'; } from '@ai-sdk/anthropic';
import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types'; import {
CopilotChatOptions,
CopilotProviderType,
ModelConditions,
ModelInputType,
ModelOutputType,
PromptMessage,
StreamObject,
} from '../types';
import { AnthropicProvider } from './anthropic'; import { AnthropicProvider } from './anthropic';
export type AnthropicOfficialConfig = { export type AnthropicOfficialConfig = {
apiKey: string; apiKey: string;
baseUrl?: string; baseUrl?: string;
fallback?: {
text?: string;
};
}; };
export class AnthropicOfficialProvider extends AnthropicProvider<AnthropicOfficialConfig> { export class AnthropicOfficialProvider extends AnthropicProvider<AnthropicOfficialConfig> {
@@ -67,4 +78,31 @@ export class AnthropicOfficialProvider extends AnthropicProvider<AnthropicOffici
baseURL: this.config.baseUrl, baseURL: this.config.baseUrl,
}); });
} }
override async text(
cond: ModelConditions,
messages: PromptMessage[],
options: CopilotChatOptions = {}
): Promise<string> {
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
return super.text(fullCond, messages, options);
}
override async *streamText(
cond: ModelConditions,
messages: PromptMessage[],
options: CopilotChatOptions = {}
): AsyncIterable<string> {
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
yield* super.streamText(fullCond, messages, options);
}
override async *streamObject(
cond: ModelConditions,
messages: PromptMessage[],
options: CopilotChatOptions = {}
): AsyncIterable<StreamObject> {
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
yield* super.streamObject(fullCond, messages, options);
}
} }

View File

@@ -4,10 +4,23 @@ import {
type GoogleVertexAnthropicProviderSettings, type GoogleVertexAnthropicProviderSettings,
} from '@ai-sdk/google-vertex/anthropic'; } from '@ai-sdk/google-vertex/anthropic';
import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types'; import {
CopilotChatOptions,
CopilotProviderType,
ModelConditions,
ModelInputType,
ModelOutputType,
PromptMessage,
StreamObject,
} from '../types';
import { getGoogleAuth, VertexModelListSchema } from '../utils';
import { AnthropicProvider } from './anthropic'; import { AnthropicProvider } from './anthropic';
export type AnthropicVertexConfig = GoogleVertexAnthropicProviderSettings; export type AnthropicVertexConfig = GoogleVertexAnthropicProviderSettings & {
fallback?: {
text?: string;
};
};
export class AnthropicVertexProvider extends AnthropicProvider<AnthropicVertexConfig> { export class AnthropicVertexProvider extends AnthropicProvider<AnthropicVertexConfig> {
override readonly type = CopilotProviderType.AnthropicVertex; override readonly type = CopilotProviderType.AnthropicVertex;
@@ -62,4 +75,54 @@ export class AnthropicVertexProvider extends AnthropicProvider<AnthropicVertexCo
super.setup(); super.setup();
this.instance = createVertexAnthropic(this.config); this.instance = createVertexAnthropic(this.config);
} }
override async text(
cond: ModelConditions,
messages: PromptMessage[],
options: CopilotChatOptions = {}
): Promise<string> {
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
return super.text(fullCond, messages, options);
}
override async *streamText(
cond: ModelConditions,
messages: PromptMessage[],
options: CopilotChatOptions = {}
): AsyncIterable<string> {
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
yield* super.streamText(fullCond, messages, options);
}
override async *streamObject(
cond: ModelConditions,
messages: PromptMessage[],
options: CopilotChatOptions = {}
): AsyncIterable<StreamObject> {
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
yield* super.streamObject(fullCond, messages, options);
}
override async refreshOnlineModels() {
try {
const { baseUrl, headers } = await getGoogleAuth(
this.config,
'anthropic'
);
if (baseUrl && !this.onlineModelList.length) {
const { publisherModels } = await fetch(`${baseUrl}/models`, {
headers: headers(),
})
.then(r => r.json())
.then(r => VertexModelListSchema.parse(r));
this.onlineModelList = publisherModels.map(
model =>
model.name.replace('publishers/anthropic/models/', '') +
(model.versionId !== 'default' ? `@${model.versionId}` : '')
);
}
} catch (e) {
this.logger.error('Failed to fetch available models', e);
}
}
} }

View File

@@ -37,11 +37,6 @@ import {
export const DEFAULT_DIMENSIONS = 256; export const DEFAULT_DIMENSIONS = 256;
export type GeminiConfig = {
apiKey: string;
baseUrl?: string;
};
export abstract class GeminiProvider<T> extends CopilotProvider<T> { export abstract class GeminiProvider<T> extends CopilotProvider<T> {
private readonly MAX_STEPS = 20; private readonly MAX_STEPS = 20;

View File

@@ -2,15 +2,35 @@ import {
createGoogleGenerativeAI, createGoogleGenerativeAI,
type GoogleGenerativeAIProvider, type GoogleGenerativeAIProvider,
} from '@ai-sdk/google'; } from '@ai-sdk/google';
import z from 'zod';
import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types'; import {
CopilotChatOptions,
CopilotEmbeddingOptions,
CopilotProviderType,
ModelConditions,
ModelInputType,
ModelOutputType,
PromptMessage,
StreamObject,
} from '../types';
import { GeminiProvider } from './gemini'; import { GeminiProvider } from './gemini';
export type GeminiGenerativeConfig = { export type GeminiGenerativeConfig = {
apiKey: string; apiKey: string;
baseUrl?: string; baseUrl?: string;
fallback?: {
text?: string;
structured?: string;
image?: string;
embedding?: string;
};
}; };
const ModelListSchema = z.object({
models: z.array(z.object({ name: z.string() })),
});
export class GeminiGenerativeProvider extends GeminiProvider<GeminiGenerativeConfig> { export class GeminiGenerativeProvider extends GeminiProvider<GeminiGenerativeConfig> {
override readonly type = CopilotProviderType.Gemini; override readonly type = CopilotProviderType.Gemini;
@@ -71,27 +91,16 @@ export class GeminiGenerativeProvider extends GeminiProvider<GeminiGenerativeCon
], ],
}, },
{ {
name: 'Text Embedding 005', name: 'Gemini Embedding',
id: 'text-embedding-005', id: 'gemini-embedding-001',
capabilities: [ capabilities: [
{ {
input: [ModelInputType.Text], input: [ModelInputType.Text],
output: [ModelOutputType.Embedding], output: [ModelOutputType.Embedding],
defaultForOutputType: true,
}, },
], ],
}, },
// not exists yet
// {
// name: 'Gemini Embedding',
// id: 'gemini-embedding-001',
// capabilities: [
// {
// input: [ModelInputType.Text],
// output: [ModelOutputType.Embedding],
// defaultForOutputType: true,
// },
// ],
// },
]; ];
protected instance!: GoogleGenerativeAIProvider; protected instance!: GoogleGenerativeAIProvider;
@@ -107,4 +116,77 @@ export class GeminiGenerativeProvider extends GeminiProvider<GeminiGenerativeCon
baseURL: this.config.baseUrl, baseURL: this.config.baseUrl,
}); });
} }
override async text(
cond: ModelConditions,
messages: PromptMessage[],
options: CopilotChatOptions = {}
): Promise<string> {
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
return super.text(fullCond, messages, options);
}
override async structure(
cond: ModelConditions,
messages: PromptMessage[],
options?: CopilotChatOptions
): Promise<string> {
const fullCond = {
...cond,
fallbackModel: this.config.fallback?.structured,
};
return super.structure(fullCond, messages, options);
}
override async *streamText(
cond: ModelConditions,
messages: PromptMessage[],
options: CopilotChatOptions = {}
): AsyncIterable<string> {
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
yield* super.streamText(fullCond, messages, options);
}
override async *streamObject(
cond: ModelConditions,
messages: PromptMessage[],
options: CopilotChatOptions = {}
): AsyncIterable<StreamObject> {
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
yield* super.streamObject(fullCond, messages, options);
}
override async embedding(
cond: ModelConditions,
messages: string | string[],
options?: CopilotEmbeddingOptions
): Promise<number[][]> {
const fullCond = {
...cond,
fallbackModel: this.config.fallback?.embedding,
};
return super.embedding(fullCond, messages, options);
}
override async refreshOnlineModels() {
try {
const baseUrl =
this.config.baseUrl ||
'https://generativelanguage.googleapis.com/v1beta';
if (baseUrl && !this.onlineModelList.length) {
const { models } = await fetch(
`${baseUrl}/models?key=${this.config.apiKey}`
)
.then(r => r.json())
.then(
r => (console.log(JSON.stringify(r)), ModelListSchema.parse(r))
);
this.onlineModelList = models.map(model =>
model.name.replace('models/', '')
);
}
} catch (e) {
this.logger.error('Failed to fetch available models', e);
}
}
} }

View File

@@ -4,10 +4,27 @@ import {
type GoogleVertexProviderSettings, type GoogleVertexProviderSettings,
} from '@ai-sdk/google-vertex'; } from '@ai-sdk/google-vertex';
import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types'; import {
CopilotChatOptions,
CopilotEmbeddingOptions,
CopilotProviderType,
ModelConditions,
ModelInputType,
ModelOutputType,
PromptMessage,
StreamObject,
} from '../types';
import { getGoogleAuth, VertexModelListSchema } from '../utils';
import { GeminiProvider } from './gemini'; import { GeminiProvider } from './gemini';
export type GeminiVertexConfig = GoogleVertexProviderSettings; export type GeminiVertexConfig = GoogleVertexProviderSettings & {
fallback?: {
text?: string;
structured?: string;
image?: string;
embedding?: string;
};
};
export class GeminiVertexProvider extends GeminiProvider<GeminiVertexConfig> { export class GeminiVertexProvider extends GeminiProvider<GeminiVertexConfig> {
override readonly type = CopilotProviderType.GeminiVertex; override readonly type = CopilotProviderType.GeminiVertex;
@@ -72,4 +89,73 @@ export class GeminiVertexProvider extends GeminiProvider<GeminiVertexConfig> {
super.setup(); super.setup();
this.instance = createVertex(this.config); this.instance = createVertex(this.config);
} }
override async text(
cond: ModelConditions,
messages: PromptMessage[],
options: CopilotChatOptions = {}
): Promise<string> {
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
return super.text(fullCond, messages, options);
}
override async structure(
cond: ModelConditions,
messages: PromptMessage[],
options?: CopilotChatOptions
): Promise<string> {
const fullCond = {
...cond,
fallbackModel: this.config.fallback?.structured,
};
return super.structure(fullCond, messages, options);
}
override async *streamText(
cond: ModelConditions,
messages: PromptMessage[],
options: CopilotChatOptions = {}
): AsyncIterable<string> {
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
yield* super.streamText(fullCond, messages, options);
}
override async *streamObject(
cond: ModelConditions,
messages: PromptMessage[],
options: CopilotChatOptions = {}
): AsyncIterable<StreamObject> {
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
yield* super.streamObject(fullCond, messages, options);
}
override async embedding(
cond: ModelConditions,
messages: string | string[],
options?: CopilotEmbeddingOptions
): Promise<number[][]> {
const fullCond = {
...cond,
fallbackModel: this.config.fallback?.embedding,
};
return super.embedding(fullCond, messages, options);
}
override async refreshOnlineModels() {
try {
const { baseUrl, headers } = await getGoogleAuth(this.config, 'google');
if (baseUrl && !this.onlineModelList.length) {
const { publisherModels } = await fetch(`${baseUrl}/models`, {
headers: headers(),
})
.then(r => r.json())
.then(r => VertexModelListSchema.parse(r));
this.onlineModelList = publisherModels.map(model =>
model.name.replace('publishers/google/models/', '')
);
}
} catch (e) {
this.logger.error('Failed to fetch available models', e);
}
}
} }

View File

@@ -46,8 +46,18 @@ export const DEFAULT_DIMENSIONS = 256;
export type OpenAIConfig = { export type OpenAIConfig = {
apiKey: string; apiKey: string;
baseUrl?: string; baseUrl?: string;
fallback?: {
text?: string;
structured?: string;
image?: string;
embedding?: string;
};
}; };
const ModelListSchema = z.object({
data: z.array(z.object({ id: z.string() })),
});
const ImageResponseSchema = z.union([ const ImageResponseSchema = z.union([
z.object({ z.object({
data: z.array(z.object({ b64_json: z.string() })), data: z.array(z.object({ b64_json: z.string() })),
@@ -271,6 +281,25 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
} }
} }
override async refreshOnlineModels() {
try {
const baseUrl = this.config.baseUrl || 'https://api.openai.com/v1';
if (baseUrl && !this.onlineModelList.length) {
const { data } = await fetch(`${baseUrl}/models`, {
headers: {
Authorization: `Bearer ${this.config.apiKey}`,
'Content-Type': 'application/json',
},
})
.then(r => r.json())
.then(r => ModelListSchema.parse(r));
this.onlineModelList = data.map(model => model.id);
}
} catch (e) {
this.logger.error('Failed to fetch available models', e);
}
}
override getProviderSpecificTools( override getProviderSpecificTools(
toolName: CopilotChatTools, toolName: CopilotChatTools,
model: string model: string
@@ -291,6 +320,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
const fullCond = { const fullCond = {
...cond, ...cond,
outputType: ModelOutputType.Text, outputType: ModelOutputType.Text,
fallbackModel: this.config.fallback?.text,
}; };
await this.checkParams({ messages, cond: fullCond, options }); await this.checkParams({ messages, cond: fullCond, options });
const model = this.selectModel(fullCond); const model = this.selectModel(fullCond);
@@ -331,6 +361,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
const fullCond = { const fullCond = {
...cond, ...cond,
outputType: ModelOutputType.Text, outputType: ModelOutputType.Text,
fallbackModel: this.config.fallback?.text,
}; };
await this.checkParams({ messages, cond: fullCond, options }); await this.checkParams({ messages, cond: fullCond, options });
const model = this.selectModel(fullCond); const model = this.selectModel(fullCond);
@@ -376,7 +407,11 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
messages: PromptMessage[], messages: PromptMessage[],
options: CopilotChatOptions = {} options: CopilotChatOptions = {}
): AsyncIterable<StreamObject> { ): AsyncIterable<StreamObject> {
const fullCond = { ...cond, outputType: ModelOutputType.Object }; const fullCond = {
...cond,
outputType: ModelOutputType.Object,
fallbackModel: this.config.fallback?.text,
};
await this.checkParams({ cond: fullCond, messages, options }); await this.checkParams({ cond: fullCond, messages, options });
const model = this.selectModel(fullCond); const model = this.selectModel(fullCond);
@@ -409,7 +444,11 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
messages: PromptMessage[], messages: PromptMessage[],
options: CopilotStructuredOptions = {} options: CopilotStructuredOptions = {}
): Promise<string> { ): Promise<string> {
const fullCond = { ...cond, outputType: ModelOutputType.Structured }; const fullCond = {
...cond,
outputType: ModelOutputType.Structured,
fallbackModel: this.config.fallback?.structured,
};
await this.checkParams({ messages, cond: fullCond, options }); await this.checkParams({ messages, cond: fullCond, options });
const model = this.selectModel(fullCond); const model = this.selectModel(fullCond);
@@ -449,7 +488,11 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
chunkMessages: PromptMessage[][], chunkMessages: PromptMessage[][],
options: CopilotChatOptions = {} options: CopilotChatOptions = {}
): Promise<number[]> { ): Promise<number[]> {
const fullCond = { ...cond, outputType: ModelOutputType.Text }; const fullCond = {
...cond,
outputType: ModelOutputType.Text,
fallbackModel: this.config.fallback?.text,
};
await this.checkParams({ messages: [], cond: fullCond, options }); await this.checkParams({ messages: [], cond: fullCond, options });
const model = this.selectModel(fullCond); const model = this.selectModel(fullCond);
// get the log probability of "yes"/"no" // get the log probability of "yes"/"no"
@@ -594,7 +637,11 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
messages: PromptMessage[], messages: PromptMessage[],
options: CopilotImageOptions = {} options: CopilotImageOptions = {}
) { ) {
const fullCond = { ...cond, outputType: ModelOutputType.Image }; const fullCond = {
...cond,
outputType: ModelOutputType.Image,
fallbackModel: this.config.fallback?.image,
};
await this.checkParams({ messages, cond: fullCond, options }); await this.checkParams({ messages, cond: fullCond, options });
const model = this.selectModel(fullCond); const model = this.selectModel(fullCond);
@@ -644,7 +691,11 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
options: CopilotEmbeddingOptions = { dimensions: DEFAULT_DIMENSIONS } options: CopilotEmbeddingOptions = { dimensions: DEFAULT_DIMENSIONS }
): Promise<number[][]> { ): Promise<number[][]> {
messages = Array.isArray(messages) ? messages : [messages]; messages = Array.isArray(messages) ? messages : [messages];
const fullCond = { ...cond, outputType: ModelOutputType.Embedding }; const fullCond = {
...cond,
outputType: ModelOutputType.Embedding,
fallbackModel: this.config.fallback?.embedding,
};
await this.checkParams({ embeddings: messages, cond: fullCond, options }); await this.checkParams({ embeddings: messages, cond: fullCond, options });
const model = this.selectModel(fullCond); const model = this.selectModel(fullCond);

View File

@@ -20,6 +20,9 @@ import { chatToGPTMessage, CitationParser } from './utils';
export type PerplexityConfig = { export type PerplexityConfig = {
apiKey: string; apiKey: string;
endpoint?: string; endpoint?: string;
fallback?: {
text?: string;
};
}; };
const PerplexityErrorSchema = z.union([ const PerplexityErrorSchema = z.union([
@@ -109,7 +112,11 @@ export class PerplexityProvider extends CopilotProvider<PerplexityConfig> {
messages: PromptMessage[], messages: PromptMessage[],
options: CopilotChatOptions = {} options: CopilotChatOptions = {}
): Promise<string> { ): Promise<string> {
const fullCond = { ...cond, outputType: ModelOutputType.Text }; const fullCond = {
...cond,
outputType: ModelOutputType.Text,
fallbackModel: this.config.fallback?.text,
};
await this.checkParams({ cond: fullCond, messages, options }); await this.checkParams({ cond: fullCond, messages, options });
const model = this.selectModel(fullCond); const model = this.selectModel(fullCond);
@@ -149,7 +156,11 @@ export class PerplexityProvider extends CopilotProvider<PerplexityConfig> {
messages: PromptMessage[], messages: PromptMessage[],
options: CopilotChatOptions = {} options: CopilotChatOptions = {}
): AsyncIterable<string> { ): AsyncIterable<string> {
const fullCond = { ...cond, outputType: ModelOutputType.Text }; const fullCond = {
...cond,
outputType: ModelOutputType.Text,
fallbackModel: this.config.fallback?.text,
};
await this.checkParams({ cond: fullCond, messages, options }); await this.checkParams({ cond: fullCond, messages, options });
const model = this.selectModel(fullCond); const model = this.selectModel(fullCond);

View File

@@ -53,6 +53,7 @@ import {
@Injectable() @Injectable()
export abstract class CopilotProvider<C = any> { export abstract class CopilotProvider<C = any> {
protected readonly logger = new Logger(this.constructor.name); protected readonly logger = new Logger(this.constructor.name);
protected onlineModelList: string[] = [];
abstract readonly type: CopilotProviderType; abstract readonly type: CopilotProviderType;
abstract readonly models: CopilotProviderModel[]; abstract readonly models: CopilotProviderModel[];
abstract configured(): boolean; abstract configured(): boolean;
@@ -80,11 +81,18 @@ export abstract class CopilotProvider<C = any> {
protected setup() { protected setup() {
if (this.configured()) { if (this.configured()) {
this.factory.register(this); this.factory.register(this);
if (env.selfhosted) {
this.refreshOnlineModels().catch(e =>
this.logger.error('Failed to refresh online models', e)
);
}
} else { } else {
this.factory.unregister(this); this.factory.unregister(this);
} }
} }
async refreshOnlineModels() {}
private findValidModel( private findValidModel(
cond: ModelFullConditions cond: ModelFullConditions
): CopilotProviderModel | undefined { ): CopilotProviderModel | undefined {
@@ -95,9 +103,26 @@ export abstract class CopilotProvider<C = any> {
inputTypes.every(type => cap.input.includes(type))); inputTypes.every(type => cap.input.includes(type)));
if (modelId) { if (modelId) {
return this.models.find( const hasOnlineModel = this.onlineModelList.includes(modelId);
const hasFallbackModel = cond.fallbackModel
? this.onlineModelList.includes(cond.fallbackModel)
: undefined;
const model = this.models.find(
m => m.id === modelId && m.capabilities.some(matcher) m => m.id === modelId && m.capabilities.some(matcher)
); );
if (model) {
// return fallback model if current model is not alive
if (!hasOnlineModel && hasFallbackModel) {
// oxlint-disable-next-line typescript-eslint(no-non-null-assertion)
return { id: cond.fallbackModel!, capabilities: [] };
}
return model;
}
// allow online model without capabilities check
if (hasOnlineModel) return { id: modelId, capabilities: [] };
return undefined;
} }
if (!outputType) return undefined; if (!outputType) return undefined;

View File

@@ -237,6 +237,7 @@ export interface ModelCapability {
export interface CopilotProviderModel { export interface CopilotProviderModel {
id: string; id: string;
name?: string;
capabilities: ModelCapability[]; capabilities: ModelCapability[];
} }
@@ -247,4 +248,5 @@ export type ModelConditions = {
export type ModelFullConditions = ModelConditions & { export type ModelFullConditions = ModelConditions & {
outputType?: ModelOutputType; outputType?: ModelOutputType;
fallbackModel?: string;
}; };

View File

@@ -1,3 +1,5 @@
import { GoogleVertexProviderSettings } from '@ai-sdk/google-vertex';
import { GoogleVertexAnthropicProviderSettings } from '@ai-sdk/google-vertex/anthropic';
import { Logger } from '@nestjs/common'; import { Logger } from '@nestjs/common';
import { import {
CoreAssistantMessage, CoreAssistantMessage,
@@ -7,7 +9,8 @@ import {
TextPart, TextPart,
TextStreamPart, TextStreamPart,
} from 'ai'; } from 'ai';
import { ZodType } from 'zod'; import { GoogleAuth, GoogleAuthOptions } from 'google-auth-library';
import z, { ZodType } from 'zod';
import { CustomAITools } from '../tools'; import { CustomAITools } from '../tools';
import { PromptMessage, StreamObject } from './types'; import { PromptMessage, StreamObject } from './types';
@@ -655,3 +658,54 @@ export class StreamObjectParser {
}, ''); }, '');
} }
} }
export const VertexModelListSchema = z.object({
publisherModels: z.array(
z.object({
name: z.string(),
versionId: z.string(),
})
),
});
export async function getGoogleAuth(
options: GoogleVertexAnthropicProviderSettings | GoogleVertexProviderSettings,
publisher: 'anthropic' | 'google'
) {
function getBaseUrl() {
const { baseURL, location } = options;
if (baseURL?.trim()) {
try {
const url = new URL(baseURL);
if (url.pathname.endsWith('/')) {
url.pathname = url.pathname.slice(0, -1);
}
return url.toString();
} catch {}
} else if (location) {
return `https://${location}-aiplatform.googleapis.com/v1beta1/publishers/${publisher}`;
}
return undefined;
}
async function generateAuthToken() {
if (!options.googleAuthOptions) {
return undefined;
}
const auth = new GoogleAuth({
scopes: ['https://www.googleapis.com/auth/cloud-platform'],
...(options.googleAuthOptions as GoogleAuthOptions),
});
const client = await auth.getClient();
const token = await client.getAccessToken();
return token.token;
}
const token = await generateAuthToken();
return {
baseUrl: getBaseUrl(),
headers: () => ({ Authorization: `Bearer ${token}` }),
fetch: options.fetch,
};
}

View File

@@ -998,6 +998,7 @@ __metadata:
express: "npm:^5.0.1" express: "npm:^5.0.1"
fast-xml-parser: "npm:^5.0.0" fast-xml-parser: "npm:^5.0.0"
get-stream: "npm:^9.0.1" get-stream: "npm:^9.0.1"
google-auth-library: "npm:^10.2.0"
graphql: "npm:^16.9.0" graphql: "npm:^16.9.0"
graphql-scalars: "npm:^1.24.0" graphql-scalars: "npm:^1.24.0"
graphql-upload: "npm:^17.0.0" graphql-upload: "npm:^17.0.0"
@@ -22147,14 +22148,14 @@ __metadata:
languageName: node languageName: node
linkType: hard linkType: hard
"gaxios@npm:^7.0.0-rc.1, gaxios@npm:^7.0.0-rc.4": "gaxios@npm:^7.0.0, gaxios@npm:^7.0.0-rc.4":
version: 7.0.0-rc.6 version: 7.1.1
resolution: "gaxios@npm:7.0.0-rc.6" resolution: "gaxios@npm:7.1.1"
dependencies: dependencies:
extend: "npm:^3.0.2" extend: "npm:^3.0.2"
https-proxy-agent: "npm:^7.0.1" https-proxy-agent: "npm:^7.0.1"
node-fetch: "npm:^3.3.2" node-fetch: "npm:^3.3.2"
checksum: 10/60c688d4c65062c97bf0f33f959713df106e207065586bf5deb546ef5d02cddcba46d138c0b7eb8712950ca880fa28d3665936b19156224f8c478d9c4f817aea checksum: 10/9e5fa8b458c318a95d4dff0f6ac187a1b8933fb1de5b376b7098b27dfc5bf6025b62c87ed20bdae0496ae73a279834bc6b974c28849a674deed0089f2ba57b98
languageName: node languageName: node
linkType: hard linkType: hard
@@ -22169,14 +22170,14 @@ __metadata:
languageName: node languageName: node
linkType: hard linkType: hard
"gcp-metadata@npm:^7.0.0-rc.1": "gcp-metadata@npm:^7.0.0":
version: 7.0.0-rc.1 version: 7.0.1
resolution: "gcp-metadata@npm:7.0.0-rc.1" resolution: "gcp-metadata@npm:7.0.1"
dependencies: dependencies:
gaxios: "npm:^7.0.0-rc.1" gaxios: "npm:^7.0.0"
google-logging-utils: "npm:^1.0.0" google-logging-utils: "npm:^1.0.0"
json-bigint: "npm:^1.0.0" json-bigint: "npm:^1.0.0"
checksum: 10/2c58401c7945c41144bc6a44a066c050d36c34ee10e04e85ffde488afec6a3f67ebe29e697d0328aa20181b38a36aae9165896c0201387ea8cff031cdb790ab9 checksum: 10/c82f20a4ce22278998fe033e668a66bff04d2b3e95e19f968adeac829e12274e07b453fcfcf34573a6d702b3570c5556cba6eb6b59d1c03757c866e3271972c1
languageName: node languageName: node
linkType: hard linkType: hard
@@ -22540,18 +22541,18 @@ __metadata:
languageName: node languageName: node
linkType: hard linkType: hard
"google-auth-library@npm:^10.0.0-rc.1": "google-auth-library@npm:^10.0.0-rc.1, google-auth-library@npm:^10.2.0":
version: 10.0.0-rc.3 version: 10.2.0
resolution: "google-auth-library@npm:10.0.0-rc.3" resolution: "google-auth-library@npm:10.2.0"
dependencies: dependencies:
base64-js: "npm:^1.3.0" base64-js: "npm:^1.3.0"
ecdsa-sig-formatter: "npm:^1.0.11" ecdsa-sig-formatter: "npm:^1.0.11"
gaxios: "npm:^7.0.0-rc.4" gaxios: "npm:^7.0.0"
gcp-metadata: "npm:^7.0.0-rc.1" gcp-metadata: "npm:^7.0.0"
google-logging-utils: "npm:^1.0.0" google-logging-utils: "npm:^1.0.0"
gtoken: "npm:^8.0.0-rc.1" gtoken: "npm:^8.0.0"
jws: "npm:^4.0.0" jws: "npm:^4.0.0"
checksum: 10/d76f470ddba1d5ec84cb72b03af388722db2987a95f64c7d91169f3f67608528a327fd132bb00d74be10b08e6446906b16983cf0aa9f6fa10ac7a6d3aacff0d7 checksum: 10/dfa6ad7240da3915b7e15d1f39cd6906e6714502b09957be07d2429665a44c6b46d8ca7a077cec7fcb83f82ed2c3cb5a9b7c1ba79ebb3e8920eba286a24bdd63
languageName: node languageName: node
linkType: hard linkType: hard
@@ -22748,13 +22749,13 @@ __metadata:
languageName: node languageName: node
linkType: hard linkType: hard
"gtoken@npm:^8.0.0-rc.1": "gtoken@npm:^8.0.0":
version: 8.0.0-rc.1 version: 8.0.0
resolution: "gtoken@npm:8.0.0-rc.1" resolution: "gtoken@npm:8.0.0"
dependencies: dependencies:
gaxios: "npm:^7.0.0-rc.1" gaxios: "npm:^7.0.0"
jws: "npm:^4.0.0" jws: "npm:^4.0.0"
checksum: 10/d2481344df8d9f62ec3ae7fe97b562c93dc294c7d1e7d8e1603162fe2726cfb5993f2b1a4e04388ee0f49c5fd02c1b0799dc58cebbf8af10489af8de80a72902 checksum: 10/b921430395dcd06ee63c3fc5a5e339ca4d6dcb38b6d618beb0f260bae1088d53d130f86029a9d578f1601c64685f49a65dba57bbd617c4b14039180b67b6c5ce
languageName: node languageName: node
linkType: hard linkType: hard