diff --git a/packages/backend/server/src/__tests__/copilot.spec.ts b/packages/backend/server/src/__tests__/copilot.spec.ts index 442755edeb..b03621cda2 100644 --- a/packages/backend/server/src/__tests__/copilot.spec.ts +++ b/packages/backend/server/src/__tests__/copilot.spec.ts @@ -11,6 +11,7 @@ import { EventBus, JobQueue } from '../base'; import { ConfigModule } from '../base/config'; import { AuthService } from '../core/auth'; import { QuotaModule } from '../core/quota'; +import { StorageModule, WorkspaceBlobStorage } from '../core/storage'; import { ContextCategories, CopilotSessionModel, @@ -68,6 +69,7 @@ type Context = { db: PrismaClient; event: EventBus; workspace: WorkspaceModel; + workspaceStorage: WorkspaceBlobStorage; copilotSession: CopilotSessionModel; context: CopilotContextService; prompt: PromptService; @@ -114,6 +116,7 @@ test.before(async t => { }, }), QuotaModule, + StorageModule, CopilotModule, ], tapModule: builder => { @@ -127,6 +130,7 @@ test.before(async t => { const db = module.get(PrismaClient); const event = module.get(EventBus); const workspace = module.get(WorkspaceModel); + const workspaceStorage = module.get(WorkspaceBlobStorage); const copilotSession = module.get(CopilotSessionModel); const prompt = module.get(PromptService); const factory = module.get(CopilotProviderFactory); @@ -146,6 +150,7 @@ test.before(async t => { t.context.db = db; t.context.event = event; t.context.workspace = workspace; + t.context.workspaceStorage = workspaceStorage; t.context.copilotSession = copilotSession; t.context.prompt = prompt; t.context.factory = factory; @@ -1520,8 +1525,16 @@ test('TextStreamParser should process a sequence of message chunks', t => { // ==================== context ==================== test('should be able to manage context', async t => { - const { context, db, event, jobs, prompt, session, storage, workspace } = - t.context; + const { + context, + event, + jobs, + prompt, + session, + storage, + workspace, + workspaceStorage, + } = t.context; const ws = await workspace.create(userId); @@ -1614,21 +1627,9 @@ test('should be able to manage context', async t => { // blob record { const blobId = 'test-blob'; - await storage.put(userId, session.workspaceId, blobId, buffer); - await db.blob.create({ - data: { - workspaceId: session.workspaceId, - key: blobId, - size: buffer.length, - mime: 'application/pdf', - }, - }); + await workspaceStorage.put(session.workspaceId, blobId, buffer); - await jobs.embedPendingBlob({ - userId, - workspaceId: session.workspaceId, - blobId, - }); + await jobs.embedPendingBlob({ workspaceId: session.workspaceId, blobId }); const result = await t.context.context.matchWorkspaceBlobs( session.workspaceId, diff --git a/packages/backend/server/src/plugins/copilot/context/resolver.ts b/packages/backend/server/src/plugins/copilot/context/resolver.ts index 9177caf073..504646b3e3 100644 --- a/packages/backend/server/src/plugins/copilot/context/resolver.ts +++ b/packages/backend/server/src/plugins/copilot/context/resolver.ts @@ -742,6 +742,12 @@ export class CopilotContextResolver { const contextSession = await this.context.get(options.contextId); + await this.ac + .user(user.id) + .workspace(contextSession.workspaceId) + .allowLocal() + .assert('Workspace.Copilot'); + try { const blob = await contextSession.addBlobRecord(options.blobId); if (!blob) { @@ -752,7 +758,6 @@ export class CopilotContextResolver { } await this.jobs.addBlobEmbeddingQueue({ - userId: user.id, workspaceId: contextSession.workspaceId, contextId: contextSession.id, blobId: options.blobId, diff --git a/packages/backend/server/src/plugins/copilot/embedding/job.ts b/packages/backend/server/src/plugins/copilot/embedding/job.ts index 7b5ec752b4..c40e100744 100644 --- a/packages/backend/server/src/plugins/copilot/embedding/job.ts +++ b/packages/backend/server/src/plugins/copilot/embedding/job.ts @@ -12,6 +12,7 @@ import { OnJob, } from '../../../base'; import { DocReader } from '../../../core/doc'; +import { WorkspaceBlobStorage } from '../../../core/storage'; import { readAllDocIdsFromWorkspaceSnapshot } from '../../../core/utils/blocksuite'; import { Models } from '../../../models'; import { CopilotStorage } from '../storage'; @@ -224,6 +225,20 @@ export class CopilotEmbeddingJob { return new File([buffer], fileName); } + private async readWorkspaceBlob( + workspaceId: string, + blobId: string, + fileName: string + ) { + const workspaceStorage = this.moduleRef.get(WorkspaceBlobStorage, { + strict: false, + }); + const { body } = await workspaceStorage.get(workspaceId, blobId); + if (!body) throw new BlobNotFound({ spaceId: workspaceId, blobId }); + const buffer = await readStream(body); + return new File([buffer], fileName); + } + @OnJob('copilot.embedding.files') async embedPendingFile({ userId, @@ -289,7 +304,6 @@ export class CopilotEmbeddingJob { @OnJob('copilot.embedding.blobs') async embedPendingBlob({ - userId, workspaceId, contextId, blobId, @@ -297,12 +311,7 @@ export class CopilotEmbeddingJob { if (!this.supportEmbedding || !this.embeddingClient) return; try { - const file = await this.readCopilotBlob( - userId, - workspaceId, - blobId, - 'blob' - ); + const file = await this.readWorkspaceBlob(workspaceId, blobId, 'blob'); const chunks = await this.embeddingClient.getFileChunks(file); const total = chunks.reduce((acc, c) => acc + c.length, 0); diff --git a/packages/backend/server/src/plugins/copilot/embedding/types.ts b/packages/backend/server/src/plugins/copilot/embedding/types.ts index 73097f2871..d56334d888 100644 --- a/packages/backend/server/src/plugins/copilot/embedding/types.ts +++ b/packages/backend/server/src/plugins/copilot/embedding/types.ts @@ -76,7 +76,6 @@ declare global { 'copilot.embedding.blobs': { contextId?: string; - userId: string; workspaceId: string; blobId: string; };