feat: add copilot impl (#6230)

fix CLOUD-22
fix CLOUD-24
This commit is contained in:
darkskygit
2024-04-10 11:15:31 +00:00
parent 46a368d7f1
commit e6a576551a
23 changed files with 669 additions and 34 deletions

View File

@@ -79,6 +79,10 @@ export class QuotaConfig {
return this.config.configs.memberLimit;
}
get copilotActionLimit() {
return this.config.configs.copilotActionLimit || undefined;
}
get humanReadable() {
return {
name: this.config.configs.name,
@@ -86,6 +90,9 @@ export class QuotaConfig {
storageQuota: formatSize(this.storageQuota),
historyPeriod: formatDate(this.historyPeriod),
memberLimit: this.memberLimit.toString(),
copilotActionLimit: this.copilotActionLimit
? `${this.copilotActionLimit} times`
: 'Unlimited',
};
}
}

View File

@@ -93,11 +93,35 @@ export const Quotas: Quota[] = [
memberLimit: 3,
},
},
{
feature: QuotaType.FreePlanV1,
type: FeatureKind.Quota,
version: 4,
configs: {
// quota name
name: 'Free',
// single blob limit 10MB
blobLimit: 10 * OneMB,
// server limit will larger then client to handle a edge case:
// when a user downgrades from pro to free, he can still continue
// to upload previously added files that exceed the free limit
// NOTE: this is a product decision, may change in future
businessBlobLimit: 100 * OneMB,
// total blob limit 10GB
storageQuota: 10 * OneGB,
// history period of validity 7 days
historyPeriod: 7 * OneDay,
// member limit 3
memberLimit: 3,
// copilot action limit 10
copilotActionLimit: 10,
},
},
];
export const Quota_FreePlanV1_1 = {
feature: Quotas[4].feature,
version: Quotas[4].version,
feature: Quotas[5].feature,
version: Quotas[5].version,
};
export const Quota_ProPlanV1 = {

View File

@@ -33,6 +33,7 @@ export class QuotaManagementService {
storageQuota: quota.feature.storageQuota,
historyPeriod: quota.feature.historyPeriod,
memberLimit: quota.feature.memberLimit,
copilotActionLimit: quota.feature.copilotActionLimit,
};
}
@@ -72,6 +73,7 @@ export class QuotaManagementService {
historyPeriod,
memberLimit,
storageQuota,
copilotActionLimit,
humanReadable,
},
} = await this.quota.getUserQuota(owner.id);
@@ -85,6 +87,7 @@ export class QuotaManagementService {
historyPeriod,
memberLimit,
storageQuota,
copilotActionLimit,
humanReadable,
usedSize,
};

View File

