mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-13 21:05:19 +00:00
@@ -5,9 +5,9 @@ import { Injectable, Logger } from '@nestjs/common';
|
||||
import { Config } from '../../../fundamentals';
|
||||
import {
|
||||
CapabilityToCopilotProvider,
|
||||
CopilotCapability,
|
||||
CopilotConfig,
|
||||
CopilotProvider,
|
||||
CopilotProviderCapability,
|
||||
CopilotProviderType,
|
||||
} from '../types';
|
||||
|
||||
@@ -19,7 +19,7 @@ interface CopilotProviderDefinition<C extends CopilotProviderConfig> {
|
||||
// type of the provider
|
||||
readonly type: CopilotProviderType;
|
||||
// capabilities of the provider, like text to text, text to image, etc.
|
||||
readonly capabilities: CopilotProviderCapability[];
|
||||
readonly capabilities: CopilotCapability[];
|
||||
// asserts that the config is valid for this provider
|
||||
assetsConfig(config: C): boolean;
|
||||
}
|
||||
@@ -32,7 +32,7 @@ const COPILOT_PROVIDER = new Map<
|
||||
|
||||
// map of capabilities to providers
|
||||
const PROVIDER_CAPABILITY_MAP = new Map<
|
||||
CopilotProviderCapability,
|
||||
CopilotCapability,
|
||||
CopilotProviderType[]
|
||||
>();
|
||||
|
||||
@@ -116,7 +116,7 @@ export class CopilotProviderService {
|
||||
return this.cachedProviders.get(provider)!;
|
||||
}
|
||||
|
||||
getProviderByCapability<C extends CopilotProviderCapability>(
|
||||
getProviderByCapability<C extends CopilotCapability>(
|
||||
capability: C,
|
||||
prefer?: CopilotProviderType
|
||||
): CapabilityToCopilotProvider[C] | null {
|
||||
@@ -133,3 +133,5 @@ export class CopilotProviderService {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
export { OpenAIProvider } from './openai';
|
||||
|
||||
195
packages/backend/server/src/plugins/copilot/providers/openai.ts
Normal file
195
packages/backend/server/src/plugins/copilot/providers/openai.ts
Normal file
@@ -0,0 +1,195 @@
|
||||
import assert from 'node:assert';
|
||||
|
||||
import { ClientOptions, OpenAI } from 'openai';
|
||||
|
||||
import {
|
||||
ChatMessage,
|
||||
ChatMessageRole,
|
||||
CopilotCapability,
|
||||
CopilotProviderType,
|
||||
CopilotTextToEmbeddingProvider,
|
||||
CopilotTextToTextProvider,
|
||||
} from '../types';
|
||||
|
||||
export class OpenAIProvider
|
||||
implements CopilotTextToTextProvider, CopilotTextToEmbeddingProvider
|
||||
{
|
||||
static readonly type = CopilotProviderType.OpenAI;
|
||||
static readonly capabilities = [
|
||||
CopilotCapability.TextToText,
|
||||
CopilotCapability.TextToEmbedding,
|
||||
CopilotCapability.TextToImage,
|
||||
];
|
||||
|
||||
readonly availableModels = [
|
||||
// text to text
|
||||
'gpt-4-vision-preview',
|
||||
'gpt-4-turbo-preview',
|
||||
'gpt-3.5-turbo',
|
||||
// embeddings
|
||||
'text-embedding-3-large',
|
||||
'text-embedding-3-small',
|
||||
'text-embedding-ada-002',
|
||||
// moderation
|
||||
'text-moderation-latest',
|
||||
'text-moderation-stable',
|
||||
];
|
||||
|
||||
private readonly instance: OpenAI;
|
||||
|
||||
constructor(config: ClientOptions) {
|
||||
assert(OpenAIProvider.assetsConfig(config));
|
||||
this.instance = new OpenAI(config);
|
||||
}
|
||||
|
||||
static assetsConfig(config: ClientOptions) {
|
||||
return !!config.apiKey;
|
||||
}
|
||||
|
||||
getCapabilities(): CopilotCapability[] {
|
||||
return OpenAIProvider.capabilities;
|
||||
}
|
||||
|
||||
private chatToGPTMessage(messages: ChatMessage[]) {
|
||||
// filter redundant fields
|
||||
return messages.map(message => ({
|
||||
role: message.role,
|
||||
content: message.content,
|
||||
}));
|
||||
}
|
||||
|
||||
private checkParams({
|
||||
messages,
|
||||
embeddings,
|
||||
model,
|
||||
}: {
|
||||
messages?: ChatMessage[];
|
||||
embeddings?: string[];
|
||||
model: string;
|
||||
}) {
|
||||
if (!this.availableModels.includes(model)) {
|
||||
throw new Error(`Invalid model: ${model}`);
|
||||
}
|
||||
if (Array.isArray(messages) && messages.length > 0) {
|
||||
if (
|
||||
messages.some(
|
||||
m =>
|
||||
// check non-object
|
||||
typeof m !== 'object' ||
|
||||
!m ||
|
||||
// check content
|
||||
typeof m.content !== 'string' ||
|
||||
!m.content ||
|
||||
!m.content.trim()
|
||||
)
|
||||
) {
|
||||
throw new Error('Empty message content');
|
||||
}
|
||||
if (
|
||||
messages.some(
|
||||
m =>
|
||||
typeof m.role !== 'string' ||
|
||||
!m.role ||
|
||||
!ChatMessageRole.includes(m.role)
|
||||
)
|
||||
) {
|
||||
throw new Error('Invalid message role');
|
||||
}
|
||||
} else if (
|
||||
Array.isArray(embeddings) &&
|
||||
embeddings.some(e => typeof e !== 'string' || !e || !e.trim())
|
||||
) {
|
||||
throw new Error('Invalid embedding');
|
||||
}
|
||||
}
|
||||
|
||||
// ====== text to text ======
|
||||
|
||||
async generateText(
|
||||
messages: ChatMessage[],
|
||||
model: string = 'gpt-3.5-turbo',
|
||||
options: {
|
||||
temperature?: number;
|
||||
maxTokens?: number;
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
} = {}
|
||||
): Promise<string> {
|
||||
this.checkParams({ messages, model });
|
||||
const result = await this.instance.chat.completions.create(
|
||||
{
|
||||
messages: this.chatToGPTMessage(messages),
|
||||
model: model,
|
||||
temperature: options.temperature || 0,
|
||||
max_tokens: options.maxTokens || 4096,
|
||||
user: options.user,
|
||||
},
|
||||
{ signal: options.signal }
|
||||
);
|
||||
const { content } = result.choices[0].message;
|
||||
if (!content) {
|
||||
throw new Error('Failed to generate text');
|
||||
}
|
||||
return content;
|
||||
}
|
||||
|
||||
async *generateTextStream(
|
||||
messages: ChatMessage[],
|
||||
model: string,
|
||||
options: {
|
||||
temperature?: number;
|
||||
maxTokens?: number;
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
} = {}
|
||||
): AsyncIterable<string> {
|
||||
this.checkParams({ messages, model });
|
||||
const result = await this.instance.chat.completions.create(
|
||||
{
|
||||
stream: true,
|
||||
messages: this.chatToGPTMessage(messages),
|
||||
model: model,
|
||||
temperature: options.temperature || 0,
|
||||
max_tokens: options.maxTokens || 4096,
|
||||
user: options.user,
|
||||
},
|
||||
{
|
||||
signal: options.signal,
|
||||
}
|
||||
);
|
||||
|
||||
for await (const message of result) {
|
||||
const content = message.choices[0].delta.content;
|
||||
if (content) {
|
||||
yield content;
|
||||
if (options.signal?.aborted) {
|
||||
result.controller.abort();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ====== text to embedding ======
|
||||
|
||||
async generateEmbedding(
|
||||
messages: string | string[],
|
||||
model: string,
|
||||
options: {
|
||||
dimensions: number;
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
} = { dimensions: 256 }
|
||||
): Promise<number[][]> {
|
||||
messages = Array.isArray(messages) ? messages : [messages];
|
||||
this.checkParams({ embeddings: messages, model });
|
||||
|
||||
const result = await this.instance.embeddings.create({
|
||||
model: model,
|
||||
input: messages,
|
||||
dimensions: options.dimensions,
|
||||
user: options.user,
|
||||
});
|
||||
return result.data.map(e => e.embedding);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user