feat: add copilot impl (#6230)

fix CLOUD-22
fix CLOUD-24
This commit is contained in:
darkskygit
2024-04-10 11:15:31 +00:00
parent 46a368d7f1
commit e6a576551a
23 changed files with 669 additions and 34 deletions

View File

@@ -5,9 +5,9 @@ import { Injectable, Logger } from '@nestjs/common';
import { Config } from '../../../fundamentals';
import {
CapabilityToCopilotProvider,
CopilotCapability,
CopilotConfig,
CopilotProvider,
CopilotProviderCapability,
CopilotProviderType,
} from '../types';
@@ -19,7 +19,7 @@ interface CopilotProviderDefinition<C extends CopilotProviderConfig> {
// type of the provider
readonly type: CopilotProviderType;
// capabilities of the provider, like text to text, text to image, etc.
readonly capabilities: CopilotProviderCapability[];
readonly capabilities: CopilotCapability[];
// asserts that the config is valid for this provider
assetsConfig(config: C): boolean;
}
@@ -32,7 +32,7 @@ const COPILOT_PROVIDER = new Map<
// map of capabilities to providers
const PROVIDER_CAPABILITY_MAP = new Map<
CopilotProviderCapability,
CopilotCapability,
CopilotProviderType[]
>();
@@ -116,7 +116,7 @@ export class CopilotProviderService {
return this.cachedProviders.get(provider)!;
}
getProviderByCapability<C extends CopilotProviderCapability>(
getProviderByCapability<C extends CopilotCapability>(
capability: C,
prefer?: CopilotProviderType
): CapabilityToCopilotProvider[C] | null {
@@ -133,3 +133,5 @@ export class CopilotProviderService {
return null;
}
}
export { OpenAIProvider } from './openai';

View File

@@ -0,0 +1,195 @@
import assert from 'node:assert';
import { ClientOptions, OpenAI } from 'openai';
import {
ChatMessage,
ChatMessageRole,
CopilotCapability,
CopilotProviderType,
CopilotTextToEmbeddingProvider,
CopilotTextToTextProvider,
} from '../types';
export class OpenAIProvider
implements CopilotTextToTextProvider, CopilotTextToEmbeddingProvider
{
static readonly type = CopilotProviderType.OpenAI;
static readonly capabilities = [
CopilotCapability.TextToText,
CopilotCapability.TextToEmbedding,
CopilotCapability.TextToImage,
];
readonly availableModels = [
// text to text
'gpt-4-vision-preview',
'gpt-4-turbo-preview',
'gpt-3.5-turbo',
// embeddings
'text-embedding-3-large',
'text-embedding-3-small',
'text-embedding-ada-002',
// moderation
'text-moderation-latest',
'text-moderation-stable',
];
private readonly instance: OpenAI;
constructor(config: ClientOptions) {
assert(OpenAIProvider.assetsConfig(config));
this.instance = new OpenAI(config);
}
static assetsConfig(config: ClientOptions) {
return !!config.apiKey;
}
getCapabilities(): CopilotCapability[] {
return OpenAIProvider.capabilities;
}
private chatToGPTMessage(messages: ChatMessage[]) {
// filter redundant fields
return messages.map(message => ({
role: message.role,
content: message.content,
}));
}
private checkParams({
messages,
embeddings,
model,
}: {
messages?: ChatMessage[];
embeddings?: string[];
model: string;
}) {
if (!this.availableModels.includes(model)) {
throw new Error(`Invalid model: ${model}`);
}
if (Array.isArray(messages) && messages.length > 0) {
if (
messages.some(
m =>
// check non-object
typeof m !== 'object' ||
!m ||
// check content
typeof m.content !== 'string' ||
!m.content ||
!m.content.trim()
)
) {
throw new Error('Empty message content');
}
if (
messages.some(
m =>
typeof m.role !== 'string' ||
!m.role ||
!ChatMessageRole.includes(m.role)
)
) {
throw new Error('Invalid message role');
}
} else if (
Array.isArray(embeddings) &&
embeddings.some(e => typeof e !== 'string' || !e || !e.trim())
) {
throw new Error('Invalid embedding');
}
}
// ====== text to text ======
async generateText(
messages: ChatMessage[],
model: string = 'gpt-3.5-turbo',
options: {
temperature?: number;
maxTokens?: number;
signal?: AbortSignal;
user?: string;
} = {}
): Promise<string> {
this.checkParams({ messages, model });
const result = await this.instance.chat.completions.create(
{
messages: this.chatToGPTMessage(messages),
model: model,
temperature: options.temperature || 0,
max_tokens: options.maxTokens || 4096,
user: options.user,
},
{ signal: options.signal }
);
const { content } = result.choices[0].message;
if (!content) {
throw new Error('Failed to generate text');
}
return content;
}
async *generateTextStream(
messages: ChatMessage[],
model: string,
options: {
temperature?: number;
maxTokens?: number;
signal?: AbortSignal;
user?: string;
} = {}
): AsyncIterable<string> {
this.checkParams({ messages, model });
const result = await this.instance.chat.completions.create(
{
stream: true,
messages: this.chatToGPTMessage(messages),
model: model,
temperature: options.temperature || 0,
max_tokens: options.maxTokens || 4096,
user: options.user,
},
{
signal: options.signal,
}
);
for await (const message of result) {
const content = message.choices[0].delta.content;
if (content) {
yield content;
if (options.signal?.aborted) {
result.controller.abort();
break;
}
}
}
}
// ====== text to embedding ======
async generateEmbedding(
messages: string | string[],
model: string,
options: {
dimensions: number;
signal?: AbortSignal;
user?: string;
} = { dimensions: 256 }
): Promise<number[][]> {
messages = Array.isArray(messages) ? messages : [messages];
this.checkParams({ embeddings: messages, model });
const result = await this.instance.embeddings.create({
model: model,
input: messages,
dimensions: options.dimensions,
user: options.user,
});
return result.data.map(e => e.embedding);
}
}