feat: integrate i18n error for copilot (#7311)

fix PD-1333 CLOUD-42
This commit is contained in:
darkskygit
2024-06-26 13:36:23 +00:00
parent 6b47c6beda
commit aeb666f95e
6 changed files with 247 additions and 124 deletions

View File

@@ -4,9 +4,13 @@ import {
config as falConfig,
stream as falStream,
} from '@fal-ai/serverless-client';
import { Logger } from '@nestjs/common';
import { z } from 'zod';
import { z, ZodType } from 'zod';
import {
CopilotPromptInvalid,
CopilotProviderSideError,
UserFriendlyError,
} from '../../../fundamentals';
import {
CopilotCapability,
CopilotChatOptions,
@@ -37,7 +41,10 @@ type FalImage = z.infer<typeof FalImageSchema>;
const FalResponseSchema = z.object({
detail: z
.union([z.array(z.object({ msg: z.string() })), z.string()])
.union([
z.array(z.object({ type: z.string(), msg: z.string() })),
z.string(),
])
.optional(),
images: z.array(FalImageSchema).optional(),
image: FalImageSchema.optional(),
@@ -84,8 +91,6 @@ export class FalProvider
'llava-next',
];
private readonly logger = new Logger(FalProvider.name);
constructor(private readonly config: FalConfig) {
assert(FalProvider.assetsConfig(config));
falConfig({ credentials: this.config.apiKey });
@@ -107,23 +112,15 @@ export class FalProvider
return this.availableModels.includes(model);
}
private extractError(resp: FalResponse): string {
return Array.isArray(resp.detail)
? resp.detail[0]?.msg
: typeof resp.detail === 'string'
? resp.detail
: '';
}
private extractPrompt(message?: PromptMessage): FalPrompt {
if (!message) throw new Error('Prompt is empty');
if (!message) throw new CopilotPromptInvalid('Prompt is empty');
const { content, attachments, params } = message;
// prompt attachments require at least one
if (!content && (!Array.isArray(attachments) || !attachments.length)) {
throw new Error('Prompt or Attachments is empty');
throw new CopilotPromptInvalid('Prompt or Attachments is empty');
}
if (Array.isArray(attachments) && attachments.length > 1) {
throw new Error('Only one attachment is allowed');
throw new CopilotPromptInvalid('Only one attachment is allowed');
}
const lora = (
params?.lora
@@ -139,38 +136,91 @@ export class FalProvider
};
}
private extractFalError(
resp: FalResponse,
message?: string
): CopilotProviderSideError {
if (Array.isArray(resp.detail) && resp.detail.length) {
const error = resp.detail[0].msg;
return new CopilotProviderSideError({
provider: this.type,
kind: resp.detail[0].type,
message: message ? `${message}: ${error}` : error,
});
} else if (typeof resp.detail === 'string') {
const error = resp.detail;
return new CopilotProviderSideError({
provider: this.type,
kind: resp.detail,
message: message ? `${message}: ${error}` : error,
});
}
return new CopilotProviderSideError({
provider: this.type,
kind: 'unknown',
message: 'No content generated',
});
}
private handleError(e: any) {
if (e instanceof UserFriendlyError) {
// pass through user friendly errors
return e;
} else {
const error = new CopilotProviderSideError({
provider: this.type,
kind: 'unexpected_response',
message: e?.message || 'Unexpected fal response',
});
return error;
}
}
private parseSchema<R>(schema: ZodType<R>, data: unknown): R {
const result = schema.safeParse(data);
if (result.success) return result.data;
const errors = JSON.stringify(result.error.errors);
throw new CopilotProviderSideError({
provider: this.type,
kind: 'unexpected_response',
message: `Unexpected fal response: ${errors}`,
});
}
async generateText(
messages: PromptMessage[],
model: string = 'llava-next',
options: CopilotChatOptions = {}
): Promise<string> {
if (!this.availableModels.includes(model)) {
throw new Error(`Invalid model: ${model}`);
throw new CopilotPromptInvalid(`Invalid model: ${model}`);
}
// by default, image prompt assumes there is only one message
const prompt = this.extractPrompt(messages.pop());
const data = (await fetch(`https://fal.run/fal-ai/${model}`, {
method: 'POST',
headers: {
Authorization: `key ${this.config.apiKey}`,
'Content-Type': 'application/json',
},
body: JSON.stringify({
...prompt,
sync_mode: true,
enable_safety_checks: false,
}),
signal: options.signal,
}).then(res => res.json())) as FalResponse;
try {
const response = await fetch(`https://fal.run/fal-ai/${model}`, {
method: 'POST',
headers: {
Authorization: `key ${this.config.apiKey}`,
'Content-Type': 'application/json',
},
body: JSON.stringify({
...prompt,
sync_mode: true,
enable_safety_checks: false,
}),
signal: options.signal,
});
if (!data.output) {
const error = this.extractError(data);
throw new Error(
error ? `Failed to generate image: ${error}` : 'No images generated'
);
const data = this.parseSchema(FalResponseSchema, await response.json());
if (!data.output) {
throw this.extractFalError(data, 'Failed to generate text');
}
return data.output;
} catch (e: any) {
throw this.handleError(e);
}
return data.output;
}
async *generateTextStream(
@@ -199,11 +249,8 @@ export class FalProvider
const prompt = this.extractPrompt(messages.pop());
if (model.startsWith('workflows/')) {
const stream = await falStream(model, { input: prompt });
const result = FalStreamOutputSchema.safeParse(await stream.done());
if (result.success) return result.data.output;
const errors = JSON.stringify(result.error.errors);
throw new Error(`Unexpected fal response: ${errors}`);
return this.parseSchema(FalStreamOutputSchema, await stream.done())
.output;
} else {
const response = await fetch(`https://fal.run/fal-ai/${model}`, {
method: 'POST',
@@ -219,10 +266,7 @@ export class FalProvider
}),
signal: options.signal,
});
const result = FalResponseSchema.safeParse(await response.json());
if (result.success) return result.data;
const errors = JSON.stringify(result.error.errors);
throw new Error(`Unexpected fal response: ${errors}`);
return this.parseSchema(FalResponseSchema, await response.json());
}
}
@@ -233,19 +277,14 @@ export class FalProvider
options: CopilotImageOptions = {}
): Promise<Array<string>> {
if (!this.availableModels.includes(model)) {
throw new Error(`Invalid model: ${model}`);
throw new CopilotPromptInvalid(`Invalid model: ${model}`);
}
try {
const data = await this.buildResponse(messages, model, options);
if (!data.images?.length && !data.image?.url) {
const error = this.extractError(data);
const finalError = error
? `Failed to generate image: ${error}`
: 'No images generated';
this.logger.error(finalError);
throw new Error(finalError);
throw this.extractFalError(data, 'Failed to generate images');
}
if (data.image?.url) {
@@ -258,9 +297,7 @@ export class FalProvider
.map(image => image.url) || []
);
} catch (e: any) {
const error = `Failed to generate image: ${e.message}`;
this.logger.error(error, e.stack);
throw new Error(error);
throw this.handleError(e);
}
}