feat: add blob upload support for copilot (#6584)

This commit is contained in:
DarkSky
2024-04-17 22:05:38 +08:00
committed by GitHub
parent e806169f60
commit ccb3bed91e
10 changed files with 260 additions and 54 deletions

View File

@@ -3,6 +3,8 @@ import {
Controller,
Get,
InternalServerErrorException,
Logger,
NotFoundException,
Param,
Query,
Req,
@@ -17,15 +19,18 @@ import {
from,
map,
merge,
mergeMap,
Observable,
switchMap,
toArray,
} from 'rxjs';
import { Public } from '../../core/auth';
import { CurrentUser } from '../../core/auth/current-user';
import { Config } from '../../fundamentals';
import { CopilotProviderService } from './providers';
import { ChatSession, ChatSessionService } from './session';
import { CopilotStorage } from './storage';
import { CopilotCapability } from './types';
export interface ChatEvent {
@@ -36,10 +41,13 @@ export interface ChatEvent {
@Controller('/api/copilot')
export class CopilotController {
private readonly logger = new Logger(CopilotController.name);
constructor(
private readonly config: Config,
private readonly chatSession: ChatSessionService,
private readonly provider: CopilotProviderService
private readonly provider: CopilotProviderService,
private readonly storage: CopilotStorage
) {}
private async hasAttachment(sessionId: string, messageId?: string) {
@@ -230,12 +238,19 @@ export class CopilotController {
delete params.message;
delete params.messageId;
const handleRemoteLink = this.storage.handleRemoteLink.bind(
this.storage,
user.id,
sessionId
);
return from(
provider.generateImagesStream(session.finish(params), session.model, {
signal: this.getSignal(req),
user: user.id,
})
).pipe(
mergeMap(handleRemoteLink),
connect(shared$ =>
merge(
// actual chat event stream
@@ -294,4 +309,33 @@ export class CopilotController {
res.status(response.status).send(await response.json());
}
@Public()
@Get('/blob/:userId/:workspaceId/:key')
async getBlob(
@Res() res: Response,
@Param('userId') userId: string,
@Param('workspaceId') workspaceId: string,
@Param('key') key: string
) {
const { body, metadata } = await this.storage.get(userId, workspaceId, key);
if (!body) {
throw new NotFoundException(
`Blob not found in ${userId}'s workspace ${workspaceId}: ${key}`
);
}
// metadata should always exists if body is not null
if (metadata) {
res.setHeader('content-type', metadata.contentType);
res.setHeader('last-modified', metadata.lastModified.toUTCString());
res.setHeader('content-length', metadata.contentLength);
} else {
this.logger.warn(`Blob ${workspaceId}/${key} has no metadata`);
}
res.setHeader('cache-control', 'public, max-age=2592000, immutable');
body.pipe(res);
}
}

View File

@@ -1,6 +1,6 @@
import { ServerFeature } from '../../core/config';
import { FeatureManagementService, FeatureService } from '../../core/features';
import { QuotaService } from '../../core/quota';
import { FeatureModule } from '../../core/features';
import { QuotaModule } from '../../core/quota';
import { PermissionService } from '../../core/workspaces/permission';
import { Plugin } from '../registry';
import { CopilotController } from './controller';
@@ -15,23 +15,23 @@ import {
} from './providers';
import { CopilotResolver, UserCopilotResolver } from './resolver';
import { ChatSessionService } from './session';
import { CopilotStorage } from './storage';
registerCopilotProvider(FalProvider);
registerCopilotProvider(OpenAIProvider);
@Plugin({
name: 'copilot',
imports: [FeatureModule, QuotaModule],
providers: [
PermissionService,
FeatureService,
FeatureManagementService,
QuotaService,
ChatSessionService,
CopilotResolver,
ChatMessageCache,
UserCopilotResolver,
PromptService,
CopilotProviderService,
CopilotStorage,
],
controllers: [CopilotController],
contributesTo: ServerFeature.Copilot,

View File

@@ -1,4 +1,4 @@
import { Logger } from '@nestjs/common';
import { BadRequestException, Logger } from '@nestjs/common';
import {
Args,
Field,
@@ -12,12 +12,18 @@ import {
Resolver,
} from '@nestjs/graphql';
import { GraphQLJSON, SafeIntResolver } from 'graphql-scalars';
import GraphQLUpload from 'graphql-upload/GraphQLUpload.mjs';
import { CurrentUser } from '../../core/auth';
import { UserType } from '../../core/user';
import { PermissionService } from '../../core/workspaces/permission';
import { MutexService, TooManyRequestsException } from '../../fundamentals';
import {
FileUpload,
MutexService,
TooManyRequestsException,
} from '../../fundamentals';
import { ChatSessionService } from './session';
import { CopilotStorage } from './storage';
import {
AvailableModels,
type ChatHistory,
@@ -28,7 +34,7 @@ import {
registerEnumType(AvailableModels, { name: 'CopilotModel' });
const COPILOT_LOCKER = 'copilot';
export const COPILOT_LOCKER = 'copilot';
// ================== Input Types ==================
@@ -57,6 +63,9 @@ class CreateChatMessageInput implements Omit<SubmittedMessage, 'content'> {
@Field(() => [String], { nullable: true })
attachments!: string[] | undefined;
@Field(() => [GraphQLUpload], { nullable: true })
blobs!: FileUpload[] | undefined;
@Field(() => GraphQLJSON, { nullable: true })
params!: Record<string, string> | undefined;
}
@@ -140,7 +149,8 @@ export class CopilotResolver {
constructor(
private readonly permissions: PermissionService,
private readonly mutex: MutexService,
private readonly chatSession: ChatSessionService
private readonly chatSession: ChatSessionService,
private readonly storage: CopilotStorage
) {}
@ResolveField(() => CopilotQuotaType, {
@@ -260,6 +270,25 @@ export class CopilotResolver {
if (!lock) {
return new TooManyRequestsException('Server is busy');
}
const session = await this.chatSession.get(options.sessionId);
if (!session) return new BadRequestException('Session not found');
if (options.blobs) {
options.attachments = options.attachments || [];
const { workspaceId } = session.config;
for (const blob of options.blobs) {
const uploaded = await this.storage.handleUpload(user.id, blob);
const link = await this.storage.put(
user.id,
workspaceId,
uploaded.filename,
uploaded.buffer
);
options.attachments.push(link);
}
}
try {
return await this.chatSession.createMessage(options);
} catch (e: any) {

View File

@@ -35,6 +35,18 @@ export class ChatSession implements AsyncDisposable {
return this.state.prompt.model;
}
get config() {
const {
sessionId,
userId,
workspaceId,
docId,
prompt: { name: promptName },
} = this.state;
return { sessionId, userId, workspaceId, docId, promptName };
}
push(message: ChatMessage) {
if (
this.state.prompt.action &&

View File

@@ -0,0 +1,91 @@
import { createHash } from 'node:crypto';
import { Injectable, PayloadTooLargeException } from '@nestjs/common';
import { QuotaManagementService } from '../../core/quota';
import {
type BlobInputType,
Config,
type FileUpload,
type StorageProvider,
StorageProviderFactory,
} from '../../fundamentals';
@Injectable()
export class CopilotStorage {
public readonly provider: StorageProvider;
constructor(
private readonly config: Config,
private readonly storageFactory: StorageProviderFactory,
private readonly quota: QuotaManagementService
) {
this.provider = this.storageFactory.create('copilot');
}
async put(
userId: string,
workspaceId: string,
key: string,
blob: BlobInputType
) {
const name = `${userId}/${workspaceId}/${key}`;
await this.provider.put(name, blob);
return `${this.config.baseUrl}/api/copilot/blob/${name}`;
}
async get(userId: string, workspaceId: string, key: string) {
return this.provider.get(`${userId}/${workspaceId}/${key}`);
}
async delete(userId: string, workspaceId: string, key: string) {
return this.provider.delete(`${userId}/${workspaceId}/${key}`);
}
async handleUpload(userId: string, blob: FileUpload) {
const checkExceeded = await this.quota.getQuotaCalculator(userId);
if (checkExceeded(0)) {
throw new PayloadTooLargeException(
'Storage or blob size limit exceeded.'
);
}
const buffer = await new Promise<Buffer>((resolve, reject) => {
const stream = blob.createReadStream();
const chunks: Uint8Array[] = [];
stream.on('data', chunk => {
chunks.push(chunk);
// check size after receive each chunk to avoid unnecessary memory usage
const bufferSize = chunks.reduce((acc, cur) => acc + cur.length, 0);
if (checkExceeded(bufferSize)) {
reject(
new PayloadTooLargeException('Storage or blob size limit exceeded.')
);
}
});
stream.on('error', reject);
stream.on('end', () => {
const buffer = Buffer.concat(chunks);
if (checkExceeded(buffer.length)) {
reject(new PayloadTooLargeException('Storage limit exceeded.'));
} else {
resolve(buffer);
}
});
});
return {
buffer,
filename: blob.filename,
};
}
async handleRemoteLink(userId: string, workspaceId: string, link: string) {
const response = await fetch(link);
const buffer = new Uint8Array(await response.arrayBuffer());
const filename = createHash('sha256').update(buffer).digest('base64url');
return this.put(userId, workspaceId, filename, Buffer.from(buffer));
}
}