@@ -34,6 +34,7 @@ const quotaPlan = z.object({
historyPeriod: z.number().positive().int(),
memberLimit: z.number().positive().int(),
businessBlobLimit: z.number().positive().int().nullish(),
copilotActionLimit: z.number().positive().int().nullish(),
}),
});
@@ -65,6 +66,9 @@ export class HumanReadableQuotaType {
@Field(() => String)
memberLimit!: string;
@Field(() => String, { nullable: true })
copilotActionLimit?: string;
}
@ObjectType()
@@ -84,6 +88,9 @@ export class QuotaQueryType {
@Field(() => SafeIntResolver)
storageQuota!: number;
@Field(() => SafeIntResolver, { nullable: true })
copilotActionLimit?: number;
@Field(() => HumanReadableQuotaType)
humanReadable!: HumanReadableQuotaType;

View File

@@ -1,12 +1,25 @@
import { ServerFeature } from '../../core/config';
import { PermissionService } from '../../core/workspaces/permission';
import { Plugin } from '../registry';
import { PromptService } from './prompt';
import { assertProvidersConfigs, CopilotProviderService } from './providers';
import {
assertProvidersConfigs,
CopilotProviderService,
OpenAIProvider,
registerCopilotProvider,
} from './providers';
import { ChatSessionService } from './session';
registerCopilotProvider(OpenAIProvider);
@Plugin({
name: 'copilot',
providers: [ChatSessionService, PromptService, CopilotProviderService],
providers: [
PermissionService,
ChatSessionService,
PromptService,
CopilotProviderService,
],
contributesTo: ServerFeature.Copilot,
if: config => {
if (config.flavor.graphql) {

View File

@@ -5,9 +5,9 @@ import { Injectable, Logger } from '@nestjs/common';
import { Config } from '../../../fundamentals';
import {
CapabilityToCopilotProvider,
CopilotCapability,
CopilotConfig,
CopilotProvider,
CopilotProviderCapability,
CopilotProviderType,
} from '../types';
@@ -19,7 +19,7 @@ interface CopilotProviderDefinition<C extends CopilotProviderConfig> {
// type of the provider
readonly type: CopilotProviderType;
// capabilities of the provider, like text to text, text to image, etc.
readonly capabilities: CopilotProviderCapability[];
readonly capabilities: CopilotCapability[];
// asserts that the config is valid for this provider
assetsConfig(config: C): boolean;
}
@@ -32,7 +32,7 @@ const COPILOT_PROVIDER = new Map<
// map of capabilities to providers
const PROVIDER_CAPABILITY_MAP = new Map<
CopilotProviderCapability,
CopilotCapability,
CopilotProviderType[]
>();
@@ -116,7 +116,7 @@ export class CopilotProviderService {
return this.cachedProviders.get(provider)!;
}
getProviderByCapability<C extends CopilotProviderCapability>(
getProviderByCapability<C extends CopilotCapability>(
capability: C,
prefer?: CopilotProviderType
): CapabilityToCopilotProvider[C] | null {
@@ -133,3 +133,5 @@ export class CopilotProviderService {
return null;
}
}
export { OpenAIProvider } from './openai';

View File

@@ -0,0 +1,195 @@
import assert from 'node:assert';
import { ClientOptions, OpenAI } from 'openai';
import {
ChatMessage,
ChatMessageRole,
CopilotCapability,
CopilotProviderType,
CopilotTextToEmbeddingProvider,
CopilotTextToTextProvider,
} from '../types';
export class OpenAIProvider
implements CopilotTextToTextProvider, CopilotTextToEmbeddingProvider
{
static readonly type = CopilotProviderType.OpenAI;
static readonly capabilities = [
CopilotCapability.TextToText,
CopilotCapability.TextToEmbedding,
CopilotCapability.TextToImage,
];
readonly availableModels = [
// text to text
'gpt-4-vision-preview',
'gpt-4-turbo-preview',
'gpt-3.5-turbo',
// embeddings
'text-embedding-3-large',
'text-embedding-3-small',
'text-embedding-ada-002',
// moderation
'text-moderation-latest',
'text-moderation-stable',
];
private readonly instance: OpenAI;
constructor(config: ClientOptions) {
assert(OpenAIProvider.assetsConfig(config));
this.instance = new OpenAI(config);
}
static assetsConfig(config: ClientOptions) {
return !!config.apiKey;
}
getCapabilities(): CopilotCapability[] {
return OpenAIProvider.capabilities;
}
private chatToGPTMessage(messages: ChatMessage[]) {
// filter redundant fields
return messages.map(message => ({
role: message.role,
content: message.content,
}));
}
private checkParams({
messages,
embeddings,
model,
}: {
messages?: ChatMessage[];
embeddings?: string[];
model: string;
}) {
if (!this.availableModels.includes(model)) {
throw new Error(`Invalid model: ${model}`);
}
if (Array.isArray(messages) && messages.length > 0) {
if (
messages.some(
m =>
// check non-object
typeof m !== 'object' ||
!m ||
// check content
typeof m.content !== 'string' ||
!m.content ||
!m.content.trim()
)
) {
throw new Error('Empty message content');
}
if (
messages.some(
m =>
typeof m.role !== 'string' ||
!m.role ||
!ChatMessageRole.includes(m.role)
)
) {
throw new Error('Invalid message role');
}
} else if (
Array.isArray(embeddings) &&
embeddings.some(e => typeof e !== 'string' || !e || !e.trim())
) {
throw new Error('Invalid embedding');
}
}
// ====== text to text ======
async generateText(
messages: ChatMessage[],
model: string = 'gpt-3.5-turbo',
options: {
temperature?: number;
maxTokens?: number;
signal?: AbortSignal;
user?: string;
} = {}
): Promise<string> {
this.checkParams({ messages, model });
const result = await this.instance.chat.completions.create(
{
messages: this.chatToGPTMessage(messages),
model: model,
temperature: options.temperature || 0,
max_tokens: options.maxTokens || 4096,
user: options.user,
},
{ signal: options.signal }
);
const { content } = result.choices[0].message;
if (!content) {
throw new Error('Failed to generate text');
}
return content;
}
async *generateTextStream(
messages: ChatMessage[],
model: string,
options: {
temperature?: number;
maxTokens?: number;
signal?: AbortSignal;
user?: string;
} = {}
): AsyncIterable<string> {
this.checkParams({ messages, model });
const result = await this.instance.chat.completions.create(
{
stream: true,
messages: this.chatToGPTMessage(messages),
model: model,
temperature: options.temperature || 0,
max_tokens: options.maxTokens || 4096,
user: options.user,
},
{
signal: options.signal,
}
);
for await (const message of result) {
const content = message.choices[0].delta.content;
if (content) {
yield content;
if (options.signal?.aborted) {
result.controller.abort();
break;
}
}
}
}
// ====== text to embedding ======
async generateEmbedding(
messages: string | string[],
model: string,
options: {
dimensions: number;
signal?: AbortSignal;
user?: string;
} = { dimensions: 256 }
): Promise<number[][]> {
messages = Array.isArray(messages) ? messages : [messages];
this.checkParams({ embeddings: messages, model });
const result = await this.instance.embeddings.create({
model: model,
input: messages,
dimensions: options.dimensions,
user: options.user,
});
return result.data.map(e => e.embedding);
}
}

View File

@@ -5,8 +5,11 @@ import { PrismaClient } from '@prisma/client';
import { ChatPrompt, PromptService } from './prompt';
import {
AvailableModel,
ChatHistory,
ChatMessage,
ChatMessageSchema,
getTokenEncoder,
PromptMessage,
PromptParams,
} from './types';
@@ -27,6 +30,13 @@ export interface ChatSessionState
messages: ChatMessage[];
}
export type ListHistoriesOptions = {
action: boolean | undefined;
limit: number | undefined;
skip: number | undefined;
sessionId: string | undefined;
};
export class ChatSession implements AsyncDisposable {
constructor(
private readonly state: ChatSessionState,
@@ -39,6 +49,13 @@ export class ChatSession implements AsyncDisposable {
}
push(message: ChatMessage) {
if (
this.state.prompt.action &&
this.state.messages.length > 0 &&
message.role === 'user'
) {
throw new Error('Action has been taken, no more messages allowed');
}
this.state.messages.push(message);
}
@@ -167,6 +184,53 @@ export class ChatSessionService {
});
}
async listHistories(
workspaceId: string,
docId: string,
options: ListHistoriesOptions
): Promise<ChatHistory[]> {
return await this.db.aiSession
.findMany({
where: {
workspaceId: workspaceId,
docId: workspaceId === docId ? undefined : docId,
prompt: { action: { not: null } },
id: options.sessionId ? { equals: options.sessionId } : undefined,
},
select: {
id: true,
prompt: true,
messages: {
select: {
role: true,
content: true,
},
orderBy: {
createdAt: 'asc',
},
},
},
take: options.limit,
skip: options.skip,
orderBy: { createdAt: 'desc' },
})
.then(sessions =>
sessions
.map(({ id, prompt, messages }) => {
const ret = ChatMessageSchema.array().safeParse(messages);
if (ret.success) {
const encoder = getTokenEncoder(prompt.model as AvailableModel);
const tokens = ret.data
.map(m => encoder?.encode_ordinary(m.content).length || 0)
.reduce((total, length) => total + length, 0);
return { sessionId: id, tokens, messages: ret.data };
}
return undefined;
})
.filter((v): v is NonNullable<typeof v> => !!v)
);
}
async create(options: ChatSessionOptions): Promise<string> {
const sessionId = randomUUID();
const prompt = await this.prompt.get(options.promptName);

View File

@@ -90,7 +90,7 @@ export enum CopilotProviderType {
OpenAI = 'openai',
}
export enum CopilotProviderCapability {
export enum CopilotCapability {
TextToText = 'text-to-text',
TextToEmbedding = 'text-to-embedding',
TextToImage = 'text-to-image',
@@ -98,7 +98,7 @@ export enum CopilotProviderCapability {
}
export interface CopilotProvider {
getCapabilities(): CopilotProviderCapability[];
getCapabilities(): CopilotCapability[];
}
export interface CopilotTextToTextProvider extends CopilotProvider {
@@ -124,15 +124,25 @@ export interface CopilotTextToTextProvider extends CopilotProvider {
): AsyncIterable<string>;
}
export interface CopilotTextToEmbeddingProvider extends CopilotProvider {}
export interface CopilotTextToEmbeddingProvider extends CopilotProvider {
generateEmbedding(
messages: string[] | string,
model: string,
options: {
dimensions: number;
signal?: AbortSignal;
user?: string;
}
): Promise<number[][]>;
}
export interface CopilotTextToImageProvider extends CopilotProvider {}
export interface CopilotImageToImageProvider extends CopilotProvider {}
export type CapabilityToCopilotProvider = {
[CopilotProviderCapability.TextToText]: CopilotTextToTextProvider;
[CopilotProviderCapability.TextToEmbedding]: CopilotTextToEmbeddingProvider;
[CopilotProviderCapability.TextToImage]: CopilotTextToImageProvider;
[CopilotProviderCapability.ImageToImage]: CopilotImageToImageProvider;
[CopilotCapability.TextToText]: CopilotTextToTextProvider;
[CopilotCapability.TextToEmbedding]: CopilotTextToEmbeddingProvider;
[CopilotCapability.TextToImage]: CopilotTextToImageProvider;
[CopilotCapability.ImageToImage]: CopilotImageToImageProvider;
};

View File

@@ -38,6 +38,7 @@ enum FeatureType {
type HumanReadableQuotaType {
blobLimit: String!
copilotActionLimit: String
historyPeriod: String!
memberLimit: String!
name: String!
@@ -224,6 +225,7 @@ type Query {
type QuotaQueryType {
blobLimit: SafeInt!
copilotActionLimit: SafeInt
historyPeriod: SafeInt!
humanReadable: HumanReadableQuotaType!
memberLimit: SafeInt!