mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-14 13:25:12 +00:00
feat(server): migrate copilot provider sdk (#11584)
fix AI-15 fix AI-16
This commit is contained in:
@@ -1,27 +1,39 @@
|
||||
import OpenAI from 'openai';
|
||||
import {
|
||||
createOpenAI,
|
||||
type OpenAIProvider as VercelOpenAIProvider,
|
||||
} from '@ai-sdk/openai';
|
||||
import { embedMany } from 'ai';
|
||||
|
||||
import { Embedding } from '../../../models';
|
||||
import { OpenAIConfig } from '../providers/openai';
|
||||
import { EmbeddingClient } from './types';
|
||||
|
||||
export class OpenAIEmbeddingClient extends EmbeddingClient {
|
||||
constructor(private readonly client: OpenAI) {
|
||||
readonly #instance: VercelOpenAIProvider;
|
||||
|
||||
constructor(config: OpenAIConfig) {
|
||||
super();
|
||||
this.#instance = createOpenAI({
|
||||
apiKey: config.apiKey,
|
||||
baseURL: config.baseUrl,
|
||||
});
|
||||
}
|
||||
|
||||
async getEmbeddings(
|
||||
input: string[],
|
||||
signal?: AbortSignal
|
||||
): Promise<Embedding[]> {
|
||||
const resp = await this.client.embeddings.create(
|
||||
{
|
||||
input,
|
||||
model: 'text-embedding-3-large',
|
||||
dimensions: 1024,
|
||||
encoding_format: 'float',
|
||||
},
|
||||
{ signal }
|
||||
);
|
||||
return resp.data.map(e => ({ ...e, content: input[e.index] }));
|
||||
async getEmbeddings(input: string[]): Promise<Embedding[]> {
|
||||
const modelInstance = this.#instance.embedding('text-embedding-3-large', {
|
||||
dimensions: 1024,
|
||||
});
|
||||
|
||||
const { embeddings } = await embedMany({
|
||||
model: modelInstance,
|
||||
values: input,
|
||||
});
|
||||
|
||||
return Array.from(embeddings.entries()).map(([index, embedding]) => ({
|
||||
index,
|
||||
embedding,
|
||||
content: input[index],
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
import OpenAI from 'openai';
|
||||
|
||||
import {
|
||||
AFFiNELogger,
|
||||
@@ -49,7 +48,7 @@ export class CopilotContextDocJob {
|
||||
this.supportEmbedding =
|
||||
await this.models.copilotContext.checkEmbeddingAvailable();
|
||||
this.client = new OpenAIEmbeddingClient(
|
||||
new OpenAI(this.config.copilot.providers.openai)
|
||||
this.config.copilot.providers.openai
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { Injectable, OnApplicationBootstrap } from '@nestjs/common';
|
||||
import OpenAI from 'openai';
|
||||
|
||||
import {
|
||||
Cache,
|
||||
@@ -46,7 +45,7 @@ export class CopilotContextService implements OnApplicationBootstrap {
|
||||
private setup() {
|
||||
const configure = this.config.copilot.providers.openai;
|
||||
if (configure.apiKey) {
|
||||
this.client = new OpenAIEmbeddingClient(new OpenAI(configure));
|
||||
this.client = new OpenAIEmbeddingClient(configure);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,14 +4,10 @@ import {
|
||||
} from '@ai-sdk/google';
|
||||
import {
|
||||
AISDKError,
|
||||
type CoreAssistantMessage,
|
||||
type CoreUserMessage,
|
||||
FilePart,
|
||||
generateObject,
|
||||
generateText,
|
||||
JSONParseError,
|
||||
streamText,
|
||||
TextPart,
|
||||
} from 'ai';
|
||||
|
||||
import {
|
||||
@@ -29,35 +25,15 @@ import {
|
||||
CopilotTextToTextProvider,
|
||||
PromptMessage,
|
||||
} from './types';
|
||||
import { chatToGPTMessage } from './utils';
|
||||
|
||||
export const DEFAULT_DIMENSIONS = 256;
|
||||
|
||||
const SIMPLE_IMAGE_URL_REGEX = /^(https?:\/\/|data:image\/)/;
|
||||
const FORMAT_INFER_MAP: Record<string, string> = {
|
||||
pdf: 'application/pdf',
|
||||
mp3: 'audio/mpeg',
|
||||
wav: 'audio/wav',
|
||||
png: 'image/png',
|
||||
jpeg: 'image/jpeg',
|
||||
jpg: 'image/jpeg',
|
||||
webp: 'image/webp',
|
||||
txt: 'text/plain',
|
||||
md: 'text/plain',
|
||||
mov: 'video/mov',
|
||||
mpeg: 'video/mpeg',
|
||||
mp4: 'video/mp4',
|
||||
avi: 'video/avi',
|
||||
wmv: 'video/wmv',
|
||||
flv: 'video/flv',
|
||||
};
|
||||
|
||||
export type GeminiConfig = {
|
||||
apiKey: string;
|
||||
baseUrl?: string;
|
||||
};
|
||||
|
||||
type ChatMessage = CoreUserMessage | CoreAssistantMessage;
|
||||
|
||||
export class GeminiProvider
|
||||
extends CopilotProvider<GeminiConfig>
|
||||
implements CopilotTextToTextProvider
|
||||
@@ -86,67 +62,6 @@ export class GeminiProvider
|
||||
});
|
||||
}
|
||||
|
||||
private inferMimeType(url: string) {
|
||||
if (url.startsWith('data:')) {
|
||||
return url.split(';')[0].split(':')[1];
|
||||
}
|
||||
const extension = url.split('.').pop();
|
||||
if (extension) {
|
||||
return FORMAT_INFER_MAP[extension];
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
protected async chatToGPTMessage(
|
||||
messages: PromptMessage[]
|
||||
): Promise<[string | undefined, ChatMessage[], any]> {
|
||||
const system =
|
||||
messages[0]?.role === 'system' ? messages.shift() : undefined;
|
||||
const schema = system?.params?.schema;
|
||||
|
||||
// filter redundant fields
|
||||
const msgs: ChatMessage[] = [];
|
||||
for (let { role, content, attachments, params } of messages.filter(
|
||||
m => m.role !== 'system'
|
||||
)) {
|
||||
content = content.trim();
|
||||
role = role as 'user' | 'assistant';
|
||||
const mimetype = params?.mimetype;
|
||||
if (Array.isArray(attachments)) {
|
||||
const contents: (TextPart | FilePart)[] = [];
|
||||
if (content.length) {
|
||||
contents.push({
|
||||
type: 'text',
|
||||
text: content,
|
||||
});
|
||||
}
|
||||
|
||||
for (const url of attachments) {
|
||||
if (SIMPLE_IMAGE_URL_REGEX.test(url)) {
|
||||
const mimeType =
|
||||
typeof mimetype === 'string' ? mimetype : this.inferMimeType(url);
|
||||
if (mimeType) {
|
||||
const data = url.startsWith('data:')
|
||||
? await fetch(url).then(r => r.arrayBuffer())
|
||||
: new URL(url);
|
||||
contents.push({
|
||||
type: 'file' as const,
|
||||
data,
|
||||
mimeType,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
msgs.push({ role, content: contents } as ChatMessage);
|
||||
} else {
|
||||
msgs.push({ role, content });
|
||||
}
|
||||
}
|
||||
|
||||
return [system?.content, msgs, schema];
|
||||
}
|
||||
|
||||
protected async checkParams({
|
||||
messages,
|
||||
embeddings,
|
||||
@@ -223,7 +138,7 @@ export class GeminiProvider
|
||||
try {
|
||||
metrics.ai.counter('chat_text_calls').add(1, { model });
|
||||
|
||||
const [system, msgs, schema] = await this.chatToGPTMessage(messages);
|
||||
const [system, msgs, schema] = await chatToGPTMessage(messages);
|
||||
|
||||
const modelInstance = this.#instance(model, {
|
||||
structuredOutputs: Boolean(options.jsonMode),
|
||||
@@ -274,7 +189,7 @@ export class GeminiProvider
|
||||
|
||||
try {
|
||||
metrics.ai.counter('chat_text_stream_calls').add(1, { model });
|
||||
const [system, msgs] = await this.chatToGPTMessage(messages);
|
||||
const [system, msgs] = await chatToGPTMessage(messages);
|
||||
|
||||
const { textStream } = streamText({
|
||||
model: this.#instance(model),
|
||||
|
||||
@@ -1,4 +1,15 @@
|
||||
import { APIError, BadRequestError, ClientOptions, OpenAI } from 'openai';
|
||||
import {
|
||||
createOpenAI,
|
||||
type OpenAIProvider as VercelOpenAIProvider,
|
||||
} from '@ai-sdk/openai';
|
||||
import {
|
||||
AISDKError,
|
||||
embedMany,
|
||||
experimental_generateImage as generateImage,
|
||||
generateObject,
|
||||
generateText,
|
||||
streamText,
|
||||
} from 'ai';
|
||||
|
||||
import {
|
||||
CopilotPromptInvalid,
|
||||
@@ -20,12 +31,14 @@ import {
|
||||
CopilotTextToTextProvider,
|
||||
PromptMessage,
|
||||
} from './types';
|
||||
import { chatToGPTMessage } from './utils';
|
||||
|
||||
export const DEFAULT_DIMENSIONS = 256;
|
||||
|
||||
const SIMPLE_IMAGE_URL_REGEX = /^(https?:\/\/|data:image\/)/;
|
||||
|
||||
export type OpenAIConfig = ClientOptions;
|
||||
export type OpenAIConfig = {
|
||||
apiKey: string;
|
||||
baseUrl?: string;
|
||||
};
|
||||
|
||||
export class OpenAIProvider
|
||||
extends CopilotProvider<OpenAIConfig>
|
||||
@@ -62,8 +75,7 @@ export class OpenAIProvider
|
||||
'dall-e-3',
|
||||
];
|
||||
|
||||
#existsModels: string[] = [];
|
||||
#instance!: OpenAI;
|
||||
#instance!: VercelOpenAIProvider;
|
||||
|
||||
override configured(): boolean {
|
||||
return !!this.config.apiKey;
|
||||
@@ -71,55 +83,9 @@ export class OpenAIProvider
|
||||
|
||||
protected override setup() {
|
||||
super.setup();
|
||||
this.#instance = new OpenAI(this.config);
|
||||
}
|
||||
|
||||
override async isModelAvailable(model: string): Promise<boolean> {
|
||||
const knownModels = this.models.includes(model);
|
||||
if (knownModels) return true;
|
||||
|
||||
if (!this.#existsModels) {
|
||||
try {
|
||||
this.#existsModels = await this.#instance.models
|
||||
.list()
|
||||
.then(({ data }) => data.map(m => m.id));
|
||||
} catch (e: any) {
|
||||
this.logger.error('Failed to fetch online model list', e.stack);
|
||||
}
|
||||
}
|
||||
return !!this.#existsModels?.includes(model);
|
||||
}
|
||||
|
||||
protected chatToGPTMessage(
|
||||
messages: PromptMessage[]
|
||||
): OpenAI.Chat.Completions.ChatCompletionMessageParam[] {
|
||||
// filter redundant fields
|
||||
return messages.map(({ role, content, attachments }) => {
|
||||
content = content.trim();
|
||||
if (Array.isArray(attachments) && attachments.length) {
|
||||
const contents: OpenAI.Chat.Completions.ChatCompletionContentPart[] =
|
||||
[];
|
||||
if (content.length) {
|
||||
contents.push({
|
||||
type: 'text',
|
||||
text: content,
|
||||
});
|
||||
}
|
||||
contents.push(
|
||||
...(attachments
|
||||
.filter(url => SIMPLE_IMAGE_URL_REGEX.test(url))
|
||||
.map(url => ({
|
||||
type: 'image_url',
|
||||
image_url: { url, detail: 'high' },
|
||||
})) as OpenAI.Chat.Completions.ChatCompletionContentPartImage[])
|
||||
);
|
||||
return {
|
||||
role,
|
||||
content: contents,
|
||||
} as OpenAI.Chat.Completions.ChatCompletionMessageParam;
|
||||
} else {
|
||||
return { role, content };
|
||||
}
|
||||
this.#instance = createOpenAI({
|
||||
apiKey: this.config.apiKey,
|
||||
baseURL: this.config.baseUrl,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -186,11 +152,8 @@ export class OpenAIProvider
|
||||
) {
|
||||
if (e instanceof UserFriendlyError) {
|
||||
return e;
|
||||
} else if (e instanceof APIError) {
|
||||
if (
|
||||
e instanceof BadRequestError &&
|
||||
(e.message.includes('safety') || e.message.includes('risk'))
|
||||
) {
|
||||
} else if (e instanceof AISDKError) {
|
||||
if (e.message.includes('safety') || e.message.includes('risk')) {
|
||||
metrics.ai
|
||||
.counter('chat_text_risk_errors')
|
||||
.add(1, { model, user: options.user || undefined });
|
||||
@@ -198,7 +161,7 @@ export class OpenAIProvider
|
||||
|
||||
return new CopilotProviderSideError({
|
||||
provider: this.type,
|
||||
kind: e.type || 'unknown',
|
||||
kind: e.name || 'unknown',
|
||||
message: e.message,
|
||||
});
|
||||
} else {
|
||||
@@ -217,26 +180,42 @@ export class OpenAIProvider
|
||||
options: CopilotChatOptions = {}
|
||||
): Promise<string> {
|
||||
await this.checkParams({ messages, model, options });
|
||||
console.log('messages', messages);
|
||||
|
||||
try {
|
||||
metrics.ai.counter('chat_text_calls').add(1, { model });
|
||||
const result = await this.#instance.chat.completions.create(
|
||||
{
|
||||
messages: this.chatToGPTMessage(messages),
|
||||
model: model,
|
||||
temperature: options.temperature || 0,
|
||||
max_completion_tokens: options.maxTokens || 4096,
|
||||
response_format: {
|
||||
type: options.jsonMode ? 'json_object' : 'text',
|
||||
},
|
||||
user: options.user,
|
||||
},
|
||||
{ signal: options.signal }
|
||||
);
|
||||
const { content } = result.choices[0].message;
|
||||
if (!content) throw new Error('Failed to generate text');
|
||||
return content.trim();
|
||||
|
||||
const [system, msgs, schema] = await chatToGPTMessage(messages);
|
||||
|
||||
const modelInstance = this.#instance(model, {
|
||||
structuredOutputs: Boolean(options.jsonMode),
|
||||
user: options.user,
|
||||
});
|
||||
|
||||
const commonParams = {
|
||||
model: modelInstance,
|
||||
system,
|
||||
messages: msgs,
|
||||
temperature: options.temperature || 0,
|
||||
maxTokens: options.maxTokens || 4096,
|
||||
abortSignal: options.signal,
|
||||
};
|
||||
|
||||
const { text } = schema
|
||||
? await generateObject({
|
||||
...commonParams,
|
||||
schema,
|
||||
}).then(r => ({ text: JSON.stringify(r.object) }))
|
||||
: await generateText({
|
||||
...commonParams,
|
||||
providerOptions: {
|
||||
openai: options.user ? { user: options.user } : {},
|
||||
},
|
||||
});
|
||||
|
||||
return text.trim();
|
||||
} catch (e: any) {
|
||||
console.log('error', e);
|
||||
metrics.ai.counter('chat_text_errors').add(1, { model });
|
||||
throw this.handleError(e, model, options);
|
||||
}
|
||||
@@ -251,34 +230,30 @@ export class OpenAIProvider
|
||||
|
||||
try {
|
||||
metrics.ai.counter('chat_text_stream_calls').add(1, { model });
|
||||
const result = await this.#instance.chat.completions.create(
|
||||
{
|
||||
stream: true,
|
||||
messages: this.chatToGPTMessage(messages),
|
||||
model: model,
|
||||
frequency_penalty: options.frequencyPenalty || 0,
|
||||
presence_penalty: options.presencePenalty || 0,
|
||||
temperature: options.temperature || 0.5,
|
||||
max_completion_tokens: options.maxTokens || 4096,
|
||||
response_format: {
|
||||
type: options.jsonMode ? 'json_object' : 'text',
|
||||
},
|
||||
user: options.user,
|
||||
},
|
||||
{
|
||||
signal: options.signal,
|
||||
}
|
||||
);
|
||||
|
||||
for await (const message of result) {
|
||||
if (!Array.isArray(message.choices) || !message.choices.length) {
|
||||
continue;
|
||||
}
|
||||
const content = message.choices[0].delta.content;
|
||||
if (content) {
|
||||
yield content;
|
||||
const [system, msgs] = await chatToGPTMessage(messages);
|
||||
|
||||
const modelInstance = this.#instance(model, {
|
||||
structuredOutputs: Boolean(options.jsonMode),
|
||||
user: options.user,
|
||||
});
|
||||
|
||||
const { textStream } = streamText({
|
||||
model: modelInstance,
|
||||
system,
|
||||
messages: msgs,
|
||||
frequencyPenalty: options.frequencyPenalty || 0,
|
||||
presencePenalty: options.presencePenalty || 0,
|
||||
temperature: options.temperature || 0,
|
||||
maxTokens: options.maxTokens || 4096,
|
||||
abortSignal: options.signal,
|
||||
});
|
||||
|
||||
for await (const message of textStream) {
|
||||
if (message) {
|
||||
yield message;
|
||||
if (options.signal?.aborted) {
|
||||
result.controller.abort();
|
||||
await textStream.cancel();
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -301,15 +276,18 @@ export class OpenAIProvider
|
||||
|
||||
try {
|
||||
metrics.ai.counter('generate_embedding_calls').add(1, { model });
|
||||
const result = await this.#instance.embeddings.create({
|
||||
model: model,
|
||||
input: messages,
|
||||
|
||||
const modelInstance = this.#instance.embedding(model, {
|
||||
dimensions: options.dimensions || DEFAULT_DIMENSIONS,
|
||||
user: options.user,
|
||||
});
|
||||
return result.data
|
||||
.map(e => e?.embedding)
|
||||
.filter(v => v && Array.isArray(v));
|
||||
|
||||
const { embeddings } = await embedMany({
|
||||
model: modelInstance,
|
||||
values: messages,
|
||||
});
|
||||
|
||||
return embeddings.filter(v => v && Array.isArray(v));
|
||||
} catch (e: any) {
|
||||
metrics.ai.counter('generate_embedding_errors').add(1, { model });
|
||||
throw this.handleError(e, model, options);
|
||||
@@ -327,18 +305,16 @@ export class OpenAIProvider
|
||||
|
||||
try {
|
||||
metrics.ai.counter('generate_images_calls').add(1, { model });
|
||||
const result = await this.#instance.images.generate(
|
||||
{
|
||||
prompt,
|
||||
model,
|
||||
response_format: 'url',
|
||||
user: options.user,
|
||||
},
|
||||
{ signal: options.signal }
|
||||
);
|
||||
|
||||
return result.data
|
||||
.map(image => image.url)
|
||||
const modelInstance = this.#instance.image(model);
|
||||
|
||||
const result = await generateImage({
|
||||
model: modelInstance,
|
||||
prompt,
|
||||
});
|
||||
|
||||
return result.images
|
||||
.map(image => image.base64)
|
||||
.filter((v): v is string => !!v);
|
||||
} catch (e: any) {
|
||||
metrics.ai.counter('generate_images_errors').add(1, { model });
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
import { EventSourceParserStream } from 'eventsource-parser/stream';
|
||||
import {
|
||||
createPerplexity,
|
||||
type PerplexityProvider as VercelPerplexityProvider,
|
||||
} from '@ai-sdk/perplexity';
|
||||
import { generateText, streamText } from 'ai';
|
||||
import { z } from 'zod';
|
||||
|
||||
import {
|
||||
@@ -14,6 +18,7 @@ import {
|
||||
CopilotTextToTextProvider,
|
||||
PromptMessage,
|
||||
} from './types';
|
||||
import { chatToGPTMessage, CitationParser } from './utils';
|
||||
|
||||
export type PerplexityConfig = {
|
||||
apiKey: string;
|
||||
@@ -39,130 +44,8 @@ const PerplexityErrorSchema = z.union([
|
||||
}),
|
||||
]);
|
||||
|
||||
const PerplexityDataSchema = z.object({
|
||||
citations: z.array(z.string()),
|
||||
choices: z.array(
|
||||
z.object({
|
||||
message: z.object({
|
||||
content: z.string(),
|
||||
role: z.literal('assistant'),
|
||||
}),
|
||||
delta: z.object({
|
||||
content: z.string(),
|
||||
role: z.literal('assistant'),
|
||||
}),
|
||||
finish_reason: z.union([z.literal('stop'), z.literal(null)]),
|
||||
})
|
||||
),
|
||||
});
|
||||
|
||||
const PerplexitySchema = z.union([PerplexityDataSchema, PerplexityErrorSchema]);
|
||||
|
||||
type PerplexityError = z.infer<typeof PerplexityErrorSchema>;
|
||||
|
||||
export class CitationParser {
|
||||
private readonly SQUARE_BRACKET_OPEN = '[';
|
||||
|
||||
private readonly SQUARE_BRACKET_CLOSE = ']';
|
||||
|
||||
private readonly PARENTHESES_OPEN = '(';
|
||||
|
||||
private startToken: string[] = [];
|
||||
|
||||
private endToken: string[] = [];
|
||||
|
||||
private numberToken: string[] = [];
|
||||
|
||||
private citations: string[] = [];
|
||||
|
||||
public parse(content: string, citations: string[]) {
|
||||
this.citations = citations;
|
||||
let result = '';
|
||||
const contentArray = content.split('');
|
||||
for (const [index, char] of contentArray.entries()) {
|
||||
if (char === this.SQUARE_BRACKET_OPEN) {
|
||||
if (this.numberToken.length === 0) {
|
||||
this.startToken.push(char);
|
||||
} else {
|
||||
result += this.flush() + char;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (char === this.SQUARE_BRACKET_CLOSE) {
|
||||
this.endToken.push(char);
|
||||
if (this.startToken.length === this.endToken.length) {
|
||||
const cIndex = Number(this.numberToken.join('').trim());
|
||||
if (
|
||||
cIndex > 0 &&
|
||||
cIndex <= citations.length &&
|
||||
contentArray[index + 1] !== this.PARENTHESES_OPEN
|
||||
) {
|
||||
const content = `[^${cIndex}]`;
|
||||
result += content;
|
||||
this.resetToken();
|
||||
} else {
|
||||
result += this.flush();
|
||||
}
|
||||
} else if (this.startToken.length < this.endToken.length) {
|
||||
result += this.flush();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (this.isNumeric(char)) {
|
||||
if (this.startToken.length > 0) {
|
||||
this.numberToken.push(char);
|
||||
} else {
|
||||
result += this.flush() + char;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (this.startToken.length > 0) {
|
||||
result += this.flush() + char;
|
||||
} else {
|
||||
result += char;
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
public end() {
|
||||
return this.flush() + '\n' + this.getFootnotes();
|
||||
}
|
||||
|
||||
private flush() {
|
||||
const content = this.getTokenContent();
|
||||
this.resetToken();
|
||||
return content;
|
||||
}
|
||||
|
||||
private getFootnotes() {
|
||||
const footnotes = this.citations.map((citation, index) => {
|
||||
return `[^${index + 1}]: {"type":"url","url":"${encodeURIComponent(
|
||||
citation
|
||||
)}"}`;
|
||||
});
|
||||
return footnotes.join('\n');
|
||||
}
|
||||
|
||||
private getTokenContent() {
|
||||
return this.startToken.concat(this.numberToken, this.endToken).join('');
|
||||
}
|
||||
|
||||
private resetToken() {
|
||||
this.startToken = [];
|
||||
this.endToken = [];
|
||||
this.numberToken = [];
|
||||
}
|
||||
|
||||
private isNumeric(str: string) {
|
||||
return !isNaN(Number(str)) && str.trim() !== '';
|
||||
}
|
||||
}
|
||||
|
||||
export class PerplexityProvider
|
||||
extends CopilotProvider<PerplexityConfig>
|
||||
implements CopilotTextToTextProvider
|
||||
@@ -176,10 +59,20 @@ export class PerplexityProvider
|
||||
'sonar-reasoning-pro',
|
||||
];
|
||||
|
||||
#instance!: VercelPerplexityProvider;
|
||||
|
||||
override configured(): boolean {
|
||||
return !!this.config.apiKey;
|
||||
}
|
||||
|
||||
protected override setup() {
|
||||
super.setup();
|
||||
this.#instance = createPerplexity({
|
||||
apiKey: this.config.apiKey,
|
||||
baseURL: this.config.endpoint,
|
||||
});
|
||||
}
|
||||
|
||||
async generateText(
|
||||
messages: PromptMessage[],
|
||||
model: string = 'sonar',
|
||||
@@ -188,38 +81,26 @@ export class PerplexityProvider
|
||||
await this.checkParams({ messages, model, options });
|
||||
try {
|
||||
metrics.ai.counter('chat_text_calls').add(1, { model });
|
||||
const sMessages = messages
|
||||
.map(({ content, role }) => ({ content, role }))
|
||||
.filter(({ content }) => typeof content === 'string');
|
||||
|
||||
const params = {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Authorization: `Bearer ${this.config.apiKey}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model,
|
||||
messages: sMessages,
|
||||
max_tokens: options.maxTokens || 4096,
|
||||
}),
|
||||
};
|
||||
const response = await fetch(
|
||||
this.config.endpoint || 'https://api.perplexity.ai/chat/completions',
|
||||
params
|
||||
);
|
||||
const data = PerplexitySchema.parse(await response.json());
|
||||
if ('detail' in data || 'error' in data) {
|
||||
throw this.convertError(data);
|
||||
} else {
|
||||
const citationParser = new CitationParser();
|
||||
const { content } = data.choices[0].message;
|
||||
const { citations } = data;
|
||||
let result = content.replaceAll(/<\/?think>\n/g, '\n---\n');
|
||||
result = citationParser.parse(result, citations);
|
||||
result += citationParser.end();
|
||||
return result;
|
||||
}
|
||||
const [system, msgs] = await chatToGPTMessage(messages);
|
||||
|
||||
const modelInstance = this.#instance(model);
|
||||
|
||||
const { text, sources } = await generateText({
|
||||
model: modelInstance,
|
||||
system,
|
||||
messages: msgs,
|
||||
temperature: options.temperature || 0,
|
||||
maxTokens: options.maxTokens || 4096,
|
||||
abortSignal: options.signal,
|
||||
});
|
||||
|
||||
const citationParser = new CitationParser();
|
||||
const citations = sources.map(s => s.url);
|
||||
let result = text.replaceAll(/<\/?think>\n/g, '\n---\n');
|
||||
result = citationParser.parse(result, citations);
|
||||
result += citationParser.end();
|
||||
return result;
|
||||
} catch (e: any) {
|
||||
metrics.ai.counter('chat_text_errors').add(1, { model });
|
||||
throw this.handleError(e);
|
||||
@@ -234,69 +115,54 @@ export class PerplexityProvider
|
||||
await this.checkParams({ messages, model, options });
|
||||
try {
|
||||
metrics.ai.counter('chat_text_stream_calls').add(1, { model });
|
||||
const sMessages = messages
|
||||
.map(({ content, role }) => ({ content, role }))
|
||||
.filter(({ content }) => typeof content === 'string');
|
||||
|
||||
const params = {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Authorization: `Bearer ${this.config.apiKey}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model,
|
||||
messages: sMessages,
|
||||
max_tokens: options.maxTokens || 4096,
|
||||
stream: true,
|
||||
}),
|
||||
};
|
||||
const response = await fetch(
|
||||
this.config.endpoint || 'https://api.perplexity.ai/chat/completions',
|
||||
params
|
||||
);
|
||||
const errorHandler = this.convertError;
|
||||
if (response.ok && response.body) {
|
||||
const citationParser = new CitationParser();
|
||||
const eventStream = response.body
|
||||
.pipeThrough(new TextDecoderStream())
|
||||
.pipeThrough(new EventSourceParserStream())
|
||||
.pipeThrough(
|
||||
new TransformStream({
|
||||
transform(chunk, controller) {
|
||||
if (options.signal?.aborted) {
|
||||
controller.enqueue(null);
|
||||
return;
|
||||
}
|
||||
const json = JSON.parse(chunk.data);
|
||||
if (json) {
|
||||
const data = PerplexitySchema.parse(json);
|
||||
if ('detail' in data || 'error' in data) {
|
||||
throw errorHandler(data);
|
||||
}
|
||||
const { content } = data.choices[0].delta;
|
||||
const { citations } = data;
|
||||
let result = content.replaceAll(/<\/?think>\n?/g, '\n---\n');
|
||||
result = citationParser.parse(result, citations);
|
||||
controller.enqueue(result);
|
||||
}
|
||||
},
|
||||
flush(controller) {
|
||||
controller.enqueue(citationParser.end());
|
||||
controller.enqueue(null);
|
||||
},
|
||||
})
|
||||
);
|
||||
const [system, msgs] = await chatToGPTMessage(messages);
|
||||
|
||||
const reader = eventStream.getReader();
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
yield value;
|
||||
const modelInstance = this.#instance(model);
|
||||
|
||||
const stream = streamText({
|
||||
model: modelInstance,
|
||||
system,
|
||||
messages: msgs,
|
||||
temperature: options.temperature || 0,
|
||||
maxTokens: options.maxTokens || 4096,
|
||||
abortSignal: options.signal,
|
||||
});
|
||||
|
||||
const citationParser = new CitationParser();
|
||||
const citations = [];
|
||||
for await (const chunk of stream.fullStream) {
|
||||
switch (chunk.type) {
|
||||
case 'source': {
|
||||
citations.push(chunk.source.url);
|
||||
break;
|
||||
}
|
||||
case 'text-delta': {
|
||||
const result = citationParser.parse(
|
||||
chunk.textDelta.replaceAll(/<\/?think>\n?/g, '\n---\n'),
|
||||
citations
|
||||
);
|
||||
yield result;
|
||||
break;
|
||||
}
|
||||
case 'step-finish': {
|
||||
const result = citationParser.end();
|
||||
yield result;
|
||||
break;
|
||||
}
|
||||
case 'error': {
|
||||
const json =
|
||||
typeof chunk.error === 'string'
|
||||
? JSON.parse(chunk.error)
|
||||
: chunk.error;
|
||||
if (json && typeof json === 'object') {
|
||||
const data = PerplexityErrorSchema.parse(json);
|
||||
if ('detail' in data || 'error' in data) {
|
||||
throw this.convertError(data);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const result = await this.generateText(messages, model, options);
|
||||
yield result;
|
||||
}
|
||||
} catch (e) {
|
||||
metrics.ai.counter('chat_text_stream_errors').add(1, { model });
|
||||
|
||||
201
packages/backend/server/src/plugins/copilot/providers/utils.ts
Normal file
201
packages/backend/server/src/plugins/copilot/providers/utils.ts
Normal file
@@ -0,0 +1,201 @@
|
||||
import {
|
||||
CoreAssistantMessage,
|
||||
CoreUserMessage,
|
||||
FilePart,
|
||||
ImagePart,
|
||||
TextPart,
|
||||
} from 'ai';
|
||||
|
||||
import { PromptMessage } from './types';
|
||||
|
||||
type ChatMessage = CoreUserMessage | CoreAssistantMessage;
|
||||
|
||||
const SIMPLE_IMAGE_URL_REGEX = /^(https?:\/\/|data:image\/)/;
|
||||
const FORMAT_INFER_MAP: Record<string, string> = {
|
||||
pdf: 'application/pdf',
|
||||
mp3: 'audio/mpeg',
|
||||
wav: 'audio/wav',
|
||||
png: 'image/png',
|
||||
jpeg: 'image/jpeg',
|
||||
jpg: 'image/jpeg',
|
||||
webp: 'image/webp',
|
||||
txt: 'text/plain',
|
||||
md: 'text/plain',
|
||||
mov: 'video/mov',
|
||||
mpeg: 'video/mpeg',
|
||||
mp4: 'video/mp4',
|
||||
avi: 'video/avi',
|
||||
wmv: 'video/wmv',
|
||||
flv: 'video/flv',
|
||||
};
|
||||
|
||||
function inferMimeType(url: string) {
|
||||
if (url.startsWith('data:')) {
|
||||
return url.split(';')[0].split(':')[1];
|
||||
}
|
||||
const extension = url.split('.').pop();
|
||||
if (extension) {
|
||||
return FORMAT_INFER_MAP[extension];
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
export async function chatToGPTMessage(
|
||||
messages: PromptMessage[]
|
||||
): Promise<[string | undefined, ChatMessage[], any]> {
|
||||
const system = messages[0]?.role === 'system' ? messages.shift() : undefined;
|
||||
const schema = system?.params?.schema;
|
||||
|
||||
// filter redundant fields
|
||||
const msgs: ChatMessage[] = [];
|
||||
for (let { role, content, attachments, params } of messages.filter(
|
||||
m => m.role !== 'system'
|
||||
)) {
|
||||
content = content.trim();
|
||||
role = role as 'user' | 'assistant';
|
||||
const mimetype = params?.mimetype;
|
||||
if (Array.isArray(attachments)) {
|
||||
const contents: (TextPart | ImagePart | FilePart)[] = [];
|
||||
if (content.length) {
|
||||
contents.push({
|
||||
type: 'text',
|
||||
text: content,
|
||||
});
|
||||
}
|
||||
|
||||
for (const url of attachments) {
|
||||
if (SIMPLE_IMAGE_URL_REGEX.test(url)) {
|
||||
const mimeType =
|
||||
typeof mimetype === 'string' ? mimetype : inferMimeType(url);
|
||||
if (mimeType) {
|
||||
if (mimeType.startsWith('image/')) {
|
||||
contents.push({
|
||||
type: 'image',
|
||||
image: url,
|
||||
mimeType,
|
||||
});
|
||||
} else {
|
||||
const data = url.startsWith('data:')
|
||||
? await fetch(url).then(r => r.arrayBuffer())
|
||||
: new URL(url);
|
||||
contents.push({
|
||||
type: 'file' as const,
|
||||
data,
|
||||
mimeType,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
msgs.push({ role, content: contents } as ChatMessage);
|
||||
} else {
|
||||
msgs.push({ role, content });
|
||||
}
|
||||
}
|
||||
|
||||
return [system?.content, msgs, schema];
|
||||
}
|
||||
|
||||
export class CitationParser {
|
||||
private readonly SQUARE_BRACKET_OPEN = '[';
|
||||
|
||||
private readonly SQUARE_BRACKET_CLOSE = ']';
|
||||
|
||||
private readonly PARENTHESES_OPEN = '(';
|
||||
|
||||
private startToken: string[] = [];
|
||||
|
||||
private endToken: string[] = [];
|
||||
|
||||
private numberToken: string[] = [];
|
||||
|
||||
private citations: string[] = [];
|
||||
|
||||
public parse(content: string, citations: string[]) {
|
||||
this.citations = citations;
|
||||
let result = '';
|
||||
const contentArray = content.split('');
|
||||
for (const [index, char] of contentArray.entries()) {
|
||||
if (char === this.SQUARE_BRACKET_OPEN) {
|
||||
if (this.numberToken.length === 0) {
|
||||
this.startToken.push(char);
|
||||
} else {
|
||||
result += this.flush() + char;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (char === this.SQUARE_BRACKET_CLOSE) {
|
||||
this.endToken.push(char);
|
||||
if (this.startToken.length === this.endToken.length) {
|
||||
const cIndex = Number(this.numberToken.join('').trim());
|
||||
if (
|
||||
cIndex > 0 &&
|
||||
cIndex <= citations.length &&
|
||||
contentArray[index + 1] !== this.PARENTHESES_OPEN
|
||||
) {
|
||||
const content = `[^${cIndex}]`;
|
||||
result += content;
|
||||
this.resetToken();
|
||||
} else {
|
||||
result += this.flush();
|
||||
}
|
||||
} else if (this.startToken.length < this.endToken.length) {
|
||||
result += this.flush();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (this.isNumeric(char)) {
|
||||
if (this.startToken.length > 0) {
|
||||
this.numberToken.push(char);
|
||||
} else {
|
||||
result += this.flush() + char;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (this.startToken.length > 0) {
|
||||
result += this.flush() + char;
|
||||
} else {
|
||||
result += char;
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
public end() {
|
||||
return this.flush() + '\n' + this.getFootnotes();
|
||||
}
|
||||
|
||||
private flush() {
|
||||
const content = this.getTokenContent();
|
||||
this.resetToken();
|
||||
return content;
|
||||
}
|
||||
|
||||
private getFootnotes() {
|
||||
const footnotes = this.citations.map((citation, index) => {
|
||||
return `[^${index + 1}]: {"type":"url","url":"${encodeURIComponent(
|
||||
citation
|
||||
)}"}`;
|
||||
});
|
||||
return footnotes.join('\n');
|
||||
}
|
||||
|
||||
private getTokenContent() {
|
||||
return this.startToken.concat(this.numberToken, this.endToken).join('');
|
||||
}
|
||||
|
||||
private resetToken() {
|
||||
this.startToken = [];
|
||||
this.endToken = [];
|
||||
this.numberToken = [];
|
||||
}
|
||||
|
||||
private isNumeric(str: string) {
|
||||
return !isNaN(Number(str)) && str.trim() !== '';
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user