mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-14 13:25:12 +00:00
@@ -4,9 +4,13 @@ import {
|
||||
config as falConfig,
|
||||
stream as falStream,
|
||||
} from '@fal-ai/serverless-client';
|
||||
import { Logger } from '@nestjs/common';
|
||||
import { z } from 'zod';
|
||||
import { z, ZodType } from 'zod';
|
||||
|
||||
import {
|
||||
CopilotPromptInvalid,
|
||||
CopilotProviderSideError,
|
||||
UserFriendlyError,
|
||||
} from '../../../fundamentals';
|
||||
import {
|
||||
CopilotCapability,
|
||||
CopilotChatOptions,
|
||||
@@ -37,7 +41,10 @@ type FalImage = z.infer<typeof FalImageSchema>;
|
||||
|
||||
const FalResponseSchema = z.object({
|
||||
detail: z
|
||||
.union([z.array(z.object({ msg: z.string() })), z.string()])
|
||||
.union([
|
||||
z.array(z.object({ type: z.string(), msg: z.string() })),
|
||||
z.string(),
|
||||
])
|
||||
.optional(),
|
||||
images: z.array(FalImageSchema).optional(),
|
||||
image: FalImageSchema.optional(),
|
||||
@@ -84,8 +91,6 @@ export class FalProvider
|
||||
'llava-next',
|
||||
];
|
||||
|
||||
private readonly logger = new Logger(FalProvider.name);
|
||||
|
||||
constructor(private readonly config: FalConfig) {
|
||||
assert(FalProvider.assetsConfig(config));
|
||||
falConfig({ credentials: this.config.apiKey });
|
||||
@@ -107,23 +112,15 @@ export class FalProvider
|
||||
return this.availableModels.includes(model);
|
||||
}
|
||||
|
||||
private extractError(resp: FalResponse): string {
|
||||
return Array.isArray(resp.detail)
|
||||
? resp.detail[0]?.msg
|
||||
: typeof resp.detail === 'string'
|
||||
? resp.detail
|
||||
: '';
|
||||
}
|
||||
|
||||
private extractPrompt(message?: PromptMessage): FalPrompt {
|
||||
if (!message) throw new Error('Prompt is empty');
|
||||
if (!message) throw new CopilotPromptInvalid('Prompt is empty');
|
||||
const { content, attachments, params } = message;
|
||||
// prompt attachments require at least one
|
||||
if (!content && (!Array.isArray(attachments) || !attachments.length)) {
|
||||
throw new Error('Prompt or Attachments is empty');
|
||||
throw new CopilotPromptInvalid('Prompt or Attachments is empty');
|
||||
}
|
||||
if (Array.isArray(attachments) && attachments.length > 1) {
|
||||
throw new Error('Only one attachment is allowed');
|
||||
throw new CopilotPromptInvalid('Only one attachment is allowed');
|
||||
}
|
||||
const lora = (
|
||||
params?.lora
|
||||
@@ -139,38 +136,91 @@ export class FalProvider
|
||||
};
|
||||
}
|
||||
|
||||
private extractFalError(
|
||||
resp: FalResponse,
|
||||
message?: string
|
||||
): CopilotProviderSideError {
|
||||
if (Array.isArray(resp.detail) && resp.detail.length) {
|
||||
const error = resp.detail[0].msg;
|
||||
return new CopilotProviderSideError({
|
||||
provider: this.type,
|
||||
kind: resp.detail[0].type,
|
||||
message: message ? `${message}: ${error}` : error,
|
||||
});
|
||||
} else if (typeof resp.detail === 'string') {
|
||||
const error = resp.detail;
|
||||
return new CopilotProviderSideError({
|
||||
provider: this.type,
|
||||
kind: resp.detail,
|
||||
message: message ? `${message}: ${error}` : error,
|
||||
});
|
||||
}
|
||||
return new CopilotProviderSideError({
|
||||
provider: this.type,
|
||||
kind: 'unknown',
|
||||
message: 'No content generated',
|
||||
});
|
||||
}
|
||||
|
||||
private handleError(e: any) {
|
||||
if (e instanceof UserFriendlyError) {
|
||||
// pass through user friendly errors
|
||||
return e;
|
||||
} else {
|
||||
const error = new CopilotProviderSideError({
|
||||
provider: this.type,
|
||||
kind: 'unexpected_response',
|
||||
message: e?.message || 'Unexpected fal response',
|
||||
});
|
||||
return error;
|
||||
}
|
||||
}
|
||||
|
||||
private parseSchema<R>(schema: ZodType<R>, data: unknown): R {
|
||||
const result = schema.safeParse(data);
|
||||
if (result.success) return result.data;
|
||||
const errors = JSON.stringify(result.error.errors);
|
||||
throw new CopilotProviderSideError({
|
||||
provider: this.type,
|
||||
kind: 'unexpected_response',
|
||||
message: `Unexpected fal response: ${errors}`,
|
||||
});
|
||||
}
|
||||
|
||||
async generateText(
|
||||
messages: PromptMessage[],
|
||||
model: string = 'llava-next',
|
||||
options: CopilotChatOptions = {}
|
||||
): Promise<string> {
|
||||
if (!this.availableModels.includes(model)) {
|
||||
throw new Error(`Invalid model: ${model}`);
|
||||
throw new CopilotPromptInvalid(`Invalid model: ${model}`);
|
||||
}
|
||||
|
||||
// by default, image prompt assumes there is only one message
|
||||
const prompt = this.extractPrompt(messages.pop());
|
||||
const data = (await fetch(`https://fal.run/fal-ai/${model}`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Authorization: `key ${this.config.apiKey}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
...prompt,
|
||||
sync_mode: true,
|
||||
enable_safety_checks: false,
|
||||
}),
|
||||
signal: options.signal,
|
||||
}).then(res => res.json())) as FalResponse;
|
||||
try {
|
||||
const response = await fetch(`https://fal.run/fal-ai/${model}`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Authorization: `key ${this.config.apiKey}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
...prompt,
|
||||
sync_mode: true,
|
||||
enable_safety_checks: false,
|
||||
}),
|
||||
signal: options.signal,
|
||||
});
|
||||
|
||||
if (!data.output) {
|
||||
const error = this.extractError(data);
|
||||
throw new Error(
|
||||
error ? `Failed to generate image: ${error}` : 'No images generated'
|
||||
);
|
||||
const data = this.parseSchema(FalResponseSchema, await response.json());
|
||||
if (!data.output) {
|
||||
throw this.extractFalError(data, 'Failed to generate text');
|
||||
}
|
||||
return data.output;
|
||||
} catch (e: any) {
|
||||
throw this.handleError(e);
|
||||
}
|
||||
return data.output;
|
||||
}
|
||||
|
||||
async *generateTextStream(
|
||||
@@ -199,11 +249,8 @@ export class FalProvider
|
||||
const prompt = this.extractPrompt(messages.pop());
|
||||
if (model.startsWith('workflows/')) {
|
||||
const stream = await falStream(model, { input: prompt });
|
||||
|
||||
const result = FalStreamOutputSchema.safeParse(await stream.done());
|
||||
if (result.success) return result.data.output;
|
||||
const errors = JSON.stringify(result.error.errors);
|
||||
throw new Error(`Unexpected fal response: ${errors}`);
|
||||
return this.parseSchema(FalStreamOutputSchema, await stream.done())
|
||||
.output;
|
||||
} else {
|
||||
const response = await fetch(`https://fal.run/fal-ai/${model}`, {
|
||||
method: 'POST',
|
||||
@@ -219,10 +266,7 @@ export class FalProvider
|
||||
}),
|
||||
signal: options.signal,
|
||||
});
|
||||
const result = FalResponseSchema.safeParse(await response.json());
|
||||
if (result.success) return result.data;
|
||||
const errors = JSON.stringify(result.error.errors);
|
||||
throw new Error(`Unexpected fal response: ${errors}`);
|
||||
return this.parseSchema(FalResponseSchema, await response.json());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -233,19 +277,14 @@ export class FalProvider
|
||||
options: CopilotImageOptions = {}
|
||||
): Promise<Array<string>> {
|
||||
if (!this.availableModels.includes(model)) {
|
||||
throw new Error(`Invalid model: ${model}`);
|
||||
throw new CopilotPromptInvalid(`Invalid model: ${model}`);
|
||||
}
|
||||
|
||||
try {
|
||||
const data = await this.buildResponse(messages, model, options);
|
||||
|
||||
if (!data.images?.length && !data.image?.url) {
|
||||
const error = this.extractError(data);
|
||||
const finalError = error
|
||||
? `Failed to generate image: ${error}`
|
||||
: 'No images generated';
|
||||
this.logger.error(finalError);
|
||||
throw new Error(finalError);
|
||||
throw this.extractFalError(data, 'Failed to generate images');
|
||||
}
|
||||
|
||||
if (data.image?.url) {
|
||||
@@ -258,9 +297,7 @@ export class FalProvider
|
||||
.map(image => image.url) || []
|
||||
);
|
||||
} catch (e: any) {
|
||||
const error = `Failed to generate image: ${e.message}`;
|
||||
this.logger.error(error, e.stack);
|
||||
throw new Error(error);
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user