mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-13 12:55:00 +00:00
feat: add blob upload support for copilot (#6584)
This commit is contained in:
@@ -3,6 +3,8 @@ import {
|
||||
Controller,
|
||||
Get,
|
||||
InternalServerErrorException,
|
||||
Logger,
|
||||
NotFoundException,
|
||||
Param,
|
||||
Query,
|
||||
Req,
|
||||
@@ -17,15 +19,18 @@ import {
|
||||
from,
|
||||
map,
|
||||
merge,
|
||||
mergeMap,
|
||||
Observable,
|
||||
switchMap,
|
||||
toArray,
|
||||
} from 'rxjs';
|
||||
|
||||
import { Public } from '../../core/auth';
|
||||
import { CurrentUser } from '../../core/auth/current-user';
|
||||
import { Config } from '../../fundamentals';
|
||||
import { CopilotProviderService } from './providers';
|
||||
import { ChatSession, ChatSessionService } from './session';
|
||||
import { CopilotStorage } from './storage';
|
||||
import { CopilotCapability } from './types';
|
||||
|
||||
export interface ChatEvent {
|
||||
@@ -36,10 +41,13 @@ export interface ChatEvent {
|
||||
|
||||
@Controller('/api/copilot')
|
||||
export class CopilotController {
|
||||
private readonly logger = new Logger(CopilotController.name);
|
||||
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
private readonly chatSession: ChatSessionService,
|
||||
private readonly provider: CopilotProviderService
|
||||
private readonly provider: CopilotProviderService,
|
||||
private readonly storage: CopilotStorage
|
||||
) {}
|
||||
|
||||
private async hasAttachment(sessionId: string, messageId?: string) {
|
||||
@@ -230,12 +238,19 @@ export class CopilotController {
|
||||
|
||||
delete params.message;
|
||||
delete params.messageId;
|
||||
const handleRemoteLink = this.storage.handleRemoteLink.bind(
|
||||
this.storage,
|
||||
user.id,
|
||||
sessionId
|
||||
);
|
||||
|
||||
return from(
|
||||
provider.generateImagesStream(session.finish(params), session.model, {
|
||||
signal: this.getSignal(req),
|
||||
user: user.id,
|
||||
})
|
||||
).pipe(
|
||||
mergeMap(handleRemoteLink),
|
||||
connect(shared$ =>
|
||||
merge(
|
||||
// actual chat event stream
|
||||
@@ -294,4 +309,33 @@ export class CopilotController {
|
||||
|
||||
res.status(response.status).send(await response.json());
|
||||
}
|
||||
|
||||
@Public()
|
||||
@Get('/blob/:userId/:workspaceId/:key')
|
||||
async getBlob(
|
||||
@Res() res: Response,
|
||||
@Param('userId') userId: string,
|
||||
@Param('workspaceId') workspaceId: string,
|
||||
@Param('key') key: string
|
||||
) {
|
||||
const { body, metadata } = await this.storage.get(userId, workspaceId, key);
|
||||
|
||||
if (!body) {
|
||||
throw new NotFoundException(
|
||||
`Blob not found in ${userId}'s workspace ${workspaceId}: ${key}`
|
||||
);
|
||||
}
|
||||
|
||||
// metadata should always exists if body is not null
|
||||
if (metadata) {
|
||||
res.setHeader('content-type', metadata.contentType);
|
||||
res.setHeader('last-modified', metadata.lastModified.toUTCString());
|
||||
res.setHeader('content-length', metadata.contentLength);
|
||||
} else {
|
||||
this.logger.warn(`Blob ${workspaceId}/${key} has no metadata`);
|
||||
}
|
||||
|
||||
res.setHeader('cache-control', 'public, max-age=2592000, immutable');
|
||||
body.pipe(res);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { ServerFeature } from '../../core/config';
|
||||
import { FeatureManagementService, FeatureService } from '../../core/features';
|
||||
import { QuotaService } from '../../core/quota';
|
||||
import { FeatureModule } from '../../core/features';
|
||||
import { QuotaModule } from '../../core/quota';
|
||||
import { PermissionService } from '../../core/workspaces/permission';
|
||||
import { Plugin } from '../registry';
|
||||
import { CopilotController } from './controller';
|
||||
@@ -15,23 +15,23 @@ import {
|
||||
} from './providers';
|
||||
import { CopilotResolver, UserCopilotResolver } from './resolver';
|
||||
import { ChatSessionService } from './session';
|
||||
import { CopilotStorage } from './storage';
|
||||
|
||||
registerCopilotProvider(FalProvider);
|
||||
registerCopilotProvider(OpenAIProvider);
|
||||
|
||||
@Plugin({
|
||||
name: 'copilot',
|
||||
imports: [FeatureModule, QuotaModule],
|
||||
providers: [
|
||||
PermissionService,
|
||||
FeatureService,
|
||||
FeatureManagementService,
|
||||
QuotaService,
|
||||
ChatSessionService,
|
||||
CopilotResolver,
|
||||
ChatMessageCache,
|
||||
UserCopilotResolver,
|
||||
PromptService,
|
||||
CopilotProviderService,
|
||||
CopilotStorage,
|
||||
],
|
||||
controllers: [CopilotController],
|
||||
contributesTo: ServerFeature.Copilot,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Logger } from '@nestjs/common';
|
||||
import { BadRequestException, Logger } from '@nestjs/common';
|
||||
import {
|
||||
Args,
|
||||
Field,
|
||||
@@ -12,12 +12,18 @@ import {
|
||||
Resolver,
|
||||
} from '@nestjs/graphql';
|
||||
import { GraphQLJSON, SafeIntResolver } from 'graphql-scalars';
|
||||
import GraphQLUpload from 'graphql-upload/GraphQLUpload.mjs';
|
||||
|
||||
import { CurrentUser } from '../../core/auth';
|
||||
import { UserType } from '../../core/user';
|
||||
import { PermissionService } from '../../core/workspaces/permission';
|
||||
import { MutexService, TooManyRequestsException } from '../../fundamentals';
|
||||
import {
|
||||
FileUpload,
|
||||
MutexService,
|
||||
TooManyRequestsException,
|
||||
} from '../../fundamentals';
|
||||
import { ChatSessionService } from './session';
|
||||
import { CopilotStorage } from './storage';
|
||||
import {
|
||||
AvailableModels,
|
||||
type ChatHistory,
|
||||
@@ -28,7 +34,7 @@ import {
|
||||
|
||||
registerEnumType(AvailableModels, { name: 'CopilotModel' });
|
||||
|
||||
const COPILOT_LOCKER = 'copilot';
|
||||
export const COPILOT_LOCKER = 'copilot';
|
||||
|
||||
// ================== Input Types ==================
|
||||
|
||||
@@ -57,6 +63,9 @@ class CreateChatMessageInput implements Omit<SubmittedMessage, 'content'> {
|
||||
@Field(() => [String], { nullable: true })
|
||||
attachments!: string[] | undefined;
|
||||
|
||||
@Field(() => [GraphQLUpload], { nullable: true })
|
||||
blobs!: FileUpload[] | undefined;
|
||||
|
||||
@Field(() => GraphQLJSON, { nullable: true })
|
||||
params!: Record<string, string> | undefined;
|
||||
}
|
||||
@@ -140,7 +149,8 @@ export class CopilotResolver {
|
||||
constructor(
|
||||
private readonly permissions: PermissionService,
|
||||
private readonly mutex: MutexService,
|
||||
private readonly chatSession: ChatSessionService
|
||||
private readonly chatSession: ChatSessionService,
|
||||
private readonly storage: CopilotStorage
|
||||
) {}
|
||||
|
||||
@ResolveField(() => CopilotQuotaType, {
|
||||
@@ -260,6 +270,25 @@ export class CopilotResolver {
|
||||
if (!lock) {
|
||||
return new TooManyRequestsException('Server is busy');
|
||||
}
|
||||
const session = await this.chatSession.get(options.sessionId);
|
||||
if (!session) return new BadRequestException('Session not found');
|
||||
|
||||
if (options.blobs) {
|
||||
options.attachments = options.attachments || [];
|
||||
const { workspaceId } = session.config;
|
||||
|
||||
for (const blob of options.blobs) {
|
||||
const uploaded = await this.storage.handleUpload(user.id, blob);
|
||||
const link = await this.storage.put(
|
||||
user.id,
|
||||
workspaceId,
|
||||
uploaded.filename,
|
||||
uploaded.buffer
|
||||
);
|
||||
options.attachments.push(link);
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
return await this.chatSession.createMessage(options);
|
||||
} catch (e: any) {
|
||||
|
||||
@@ -35,6 +35,18 @@ export class ChatSession implements AsyncDisposable {
|
||||
return this.state.prompt.model;
|
||||
}
|
||||
|
||||
get config() {
|
||||
const {
|
||||
sessionId,
|
||||
userId,
|
||||
workspaceId,
|
||||
docId,
|
||||
prompt: { name: promptName },
|
||||
} = this.state;
|
||||
|
||||
return { sessionId, userId, workspaceId, docId, promptName };
|
||||
}
|
||||
|
||||
push(message: ChatMessage) {
|
||||
if (
|
||||
this.state.prompt.action &&
|
||||
|
||||
91
packages/backend/server/src/plugins/copilot/storage.ts
Normal file
91
packages/backend/server/src/plugins/copilot/storage.ts
Normal file
@@ -0,0 +1,91 @@
|
||||
import { createHash } from 'node:crypto';
|
||||
|
||||
import { Injectable, PayloadTooLargeException } from '@nestjs/common';
|
||||
|
||||
import { QuotaManagementService } from '../../core/quota';
|
||||
import {
|
||||
type BlobInputType,
|
||||
Config,
|
||||
type FileUpload,
|
||||
type StorageProvider,
|
||||
StorageProviderFactory,
|
||||
} from '../../fundamentals';
|
||||
|
||||
@Injectable()
|
||||
export class CopilotStorage {
|
||||
public readonly provider: StorageProvider;
|
||||
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
private readonly storageFactory: StorageProviderFactory,
|
||||
private readonly quota: QuotaManagementService
|
||||
) {
|
||||
this.provider = this.storageFactory.create('copilot');
|
||||
}
|
||||
|
||||
async put(
|
||||
userId: string,
|
||||
workspaceId: string,
|
||||
key: string,
|
||||
blob: BlobInputType
|
||||
) {
|
||||
const name = `${userId}/${workspaceId}/${key}`;
|
||||
await this.provider.put(name, blob);
|
||||
return `${this.config.baseUrl}/api/copilot/blob/${name}`;
|
||||
}
|
||||
|
||||
async get(userId: string, workspaceId: string, key: string) {
|
||||
return this.provider.get(`${userId}/${workspaceId}/${key}`);
|
||||
}
|
||||
|
||||
async delete(userId: string, workspaceId: string, key: string) {
|
||||
return this.provider.delete(`${userId}/${workspaceId}/${key}`);
|
||||
}
|
||||
|
||||
async handleUpload(userId: string, blob: FileUpload) {
|
||||
const checkExceeded = await this.quota.getQuotaCalculator(userId);
|
||||
|
||||
if (checkExceeded(0)) {
|
||||
throw new PayloadTooLargeException(
|
||||
'Storage or blob size limit exceeded.'
|
||||
);
|
||||
}
|
||||
const buffer = await new Promise<Buffer>((resolve, reject) => {
|
||||
const stream = blob.createReadStream();
|
||||
const chunks: Uint8Array[] = [];
|
||||
stream.on('data', chunk => {
|
||||
chunks.push(chunk);
|
||||
|
||||
// check size after receive each chunk to avoid unnecessary memory usage
|
||||
const bufferSize = chunks.reduce((acc, cur) => acc + cur.length, 0);
|
||||
if (checkExceeded(bufferSize)) {
|
||||
reject(
|
||||
new PayloadTooLargeException('Storage or blob size limit exceeded.')
|
||||
);
|
||||
}
|
||||
});
|
||||
stream.on('error', reject);
|
||||
stream.on('end', () => {
|
||||
const buffer = Buffer.concat(chunks);
|
||||
|
||||
if (checkExceeded(buffer.length)) {
|
||||
reject(new PayloadTooLargeException('Storage limit exceeded.'));
|
||||
} else {
|
||||
resolve(buffer);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return {
|
||||
buffer,
|
||||
filename: blob.filename,
|
||||
};
|
||||
}
|
||||
|
||||
async handleRemoteLink(userId: string, workspaceId: string, link: string) {
|
||||
const response = await fetch(link);
|
||||
const buffer = new Uint8Array(await response.arrayBuffer());
|
||||
const filename = createHash('sha256').update(buffer).digest('base64url');
|
||||
return this.put(userId, workspaceId, filename, Buffer.from(buffer));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user