mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-13 21:05:19 +00:00
feat(server): add fallback model and baseurl in schema (#13375)
fix AI-398 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added support for specifying fallback models for multiple AI providers, enhancing reliability when primary models are unavailable. * Providers can now fetch and update their list of available models dynamically from external APIs. * Configuration options expanded to allow custom base URLs for certain providers. * **Bug Fixes** * Improved model selection logic to use fallback models if the requested model is not available online. * **Chores** * Updated backend dependencies to include authentication support for Google services. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
@@ -47,6 +47,13 @@ defineModuleConfig('copilot', {
|
||||
desc: 'The config for the openai provider.',
|
||||
default: {
|
||||
apiKey: '',
|
||||
baseUrl: '',
|
||||
fallback: {
|
||||
text: '',
|
||||
structured: '',
|
||||
image: '',
|
||||
embedding: '',
|
||||
},
|
||||
},
|
||||
link: 'https://github.com/openai/openai-node',
|
||||
},
|
||||
@@ -60,28 +67,54 @@ defineModuleConfig('copilot', {
|
||||
desc: 'The config for the gemini provider.',
|
||||
default: {
|
||||
apiKey: '',
|
||||
baseUrl: '',
|
||||
fallback: {
|
||||
text: '',
|
||||
structured: '',
|
||||
image: '',
|
||||
embedding: '',
|
||||
},
|
||||
},
|
||||
},
|
||||
'providers.geminiVertex': {
|
||||
desc: 'The config for the gemini provider in Google Vertex AI.',
|
||||
default: {},
|
||||
default: {
|
||||
baseURL: '',
|
||||
fallback: {
|
||||
text: '',
|
||||
structured: '',
|
||||
image: '',
|
||||
embedding: '',
|
||||
},
|
||||
},
|
||||
schema: VertexSchema,
|
||||
},
|
||||
'providers.perplexity': {
|
||||
desc: 'The config for the perplexity provider.',
|
||||
default: {
|
||||
apiKey: '',
|
||||
fallback: {
|
||||
text: '',
|
||||
},
|
||||
},
|
||||
},
|
||||
'providers.anthropic': {
|
||||
desc: 'The config for the anthropic provider.',
|
||||
default: {
|
||||
apiKey: '',
|
||||
fallback: {
|
||||
text: '',
|
||||
},
|
||||
},
|
||||
},
|
||||
'providers.anthropicVertex': {
|
||||
desc: 'The config for the anthropic provider in Google Vertex AI.',
|
||||
default: {},
|
||||
default: {
|
||||
baseURL: '',
|
||||
fallback: {
|
||||
text: '',
|
||||
},
|
||||
},
|
||||
schema: VertexSchema,
|
||||
},
|
||||
'providers.morph': {
|
||||
|
||||
@@ -3,12 +3,23 @@ import {
|
||||
createAnthropic,
|
||||
} from '@ai-sdk/anthropic';
|
||||
|
||||
import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types';
|
||||
import {
|
||||
CopilotChatOptions,
|
||||
CopilotProviderType,
|
||||
ModelConditions,
|
||||
ModelInputType,
|
||||
ModelOutputType,
|
||||
PromptMessage,
|
||||
StreamObject,
|
||||
} from '../types';
|
||||
import { AnthropicProvider } from './anthropic';
|
||||
|
||||
export type AnthropicOfficialConfig = {
|
||||
apiKey: string;
|
||||
baseUrl?: string;
|
||||
fallback?: {
|
||||
text?: string;
|
||||
};
|
||||
};
|
||||
|
||||
export class AnthropicOfficialProvider extends AnthropicProvider<AnthropicOfficialConfig> {
|
||||
@@ -67,4 +78,31 @@ export class AnthropicOfficialProvider extends AnthropicProvider<AnthropicOffici
|
||||
baseURL: this.config.baseUrl,
|
||||
});
|
||||
}
|
||||
|
||||
override async text(
|
||||
cond: ModelConditions,
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): Promise<string> {
|
||||
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
|
||||
return super.text(fullCond, messages, options);
|
||||
}
|
||||
|
||||
override async *streamText(
|
||||
cond: ModelConditions,
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): AsyncIterable<string> {
|
||||
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
|
||||
yield* super.streamText(fullCond, messages, options);
|
||||
}
|
||||
|
||||
override async *streamObject(
|
||||
cond: ModelConditions,
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): AsyncIterable<StreamObject> {
|
||||
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
|
||||
yield* super.streamObject(fullCond, messages, options);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,10 +4,23 @@ import {
|
||||
type GoogleVertexAnthropicProviderSettings,
|
||||
} from '@ai-sdk/google-vertex/anthropic';
|
||||
|
||||
import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types';
|
||||
import {
|
||||
CopilotChatOptions,
|
||||
CopilotProviderType,
|
||||
ModelConditions,
|
||||
ModelInputType,
|
||||
ModelOutputType,
|
||||
PromptMessage,
|
||||
StreamObject,
|
||||
} from '../types';
|
||||
import { getGoogleAuth, VertexModelListSchema } from '../utils';
|
||||
import { AnthropicProvider } from './anthropic';
|
||||
|
||||
export type AnthropicVertexConfig = GoogleVertexAnthropicProviderSettings;
|
||||
export type AnthropicVertexConfig = GoogleVertexAnthropicProviderSettings & {
|
||||
fallback?: {
|
||||
text?: string;
|
||||
};
|
||||
};
|
||||
|
||||
export class AnthropicVertexProvider extends AnthropicProvider<AnthropicVertexConfig> {
|
||||
override readonly type = CopilotProviderType.AnthropicVertex;
|
||||
@@ -62,4 +75,54 @@ export class AnthropicVertexProvider extends AnthropicProvider<AnthropicVertexCo
|
||||
super.setup();
|
||||
this.instance = createVertexAnthropic(this.config);
|
||||
}
|
||||
|
||||
override async text(
|
||||
cond: ModelConditions,
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): Promise<string> {
|
||||
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
|
||||
return super.text(fullCond, messages, options);
|
||||
}
|
||||
|
||||
override async *streamText(
|
||||
cond: ModelConditions,
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): AsyncIterable<string> {
|
||||
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
|
||||
yield* super.streamText(fullCond, messages, options);
|
||||
}
|
||||
|
||||
override async *streamObject(
|
||||
cond: ModelConditions,
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): AsyncIterable<StreamObject> {
|
||||
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
|
||||
yield* super.streamObject(fullCond, messages, options);
|
||||
}
|
||||
|
||||
override async refreshOnlineModels() {
|
||||
try {
|
||||
const { baseUrl, headers } = await getGoogleAuth(
|
||||
this.config,
|
||||
'anthropic'
|
||||
);
|
||||
if (baseUrl && !this.onlineModelList.length) {
|
||||
const { publisherModels } = await fetch(`${baseUrl}/models`, {
|
||||
headers: headers(),
|
||||
})
|
||||
.then(r => r.json())
|
||||
.then(r => VertexModelListSchema.parse(r));
|
||||
this.onlineModelList = publisherModels.map(
|
||||
model =>
|
||||
model.name.replace('publishers/anthropic/models/', '') +
|
||||
(model.versionId !== 'default' ? `@${model.versionId}` : '')
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
this.logger.error('Failed to fetch available models', e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,11 +37,6 @@ import {
|
||||
|
||||
export const DEFAULT_DIMENSIONS = 256;
|
||||
|
||||
export type GeminiConfig = {
|
||||
apiKey: string;
|
||||
baseUrl?: string;
|
||||
};
|
||||
|
||||
export abstract class GeminiProvider<T> extends CopilotProvider<T> {
|
||||
private readonly MAX_STEPS = 20;
|
||||
|
||||
|
||||
@@ -2,15 +2,35 @@ import {
|
||||
createGoogleGenerativeAI,
|
||||
type GoogleGenerativeAIProvider,
|
||||
} from '@ai-sdk/google';
|
||||
import z from 'zod';
|
||||
|
||||
import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types';
|
||||
import {
|
||||
CopilotChatOptions,
|
||||
CopilotEmbeddingOptions,
|
||||
CopilotProviderType,
|
||||
ModelConditions,
|
||||
ModelInputType,
|
||||
ModelOutputType,
|
||||
PromptMessage,
|
||||
StreamObject,
|
||||
} from '../types';
|
||||
import { GeminiProvider } from './gemini';
|
||||
|
||||
export type GeminiGenerativeConfig = {
|
||||
apiKey: string;
|
||||
baseUrl?: string;
|
||||
fallback?: {
|
||||
text?: string;
|
||||
structured?: string;
|
||||
image?: string;
|
||||
embedding?: string;
|
||||
};
|
||||
};
|
||||
|
||||
const ModelListSchema = z.object({
|
||||
models: z.array(z.object({ name: z.string() })),
|
||||
});
|
||||
|
||||
export class GeminiGenerativeProvider extends GeminiProvider<GeminiGenerativeConfig> {
|
||||
override readonly type = CopilotProviderType.Gemini;
|
||||
|
||||
@@ -71,27 +91,16 @@ export class GeminiGenerativeProvider extends GeminiProvider<GeminiGenerativeCon
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'Text Embedding 005',
|
||||
id: 'text-embedding-005',
|
||||
name: 'Gemini Embedding',
|
||||
id: 'gemini-embedding-001',
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text],
|
||||
output: [ModelOutputType.Embedding],
|
||||
defaultForOutputType: true,
|
||||
},
|
||||
],
|
||||
},
|
||||
// not exists yet
|
||||
// {
|
||||
// name: 'Gemini Embedding',
|
||||
// id: 'gemini-embedding-001',
|
||||
// capabilities: [
|
||||
// {
|
||||
// input: [ModelInputType.Text],
|
||||
// output: [ModelOutputType.Embedding],
|
||||
// defaultForOutputType: true,
|
||||
// },
|
||||
// ],
|
||||
// },
|
||||
];
|
||||
|
||||
protected instance!: GoogleGenerativeAIProvider;
|
||||
@@ -107,4 +116,77 @@ export class GeminiGenerativeProvider extends GeminiProvider<GeminiGenerativeCon
|
||||
baseURL: this.config.baseUrl,
|
||||
});
|
||||
}
|
||||
|
||||
override async text(
|
||||
cond: ModelConditions,
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): Promise<string> {
|
||||
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
|
||||
return super.text(fullCond, messages, options);
|
||||
}
|
||||
|
||||
override async structure(
|
||||
cond: ModelConditions,
|
||||
messages: PromptMessage[],
|
||||
options?: CopilotChatOptions
|
||||
): Promise<string> {
|
||||
const fullCond = {
|
||||
...cond,
|
||||
fallbackModel: this.config.fallback?.structured,
|
||||
};
|
||||
return super.structure(fullCond, messages, options);
|
||||
}
|
||||
|
||||
override async *streamText(
|
||||
cond: ModelConditions,
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): AsyncIterable<string> {
|
||||
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
|
||||
yield* super.streamText(fullCond, messages, options);
|
||||
}
|
||||
|
||||
override async *streamObject(
|
||||
cond: ModelConditions,
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): AsyncIterable<StreamObject> {
|
||||
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
|
||||
yield* super.streamObject(fullCond, messages, options);
|
||||
}
|
||||
|
||||
override async embedding(
|
||||
cond: ModelConditions,
|
||||
messages: string | string[],
|
||||
options?: CopilotEmbeddingOptions
|
||||
): Promise<number[][]> {
|
||||
const fullCond = {
|
||||
...cond,
|
||||
fallbackModel: this.config.fallback?.embedding,
|
||||
};
|
||||
return super.embedding(fullCond, messages, options);
|
||||
}
|
||||
|
||||
override async refreshOnlineModels() {
|
||||
try {
|
||||
const baseUrl =
|
||||
this.config.baseUrl ||
|
||||
'https://generativelanguage.googleapis.com/v1beta';
|
||||
if (baseUrl && !this.onlineModelList.length) {
|
||||
const { models } = await fetch(
|
||||
`${baseUrl}/models?key=${this.config.apiKey}`
|
||||
)
|
||||
.then(r => r.json())
|
||||
.then(
|
||||
r => (console.log(JSON.stringify(r)), ModelListSchema.parse(r))
|
||||
);
|
||||
this.onlineModelList = models.map(model =>
|
||||
model.name.replace('models/', '')
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
this.logger.error('Failed to fetch available models', e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,10 +4,27 @@ import {
|
||||
type GoogleVertexProviderSettings,
|
||||
} from '@ai-sdk/google-vertex';
|
||||
|
||||
import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types';
|
||||
import {
|
||||
CopilotChatOptions,
|
||||
CopilotEmbeddingOptions,
|
||||
CopilotProviderType,
|
||||
ModelConditions,
|
||||
ModelInputType,
|
||||
ModelOutputType,
|
||||
PromptMessage,
|
||||
StreamObject,
|
||||
} from '../types';
|
||||
import { getGoogleAuth, VertexModelListSchema } from '../utils';
|
||||
import { GeminiProvider } from './gemini';
|
||||
|
||||
export type GeminiVertexConfig = GoogleVertexProviderSettings;
|
||||
export type GeminiVertexConfig = GoogleVertexProviderSettings & {
|
||||
fallback?: {
|
||||
text?: string;
|
||||
structured?: string;
|
||||
image?: string;
|
||||
embedding?: string;
|
||||
};
|
||||
};
|
||||
|
||||
export class GeminiVertexProvider extends GeminiProvider<GeminiVertexConfig> {
|
||||
override readonly type = CopilotProviderType.GeminiVertex;
|
||||
@@ -72,4 +89,73 @@ export class GeminiVertexProvider extends GeminiProvider<GeminiVertexConfig> {
|
||||
super.setup();
|
||||
this.instance = createVertex(this.config);
|
||||
}
|
||||
|
||||
override async text(
|
||||
cond: ModelConditions,
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): Promise<string> {
|
||||
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
|
||||
return super.text(fullCond, messages, options);
|
||||
}
|
||||
|
||||
override async structure(
|
||||
cond: ModelConditions,
|
||||
messages: PromptMessage[],
|
||||
options?: CopilotChatOptions
|
||||
): Promise<string> {
|
||||
const fullCond = {
|
||||
...cond,
|
||||
fallbackModel: this.config.fallback?.structured,
|
||||
};
|
||||
return super.structure(fullCond, messages, options);
|
||||
}
|
||||
|
||||
override async *streamText(
|
||||
cond: ModelConditions,
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): AsyncIterable<string> {
|
||||
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
|
||||
yield* super.streamText(fullCond, messages, options);
|
||||
}
|
||||
|
||||
override async *streamObject(
|
||||
cond: ModelConditions,
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): AsyncIterable<StreamObject> {
|
||||
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
|
||||
yield* super.streamObject(fullCond, messages, options);
|
||||
}
|
||||
|
||||
override async embedding(
|
||||
cond: ModelConditions,
|
||||
messages: string | string[],
|
||||
options?: CopilotEmbeddingOptions
|
||||
): Promise<number[][]> {
|
||||
const fullCond = {
|
||||
...cond,
|
||||
fallbackModel: this.config.fallback?.embedding,
|
||||
};
|
||||
return super.embedding(fullCond, messages, options);
|
||||
}
|
||||
|
||||
override async refreshOnlineModels() {
|
||||
try {
|
||||
const { baseUrl, headers } = await getGoogleAuth(this.config, 'google');
|
||||
if (baseUrl && !this.onlineModelList.length) {
|
||||
const { publisherModels } = await fetch(`${baseUrl}/models`, {
|
||||
headers: headers(),
|
||||
})
|
||||
.then(r => r.json())
|
||||
.then(r => VertexModelListSchema.parse(r));
|
||||
this.onlineModelList = publisherModels.map(model =>
|
||||
model.name.replace('publishers/google/models/', '')
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
this.logger.error('Failed to fetch available models', e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,8 +46,18 @@ export const DEFAULT_DIMENSIONS = 256;
|
||||
export type OpenAIConfig = {
|
||||
apiKey: string;
|
||||
baseUrl?: string;
|
||||
fallback?: {
|
||||
text?: string;
|
||||
structured?: string;
|
||||
image?: string;
|
||||
embedding?: string;
|
||||
};
|
||||
};
|
||||
|
||||
const ModelListSchema = z.object({
|
||||
data: z.array(z.object({ id: z.string() })),
|
||||
});
|
||||
|
||||
const ImageResponseSchema = z.union([
|
||||
z.object({
|
||||
data: z.array(z.object({ b64_json: z.string() })),
|
||||
@@ -271,6 +281,25 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
}
|
||||
}
|
||||
|
||||
override async refreshOnlineModels() {
|
||||
try {
|
||||
const baseUrl = this.config.baseUrl || 'https://api.openai.com/v1';
|
||||
if (baseUrl && !this.onlineModelList.length) {
|
||||
const { data } = await fetch(`${baseUrl}/models`, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${this.config.apiKey}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
})
|
||||
.then(r => r.json())
|
||||
.then(r => ModelListSchema.parse(r));
|
||||
this.onlineModelList = data.map(model => model.id);
|
||||
}
|
||||
} catch (e) {
|
||||
this.logger.error('Failed to fetch available models', e);
|
||||
}
|
||||
}
|
||||
|
||||
override getProviderSpecificTools(
|
||||
toolName: CopilotChatTools,
|
||||
model: string
|
||||
@@ -291,6 +320,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
const fullCond = {
|
||||
...cond,
|
||||
outputType: ModelOutputType.Text,
|
||||
fallbackModel: this.config.fallback?.text,
|
||||
};
|
||||
await this.checkParams({ messages, cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
@@ -331,6 +361,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
const fullCond = {
|
||||
...cond,
|
||||
outputType: ModelOutputType.Text,
|
||||
fallbackModel: this.config.fallback?.text,
|
||||
};
|
||||
await this.checkParams({ messages, cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
@@ -376,7 +407,11 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): AsyncIterable<StreamObject> {
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Object };
|
||||
const fullCond = {
|
||||
...cond,
|
||||
outputType: ModelOutputType.Object,
|
||||
fallbackModel: this.config.fallback?.text,
|
||||
};
|
||||
await this.checkParams({ cond: fullCond, messages, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
@@ -409,7 +444,11 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
messages: PromptMessage[],
|
||||
options: CopilotStructuredOptions = {}
|
||||
): Promise<string> {
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Structured };
|
||||
const fullCond = {
|
||||
...cond,
|
||||
outputType: ModelOutputType.Structured,
|
||||
fallbackModel: this.config.fallback?.structured,
|
||||
};
|
||||
await this.checkParams({ messages, cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
@@ -449,7 +488,11 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
chunkMessages: PromptMessage[][],
|
||||
options: CopilotChatOptions = {}
|
||||
): Promise<number[]> {
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Text };
|
||||
const fullCond = {
|
||||
...cond,
|
||||
outputType: ModelOutputType.Text,
|
||||
fallbackModel: this.config.fallback?.text,
|
||||
};
|
||||
await this.checkParams({ messages: [], cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
// get the log probability of "yes"/"no"
|
||||
@@ -594,7 +637,11 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
messages: PromptMessage[],
|
||||
options: CopilotImageOptions = {}
|
||||
) {
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Image };
|
||||
const fullCond = {
|
||||
...cond,
|
||||
outputType: ModelOutputType.Image,
|
||||
fallbackModel: this.config.fallback?.image,
|
||||
};
|
||||
await this.checkParams({ messages, cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
@@ -644,7 +691,11 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
options: CopilotEmbeddingOptions = { dimensions: DEFAULT_DIMENSIONS }
|
||||
): Promise<number[][]> {
|
||||
messages = Array.isArray(messages) ? messages : [messages];
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Embedding };
|
||||
const fullCond = {
|
||||
...cond,
|
||||
outputType: ModelOutputType.Embedding,
|
||||
fallbackModel: this.config.fallback?.embedding,
|
||||
};
|
||||
await this.checkParams({ embeddings: messages, cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
|
||||
@@ -20,6 +20,9 @@ import { chatToGPTMessage, CitationParser } from './utils';
|
||||
export type PerplexityConfig = {
|
||||
apiKey: string;
|
||||
endpoint?: string;
|
||||
fallback?: {
|
||||
text?: string;
|
||||
};
|
||||
};
|
||||
|
||||
const PerplexityErrorSchema = z.union([
|
||||
@@ -109,7 +112,11 @@ export class PerplexityProvider extends CopilotProvider<PerplexityConfig> {
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): Promise<string> {
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Text };
|
||||
const fullCond = {
|
||||
...cond,
|
||||
outputType: ModelOutputType.Text,
|
||||
fallbackModel: this.config.fallback?.text,
|
||||
};
|
||||
await this.checkParams({ cond: fullCond, messages, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
@@ -149,7 +156,11 @@ export class PerplexityProvider extends CopilotProvider<PerplexityConfig> {
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): AsyncIterable<string> {
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Text };
|
||||
const fullCond = {
|
||||
...cond,
|
||||
outputType: ModelOutputType.Text,
|
||||
fallbackModel: this.config.fallback?.text,
|
||||
};
|
||||
await this.checkParams({ cond: fullCond, messages, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
|
||||
@@ -53,6 +53,7 @@ import {
|
||||
@Injectable()
|
||||
export abstract class CopilotProvider<C = any> {
|
||||
protected readonly logger = new Logger(this.constructor.name);
|
||||
protected onlineModelList: string[] = [];
|
||||
abstract readonly type: CopilotProviderType;
|
||||
abstract readonly models: CopilotProviderModel[];
|
||||
abstract configured(): boolean;
|
||||
@@ -80,11 +81,18 @@ export abstract class CopilotProvider<C = any> {
|
||||
protected setup() {
|
||||
if (this.configured()) {
|
||||
this.factory.register(this);
|
||||
if (env.selfhosted) {
|
||||
this.refreshOnlineModels().catch(e =>
|
||||
this.logger.error('Failed to refresh online models', e)
|
||||
);
|
||||
}
|
||||
} else {
|
||||
this.factory.unregister(this);
|
||||
}
|
||||
}
|
||||
|
||||
async refreshOnlineModels() {}
|
||||
|
||||
private findValidModel(
|
||||
cond: ModelFullConditions
|
||||
): CopilotProviderModel | undefined {
|
||||
@@ -95,9 +103,26 @@ export abstract class CopilotProvider<C = any> {
|
||||
inputTypes.every(type => cap.input.includes(type)));
|
||||
|
||||
if (modelId) {
|
||||
return this.models.find(
|
||||
const hasOnlineModel = this.onlineModelList.includes(modelId);
|
||||
const hasFallbackModel = cond.fallbackModel
|
||||
? this.onlineModelList.includes(cond.fallbackModel)
|
||||
: undefined;
|
||||
|
||||
const model = this.models.find(
|
||||
m => m.id === modelId && m.capabilities.some(matcher)
|
||||
);
|
||||
|
||||
if (model) {
|
||||
// return fallback model if current model is not alive
|
||||
if (!hasOnlineModel && hasFallbackModel) {
|
||||
// oxlint-disable-next-line typescript-eslint(no-non-null-assertion)
|
||||
return { id: cond.fallbackModel!, capabilities: [] };
|
||||
}
|
||||
return model;
|
||||
}
|
||||
// allow online model without capabilities check
|
||||
if (hasOnlineModel) return { id: modelId, capabilities: [] };
|
||||
return undefined;
|
||||
}
|
||||
if (!outputType) return undefined;
|
||||
|
||||
|
||||
@@ -237,6 +237,7 @@ export interface ModelCapability {
|
||||
|
||||
export interface CopilotProviderModel {
|
||||
id: string;
|
||||
name?: string;
|
||||
capabilities: ModelCapability[];
|
||||
}
|
||||
|
||||
@@ -247,4 +248,5 @@ export type ModelConditions = {
|
||||
|
||||
export type ModelFullConditions = ModelConditions & {
|
||||
outputType?: ModelOutputType;
|
||||
fallbackModel?: string;
|
||||
};
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import { GoogleVertexProviderSettings } from '@ai-sdk/google-vertex';
|
||||
import { GoogleVertexAnthropicProviderSettings } from '@ai-sdk/google-vertex/anthropic';
|
||||
import { Logger } from '@nestjs/common';
|
||||
import {
|
||||
CoreAssistantMessage,
|
||||
@@ -7,7 +9,8 @@ import {
|
||||
TextPart,
|
||||
TextStreamPart,
|
||||
} from 'ai';
|
||||
import { ZodType } from 'zod';
|
||||
import { GoogleAuth, GoogleAuthOptions } from 'google-auth-library';
|
||||
import z, { ZodType } from 'zod';
|
||||
|
||||
import { CustomAITools } from '../tools';
|
||||
import { PromptMessage, StreamObject } from './types';
|
||||
@@ -655,3 +658,54 @@ export class StreamObjectParser {
|
||||
}, '');
|
||||
}
|
||||
}
|
||||
|
||||
export const VertexModelListSchema = z.object({
|
||||
publisherModels: z.array(
|
||||
z.object({
|
||||
name: z.string(),
|
||||
versionId: z.string(),
|
||||
})
|
||||
),
|
||||
});
|
||||
|
||||
export async function getGoogleAuth(
|
||||
options: GoogleVertexAnthropicProviderSettings | GoogleVertexProviderSettings,
|
||||
publisher: 'anthropic' | 'google'
|
||||
) {
|
||||
function getBaseUrl() {
|
||||
const { baseURL, location } = options;
|
||||
if (baseURL?.trim()) {
|
||||
try {
|
||||
const url = new URL(baseURL);
|
||||
if (url.pathname.endsWith('/')) {
|
||||
url.pathname = url.pathname.slice(0, -1);
|
||||
}
|
||||
return url.toString();
|
||||
} catch {}
|
||||
} else if (location) {
|
||||
return `https://${location}-aiplatform.googleapis.com/v1beta1/publishers/${publisher}`;
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
async function generateAuthToken() {
|
||||
if (!options.googleAuthOptions) {
|
||||
return undefined;
|
||||
}
|
||||
const auth = new GoogleAuth({
|
||||
scopes: ['https://www.googleapis.com/auth/cloud-platform'],
|
||||
...(options.googleAuthOptions as GoogleAuthOptions),
|
||||
});
|
||||
const client = await auth.getClient();
|
||||
const token = await client.getAccessToken();
|
||||
return token.token;
|
||||
}
|
||||
|
||||
const token = await generateAuthToken();
|
||||
|
||||
return {
|
||||
baseUrl: getBaseUrl(),
|
||||
headers: () => ({ Authorization: `Bearer ${token}` }),
|
||||
fetch: options.fetch,
|
||||
};
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user