mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-13 21:05:19 +00:00
@@ -3,14 +3,16 @@ import assert from 'node:assert';
|
||||
import { ClientOptions, OpenAI } from 'openai';
|
||||
|
||||
import {
|
||||
ChatMessage,
|
||||
ChatMessageRole,
|
||||
CopilotCapability,
|
||||
CopilotProviderType,
|
||||
CopilotTextToEmbeddingProvider,
|
||||
CopilotTextToTextProvider,
|
||||
PromptMessage,
|
||||
} from '../types';
|
||||
|
||||
const DEFAULT_DIMENSIONS = 256;
|
||||
|
||||
export class OpenAIProvider
|
||||
implements CopilotTextToTextProvider, CopilotTextToEmbeddingProvider
|
||||
{
|
||||
@@ -50,7 +52,7 @@ export class OpenAIProvider
|
||||
return OpenAIProvider.capabilities;
|
||||
}
|
||||
|
||||
private chatToGPTMessage(messages: ChatMessage[]) {
|
||||
private chatToGPTMessage(messages: PromptMessage[]) {
|
||||
// filter redundant fields
|
||||
return messages.map(message => ({
|
||||
role: message.role,
|
||||
@@ -63,7 +65,7 @@ export class OpenAIProvider
|
||||
embeddings,
|
||||
model,
|
||||
}: {
|
||||
messages?: ChatMessage[];
|
||||
messages?: PromptMessage[];
|
||||
embeddings?: string[];
|
||||
model: string;
|
||||
}) {
|
||||
@@ -106,7 +108,7 @@ export class OpenAIProvider
|
||||
// ====== text to text ======
|
||||
|
||||
async generateText(
|
||||
messages: ChatMessage[],
|
||||
messages: PromptMessage[],
|
||||
model: string = 'gpt-3.5-turbo',
|
||||
options: {
|
||||
temperature?: number;
|
||||
@@ -134,8 +136,8 @@ export class OpenAIProvider
|
||||
}
|
||||
|
||||
async *generateTextStream(
|
||||
messages: ChatMessage[],
|
||||
model: string,
|
||||
messages: PromptMessage[],
|
||||
model: string = 'gpt-3.5-turbo',
|
||||
options: {
|
||||
temperature?: number;
|
||||
maxTokens?: number;
|
||||
@@ -179,7 +181,7 @@ export class OpenAIProvider
|
||||
dimensions: number;
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
} = { dimensions: 256 }
|
||||
} = { dimensions: DEFAULT_DIMENSIONS }
|
||||
): Promise<number[][]> {
|
||||
messages = Array.isArray(messages) ? messages : [messages];
|
||||
this.checkParams({ embeddings: messages, model });
|
||||
@@ -187,7 +189,7 @@ export class OpenAIProvider
|
||||
const result = await this.instance.embeddings.create({
|
||||
model: model,
|
||||
input: messages,
|
||||
dimensions: options.dimensions,
|
||||
dimensions: options.dimensions || DEFAULT_DIMENSIONS,
|
||||
user: options.user,
|
||||
});
|
||||
return result.data.map(e => e.embedding);
|
||||
|
||||
Reference in New Issue
Block a user