feat: add blob upload support for copilot (#6584)

This commit is contained in:
DarkSky
2024-04-17 22:05:38 +08:00
committed by GitHub
parent e806169f60
commit ccb3bed91e
10 changed files with 260 additions and 54 deletions

View File

@@ -7,7 +7,10 @@ import { OneGB } from './constant';
import { QuotaService } from './service'; import { QuotaService } from './service';
import { formatSize, QuotaQueryType } from './types'; import { formatSize, QuotaQueryType } from './types';
type QuotaBusinessType = QuotaQueryType & { businessBlobLimit: number }; type QuotaBusinessType = QuotaQueryType & {
businessBlobLimit: number;
unlimited: boolean;
};
@Injectable() @Injectable()
export class QuotaManagementService { export class QuotaManagementService {
@@ -59,6 +62,52 @@ export class QuotaManagementService {
}, 0); }, 0);
} }
private generateQuotaCalculator(
quota: number,
blobLimit: number,
usedSize: number,
unlimited = false
) {
const checkExceeded = (recvSize: number) => {
const total = usedSize + recvSize;
// only skip total storage check if workspace has unlimited feature
if (total > quota && !unlimited) {
this.logger.log(`storage size limit exceeded: ${total} > ${quota}`);
return true;
} else if (recvSize > blobLimit) {
this.logger.log(`blob size limit exceeded: ${recvSize} > ${blobLimit}`);
return true;
} else {
return false;
}
};
return checkExceeded;
}
async getQuotaCalculator(userId: string) {
const quota = await this.getUserQuota(userId);
const { storageQuota, businessBlobLimit } = quota;
const usedSize = await this.getUserUsage(userId);
return this.generateQuotaCalculator(
storageQuota,
businessBlobLimit,
usedSize
);
}
async getQuotaCalculatorByWorkspace(workspaceId: string) {
const { storageQuota, usedSize, businessBlobLimit, unlimited } =
await this.getWorkspaceUsage(workspaceId);
return this.generateQuotaCalculator(
storageQuota,
businessBlobLimit,
usedSize,
unlimited
);
}
// get workspace's owner quota and total size of used // get workspace's owner quota and total size of used
// quota was apply to owner's account // quota was apply to owner's account
async getWorkspaceUsage(workspaceId: string): Promise<QuotaBusinessType> { async getWorkspaceUsage(workspaceId: string): Promise<QuotaBusinessType> {
@@ -79,6 +128,12 @@ export class QuotaManagementService {
} = await this.quota.getUserQuota(owner.id); } = await this.quota.getUserQuota(owner.id);
// get all workspaces size of owner used // get all workspaces size of owner used
const usedSize = await this.getUserUsage(owner.id); const usedSize = await this.getUserUsage(owner.id);
// relax restrictions if workspace has unlimited feature
// todo(@darkskygit): need a mechanism to allow feature as a middleware to edit quota
const unlimited = await this.feature.hasWorkspaceFeature(
workspaceId,
FeatureType.UnlimitedWorkspace
);
const quota = { const quota = {
name, name,
@@ -90,15 +145,10 @@ export class QuotaManagementService {
copilotActionLimit, copilotActionLimit,
humanReadable, humanReadable,
usedSize, usedSize,
unlimited,
}; };
// relax restrictions if workspace has unlimited feature if (quota.unlimited) {
// todo(@darkskygit): need a mechanism to allow feature as a middleware to edit quota
const unlimited = await this.feature.hasWorkspaceFeature(
workspaceId,
FeatureType.UnlimitedWorkspace
);
if (unlimited) {
return this.mergeUnlimitedQuota(quota); return this.mergeUnlimitedQuota(quota);
} }

View File

@@ -1,8 +1,4 @@
import { import { Logger, PayloadTooLargeException, UseGuards } from '@nestjs/common';
ForbiddenException,
Logger,
PayloadTooLargeException,
} from '@nestjs/common';
import { import {
Args, Args,
Int, Int,
@@ -16,20 +12,23 @@ import { SafeIntResolver } from 'graphql-scalars';
import GraphQLUpload from 'graphql-upload/GraphQLUpload.mjs'; import GraphQLUpload from 'graphql-upload/GraphQLUpload.mjs';
import type { FileUpload } from '../../../fundamentals'; import type { FileUpload } from '../../../fundamentals';
import { MakeCache, PreventCache } from '../../../fundamentals'; import {
CloudThrottlerGuard,
MakeCache,
PreventCache,
} from '../../../fundamentals';
import { CurrentUser } from '../../auth'; import { CurrentUser } from '../../auth';
import { FeatureManagementService, FeatureType } from '../../features';
import { QuotaManagementService } from '../../quota'; import { QuotaManagementService } from '../../quota';
import { WorkspaceBlobStorage } from '../../storage'; import { WorkspaceBlobStorage } from '../../storage';
import { PermissionService } from '../permission'; import { PermissionService } from '../permission';
import { Permission, WorkspaceBlobSizes, WorkspaceType } from '../types'; import { Permission, WorkspaceBlobSizes, WorkspaceType } from '../types';
@UseGuards(CloudThrottlerGuard)
@Resolver(() => WorkspaceType) @Resolver(() => WorkspaceType)
export class WorkspaceBlobResolver { export class WorkspaceBlobResolver {
logger = new Logger(WorkspaceBlobResolver.name); logger = new Logger(WorkspaceBlobResolver.name);
constructor( constructor(
private readonly permissions: PermissionService, private readonly permissions: PermissionService,
private readonly feature: FeatureManagementService,
private readonly quota: QuotaManagementService, private readonly quota: QuotaManagementService,
private readonly storage: WorkspaceBlobStorage private readonly storage: WorkspaceBlobStorage
) {} ) {}
@@ -124,34 +123,8 @@ export class WorkspaceBlobResolver {
Permission.Write Permission.Write
); );
const { storageQuota, usedSize, businessBlobLimit } = const checkExceeded =
await this.quota.getWorkspaceUsage(workspaceId); await this.quota.getQuotaCalculatorByWorkspace(workspaceId);
const unlimited = await this.feature.hasWorkspaceFeature(
workspaceId,
FeatureType.UnlimitedWorkspace
);
const checkExceeded = (recvSize: number) => {
if (!storageQuota) {
throw new ForbiddenException('Cannot find user quota.');
}
const total = usedSize + recvSize;
// only skip total storage check if workspace has unlimited feature
if (total > storageQuota && !unlimited) {
this.logger.log(
`storage size limit exceeded: ${total} > ${storageQuota}`
);
return true;
} else if (recvSize > businessBlobLimit) {
this.logger.log(
`blob size limit exceeded: ${recvSize} > ${businessBlobLimit}`
);
return true;
} else {
return false;
}
};
if (checkExceeded(0)) { if (checkExceeded(0)) {
throw new PayloadTooLargeException( throw new PayloadTooLargeException(

View File

@@ -19,6 +19,7 @@ export type StorageConfig<Ext = unknown> = {
export interface StoragesConfig { export interface StoragesConfig {
avatar: StorageConfig<{ publicLinkFactory: (key: string) => string }>; avatar: StorageConfig<{ publicLinkFactory: (key: string) => string }>;
blob: StorageConfig; blob: StorageConfig;
copilot: StorageConfig;
} }
export interface AFFiNEStorageConfig { export interface AFFiNEStorageConfig {
@@ -51,6 +52,10 @@ export function getDefaultAFFiNEStorageConfig(): AFFiNEStorageConfig {
provider: 'fs', provider: 'fs',
bucket: 'blobs', bucket: 'blobs',
}, },
copilot: {
provider: 'fs',
bucket: 'copilot',
},
}, },
}; };
} }

View File

@@ -3,6 +3,8 @@ import {
Controller, Controller,
Get, Get,
InternalServerErrorException, InternalServerErrorException,
Logger,
NotFoundException,
Param, Param,
Query, Query,
Req, Req,
@@ -17,15 +19,18 @@ import {
from, from,
map, map,
merge, merge,
mergeMap,
Observable, Observable,
switchMap, switchMap,
toArray, toArray,
} from 'rxjs'; } from 'rxjs';
import { Public } from '../../core/auth';
import { CurrentUser } from '../../core/auth/current-user'; import { CurrentUser } from '../../core/auth/current-user';
import { Config } from '../../fundamentals'; import { Config } from '../../fundamentals';
import { CopilotProviderService } from './providers'; import { CopilotProviderService } from './providers';
import { ChatSession, ChatSessionService } from './session'; import { ChatSession, ChatSessionService } from './session';
import { CopilotStorage } from './storage';
import { CopilotCapability } from './types'; import { CopilotCapability } from './types';
export interface ChatEvent { export interface ChatEvent {
@@ -36,10 +41,13 @@ export interface ChatEvent {
@Controller('/api/copilot') @Controller('/api/copilot')
export class CopilotController { export class CopilotController {
private readonly logger = new Logger(CopilotController.name);
constructor( constructor(
private readonly config: Config, private readonly config: Config,
private readonly chatSession: ChatSessionService, private readonly chatSession: ChatSessionService,
private readonly provider: CopilotProviderService private readonly provider: CopilotProviderService,
private readonly storage: CopilotStorage
) {} ) {}
private async hasAttachment(sessionId: string, messageId?: string) { private async hasAttachment(sessionId: string, messageId?: string) {
@@ -230,12 +238,19 @@ export class CopilotController {
delete params.message; delete params.message;
delete params.messageId; delete params.messageId;
const handleRemoteLink = this.storage.handleRemoteLink.bind(
this.storage,
user.id,
sessionId
);
return from( return from(
provider.generateImagesStream(session.finish(params), session.model, { provider.generateImagesStream(session.finish(params), session.model, {
signal: this.getSignal(req), signal: this.getSignal(req),
user: user.id, user: user.id,
}) })
).pipe( ).pipe(
mergeMap(handleRemoteLink),
connect(shared$ => connect(shared$ =>
merge( merge(
// actual chat event stream // actual chat event stream
@@ -294,4 +309,33 @@ export class CopilotController {
res.status(response.status).send(await response.json()); res.status(response.status).send(await response.json());
} }
@Public()
@Get('/blob/:userId/:workspaceId/:key')
async getBlob(
@Res() res: Response,
@Param('userId') userId: string,
@Param('workspaceId') workspaceId: string,
@Param('key') key: string
) {
const { body, metadata } = await this.storage.get(userId, workspaceId, key);
if (!body) {
throw new NotFoundException(
`Blob not found in ${userId}'s workspace ${workspaceId}: ${key}`
);
}
// metadata should always exists if body is not null
if (metadata) {
res.setHeader('content-type', metadata.contentType);
res.setHeader('last-modified', metadata.lastModified.toUTCString());
res.setHeader('content-length', metadata.contentLength);
} else {
this.logger.warn(`Blob ${workspaceId}/${key} has no metadata`);
}
res.setHeader('cache-control', 'public, max-age=2592000, immutable');
body.pipe(res);
}
} }

View File

@@ -1,6 +1,6 @@
import { ServerFeature } from '../../core/config'; import { ServerFeature } from '../../core/config';
import { FeatureManagementService, FeatureService } from '../../core/features'; import { FeatureModule } from '../../core/features';
import { QuotaService } from '../../core/quota'; import { QuotaModule } from '../../core/quota';
import { PermissionService } from '../../core/workspaces/permission'; import { PermissionService } from '../../core/workspaces/permission';
import { Plugin } from '../registry'; import { Plugin } from '../registry';
import { CopilotController } from './controller'; import { CopilotController } from './controller';
@@ -15,23 +15,23 @@ import {
} from './providers'; } from './providers';
import { CopilotResolver, UserCopilotResolver } from './resolver'; import { CopilotResolver, UserCopilotResolver } from './resolver';
import { ChatSessionService } from './session'; import { ChatSessionService } from './session';
import { CopilotStorage } from './storage';
registerCopilotProvider(FalProvider); registerCopilotProvider(FalProvider);
registerCopilotProvider(OpenAIProvider); registerCopilotProvider(OpenAIProvider);
@Plugin({ @Plugin({
name: 'copilot', name: 'copilot',
imports: [FeatureModule, QuotaModule],
providers: [ providers: [
PermissionService, PermissionService,
FeatureService,
FeatureManagementService,
QuotaService,
ChatSessionService, ChatSessionService,
CopilotResolver, CopilotResolver,
ChatMessageCache, ChatMessageCache,
UserCopilotResolver, UserCopilotResolver,
PromptService, PromptService,
CopilotProviderService, CopilotProviderService,
CopilotStorage,
], ],
controllers: [CopilotController], controllers: [CopilotController],
contributesTo: ServerFeature.Copilot, contributesTo: ServerFeature.Copilot,

View File

@@ -1,4 +1,4 @@
import { Logger } from '@nestjs/common'; import { BadRequestException, Logger } from '@nestjs/common';
import { import {
Args, Args,
Field, Field,
@@ -12,12 +12,18 @@ import {
Resolver, Resolver,
} from '@nestjs/graphql'; } from '@nestjs/graphql';
import { GraphQLJSON, SafeIntResolver } from 'graphql-scalars'; import { GraphQLJSON, SafeIntResolver } from 'graphql-scalars';
import GraphQLUpload from 'graphql-upload/GraphQLUpload.mjs';
import { CurrentUser } from '../../core/auth'; import { CurrentUser } from '../../core/auth';
import { UserType } from '../../core/user'; import { UserType } from '../../core/user';
import { PermissionService } from '../../core/workspaces/permission'; import { PermissionService } from '../../core/workspaces/permission';
import { MutexService, TooManyRequestsException } from '../../fundamentals'; import {
FileUpload,
MutexService,
TooManyRequestsException,
} from '../../fundamentals';
import { ChatSessionService } from './session'; import { ChatSessionService } from './session';
import { CopilotStorage } from './storage';
import { import {
AvailableModels, AvailableModels,
type ChatHistory, type ChatHistory,
@@ -28,7 +34,7 @@ import {
registerEnumType(AvailableModels, { name: 'CopilotModel' }); registerEnumType(AvailableModels, { name: 'CopilotModel' });
const COPILOT_LOCKER = 'copilot'; export const COPILOT_LOCKER = 'copilot';
// ================== Input Types ================== // ================== Input Types ==================
@@ -57,6 +63,9 @@ class CreateChatMessageInput implements Omit<SubmittedMessage, 'content'> {
@Field(() => [String], { nullable: true }) @Field(() => [String], { nullable: true })
attachments!: string[] | undefined; attachments!: string[] | undefined;
@Field(() => [GraphQLUpload], { nullable: true })
blobs!: FileUpload[] | undefined;
@Field(() => GraphQLJSON, { nullable: true }) @Field(() => GraphQLJSON, { nullable: true })
params!: Record<string, string> | undefined; params!: Record<string, string> | undefined;
} }
@@ -140,7 +149,8 @@ export class CopilotResolver {
constructor( constructor(
private readonly permissions: PermissionService, private readonly permissions: PermissionService,
private readonly mutex: MutexService, private readonly mutex: MutexService,
private readonly chatSession: ChatSessionService private readonly chatSession: ChatSessionService,
private readonly storage: CopilotStorage
) {} ) {}
@ResolveField(() => CopilotQuotaType, { @ResolveField(() => CopilotQuotaType, {
@@ -260,6 +270,25 @@ export class CopilotResolver {
if (!lock) { if (!lock) {
return new TooManyRequestsException('Server is busy'); return new TooManyRequestsException('Server is busy');
} }
const session = await this.chatSession.get(options.sessionId);
if (!session) return new BadRequestException('Session not found');
if (options.blobs) {
options.attachments = options.attachments || [];
const { workspaceId } = session.config;
for (const blob of options.blobs) {
const uploaded = await this.storage.handleUpload(user.id, blob);
const link = await this.storage.put(
user.id,
workspaceId,
uploaded.filename,
uploaded.buffer
);
options.attachments.push(link);
}
}
try { try {
return await this.chatSession.createMessage(options); return await this.chatSession.createMessage(options);
} catch (e: any) { } catch (e: any) {

View File

@@ -35,6 +35,18 @@ export class ChatSession implements AsyncDisposable {
return this.state.prompt.model; return this.state.prompt.model;
} }
get config() {
const {
sessionId,
userId,
workspaceId,
docId,
prompt: { name: promptName },
} = this.state;
return { sessionId, userId, workspaceId, docId, promptName };
}
push(message: ChatMessage) { push(message: ChatMessage) {
if ( if (
this.state.prompt.action && this.state.prompt.action &&

View File

@@ -0,0 +1,91 @@
import { createHash } from 'node:crypto';
import { Injectable, PayloadTooLargeException } from '@nestjs/common';
import { QuotaManagementService } from '../../core/quota';
import {
type BlobInputType,
Config,
type FileUpload,
type StorageProvider,
StorageProviderFactory,
} from '../../fundamentals';
@Injectable()
export class CopilotStorage {
public readonly provider: StorageProvider;
constructor(
private readonly config: Config,
private readonly storageFactory: StorageProviderFactory,
private readonly quota: QuotaManagementService
) {
this.provider = this.storageFactory.create('copilot');
}
async put(
userId: string,
workspaceId: string,
key: string,
blob: BlobInputType
) {
const name = `${userId}/${workspaceId}/${key}`;
await this.provider.put(name, blob);
return `${this.config.baseUrl}/api/copilot/blob/${name}`;
}
async get(userId: string, workspaceId: string, key: string) {
return this.provider.get(`${userId}/${workspaceId}/${key}`);
}
async delete(userId: string, workspaceId: string, key: string) {
return this.provider.delete(`${userId}/${workspaceId}/${key}`);
}
async handleUpload(userId: string, blob: FileUpload) {
const checkExceeded = await this.quota.getQuotaCalculator(userId);
if (checkExceeded(0)) {
throw new PayloadTooLargeException(
'Storage or blob size limit exceeded.'
);
}
const buffer = await new Promise<Buffer>((resolve, reject) => {
const stream = blob.createReadStream();
const chunks: Uint8Array[] = [];
stream.on('data', chunk => {
chunks.push(chunk);
// check size after receive each chunk to avoid unnecessary memory usage
const bufferSize = chunks.reduce((acc, cur) => acc + cur.length, 0);
if (checkExceeded(bufferSize)) {
reject(
new PayloadTooLargeException('Storage or blob size limit exceeded.')
);
}
});
stream.on('error', reject);
stream.on('end', () => {
const buffer = Buffer.concat(chunks);
if (checkExceeded(buffer.length)) {
reject(new PayloadTooLargeException('Storage limit exceeded.'));
} else {
resolve(buffer);
}
});
});
return {
buffer,
filename: blob.filename,
};
}
async handleRemoteLink(userId: string, workspaceId: string, link: string) {
const response = await fetch(link);
const buffer = new Uint8Array(await response.arrayBuffer());
const filename = createHash('sha256').update(buffer).digest('base64url');
return this.put(userId, workspaceId, filename, Buffer.from(buffer));
}
}

View File

@@ -40,6 +40,7 @@ type CopilotQuota {
input CreateChatMessageInput { input CreateChatMessageInput {
attachments: [String!] attachments: [String!]
blobs: [Upload!]
content: String content: String
params: JSON params: JSON
sessionId: String! sessionId: String!

View File

@@ -38,6 +38,7 @@ export interface Scalars {
export interface CreateChatMessageInput { export interface CreateChatMessageInput {
attachments: InputMaybe<Array<Scalars['String']['input']>>; attachments: InputMaybe<Array<Scalars['String']['input']>>;
blobs: InputMaybe<Array<Scalars['Upload']['input']>>;
content: InputMaybe<Scalars['String']['input']>; content: InputMaybe<Scalars['String']['input']>;
params: InputMaybe<Scalars['JSON']['input']>; params: InputMaybe<Scalars['JSON']['input']>;
sessionId: Scalars['String']['input']; sessionId: Scalars['String']['input'];