feat: allow undefined new model (#6933)

This commit is contained in:
darkskygit
2024-05-14 13:05:07 +00:00
parent b036f1b5c9
commit 98e218af93
8 changed files with 74 additions and 26 deletions

View File

@@ -133,7 +133,7 @@ export class CopilotController {
@Query() params: Record<string, string | string[]>
): Promise<string> {
const { model } = await this.checkRequest(user.id, sessionId);
const provider = this.provider.getProviderByCapability(
const provider = await this.provider.getProviderByCapability(
CopilotCapability.TextToText,
model
);
@@ -179,7 +179,7 @@ export class CopilotController {
): Promise<Observable<ChatEvent>> {
try {
const { model } = await this.checkRequest(user.id, sessionId);
const provider = this.provider.getProviderByCapability(
const provider = await this.provider.getProviderByCapability(
CopilotCapability.TextToText,
model
);
@@ -246,7 +246,7 @@ export class CopilotController {
sessionId,
messageId
);
const provider = this.provider.getProviderByCapability(
const provider = await this.provider.getProviderByCapability(
hasAttachment
? CopilotCapability.ImageToImage
: CopilotCapability.TextToImage,

View File

@@ -50,7 +50,7 @@ export class FalProvider
return FalProvider.capabilities;
}
isModelAvailable(model: string): boolean {
async isModelAvailable(model: string): Promise<boolean> {
return this.availableModels.includes(model);
}

View File

@@ -48,7 +48,7 @@ export function registerCopilotProvider<
const providerConfig = config.plugins.copilot?.[type];
if (!provider.assetsConfig(providerConfig as C)) {
throw new Error(
`Invalid configuration for copilot provider ${type}: ${providerConfig}`
`Invalid configuration for copilot provider ${type}: ${JSON.stringify(providerConfig)}`
);
}
const instance = new provider(providerConfig as C);
@@ -116,11 +116,11 @@ export class CopilotProviderService {
return this.cachedProviders.get(provider)!;
}
getProviderByCapability<C extends CopilotCapability>(
async getProviderByCapability<C extends CopilotCapability>(
capability: C,
model?: string,
prefer?: CopilotProviderType
): CapabilityToCopilotProvider[C] | null {
): Promise<CapabilityToCopilotProvider[C] | null> {
const providers = PROVIDER_CAPABILITY_MAP.get(capability);
if (Array.isArray(providers) && providers.length) {
let selectedProvider: CopilotProviderType | undefined = prefer;
@@ -137,7 +137,7 @@ export class CopilotProviderService {
const provider = this.getProvider(selectedProvider);
if (provider.getCapabilities().includes(capability)) {
if (model) {
if (provider.isModelAvailable(model)) {
if (await provider.isModelAvailable(model)) {
return provider as CapabilityToCopilotProvider[C];
}
} else {

View File

@@ -1,5 +1,6 @@
import assert from 'node:assert';
import { Logger } from '@nestjs/common';
import { ClientOptions, OpenAI } from 'openai';
import {
@@ -51,7 +52,9 @@ export class OpenAIProvider
'dall-e-3',
];
private readonly logger = new Logger(OpenAIProvider.type);
private readonly instance: OpenAI;
private existsModels: string[] | undefined;
constructor(config: ClientOptions) {
assert(OpenAIProvider.assetsConfig(config));
@@ -70,8 +73,20 @@ export class OpenAIProvider
return OpenAIProvider.capabilities;
}
isModelAvailable(model: string): boolean {
return this.availableModels.includes(model);
async isModelAvailable(model: string): Promise<boolean> {
const knownModels = this.availableModels.includes(model);
if (knownModels) return true;
if (!this.existsModels) {
try {
this.existsModels = await this.instance.models
.list()
.then(({ data }) => data.map(m => m.id));
} catch (e) {
this.logger.error('Failed to fetch online model list', e);
}
}
return !!this.existsModels?.includes(model);
}
protected chatToGPTMessage(

View File

@@ -172,7 +172,7 @@ export type CopilotImageOptions = z.infer<typeof CopilotImageOptionsSchema>;
export interface CopilotProvider {
readonly type: CopilotProviderType;
getCapabilities(): CopilotCapability[];
isModelAvailable(model: string): boolean;
isModelAvailable(model: string): Promise<boolean>;
}
export interface CopilotTextToTextProvider extends CopilotProvider {