fix: pick copilot provider depend on model (#6540)

This commit is contained in:
darkskygit
2024-04-12 12:01:39 +00:00
parent 62f90e5f10
commit fc51b68674
6 changed files with 75 additions and 21 deletions

View File

@@ -89,8 +89,10 @@ export class CopilotController {
@Query('messageId') messageId: string | undefined,
@Query() params: Record<string, string | string[]>
): Promise<string> {
const model = await this.chatSession.get(sessionId).then(s => s?.model);
const provider = this.provider.getProviderByCapability(
CopilotCapability.TextToText
CopilotCapability.TextToText,
model
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
@@ -139,8 +141,10 @@ export class CopilotController {
@Query('messageId') messageId: string | undefined,
@Query() params: Record<string, string>
): Promise<Observable<ChatEvent>> {
const model = await this.chatSession.get(sessionId).then(s => s?.model);
const provider = this.provider.getProviderByCapability(
CopilotCapability.TextToText
CopilotCapability.TextToText,
model
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
@@ -194,10 +198,13 @@ export class CopilotController {
@Query('messageId') messageId: string | undefined,
@Query() params: Record<string, string>
): Promise<Observable<ChatEvent>> {
const hasAttachment = await this.hasAttachment(sessionId, messageId);
const model = await this.chatSession.get(sessionId).then(s => s?.model);
const provider = this.provider.getProviderByCapability(
(await this.hasAttachment(sessionId, messageId))
hasAttachment
? CopilotCapability.ImageToImage
: CopilotCapability.TextToImage
: CopilotCapability.TextToImage,
model
);
if (!provider) {
throw new InternalServerErrorException('No provider available');

View File

@@ -4,6 +4,7 @@ import {
CopilotCapability,
CopilotImageToImageProvider,
CopilotProviderType,
CopilotTextToImageProvider,
PromptMessage,
} from '../types';
@@ -12,17 +13,24 @@ export type FalConfig = {
};
export type FalResponse = {
detail: Array<{ msg: string }>;
images: Array<{ url: string }>;
};
export class FalProvider implements CopilotImageToImageProvider {
export class FalProvider
implements CopilotTextToImageProvider, CopilotImageToImageProvider
{
static readonly type = CopilotProviderType.FAL;
static readonly capabilities = [CopilotCapability.ImageToImage];
static readonly capabilities = [
CopilotCapability.TextToImage,
CopilotCapability.ImageToImage,
];
readonly availableModels = [
// text to image
'fast-turbo-diffusion',
// image to image
// https://blog.fal.ai/building-applications-with-real-time-stable-diffusion-apis/
'110602490-lcm-sd15-i2i',
'lcm-sd15-i2i',
];
constructor(private readonly config: FalConfig) {
@@ -37,6 +45,10 @@ export class FalProvider implements CopilotImageToImageProvider {
return FalProvider.capabilities;
}
isModelAvailable(model: string): boolean {
return this.availableModels.includes(model);
}
// ====== image to image ======
async generateImages(
messages: PromptMessage[],
@@ -50,21 +62,20 @@ export class FalProvider implements CopilotImageToImageProvider {
if (!this.availableModels.includes(model)) {
throw new Error(`Invalid model: ${model}`);
}
if (!content) {
throw new Error('Prompt is required');
}
if (!Array.isArray(attachments) || !attachments.length) {
throw new Error('Attachments is required');
// prompt attachments require at least one
if (!content && (!Array.isArray(attachments) || !attachments.length)) {
throw new Error('Prompt or Attachments is empty');
}
const data = (await fetch(`https://${model}.gateway.alpha.fal.ai/`, {
const data = (await fetch(`https://fal.run/fal-ai/${model}`, {
method: 'POST',
headers: {
Authorization: `key ${this.config.apiKey}`,
'Content-Type': 'application/json',
},
body: JSON.stringify({
image_url: attachments[0],
image_url: attachments?.[0],
prompt: content,
sync_mode: true,
seed: 42,
@@ -73,7 +84,13 @@ export class FalProvider implements CopilotImageToImageProvider {
signal: options.signal,
}).then(res => res.json())) as FalResponse;
return data.images.map(image => image.url);
if (!data.images?.length) {
const error = data.detail?.[0]?.msg;
throw new Error(
error ? `Invalid message: ${error}` : 'No images generated'
);
}
return data.images?.map(image => image.url) || [];
}
async *generateImagesStream(

View File

@@ -118,17 +118,36 @@ export class CopilotProviderService {
getProviderByCapability<C extends CopilotCapability>(
capability: C,
model?: string,
prefer?: CopilotProviderType
): CapabilityToCopilotProvider[C] | null {
const providers = PROVIDER_CAPABILITY_MAP.get(capability);
if (Array.isArray(providers) && providers.length) {
const selectedCapability =
prefer && providers.includes(prefer) ? prefer : providers[0];
let selectedProvider: CopilotProviderType | undefined = prefer;
let currentIndex = -1;
const provider = this.getProvider(selectedCapability);
assert(provider.getCapabilities().includes(capability));
if (!selectedProvider) {
currentIndex = 0;
selectedProvider = providers[currentIndex];
}
return provider as CapabilityToCopilotProvider[C];
while (selectedProvider) {
// find first provider that supports the capability and model
if (providers.includes(selectedProvider)) {
const provider = this.getProvider(selectedProvider);
if (provider.getCapabilities().includes(capability)) {
if (model) {
if (provider.isModelAvailable(model)) {
return provider as CapabilityToCopilotProvider[C];
}
} else {
return provider as CapabilityToCopilotProvider[C];
}
}
}
currentIndex += 1;
selectedProvider = providers[currentIndex];
}
}
return null;
}

View File

@@ -63,6 +63,10 @@ export class OpenAIProvider
return OpenAIProvider.capabilities;
}
isModelAvailable(model: string): boolean {
return this.availableModels.includes(model);
}
private chatToGPTMessage(
messages: PromptMessage[]
): OpenAI.Chat.Completions.ChatCompletionMessageParam[] {

View File

@@ -141,6 +141,7 @@ export enum CopilotCapability {
export interface CopilotProvider {
getCapabilities(): CopilotCapability[];
isModelAvailable(model: string): boolean;
}
export interface CopilotTextToTextProvider extends CopilotProvider {