mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-13 21:05:19 +00:00
@@ -9,6 +9,7 @@ import { z, ZodType } from 'zod';
|
||||
import {
|
||||
CopilotPromptInvalid,
|
||||
CopilotProviderSideError,
|
||||
metrics,
|
||||
UserFriendlyError,
|
||||
} from '../../../fundamentals';
|
||||
import {
|
||||
@@ -217,6 +218,7 @@ export class FalProvider
|
||||
// by default, image prompt assumes there is only one message
|
||||
const prompt = this.extractPrompt(messages.pop());
|
||||
try {
|
||||
metrics.ai.counter('chat_text_calls').add(1, { model });
|
||||
const response = await fetch(`https://fal.run/fal-ai/${model}`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
@@ -237,6 +239,7 @@ export class FalProvider
|
||||
}
|
||||
return data.output;
|
||||
} catch (e: any) {
|
||||
metrics.ai.counter('chat_text_errors').add(1, { model });
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
@@ -246,15 +249,21 @@ export class FalProvider
|
||||
model: string = 'llava-next',
|
||||
options: CopilotChatOptions = {}
|
||||
): AsyncIterable<string> {
|
||||
const result = await this.generateText(messages, model, options);
|
||||
try {
|
||||
metrics.ai.counter('chat_text_stream_calls').add(1, { model });
|
||||
const result = await this.generateText(messages, model, options);
|
||||
|
||||
for await (const content of result) {
|
||||
if (content) {
|
||||
yield content;
|
||||
if (options.signal?.aborted) {
|
||||
break;
|
||||
for await (const content of result) {
|
||||
if (content) {
|
||||
yield content;
|
||||
if (options.signal?.aborted) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
metrics.ai.counter('chat_text_stream_errors').add(1, { model });
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -299,6 +308,8 @@ export class FalProvider
|
||||
}
|
||||
|
||||
try {
|
||||
metrics.ai.counter('generate_images_calls').add(1, { model });
|
||||
|
||||
const data = await this.buildResponse(messages, model, options);
|
||||
|
||||
if (!data.images?.length && !data.image?.url) {
|
||||
@@ -315,6 +326,7 @@ export class FalProvider
|
||||
.map(image => image.url) || []
|
||||
);
|
||||
} catch (e: any) {
|
||||
metrics.ai.counter('generate_images_errors').add(1, { model });
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
@@ -324,9 +336,15 @@ export class FalProvider
|
||||
model: string = this.availableModels[0],
|
||||
options: CopilotImageOptions = {}
|
||||
): AsyncIterable<string> {
|
||||
const ret = await this.generateImages(messages, model, options);
|
||||
for (const url of ret) {
|
||||
yield url;
|
||||
try {
|
||||
metrics.ai.counter('generate_images_stream_calls').add(1, { model });
|
||||
const ret = await this.generateImages(messages, model, options);
|
||||
for (const url of ret) {
|
||||
yield url;
|
||||
}
|
||||
} catch (e) {
|
||||
metrics.ai.counter('generate_images_stream_errors').add(1, { model });
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import { Logger } from '@nestjs/common';
|
||||
import { APIError, ClientOptions, OpenAI } from 'openai';
|
||||
import { APIError, BadRequestError, ClientOptions, OpenAI } from 'openai';
|
||||
|
||||
import {
|
||||
CopilotPromptInvalid,
|
||||
CopilotProviderSideError,
|
||||
metrics,
|
||||
UserFriendlyError,
|
||||
} from '../../../fundamentals';
|
||||
import {
|
||||
@@ -179,10 +180,23 @@ export class OpenAIProvider
|
||||
}
|
||||
}
|
||||
|
||||
private handleError(e: any) {
|
||||
private handleError(
|
||||
e: any,
|
||||
model: string,
|
||||
options: CopilotImageOptions = {}
|
||||
) {
|
||||
if (e instanceof UserFriendlyError) {
|
||||
return e;
|
||||
} else if (e instanceof APIError) {
|
||||
if (
|
||||
e instanceof BadRequestError &&
|
||||
(e.message.includes('safety') || e.message.includes('risk'))
|
||||
) {
|
||||
metrics.ai
|
||||
.counter('chat_text_risk_errors')
|
||||
.add(1, { model, user: options.user || undefined });
|
||||
}
|
||||
|
||||
return new CopilotProviderSideError({
|
||||
provider: this.type,
|
||||
kind: e.type || 'unknown',
|
||||
@@ -206,6 +220,7 @@ export class OpenAIProvider
|
||||
this.checkParams({ messages, model, options });
|
||||
|
||||
try {
|
||||
metrics.ai.counter('chat_text_calls').add(1, { model });
|
||||
const result = await this.instance.chat.completions.create(
|
||||
{
|
||||
messages: this.chatToGPTMessage(messages),
|
||||
@@ -223,7 +238,8 @@ export class OpenAIProvider
|
||||
if (!content) throw new Error('Failed to generate text');
|
||||
return content.trim();
|
||||
} catch (e: any) {
|
||||
throw this.handleError(e);
|
||||
metrics.ai.counter('chat_text_errors').add(1, { model });
|
||||
throw this.handleError(e, model, options);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -235,6 +251,7 @@ export class OpenAIProvider
|
||||
this.checkParams({ messages, model, options });
|
||||
|
||||
try {
|
||||
metrics.ai.counter('chat_text_stream_calls').add(1, { model });
|
||||
const result = await this.instance.chat.completions.create(
|
||||
{
|
||||
stream: true,
|
||||
@@ -268,7 +285,8 @@ export class OpenAIProvider
|
||||
}
|
||||
}
|
||||
} catch (e: any) {
|
||||
throw this.handleError(e);
|
||||
metrics.ai.counter('chat_text_stream_errors').add(1, { model });
|
||||
throw this.handleError(e, model, options);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -283,15 +301,19 @@ export class OpenAIProvider
|
||||
this.checkParams({ embeddings: messages, model, options });
|
||||
|
||||
try {
|
||||
metrics.ai.counter('generate_embedding_calls').add(1, { model });
|
||||
const result = await this.instance.embeddings.create({
|
||||
model: model,
|
||||
input: messages,
|
||||
dimensions: options.dimensions || DEFAULT_DIMENSIONS,
|
||||
user: options.user,
|
||||
});
|
||||
return result.data.map(e => e.embedding);
|
||||
return result.data
|
||||
.map(e => e?.embedding)
|
||||
.filter(v => v && Array.isArray(v));
|
||||
} catch (e: any) {
|
||||
throw this.handleError(e);
|
||||
metrics.ai.counter('generate_embedding_errors').add(1, { model });
|
||||
throw this.handleError(e, model, options);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -305,6 +327,7 @@ export class OpenAIProvider
|
||||
if (!prompt) throw new CopilotPromptInvalid('Prompt is required');
|
||||
|
||||
try {
|
||||
metrics.ai.counter('generate_images_calls').add(1, { model });
|
||||
const result = await this.instance.images.generate(
|
||||
{
|
||||
prompt,
|
||||
@@ -319,7 +342,8 @@ export class OpenAIProvider
|
||||
.map(image => image.url)
|
||||
.filter((v): v is string => !!v);
|
||||
} catch (e: any) {
|
||||
throw this.handleError(e);
|
||||
metrics.ai.counter('generate_images_errors').add(1, { model });
|
||||
throw this.handleError(e, model, options);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -328,9 +352,15 @@ export class OpenAIProvider
|
||||
model: string = 'dall-e-3',
|
||||
options: CopilotImageOptions = {}
|
||||
): AsyncIterable<string> {
|
||||
const ret = await this.generateImages(messages, model, options);
|
||||
for (const url of ret) {
|
||||
yield url;
|
||||
try {
|
||||
metrics.ai.counter('generate_images_stream_calls').add(1, { model });
|
||||
const ret = await this.generateImages(messages, model, options);
|
||||
for (const url of ret) {
|
||||
yield url;
|
||||
}
|
||||
} catch (e) {
|
||||
metrics.ai.counter('generate_images_stream_errors').add(1, { model });
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user