mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-13 12:55:00 +00:00
fix: pick copilot provider depend on model (#6540)
This commit is contained in:
@@ -89,8 +89,10 @@ export class CopilotController {
|
||||
@Query('messageId') messageId: string | undefined,
|
||||
@Query() params: Record<string, string | string[]>
|
||||
): Promise<string> {
|
||||
const model = await this.chatSession.get(sessionId).then(s => s?.model);
|
||||
const provider = this.provider.getProviderByCapability(
|
||||
CopilotCapability.TextToText
|
||||
CopilotCapability.TextToText,
|
||||
model
|
||||
);
|
||||
if (!provider) {
|
||||
throw new InternalServerErrorException('No provider available');
|
||||
@@ -139,8 +141,10 @@ export class CopilotController {
|
||||
@Query('messageId') messageId: string | undefined,
|
||||
@Query() params: Record<string, string>
|
||||
): Promise<Observable<ChatEvent>> {
|
||||
const model = await this.chatSession.get(sessionId).then(s => s?.model);
|
||||
const provider = this.provider.getProviderByCapability(
|
||||
CopilotCapability.TextToText
|
||||
CopilotCapability.TextToText,
|
||||
model
|
||||
);
|
||||
if (!provider) {
|
||||
throw new InternalServerErrorException('No provider available');
|
||||
@@ -194,10 +198,13 @@ export class CopilotController {
|
||||
@Query('messageId') messageId: string | undefined,
|
||||
@Query() params: Record<string, string>
|
||||
): Promise<Observable<ChatEvent>> {
|
||||
const hasAttachment = await this.hasAttachment(sessionId, messageId);
|
||||
const model = await this.chatSession.get(sessionId).then(s => s?.model);
|
||||
const provider = this.provider.getProviderByCapability(
|
||||
(await this.hasAttachment(sessionId, messageId))
|
||||
hasAttachment
|
||||
? CopilotCapability.ImageToImage
|
||||
: CopilotCapability.TextToImage
|
||||
: CopilotCapability.TextToImage,
|
||||
model
|
||||
);
|
||||
if (!provider) {
|
||||
throw new InternalServerErrorException('No provider available');
|
||||
|
||||
@@ -4,6 +4,7 @@ import {
|
||||
CopilotCapability,
|
||||
CopilotImageToImageProvider,
|
||||
CopilotProviderType,
|
||||
CopilotTextToImageProvider,
|
||||
PromptMessage,
|
||||
} from '../types';
|
||||
|
||||
@@ -12,17 +13,24 @@ export type FalConfig = {
|
||||
};
|
||||
|
||||
export type FalResponse = {
|
||||
detail: Array<{ msg: string }>;
|
||||
images: Array<{ url: string }>;
|
||||
};
|
||||
|
||||
export class FalProvider implements CopilotImageToImageProvider {
|
||||
export class FalProvider
|
||||
implements CopilotTextToImageProvider, CopilotImageToImageProvider
|
||||
{
|
||||
static readonly type = CopilotProviderType.FAL;
|
||||
static readonly capabilities = [CopilotCapability.ImageToImage];
|
||||
static readonly capabilities = [
|
||||
CopilotCapability.TextToImage,
|
||||
CopilotCapability.ImageToImage,
|
||||
];
|
||||
|
||||
readonly availableModels = [
|
||||
// text to image
|
||||
'fast-turbo-diffusion',
|
||||
// image to image
|
||||
// https://blog.fal.ai/building-applications-with-real-time-stable-diffusion-apis/
|
||||
'110602490-lcm-sd15-i2i',
|
||||
'lcm-sd15-i2i',
|
||||
];
|
||||
|
||||
constructor(private readonly config: FalConfig) {
|
||||
@@ -37,6 +45,10 @@ export class FalProvider implements CopilotImageToImageProvider {
|
||||
return FalProvider.capabilities;
|
||||
}
|
||||
|
||||
isModelAvailable(model: string): boolean {
|
||||
return this.availableModels.includes(model);
|
||||
}
|
||||
|
||||
// ====== image to image ======
|
||||
async generateImages(
|
||||
messages: PromptMessage[],
|
||||
@@ -50,21 +62,20 @@ export class FalProvider implements CopilotImageToImageProvider {
|
||||
if (!this.availableModels.includes(model)) {
|
||||
throw new Error(`Invalid model: ${model}`);
|
||||
}
|
||||
if (!content) {
|
||||
throw new Error('Prompt is required');
|
||||
}
|
||||
if (!Array.isArray(attachments) || !attachments.length) {
|
||||
throw new Error('Attachments is required');
|
||||
|
||||
// prompt attachments require at least one
|
||||
if (!content && (!Array.isArray(attachments) || !attachments.length)) {
|
||||
throw new Error('Prompt or Attachments is empty');
|
||||
}
|
||||
|
||||
const data = (await fetch(`https://${model}.gateway.alpha.fal.ai/`, {
|
||||
const data = (await fetch(`https://fal.run/fal-ai/${model}`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Authorization: `key ${this.config.apiKey}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
image_url: attachments[0],
|
||||
image_url: attachments?.[0],
|
||||
prompt: content,
|
||||
sync_mode: true,
|
||||
seed: 42,
|
||||
@@ -73,7 +84,13 @@ export class FalProvider implements CopilotImageToImageProvider {
|
||||
signal: options.signal,
|
||||
}).then(res => res.json())) as FalResponse;
|
||||
|
||||
return data.images.map(image => image.url);
|
||||
if (!data.images?.length) {
|
||||
const error = data.detail?.[0]?.msg;
|
||||
throw new Error(
|
||||
error ? `Invalid message: ${error}` : 'No images generated'
|
||||
);
|
||||
}
|
||||
return data.images?.map(image => image.url) || [];
|
||||
}
|
||||
|
||||
async *generateImagesStream(
|
||||
|
||||
@@ -118,17 +118,36 @@ export class CopilotProviderService {
|
||||
|
||||
getProviderByCapability<C extends CopilotCapability>(
|
||||
capability: C,
|
||||
model?: string,
|
||||
prefer?: CopilotProviderType
|
||||
): CapabilityToCopilotProvider[C] | null {
|
||||
const providers = PROVIDER_CAPABILITY_MAP.get(capability);
|
||||
if (Array.isArray(providers) && providers.length) {
|
||||
const selectedCapability =
|
||||
prefer && providers.includes(prefer) ? prefer : providers[0];
|
||||
let selectedProvider: CopilotProviderType | undefined = prefer;
|
||||
let currentIndex = -1;
|
||||
|
||||
const provider = this.getProvider(selectedCapability);
|
||||
assert(provider.getCapabilities().includes(capability));
|
||||
if (!selectedProvider) {
|
||||
currentIndex = 0;
|
||||
selectedProvider = providers[currentIndex];
|
||||
}
|
||||
|
||||
return provider as CapabilityToCopilotProvider[C];
|
||||
while (selectedProvider) {
|
||||
// find first provider that supports the capability and model
|
||||
if (providers.includes(selectedProvider)) {
|
||||
const provider = this.getProvider(selectedProvider);
|
||||
if (provider.getCapabilities().includes(capability)) {
|
||||
if (model) {
|
||||
if (provider.isModelAvailable(model)) {
|
||||
return provider as CapabilityToCopilotProvider[C];
|
||||
}
|
||||
} else {
|
||||
return provider as CapabilityToCopilotProvider[C];
|
||||
}
|
||||
}
|
||||
}
|
||||
currentIndex += 1;
|
||||
selectedProvider = providers[currentIndex];
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -63,6 +63,10 @@ export class OpenAIProvider
|
||||
return OpenAIProvider.capabilities;
|
||||
}
|
||||
|
||||
isModelAvailable(model: string): boolean {
|
||||
return this.availableModels.includes(model);
|
||||
}
|
||||
|
||||
private chatToGPTMessage(
|
||||
messages: PromptMessage[]
|
||||
): OpenAI.Chat.Completions.ChatCompletionMessageParam[] {
|
||||
|
||||
@@ -141,6 +141,7 @@ export enum CopilotCapability {
|
||||
|
||||
export interface CopilotProvider {
|
||||
getCapabilities(): CopilotCapability[];
|
||||
isModelAvailable(model: string): boolean;
|
||||
}
|
||||
|
||||
export interface CopilotTextToTextProvider extends CopilotProvider {
|
||||
|
||||
Reference in New Issue
Block a user