Compare commits

...

15 Commits

Author SHA1 Message Date
DarkSky c36b5b201e chore(i18n): update i18n (#15191)
#### PR Dependency Tree


* **PR #15191** 👈

This tree was auto-generated by
[Charcoal](https://github.com/danerwilliams/charcoal)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **New Features**
* Updated sign-in, sign-up, and email content to show the correct server
name for cloud and self-hosted setups.
* Added clearer workspace status messages for syncing, local workspaces,
and server-connected workspaces.
* Introduced localized text updates for more languages, including new
locale coverage.

* **Bug Fixes**
* Replaced outdated “Cloud” wording throughout the app with more
accurate “Sync” and self-hosted terminology.
* Improved account deletion, password reset, sharing, and storage
prompts for clearer user guidance.
* Updated workspace status labels and tooltips to better reflect the
current server connection.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-07-03 18:37:20 +08:00
keepClamDown cebd7296b1 fix(ios): restore simulator Rust build outputs (#15190)
- sync the iOS ATT pod lockfile with the Capacitor plugin version
already referenced by the repo
- make `xc-universal-binary.sh` produce a real universal simulator Rust
archive for `arm64` and `x86_64`
- run `uniffi-bindgen` against a single-slice archive before assembling
the fat simulator output

- [x] `bash -n packages/frontend/apps/ios/App/xc-universal-binary.sh`
- [x] `xcodebuild -workspace \"App.xcworkspace\" -scheme App
-destination 'generic/platform=iOS Simulator' -derivedDataPath
\".derivedData-rebuild\" build`

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Bug Fixes**
* Improved the iOS build process so universal binaries are generated
from consistent output locations, reducing build path issues.
* Made device and simulator library handling more reliable during
packaging and code generation, helping produce the expected app binaries
more consistently.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-07-03 18:11:08 +08:00
akshitha-07 1f0bcd01a3 fix: enforce Doc.Read permission on workspace histories field (#15192)
The histories() resolver was returning document edit history (including
editor names, emails, and timestamps) without checking Doc.Read
permission first. This let any workspace member view history for private
docs they weren't given access to, by passing an arbitrary document
guid.

Added the same permission check already used by
WorkspaceDocResolver.doc() and recoverDoc() in this file.

Fixes #15179

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Bug Fixes**
* Added permission checks when viewing document history so only
authorized users can access snapshot histories.
* Prevented workspace collaborators without read access from querying
histories for private documents.
* **Tests**
* Added an end-to-end test to verify history access is denied when the
required permission is missing.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-07-03 18:08:24 +08:00
DarkSky 5b7f83a6e3 feat(server): batch blob gc (#15183) 2026-07-02 06:14:01 +08:00
DarkSky 6f9269498f fix(server): blob gc planning 2026-07-02 03:28:56 +08:00
DarkSky e5d44b8ff2 fix(server): s3 metadata encode 2026-07-02 00:27:17 +08:00
DarkSky 8c68319094 feat(server): improve client builder 2026-07-01 23:27:20 +08:00
DarkSky 8ebdb7452f feat(server): impl storage runtime (#15181)
#### PR Dependency Tree


* **PR #15181** 👈

This tree was auto-generated by
[Charcoal](https://github.com/danerwilliams/charcoal)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added an additional storage backend option: asset-pack based storage
(provider for avatar, blob, and copilot).
* Introduced a dedicated storage runtime with provider capability
reporting and expanded object operations (put/head/get/list/delete),
including presigned and multipart flows where supported.
* Cloudflare R2 `jurisdiction` now uses an explicit default when
omitted.
* **Bug Fixes**
  * Broadened avatar access to allow both fs and asset-pack providers.
* Improved workspace blob upload completion validation and handling when
stored objects are missing or mismatched.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-07-01 22:24:10 +08:00
FailSafe da7d438377 fix: enforce quota for comment attachments (#15149)
## Summary

This change includes comment attachments in workspace storage usage and
checks workspace storage quota before accepting a new comment attachment
upload.

## Impact

Comment attachments already had a per-file size limit, but they were not
counted in the same workspace storage usage path as other uploaded
blobs. A user with comment permission could keep adding attachments
without those bytes participating in workspace storage quota
calculations.

## Fix

- Count comment attachment bytes in workspace storage usage
reconciliation.
- Check the workspace quota before storing a new comment attachment.
- Return the existing comment attachment quota error when the upload
would exceed limits.

## Validation

- `git diff --check`
- Full test/lint suite was not run locally because dependencies are not
installed in this checkout.


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **New Features**
* Workspace attachment uploads now respect storage and file quota limits
more accurately.
* Workspace storage tracking now includes comment attachments, improving
quota enforcement.

* **Bug Fixes**
* Attachment uploads now fail with a clear quota error when a workspace
is out of space or blob capacity.
* Storage usage calculations now better reflect actual workspace
content, including non-deleted files.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: failsafesecurity <190101117+failsafesecurity@users.noreply.github.com>
2026-07-01 08:45:33 +08:00
DarkSky a821f67fc9 fix: config override 2026-06-30 04:37:52 +08:00
DarkSky a1363b3873 fix(server): config & update handle (#15173)
#### PR Dependency Tree


* **PR #15173** 👈

This tree was auto-generated by
[Charcoal](https://github.com/danerwilliams/charcoal)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added native document update validation to check incoming Yjs updates
for decodability before applying them.
* Introduced support for validation timeouts and cancellation during
update checks.
* Blob maintenance jobs now detect when object storage is unavailable
and skip related work gracefully.

* **Bug Fixes**
* Invalid (and oversized) updates are now filtered out earlier during
document ingestion.
* Background blob maintenance continues processing other work even if
one workspace fails.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-06-29 22:59:17 +08:00
DarkSky 1b9e21f2de fix(core): handle unsupported server error (#15164)
fix #15160
fix #15161
fix #15158
fix #15166


#### PR Dependency Tree


* **PR #15164** 👈

This tree was auto-generated by
[Charcoal](https://github.com/danerwilliams/charcoal)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added a “server version too old” message for self-hosted servers,
including the required upgrade version.
* Sign-in and OAuth-related preflight steps now verify server
compatibility before proceeding.
* **Bug Fixes**
* Improved error handling for missing/invalid server version responses
and schema/type mismatches, mapping them to the upgrade instruction.
* **Tests**
* Added coverage for server version guarding and the resulting
user-friendly error payload.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-06-29 00:03:02 +08:00
DarkSky 0a422aa158 feat(server): blob reconciliation (#15165)
#### PR Dependency Tree


* **PR #15165** 👈

This tree was auto-generated by
[Charcoal](https://github.com/danerwilliams/charcoal)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added automated backend maintenance for missing blob metadata
backfill, document-to-blob reference rebuilding, and unreferenced blob
cleanup planning/execution.
* Introduced scheduled batch processing (workspace-paged) and paginated
object-storage listing.
* **Bug Fixes**
* Improved reliability of object-storage reads by treating expected “not
found” results as non-errors.
* Strengthened blob/expired cleanup flows with runtime-driven batching
and reduced coupling to metadata synchronization.
* **Tests**
* Expanded unit and e2e coverage for partial blob metadata and updated
runtime/job cleanup test assertions.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-06-29 00:02:38 +08:00
DarkSky 4a7c931eca fix(server): member loading (#15156)
#### PR Dependency Tree


* **PR #15156** 👈

This tree was auto-generated by
[Charcoal](https://github.com/danerwilliams/charcoal)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Bug Fixes**
* Fixed Stripe subscription syncing for Team plans to update the correct
existing local subscription (avoiding duplicates) while refreshing
quantity and billing/trial period details.
* **UI/UX Improvements**
* Improved workspace member list loading/error states with shared UI
components and steadier pagination behavior.
  * Refined fallback styling for cleaner, more stable layout.
* **Tests**
* Expanded subscription and projection coverage and adjusted
seat-allocation/e2e assertions to be more robust.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-06-26 21:29:15 +08:00
DarkSky 8e036a2f38 fix(server): workspace sub status (#15155)
#### PR Dependency Tree


* **PR #15155** 👈

This tree was auto-generated by
[Charcoal](https://github.com/danerwilliams/charcoal)
2026-06-26 17:07:56 +08:00
194 changed files with 10781 additions and 5726 deletions
+105
View File
@@ -121,6 +121,18 @@
"default": {
"concurrency": 1
}
},
"queues.backendRuntime": {
"type": "object",
"description": "The config for backend runtime job queue\n@default {\"concurrency\":1}",
"properties": {
"concurrency": {
"type": "number"
}
},
"default": {
"concurrency": 1
}
}
}
},
@@ -493,6 +505,7 @@
"jurisdiction": {
"type": "string",
"enum": [
"default",
"eu"
],
"description": "Optional jurisdiction for the cloudflare r2 endpoint. Set to \"eu\" for EU buckets."
@@ -518,6 +531,36 @@
}
}
}
},
{
"type": "object",
"properties": {
"provider": {
"type": "string",
"enum": [
"assetpack"
]
},
"bucket": {
"type": "string"
},
"config": {
"type": "object",
"properties": {
"path": {
"type": "string"
}
},
"required": [
"path"
]
}
},
"required": [
"provider",
"bucket",
"config"
]
}
],
"default": {
@@ -691,6 +734,7 @@
"jurisdiction": {
"type": "string",
"enum": [
"default",
"eu"
],
"description": "Optional jurisdiction for the cloudflare r2 endpoint. Set to \"eu\" for EU buckets."
@@ -716,6 +760,36 @@
}
}
}
},
{
"type": "object",
"properties": {
"provider": {
"type": "string",
"enum": [
"assetpack"
]
},
"bucket": {
"type": "string"
},
"config": {
"type": "object",
"properties": {
"path": {
"type": "string"
}
},
"required": [
"path"
]
}
},
"required": [
"provider",
"bucket",
"config"
]
}
],
"default": {
@@ -1332,6 +1406,7 @@
"jurisdiction": {
"type": "string",
"enum": [
"default",
"eu"
],
"description": "Optional jurisdiction for the cloudflare r2 endpoint. Set to \"eu\" for EU buckets."
@@ -1357,6 +1432,36 @@
}
}
}
},
{
"type": "object",
"properties": {
"provider": {
"type": "string",
"enum": [
"assetpack"
]
},
"bucket": {
"type": "string"
},
"config": {
"type": "object",
"properties": {
"path": {
"type": "string"
}
},
"required": [
"path"
]
}
},
"required": [
"provider",
"bucket",
"config"
]
}
],
"default": {
Generated
+387 -589
View File
File diff suppressed because it is too large Load Diff
+2 -24
View File
@@ -17,17 +17,12 @@ resolver = "3"
aes-gcm = "0.10"
affine_common = { path = "./packages/common/native" }
affine_nbstore = { path = "./packages/frontend/native/nbstore" }
ahash = "0.8"
anyhow = "1"
arbitrary = { version = "1.3", features = ["derive"] }
assert-json-diff = "2.0"
base64 = "0.22.1"
base64-simd = "0.8"
bitvec = "1.0"
block2 = "0.6"
byteorder = "1.5"
chrono = "0.4"
clap = { version = "4.4", features = ["derive"] }
core-foundation = "0.10"
coreaudio-rs = "0.12"
cpal = "0.15"
@@ -48,14 +43,10 @@ resolver = "3"
"webp",
] }
infer = { version = "0.19.0" }
lasso = { version = "0.7", features = ["multi-threaded"] }
lib0 = { version = "0.16", features = ["lib0-serde"] }
libc = "0.2"
libwebp-sys = "0.14.2"
little_exif = "0.6.23"
llm_adapter = { version = "0.2", default-features = false }
llm_runtime = { version = "0.2", default-features = false }
log = "0.4"
lru = "0.16"
matroska = "0.30"
memory-indexer = "0.3.1"
@@ -72,33 +63,22 @@ resolver = "3"
] }
napi-build = { version = "2" }
napi-derive = { version = "3.4" }
nom = "8"
notify = { version = "8", features = ["serde"] }
objc2 = "0.6"
objc2-foundation = "0.3"
ogg = "0.9"
once_cell = "1"
ordered-float = "5"
p256 = { version = "0.13", features = ["ecdsa", "pem"] }
parking_lot = "0.12"
phf = { version = "0.11", features = ["macros"] }
proptest = "1.3"
proptest-derive = "0.5"
pulldown-cmark = "0.13"
rand = "0.9"
rand_chacha = "0.9"
rand_distr = "0.5"
rayon = "1.10"
regex = "1.10"
rubato = "0.16"
safefetch = "0.1.0"
schemars = "0.8"
screencapturekit = "0.3"
serde = "1"
serde_json = "1"
sha2 = "0.10"
sha3 = "0.10"
smol_str = "0.3"
sha2 = "0.11"
sha3 = "0.11"
sqlx = { version = "0.8", default-features = false, features = [
"chrono",
"macros",
@@ -136,8 +116,6 @@ resolver = "3"
] }
windows-core = { version = "0.61" }
y-octo = "0.0.3"
y-sync = { version = "0.4" }
yrs = "0.23.0"
[profile.dev.package.sqlx-macros]
opt-level = 3
+20 -5
View File
@@ -16,17 +16,20 @@ affine_common = { workspace = true, features = [
"ydoc-loader",
] }
anyhow = { workspace = true }
aws-sdk-s3 = "1.137.0"
assetpack-core = "0.1.0"
assetpack-transform-precomp2 = "0.1.0"
base64 = { workspace = true }
chrono = { workspace = true }
crc32fast = "1.5.0"
doc_extractor = { workspace = true }
file-format = { workspace = true }
hex = { workspace = true }
homedir = { workspace = true }
image = { workspace = true }
infer = { workspace = true }
instant-xml = "0.7.5"
jsonschema = "0.46"
libwebp-sys = { workspace = true }
libwebp-sys = "0.9.6"
little_exif = { workspace = true }
llm_adapter = { workspace = true, features = ["schema", "ureq-client"] }
llm_runtime = { workspace = true, features = ["schema", "ureq-client"] }
@@ -36,6 +39,15 @@ napi = { workspace = true, features = ["async", "serde-json"] }
napi-derive = { workspace = true }
p256 = { workspace = true }
rand = { workspace = true }
reqwest = { version = "0.13.4", default-features = false, features = [
"rustls",
] }
rustls = { version = "0.23", default-features = false, features = [
"aws-lc-rs",
"std",
"tls12",
] }
rusty-s3 = "0.10.0"
safefetch = { workspace = true }
schemars = { workspace = true }
serde = { workspace = true, features = ["derive"] }
@@ -50,11 +62,13 @@ sqlx = { workspace = true, default-features = false, features = [
"postgres",
"runtime-tokio",
] }
thiserror.workspace = true
tiktoken-rs = { workspace = true }
tokio = { workspace = true, features = ["rt-multi-thread", "sync"] }
tokio = { workspace = true, features = ["rt-multi-thread", "sync", "time"] }
url = { workspace = true }
uuid = { workspace = true, features = ["v4"] }
v_htmlescape = { workspace = true }
webpki-roots = "1.0"
y-octo = { workspace = true, features = ["large_refs"] }
[target.'cfg(not(target_os = "linux"))'.dependencies]
@@ -64,8 +78,9 @@ mimalloc = { workspace = true }
mimalloc = { workspace = true, features = ["local_dynamic_tls"] }
[dev-dependencies]
rayon = { workspace = true }
tokio = { workspace = true, features = ["macros", "rt-multi-thread"] }
rayon = { workspace = true }
tempfile = "3.27.0"
tokio = { workspace = true, features = ["macros", "rt-multi-thread"] }
[build-dependencies]
napi-build = { workspace = true }
+6
View File
@@ -0,0 +1,6 @@
/* auto-generated by NAPI-RS */
/* eslint-disable */
declare const _default: typeof import('./index')
export default _default
+102 -33
View File
@@ -1,10 +1,10 @@
/* auto-generated by NAPI-RS */
/* eslint-disable */
declare const _default: typeof import('./index')
export default _default
export declare class BackendRuntime {
completeBlobUpload(workspaceId: string, key: string, expectedSize: number, expectedMime: string): Promise<RuntimeBlobCompleteResult>
completeFsBlobUpload(root: string, bucket: string, workspaceId: string, key: string, expectedSize: number, expectedMime: string): Promise<RuntimeBlobCompleteResult>
cleanupExpiredPendingBlobs(cutoffMs: number, limit: number): Promise<RuntimeBlobCleanupResult>
releaseDeletedBlobs(workspaceId: string, limit: number): Promise<RuntimeBlobCleanupResult>
acquireCoordinationLease(key: string, owner: string, ttlMs: number): Promise<CoordinationLeaseGrant | null>
releaseCoordinationLease(key: string, owner: string, fencingToken: bigint | number): Promise<boolean>
renewCoordinationLease(key: string, owner: string, fencingToken: bigint | number, ttlMs: number): Promise<boolean>
@@ -27,18 +27,6 @@ export declare class BackendRuntime {
cleanupExpiredRuntimeGates(limit: number): Promise<number>
cleanupExpiredUserSessions(limit: number): Promise<number>
cleanupExpiredSnapshotHistories(limit: number): Promise<number>
objectStorageHealth(): RuntimeObjectStorageHealth
objectStoragePut(key: string, body: Buffer, metadata?: RuntimeObjectStoragePutOptions | undefined | null): Promise<void>
objectStoragePresignPut(key: string, metadata?: RuntimeObjectStoragePutOptions | undefined | null): Promise<RuntimePresignedObjectRequest>
objectStorageCreateMultipartUpload(key: string, metadata?: RuntimeObjectStoragePutOptions | undefined | null): Promise<RuntimeMultipartUploadInit | null>
objectStoragePresignUploadPart(key: string, uploadId: string, partNumber: number): Promise<RuntimePresignedObjectRequest>
objectStorageListMultipartUploadParts(key: string, uploadId: string): Promise<Array<RuntimeMultipartUploadPart>>
objectStorageCompleteMultipartUpload(key: string, uploadId: string, parts: Array<RuntimeMultipartUploadPart>): Promise<void>
objectStorageAbortMultipartUpload(key: string, uploadId: string): Promise<void>
objectStorageHead(key: string): Promise<RuntimeObjectMetadata | null>
objectStorageGet(key: string): Promise<RuntimeObjectGetResult | null>
objectStorageList(prefix?: string | undefined | null): Promise<Array<RuntimeObjectListEntry>>
objectStorageDelete(key: string): Promise<void>
createAuthChallenge(purpose: string, token: string, payload: any, ttlMs: number): Promise<boolean>
getAuthChallenge(purpose: string, token: string): Promise<any | null>
consumeAuthChallenge(purpose: string, token: string): Promise<any | null>
@@ -70,6 +58,37 @@ export declare class LlmStreamHandle {
abort(): void
}
export declare class StorageRuntime {
planUnreferencedWorkspaceBlobs(workspaceId: string, gracePeriodDays: number, limit: number): Promise<RuntimeBlobCleanupPlanResult>
executeBlobCleanupCandidates(runId: string, gracePeriodDays: number, limit: number): Promise<RuntimeBlobCleanupExecuteResult>
cleanupExpiredPendingBlobs(cutoffMs: number, limit: number): Promise<RuntimeBlobCleanupResult>
releaseDeletedBlobs(workspaceId: string, limit: number): Promise<RuntimeBlobCleanupResult>
backfillMissingBlobMetadata(workspaceId: string | undefined | null, limit: number): Promise<RuntimeBlobMetadataBackfillResult>
rebuildDocBlobRefs(workspaceId: string, docId: string): Promise<RuntimeDocBlobRefsResult>
rebuildWorkspaceDocBlobRefs(workspaceId: string, limit: number): Promise<RuntimeDocBlobRefsResult>
constructor()
start(): Promise<void>
configure(configJson: string): void
stop(): Promise<void>
runMigrations(): Promise<void>
health(): Promise<StorageRuntimeHealth>
providerCapabilities(scope: string): Promise<StorageProviderCapabilities>
putObject(scope: string, key: string, body: Buffer, metadata?: RuntimeObjectStoragePutOptions | undefined | null): Promise<RuntimeObjectMetadata>
headObject(scope: string, key: string): Promise<RuntimeObjectMetadata | null>
getObject(scope: string, key: string): Promise<RuntimeObjectGetResult | null>
listObjects(scope: string, prefix?: string | undefined | null): Promise<Array<RuntimeObjectListEntry>>
deleteObject(scope: string, key: string): Promise<void>
presignPut(scope: string, key: string, metadata?: RuntimeObjectStoragePutOptions | undefined | null): Promise<RuntimePresignedObjectRequest | null>
presignGet(scope: string, key: string): Promise<RuntimePresignedObjectRequest | null>
createMultipartUpload(scope: string, key: string, metadata?: RuntimeObjectStoragePutOptions | undefined | null): Promise<RuntimeMultipartUploadInit | null>
presignUploadPart(scope: string, key: string, uploadId: string, partNumber: number): Promise<RuntimePresignedObjectRequest | null>
proxyUploadPart(scope: string, key: string, uploadId: string, partNumber: number, body: Buffer, contentLength?: number | undefined | null): Promise<string | null>
listMultipartUploadParts(scope: string, key: string, uploadId: string): Promise<Array<RuntimeMultipartUploadPart> | null>
completeMultipartUpload(scope: string, key: string, uploadId: string, parts: Array<RuntimeMultipartUploadPart>): Promise<boolean>
abortMultipartUpload(scope: string, key: string, uploadId: string): Promise<boolean>
completeWorkspaceBlobUpload(workspaceId: string, key: string, expectedSize: number, expectedMime: string): Promise<RuntimeBlobCompleteResult>
}
export declare class Tokenizer {
count(content: string, allowedSpecial?: Array<string> | undefined | null): number
}
@@ -143,7 +162,6 @@ export interface AssertSafeUrlRequest {
export interface BackendRuntimeHealth {
started: boolean
databaseConnected: boolean
objectStorageConfigured: boolean
}
export declare function buildPublicRootDoc(rootDocBin: Buffer, docMetas: Array<PublicDocMetaInput>): Buffer
@@ -816,6 +834,25 @@ export declare function resolveEntitlementV1(input: ResolveEntitlementInput): Re
export declare function runNativeActionRecipePreparedStream(input: ActionRuntimeInput, callback: ((err: Error | null, arg: string) => void)): LlmStreamHandle
export interface RuntimeBlobCleanupExecuteResult {
scannedCandidates: number
deletedObjects: number
deletedMetadata: number
skippedStillReferenced: number
failed: number
workspaceIds: Array<string>
}
export interface RuntimeBlobCleanupPlanResult {
runId?: string
scannedBlobs: number
candidatesMarked: number
protectedByDocRefs: number
protectedByMetadata: number
protectedByOtherRefs: number
nextCursor?: string
}
export interface RuntimeBlobCleanupResult {
scanned: number
deleted: number
@@ -831,12 +868,32 @@ export interface RuntimeBlobCompleteResult {
lastModifiedMs?: number
}
export interface RuntimeBlobMetadataBackfillResult {
scannedObjects: number
headedObjects: number
upsertedMetadata: number
skippedExisting: number
skippedWorkspaceMissing: number
failed: number
nextCursor?: string
workspaceIds: Array<string>
}
export interface RuntimeByokLocalLeaseRecord {
leaseId: string
payload: any
expiresAtMs: number
}
export interface RuntimeDocBlobRefsResult {
scannedDocs: number
parsedDocs: number
refsWritten: number
refsDeleted: number
failedDocs: number
nextCursor?: string
}
export interface RuntimeDocCompactionResult {
leaseAcquired: boolean
merged: boolean
@@ -891,22 +948,6 @@ export interface RuntimeObjectMetadata {
checksumCrc32?: string
}
export interface RuntimeObjectStorageHealth {
configured: boolean
provider?: string
bucket?: string
endpoint?: string
region?: string
hasCredentials: boolean
forcePathStyle: boolean
requestTimeoutMs?: number
minPartSize?: number
presignExpiresInSeconds?: number
presignSignContentTypeForPut?: boolean
usePresignedUrl: boolean
clientBuildable: boolean
}
export interface RuntimeObjectStoragePutOptions {
contentType?: string
contentLength?: number
@@ -989,6 +1030,28 @@ export interface SafeFetchResponse {
body: Buffer
}
export interface StorageProviderCapabilities {
put: boolean
get: boolean
head: boolean
list: boolean
delete: boolean
presignPut: boolean
presignGet: boolean
multipartDirect: boolean
proxyUpload: boolean
assetpack: boolean
serverMediatedOnly: boolean
}
export interface StorageRuntimeHealth {
started: boolean
databaseConnected: boolean
providerConfigured: boolean
provider?: string
bucket?: string
}
export interface ToolContract {
name: string
description?: string
@@ -1055,4 +1118,10 @@ export declare function updateLicenseSeats(request: LicenseSeatsRequest): Promis
*/
export declare function updateRootDocMetaTitle(rootDocBin: Buffer, docId: string, title: string): Buffer
/**
* Check whether a Yjs update binary can be decoded without applying it to a
* document state.
*/
export declare function validateDocUpdate(update: Buffer): Promise<boolean>
export declare function verifyChallengeResponse(response: string, bits: number, resource: string): Promise<boolean>
+1
View File
@@ -16,6 +16,7 @@
},
"napi": {
"binaryName": "server-native",
"dtsHeaderFile": "dts-header.d.ts",
"targets": [
"aarch64-apple-darwin",
"aarch64-unknown-linux-gnu",
@@ -1,296 +0,0 @@
use std::{
fs,
path::{Path, PathBuf},
};
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use napi::Result;
use serde::Deserialize;
use sha2::{Digest, Sha256};
use super::{BackendRuntime, error::napi_error, types::RuntimeBlobCompleteResult};
const MAX_BLOB_SIZE: i64 = i32::MAX as i64;
fn object_missing_error(err: &napi::Error) -> bool {
let message = err.to_string();
message.contains("NoSuchKey") || message.contains("NotFound") || message.contains("not found")
}
fn blob_complete_failure(reason: &str) -> RuntimeBlobCompleteResult {
RuntimeBlobCompleteResult {
ok: false,
reason: Some(reason.to_string()),
content_type: None,
content_length: None,
last_modified_ms: None,
}
}
fn blob_complete_success(
content_type: String,
content_length: i64,
last_modified_ms: i64,
) -> RuntimeBlobCompleteResult {
RuntimeBlobCompleteResult {
ok: true,
reason: None,
content_type: Some(content_type),
content_length: Some(content_length),
last_modified_ms: Some(last_modified_ms),
}
}
fn normalize_base64_url_key(key: &str) -> &str {
key.trim_end_matches('=')
}
fn sha256_base64_url(body: &[u8]) -> String {
URL_SAFE_NO_PAD.encode(Sha256::digest(body))
}
fn sha256_base64_url_matches(body: &[u8], key: &str) -> bool {
sha256_base64_url(body) == normalize_base64_url_key(key)
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct FsBlobMetadata {
content_type: String,
content_length: i64,
last_modified: i64,
}
fn normalize_storage_key(key: &str) -> Result<Vec<String>> {
let normalized = key.replace('\\', "/");
let segments = normalized.split('/').map(ToString::to_string).collect::<Vec<_>>();
if normalized.is_empty()
|| normalized.starts_with('/')
|| segments
.iter()
.any(|segment| segment.is_empty() || segment == "." || segment == "..")
{
return Err(napi_error(format!("Invalid storage key: {key}")));
}
Ok(segments)
}
fn fs_bucket_path(root: &str, bucket: &str) -> PathBuf {
if let Some(stripped) = root.strip_prefix("~/")
&& let Ok(Some(home)) = homedir::my_home()
{
return home.join(stripped).join(bucket);
}
Path::new(root).join(bucket)
}
fn fs_object_path(root: &str, bucket: &str, key: &str) -> Result<PathBuf> {
let mut path = fs_bucket_path(root, bucket);
for segment in normalize_storage_key(key)? {
path.push(segment);
}
Ok(path)
}
fn read_fs_metadata(path: &Path) -> Result<Option<FsBlobMetadata>> {
let metadata_path = PathBuf::from(format!("{}.metadata.json", path.display()));
let raw = match fs::read_to_string(metadata_path) {
Ok(raw) => raw,
Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(None),
Err(err) => {
return Err(napi_error(format!("BlobComplete read fs metadata failed: {err}")));
}
};
serde_json::from_str(&raw).map(Some).map_err(|err| {
napi_error(format!(
"BlobComplete parse fs metadata failed for {}: {err}",
path.display()
))
})
}
async fn upsert_completed_blob(
runtime: &BackendRuntime,
workspace_id: &str,
key: &str,
mime: &str,
size: i64,
) -> Result<()> {
if !(0..=MAX_BLOB_SIZE).contains(&size) {
return Err(napi_error("BlobComplete size exceeds limit"));
}
let size = i32::try_from(size).map_err(|_| napi_error("BlobComplete size exceeds limit"))?;
sqlx::query(
r#"
INSERT INTO blobs (workspace_id, key, mime, size, status, upload_id)
VALUES ($1, $2, $3, $4, 'completed', NULL)
ON CONFLICT (workspace_id, key)
DO UPDATE SET
mime = EXCLUDED.mime,
size = EXCLUDED.size,
status = EXCLUDED.status,
upload_id = NULL
"#,
)
.bind(workspace_id)
.bind(key)
.bind(mime)
.bind(size)
.execute(&runtime.pool().await?)
.await
.map_err(|err| napi_error(format!("BlobComplete upsert metadata failed: {err}")))?;
Ok(())
}
#[napi_derive::napi]
impl BackendRuntime {
#[napi]
pub async fn complete_blob_upload(
&self,
workspace_id: String,
key: String,
expected_size: i64,
expected_mime: String,
) -> Result<RuntimeBlobCompleteResult> {
if !(0..=MAX_BLOB_SIZE).contains(&expected_size) {
return Ok(blob_complete_failure("size_too_large"));
}
let object_key = format!("{workspace_id}/{key}");
let object = match self.object_storage_get(object_key.clone()).await {
Ok(Some(object)) => object,
Ok(None) => return Ok(blob_complete_failure("not_found")),
Err(err) if object_missing_error(&err) => return Ok(blob_complete_failure("not_found")),
Err(err) => return Err(err),
};
if !(0..=MAX_BLOB_SIZE).contains(&object.metadata.content_length) {
match self.object_storage_delete(object_key).await {
Ok(()) => {}
Err(err) if object_missing_error(&err) => {}
Err(err) => return Err(err),
}
return Ok(blob_complete_failure("size_too_large"));
}
if object.metadata.content_length != expected_size {
return Ok(blob_complete_failure("size_mismatch"));
}
if !expected_mime.is_empty() && object.metadata.content_type != expected_mime {
return Ok(blob_complete_failure("mime_mismatch"));
}
if !sha256_base64_url_matches(&object.body, &key) {
match self.object_storage_delete(object_key).await {
Ok(()) => {}
Err(err) if object_missing_error(&err) => {}
Err(err) => return Err(err),
}
return Ok(blob_complete_failure("checksum_mismatch"));
}
upsert_completed_blob(
self,
&workspace_id,
&key,
&object.metadata.content_type,
object.metadata.content_length,
)
.await?;
Ok(blob_complete_success(
object.metadata.content_type,
object.metadata.content_length,
object.metadata.last_modified_ms,
))
}
#[napi]
pub async fn complete_fs_blob_upload(
&self,
root: String,
bucket: String,
workspace_id: String,
key: String,
expected_size: i64,
expected_mime: String,
) -> Result<RuntimeBlobCompleteResult> {
if !(0..=MAX_BLOB_SIZE).contains(&expected_size) {
return Ok(blob_complete_failure("size_too_large"));
}
let storage_key = format!("{workspace_id}/{key}");
let path = fs_object_path(&root, &bucket, &storage_key)?;
let metadata = match read_fs_metadata(&path)? {
Some(metadata) => metadata,
None => return Ok(blob_complete_failure("not_found")),
};
if !(0..=MAX_BLOB_SIZE).contains(&metadata.content_length) {
let _ = fs::remove_file(&path);
let _ = fs::remove_file(PathBuf::from(format!("{}.metadata.json", path.display())));
return Ok(blob_complete_failure("size_too_large"));
}
if metadata.content_length != expected_size {
return Ok(blob_complete_failure("size_mismatch"));
}
if !expected_mime.is_empty() && metadata.content_type != expected_mime {
return Ok(blob_complete_failure("mime_mismatch"));
}
let body = match fs::read(&path) {
Ok(body) => body,
Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(blob_complete_failure("not_found")),
Err(err) => return Err(napi_error(format!("BlobComplete read fs object failed: {err}"))),
};
if !sha256_base64_url_matches(&body, &key) {
let _ = fs::remove_file(&path);
let _ = fs::remove_file(PathBuf::from(format!("{}.metadata.json", path.display())));
return Ok(blob_complete_failure("checksum_mismatch"));
}
upsert_completed_blob(
self,
&workspace_id,
&key,
&metadata.content_type,
metadata.content_length,
)
.await?;
Ok(blob_complete_success(
metadata.content_type,
metadata.content_length,
metadata.last_modified,
))
}
}
#[cfg(test)]
mod tests {
use super::{sha256_base64_url, sha256_base64_url_matches};
#[test]
fn sha256_base64_url_omits_padding() {
assert_eq!(
sha256_base64_url(b"hello"),
"LPJNul-wow4m6DsqxbninhsWHlwfp0JecwQzYpOLmCQ"
);
}
#[test]
fn sha256_base64_url_matches_legacy_padding() {
assert!(sha256_base64_url_matches(
b"hello",
"LPJNul-wow4m6DsqxbninhsWHlwfp0JecwQzYpOLmCQ="
));
}
}
@@ -1,128 +0,0 @@
use std::{
collections::HashMap,
env, fs,
path::{Path, PathBuf},
};
use napi::Result;
use serde::Deserialize;
use super::{
error::napi_error,
object_storage::{ObjectStorageConfig, StorageProviderConfig},
};
#[derive(Clone, Debug)]
pub(super) struct RuntimeConfig {
pub(super) database_url: String,
pub(super) storage: Option<ObjectStorageConfig>,
}
impl RuntimeConfig {
pub(super) fn from_config_files() -> Result<Self> {
let database_url =
database_url_from_config_files()?.unwrap_or_else(|| "postgresql://localhost:5432/affine".to_string());
let storage = ObjectStorageConfig::from_config_files()?;
Ok(Self { database_url, storage })
}
}
#[derive(Debug, Deserialize)]
struct AppConfigFile {
db: Option<DbConfigFile>,
storages: Option<HashMap<String, StorageProviderConfig>>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct DbConfigFile {
datasource_url: Option<String>,
}
fn database_url_from_config_files() -> Result<Option<String>> {
let mut database_url = None;
for path in config_json_paths() {
if !path.exists() {
continue;
}
let raw = fs::read_to_string(&path)
.map_err(|err| napi_error(format!("failed to read config file {}: {err}", path.display())))?;
let config: AppConfigFile = serde_json::from_str(&raw)
.map_err(|err| napi_error(format!("failed to parse config file {}: {err}", path.display())))?;
if let Some(next) = config.db.and_then(|db| db.datasource_url)
&& !next.trim().is_empty()
{
database_url = Some(next);
}
}
Ok(database_url)
}
pub(super) fn blob_storage_config_from_config_files() -> Result<Option<StorageProviderConfig>> {
let mut storage = None;
for path in config_json_paths() {
if !path.exists() {
continue;
}
let raw = fs::read_to_string(&path)
.map_err(|err| napi_error(format!("failed to read config file {}: {err}", path.display())))?;
let config: AppConfigFile = serde_json::from_str(&raw)
.map_err(|err| napi_error(format!("failed to parse config file {}: {err}", path.display())))?;
if let Some(next) = config.storages.and_then(|mut storages| storages.remove("blob.storage")) {
storage = Some(next);
}
}
Ok(storage)
}
pub(super) fn config_json_paths() -> Vec<PathBuf> {
let mut paths = Vec::new();
if let Ok(exe) = env::current_exe()
&& let Some(dir) = exe.parent()
{
paths.push(config_in(dir));
}
if let Ok(cwd) = env::current_dir() {
paths.push(config_in(&cwd));
}
dedupe_paths(paths)
}
fn config_in(dir: &Path) -> PathBuf {
dir.join("config.json")
}
fn dedupe_paths(paths: Vec<PathBuf>) -> Vec<PathBuf> {
let mut deduped = Vec::new();
for path in paths {
if !deduped.contains(&path) {
deduped.push(path);
}
}
deduped
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn config_paths_are_limited_to_executable_dir_and_cwd() {
let paths = config_json_paths();
assert!(!paths.is_empty());
assert!(paths.len() <= 2);
assert!(
paths
.iter()
.all(|path| path.file_name().is_some_and(|name| name == "config.json"))
);
assert!(paths.iter().all(|path| !path.to_string_lossy().contains(".affine")));
assert!(
paths
.iter()
.all(|path| !path.to_string_lossy().contains("packages/backend/server"))
);
}
}
@@ -1,5 +0,0 @@
use napi::{Error, Status};
pub(super) fn napi_error(message: impl Into<String>) -> Error {
Error::new(Status::GenericFailure, message.into())
}
@@ -1,353 +0,0 @@
use std::{
collections::HashMap,
time::{Duration, SystemTime, UNIX_EPOCH},
};
use aws_sdk_s3::{
Client as S3Client, presigning::PresigningConfig, primitives::ByteStream, types::CompletedMultipartUpload,
};
use napi::Result;
use super::types::{
MultipartUploadInitResult, MultipartUploadPart, ObjectGetResult, ObjectListEntry, ObjectMetadata, ObjectPutMetadata,
PresignedObjectRequest, completed_multipart_parts, trim_etag,
};
use crate::backend_runtime::error::napi_error;
#[derive(Clone)]
pub(super) struct ObjectStorageClient {
client: S3Client,
bucket: String,
presign_expires_in_seconds: u64,
presign_sign_content_type_for_put: bool,
}
impl ObjectStorageClient {
pub(super) fn new(
config: aws_sdk_s3::Config,
bucket: String,
presign_expires_in_seconds: u64,
presign_sign_content_type_for_put: bool,
) -> Self {
Self {
client: S3Client::from_conf(config),
bucket,
presign_expires_in_seconds,
presign_sign_content_type_for_put,
}
}
pub(super) fn non_destructive_health(&self) -> bool {
let _ = &self.client;
!self.bucket.is_empty()
}
pub(super) async fn put(&self, key: &str, body: Vec<u8>, metadata: ObjectPutMetadata) -> Result<()> {
let content_length = metadata.content_length.unwrap_or(body.len() as i64);
let content_type = metadata
.content_type
.unwrap_or_else(|| "application/octet-stream".to_string());
let mut request = self
.client
.put_object()
.bucket(&self.bucket)
.key(key)
.body(ByteStream::from(body))
.content_type(content_type)
.content_length(content_length);
if let Some(checksum) = metadata.checksum_crc32 {
request = request.checksum_crc32(checksum);
}
request
.send()
.await
.map_err(|err| napi_error(format!("ObjectStorage put failed for {key}: {err:?}")))?;
Ok(())
}
pub(super) async fn presign_put(&self, key: &str, metadata: ObjectPutMetadata) -> Result<PresignedObjectRequest> {
let content_type = metadata
.content_type
.unwrap_or_else(|| "application/octet-stream".to_string());
let expires_at_ms = expires_at_ms(self.presign_expires_in_seconds)?;
let config = PresigningConfig::expires_in(Duration::from_secs(self.presign_expires_in_seconds))
.map_err(|err| napi_error(format!("ObjectStorage presign config failed: {err}")))?;
let mut request = self.client.put_object().bucket(&self.bucket).key(key);
if self.presign_sign_content_type_for_put {
request = request.content_type(content_type.clone());
}
if let Some(content_length) = metadata.content_length {
request = request.content_length(content_length);
}
let presigned = request
.presigned(config)
.await
.map_err(|err| napi_error(format!("ObjectStorage presign put failed for {key}: {err}")))?;
let mut headers = presigned_headers(&presigned);
headers.insert("Content-Type".to_string(), content_type);
Ok(PresignedObjectRequest {
url: presigned.uri().to_string(),
headers,
expires_at_ms,
})
}
pub(super) async fn create_multipart_upload(
&self,
key: &str,
metadata: ObjectPutMetadata,
) -> Result<Option<MultipartUploadInitResult>> {
let content_type = metadata
.content_type
.unwrap_or_else(|| "application/octet-stream".to_string());
let result = self
.client
.create_multipart_upload()
.bucket(&self.bucket)
.key(key)
.content_type(content_type)
.send()
.await
.map_err(|err| {
napi_error(format!(
"ObjectStorage create multipart upload failed for {key}: {err:?}"
))
})?;
let expires_at_ms = expires_at_ms(self.presign_expires_in_seconds)?;
Ok(result.upload_id.map(|upload_id| MultipartUploadInitResult {
upload_id,
expires_at_ms,
}))
}
pub(super) async fn presign_upload_part(
&self,
key: &str,
upload_id: &str,
part_number: i32,
) -> Result<PresignedObjectRequest> {
let expires_at_ms = expires_at_ms(self.presign_expires_in_seconds)?;
let config = PresigningConfig::expires_in(Duration::from_secs(self.presign_expires_in_seconds))
.map_err(|err| napi_error(format!("ObjectStorage presign config failed: {err}")))?;
let presigned = self
.client
.upload_part()
.bucket(&self.bucket)
.key(key)
.upload_id(upload_id)
.part_number(part_number)
.presigned(config)
.await
.map_err(|err| napi_error(format!("ObjectStorage presign upload part failed for {key}: {err}")))?;
Ok(PresignedObjectRequest {
url: presigned.uri().to_string(),
headers: presigned_headers(&presigned),
expires_at_ms,
})
}
pub(super) async fn list_multipart_upload_parts(
&self,
key: &str,
upload_id: &str,
) -> Result<Vec<MultipartUploadPart>> {
let result = self
.client
.list_parts()
.bucket(&self.bucket)
.key(key)
.upload_id(upload_id)
.send()
.await
.map_err(|err| {
napi_error(format!(
"ObjectStorage list multipart upload parts failed for {key}: {err}"
))
})?;
Ok(
result
.parts()
.iter()
.filter_map(|part| {
Some(MultipartUploadPart {
part_number: part.part_number?,
etag: trim_etag(part.e_tag.as_deref().unwrap_or_default()),
})
})
.collect(),
)
}
pub(super) async fn complete_multipart_upload(
&self,
key: &str,
upload_id: &str,
parts: Vec<MultipartUploadPart>,
) -> Result<()> {
let ordered_parts = completed_multipart_parts(parts);
self
.client
.complete_multipart_upload()
.bucket(&self.bucket)
.key(key)
.upload_id(upload_id)
.multipart_upload(
CompletedMultipartUpload::builder()
.set_parts(Some(ordered_parts))
.build(),
)
.send()
.await
.map_err(|err| {
napi_error(format!(
"ObjectStorage complete multipart upload failed for {key}: {err}"
))
})?;
Ok(())
}
pub(super) async fn abort_multipart_upload(&self, key: &str, upload_id: &str) -> Result<()> {
self
.client
.abort_multipart_upload()
.bucket(&self.bucket)
.key(key)
.upload_id(upload_id)
.send()
.await
.map_err(|err| {
napi_error(format!(
"ObjectStorage abort multipart upload failed for {key}: {err:?}"
))
})?;
Ok(())
}
pub(super) async fn head(&self, key: &str) -> Result<Option<ObjectMetadata>> {
let result = self
.client
.head_object()
.bucket(&self.bucket)
.key(key)
.send()
.await
.map_err(|err| napi_error(format!("ObjectStorage head failed for {key}: {err:?}")))?;
Ok(Some(ObjectMetadata {
content_type: result
.content_type
.unwrap_or_else(|| "application/octet-stream".to_string()),
content_length: result.content_length.unwrap_or(0),
last_modified_ms: optional_datetime_ms(result.last_modified),
checksum_crc32: result.checksum_crc32,
}))
}
pub(super) async fn get(&self, key: &str) -> Result<Option<ObjectGetResult>> {
let result = self
.client
.get_object()
.bucket(&self.bucket)
.key(key)
.send()
.await
.map_err(|err| napi_error(format!("ObjectStorage get failed for {key}: {err:?}")))?;
let metadata = ObjectMetadata {
content_type: result
.content_type
.unwrap_or_else(|| "application/octet-stream".to_string()),
content_length: result.content_length.unwrap_or(0),
last_modified_ms: optional_datetime_ms(result.last_modified),
checksum_crc32: result.checksum_crc32,
};
let body = result
.body
.collect()
.await
.map_err(|err| napi_error(format!("ObjectStorage read body failed for {key}: {err}")))?
.into_bytes()
.to_vec();
Ok(Some(ObjectGetResult { body, metadata }))
}
pub(super) async fn list(&self, prefix: Option<String>) -> Result<Vec<ObjectListEntry>> {
let mut entries = Vec::new();
let mut token = None;
loop {
let mut request = self.client.list_objects_v2().bucket(&self.bucket);
if let Some(prefix) = &prefix {
request = request.prefix(prefix);
}
if let Some(next_token) = token {
request = request.continuation_token(next_token);
}
let result = request
.send()
.await
.map_err(|err| napi_error(format!("ObjectStorage list failed: {err:?}")))?;
entries.extend(result.contents().iter().filter_map(|object| {
Some(ObjectListEntry {
key: object.key.as_ref()?.clone(),
content_length: object.size.unwrap_or(0),
last_modified_ms: optional_datetime_ms(object.last_modified),
})
}));
if result.is_truncated.unwrap_or(false) {
token = result.next_continuation_token;
} else {
break;
}
}
Ok(entries)
}
pub(super) async fn delete(&self, key: &str) -> Result<()> {
self
.client
.delete_object()
.bucket(&self.bucket)
.key(key)
.send()
.await
.map_err(|err| napi_error(format!("ObjectStorage delete failed for {key}: {err:?}")))?;
Ok(())
}
}
fn expires_at_ms(expires_in_seconds: u64) -> Result<i64> {
let expires_at = SystemTime::now()
.checked_add(Duration::from_secs(expires_in_seconds))
.ok_or_else(|| napi_error("ObjectStorage presign expiration overflow"))?;
system_time_ms(expires_at)
}
fn system_time_ms(time: SystemTime) -> Result<i64> {
let duration = time
.duration_since(UNIX_EPOCH)
.map_err(|err| napi_error(format!("system time before unix epoch: {err}")))?;
Ok(duration.as_millis() as i64)
}
fn optional_datetime_ms(time: Option<aws_sdk_s3::primitives::DateTime>) -> i64 {
time.and_then(|value| value.to_millis().ok()).unwrap_or(0)
}
fn presigned_headers(request: &aws_sdk_s3::presigning::PresignedRequest) -> HashMap<String, String> {
request
.headers()
.map(|(key, value)| (key.to_string(), value.to_string()))
.collect()
}
@@ -1,184 +0,0 @@
mod client;
mod config;
#[cfg(test)]
mod tests;
mod types;
use client::ObjectStorageClient;
pub(super) use config::ObjectStorageConfig;
use napi::{Result, bindgen_prelude::Buffer};
pub(super) use types::StorageProviderConfig;
use super::{
BackendRuntime,
types::{
RuntimeMultipartUploadInit, RuntimeMultipartUploadPart, RuntimeObjectGetResult, RuntimeObjectListEntry,
RuntimeObjectMetadata, RuntimeObjectStorageHealth, RuntimeObjectStoragePutOptions, RuntimePresignedObjectRequest,
},
};
#[napi_derive::napi]
impl BackendRuntime {
fn object_storage_client(&self) -> Result<ObjectStorageClient> {
self
.config
.storage
.as_ref()
.ok_or_else(|| super::error::napi_error("ObjectStorageClient is not configured"))?
.build_client()
}
pub(super) async fn object_storage_delete_object(&self, key: &str) -> Result<()> {
self.object_storage_client()?.delete(key).await
}
pub(super) async fn object_storage_abort_upload(&self, key: &str, upload_id: &str) -> Result<()> {
self
.object_storage_client()?
.abort_multipart_upload(key, upload_id)
.await
}
#[napi]
pub fn object_storage_health(&self) -> RuntimeObjectStorageHealth {
match &self.config.storage {
Some(storage) => storage.health(),
None => RuntimeObjectStorageHealth {
configured: false,
provider: None,
bucket: None,
endpoint: None,
region: None,
has_credentials: false,
force_path_style: false,
request_timeout_ms: None,
min_part_size: None,
presign_expires_in_seconds: None,
presign_sign_content_type_for_put: None,
use_presigned_url: false,
client_buildable: false,
},
}
}
#[napi]
pub async fn object_storage_put(
&self,
key: String,
body: Buffer,
metadata: Option<RuntimeObjectStoragePutOptions>,
) -> Result<()> {
self
.object_storage_client()?
.put(&key, body.to_vec(), metadata.map(Into::into).unwrap_or_default())
.await
}
#[napi]
pub async fn object_storage_presign_put(
&self,
key: String,
metadata: Option<RuntimeObjectStoragePutOptions>,
) -> Result<RuntimePresignedObjectRequest> {
self
.object_storage_client()?
.presign_put(&key, metadata.map(Into::into).unwrap_or_default())
.await?
.try_into()
}
#[napi]
pub async fn object_storage_create_multipart_upload(
&self,
key: String,
metadata: Option<RuntimeObjectStoragePutOptions>,
) -> Result<Option<RuntimeMultipartUploadInit>> {
Ok(
self
.object_storage_client()?
.create_multipart_upload(&key, metadata.map(Into::into).unwrap_or_default())
.await?
.map(Into::into),
)
}
#[napi]
pub async fn object_storage_presign_upload_part(
&self,
key: String,
upload_id: String,
part_number: i32,
) -> Result<RuntimePresignedObjectRequest> {
self
.object_storage_client()?
.presign_upload_part(&key, &upload_id, part_number)
.await?
.try_into()
}
#[napi]
pub async fn object_storage_list_multipart_upload_parts(
&self,
key: String,
upload_id: String,
) -> Result<Vec<RuntimeMultipartUploadPart>> {
Ok(
self
.object_storage_client()?
.list_multipart_upload_parts(&key, &upload_id)
.await?
.into_iter()
.map(Into::into)
.collect(),
)
}
#[napi]
pub async fn object_storage_complete_multipart_upload(
&self,
key: String,
upload_id: String,
parts: Vec<RuntimeMultipartUploadPart>,
) -> Result<()> {
self
.object_storage_client()?
.complete_multipart_upload(&key, &upload_id, parts.into_iter().map(Into::into).collect())
.await
}
#[napi]
pub async fn object_storage_abort_multipart_upload(&self, key: String, upload_id: String) -> Result<()> {
self
.object_storage_client()?
.abort_multipart_upload(&key, &upload_id)
.await
}
#[napi]
pub async fn object_storage_head(&self, key: String) -> Result<Option<RuntimeObjectMetadata>> {
Ok(self.object_storage_client()?.head(&key).await?.map(Into::into))
}
#[napi]
pub async fn object_storage_get(&self, key: String) -> Result<Option<RuntimeObjectGetResult>> {
Ok(self.object_storage_client()?.get(&key).await?.map(Into::into))
}
#[napi]
pub async fn object_storage_list(&self, prefix: Option<String>) -> Result<Vec<RuntimeObjectListEntry>> {
Ok(
self
.object_storage_client()?
.list(prefix)
.await?
.into_iter()
.map(Into::into)
.collect(),
)
}
#[napi]
pub async fn object_storage_delete(&self, key: String) -> Result<()> {
self.object_storage_client()?.delete(&key).await
}
}
@@ -1,129 +0,0 @@
use super::{
config::ObjectStorageConfig,
types::{MultipartUploadPart, ObjectPutMetadata, StorageProviderConfig, completed_multipart_parts, trim_etag},
};
#[test]
fn resolves_r2_config_from_config_json_shape() {
let storage = StorageProviderConfig {
provider: "cloudflare-r2".to_string(),
bucket: "workspace-blobs".to_string(),
config: serde_json::json!({
"accountId": "account",
"jurisdiction": "eu",
"credentials": {
"accessKeyId": "key",
"secretAccessKey": "secret"
},
"usePresignedURL": {
"enabled": true
}
}),
};
let config = ObjectStorageConfig::from_r2_config(storage).unwrap().unwrap();
assert_eq!(config.provider, "cloudflare-r2");
assert_eq!(config.bucket, "workspace-blobs");
assert_eq!(
config.endpoint.as_deref(),
Some("https://account.eu.r2.cloudflarestorage.com")
);
assert_eq!(config.region.as_deref(), Some("auto"));
assert!(config.force_path_style);
assert!(config.use_presigned_url);
assert_eq!(config.access_key_id.as_deref(), Some("key"));
}
#[test]
fn resolves_s3_config_from_config_json_shape() {
let storage = StorageProviderConfig {
provider: "aws-s3".to_string(),
bucket: "workspace-blobs".to_string(),
config: serde_json::json!({
"region": "us-west-2",
"credentials": {
"accessKeyId": "key",
"secretAccessKey": "secret",
"sessionToken": "session"
},
"forcePathStyle": true,
"requestTimeoutMs": 1000,
"minPartSize": 1024,
"presign": {
"expiresInSeconds": 60,
"signContentTypeForPut": false
}
}),
};
let config = ObjectStorageConfig::from_s3_config(storage).unwrap().unwrap();
assert_eq!(config.provider, "aws-s3");
assert_eq!(config.endpoint.as_deref(), Some("https://s3.us-west-2.amazonaws.com"));
assert_eq!(config.session_token.as_deref(), Some("session"));
assert!(config.force_path_style);
assert_eq!(config.request_timeout_ms, Some(1000));
assert_eq!(config.min_part_size, Some(1024));
assert_eq!(config.presign_expires_in_seconds, Some(60));
assert_eq!(config.presign_sign_content_type_for_put, Some(false));
}
#[tokio::test]
async fn object_storage_presign_put_returns_sigv4_url_and_headers() {
let storage = StorageProviderConfig {
provider: "aws-s3".to_string(),
bucket: "test-bucket".to_string(),
config: serde_json::json!({
"region": "us-east-1",
"endpoint": "https://s3.us-east-1.amazonaws.com",
"credentials": {
"accessKeyId": "key",
"secretAccessKey": "secret"
},
"presign": {
"expiresInSeconds": 60
}
}),
};
let config = ObjectStorageConfig::from_s3_config(storage).unwrap().unwrap();
let Ok(Ok(client)) = std::panic::catch_unwind(|| config.build_client()) else {
eprintln!("skipping object storage presign test: S3 client cannot be built in this environment");
return;
};
let result = client
.presign_put(
"key",
ObjectPutMetadata {
content_type: Some("text/plain".to_string()),
..Default::default()
},
)
.await
.unwrap();
assert!(result.url.contains("X-Amz-Algorithm=AWS4-HMAC-SHA256"));
assert!(result.url.contains("X-Amz-SignedHeaders="));
assert_eq!(
result.headers.get("Content-Type").map(String::as_str),
Some("text/plain")
);
assert!(result.expires_at_ms > 0);
}
#[test]
fn object_storage_orders_completed_multipart_parts_and_trims_etags() {
let parts = completed_multipart_parts(vec![
MultipartUploadPart {
part_number: 2,
etag: trim_etag("\"b\""),
},
MultipartUploadPart {
part_number: 1,
etag: trim_etag("a"),
},
]);
assert_eq!(parts[0].part_number, Some(1));
assert_eq!(parts[0].e_tag.as_deref(), Some("a"));
assert_eq!(parts[1].part_number, Some(2));
assert_eq!(parts[1].e_tag.as_deref(), Some("b"));
}
@@ -1,165 +0,0 @@
use std::collections::HashMap;
use aws_sdk_s3::types::CompletedPart;
use napi::Result;
use serde::Deserialize;
use crate::backend_runtime::{
error::napi_error,
types::{
RuntimeMultipartUploadInit, RuntimeMultipartUploadPart, RuntimeObjectGetResult, RuntimeObjectListEntry,
RuntimeObjectMetadata, RuntimeObjectStoragePutOptions, RuntimePresignedObjectRequest,
},
};
#[derive(Clone, Debug, Default)]
pub(super) struct ObjectPutMetadata {
pub(super) content_type: Option<String>,
pub(super) content_length: Option<i64>,
pub(super) checksum_crc32: Option<String>,
}
#[derive(Clone, Debug, PartialEq)]
pub(super) struct ObjectMetadata {
pub(super) content_type: String,
pub(super) content_length: i64,
pub(super) last_modified_ms: i64,
pub(super) checksum_crc32: Option<String>,
}
#[derive(Clone, Debug, PartialEq)]
pub(super) struct ObjectListEntry {
pub(super) key: String,
pub(super) content_length: i64,
pub(super) last_modified_ms: i64,
}
#[derive(Clone, Debug, PartialEq)]
pub(super) struct ObjectGetResult {
pub(super) body: Vec<u8>,
pub(super) metadata: ObjectMetadata,
}
#[derive(Clone, Debug, PartialEq)]
pub(super) struct PresignedObjectRequest {
pub(super) url: String,
pub(super) headers: HashMap<String, String>,
pub(super) expires_at_ms: i64,
}
#[derive(Clone, Debug, PartialEq)]
pub(super) struct MultipartUploadInitResult {
pub(super) upload_id: String,
pub(super) expires_at_ms: i64,
}
#[derive(Clone, Debug, PartialEq)]
pub(super) struct MultipartUploadPart {
pub(super) part_number: i32,
pub(super) etag: String,
}
#[derive(Clone, Debug, Deserialize)]
pub(in crate::backend_runtime) struct StorageProviderConfig {
pub(super) provider: String,
pub(super) bucket: String,
#[serde(default)]
pub(super) config: serde_json::Value,
}
pub(super) fn trim_etag(etag: &str) -> String {
etag.trim_matches('"').to_string()
}
pub(super) fn completed_multipart_parts(mut parts: Vec<MultipartUploadPart>) -> Vec<CompletedPart> {
parts.sort_by_key(|part| part.part_number);
parts
.into_iter()
.map(|part| {
CompletedPart::builder()
.part_number(part.part_number)
.e_tag(part.etag)
.build()
})
.collect()
}
impl From<RuntimeObjectStoragePutOptions> for ObjectPutMetadata {
fn from(options: RuntimeObjectStoragePutOptions) -> Self {
Self {
content_type: options.content_type,
content_length: options.content_length,
checksum_crc32: options.checksum_crc32,
}
}
}
impl From<ObjectMetadata> for RuntimeObjectMetadata {
fn from(metadata: ObjectMetadata) -> Self {
Self {
content_type: metadata.content_type,
content_length: metadata.content_length,
last_modified_ms: metadata.last_modified_ms,
checksum_crc32: metadata.checksum_crc32,
}
}
}
impl From<ObjectListEntry> for RuntimeObjectListEntry {
fn from(entry: ObjectListEntry) -> Self {
Self {
key: entry.key,
content_length: entry.content_length,
last_modified_ms: entry.last_modified_ms,
}
}
}
impl TryFrom<PresignedObjectRequest> for RuntimePresignedObjectRequest {
type Error = napi::Error;
fn try_from(request: PresignedObjectRequest) -> Result<Self> {
Ok(Self {
url: request.url,
headers_json: serde_json::to_string(&request.headers)
.map_err(|err| napi_error(format!("ObjectStorage headers serialization failed: {err}")))?,
expires_at_ms: request.expires_at_ms,
})
}
}
impl From<ObjectGetResult> for RuntimeObjectGetResult {
fn from(result: ObjectGetResult) -> Self {
Self {
body: result.body.into(),
metadata: result.metadata.into(),
}
}
}
impl From<MultipartUploadInitResult> for RuntimeMultipartUploadInit {
fn from(init: MultipartUploadInitResult) -> Self {
Self {
upload_id: init.upload_id,
expires_at_ms: init.expires_at_ms,
}
}
}
impl From<RuntimeMultipartUploadPart> for MultipartUploadPart {
fn from(part: RuntimeMultipartUploadPart) -> Self {
Self {
part_number: part.part_number,
etag: part.etag,
}
}
}
impl From<MultipartUploadPart> for RuntimeMultipartUploadPart {
fn from(part: MultipartUploadPart) -> Self {
Self {
part_number: part.part_number,
etag: part.etag,
}
}
}
@@ -1,40 +0,0 @@
CREATE TABLE IF NOT EXISTS runtime_states (
purpose TEXT NOT NULL,
token_hash TEXT NOT NULL,
lookup_key TEXT,
payload JSONB NOT NULL,
attempts INTEGER NOT NULL DEFAULT 0,
consumed_at TIMESTAMPTZ(3),
expires_at TIMESTAMPTZ(3) NOT NULL,
created_at TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (purpose, token_hash)
);
CREATE INDEX IF NOT EXISTS runtime_states_lookup_idx
ON runtime_states (purpose, lookup_key)
WHERE lookup_key IS NOT NULL AND consumed_at IS NULL;
CREATE INDEX IF NOT EXISTS runtime_states_expires_at_idx
ON runtime_states (expires_at);
CREATE TABLE IF NOT EXISTS runtime_gates (
key TEXT PRIMARY KEY,
expires_at TIMESTAMPTZ(3) NOT NULL,
created_at TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS runtime_gates_expires_at_idx
ON runtime_gates (expires_at);
CREATE TABLE IF NOT EXISTS runtime_leases (
key TEXT PRIMARY KEY,
owner TEXT NOT NULL,
fencing_token BIGINT NOT NULL,
expires_at TIMESTAMPTZ(3) NOT NULL,
created_at TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS runtime_leases_expires_at_idx
ON runtime_leases (expires_at);
+18 -2
View File
@@ -2,7 +2,6 @@
mod utils;
pub mod backend_runtime;
pub mod doc;
pub mod doc_loader;
pub mod entitlement;
@@ -13,12 +12,13 @@ pub mod image;
pub mod license;
pub mod llm;
pub mod permission;
pub mod runtime;
pub mod safe_fetch;
pub mod tiktoken;
use affine_common::napi_utils::map_napi_err;
use napi::{Result, Status, bindgen_prelude::*};
use y_octo::Doc;
use y_octo::{Doc, Update};
#[cfg(not(target_arch = "arm"))]
#[global_allocator]
@@ -41,6 +41,16 @@ pub fn merge_updates_in_apply_way(updates: Vec<Buffer>) -> Result<Buffer> {
Ok(buf.into())
}
/// Check whether a Yjs update binary can be decoded without applying it to a
/// document state.
#[napi(catch_unwind)]
pub async fn validate_doc_update(update: Buffer) -> Result<bool> {
let update = update.to_vec();
tokio::task::spawn_blocking(move || Update::decode_v1(update).is_ok())
.await
.map_err(|err| napi::Error::from_reason(format!("Doc update validation task failed: {err}")))
}
#[napi]
pub const AFFINE_PRO_PUBLIC_KEY: Option<&'static str> = std::option_env!("AFFINE_PRO_PUBLIC_KEY");
@@ -59,4 +69,10 @@ mod tests {
};
assert_eq!(err.status, Status::GenericFailure);
}
#[test]
fn y_octo_update_decode_accepts_valid_update_and_rejects_invalid_update() {
assert!(Update::decode_v1(vec![0, 0]).is_ok());
assert!(Update::decode_v1(vec![0]).is_err());
}
}
+2 -5
View File
@@ -1,7 +1,7 @@
use std::{
collections::HashMap,
sync::{Mutex, OnceLock},
time::{Duration, SystemTime, UNIX_EPOCH},
time::{Duration, SystemTime},
};
use anyhow::{Context, Result as AnyResult, bail};
@@ -458,10 +458,7 @@ fn parse_future_end_at(value: &serde_json::Value) -> AnyResult<f64> {
}
fn now_millis() -> f64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as f64
crate::utils::system_time_millis(SystemTime::now()).unwrap_or_default() as f64
}
fn affine_pro_ech_config() -> AnyResult<Vec<u8>> {
@@ -7,4 +7,3 @@ pub(super) const WORKSPACE_INVITE_LINK_WORKSPACE_PURPOSE: &str = "workspace_invi
pub(super) const WORKSPACE_STATS_LEASE_KEY: &str = "workspace:admin-stats:refresh";
pub(super) const WORKSPACE_STATS_LOCK_NAMESPACE: i64 = 97_301;
pub(super) const WORKSPACE_STATS_REFRESH_LOCK_KEY: i64 = 1;
pub(super) const RUNTIME_MIGRATIONS: &str = include_str!("sql/runtime_migrations.sql");
@@ -1,7 +1,7 @@
use napi::Result;
use sqlx::{FromRow, PgPool};
use super::{BackendRuntime, error::napi_error, types::CoordinationLeaseGrant};
use super::{BackendRuntime, RuntimeError, RuntimeResult, napi_error, types::CoordinationLeaseGrant};
#[derive(FromRow)]
struct LeaseGrantRow {
@@ -17,7 +17,7 @@ impl CoordinationLeaseStore {
Self { pool }
}
async fn acquire(&self, key: String, owner: String, ttl_ms: i64) -> Result<Option<CoordinationLeaseGrant>> {
async fn acquire(&self, key: String, owner: String, ttl_ms: i64) -> RuntimeResult<Option<CoordinationLeaseGrant>> {
let row = sqlx::query_as::<_, LeaseGrantRow>(
r#"
INSERT INTO runtime_leases (key, owner, fencing_token, expires_at)
@@ -36,7 +36,7 @@ impl CoordinationLeaseStore {
.bind(ttl_ms as f64)
.fetch_optional(&self.pool)
.await
.map_err(|err| napi_error(format!("CoordinationLease acquire failed: {err}")))?;
.map_err(|err| RuntimeError::database("CoordinationLease acquire failed", err))?;
Ok(row.map(|row| CoordinationLeaseGrant {
key,
@@ -45,7 +45,7 @@ impl CoordinationLeaseStore {
}))
}
async fn release(&self, key: &str, owner: &str, fencing_token: i64) -> Result<bool> {
async fn release(&self, key: &str, owner: &str, fencing_token: i64) -> RuntimeResult<bool> {
let result = sqlx::query(
r#"
DELETE FROM runtime_leases
@@ -57,12 +57,12 @@ impl CoordinationLeaseStore {
.bind(fencing_token)
.execute(&self.pool)
.await
.map_err(|err| napi_error(format!("CoordinationLease release failed: {err}")))?;
.map_err(|err| RuntimeError::database("CoordinationLease release failed", err))?;
Ok(result.rows_affected() == 1)
}
async fn renew(&self, key: &str, owner: &str, fencing_token: i64, ttl_ms: i64) -> Result<bool> {
async fn renew(&self, key: &str, owner: &str, fencing_token: i64, ttl_ms: i64) -> RuntimeResult<bool> {
let result = sqlx::query(
r#"
UPDATE runtime_leases
@@ -80,7 +80,7 @@ impl CoordinationLeaseStore {
.bind(ttl_ms as f64)
.execute(&self.pool)
.await
.map_err(|err| napi_error(format!("CoordinationLease renew failed: {err}")))?;
.map_err(|err| RuntimeError::database("CoordinationLease renew failed", err))?;
Ok(result.rows_affected() == 1)
}
@@ -88,6 +88,35 @@ impl CoordinationLeaseStore {
#[napi_derive::napi]
impl BackendRuntime {
pub(crate) async fn acquire_coordination_lease_inner(
&self,
key: String,
owner: String,
ttl_ms: i64,
) -> RuntimeResult<Option<CoordinationLeaseGrant>> {
if ttl_ms <= 0 {
return Err(RuntimeError::invalid_input("coordination lease ttl must be positive"));
}
if owner.is_empty() {
return Err(RuntimeError::invalid_input("coordination lease owner is required"));
}
CoordinationLeaseStore::new(self.pool().await?)
.acquire(key, owner, ttl_ms)
.await
}
pub(crate) async fn release_coordination_lease_inner(
&self,
key: String,
owner: String,
fencing_token: i64,
) -> RuntimeResult<bool> {
CoordinationLeaseStore::new(self.pool().await?)
.release(&key, &owner, fencing_token)
.await
}
#[napi]
pub async fn acquire_coordination_lease(
&self,
@@ -95,16 +124,10 @@ impl BackendRuntime {
owner: String,
ttl_ms: i64,
) -> Result<Option<CoordinationLeaseGrant>> {
if ttl_ms <= 0 {
return Err(napi_error("coordination lease ttl must be positive"));
}
if owner.is_empty() {
return Err(napi_error("coordination lease owner is required"));
}
CoordinationLeaseStore::new(self.pool().await?)
.acquire(key, owner, ttl_ms)
self
.acquire_coordination_lease_inner(key, owner, ttl_ms)
.await
.map_err(napi::Error::from)
}
#[napi]
@@ -114,9 +137,10 @@ impl BackendRuntime {
owner: String,
#[napi(ts_arg_type = "bigint | number")] fencing_token: i64,
) -> Result<bool> {
CoordinationLeaseStore::new(self.pool().await?)
.release(&key, &owner, fencing_token)
self
.release_coordination_lease_inner(key, owner, fencing_token)
.await
.map_err(napi::Error::from)
}
#[napi]
@@ -134,5 +158,6 @@ impl BackendRuntime {
CoordinationLeaseStore::new(self.pool().await?)
.renew(&key, &owner, fencing_token, ttl_ms)
.await
.map_err(napi::Error::from)
}
}
@@ -1,9 +1,8 @@
use chrono::{DateTime, Duration, Utc};
use napi::Result;
use sqlx::{FromRow, PgPool, Postgres, Row, Transaction};
use y_octo::Doc;
use super::{BackendRuntime, error::napi_error, types::RuntimeDocCompactionResult};
use super::{BackendRuntime, RuntimeError, RuntimeResult, napi_error, types::RuntimeDocCompactionResult};
#[derive(FromRow)]
struct SnapshotRow {
@@ -35,7 +34,7 @@ impl DocCompactorStore {
batch_limit: i64,
history_min_interval_ms: i64,
history_max_age_seconds: i64,
) -> Result<(i64, bool)> {
) -> RuntimeResult<(i64, bool)> {
compact_doc(
self.pool.clone(),
workspace_id,
@@ -52,31 +51,32 @@ fn is_empty_doc(bin: &[u8]) -> bool {
bin.is_empty() || (bin.len() == 1 && bin[0] == 0) || (bin.len() == 2 && bin[0] == 0 && bin[1] == 0)
}
fn apply_updates(updates: impl IntoIterator<Item = Vec<u8>>) -> Result<Vec<u8>> {
fn apply_updates(updates: impl IntoIterator<Item = Vec<u8>>) -> RuntimeResult<Vec<u8>> {
let mut doc = Doc::default();
for update in updates {
doc
.apply_update_from_binary_v1(&update)
.map_err(|err| napi_error(format!("DocCompactor merge failed: {err}")))?;
.map_err(|err| RuntimeError::invalid_state(format!("DocCompactor merge failed: {err}")))?;
}
doc
.encode_update_v1()
.map_err(|err| napi_error(format!("DocCompactor encode failed: {err}")))
.map_err(|err| RuntimeError::invalid_state(format!("DocCompactor encode failed: {err}")))
}
fn checked_milliseconds(value: i64, field: &str) -> Result<Duration> {
Duration::try_milliseconds(value).ok_or_else(|| napi_error(format!("DocCompactor {field} is too large")))
fn checked_milliseconds(value: i64, field: &str) -> RuntimeResult<Duration> {
Duration::try_milliseconds(value)
.ok_or_else(|| RuntimeError::invalid_input(format!("DocCompactor {field} is too large")))
}
fn checked_seconds(value: i64, field: &str) -> Result<Duration> {
Duration::try_seconds(value).ok_or_else(|| napi_error(format!("DocCompactor {field} is too large")))
fn checked_seconds(value: i64, field: &str) -> RuntimeResult<Duration> {
Duration::try_seconds(value).ok_or_else(|| RuntimeError::invalid_input(format!("DocCompactor {field} is too large")))
}
async fn load_snapshot(
tx: &mut Transaction<'_, Postgres>,
workspace_id: &str,
doc_id: &str,
) -> Result<Option<SnapshotRow>> {
) -> RuntimeResult<Option<SnapshotRow>> {
sqlx::query_as::<_, SnapshotRow>(
r#"
SELECT blob, updated_at, updated_by
@@ -89,7 +89,7 @@ async fn load_snapshot(
.bind(doc_id)
.fetch_optional(&mut **tx)
.await
.map_err(|err| napi_error(format!("DocCompactor load snapshot failed: {err}")))
.map_err(|err| RuntimeError::database("DocCompactor load snapshot failed", err))
}
async fn load_updates(
@@ -97,7 +97,7 @@ async fn load_updates(
workspace_id: &str,
doc_id: &str,
batch_limit: i64,
) -> Result<Vec<UpdateRow>> {
) -> RuntimeResult<Vec<UpdateRow>> {
sqlx::query_as::<_, UpdateRow>(
r#"
SELECT blob, created_at, created_by
@@ -113,7 +113,7 @@ async fn load_updates(
.bind(batch_limit)
.fetch_all(&mut **tx)
.await
.map_err(|err| napi_error(format!("DocCompactor load updates failed: {err}")))
.map_err(|err| RuntimeError::database("DocCompactor load updates failed", err))
}
async fn upsert_snapshot(
@@ -123,7 +123,7 @@ async fn upsert_snapshot(
blob: &[u8],
timestamp: DateTime<Utc>,
editor: Option<&str>,
) -> Result<bool> {
) -> RuntimeResult<bool> {
if is_empty_doc(blob) {
return Ok(false);
}
@@ -154,7 +154,7 @@ async fn upsert_snapshot(
.bind(editor)
.fetch_optional(&mut **tx)
.await
.map_err(|err| napi_error(format!("DocCompactor upsert snapshot failed: {err}")))?;
.map_err(|err| RuntimeError::database("DocCompactor upsert snapshot failed", err))?;
Ok(row.is_some())
}
@@ -165,7 +165,7 @@ async fn should_create_history(
workspace_id: &str,
doc_id: &str,
history_min_interval_ms: i64,
) -> Result<bool> {
) -> RuntimeResult<bool> {
if is_empty_doc(&snapshot.blob) {
return Ok(false);
}
@@ -183,7 +183,7 @@ async fn should_create_history(
.bind(doc_id)
.fetch_optional(&mut **tx)
.await
.map_err(|err| napi_error(format!("DocCompactor load latest history failed: {err}")))?;
.map_err(|err| RuntimeError::database("DocCompactor load latest history failed", err))?;
let Some(row) = row else {
return Ok(true);
@@ -198,7 +198,7 @@ async fn should_create_history(
let threshold = snapshot
.updated_at
.checked_sub_signed(min_interval)
.ok_or_else(|| napi_error("DocCompactor history interval is out of range"))?;
.ok_or_else(|| RuntimeError::invalid_input("DocCompactor history interval is out of range"))?;
Ok(last_timestamp < threshold)
}
@@ -209,7 +209,7 @@ async fn create_history(
doc_id: &str,
snapshot: &SnapshotRow,
max_age_seconds: i64,
) -> Result<bool> {
) -> RuntimeResult<bool> {
if max_age_seconds <= 0 {
return Ok(false);
}
@@ -217,7 +217,7 @@ async fn create_history(
let max_age = checked_seconds(max_age_seconds, "history max age")?;
let expired_at = Utc::now()
.checked_add_signed(max_age)
.ok_or_else(|| napi_error("DocCompactor history max age is out of range"))?;
.ok_or_else(|| RuntimeError::invalid_input("DocCompactor history max age is out of range"))?;
sqlx::query(
r#"
INSERT INTO snapshot_histories
@@ -236,7 +236,7 @@ async fn create_history(
.bind(snapshot.updated_by.as_deref())
.execute(&mut **tx)
.await
.map_err(|err| napi_error(format!("DocCompactor create history failed: {err}")))?;
.map_err(|err| RuntimeError::database("DocCompactor create history failed", err))?;
Ok(true)
}
@@ -246,7 +246,7 @@ async fn delete_updates(
workspace_id: &str,
doc_id: &str,
timestamps: &[DateTime<Utc>],
) -> Result<i64> {
) -> RuntimeResult<i64> {
let result = sqlx::query(
r#"
DELETE FROM updates
@@ -260,7 +260,7 @@ async fn delete_updates(
.bind(timestamps)
.execute(&mut **tx)
.await
.map_err(|err| napi_error(format!("DocCompactor delete updates failed: {err}")))?;
.map_err(|err| RuntimeError::database("DocCompactor delete updates failed", err))?;
Ok(result.rows_affected() as i64)
}
@@ -272,18 +272,18 @@ async fn compact_doc(
batch_limit: i64,
history_min_interval_ms: i64,
history_max_age_seconds: i64,
) -> Result<(i64, bool)> {
) -> RuntimeResult<(i64, bool)> {
let mut tx = pool
.begin()
.await
.map_err(|err| napi_error(format!("DocCompactor begin transaction failed: {err}")))?;
.map_err(|err| RuntimeError::database("DocCompactor begin transaction failed", err))?;
let snapshot = load_snapshot(&mut tx, workspace_id, doc_id).await?;
let updates = load_updates(&mut tx, workspace_id, doc_id, batch_limit).await?;
if updates.is_empty() {
tx.commit()
.await
.map_err(|err| napi_error(format!("DocCompactor commit transaction failed: {err}")))?;
.map_err(|err| RuntimeError::database("DocCompactor commit transaction failed", err))?;
return Ok((0, false));
}
@@ -323,7 +323,7 @@ async fn compact_doc(
tx.commit()
.await
.map_err(|err| napi_error(format!("DocCompactor commit transaction failed: {err}")))?;
.map_err(|err| RuntimeError::database("DocCompactor commit transaction failed", err))?;
Ok((deleted, history_created))
}
@@ -350,7 +350,7 @@ impl BackendRuntime {
history_max_age_seconds: i64,
owner: String,
lease_ttl_ms: i64,
) -> Result<RuntimeDocCompactionResult> {
) -> napi::Result<RuntimeDocCompactionResult> {
if batch_limit <= 0 {
return Err(napi_error("doc compactor batch limit must be positive"));
}
@@ -365,7 +365,7 @@ impl BackendRuntime {
let max_age = checked_seconds(history_max_age_seconds, "history max age")?;
Utc::now()
.checked_add_signed(max_age)
.ok_or_else(|| napi_error("DocCompactor history max age is out of range"))?;
.ok_or_else(|| RuntimeError::invalid_input("DocCompactor history max age is out of range"))?;
}
let lease_key = format!("doc:update:{workspace_id}:{doc_id}");
@@ -394,7 +394,7 @@ impl BackendRuntime {
.release_coordination_lease(lease.key, lease.owner, lease.fencing_token)
.await?;
if !released {
return Err(napi_error("DocCompactor failed to release coordination lease"));
return Err(RuntimeError::invalid_state("DocCompactor failed to release coordination lease").into());
}
let (updates_merged, history_created) = result?;
@@ -1,14 +1,18 @@
use chrono::{DateTime, Duration, Utc};
use napi::{Result, bindgen_prelude::Buffer};
use napi::bindgen_prelude::Buffer;
use sqlx::{PgPool, Row};
use super::{BackendRuntime, error::napi_error, types::RuntimeDocHistoryInput};
use super::{BackendRuntime, RuntimeError, RuntimeResult, napi_error, types::RuntimeDocHistoryInput};
fn is_empty_doc(bin: &[u8]) -> bool {
bin.is_empty() || (bin.len() == 1 && bin[0] == 0) || (bin.len() == 2 && bin[0] == 0 && bin[1] == 0)
}
async fn latest_history_timestamp(pool: &PgPool, workspace_id: &str, doc_id: &str) -> Result<Option<DateTime<Utc>>> {
async fn latest_history_timestamp(
pool: &PgPool,
workspace_id: &str,
doc_id: &str,
) -> RuntimeResult<Option<DateTime<Utc>>> {
sqlx::query(
r#"
SELECT timestamp
@@ -23,7 +27,7 @@ async fn latest_history_timestamp(pool: &PgPool, workspace_id: &str, doc_id: &st
.fetch_optional(pool)
.await
.map(|row| row.map(|row| row.get("timestamp")))
.map_err(|err| napi_error(format!("DocStorage load latest history failed: {err}")))
.map_err(|err| RuntimeError::database("DocStorage load latest history failed", err))
}
#[napi_derive::napi]
@@ -36,13 +40,13 @@ impl BackendRuntime {
blob: Buffer,
timestamp_ms: i64,
editor_id: Option<String>,
) -> Result<bool> {
) -> napi::Result<bool> {
if is_empty_doc(blob.as_ref()) {
return Ok(false);
}
let timestamp = DateTime::<Utc>::from_timestamp_millis(timestamp_ms)
.ok_or_else(|| napi_error(format!("Invalid doc snapshot timestamp: {timestamp_ms}")))?;
.ok_or_else(|| RuntimeError::invalid_input(format!("Invalid doc snapshot timestamp: {timestamp_ms}")))?;
let pool = self.pool().await?;
let row = sqlx::query(
r#"
@@ -70,13 +74,13 @@ impl BackendRuntime {
.bind(editor_id.as_deref())
.fetch_optional(&pool)
.await
.map_err(|err| napi_error(format!("DocStorage upsert snapshot failed: {err}")))?;
.map_err(|err| RuntimeError::database("DocStorage upsert snapshot failed", err))?;
Ok(row.is_some())
}
#[napi]
pub async fn create_doc_history(&self, input: RuntimeDocHistoryInput) -> Result<bool> {
pub async fn create_doc_history(&self, input: RuntimeDocHistoryInput) -> napi::Result<bool> {
if input.history_min_interval_ms < 0 {
return Err(napi_error("doc history interval must be non-negative"));
}
@@ -85,7 +89,7 @@ impl BackendRuntime {
}
let timestamp = DateTime::<Utc>::from_timestamp_millis(input.timestamp_ms)
.ok_or_else(|| napi_error(format!("Invalid doc history timestamp: {}", input.timestamp_ms)))?;
.ok_or_else(|| RuntimeError::invalid_input(format!("Invalid doc history timestamp: {}", input.timestamp_ms)))?;
let pool = self.pool().await?;
let should_create = match latest_history_timestamp(&pool, &input.workspace_id, &input.doc_id).await? {
None => true,
@@ -118,41 +122,41 @@ impl BackendRuntime {
.bind(input.editor_id.as_deref())
.execute(&pool)
.await
.map_err(|err| napi_error(format!("DocStorage create history failed: {err}")))?;
.map_err(|err| RuntimeError::database("DocStorage create history failed", err))?;
Ok(true)
}
#[napi]
pub async fn delete_doc_storage(&self, workspace_id: String, doc_id: String) -> Result<()> {
pub async fn delete_doc_storage(&self, workspace_id: String, doc_id: String) -> napi::Result<()> {
let pool = self.pool().await?;
let mut tx = pool
.begin()
.await
.map_err(|err| napi_error(format!("DocStorage delete begin transaction failed: {err}")))?;
.map_err(|err| RuntimeError::database("DocStorage delete begin transaction failed", err))?;
sqlx::query("DELETE FROM snapshots WHERE workspace_id = $1 AND guid = $2")
.bind(&workspace_id)
.bind(&doc_id)
.execute(&mut *tx)
.await
.map_err(|err| napi_error(format!("DocStorage delete snapshot failed: {err}")))?;
.map_err(|err| RuntimeError::database("DocStorage delete snapshot failed", err))?;
sqlx::query("DELETE FROM updates WHERE workspace_id = $1 AND guid = $2")
.bind(&workspace_id)
.bind(&doc_id)
.execute(&mut *tx)
.await
.map_err(|err| napi_error(format!("DocStorage delete updates failed: {err}")))?;
.map_err(|err| RuntimeError::database("DocStorage delete updates failed", err))?;
sqlx::query("DELETE FROM snapshot_histories WHERE workspace_id = $1 AND guid = $2")
.bind(&workspace_id)
.bind(&doc_id)
.execute(&mut *tx)
.await
.map_err(|err| napi_error(format!("DocStorage delete histories failed: {err}")))?;
.map_err(|err| RuntimeError::database("DocStorage delete histories failed", err))?;
tx.commit()
.await
.map_err(|err| napi_error(format!("DocStorage delete commit failed: {err}")))?;
.map_err(|err| RuntimeError::database("DocStorage delete commit failed", err))?;
Ok(())
}
}
@@ -1,7 +1,7 @@
use napi::Result;
use sqlx::PgPool;
use super::{BackendRuntime, error::napi_error};
use super::{BackendRuntime, RuntimeError, RuntimeResult, napi_error};
struct RuntimeGateStore {
pool: PgPool,
@@ -12,18 +12,18 @@ impl RuntimeGateStore {
Self { pool }
}
async fn put_if_absent(&self, key: &str, ttl_ms: i64) -> Result<bool> {
async fn put_if_absent(&self, key: &str, ttl_ms: i64) -> RuntimeResult<bool> {
let mut tx = self
.pool
.begin()
.await
.map_err(|err| napi_error(format!("RuntimeGate transaction failed: {err}")))?;
.map_err(|err| RuntimeError::database("RuntimeGate transaction failed", err))?;
sqlx::query("DELETE FROM runtime_gates WHERE key = $1 AND expires_at <= CURRENT_TIMESTAMP")
.bind(key)
.execute(&mut *tx)
.await
.map_err(|err| napi_error(format!("RuntimeGate expired cleanup failed: {err}")))?;
.map_err(|err| RuntimeError::database("RuntimeGate expired cleanup failed", err))?;
let inserted = sqlx::query(
r#"
@@ -36,18 +36,18 @@ impl RuntimeGateStore {
.bind(ttl_ms as f64)
.execute(&mut *tx)
.await
.map_err(|err| napi_error(format!("RuntimeGate put_if_absent failed: {err}")))?
.map_err(|err| RuntimeError::database("RuntimeGate put_if_absent failed", err))?
.rows_affected()
== 1;
tx.commit()
.await
.map_err(|err| napi_error(format!("RuntimeGate transaction commit failed: {err}")))?;
.map_err(|err| RuntimeError::database("RuntimeGate transaction commit failed", err))?;
Ok(inserted)
}
async fn cleanup_expired(&self, limit: i64) -> Result<i64> {
async fn cleanup_expired(&self, limit: i64) -> RuntimeResult<i64> {
let result = sqlx::query(
r#"
DELETE FROM runtime_gates
@@ -62,7 +62,7 @@ impl RuntimeGateStore {
.bind(limit)
.execute(&self.pool)
.await
.map_err(|err| napi_error(format!("RuntimeGate cleanup failed: {err}")))?;
.map_err(|err| RuntimeError::database("RuntimeGate cleanup failed", err))?;
Ok(result.rows_affected() as i64)
}
@@ -78,6 +78,7 @@ impl BackendRuntime {
RuntimeGateStore::new(self.pool().await?)
.put_if_absent(&key, ttl_ms)
.await
.map_err(napi::Error::from)
}
#[napi]
@@ -85,6 +86,9 @@ impl BackendRuntime {
if limit <= 0 {
return Err(napi_error("runtime gate cleanup limit must be positive"));
}
RuntimeGateStore::new(self.pool().await?).cleanup_expired(limit).await
RuntimeGateStore::new(self.pool().await?)
.cleanup_expired(limit)
.await
.map_err(napi::Error::from)
}
}
@@ -1,7 +1,7 @@
use napi::Result;
use sqlx::PgPool;
use super::{BackendRuntime, error::napi_error};
use super::{BackendRuntime, RuntimeError, RuntimeResult, napi_error};
struct HousekeepingStore {
pool: PgPool,
@@ -12,7 +12,7 @@ impl HousekeepingStore {
Self { pool }
}
async fn cleanup_expired_user_sessions(&self, limit: i64) -> Result<i64> {
async fn cleanup_expired_user_sessions(&self, limit: i64) -> RuntimeResult<i64> {
let result = sqlx::query(
r#"
DELETE FROM user_sessions
@@ -27,12 +27,12 @@ impl HousekeepingStore {
.bind(limit)
.execute(&self.pool)
.await
.map_err(|err| napi_error(format!("Housekeeping user sessions cleanup failed: {err}")))?;
.map_err(|err| RuntimeError::database("Housekeeping user sessions cleanup failed", err))?;
Ok(result.rows_affected() as i64)
}
async fn cleanup_expired_snapshot_histories(&self, limit: i64) -> Result<i64> {
async fn cleanup_expired_snapshot_histories(&self, limit: i64) -> RuntimeResult<i64> {
let result = sqlx::query(
r#"
DELETE FROM snapshot_histories
@@ -48,7 +48,7 @@ impl HousekeepingStore {
.bind(limit)
.execute(&self.pool)
.await
.map_err(|err| napi_error(format!("Housekeeping snapshot histories cleanup failed: {err}")))?;
.map_err(|err| RuntimeError::database("Housekeeping snapshot histories cleanup failed", err))?;
Ok(result.rows_affected() as i64)
}
@@ -65,6 +65,7 @@ impl BackendRuntime {
HousekeepingStore::new(self.pool().await?)
.cleanup_expired_user_sessions(limit)
.await
.map_err(napi::Error::from)
}
#[napi]
@@ -76,5 +77,6 @@ impl BackendRuntime {
HousekeepingStore::new(self.pool().await?)
.cleanup_expired_snapshot_histories(limit)
.await
.map_err(napi::Error::from)
}
}
@@ -1,28 +1,25 @@
mod blob_complete;
mod blob_reclaimer;
mod config;
mod constants;
mod coordination_lease;
mod doc_compactor;
mod doc_storage;
mod error;
mod gate;
mod housekeeping;
mod object_storage;
mod runtime_state;
#[cfg(test)]
mod tests;
mod types;
mod workspace_stats;
use std::time::Duration;
use std::{sync::RwLock, time::Duration};
use napi::Result;
use sha2::{Digest, Sha256};
use sqlx::{PgPool, Row, postgres::PgPoolOptions};
use tokio::sync::Mutex;
use self::{config::RuntimeConfig, constants::RUNTIME_MIGRATIONS, error::napi_error, types::BackendRuntimeHealth};
use self::types::BackendRuntimeHealth;
pub(crate) use super::types;
use super::{
BackendRuntimeConfig, RuntimeError, RuntimeResult, migrations::migrate_runtime_tables, napi_error, to_napi_error,
};
pub(super) fn token_hash(token: &str) -> String {
hex::encode(Sha256::digest(token.as_bytes()))
@@ -30,7 +27,7 @@ pub(super) fn token_hash(token: &str) -> String {
#[napi_derive::napi]
pub struct BackendRuntime {
config: RuntimeConfig,
config: RwLock<BackendRuntimeConfig>,
pool: Mutex<Option<PgPool>>,
}
@@ -39,29 +36,37 @@ impl BackendRuntime {
#[napi(constructor)]
pub fn new() -> Result<Self> {
Ok(Self {
config: RuntimeConfig::from_config_files()?,
config: RwLock::new(BackendRuntimeConfig::from_config_files().map_err(to_napi_error)?),
pool: Mutex::new(None),
})
}
#[napi]
pub async fn start(&self) -> Result<()> {
self.start_inner().await.map_err(to_napi_error)
}
async fn start_inner(&self) -> RuntimeResult<()> {
let mut guard = self.pool.lock().await;
if guard.is_some() {
return Ok(());
}
let database_url = self.config()?.database_url;
let pool = PgPoolOptions::new()
.max_connections(5)
.acquire_timeout(Duration::from_secs(5))
.connect(&self.config.database_url)
.connect(&database_url)
.await
.map_err(|err| napi_error(format!("BackendRuntime failed to connect postgres: {err}")))?;
.map_err(|err| RuntimeError::database("BackendRuntime failed to connect postgres", err))?;
sqlx::query("SELECT 1")
.execute(&pool)
.await
.map_err(|err| napi_error(format!("BackendRuntime postgres health check failed: {err}")))?;
.map_err(|err| RuntimeError::database("BackendRuntime postgres health check failed", err))?;
let config = self.config()?.with_db_overrides(&pool).await?;
self.update_config(config)?;
*guard = Some(pool);
Ok(())
@@ -91,38 +96,38 @@ impl BackendRuntime {
Ok(BackendRuntimeHealth {
started: pool.is_some(),
database_connected,
object_storage_configured: self.config.storage.is_some(),
})
}
#[napi]
pub async fn run_migrations(&self) -> Result<()> {
let pool = self.pool().await?;
migrate_runtime_tables(&pool).await
migrate_runtime_tables(&pool).await.map_err(to_napi_error)
}
async fn pool(&self) -> Result<PgPool> {
pub(crate) async fn pool(&self) -> RuntimeResult<PgPool> {
self
.pool
.lock()
.await
.as_ref()
.cloned()
.ok_or_else(|| napi_error("BackendRuntime must be started before using postgres operations"))
}
}
async fn migrate_runtime_tables(pool: &PgPool) -> Result<()> {
for statement in RUNTIME_MIGRATIONS
.split(';')
.map(str::trim)
.filter(|statement| !statement.is_empty())
{
sqlx::query(statement)
.execute(pool)
.await
.map_err(|err| napi_error(format!("BackendRuntime migration failed: {err}")))?;
.ok_or_else(|| RuntimeError::invalid_state("BackendRuntime must be started before using postgres operations"))
}
Ok(())
pub(crate) fn config(&self) -> RuntimeResult<BackendRuntimeConfig> {
self
.config
.read()
.map(|config| config.clone())
.map_err(|_| RuntimeError::invalid_state("BackendRuntime config lock poisoned"))
}
fn update_config(&self, config: BackendRuntimeConfig) -> RuntimeResult<()> {
*self
.config
.write()
.map_err(|_| RuntimeError::invalid_state("BackendRuntime config lock poisoned"))? = config;
Ok(())
}
}
@@ -1,6 +1,4 @@
use napi::Result;
use super::{auth_challenge_purpose, dto::RuntimeStateRows};
use super::{Result, auth_challenge_purpose, dto::RuntimeStateRows};
pub(super) async fn create(
rows: &RuntimeStateRows,
@@ -1,10 +1,6 @@
use napi::Result;
use super::dto::{RuntimeStateInsertPayload, RuntimeStatePayloadRow, RuntimeStateRows};
use crate::backend_runtime::{
constants::{BYOK_LOCAL_LEASE_ACTIVE_PURPOSE, BYOK_LOCAL_LEASE_PURPOSE},
error::napi_error,
types::RuntimeByokLocalLeaseRecord,
use super::{
BYOK_LOCAL_LEASE_ACTIVE_PURPOSE, BYOK_LOCAL_LEASE_PURPOSE, Result, RuntimeByokLocalLeaseRecord, RuntimeError,
dto::{RuntimeStateInsertPayload, RuntimeStatePayloadRow, RuntimeStateRows},
};
pub(super) async fn get(rows: &RuntimeStateRows, lease_id: String) -> Result<Option<RuntimeByokLocalLeaseRecord>> {
@@ -19,7 +15,7 @@ pub(super) async fn create(
ttl_ms: i64,
) -> Result<RuntimeByokLocalLeaseRecord> {
if ttl_ms <= 0 {
return Err(napi_error("BYOK local lease ttl must be positive"));
return Err(RuntimeError::invalid_input("BYOK local lease ttl must be positive"));
}
let mut tx = rows.begin("RuntimeState BYOK local lease").await?;
@@ -27,7 +23,7 @@ pub(super) async fn create(
.bind(&active_key)
.execute(&mut *tx)
.await
.map_err(|err| napi_error(format!("RuntimeState BYOK local lease active lock failed: {err}")))?;
.map_err(|err| RuntimeError::database("RuntimeState BYOK local lease active lock failed", err))?;
if let Some(active) = rows
.active_payload_with_expires_for_update_in_tx(
@@ -43,11 +39,9 @@ pub(super) async fn create(
None => None,
};
if let Some(lease) = existing_lease {
tx.commit().await.map_err(|err| {
napi_error(format!(
"RuntimeState BYOK local lease transaction commit failed: {err}"
))
})?;
tx.commit()
.await
.map_err(|err| RuntimeError::database("RuntimeState BYOK local lease transaction commit failed", err))?;
return Ok(lease);
}
@@ -89,11 +83,9 @@ pub(super) async fn create(
)
.await?;
tx.commit().await.map_err(|err| {
napi_error(format!(
"RuntimeState BYOK local lease transaction commit failed: {err}"
))
})?;
tx.commit()
.await
.map_err(|err| RuntimeError::database("RuntimeState BYOK local lease transaction commit failed", err))?;
Ok(RuntimeByokLocalLeaseRecord {
lease_id,
@@ -1,7 +1,8 @@
use napi::Result;
use sqlx::{PgPool, Row};
use crate::backend_runtime::{error::napi_error, token_hash};
use super::{RuntimeError, RuntimeResult, token_hash};
type Result<T> = RuntimeResult<T>;
pub(super) struct RuntimeStatePayloadRow {
pub(super) payload: serde_json::Value,
@@ -42,7 +43,7 @@ impl RuntimeStateRows {
.pool
.begin()
.await
.map_err(|err| napi_error(format!("{context} transaction failed: {err}")))
.map_err(|err| RuntimeError::database(format!("{context} transaction failed"), err))
}
pub(super) async fn insert_payload(
@@ -67,7 +68,7 @@ impl RuntimeStateRows {
.bind(ttl_ms as f64)
.execute(&self.pool)
.await
.map_err(|err| napi_error(format!("{context} failed: {err}")))?;
.map_err(|err| RuntimeError::database(context, err))?;
Ok(())
}
@@ -95,7 +96,7 @@ impl RuntimeStateRows {
.bind(ttl_ms as f64)
.execute(&self.pool)
.await
.map_err(|err| napi_error(format!("{context} failed: {err}")))?
.map_err(|err| RuntimeError::database(context, err))?
.rows_affected()
== 1;
@@ -131,7 +132,7 @@ impl RuntimeStateRows {
.bind(ttl_ms as f64)
.execute(&self.pool)
.await
.map_err(|err| napi_error(format!("{context} failed: {err}")))?;
.map_err(|err| RuntimeError::database(context, err))?;
Ok(())
}
@@ -156,7 +157,7 @@ impl RuntimeStateRows {
.bind(token_hash(token))
.fetch_optional(&self.pool)
.await
.map_err(|err| napi_error(format!("{context} failed: {err}")))?;
.map_err(|err| RuntimeError::database(context, err))?;
Ok(row.map(|row| row.get::<serde_json::Value, _>("payload")))
}
@@ -181,7 +182,7 @@ impl RuntimeStateRows {
.bind(token_hash(token))
.fetch_optional(&self.pool)
.await
.map_err(|err| napi_error(format!("{context} failed: {err}")))?;
.map_err(|err| RuntimeError::database(context, err))?;
Ok(row.map(payload_row))
}
@@ -208,7 +209,7 @@ impl RuntimeStateRows {
.bind(token_hash(token))
.fetch_optional(&self.pool)
.await
.map_err(|err| napi_error(format!("{context} failed: {err}")))?;
.map_err(|err| RuntimeError::database(context, err))?;
Ok(row.map(|row| row.get::<serde_json::Value, _>("payload")))
}
@@ -235,7 +236,7 @@ impl RuntimeStateRows {
.bind(token_hash(token))
.fetch_optional(&self.pool)
.await
.map_err(|err| napi_error(format!("{context} failed: {err}")))?;
.map_err(|err| RuntimeError::database(context, err))?;
Ok(row.map(payload_row))
}
@@ -262,7 +263,7 @@ impl RuntimeStateRows {
.bind(token_hash(token))
.fetch_optional(&mut **tx)
.await
.map_err(|err| napi_error(format!("{context} failed: {err}")))?;
.map_err(|err| RuntimeError::database(context, err))?;
Ok(row.map(payload_row))
}
@@ -288,7 +289,7 @@ impl RuntimeStateRows {
.bind(token_hash(token))
.fetch_optional(&mut **tx)
.await
.map_err(|err| napi_error(format!("{context} failed: {err}")))?;
.map_err(|err| RuntimeError::database(context, err))?;
Ok(row.map(|row| RuntimeStateLockedRow {
payload: row.get("payload"),
@@ -316,7 +317,7 @@ impl RuntimeStateRows {
.bind(input.ttl_ms as f64)
.fetch_one(&mut **tx)
.await
.map_err(|err| napi_error(format!("{} failed: {err}", input.context)))?;
.map_err(|err| RuntimeError::database(input.context, err))?;
Ok(row.get::<i64, _>("expires_at_ms"))
}
@@ -348,7 +349,7 @@ impl RuntimeStateRows {
.bind(input.ttl_ms as f64)
.fetch_optional(&mut **tx)
.await
.map_err(|err| napi_error(format!("{} failed: {err}", input.context)))?;
.map_err(|err| RuntimeError::database(input.context, err))?;
Ok(row.map(|row| row.get::<i64, _>("expires_at_ms")))
}
@@ -375,7 +376,7 @@ impl RuntimeStateRows {
.bind(attempts)
.execute(&mut **tx)
.await
.map_err(|err| napi_error(format!("{context} failed: {err}")))?;
.map_err(|err| RuntimeError::database(context, err))?;
Ok(())
}
@@ -392,7 +393,7 @@ impl RuntimeStateRows {
.bind(token_hash(token))
.execute(&mut **tx)
.await
.map_err(|err| napi_error(format!("{context} failed: {err}")))?;
.map_err(|err| RuntimeError::database(context, err))?;
Ok(())
}
@@ -413,7 +414,7 @@ impl RuntimeStateRows {
.bind(limit)
.execute(&self.pool)
.await
.map_err(|err| napi_error(format!("{context} failed: {err}")))?;
.map_err(|err| RuntimeError::database(context, err))?;
Ok(result.rows_affected() as i64)
}
@@ -440,7 +441,7 @@ impl RuntimeStateRows {
.bind(limit)
.execute(&self.pool)
.await
.map_err(|err| napi_error(format!("{context} failed: {err}")))?;
.map_err(|err| RuntimeError::database(context, err))?;
Ok(result.rows_affected() as i64)
}
@@ -1,10 +1,7 @@
use napi::Result;
use super::dto::{RuntimeStateInsertPayload, RuntimeStatePayloadRow, RuntimeStateRows};
use crate::backend_runtime::{
constants::{WORKSPACE_INVITE_LINK_ID_PURPOSE, WORKSPACE_INVITE_LINK_WORKSPACE_PURPOSE},
error::napi_error,
types::RuntimeWorkspaceInviteLinkRecord,
use super::{
Result, RuntimeError, RuntimeWorkspaceInviteLinkRecord, WORKSPACE_INVITE_LINK_ID_PURPOSE,
WORKSPACE_INVITE_LINK_WORKSPACE_PURPOSE,
dto::{RuntimeStateInsertPayload, RuntimeStatePayloadRow, RuntimeStateRows},
};
pub(super) async fn get_by_workspace(
@@ -29,7 +26,9 @@ pub(super) async fn create(
ttl_ms: i64,
) -> Result<RuntimeWorkspaceInviteLinkRecord> {
if ttl_ms <= 0 {
return Err(napi_error("workspace invite link ttl must be positive"));
return Err(RuntimeError::invalid_input(
"workspace invite link ttl must be positive",
));
}
let mut tx = rows.begin("RuntimeState workspace invite link").await?;
@@ -37,16 +36,14 @@ pub(super) async fn create(
.bind(&workspace_id)
.execute(&mut *tx)
.await
.map_err(|err| napi_error(format!("RuntimeState workspace invite link active lock failed: {err}")))?;
.map_err(|err| RuntimeError::database("RuntimeState workspace invite link active lock failed", err))?;
if let Some(existing) =
get_by_key_in_tx(rows, &mut tx, WORKSPACE_INVITE_LINK_WORKSPACE_PURPOSE, &workspace_id).await?
{
tx.commit().await.map_err(|err| {
napi_error(format!(
"RuntimeState workspace invite link transaction commit failed: {err}"
))
})?;
tx.commit()
.await
.map_err(|err| RuntimeError::database("RuntimeState workspace invite link transaction commit failed", err))?;
return Ok(existing);
}
@@ -71,12 +68,11 @@ pub(super) async fn create(
.await?
else {
let existing = get_by_key_in_tx(rows, &mut tx, WORKSPACE_INVITE_LINK_WORKSPACE_PURPOSE, &workspace_id).await?;
tx.commit().await.map_err(|err| {
napi_error(format!(
"RuntimeState workspace invite link transaction commit failed: {err}"
))
})?;
return existing.ok_or_else(|| napi_error("RuntimeState workspace invite link active conflict missing row"));
tx.commit()
.await
.map_err(|err| RuntimeError::database("RuntimeState workspace invite link transaction commit failed", err))?;
return existing
.ok_or_else(|| RuntimeError::invalid_state("RuntimeState workspace invite link active conflict missing row"));
};
rows
.insert_payload_returning_expires_in_tx(
@@ -92,11 +88,9 @@ pub(super) async fn create(
)
.await?;
tx.commit().await.map_err(|err| {
napi_error(format!(
"RuntimeState workspace invite link transaction commit failed: {err}"
))
})?;
tx.commit()
.await
.map_err(|err| RuntimeError::database("RuntimeState workspace invite link transaction commit failed", err))?;
Ok(RuntimeWorkspaceInviteLinkRecord {
workspace_id,
@@ -110,11 +104,9 @@ pub(super) async fn revoke(rows: &RuntimeStateRows, workspace_id: String) -> Res
let mut tx = rows.begin("RuntimeState workspace invite link").await?;
let existing = get_by_key_in_tx(rows, &mut tx, WORKSPACE_INVITE_LINK_WORKSPACE_PURPOSE, &workspace_id).await?;
let Some(existing) = existing else {
tx.commit().await.map_err(|err| {
napi_error(format!(
"RuntimeState workspace invite link transaction commit failed: {err}"
))
})?;
tx.commit()
.await
.map_err(|err| RuntimeError::database("RuntimeState workspace invite link transaction commit failed", err))?;
return Ok(false);
};
@@ -135,11 +127,9 @@ pub(super) async fn revoke(rows: &RuntimeStateRows, workspace_id: String) -> Res
)
.await?;
tx.commit().await.map_err(|err| {
napi_error(format!(
"RuntimeState workspace invite link transaction commit failed: {err}"
))
})?;
tx.commit()
.await
.map_err(|err| RuntimeError::database("RuntimeState workspace invite link transaction commit failed", err))?;
Ok(true)
}
@@ -175,19 +165,19 @@ fn record_from_row(row: RuntimeStatePayloadRow) -> Result<RuntimeWorkspaceInvite
.payload
.get("workspaceId")
.and_then(serde_json::Value::as_str)
.ok_or_else(|| napi_error("RuntimeState workspace invite link payload missing workspaceId"))?
.ok_or_else(|| RuntimeError::invalid_state("RuntimeState workspace invite link payload missing workspaceId"))?
.to_string(),
invite_id: row
.payload
.get("inviteId")
.and_then(serde_json::Value::as_str)
.ok_or_else(|| napi_error("RuntimeState workspace invite link payload missing inviteId"))?
.ok_or_else(|| RuntimeError::invalid_state("RuntimeState workspace invite link payload missing inviteId"))?
.to_string(),
inviter_user_id: row
.payload
.get("inviterUserId")
.and_then(serde_json::Value::as_str)
.ok_or_else(|| napi_error("RuntimeState workspace invite link payload missing inviterUserId"))?
.ok_or_else(|| RuntimeError::invalid_state("RuntimeState workspace invite link payload missing inviterUserId"))?
.to_string(),
expires_at_ms: row.expires_at_ms,
})
@@ -1,10 +1,6 @@
use napi::Result;
use super::dto::RuntimeStateRows;
use crate::backend_runtime::{
constants::{MAGIC_LINK_OTP_PURPOSE, MAX_MAGIC_LINK_OTP_ATTEMPTS},
error::napi_error,
types::RuntimeMagicLinkOtpConsumeResult,
use super::{
MAGIC_LINK_OTP_PURPOSE, MAX_MAGIC_LINK_OTP_ATTEMPTS, Result, RuntimeError, RuntimeMagicLinkOtpConsumeResult,
dto::RuntimeStateRows,
};
impl RuntimeMagicLinkOtpConsumeResult {
@@ -34,7 +30,7 @@ pub(super) async fn upsert(
ttl_ms: i64,
) -> Result<()> {
if ttl_ms <= 0 {
return Err(napi_error("magic link otp ttl must be positive"));
return Err(RuntimeError::invalid_input("magic link otp ttl must be positive"));
}
let payload = serde_json::json!({
@@ -73,11 +69,9 @@ pub(super) async fn consume(
.await?;
let Some(row) = row else {
tx.rollback().await.map_err(|err| {
napi_error(format!(
"RuntimeState magic link otp transaction rollback failed: {err}"
))
})?;
tx.rollback()
.await
.map_err(|err| RuntimeError::database("RuntimeState magic link otp transaction rollback failed", err))?;
return Ok(RuntimeMagicLinkOtpConsumeResult::fail("not_found"));
};
@@ -96,7 +90,7 @@ pub(super) async fn consume(
.await?;
tx.commit()
.await
.map_err(|err| napi_error(format!("RuntimeState magic link otp transaction commit failed: {err}")))?;
.map_err(|err| RuntimeError::database("RuntimeState magic link otp transaction commit failed", err))?;
return Ok(RuntimeMagicLinkOtpConsumeResult::fail("expired"));
}
@@ -104,7 +98,7 @@ pub(super) async fn consume(
if stored_client_nonce.is_some() && stored_client_nonce != client_nonce.as_deref() {
tx.commit()
.await
.map_err(|err| napi_error(format!("RuntimeState magic link otp transaction commit failed: {err}")))?;
.map_err(|err| RuntimeError::database("RuntimeState magic link otp transaction commit failed", err))?;
return Ok(RuntimeMagicLinkOtpConsumeResult::fail("nonce_mismatch"));
}
@@ -119,7 +113,7 @@ pub(super) async fn consume(
.await?;
tx.commit()
.await
.map_err(|err| napi_error(format!("RuntimeState magic link otp transaction commit failed: {err}")))?;
.map_err(|err| RuntimeError::database("RuntimeState magic link otp transaction commit failed", err))?;
return Ok(RuntimeMagicLinkOtpConsumeResult::fail("locked"));
}
@@ -137,7 +131,7 @@ pub(super) async fn consume(
.await?;
tx.commit()
.await
.map_err(|err| napi_error(format!("RuntimeState magic link otp transaction commit failed: {err}")))?;
.map_err(|err| RuntimeError::database("RuntimeState magic link otp transaction commit failed", err))?;
return Ok(RuntimeMagicLinkOtpConsumeResult::fail("locked"));
}
@@ -153,14 +147,14 @@ pub(super) async fn consume(
tx.commit()
.await
.map_err(|err| napi_error(format!("RuntimeState magic link otp transaction commit failed: {err}")))?;
.map_err(|err| RuntimeError::database("RuntimeState magic link otp transaction commit failed", err))?;
return Ok(RuntimeMagicLinkOtpConsumeResult::fail("invalid_otp"));
}
let token = payload
.get("token")
.and_then(serde_json::Value::as_str)
.ok_or_else(|| napi_error("RuntimeState magic link otp payload missing token"))?
.ok_or_else(|| RuntimeError::invalid_state("RuntimeState magic link otp payload missing token"))?
.to_string();
rows
.delete_by_key_in_tx(
@@ -172,7 +166,7 @@ pub(super) async fn consume(
.await?;
tx.commit()
.await
.map_err(|err| napi_error(format!("RuntimeState magic link otp transaction commit failed: {err}")))?;
.map_err(|err| RuntimeError::database("RuntimeState magic link otp transaction commit failed", err))?;
Ok(RuntimeMagicLinkOtpConsumeResult::ok(token))
}
@@ -1,8 +1,10 @@
use napi::Result;
use super::{
BackendRuntime,
error::napi_error,
use super::{BackendRuntime, RuntimeError, RuntimeResult, napi_error};
pub(super) use super::{
constants::{
BYOK_LOCAL_LEASE_ACTIVE_PURPOSE, BYOK_LOCAL_LEASE_PURPOSE, MAGIC_LINK_OTP_PURPOSE, MAX_MAGIC_LINK_OTP_ATTEMPTS,
WORKSPACE_INVITE_LINK_ID_PURPOSE, WORKSPACE_INVITE_LINK_WORKSPACE_PURPOSE,
},
token_hash,
types::{
RuntimeByokLocalLeaseRecord, RuntimeMagicLinkOtpConsumeResult, RuntimeVerificationTokenRecord,
RuntimeWorkspaceInviteLinkRecord,
@@ -18,6 +20,8 @@ mod store;
mod verification_token;
use store::RuntimeStateStore;
pub(super) type Result<T> = RuntimeResult<T>;
pub(super) fn auth_challenge_purpose(purpose: &str) -> String {
format!("auth_challenge:{purpose}")
}
@@ -35,27 +39,34 @@ impl BackendRuntime {
token: String,
payload: serde_json::Value,
ttl_ms: i64,
) -> Result<bool> {
) -> napi::Result<bool> {
if ttl_ms <= 0 {
return Err(napi_error("auth challenge ttl must be positive"));
}
RuntimeStateStore::new(self.pool().await?)
.create_auth_challenge(&purpose, &token, payload, ttl_ms)
.await
.map_err(napi::Error::from)
}
#[napi]
pub async fn get_auth_challenge(&self, purpose: String, token: String) -> Result<Option<serde_json::Value>> {
pub async fn get_auth_challenge(&self, purpose: String, token: String) -> napi::Result<Option<serde_json::Value>> {
RuntimeStateStore::new(self.pool().await?)
.get_auth_challenge(&purpose, &token)
.await
.map_err(napi::Error::from)
}
#[napi]
pub async fn consume_auth_challenge(&self, purpose: String, token: String) -> Result<Option<serde_json::Value>> {
pub async fn consume_auth_challenge(
&self,
purpose: String,
token: String,
) -> napi::Result<Option<serde_json::Value>> {
RuntimeStateStore::new(self.pool().await?)
.consume_auth_challenge(&purpose, &token)
.await
.map_err(napi::Error::from)
}
#[napi]
@@ -64,13 +75,14 @@ impl BackendRuntime {
token_type: i32,
credential: Option<String>,
ttl_ms: i64,
) -> Result<String> {
) -> napi::Result<String> {
if ttl_ms <= 0 {
return Err(napi_error("verification token ttl must be positive"));
}
RuntimeStateStore::new(self.pool().await?)
.create_verification_token(token_type, credential, ttl_ms)
.await
.map_err(napi::Error::from)
}
#[napi]
@@ -79,11 +91,12 @@ impl BackendRuntime {
token_type: i32,
token: String,
keep: Option<bool>,
) -> Result<Option<RuntimeVerificationTokenRecord>> {
) -> napi::Result<Option<RuntimeVerificationTokenRecord>> {
let keep = keep.unwrap_or(false);
RuntimeStateStore::new(self.pool().await?)
.get_verification_token(token_type, token, keep)
.await
.map_err(napi::Error::from)
}
#[napi]
@@ -93,21 +106,23 @@ impl BackendRuntime {
token: String,
credential: Option<String>,
keep: Option<bool>,
) -> Result<Option<RuntimeVerificationTokenRecord>> {
) -> napi::Result<Option<RuntimeVerificationTokenRecord>> {
let keep = keep.unwrap_or(false);
RuntimeStateStore::new(self.pool().await?)
.verify_verification_token(token_type, token, credential, keep)
.await
.map_err(napi::Error::from)
}
#[napi]
pub async fn cleanup_expired_verification_tokens(&self, limit: i64) -> Result<i64> {
pub async fn cleanup_expired_verification_tokens(&self, limit: i64) -> napi::Result<i64> {
if limit <= 0 {
return Err(napi_error("verification token cleanup limit must be positive"));
}
RuntimeStateStore::new(self.pool().await?)
.cleanup_expired_verification_tokens(limit)
.await
.map_err(napi::Error::from)
}
#[napi]
@@ -118,10 +133,11 @@ impl BackendRuntime {
token: String,
client_nonce: Option<String>,
ttl_ms: i64,
) -> Result<()> {
) -> napi::Result<()> {
RuntimeStateStore::new(self.pool().await?)
.upsert_magic_link_otp(email, otp_hash, token, client_nonce, ttl_ms)
.await
.map_err(napi::Error::from)
}
#[napi]
@@ -130,10 +146,11 @@ impl BackendRuntime {
email: String,
otp_hash: String,
client_nonce: Option<String>,
) -> Result<RuntimeMagicLinkOtpConsumeResult> {
) -> napi::Result<RuntimeMagicLinkOtpConsumeResult> {
RuntimeStateStore::new(self.pool().await?)
.consume_magic_link_otp(email, otp_hash, client_nonce)
.await
.map_err(napi::Error::from)
}
#[napi]
@@ -143,37 +160,41 @@ impl BackendRuntime {
invite_id: String,
inviter_user_id: String,
ttl_ms: i64,
) -> Result<RuntimeWorkspaceInviteLinkRecord> {
) -> napi::Result<RuntimeWorkspaceInviteLinkRecord> {
RuntimeStateStore::new(self.pool().await?)
.create_workspace_invite_link(workspace_id, invite_id, inviter_user_id, ttl_ms)
.await
.map_err(napi::Error::from)
}
#[napi]
pub async fn get_workspace_invite_link(
&self,
workspace_id: String,
) -> Result<Option<RuntimeWorkspaceInviteLinkRecord>> {
) -> napi::Result<Option<RuntimeWorkspaceInviteLinkRecord>> {
RuntimeStateStore::new(self.pool().await?)
.get_workspace_invite_link(workspace_id)
.await
.map_err(napi::Error::from)
}
#[napi]
pub async fn get_workspace_invite_link_by_id(
&self,
invite_id: String,
) -> Result<Option<RuntimeWorkspaceInviteLinkRecord>> {
) -> napi::Result<Option<RuntimeWorkspaceInviteLinkRecord>> {
RuntimeStateStore::new(self.pool().await?)
.get_workspace_invite_link_by_id(invite_id)
.await
.map_err(napi::Error::from)
}
#[napi]
pub async fn revoke_workspace_invite_link(&self, workspace_id: String) -> Result<bool> {
pub async fn revoke_workspace_invite_link(&self, workspace_id: String) -> napi::Result<bool> {
RuntimeStateStore::new(self.pool().await?)
.revoke_workspace_invite_link(workspace_id)
.await
.map_err(napi::Error::from)
}
#[napi]
@@ -183,35 +204,37 @@ impl BackendRuntime {
lease_id: String,
payload: serde_json::Value,
ttl_ms: i64,
) -> Result<RuntimeByokLocalLeaseRecord> {
) -> napi::Result<RuntimeByokLocalLeaseRecord> {
RuntimeStateStore::new(self.pool().await?)
.create_byok_local_lease(active_key, lease_id, payload, ttl_ms)
.await
.map_err(napi::Error::from)
}
#[napi]
pub async fn get_byok_local_lease(&self, lease_id: String) -> Result<Option<RuntimeByokLocalLeaseRecord>> {
pub async fn get_byok_local_lease(&self, lease_id: String) -> napi::Result<Option<RuntimeByokLocalLeaseRecord>> {
RuntimeStateStore::new(self.pool().await?)
.get_byok_local_lease(lease_id)
.await
.map_err(napi::Error::from)
}
#[napi]
pub async fn cleanup_expired_runtime_states(&self, limit: i64) -> Result<i64> {
pub async fn cleanup_expired_runtime_states(&self, limit: i64) -> napi::Result<i64> {
if limit <= 0 {
return Err(napi_error("runtime state cleanup limit must be positive"));
}
RuntimeStateStore::new(self.pool().await?)
.cleanup_expired_runtime_states(limit)
.await
.map_err(napi::Error::from)
}
}
#[cfg(test)]
mod tests {
use crate::backend_runtime::{
constants::{MAGIC_LINK_OTP_PURPOSE, WORKSPACE_INVITE_LINK_ID_PURPOSE, WORKSPACE_INVITE_LINK_WORKSPACE_PURPOSE},
token_hash,
use super::{
MAGIC_LINK_OTP_PURPOSE, WORKSPACE_INVITE_LINK_ID_PURPOSE, WORKSPACE_INVITE_LINK_WORKSPACE_PURPOSE, token_hash,
};
#[test]
@@ -1,10 +1,9 @@
use napi::Result;
use sqlx::PgPool;
use super::{auth_challenge, byok_local_lease, dto::RuntimeStateRows, invite_link, magic_link_otp, verification_token};
use crate::backend_runtime::types::{
RuntimeByokLocalLeaseRecord, RuntimeMagicLinkOtpConsumeResult, RuntimeVerificationTokenRecord,
RuntimeWorkspaceInviteLinkRecord,
use super::{
Result, RuntimeByokLocalLeaseRecord, RuntimeMagicLinkOtpConsumeResult, RuntimeVerificationTokenRecord,
RuntimeWorkspaceInviteLinkRecord, auth_challenge, byok_local_lease, dto::RuntimeStateRows, invite_link,
magic_link_otp, verification_token,
};
pub(super) struct RuntimeStateStore {
@@ -1,12 +1,11 @@
use napi::Result;
use sqlx::{PgPool, Row};
use uuid::Uuid;
use super::{
Result, RuntimeError, RuntimeVerificationTokenRecord,
dto::{RuntimeStatePayloadRow, RuntimeStateRows},
verification_token_purpose,
token_hash, verification_token_purpose,
};
use crate::backend_runtime::{error::napi_error, token_hash, types::RuntimeVerificationTokenRecord};
pub(super) async fn create(
rows: &RuntimeStateRows,
@@ -64,7 +63,7 @@ pub(super) async fn verify(
} else {
consume_payload_with_credential(rows.pool(), &purpose, &token, credential.as_deref()).await
}
.map_err(|err| napi_error(format!("RuntimeState verification token verify failed: {err}")))?;
.map_err(|err| RuntimeError::database("RuntimeState verification token verify failed", err))?;
Ok(row.map(|row| record_from_row(token_type, token, row)))
}
@@ -1,6 +1,10 @@
use anyhow::{Context, Result as AnyResult, anyhow};
use super::{runtime_state::*, *};
use super::{
super::migrations::{RUNTIME_MIGRATIONS, migrate_runtime_tables},
runtime_state::*,
*,
};
static PG_TEST_LOCK: std::sync::OnceLock<tokio::sync::Mutex<()>> = std::sync::OnceLock::new();
const TEST_VERIFICATION_TOKEN_TYPE: i32 = 99_999;
@@ -14,6 +18,10 @@ fn migrations_include_runtime_tables_without_worker_heartbeats() {
assert!(RUNTIME_MIGRATIONS.contains("runtime_states"));
assert!(RUNTIME_MIGRATIONS.contains("runtime_gates"));
assert!(RUNTIME_MIGRATIONS.contains("runtime_leases"));
assert!(RUNTIME_MIGRATIONS.contains("blob_reconciliation_runs"));
assert!(RUNTIME_MIGRATIONS.contains("blob_reconciliation_checkpoints"));
assert!(RUNTIME_MIGRATIONS.contains("doc_blob_refs"));
assert!(RUNTIME_MIGRATIONS.contains("blob_cleanup_candidates"));
assert!(!RUNTIME_MIGRATIONS.contains("runtime_worker_heartbeats"));
}
@@ -66,10 +74,7 @@ async fn runtime_from_database_url() -> AnyResult<Option<BackendRuntime>> {
.context("cleanup runtime_leases for backend runtime tests")?;
Ok(Some(BackendRuntime {
config: RuntimeConfig {
database_url,
storage: None,
},
config: std::sync::RwLock::new(BackendRuntimeConfig { database_url }),
pool: Mutex::new(Some(pool)),
}))
}
@@ -126,7 +131,7 @@ async fn runtime_gate_sql_semantics_are_atomic_and_ttl_bound() {
let mut tasks = Vec::new();
for _ in 0..16 {
let runtime = BackendRuntime {
config: runtime.config.clone(),
config: std::sync::RwLock::new(runtime.config().unwrap()),
pool: Mutex::new(Some(runtime.pool().await.unwrap())),
};
tasks.push(tokio::spawn(async move {
@@ -185,7 +190,7 @@ async fn coordination_lease_sql_semantics_are_fenced_and_ttl_bound() {
let mut tasks = Vec::new();
for index in 0..16 {
let runtime = BackendRuntime {
config: runtime.config.clone(),
config: std::sync::RwLock::new(runtime.config().unwrap()),
pool: Mutex::new(Some(runtime.pool().await.unwrap())),
};
tasks.push(tokio::spawn(async move {
@@ -389,7 +394,7 @@ async fn verification_token_sql_state_machine_handles_keep_verify_and_cleanup()
let mut tasks = Vec::new();
for _ in 0..16 {
let runtime = BackendRuntime {
config: runtime.config.clone(),
config: std::sync::RwLock::new(runtime.config().unwrap()),
pool: Mutex::new(Some(runtime.pool().await.unwrap())),
};
let token = concurrent_token.clone();
@@ -1,11 +1,10 @@
use napi::Result;
use sqlx::{FromRow, PgPool, Postgres, Row, Transaction};
use tokio::time::{Duration as TokioDuration, sleep};
use super::{
BackendRuntime,
BackendRuntime, RuntimeError, RuntimeResult,
constants::{WORKSPACE_STATS_LEASE_KEY, WORKSPACE_STATS_LOCK_NAMESPACE, WORKSPACE_STATS_REFRESH_LOCK_KEY},
error::napi_error,
napi_error,
types::{
CoordinationLeaseGrant, RuntimeWorkspaceStatsDailyRecalibrationResult, RuntimeWorkspaceStatsRecalibrationResult,
RuntimeWorkspaceStatsRefreshResult, RuntimeWorkspaceStatsSnapshotResult,
@@ -22,13 +21,13 @@ impl BackendRuntime {
batch_limit: i64,
owner: String,
lease_ttl_ms: i64,
) -> Result<RuntimeWorkspaceStatsRefreshResult> {
) -> napi::Result<RuntimeWorkspaceStatsRefreshResult> {
if batch_limit <= 0 {
return Err(napi_error("workspace stats dirty refresh limit must be positive"));
}
let Some(lease) = self
.acquire_coordination_lease(WORKSPACE_STATS_LEASE_KEY.to_string(), owner, lease_ttl_ms)
.acquire_coordination_lease_inner(WORKSPACE_STATS_LEASE_KEY.to_string(), owner, lease_ttl_ms)
.await?
else {
return Ok(RuntimeWorkspaceStatsRefreshResult {
@@ -46,7 +45,7 @@ impl BackendRuntime {
.await;
release_workspace_stats_lease(self, lease).await?;
result
Ok(result?)
}
#[napi]
@@ -56,13 +55,13 @@ impl BackendRuntime {
batch_limit: i64,
owner: String,
lease_ttl_ms: i64,
) -> Result<RuntimeWorkspaceStatsRecalibrationResult> {
) -> napi::Result<RuntimeWorkspaceStatsRecalibrationResult> {
if batch_limit <= 0 {
return Err(napi_error("workspace stats recalibration limit must be positive"));
}
let Some(lease) = self
.acquire_coordination_lease(WORKSPACE_STATS_LEASE_KEY.to_string(), owner, lease_ttl_ms)
.acquire_coordination_lease_inner(WORKSPACE_STATS_LEASE_KEY.to_string(), owner, lease_ttl_ms)
.await?
else {
return Ok(RuntimeWorkspaceStatsRecalibrationResult {
@@ -80,7 +79,7 @@ impl BackendRuntime {
.await;
release_workspace_stats_lease(self, lease).await?;
result
Ok(result?)
}
#[napi]
@@ -88,9 +87,9 @@ impl BackendRuntime {
&self,
owner: String,
lease_ttl_ms: i64,
) -> Result<RuntimeWorkspaceStatsSnapshotResult> {
) -> napi::Result<RuntimeWorkspaceStatsSnapshotResult> {
let Some(lease) = self
.acquire_coordination_lease(WORKSPACE_STATS_LEASE_KEY.to_string(), owner, lease_ttl_ms)
.acquire_coordination_lease_inner(WORKSPACE_STATS_LEASE_KEY.to_string(), owner, lease_ttl_ms)
.await?
else {
return Ok(RuntimeWorkspaceStatsSnapshotResult {
@@ -107,7 +106,7 @@ impl BackendRuntime {
.await;
release_workspace_stats_lease(self, lease).await?;
result
Ok(result?)
}
#[napi]
@@ -118,7 +117,7 @@ impl BackendRuntime {
lease_ttl_ms: i64,
lock_retry_times: i64,
lock_retry_delay_ms: i64,
) -> Result<RuntimeWorkspaceStatsDailyRecalibrationResult> {
) -> napi::Result<RuntimeWorkspaceStatsDailyRecalibrationResult> {
if batch_limit <= 0 {
return Err(napi_error("workspace stats daily recalibration limit must be positive"));
}
@@ -150,7 +149,7 @@ impl BackendRuntime {
});
};
let result = async {
let result: RuntimeResult<RuntimeWorkspaceStatsDailyRecalibrationResult> = async {
let store = WorkspaceStatsStore::new(self.pool().await?);
let mut processed = 0;
let mut last_sid = 0;
@@ -195,7 +194,7 @@ impl BackendRuntime {
.await;
release_workspace_stats_lease(self, lease).await?;
result
Ok(result?)
}
}
@@ -214,16 +213,16 @@ impl WorkspaceStatsStore {
Self { pool }
}
async fn refresh_dirty(&self, batch_limit: i64) -> Result<RuntimeWorkspaceStatsRefreshResult> {
async fn refresh_dirty(&self, batch_limit: i64) -> RuntimeResult<RuntimeWorkspaceStatsRefreshResult> {
let mut tx = self
.pool
.begin()
.await
.map_err(|err| napi_error(format!("WorkspaceStats dirty refresh transaction failed: {err}")))?;
.map_err(|err| RuntimeError::database("WorkspaceStats dirty refresh transaction failed", err))?;
if !try_transaction_lock(&mut tx).await? {
tx.commit()
.await
.map_err(|err| napi_error(format!("WorkspaceStats dirty refresh commit failed: {err}")))?;
.map_err(|err| RuntimeError::database("WorkspaceStats dirty refresh commit failed", err))?;
return Ok(RuntimeWorkspaceStatsRefreshResult {
processed: 0,
backlog: 0,
@@ -236,7 +235,7 @@ impl WorkspaceStatsStore {
if dirty.is_empty() {
tx.commit()
.await
.map_err(|err| napi_error(format!("WorkspaceStats dirty refresh commit failed: {err}")))?;
.map_err(|err| RuntimeError::database("WorkspaceStats dirty refresh commit failed", err))?;
return Ok(RuntimeWorkspaceStatsRefreshResult {
processed: 0,
backlog,
@@ -248,7 +247,7 @@ impl WorkspaceStatsStore {
clear_dirty(&mut tx, &dirty).await?;
tx.commit()
.await
.map_err(|err| napi_error(format!("WorkspaceStats dirty refresh commit failed: {err}")))?;
.map_err(|err| RuntimeError::database("WorkspaceStats dirty refresh commit failed", err))?;
Ok(RuntimeWorkspaceStatsRefreshResult {
processed: dirty.len() as i64,
@@ -257,16 +256,20 @@ impl WorkspaceStatsStore {
})
}
async fn recalibrate(&self, last_sid: i64, batch_limit: i64) -> Result<RuntimeWorkspaceStatsRecalibrationResult> {
async fn recalibrate(
&self,
last_sid: i64,
batch_limit: i64,
) -> RuntimeResult<RuntimeWorkspaceStatsRecalibrationResult> {
let mut tx = self
.pool
.begin()
.await
.map_err(|err| napi_error(format!("WorkspaceStats recalibration transaction failed: {err}")))?;
.map_err(|err| RuntimeError::database("WorkspaceStats recalibration transaction failed", err))?;
if !try_transaction_lock(&mut tx).await? {
tx.commit()
.await
.map_err(|err| napi_error(format!("WorkspaceStats recalibration commit failed: {err}")))?;
.map_err(|err| RuntimeError::database("WorkspaceStats recalibration commit failed", err))?;
return Ok(RuntimeWorkspaceStatsRecalibrationResult {
processed: 0,
last_sid,
@@ -278,7 +281,7 @@ impl WorkspaceStatsStore {
if workspaces.is_empty() {
tx.commit()
.await
.map_err(|err| napi_error(format!("WorkspaceStats recalibration commit failed: {err}")))?;
.map_err(|err| RuntimeError::database("WorkspaceStats recalibration commit failed", err))?;
return Ok(RuntimeWorkspaceStatsRecalibrationResult {
processed: 0,
last_sid,
@@ -297,7 +300,7 @@ impl WorkspaceStatsStore {
upsert_stats(&mut tx, &ids).await?;
tx.commit()
.await
.map_err(|err| napi_error(format!("WorkspaceStats recalibration commit failed: {err}")))?;
.map_err(|err| RuntimeError::database("WorkspaceStats recalibration commit failed", err))?;
Ok(RuntimeWorkspaceStatsRecalibrationResult {
processed: ids.len() as i64,
@@ -306,16 +309,16 @@ impl WorkspaceStatsStore {
})
}
async fn write_daily_snapshot(&self) -> Result<RuntimeWorkspaceStatsSnapshotResult> {
async fn write_daily_snapshot(&self) -> RuntimeResult<RuntimeWorkspaceStatsSnapshotResult> {
let mut tx = self
.pool
.begin()
.await
.map_err(|err| napi_error(format!("WorkspaceStats daily snapshot transaction failed: {err}")))?;
.map_err(|err| RuntimeError::database("WorkspaceStats daily snapshot transaction failed", err))?;
if !try_transaction_lock(&mut tx).await? {
tx.commit()
.await
.map_err(|err| napi_error(format!("WorkspaceStats daily snapshot commit failed: {err}")))?;
.map_err(|err| RuntimeError::database("WorkspaceStats daily snapshot commit failed", err))?;
return Ok(RuntimeWorkspaceStatsSnapshotResult {
snapshotted: 0,
skipped: true,
@@ -324,7 +327,7 @@ impl WorkspaceStatsStore {
let snapshotted = write_daily_snapshot(&mut tx).await?;
tx.commit()
.await
.map_err(|err| napi_error(format!("WorkspaceStats daily snapshot commit failed: {err}")))?;
.map_err(|err| RuntimeError::database("WorkspaceStats daily snapshot commit failed", err))?;
Ok(RuntimeWorkspaceStatsSnapshotResult {
snapshotted,
@@ -333,9 +336,9 @@ impl WorkspaceStatsStore {
}
}
async fn release_workspace_stats_lease(runtime: &BackendRuntime, lease: CoordinationLeaseGrant) -> Result<()> {
async fn release_workspace_stats_lease(runtime: &BackendRuntime, lease: CoordinationLeaseGrant) -> RuntimeResult<()> {
let _ = runtime
.release_coordination_lease(lease.key, lease.owner, lease.fencing_token)
.release_coordination_lease_inner(lease.key, lease.owner, lease.fencing_token)
.await?;
Ok(())
}
@@ -346,10 +349,10 @@ async fn acquire_workspace_stats_lease_with_retry(
lease_ttl_ms: i64,
retry_times: i64,
retry_delay_ms: i64,
) -> Result<Option<CoordinationLeaseGrant>> {
) -> RuntimeResult<Option<CoordinationLeaseGrant>> {
for attempt in 0..retry_times {
let lease = runtime
.acquire_coordination_lease(WORKSPACE_STATS_LEASE_KEY.to_string(), owner.clone(), lease_ttl_ms)
.acquire_coordination_lease_inner(WORKSPACE_STATS_LEASE_KEY.to_string(), owner.clone(), lease_ttl_ms)
.await?;
if lease.is_some() {
return Ok(lease);
@@ -367,11 +370,11 @@ async fn retry_workspace_stats_operation<T, F, Fut>(
retry_times: i64,
retry_delay_ms: i64,
mut operation: F,
) -> Result<T>
) -> RuntimeResult<T>
where
T: WorkspaceStatsSkippable,
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<T>>,
Fut: std::future::Future<Output = RuntimeResult<T>>,
{
for attempt in 0..retry_times {
let result = operation().await?;
@@ -403,7 +406,7 @@ impl WorkspaceStatsSkippable for RuntimeWorkspaceStatsSnapshotResult {
}
}
async fn try_transaction_lock(tx: &mut Transaction<'_, Postgres>) -> Result<bool> {
async fn try_transaction_lock(tx: &mut Transaction<'_, Postgres>) -> RuntimeResult<bool> {
let row = sqlx::query(
r#"
SELECT pg_try_advisory_xact_lock(($1::bigint << 32) + $2::bigint) AS locked
@@ -413,12 +416,12 @@ async fn try_transaction_lock(tx: &mut Transaction<'_, Postgres>) -> Result<bool
.bind(WORKSPACE_STATS_REFRESH_LOCK_KEY)
.fetch_one(&mut **tx)
.await
.map_err(|err| napi_error(format!("WorkspaceStats transaction lock failed: {err}")))?;
.map_err(|err| RuntimeError::database("WorkspaceStats transaction lock failed", err))?;
Ok(row.get::<bool, _>("locked"))
}
async fn load_dirty(tx: &mut Transaction<'_, Postgres>, limit: i64) -> Result<Vec<String>> {
async fn load_dirty(tx: &mut Transaction<'_, Postgres>, limit: i64) -> RuntimeResult<Vec<String>> {
let rows = sqlx::query(
r#"
SELECT workspace_id
@@ -431,20 +434,20 @@ async fn load_dirty(tx: &mut Transaction<'_, Postgres>, limit: i64) -> Result<Ve
.bind(limit)
.fetch_all(&mut **tx)
.await
.map_err(|err| napi_error(format!("WorkspaceStats load dirty workspaces failed: {err}")))?;
.map_err(|err| RuntimeError::database("WorkspaceStats load dirty workspaces failed", err))?;
Ok(rows.into_iter().map(|row| row.get("workspace_id")).collect())
}
async fn count_dirty(tx: &mut Transaction<'_, Postgres>) -> Result<i64> {
async fn count_dirty(tx: &mut Transaction<'_, Postgres>) -> RuntimeResult<i64> {
let row = sqlx::query("SELECT COUNT(*) AS total FROM workspace_admin_stats_dirty")
.fetch_one(&mut **tx)
.await
.map_err(|err| napi_error(format!("WorkspaceStats count dirty workspaces failed: {err}")))?;
.map_err(|err| RuntimeError::database("WorkspaceStats count dirty workspaces failed", err))?;
Ok(row.get::<i64, _>("total"))
}
async fn clear_dirty(tx: &mut Transaction<'_, Postgres>, workspace_ids: &[String]) -> Result<()> {
async fn clear_dirty(tx: &mut Transaction<'_, Postgres>, workspace_ids: &[String]) -> RuntimeResult<()> {
sqlx::query(
r#"
DELETE FROM workspace_admin_stats_dirty
@@ -454,11 +457,11 @@ async fn clear_dirty(tx: &mut Transaction<'_, Postgres>, workspace_ids: &[String
.bind(workspace_ids)
.execute(&mut **tx)
.await
.map_err(|err| napi_error(format!("WorkspaceStats clear dirty workspaces failed: {err}")))?;
.map_err(|err| RuntimeError::database("WorkspaceStats clear dirty workspaces failed", err))?;
Ok(())
}
async fn upsert_stats(tx: &mut Transaction<'_, Postgres>, workspace_ids: &[String]) -> Result<()> {
async fn upsert_stats(tx: &mut Transaction<'_, Postgres>, workspace_ids: &[String]) -> RuntimeResult<()> {
if workspace_ids.is_empty() {
return Ok(());
}
@@ -467,7 +470,7 @@ async fn upsert_stats(tx: &mut Transaction<'_, Postgres>, workspace_ids: &[Strin
.bind(workspace_ids)
.execute(&mut **tx)
.await
.map_err(|err| napi_error(format!("WorkspaceStats upsert stats failed: {err}")))?;
.map_err(|err| RuntimeError::database("WorkspaceStats upsert stats failed", err))?;
Ok(())
}
@@ -475,7 +478,7 @@ async fn fetch_workspace_batch(
tx: &mut Transaction<'_, Postgres>,
last_sid: i64,
limit: i64,
) -> Result<Vec<WorkspaceSid>> {
) -> RuntimeResult<Vec<WorkspaceSid>> {
sqlx::query_as::<_, WorkspaceSid>(
r#"
SELECT id, sid
@@ -489,10 +492,10 @@ async fn fetch_workspace_batch(
.bind(limit)
.fetch_all(&mut **tx)
.await
.map_err(|err| napi_error(format!("WorkspaceStats fetch workspace batch failed: {err}")))
.map_err(|err| RuntimeError::database("WorkspaceStats fetch workspace batch failed", err))
}
async fn write_daily_snapshot(tx: &mut Transaction<'_, Postgres>) -> Result<i64> {
async fn write_daily_snapshot(tx: &mut Transaction<'_, Postgres>) -> RuntimeResult<i64> {
let result = sqlx::query(
r#"
INSERT INTO workspace_admin_stats_daily (
@@ -521,7 +524,7 @@ async fn write_daily_snapshot(tx: &mut Transaction<'_, Postgres>) -> Result<i64>
)
.execute(&mut **tx)
.await
.map_err(|err| napi_error(format!("WorkspaceStats daily snapshot failed: {err}")))?;
.map_err(|err| RuntimeError::database("WorkspaceStats daily snapshot failed", err))?;
Ok(result.rows_affected() as i64)
}
@@ -0,0 +1,200 @@
use std::{
env, fs,
path::{Path, PathBuf},
};
use serde::Deserialize;
use serde_json::Map;
use sqlx::{PgPool, Row};
use super::{RuntimeError, RuntimeResult};
#[derive(Clone, Debug)]
pub(crate) struct BackendRuntimeConfig {
pub(crate) database_url: String,
}
impl BackendRuntimeConfig {
pub(crate) fn from_config_files() -> RuntimeResult<Self> {
let app_config = app_config_from_config_files()?;
let database_url = database_url_from_env()
.or(app_config.database_url())
.unwrap_or_else(|| "postgresql://localhost:5432/affine".to_string());
Ok(Self { database_url })
}
pub(crate) async fn with_db_overrides(&self, pool: &PgPool) -> RuntimeResult<Self> {
let mut app_config = app_config_from_config_files()?;
app_config.apply_file_config(load_app_config_overrides_from_db(pool).await?);
Ok(Self {
// The DB override is loaded after this connection already exists, so it
// must not rewrite the active datasource URL.
database_url: self.database_url.clone(),
})
}
}
#[derive(Debug, Default, Deserialize)]
struct AppConfigFile {
db: Option<DbConfigFile>,
}
#[derive(Debug, Default, Deserialize)]
#[serde(rename_all = "camelCase")]
struct DbConfigFile {
datasource_url: Option<String>,
}
impl AppConfigFile {
fn database_url(&self) -> Option<String> {
self
.db
.as_ref()
.and_then(|db| db.datasource_url.clone())
.and_then(non_empty_string)
}
}
fn database_url_from_env() -> Option<String> {
env::var("DATABASE_URL").ok().and_then(non_empty_string)
}
fn non_empty_string(value: String) -> Option<String> {
if value.trim().is_empty() { None } else { Some(value) }
}
fn app_config_from_config_files() -> RuntimeResult<AppConfigFile> {
let mut merged = AppConfigFile::default();
for path in config_json_paths() {
if !path.exists() {
continue;
}
let raw = fs::read_to_string(&path).map_err(|err| RuntimeError::io("failed to read config file", err))?;
let config: AppConfigFile =
serde_json::from_str(&raw).map_err(|err| RuntimeError::json("failed to parse config file", err))?;
merged.apply_file_config(config);
}
Ok(merged)
}
impl AppConfigFile {
fn apply_file_config(&mut self, config: AppConfigFile) {
if config.db.is_some() {
self.db = config.db;
}
}
}
async fn load_app_config_overrides_from_db(pool: &PgPool) -> RuntimeResult<AppConfigFile> {
let rows = match sqlx::query("SELECT id, value FROM app_configs").fetch_all(pool).await {
Ok(rows) => rows,
Err(sqlx::Error::Database(err)) if err.code().as_deref() == Some("42P01") => return Ok(AppConfigFile::default()),
Err(err) => return Err(RuntimeError::database("failed to load app config overrides", err)),
};
app_config_from_flat_overrides(rows.into_iter().map(|row| {
let id: String = row.get("id");
let value: serde_json::Value = row.get("value");
(id, value)
}))
}
fn app_config_from_flat_overrides<I, S>(rows: I) -> RuntimeResult<AppConfigFile>
where
I: IntoIterator<Item = (S, serde_json::Value)>,
S: AsRef<str>,
{
let mut root = Map::new();
for (path, value) in rows {
let Some((module, key)) = path.as_ref().split_once('.') else {
continue;
};
root
.entry(module.to_string())
.or_insert_with(|| serde_json::Value::Object(Map::new()));
if let Some(serde_json::Value::Object(module_object)) = root.get_mut(module) {
module_object.insert(key.to_string(), value);
}
}
serde_json::from_value(serde_json::Value::Object(root))
.map_err(|err| RuntimeError::json("invalid app config overrides", err))
}
pub(super) fn config_json_paths() -> Vec<PathBuf> {
let mut paths = Vec::new();
if let Ok(exe) = env::current_exe()
&& let Some(dir) = exe.parent()
{
paths.push(config_in(dir));
}
if let Ok(cwd) = env::current_dir() {
paths.push(config_in(&cwd));
}
dedupe_paths(paths)
}
fn config_in(dir: &Path) -> PathBuf {
dir.join("config.json")
}
fn dedupe_paths(paths: Vec<PathBuf>) -> Vec<PathBuf> {
let mut deduped = Vec::new();
for path in paths {
if !deduped.contains(&path) {
deduped.push(path);
}
}
deduped
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn config_paths_are_limited_to_executable_dir_and_cwd() {
let paths = config_json_paths();
assert!(!paths.is_empty());
assert!(paths.len() <= 2);
assert!(
paths
.iter()
.all(|path| path.file_name().is_some_and(|name| name == "config.json"))
);
assert!(paths.iter().all(|path| !path.to_string_lossy().contains(".affine")));
assert!(
paths
.iter()
.all(|path| !path.to_string_lossy().contains("packages/backend/server"))
);
}
#[test]
fn blank_database_urls_are_ignored() {
assert_eq!(non_empty_string("".to_string()), None);
assert_eq!(non_empty_string(" ".to_string()), None);
assert_eq!(
non_empty_string("postgresql://affine:affine@localhost:5432/affine".to_string()),
Some("postgresql://affine:affine@localhost:5432/affine".to_string())
);
}
#[test]
fn ignores_storage_app_config_values() {
let app_config = app_config_from_flat_overrides([
(
"storages.blob.storage",
serde_json::json!({"provider": "cloudflare-r2"}),
),
("db.datasourceUrl", serde_json::json!("postgresql://example/runtime")),
])
.unwrap();
assert_eq!(
app_config.database_url().as_deref(),
Some("postgresql://example/runtime")
);
}
}
@@ -0,0 +1,126 @@
use napi::{Error, Status};
use super::storage_runtime::object_storage::error::ObjectStorageError;
pub(crate) type RuntimeResult<T> = std::result::Result<T, RuntimeError>;
#[derive(Debug, thiserror::Error)]
pub(crate) enum RuntimeError {
#[error("{0}")]
Config(String),
#[error("{0}")]
InvalidInput(String),
#[error("{0}")]
InvalidState(String),
#[error("{context}: {source}")]
Database {
context: String,
#[source]
source: sqlx::Error,
},
#[error("{context}: {source}")]
Io {
context: String,
#[source]
source: std::io::Error,
},
#[error("{context}: {source}")]
Json {
context: String,
#[source]
source: serde_json::Error,
},
#[error("{context}: {source}")]
Time {
context: String,
#[source]
source: std::time::SystemTimeError,
},
#[error(transparent)]
ObjectStorage(#[from] ObjectStorageError),
#[error("{0}")]
NapiBoundary(String),
}
impl RuntimeError {
pub(crate) fn config(message: impl Into<String>) -> Self {
Self::Config(message.into())
}
pub(crate) fn invalid_input(message: impl Into<String>) -> Self {
Self::InvalidInput(message.into())
}
pub(crate) fn invalid_state(message: impl Into<String>) -> Self {
Self::InvalidState(message.into())
}
pub(crate) fn database(context: impl Into<String>, source: sqlx::Error) -> Self {
Self::Database {
context: context.into(),
source,
}
}
pub(crate) fn io(context: impl Into<String>, source: std::io::Error) -> Self {
Self::Io {
context: context.into(),
source,
}
}
pub(crate) fn json(context: impl Into<String>, source: serde_json::Error) -> Self {
Self::Json {
context: context.into(),
source,
}
}
pub(crate) fn is_object_missing(&self) -> bool {
match self {
Self::ObjectStorage(error) => error.is_not_found(),
Self::Io { source, .. } => source.kind() == std::io::ErrorKind::NotFound,
Self::InvalidState(message)
| Self::InvalidInput(message)
| Self::Config(message)
| Self::NapiBoundary(message) => {
message.contains("NoSuchKey") || message.contains("NotFound") || message.contains("not found")
}
_ => false,
}
}
}
pub(crate) fn to_napi_error(error: RuntimeError) -> Error {
Error::new(Status::GenericFailure, error.to_string())
}
impl From<RuntimeError> for Error {
fn from(error: RuntimeError) -> Self {
to_napi_error(error)
}
}
impl From<ObjectStorageError> for Error {
fn from(error: ObjectStorageError) -> Self {
to_napi_error(RuntimeError::from(error))
}
}
impl From<Error> for RuntimeError {
fn from(error: Error) -> Self {
Self::NapiBoundary(error.to_string())
}
}
pub(crate) fn napi_error(message: impl Into<String>) -> Error {
Error::new(Status::GenericFailure, message.into())
}
@@ -0,0 +1,20 @@
use sqlx::PgPool;
use super::{RuntimeError, RuntimeResult};
pub(crate) const RUNTIME_MIGRATIONS: &str = include_str!("sql/runtime_migrations.sql");
pub(crate) async fn migrate_runtime_tables(pool: &PgPool) -> RuntimeResult<()> {
for statement in RUNTIME_MIGRATIONS
.split(';')
.map(str::trim)
.filter(|statement| !statement.is_empty())
{
sqlx::query(statement)
.execute(pool)
.await
.map_err(|err| RuntimeError::database("Runtime migration failed", err))?;
}
Ok(())
}
@@ -0,0 +1,10 @@
pub mod backend_runtime;
pub mod storage_runtime;
pub(crate) mod config;
pub(crate) mod error;
pub(crate) mod migrations;
pub(crate) mod types;
pub(crate) use config::BackendRuntimeConfig;
pub(crate) use error::{RuntimeError, RuntimeResult, napi_error, to_napi_error};
@@ -0,0 +1,112 @@
CREATE TABLE IF NOT EXISTS runtime_states (
purpose TEXT NOT NULL,
token_hash TEXT NOT NULL,
lookup_key TEXT,
payload JSONB NOT NULL,
attempts INTEGER NOT NULL DEFAULT 0,
consumed_at TIMESTAMPTZ(3),
expires_at TIMESTAMPTZ(3) NOT NULL,
created_at TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (purpose, token_hash)
);
CREATE INDEX IF NOT EXISTS runtime_states_lookup_idx
ON runtime_states (purpose, lookup_key)
WHERE lookup_key IS NOT NULL AND consumed_at IS NULL;
CREATE INDEX IF NOT EXISTS runtime_states_expires_at_idx
ON runtime_states (expires_at);
CREATE TABLE IF NOT EXISTS runtime_gates (
key TEXT PRIMARY KEY,
expires_at TIMESTAMPTZ(3) NOT NULL,
created_at TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS runtime_gates_expires_at_idx
ON runtime_gates (expires_at);
CREATE TABLE IF NOT EXISTS runtime_leases (
key TEXT PRIMARY KEY,
owner TEXT NOT NULL,
fencing_token BIGINT NOT NULL,
expires_at TIMESTAMPTZ(3) NOT NULL,
created_at TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS runtime_leases_expires_at_idx
ON runtime_leases (expires_at);
CREATE TABLE IF NOT EXISTS blob_reconciliation_runs (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
kind TEXT NOT NULL,
mode TEXT NOT NULL,
status TEXT NOT NULL,
workspace_id TEXT,
started_at TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
finished_at TIMESTAMPTZ(3),
cursor JSONB NOT NULL DEFAULT '{}',
scanned INTEGER NOT NULL DEFAULT 0,
changed INTEGER NOT NULL DEFAULT 0,
failed INTEGER NOT NULL DEFAULT 0,
metadata JSONB NOT NULL DEFAULT '{}'
);
CREATE INDEX IF NOT EXISTS blob_reconciliation_runs_workspace_idx
ON blob_reconciliation_runs (workspace_id, started_at DESC);
CREATE TABLE IF NOT EXISTS blob_reconciliation_checkpoints (
kind TEXT NOT NULL,
scope TEXT NOT NULL,
status TEXT NOT NULL,
cursor JSONB NOT NULL DEFAULT '{}',
last_key TEXT,
last_sid INTEGER,
completed_at TIMESTAMPTZ(3),
updated_at TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
metadata JSONB NOT NULL DEFAULT '{}',
PRIMARY KEY (kind, scope)
);
CREATE INDEX IF NOT EXISTS blob_reconciliation_checkpoints_status_idx
ON blob_reconciliation_checkpoints (kind, status, updated_at DESC);
CREATE TABLE IF NOT EXISTS doc_blob_refs (
workspace_id TEXT NOT NULL,
doc_id TEXT NOT NULL,
blob_key TEXT NOT NULL,
block_id TEXT NOT NULL,
flavour TEXT NOT NULL,
snapshot_updated_at TIMESTAMPTZ(3) NOT NULL,
indexed_at TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
parser_version INTEGER NOT NULL,
status TEXT NOT NULL DEFAULT 'fresh',
error TEXT,
PRIMARY KEY (workspace_id, doc_id, blob_key, block_id)
);
CREATE INDEX IF NOT EXISTS doc_blob_refs_workspace_blob_idx
ON doc_blob_refs (workspace_id, blob_key);
CREATE INDEX IF NOT EXISTS doc_blob_refs_workspace_status_idx
ON doc_blob_refs (workspace_id, status);
CREATE TABLE IF NOT EXISTS blob_cleanup_candidates (
workspace_id TEXT NOT NULL,
blob_key TEXT NOT NULL,
reason TEXT NOT NULL,
status TEXT NOT NULL,
object_size BIGINT NOT NULL,
object_last_modified TIMESTAMPTZ(3),
planned_at TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
executed_at TIMESTAMPTZ(3),
run_id UUID NOT NULL,
evidence JSONB NOT NULL DEFAULT '{}',
error TEXT,
PRIMARY KEY (workspace_id, blob_key)
);
CREATE INDEX IF NOT EXISTS blob_cleanup_candidates_run_idx
ON blob_cleanup_candidates (run_id, status);
@@ -0,0 +1,358 @@
use std::{
io::{BufReader, Cursor},
path::PathBuf,
time::SystemTime,
};
use assetpack_core::{
Codec, FileHint, FileTransformConfig, Hash32, ObjectKind, Pipeline, PipelineConfig, SqliteStore, TransformRegistry,
TransformSelector, build_recipe, pack::ObjectRecord, parse_recipe_checked,
};
use sqlx::Row;
use super::{
FsStorageConfig, MAX_BLOB_SIZE, ObjectGetResult, ObjectListEntry, ObjectMetadata, ObjectPutMetadata, RuntimeError,
RuntimeResult, fs_bucket_path, normalize_storage_key, system_time_ms,
};
pub(super) async fn put(
config: &FsStorageConfig,
scope: &str,
key: &str,
body: Vec<u8>,
metadata: ObjectPutMetadata,
) -> RuntimeResult<ObjectMetadata> {
normalize_storage_key(key)?;
let metadata = metadata.complete_for_body(&body);
let content_length = metadata.content_length.unwrap_or(body.len() as i64);
if content_length != body.len() as i64 {
return Err(RuntimeError::invalid_input(
"Assetpack contentLength does not match body length",
));
}
if !(0..=MAX_BLOB_SIZE).contains(&content_length) {
return Err(RuntimeError::invalid_input(
"Assetpack contentLength exceeds supported blob size",
));
}
let store = open_store(config).await?;
let transform_config = FileTransformConfig::default();
let bucket_path = fs_bucket_path(config);
let selector = TransformSelector::new(
transform_config.clone(),
transform_config.resolved_temp_dir(&bucket_path),
assetpack_transform_precomp2::default_specs(),
);
let original_hash = Hash32::sha3_256(&body);
let hint = FileHint {
size: body.len() as u64,
extension: extension_from_key(key),
head: Some(body.iter().take(4096).copied().collect()),
};
let plan = Pipeline::new(PipelineConfig::default())
.run(body, &hint, original_hash, Some(&selector))
.map_err(|err| RuntimeError::invalid_state(format!("Assetpack pipeline failed: {err}")))?;
let chunks = plan
.chunks
.iter()
.map(|chunk| (chunk.hash, chunk.raw_len))
.collect::<Vec<_>>();
let recipe = build_recipe(
plan.original_size,
&chunks,
plan.original_hash,
plan.transform_id,
plan.transform_version,
);
let recipe_hash = Hash32::sha3_256(&recipe);
let mut objects = Vec::with_capacity(plan.chunks.len() + 1);
for chunk in plan.chunks {
let Some(payload) = chunk.payload else {
return Err(RuntimeError::invalid_state(
"Assetpack pipeline unexpectedly discarded chunk payload",
));
};
objects.push(ObjectRecord {
hash: chunk.hash,
kind: ObjectKind::Chunk,
size: chunk.raw_len as u64,
codec: chunk.codec,
content: payload,
});
}
objects.push(ObjectRecord {
hash: recipe_hash,
kind: ObjectKind::Recipe,
size: recipe.len() as u64,
codec: Codec::Raw,
content: recipe,
});
let mut tx = store
.begin_write_tx()
.await
.map_err(|err| RuntimeError::invalid_state(format!("Assetpack begin write failed: {err}")))?;
store
.put_objects_batch_tx(&mut tx, &objects)
.await
.map_err(|err| RuntimeError::invalid_state(format!("Assetpack object write failed: {err}")))?;
store
.put_file_recipe_cache_batch_tx(&mut tx, &[(original_hash, recipe_hash)])
.await
.map_err(|err| RuntimeError::invalid_state(format!("Assetpack recipe cache write failed: {err}")))?;
let object_metadata = ObjectMetadata {
content_type: metadata
.content_type
.unwrap_or_else(|| "application/octet-stream".to_string()),
content_length,
last_modified_ms: system_time_ms(SystemTime::now())?,
checksum_crc32: metadata.checksum_crc32,
};
sqlx::query(
r#"
INSERT INTO storage_assetpack_blobs
(scope, key, recipe_hash, content_type, content_length, checksum_crc32, last_modified_ms)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)
ON CONFLICT (scope, key)
DO UPDATE SET
recipe_hash = excluded.recipe_hash,
content_type = excluded.content_type,
content_length = excluded.content_length,
checksum_crc32 = excluded.checksum_crc32,
last_modified_ms = excluded.last_modified_ms
"#,
)
.bind(scope)
.bind(key)
.bind(recipe_hash.to_hex())
.bind(&object_metadata.content_type)
.bind(object_metadata.content_length)
.bind(&object_metadata.checksum_crc32)
.bind(object_metadata.last_modified_ms)
.execute(&mut *tx)
.await
.map_err(|err| RuntimeError::database("Assetpack manifest write failed", err))?;
tx.commit()
.await
.map_err(|err| RuntimeError::invalid_state(format!("Assetpack commit failed: {err}")))?;
Ok(object_metadata)
}
pub(super) async fn head(config: &FsStorageConfig, scope: &str, key: &str) -> RuntimeResult<Option<ObjectMetadata>> {
normalize_storage_key(key)?;
let store = open_store(config).await?;
manifest_row(&store, scope, key)
.await
.map(|row| row.map(|row| row.metadata))
}
pub(super) async fn get(config: &FsStorageConfig, scope: &str, key: &str) -> RuntimeResult<Option<ObjectGetResult>> {
normalize_storage_key(key)?;
let store = open_store(config).await?;
let Some(row) = manifest_row(&store, scope, key).await? else {
return Ok(None);
};
let recipe_hash = Hash32::from_hex(&row.recipe_hash)
.map_err(|err| RuntimeError::invalid_state(format!("Assetpack manifest recipe hash is invalid: {err}")))?;
let Some(recipe_object) = store
.get_object(&recipe_hash)
.await
.map_err(|err| RuntimeError::invalid_state(format!("Assetpack recipe read failed: {err}")))?
else {
return Err(RuntimeError::invalid_state(format!(
"Assetpack recipe object is missing for {key}"
)));
};
let recipe = parse_recipe_checked(&recipe_object.content, &recipe_hash)
.map_err(|err| RuntimeError::invalid_state(format!("Assetpack recipe parse failed: {err}")))?;
let mut stored_stream = Vec::with_capacity(recipe.stored_stream_size as usize);
for (chunk_hash, expected_len) in &recipe.chunks {
let Some(chunk) = store
.get_object(chunk_hash)
.await
.map_err(|err| RuntimeError::invalid_state(format!("Assetpack chunk read failed: {err}")))?
else {
return Err(RuntimeError::invalid_state(format!(
"Assetpack chunk is missing for {key}: {chunk_hash}"
)));
};
if chunk.kind != ObjectKind::Chunk || chunk.size != *expected_len as u64 {
return Err(RuntimeError::invalid_state(format!(
"Assetpack chunk metadata mismatch for {key}: {chunk_hash}"
)));
}
stored_stream.extend_from_slice(&chunk.content);
}
let body = decode_stored_stream(recipe.transform_id, stored_stream)?;
if body.len() as u64 != recipe.original_file_size || Hash32::sha3_256(&body) != recipe.original_file_hash {
return Err(RuntimeError::invalid_state(format!(
"Assetpack reconstructed body failed integrity check for {key}"
)));
}
Ok(Some(ObjectGetResult {
body,
metadata: row.metadata,
}))
}
pub(super) async fn list(
config: &FsStorageConfig,
scope: &str,
prefix: Option<String>,
) -> RuntimeResult<Vec<ObjectListEntry>> {
let prefix = prefix
.map(|prefix| super::normalize_storage_prefix(&prefix))
.transpose()?
.unwrap_or_default();
let store = open_store(config).await?;
let rows = sqlx::query(
r#"
SELECT key, content_length, last_modified_ms
FROM storage_assetpack_blobs
WHERE scope = ?1 AND key LIKE ?2 ESCAPE '\'
ORDER BY key ASC
"#,
)
.bind(scope)
.bind(format!("{}%", escape_sqlite_like(&prefix)))
.fetch_all(store.pool())
.await
.map_err(|err| RuntimeError::database("Assetpack manifest list failed", err))?;
rows
.into_iter()
.map(|row| {
Ok(ObjectListEntry {
key: row.get("key"),
content_length: row.get::<i64, _>("content_length"),
last_modified_ms: row.get("last_modified_ms"),
})
})
.collect()
}
fn escape_sqlite_like(value: &str) -> String {
let mut escaped = String::with_capacity(value.len());
for ch in value.chars() {
match ch {
'%' | '_' | '\\' => {
escaped.push('\\');
escaped.push(ch);
}
_ => escaped.push(ch),
}
}
escaped
}
pub(super) async fn delete(config: &FsStorageConfig, scope: &str, key: &str) -> RuntimeResult<()> {
normalize_storage_key(key)?;
let store = open_store(config).await?;
sqlx::query("DELETE FROM storage_assetpack_blobs WHERE scope = ?1 AND key = ?2")
.bind(scope)
.bind(key)
.execute(store.pool())
.await
.map_err(|err| RuntimeError::database("Assetpack manifest delete failed", err))?;
Ok(())
}
async fn open_store(config: &FsStorageConfig) -> RuntimeResult<SqliteStore> {
let store = SqliteStore::open(store_path(config))
.await
.map_err(|err| RuntimeError::invalid_state(format!("Assetpack store open failed: {err}")))?;
ensure_manifest_schema(&store).await?;
Ok(store)
}
fn store_path(config: &FsStorageConfig) -> PathBuf {
fs_bucket_path(config).join("assetpack.sqlite")
}
fn extension_from_key(key: &str) -> Option<String> {
key
.rsplit_once('.')
.and_then(|(_, extension)| (!extension.is_empty()).then(|| extension.to_ascii_lowercase()))
}
struct ManifestRow {
recipe_hash: String,
metadata: ObjectMetadata,
}
async fn ensure_manifest_schema(store: &SqliteStore) -> RuntimeResult<()> {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS storage_assetpack_blobs (
scope TEXT NOT NULL,
key TEXT NOT NULL,
recipe_hash TEXT NOT NULL,
content_type TEXT NOT NULL,
content_length INTEGER NOT NULL,
checksum_crc32 TEXT,
last_modified_ms INTEGER NOT NULL,
PRIMARY KEY (scope, key)
)
"#,
)
.execute(store.pool())
.await
.map_err(|err| RuntimeError::database("Assetpack manifest schema create failed", err))?;
sqlx::query(
"CREATE INDEX IF NOT EXISTS storage_assetpack_blobs_scope_prefix_idx ON storage_assetpack_blobs (scope, key)",
)
.execute(store.pool())
.await
.map_err(|err| RuntimeError::database("Assetpack manifest index create failed", err))?;
Ok(())
}
async fn manifest_row(store: &SqliteStore, scope: &str, key: &str) -> RuntimeResult<Option<ManifestRow>> {
let row = sqlx::query(
r#"
SELECT recipe_hash, content_type, content_length, checksum_crc32, last_modified_ms
FROM storage_assetpack_blobs
WHERE scope = ?1 AND key = ?2
"#,
)
.bind(scope)
.bind(key)
.fetch_optional(store.pool())
.await
.map_err(|err| RuntimeError::database("Assetpack manifest read failed", err))?;
row
.map(|row| {
Ok(ManifestRow {
recipe_hash: row.get("recipe_hash"),
metadata: ObjectMetadata {
content_type: row.get("content_type"),
content_length: row.get("content_length"),
checksum_crc32: row.get("checksum_crc32"),
last_modified_ms: row.get("last_modified_ms"),
},
})
})
.transpose()
}
fn decode_stored_stream(transform_id: u16, stored_stream: Vec<u8>) -> RuntimeResult<Vec<u8>> {
let transform_config = FileTransformConfig::default();
let registry = TransformRegistry::new(&transform_config, assetpack_transform_precomp2::default_specs());
let transform = registry
.get(transform_id)
.ok_or_else(|| RuntimeError::invalid_state(format!("Assetpack transform is not registered: {transform_id}")))?;
let mut out = Vec::new();
transform
.decode(&mut BufReader::new(Cursor::new(stored_stream)), &mut out)
.map_err(|err| RuntimeError::invalid_state(format!("Assetpack transform decode failed: {err}")))?;
Ok(out)
}
@@ -0,0 +1,691 @@
use std::collections::HashMap;
use chrono::{DateTime, Duration, Utc};
use sqlx::{FromRow, PgPool};
use super::{
RuntimeBlobCleanupExecuteResult, RuntimeBlobCleanupPlanResult, RuntimeError, RuntimeResult, StorageRuntime,
napi_error,
};
#[derive(FromRow)]
struct BlobCandidateRow {
workspace_id: String,
key: String,
size: i32,
}
#[derive(FromRow)]
struct MarkedCandidateRow {
workspace_id: String,
blob_key: String,
}
struct DeletableCandidate {
workspace_id: String,
blob_key: String,
object_key: String,
}
fn push_workspace_once(workspace_ids: &mut Vec<String>, workspace_id: &str) {
if !workspace_ids.iter().any(|id| id == workspace_id) {
workspace_ids.push(workspace_id.to_string());
}
}
async fn checkpoint_completed(pool: &PgPool, kind: &str, scope: &str) -> RuntimeResult<bool> {
sqlx::query_scalar::<_, bool>(
"SELECT EXISTS(SELECT 1 FROM blob_reconciliation_checkpoints WHERE kind = $1 AND scope = $2 AND status = \
'completed')",
)
.bind(kind)
.bind(scope)
.fetch_one(pool)
.await
.map_err(|err| RuntimeError::database("Blob cleanup checkpoint check failed", err))
}
async fn projection_is_stale(pool: &PgPool, workspace_id: &str) -> RuntimeResult<bool> {
let checkpoint_fresh = checkpoint_completed(pool, "doc_blob_refs", workspace_id).await?;
let has_stale_rows = sqlx::query_scalar::<_, bool>(
"SELECT EXISTS(SELECT 1 FROM doc_blob_refs WHERE workspace_id = $1 AND status <> 'fresh')",
)
.bind(workspace_id)
.fetch_one(pool)
.await
.map_err(|err| RuntimeError::database("Blob cleanup projection freshness check failed", err))?;
Ok(!checkpoint_fresh || has_stale_rows)
}
async fn stale_projection_workspaces(pool: &PgPool, workspace_id: &str) -> RuntimeResult<Vec<String>> {
if projection_is_stale(pool, workspace_id).await? {
Ok(vec![workspace_id.to_string()])
} else {
Ok(Vec::new())
}
}
async fn metadata_backfill_is_complete(pool: &PgPool, workspace_id: &str) -> RuntimeResult<bool> {
checkpoint_completed(pool, "blob_metadata_backfill", workspace_id).await
}
async fn has_doc_ref(pool: &PgPool, workspace_id: &str, key: &str) -> RuntimeResult<bool> {
sqlx::query_scalar::<_, bool>(
"SELECT EXISTS(SELECT 1 FROM doc_blob_refs WHERE workspace_id = $1 AND blob_key = $2 AND status = 'fresh')",
)
.bind(workspace_id)
.bind(key)
.fetch_one(pool)
.await
.map_err(|err| RuntimeError::database("Blob cleanup doc ref check failed", err))
}
async fn has_other_ref(pool: &PgPool, workspace_id: &str, key: &str) -> RuntimeResult<bool> {
let required_ref = sqlx::query_scalar::<_, bool>(
r#"
SELECT EXISTS(SELECT 1 FROM workspaces WHERE id = $1 AND avatar_key = $2)
OR EXISTS(SELECT 1 FROM ai_transcript_tasks WHERE workspace_id = $1 AND blob_id = $2)
OR EXISTS(SELECT 1 FROM ai_jobs WHERE workspace_id = $1 AND blob_id = $2)
OR EXISTS(
SELECT 1
FROM ai_contexts c
JOIN ai_sessions_metadata s ON s.id = c.session_id
WHERE s.workspace_id = $1
AND jsonb_path_exists(
c.config::jsonb,
'$.** ? (@ == $blobKey)',
jsonb_build_object('blobKey', to_jsonb($2::text))
)
)
"#,
)
.bind(workspace_id)
.bind(key)
.fetch_one(pool)
.await
.map_err(|err| RuntimeError::database("Blob cleanup protected ref check failed", err))?;
if required_ref {
return Ok(true);
}
if table_exists(pool, "ai_workspace_files").await?
&& sqlx::query_scalar::<_, bool>(
"SELECT EXISTS(SELECT 1 FROM ai_workspace_files WHERE workspace_id = $1 AND blob_id = $2)",
)
.bind(workspace_id)
.bind(key)
.fetch_one(pool)
.await
.map_err(|err| RuntimeError::database("Blob cleanup workspace file ref check failed", err))?
{
return Ok(true);
}
if table_exists(pool, "ai_workspace_blob_embeddings").await?
&& sqlx::query_scalar::<_, bool>(
"SELECT EXISTS(SELECT 1 FROM ai_workspace_blob_embeddings WHERE workspace_id = $1 AND blob_id = $2)",
)
.bind(workspace_id)
.bind(key)
.fetch_one(pool)
.await
.map_err(|err| RuntimeError::database("Blob cleanup workspace blob embedding ref check failed", err))?
{
return Ok(true);
}
Ok(false)
}
async fn table_exists(pool: &PgPool, table: &str) -> RuntimeResult<bool> {
sqlx::query_scalar::<_, bool>("SELECT to_regclass($1) IS NOT NULL")
.bind(format!("public.{table}"))
.fetch_one(pool)
.await
.map_err(|err| RuntimeError::database("Blob cleanup table existence check failed", err))
}
async fn load_completed_blobs(
pool: &PgPool,
workspace_id: &str,
after_key: Option<&str>,
limit: i64,
) -> RuntimeResult<Vec<BlobCandidateRow>> {
sqlx::query_as::<_, BlobCandidateRow>(
r#"
SELECT workspace_id, key, size
FROM blobs
WHERE workspace_id = $1
AND status = 'completed'
AND deleted_at IS NULL
AND ($2::text IS NULL OR key > $2)
ORDER BY key ASC
LIMIT $3
"#,
)
.bind(workspace_id)
.bind(after_key)
.bind(limit)
.fetch_all(pool)
.await
.map_err(|err| RuntimeError::database("Blob cleanup load completed blobs failed", err))
}
async fn load_plan_cursor(pool: &PgPool, workspace_id: &str) -> RuntimeResult<Option<String>> {
let row = sqlx::query_as::<_, (String, serde_json::Value)>(
"SELECT status, cursor FROM blob_reconciliation_checkpoints WHERE kind = 'blob_cleanup_plan' AND scope = $1",
)
.bind(workspace_id)
.fetch_optional(pool)
.await
.map_err(|err| RuntimeError::database("Blob cleanup plan checkpoint load failed", err))?;
let Some((status, cursor)) = row else {
return Ok(None);
};
if status == "completed" {
return Ok(None);
}
Ok({
cursor
.get("lastBlobKey")
.and_then(|value| value.as_str())
.map(ToString::to_string)
})
}
async fn upsert_plan_checkpoint(
pool: &PgPool,
workspace_id: &str,
last_blob_key: Option<&str>,
completed: bool,
) -> RuntimeResult<()> {
let status = if completed { "completed" } else { "running" };
sqlx::query(
r#"
INSERT INTO blob_reconciliation_checkpoints
(kind, scope, status, cursor, last_key, completed_at)
VALUES ('blob_cleanup_plan', $1, $2, $3, $4, CASE WHEN $5 THEN CURRENT_TIMESTAMP ELSE NULL END)
ON CONFLICT (kind, scope) DO UPDATE
SET status = EXCLUDED.status,
cursor = EXCLUDED.cursor,
last_key = COALESCE(EXCLUDED.last_key, blob_reconciliation_checkpoints.last_key),
completed_at = CASE WHEN $5 THEN CURRENT_TIMESTAMP ELSE NULL END,
updated_at = CURRENT_TIMESTAMP
"#,
)
.bind(workspace_id)
.bind(status)
.bind(serde_json::json!({ "lastBlobKey": last_blob_key }))
.bind(last_blob_key)
.bind(completed)
.execute(pool)
.await
.map_err(|err| RuntimeError::database("Blob cleanup plan checkpoint write failed", err))?;
Ok(())
}
async fn create_run(pool: &PgPool, workspace_id: &str) -> RuntimeResult<String> {
sqlx::query_scalar::<_, String>(
r#"
INSERT INTO blob_reconciliation_runs (kind, mode, status, workspace_id)
VALUES ('blob_cleanup_plan', 'mark_only', 'running', $1)
RETURNING id::text
"#,
)
.bind(workspace_id)
.fetch_one(pool)
.await
.map_err(|err| RuntimeError::database("Blob cleanup create run failed", err))
}
async fn finish_run(
pool: &PgPool,
run_id: &str,
workspace_id: &str,
result: &RuntimeBlobCleanupPlanResult,
stale_projection_workspaces: Vec<String>,
) -> RuntimeResult<()> {
let candidate_bytes = sqlx::query_scalar::<_, Option<i64>>(
"SELECT SUM(object_size)::bigint FROM blob_cleanup_candidates WHERE run_id = $1::uuid AND status = 'marked'",
)
.bind(run_id)
.fetch_one(pool)
.await
.map_err(|err| RuntimeError::database("Blob cleanup candidate bytes audit failed", err))?
.unwrap_or(0);
sqlx::query(
r#"
UPDATE blob_reconciliation_runs
SET status = 'finished',
finished_at = CURRENT_TIMESTAMP,
scanned = $2,
changed = $3,
metadata = $4
WHERE id = $1::uuid
"#,
)
.bind(run_id)
.bind(result.scanned_blobs as i32)
.bind(result.candidates_marked as i32)
.bind(serde_json::json!({
"protectedByDocRefs": result.protected_by_doc_refs,
"protectedByMetadata": result.protected_by_metadata,
"protectedByOtherRefs": result.protected_by_other_refs,
"topWorkspaceCandidateBytes": [{
"workspaceId": workspace_id,
"candidateBytes": candidate_bytes,
}],
"staleOrFailedProjectionWorkspaces": stale_projection_workspaces,
}))
.execute(pool)
.await
.map_err(|err| RuntimeError::database("Blob cleanup finish run failed", err))?;
Ok(())
}
async fn mark_candidate_status(
pool: &PgPool,
run_id: &str,
workspace_id: &str,
blob_key: &str,
status: &str,
evidence: serde_json::Value,
error: Option<&str>,
) -> RuntimeResult<()> {
sqlx::query(
r#"
UPDATE blob_cleanup_candidates
SET status = $3,
executed_at = CURRENT_TIMESTAMP,
evidence = evidence || $4,
error = $5
WHERE workspace_id = $1 AND blob_key = $2 AND run_id = $6::uuid
"#,
)
.bind(workspace_id)
.bind(blob_key)
.bind(status)
.bind(evidence)
.bind(error)
.bind(run_id)
.execute(pool)
.await
.map_err(|err| RuntimeError::database("Blob cleanup mark candidate status failed", err))?;
Ok(())
}
async fn finish_execute_run(
pool: &PgPool,
run_id: &str,
result: &RuntimeBlobCleanupExecuteResult,
) -> RuntimeResult<()> {
sqlx::query(
r#"
UPDATE blob_reconciliation_runs
SET status = 'finished',
finished_at = CURRENT_TIMESTAMP,
scanned = $2,
changed = $3,
failed = $4,
metadata = metadata || $5
WHERE id = $1::uuid
"#,
)
.bind(run_id)
.bind(result.scanned_candidates as i32)
.bind(result.deleted_metadata as i32)
.bind(result.failed as i32)
.bind(serde_json::json!({
"deletedObjects": result.deleted_objects,
"deletedMetadata": result.deleted_metadata,
"skippedStillReferenced": result.skipped_still_referenced,
"failed": result.failed,
}))
.execute(pool)
.await
.map_err(|err| RuntimeError::database("Blob cleanup execute run finish failed", err))?;
Ok(())
}
async fn mark_candidate(
pool: &PgPool,
run_id: &str,
row: &BlobCandidateRow,
object_size: i64,
object_last_modified: DateTime<Utc>,
) -> RuntimeResult<i64> {
let result = sqlx::query(
r#"
INSERT INTO blob_cleanup_candidates
(workspace_id, blob_key, reason, status, object_size, object_last_modified, run_id, evidence)
VALUES ($1, $2, 'unreferenced_completed_blob', 'marked', $3, $4, $5::uuid, $6)
ON CONFLICT (workspace_id, blob_key) DO UPDATE
SET reason = EXCLUDED.reason,
status = 'marked',
object_size = EXCLUDED.object_size,
object_last_modified = EXCLUDED.object_last_modified,
planned_at = CURRENT_TIMESTAMP,
run_id = EXCLUDED.run_id,
evidence = EXCLUDED.evidence,
error = NULL
"#,
)
.bind(&row.workspace_id)
.bind(&row.key)
.bind(object_size)
.bind(object_last_modified)
.bind(run_id)
.bind(serde_json::json!({ "metadataSize": row.size }))
.execute(pool)
.await
.map_err(|err| RuntimeError::database("Blob cleanup mark candidate failed", err))?;
Ok(result.rows_affected() as i64)
}
async fn load_marked_candidates(pool: &PgPool, run_id: &str, limit: i64) -> RuntimeResult<Vec<MarkedCandidateRow>> {
sqlx::query_as::<_, MarkedCandidateRow>(
r#"
SELECT workspace_id, blob_key
FROM blob_cleanup_candidates
WHERE run_id = $1::uuid AND status IN ('marked', 'failed')
ORDER BY CASE WHEN status = 'marked' THEN 0 ELSE 1 END, planned_at ASC
LIMIT $2
"#,
)
.bind(run_id)
.bind(limit)
.fetch_all(pool)
.await
.map_err(|err| RuntimeError::database("Blob cleanup load marked candidates failed", err))
}
#[napi_derive::napi]
impl StorageRuntime {
#[napi]
pub async fn plan_unreferenced_workspace_blobs(
&self,
workspace_id: String,
grace_period_days: i64,
limit: i64,
) -> napi::Result<RuntimeBlobCleanupPlanResult> {
if limit <= 0 {
return Err(napi_error("blob cleanup plan limit must be positive"));
}
if grace_period_days < 0 {
return Err(napi_error("blob cleanup grace period must be non-negative"));
}
let pool = self.pool().await?;
let run_id = create_run(&pool, &workspace_id).await?;
let mut result = RuntimeBlobCleanupPlanResult {
run_id: Some(run_id.clone()),
scanned_blobs: 0,
candidates_marked: 0,
protected_by_doc_refs: 0,
protected_by_metadata: 0,
protected_by_other_refs: 0,
next_cursor: None,
};
let cursor = load_plan_cursor(&pool, &workspace_id).await?;
let stale_projection_workspaces = stale_projection_workspaces(&pool, &workspace_id).await?;
if !metadata_backfill_is_complete(&pool, &workspace_id).await? || !stale_projection_workspaces.is_empty() {
result.protected_by_metadata = load_completed_blobs(&pool, &workspace_id, cursor.as_deref(), limit)
.await?
.len() as i64;
finish_run(&pool, &run_id, &workspace_id, &result, stale_projection_workspaces).await?;
return Ok(result);
}
let min_last_modified = Utc::now() - Duration::days(grace_period_days);
let rows = load_completed_blobs(&pool, &workspace_id, cursor.as_deref(), limit).await?;
let has_more = rows.len() == limit as usize;
let mut last_blob_key = None;
for row in rows {
result.scanned_blobs += 1;
last_blob_key = Some(row.key.clone());
if has_doc_ref(&pool, &row.workspace_id, &row.key).await? {
result.protected_by_doc_refs += 1;
continue;
}
if has_other_ref(&pool, &row.workspace_id, &row.key).await? {
result.protected_by_other_refs += 1;
continue;
}
let object_key = format!("{}/{}", row.workspace_id, row.key);
let Some(metadata) = self.object_storage_head(object_key).await? else {
result.protected_by_metadata += 1;
continue;
};
let last_modified = DateTime::<Utc>::from_timestamp_millis(metadata.last_modified_ms)
.ok_or_else(|| RuntimeError::invalid_state("blob cleanup object last modified is invalid"))?;
if metadata.content_length != row.size as i64 || last_modified > min_last_modified {
result.protected_by_metadata += 1;
continue;
}
result.candidates_marked += mark_candidate(&pool, &run_id, &row, metadata.content_length, last_modified).await?;
}
if has_more {
result.next_cursor = last_blob_key.clone();
}
upsert_plan_checkpoint(&pool, &workspace_id, last_blob_key.as_deref(), !has_more).await?;
finish_run(&pool, &run_id, &workspace_id, &result, Vec::new()).await?;
Ok(result)
}
#[napi]
pub async fn execute_blob_cleanup_candidates(
&self,
run_id: String,
grace_period_days: i64,
limit: i64,
) -> napi::Result<RuntimeBlobCleanupExecuteResult> {
if limit <= 0 {
return Err(napi_error("blob cleanup execute limit must be positive"));
}
if grace_period_days < 0 {
return Err(napi_error("blob cleanup grace period must be non-negative"));
}
let pool = self.pool().await?;
let min_last_modified = Utc::now() - Duration::days(grace_period_days);
let rows = load_marked_candidates(&pool, &run_id, limit).await?;
let mut result = RuntimeBlobCleanupExecuteResult {
scanned_candidates: rows.len() as i64,
deleted_objects: 0,
deleted_metadata: 0,
skipped_still_referenced: 0,
failed: 0,
workspace_ids: Vec::new(),
};
let mut deletable_candidates = Vec::new();
for row in rows {
if projection_is_stale(&pool, &row.workspace_id).await?
|| has_doc_ref(&pool, &row.workspace_id, &row.blob_key).await?
|| has_other_ref(&pool, &row.workspace_id, &row.blob_key).await?
{
result.skipped_still_referenced += 1;
mark_candidate_status(
&pool,
&run_id,
&row.workspace_id,
&row.blob_key,
"skipped",
serde_json::json!({ "skipReason": "referenced_or_projection_stale" }),
None,
)
.await?;
continue;
}
let object_key = format!("{}/{}", row.workspace_id, row.blob_key);
let metadata = match self.object_storage_head(object_key.clone()).await {
Ok(metadata) => metadata,
Err(err) => {
result.failed += 1;
mark_candidate_status(
&pool,
&run_id,
&row.workspace_id,
&row.blob_key,
"failed",
serde_json::json!({ "failure": "object_head_failed" }),
Some(&err.to_string()),
)
.await?;
continue;
}
};
if let Some(metadata) = metadata {
let last_modified = DateTime::<Utc>::from_timestamp_millis(metadata.last_modified_ms)
.ok_or_else(|| RuntimeError::invalid_state("blob cleanup execute object last modified is invalid"))?;
if last_modified > min_last_modified {
result.skipped_still_referenced += 1;
mark_candidate_status(
&pool,
&run_id,
&row.workspace_id,
&row.blob_key,
"skipped",
serde_json::json!({ "skipReason": "object_inside_grace_period" }),
None,
)
.await?;
continue;
}
deletable_candidates.push(DeletableCandidate {
workspace_id: row.workspace_id,
blob_key: row.blob_key,
object_key,
});
continue;
}
let deleted_metadata =
match sqlx::query("DELETE FROM blobs WHERE workspace_id = $1 AND key = $2 AND deleted_at IS NULL")
.bind(&row.workspace_id)
.bind(&row.blob_key)
.execute(&pool)
.await
{
Ok(result) => result.rows_affected() as i64,
Err(err) => {
result.failed += 1;
mark_candidate_status(
&pool,
&run_id,
&row.workspace_id,
&row.blob_key,
"failed",
serde_json::json!({ "failure": "metadata_delete_failed" }),
Some(&err.to_string()),
)
.await?;
continue;
}
};
result.deleted_metadata += deleted_metadata;
push_workspace_once(&mut result.workspace_ids, &row.workspace_id);
mark_candidate_status(
&pool,
&run_id,
&row.workspace_id,
&row.blob_key,
"executed",
serde_json::json!({
"deletedMetadata": deleted_metadata,
"objectMissingBeforeDelete": true,
}),
None,
)
.await?;
}
if !deletable_candidates.is_empty() {
let object_keys = deletable_candidates
.iter()
.map(|candidate| candidate.object_key.clone())
.collect::<Vec<_>>();
let outcomes = match self.object_storage_delete_many(object_keys.clone()).await {
Ok(outcomes) => outcomes,
Err(err) => object_keys
.into_iter()
.map(|key| super::object_storage::types::ObjectDeleteOutcome {
key,
error: Some(err.to_string()),
})
.collect(),
};
let mut outcomes_by_key = outcomes
.into_iter()
.map(|outcome| (outcome.key, outcome.error))
.collect::<HashMap<_, _>>();
for row in deletable_candidates {
let delete_error = match outcomes_by_key.remove(&row.object_key) {
Some(Some(error)) => Some(error),
Some(None) => None,
None => Some("DeleteObjects response did not include this key".to_string()),
};
if let Some(error) = delete_error {
result.failed += 1;
mark_candidate_status(
&pool,
&run_id,
&row.workspace_id,
&row.blob_key,
"failed",
serde_json::json!({ "failure": "object_delete_failed" }),
Some(&error),
)
.await?;
continue;
}
result.deleted_objects += 1;
let deleted_metadata =
match sqlx::query("DELETE FROM blobs WHERE workspace_id = $1 AND key = $2 AND deleted_at IS NULL")
.bind(&row.workspace_id)
.bind(&row.blob_key)
.execute(&pool)
.await
{
Ok(result) => result.rows_affected() as i64,
Err(err) => {
result.failed += 1;
mark_candidate_status(
&pool,
&run_id,
&row.workspace_id,
&row.blob_key,
"failed",
serde_json::json!({ "failure": "metadata_delete_failed" }),
Some(&err.to_string()),
)
.await?;
continue;
}
};
result.deleted_metadata += deleted_metadata;
push_workspace_once(&mut result.workspace_ids, &row.workspace_id);
mark_candidate_status(
&pool,
&run_id,
&row.workspace_id,
&row.blob_key,
"executed",
serde_json::json!({
"deletedMetadata": deleted_metadata,
"objectMissingBeforeDelete": false,
}),
None,
)
.await?;
}
}
finish_execute_run(&pool, &run_id, &result).await?;
Ok(result)
}
}
@@ -2,7 +2,7 @@ use chrono::{DateTime, Utc};
use napi::Result;
use sqlx::{FromRow, PgPool};
use super::{BackendRuntime, error::napi_error, types::RuntimeBlobCleanupResult};
use super::{RuntimeBlobCleanupResult, RuntimeError, RuntimeResult, StorageRuntime, napi_error};
#[derive(FromRow)]
struct BlobRow {
@@ -20,7 +20,7 @@ impl BlobReclaimerStore {
Self { pool }
}
async fn load_expired_pending(&self, cutoff: DateTime<Utc>, limit: i64) -> Result<Vec<BlobRow>> {
async fn load_expired_pending(&self, cutoff: DateTime<Utc>, limit: i64) -> RuntimeResult<Vec<BlobRow>> {
sqlx::query_as::<_, BlobRow>(
r#"
SELECT workspace_id, key, upload_id
@@ -36,10 +36,10 @@ impl BlobReclaimerStore {
.bind(limit)
.fetch_all(&self.pool)
.await
.map_err(|err| napi_error(format!("BlobReclaimer load pending blobs failed: {err}")))
.map_err(|err| RuntimeError::database("BlobReclaimer load pending blobs failed", err))
}
async fn load_deleted(&self, workspace_id: &str, limit: i64) -> Result<Vec<BlobRow>> {
async fn load_deleted(&self, workspace_id: &str, limit: i64) -> RuntimeResult<Vec<BlobRow>> {
sqlx::query_as::<_, BlobRow>(
r#"
SELECT workspace_id, key, upload_id
@@ -54,10 +54,10 @@ impl BlobReclaimerStore {
.bind(limit)
.fetch_all(&self.pool)
.await
.map_err(|err| napi_error(format!("BlobReclaimer load deleted blobs failed: {err}")))
.map_err(|err| RuntimeError::database("BlobReclaimer load deleted blobs failed", err))
}
async fn delete_pending_metadata(&self, workspace_id: &str, key: &str) -> Result<i64> {
async fn delete_pending_metadata(&self, workspace_id: &str, key: &str) -> RuntimeResult<i64> {
let result = sqlx::query(
r#"
DELETE FROM blobs
@@ -70,11 +70,11 @@ impl BlobReclaimerStore {
.bind(key)
.execute(&self.pool)
.await
.map_err(|err| napi_error(format!("BlobReclaimer delete pending blob metadata failed: {err}")))?;
.map_err(|err| RuntimeError::database("BlobReclaimer delete pending blob metadata failed", err))?;
Ok(result.rows_affected() as i64)
}
async fn delete_released_metadata(&self, workspace_id: &str, key: &str) -> Result<i64> {
async fn delete_released_metadata(&self, workspace_id: &str, key: &str) -> RuntimeResult<i64> {
let result = sqlx::query(
r#"
DELETE FROM blobs
@@ -86,31 +86,23 @@ impl BlobReclaimerStore {
.bind(key)
.execute(&self.pool)
.await
.map_err(|err| napi_error(format!("BlobReclaimer delete blob metadata failed: {err}")))?;
.map_err(|err| RuntimeError::database("BlobReclaimer delete blob metadata failed", err))?;
Ok(result.rows_affected() as i64)
}
}
fn object_missing_error(err: &napi::Error) -> bool {
let message = err.to_string();
message.contains("NoSuchKey")
|| message.contains("NoSuchUpload")
|| message.contains("NotFound")
|| message.contains("not found")
}
async fn delete_object_idempotent(runtime: &BackendRuntime, key: &str) -> Result<()> {
async fn delete_object_idempotent(runtime: &StorageRuntime, key: &str) -> RuntimeResult<()> {
match runtime.object_storage_delete_object(key).await {
Ok(()) => Ok(()),
Err(err) if object_missing_error(&err) => Ok(()),
Err(err) if err.is_object_missing() => Ok(()),
Err(err) => Err(err),
}
}
async fn abort_upload_idempotent(runtime: &BackendRuntime, key: &str, upload_id: &str) -> Result<()> {
async fn abort_upload_idempotent(runtime: &StorageRuntime, key: &str, upload_id: &str) -> RuntimeResult<()> {
match runtime.object_storage_abort_upload(key, upload_id).await {
Ok(()) => Ok(()),
Err(err) if object_missing_error(&err) => Ok(()),
Err(err) if err.is_object_missing() => Ok(()),
Err(err) => Err(err),
}
}
@@ -122,7 +114,7 @@ fn push_workspace_once(workspace_ids: &mut Vec<String>, workspace_id: &str) {
}
#[napi_derive::napi]
impl BackendRuntime {
impl StorageRuntime {
#[napi]
pub async fn cleanup_expired_pending_blobs(&self, cutoff_ms: i64, limit: i64) -> Result<RuntimeBlobCleanupResult> {
if limit <= 0 {
@@ -130,7 +122,7 @@ impl BackendRuntime {
}
let cutoff = DateTime::<Utc>::from_timestamp_millis(cutoff_ms)
.ok_or_else(|| napi_error("pending blob cleanup cutoff is invalid"))?;
.ok_or_else(|| RuntimeError::invalid_input("pending blob cleanup cutoff is invalid"))?;
let store = BlobReclaimerStore::new(self.pool().await?);
let rows = store.load_expired_pending(cutoff, limit).await?;
@@ -0,0 +1,273 @@
use chrono::{DateTime, Utc};
use sqlx::{FromRow, PgPool};
use super::{
RuntimeBlobMetadataBackfillResult, RuntimeError, RuntimeObjectMetadata, RuntimeResult, StorageRuntime, napi_error,
};
async fn workspace_exists(pool: &PgPool, workspace_id: &str) -> RuntimeResult<bool> {
sqlx::query_scalar::<_, bool>("SELECT EXISTS(SELECT 1 FROM workspaces WHERE id = $1)")
.bind(workspace_id)
.fetch_one(pool)
.await
.map_err(|err| RuntimeError::database("Blob metadata backfill workspace check failed", err))
}
async fn blob_exists(pool: &PgPool, workspace_id: &str, key: &str) -> RuntimeResult<bool> {
sqlx::query_scalar::<_, bool>("SELECT EXISTS(SELECT 1 FROM blobs WHERE workspace_id = $1 AND key = $2)")
.bind(workspace_id)
.bind(key)
.fetch_one(pool)
.await
.map_err(|err| RuntimeError::database("Blob metadata backfill blob check failed", err))
}
async fn upsert_blob_metadata(
pool: &PgPool,
workspace_id: &str,
key: &str,
metadata: RuntimeObjectMetadata,
) -> RuntimeResult<i64> {
let last_modified = DateTime::<Utc>::from_timestamp_millis(metadata.last_modified_ms)
.ok_or_else(|| RuntimeError::invalid_state("Blob metadata backfill object last modified is invalid"))?;
let result = sqlx::query(
r#"
INSERT INTO blobs (workspace_id, key, size, mime, status, upload_id, created_at, deleted_at)
VALUES ($1, $2, $3, $4, 'completed', NULL, $5, NULL)
ON CONFLICT (workspace_id, key) DO UPDATE
SET size = EXCLUDED.size,
mime = EXCLUDED.mime,
status = 'completed',
upload_id = NULL,
deleted_at = NULL
WHERE blobs.deleted_at IS NULL
"#,
)
.bind(workspace_id)
.bind(key)
.bind(metadata.content_length as i32)
.bind(metadata.content_type)
.bind(last_modified)
.execute(pool)
.await
.map_err(|err| RuntimeError::database("Blob metadata backfill upsert failed", err))?;
Ok(result.rows_affected() as i64)
}
fn split_workspace_blob_key(full_key: &str) -> Option<(&str, &str)> {
let (workspace_id, key) = full_key.split_once('/')?;
if workspace_id.is_empty() || key.is_empty() || key.contains('/') {
return None;
}
Some((workspace_id, key))
}
fn checkpoint_scope(workspace_id: Option<&str>) -> String {
workspace_id.unwrap_or("__all__").to_string()
}
#[derive(FromRow)]
struct BackfillCheckpoint {
last_key: Option<String>,
cursor: serde_json::Value,
}
impl BackfillCheckpoint {
fn continuation_token(&self) -> Option<String> {
self
.cursor
.get("continuationToken")
.and_then(|value| value.as_str())
.map(ToString::to_string)
}
}
async fn load_checkpoint(pool: &PgPool, scope: &str) -> RuntimeResult<Option<BackfillCheckpoint>> {
sqlx::query_as::<_, BackfillCheckpoint>(
"SELECT last_key, cursor FROM blob_reconciliation_checkpoints WHERE kind = 'blob_metadata_backfill' AND scope = $1",
)
.bind(scope)
.fetch_optional(pool)
.await
.map_err(|err| RuntimeError::database("Blob metadata backfill checkpoint load failed", err))
}
async fn upsert_checkpoint(
pool: &PgPool,
scope: &str,
last_key: Option<&str>,
continuation_token: Option<&str>,
completed: bool,
) -> RuntimeResult<()> {
let status = if completed { "completed" } else { "running" };
sqlx::query(
r#"
INSERT INTO blob_reconciliation_checkpoints
(kind, scope, status, cursor, last_key, completed_at, metadata)
VALUES ('blob_metadata_backfill', $1, $2, $3, $4, CASE WHEN $5 THEN CURRENT_TIMESTAMP ELSE NULL END, $6)
ON CONFLICT (kind, scope) DO UPDATE
SET status = EXCLUDED.status,
cursor = EXCLUDED.cursor,
last_key = COALESCE(EXCLUDED.last_key, blob_reconciliation_checkpoints.last_key),
completed_at = CASE WHEN $5 THEN CURRENT_TIMESTAMP ELSE NULL END,
updated_at = CURRENT_TIMESTAMP,
metadata = EXCLUDED.metadata
"#,
)
.bind(scope)
.bind(status)
.bind(serde_json::json!({
"lastKey": last_key,
"continuationToken": continuation_token,
}))
.bind(last_key)
.bind(completed)
.bind(serde_json::json!({
"quotaReportingReconciliationRequired": true,
}))
.execute(pool)
.await
.map_err(|err| RuntimeError::database("Blob metadata backfill checkpoint write failed", err))?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn blob_metadata_backfill_splits_workspace_blob_keys() {
assert_eq!(
split_workspace_blob_key("workspace/blob-key"),
Some(("workspace", "blob-key"))
);
assert_eq!(split_workspace_blob_key("workspace/nested/blob-key"), None);
assert_eq!(split_workspace_blob_key("workspace/"), None);
assert_eq!(split_workspace_blob_key("blob-key"), None);
}
#[test]
fn blob_metadata_backfill_checkpoint_scope_is_explicit() {
assert_eq!(checkpoint_scope(Some("workspace")), "workspace");
assert_eq!(checkpoint_scope(None), "__all__");
}
}
fn push_workspace_once(workspace_ids: &mut Vec<String>, workspace_id: &str) {
if !workspace_ids.iter().any(|id| id == workspace_id) {
workspace_ids.push(workspace_id.to_string());
}
}
fn checked_list_page_limit(limit: i64) -> RuntimeResult<i32> {
i32::try_from(limit).map_err(|_| RuntimeError::invalid_input("blob metadata backfill limit exceeds i32::MAX"))
}
#[napi_derive::napi]
impl StorageRuntime {
#[napi]
pub async fn backfill_missing_blob_metadata(
&self,
workspace_id: Option<String>,
limit: i64,
) -> napi::Result<RuntimeBlobMetadataBackfillResult> {
if limit <= 0 {
return Err(napi_error("blob metadata backfill limit must be positive"));
}
let page_limit = checked_list_page_limit(limit)?;
let pool = self.pool().await?;
let prefix = workspace_id.as_ref().map(|id| format!("{id}/"));
let scope = checkpoint_scope(workspace_id.as_deref());
let checkpoint = load_checkpoint(&pool, &scope).await?;
let page = self
.object_storage_list_page(
prefix,
checkpoint.as_ref().and_then(BackfillCheckpoint::continuation_token),
checkpoint.as_ref().and_then(|checkpoint| checkpoint.last_key.clone()),
page_limit,
)
.await?;
let has_more = page.next_continuation_token.is_some();
let mut result = RuntimeBlobMetadataBackfillResult {
scanned_objects: 0,
headed_objects: 0,
upserted_metadata: 0,
skipped_existing: 0,
skipped_workspace_missing: 0,
failed: 0,
next_cursor: None,
workspace_ids: Vec::new(),
};
let mut last_scanned_key = None;
for object in &page.entries {
result.scanned_objects += 1;
last_scanned_key = Some(object.key.clone());
let Some((object_workspace_id, key)) = split_workspace_blob_key(&object.key) else {
result.failed += 1;
continue;
};
if workspace_id.as_deref().is_some_and(|id| id != object_workspace_id) {
result.failed += 1;
continue;
}
if !workspace_exists(&pool, object_workspace_id).await? {
result.skipped_workspace_missing += 1;
continue;
}
if blob_exists(&pool, object_workspace_id, key).await? {
result.skipped_existing += 1;
continue;
}
result.headed_objects += 1;
let Some(metadata) = self.object_storage_head(object.key.clone()).await? else {
result.failed += 1;
continue;
};
let affected = upsert_blob_metadata(&pool, object_workspace_id, key, metadata).await?;
if affected > 0 {
result.upserted_metadata += affected;
push_workspace_once(&mut result.workspace_ids, object_workspace_id);
}
}
if has_more {
result.next_cursor = last_scanned_key.clone();
}
upsert_checkpoint(
&pool,
&scope,
last_scanned_key.as_deref(),
page.next_continuation_token.as_deref(),
!has_more,
)
.await?;
sqlx::query(
r#"
INSERT INTO blob_reconciliation_runs
(kind, mode, status, workspace_id, finished_at, scanned, changed, failed, metadata)
VALUES ('blob_metadata_backfill', 'execute', 'finished', $1, CURRENT_TIMESTAMP, $2, $3, $4, $5)
"#,
)
.bind(workspace_id)
.bind(result.scanned_objects as i32)
.bind(result.upserted_metadata as i32)
.bind(result.failed as i32)
.bind(serde_json::json!({
"headedObjects": result.headed_objects,
"skippedExisting": result.skipped_existing,
"skippedWorkspaceMissing": result.skipped_workspace_missing,
"checkpointScope": scope,
"nextCursor": result.next_cursor,
"quotaReportingReconciliationRequired": true,
}))
.execute(&pool)
.await
.map_err(|err| RuntimeError::database("Blob metadata backfill run record failed", err))?;
Ok(result)
}
}
@@ -0,0 +1,427 @@
use affine_common::doc_parser;
use chrono::{DateTime, Utc};
use sqlx::{FromRow, PgPool};
use y_octo::Doc;
use super::{RuntimeDocBlobRefsResult, RuntimeError, RuntimeResult, StorageRuntime, napi_error};
const PARSER_VERSION: i32 = 1;
#[derive(FromRow)]
struct SnapshotRow {
workspace_id: String,
doc_id: String,
blob: Vec<u8>,
updated_at: DateTime<Utc>,
}
#[derive(FromRow)]
struct UpdateRow {
blob: Vec<u8>,
created_at: DateTime<Utc>,
}
struct ExtractedRef {
blob_key: String,
block_id: String,
flavour: String,
}
async fn load_snapshot(pool: &PgPool, workspace_id: &str, doc_id: &str) -> RuntimeResult<Option<SnapshotRow>> {
sqlx::query_as::<_, SnapshotRow>(
r#"
SELECT workspace_id, guid AS doc_id, blob, updated_at
FROM snapshots
WHERE workspace_id = $1 AND guid = $2
"#,
)
.bind(workspace_id)
.bind(doc_id)
.fetch_optional(pool)
.await
.map_err(|err| RuntimeError::database("Doc blob refs load snapshot failed", err))
}
async fn load_updates(pool: &PgPool, workspace_id: &str, doc_id: &str) -> RuntimeResult<Vec<UpdateRow>> {
sqlx::query_as::<_, UpdateRow>(
r#"
SELECT blob, created_at
FROM updates
WHERE workspace_id = $1 AND guid = $2
ORDER BY created_at ASC
"#,
)
.bind(workspace_id)
.bind(doc_id)
.fetch_all(pool)
.await
.map_err(|err| RuntimeError::database("Doc blob refs load updates failed", err))
}
fn apply_doc_updates(updates: impl IntoIterator<Item = Vec<u8>>) -> RuntimeResult<Vec<u8>> {
let mut doc = Doc::default();
for update in updates {
doc
.apply_update_from_binary_v1(&update)
.map_err(|err| RuntimeError::invalid_state(format!("Doc blob refs merge failed: {err}")))?;
}
doc
.encode_update_v1()
.map_err(|err| RuntimeError::invalid_state(format!("Doc blob refs encode failed: {err}")))
}
async fn load_current_doc(pool: &PgPool, workspace_id: &str, doc_id: &str) -> RuntimeResult<Option<SnapshotRow>> {
let snapshot = load_snapshot(pool, workspace_id, doc_id).await?;
let updates = load_updates(pool, workspace_id, doc_id).await?;
if snapshot.is_none() && updates.is_empty() {
return Ok(None);
}
let mut merge_inputs = Vec::with_capacity(updates.len() + usize::from(snapshot.is_some()));
let mut updated_at = snapshot
.as_ref()
.map(|snapshot| snapshot.updated_at)
.unwrap_or_else(Utc::now);
if let Some(snapshot) = snapshot {
merge_inputs.push(snapshot.blob);
}
for update in updates {
updated_at = update.created_at;
merge_inputs.push(update.blob);
}
Ok(Some(SnapshotRow {
workspace_id: workspace_id.to_string(),
doc_id: doc_id.to_string(),
blob: apply_doc_updates(merge_inputs)?,
updated_at,
}))
}
async fn load_workspace_doc_ids(pool: &PgPool, workspace_id: &str) -> RuntimeResult<Vec<String>> {
let Some(root) = load_current_doc(pool, workspace_id, workspace_id).await? else {
return Ok(Vec::new());
};
let ids = doc_parser::get_doc_ids_from_binary(root.blob, false)
.map_err(|err| RuntimeError::invalid_state(format!("Doc blob refs root doc parse failed: {err}")))?;
let mut ids = ids;
ids.sort();
Ok(ids)
}
async fn upsert_projection_checkpoint(
pool: &PgPool,
workspace_id: &str,
result: &RuntimeDocBlobRefsResult,
) -> RuntimeResult<()> {
let completed = result.next_cursor.is_none();
let status = if completed && result.failed_docs == 0 {
"completed"
} else if result.failed_docs > 0 {
"failed"
} else {
"running"
};
sqlx::query(
r#"
INSERT INTO blob_reconciliation_checkpoints
(kind, scope, status, cursor, completed_at, metadata)
VALUES ('doc_blob_refs', $1, $2, $3, CASE WHEN $4 THEN CURRENT_TIMESTAMP ELSE NULL END, $5)
ON CONFLICT (kind, scope) DO UPDATE
SET status = EXCLUDED.status,
cursor = EXCLUDED.cursor,
completed_at = CASE WHEN $4 THEN CURRENT_TIMESTAMP ELSE NULL END,
updated_at = CURRENT_TIMESTAMP,
metadata = EXCLUDED.metadata
"#,
)
.bind(workspace_id)
.bind(status)
.bind(serde_json::json!({ "lastDocId": result.next_cursor }))
.bind(completed && result.failed_docs == 0)
.bind(serde_json::json!({
"parserVersion": PARSER_VERSION,
}))
.execute(pool)
.await
.map_err(|err| RuntimeError::database("Doc blob refs checkpoint write failed", err))?;
Ok(())
}
async fn upsert_projection_failure_checkpoint(pool: &PgPool, workspace_id: &str, error: &str) -> RuntimeResult<()> {
sqlx::query(
r#"
INSERT INTO blob_reconciliation_checkpoints
(kind, scope, status, cursor, completed_at, metadata)
VALUES ('doc_blob_refs', $1, 'failed', '{}', NULL, $2)
ON CONFLICT (kind, scope) DO UPDATE
SET status = 'failed',
cursor = '{}',
completed_at = NULL,
updated_at = CURRENT_TIMESTAMP,
metadata = EXCLUDED.metadata
"#,
)
.bind(workspace_id)
.bind(serde_json::json!({
"parserVersion": PARSER_VERSION,
"error": error,
}))
.execute(pool)
.await
.map_err(|err| RuntimeError::database("Doc blob refs failure checkpoint write failed", err))?;
Ok(())
}
async fn load_projection_cursor(pool: &PgPool, workspace_id: &str) -> RuntimeResult<Option<String>> {
let cursor = sqlx::query_scalar::<_, serde_json::Value>(
"SELECT cursor FROM blob_reconciliation_checkpoints WHERE kind = 'doc_blob_refs' AND scope = $1",
)
.bind(workspace_id)
.fetch_optional(pool)
.await
.map_err(|err| RuntimeError::database("Doc blob refs checkpoint load failed", err))?;
Ok(cursor.and_then(|cursor| {
cursor
.get("lastDocId")
.and_then(|value| value.as_str())
.map(ToString::to_string)
}))
}
async fn purge_removed_doc_refs(pool: &PgPool, workspace_id: &str, current_doc_ids: &[String]) -> RuntimeResult<i64> {
let result = sqlx::query(
r#"
DELETE FROM doc_blob_refs
WHERE workspace_id = $1
AND NOT (doc_id = ANY($2))
"#,
)
.bind(workspace_id)
.bind(current_doc_ids)
.execute(pool)
.await
.map_err(|err| RuntimeError::database("Doc blob refs purge removed docs failed", err))?;
Ok(result.rows_affected() as i64)
}
fn extract_refs(snapshot: &SnapshotRow) -> RuntimeResult<Vec<ExtractedRef>> {
let parsed = doc_parser::parse_doc_from_binary(snapshot.blob.clone(), snapshot.doc_id.clone())
.map_err(|err| RuntimeError::invalid_state(format!("Doc blob refs parse failed: {err}")))?;
let mut refs = Vec::new();
for block in parsed.blocks {
let Some(blob_keys) = block.blob else {
continue;
};
for blob_key in blob_keys {
refs.push(ExtractedRef {
blob_key,
block_id: block.block_id.clone(),
flavour: block.flavour.clone(),
});
}
}
Ok(refs)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn doc_blob_refs_extracts_image_refs() {
let doc_id = "doc-blob-ref-test".to_string();
let blob =
doc_parser::build_full_doc("Doc", "![Alt](blob://image-blob-key)", &doc_id).expect("doc fixture should build");
let snapshot = SnapshotRow {
workspace_id: "workspace".to_string(),
doc_id,
blob,
updated_at: Utc::now(),
};
let refs = extract_refs(&snapshot).expect("refs should parse");
assert!(
refs
.iter()
.any(|reference| { reference.blob_key == "image-blob-key" && reference.flavour == "affine:image" })
);
}
}
async fn replace_doc_refs(pool: &PgPool, snapshot: &SnapshotRow, refs: Vec<ExtractedRef>) -> RuntimeResult<(i64, i64)> {
let mut tx = pool
.begin()
.await
.map_err(|err| RuntimeError::database("Doc blob refs transaction failed", err))?;
let deleted = sqlx::query("DELETE FROM doc_blob_refs WHERE workspace_id = $1 AND doc_id = $2")
.bind(&snapshot.workspace_id)
.bind(&snapshot.doc_id)
.execute(&mut *tx)
.await
.map_err(|err| RuntimeError::database("Doc blob refs delete failed", err))?
.rows_affected() as i64;
let mut written = 0;
for reference in refs {
let affected = sqlx::query(
r#"
INSERT INTO doc_blob_refs
(workspace_id, doc_id, blob_key, block_id, flavour, snapshot_updated_at, parser_version, status)
VALUES ($1, $2, $3, $4, $5, $6, $7, 'fresh')
ON CONFLICT (workspace_id, doc_id, blob_key, block_id) DO UPDATE
SET flavour = EXCLUDED.flavour,
snapshot_updated_at = EXCLUDED.snapshot_updated_at,
indexed_at = CURRENT_TIMESTAMP,
parser_version = EXCLUDED.parser_version,
status = 'fresh',
error = NULL
"#,
)
.bind(&snapshot.workspace_id)
.bind(&snapshot.doc_id)
.bind(reference.blob_key)
.bind(reference.block_id)
.bind(reference.flavour)
.bind(snapshot.updated_at)
.bind(PARSER_VERSION)
.execute(&mut *tx)
.await
.map_err(|err| RuntimeError::database("Doc blob refs insert failed", err))?
.rows_affected() as i64;
written += affected;
}
tx.commit()
.await
.map_err(|err| RuntimeError::database("Doc blob refs transaction commit failed", err))?;
Ok((written, deleted))
}
async fn mark_doc_failed(pool: &PgPool, workspace_id: &str, doc_id: &str, error: &str) -> RuntimeResult<()> {
sqlx::query(
r#"
INSERT INTO doc_blob_refs
(workspace_id, doc_id, blob_key, block_id, flavour, snapshot_updated_at, parser_version, status, error)
VALUES ($1, $2, '__parse_failed__', '__parse_failed__', '__parse_failed__', CURRENT_TIMESTAMP, $3, 'failed', $4)
ON CONFLICT (workspace_id, doc_id, blob_key, block_id) DO UPDATE
SET indexed_at = CURRENT_TIMESTAMP,
status = 'failed',
error = EXCLUDED.error
"#,
)
.bind(workspace_id)
.bind(doc_id)
.bind(PARSER_VERSION)
.bind(error)
.execute(pool)
.await
.map_err(|err| RuntimeError::database("Doc blob refs mark failure failed", err))?;
Ok(())
}
async fn rebuild_doc_blob_refs_inner(
runtime: &StorageRuntime,
workspace_id: String,
doc_id: String,
) -> RuntimeResult<RuntimeDocBlobRefsResult> {
let pool = runtime.pool().await?;
let mut result = RuntimeDocBlobRefsResult {
scanned_docs: 1,
parsed_docs: 0,
refs_written: 0,
refs_deleted: 0,
failed_docs: 0,
next_cursor: None,
};
let Some(snapshot) = load_current_doc(&pool, &workspace_id, &doc_id).await? else {
result.failed_docs = 1;
mark_doc_failed(&pool, &workspace_id, &doc_id, "snapshot_missing").await?;
return Ok(result);
};
match extract_refs(&snapshot) {
Ok(refs) => {
let (written, deleted) = replace_doc_refs(&pool, &snapshot, refs).await?;
result.parsed_docs = 1;
result.refs_written = written;
result.refs_deleted = deleted;
}
Err(err) => {
result.failed_docs = 1;
mark_doc_failed(&pool, &workspace_id, &doc_id, &err.to_string()).await?;
}
}
Ok(result)
}
#[napi_derive::napi]
impl StorageRuntime {
#[napi]
pub async fn rebuild_doc_blob_refs(
&self,
workspace_id: String,
doc_id: String,
) -> napi::Result<RuntimeDocBlobRefsResult> {
Ok(rebuild_doc_blob_refs_inner(self, workspace_id, doc_id).await?)
}
#[napi]
pub async fn rebuild_workspace_doc_blob_refs(
&self,
workspace_id: String,
limit: i64,
) -> napi::Result<RuntimeDocBlobRefsResult> {
if limit <= 0 {
return Err(napi_error("doc blob refs rebuild limit must be positive"));
}
let pool = self.pool().await?;
let doc_ids = match load_workspace_doc_ids(&pool, &workspace_id).await {
Ok(doc_ids) => doc_ids,
Err(err) => {
upsert_projection_failure_checkpoint(&pool, &workspace_id, &err.to_string()).await?;
return Err(err.into());
}
};
let cursor = load_projection_cursor(&pool, &workspace_id).await?;
let current_doc_ids = doc_ids.clone();
let doc_ids = doc_ids
.into_iter()
.filter(|doc_id| cursor.as_ref().is_none_or(|cursor| doc_id > cursor))
.collect::<Vec<_>>();
let has_more = doc_ids.len() > limit as usize;
let mut total = RuntimeDocBlobRefsResult {
scanned_docs: 0,
parsed_docs: 0,
refs_written: 0,
refs_deleted: 0,
failed_docs: 0,
next_cursor: None,
};
let mut last_doc_id = None;
for doc_id in doc_ids.into_iter().take(limit as usize) {
last_doc_id = Some(doc_id.clone());
let result = rebuild_doc_blob_refs_inner(self, workspace_id.clone(), doc_id).await?;
total.scanned_docs += result.scanned_docs;
total.parsed_docs += result.parsed_docs;
total.refs_written += result.refs_written;
total.refs_deleted += result.refs_deleted;
total.failed_docs += result.failed_docs;
}
if has_more {
total.next_cursor = last_doc_id;
} else if total.failed_docs == 0 {
total.refs_deleted += purge_removed_doc_refs(&pool, &workspace_id, &current_doc_ids).await?;
}
upsert_projection_checkpoint(&pool, &workspace_id, &total).await?;
Ok(total)
}
}
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -1,29 +1,29 @@
use aws_sdk_s3::config::{
BehaviorVersion, Credentials, Region, RequestChecksumCalculation, ResponseChecksumValidation, timeout::TimeoutConfig,
};
use napi::Result;
use rusty_s3::{Bucket, Credentials, UrlStyle};
use serde::Deserialize;
use url::Url;
use super::{client::ObjectStorageClient, types::StorageProviderConfig};
use crate::backend_runtime::{
config::blob_storage_config_from_config_files, error::napi_error, types::RuntimeObjectStorageHealth,
use super::{
client::ObjectStorageClient,
error::{ObjectStorageError, ObjectStorageResult},
types::StorageProviderConfig,
};
#[derive(Clone, Debug)]
pub(in crate::backend_runtime) struct ObjectStorageConfig {
pub(super) provider: String,
pub(super) bucket: String,
pub(super) endpoint: Option<String>,
pub(super) region: Option<String>,
pub(super) access_key_id: Option<String>,
pub(super) secret_access_key: Option<String>,
pub(super) session_token: Option<String>,
pub(super) force_path_style: bool,
pub(super) request_timeout_ms: Option<u64>,
pub(super) min_part_size: Option<u64>,
pub(super) presign_expires_in_seconds: Option<u64>,
pub(super) presign_sign_content_type_for_put: Option<bool>,
pub(super) use_presigned_url: bool,
pub(crate) struct ObjectStorageConfig {
pub(crate) provider: String,
pub(crate) bucket: String,
pub(crate) endpoint: Option<String>,
pub(crate) region: Option<String>,
pub(crate) access_key_id: Option<String>,
pub(crate) secret_access_key: Option<String>,
pub(crate) session_token: Option<String>,
pub(crate) force_path_style: bool,
pub(crate) request_timeout_ms: Option<u64>,
pub(crate) min_part_size: Option<u64>,
pub(crate) presign_expires_in_seconds: Option<u64>,
pub(crate) presign_sign_content_type_for_put: Option<bool>,
pub(crate) use_presigned_url: bool,
pub(crate) proxy_upload: bool,
}
#[derive(Debug, Deserialize)]
@@ -70,13 +70,16 @@ struct S3PresignConfigFile {
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct UsePresignedUrlConfigFile {
enabled: bool,
url_prefix: Option<String>,
sign_key: Option<String>,
}
impl ObjectStorageConfig {
pub(in crate::backend_runtime) fn from_config_files() -> Result<Option<Self>> {
let Some(storage) = blob_storage_config_from_config_files()? else {
pub(crate) fn from_provider_config(storage: Option<StorageProviderConfig>) -> ObjectStorageResult<Option<Self>> {
let Some(storage) = storage else {
return Ok(None);
};
@@ -84,18 +87,18 @@ impl ObjectStorageConfig {
"aws-s3" => Self::from_s3_config(storage),
"cloudflare-r2" => Self::from_r2_config(storage),
"fs" => Ok(None),
provider => Err(napi_error(format!(
"unsupported blob storage provider for BackendRuntime: {provider}"
provider => Err(ObjectStorageError::Config(format!(
"unsupported blob storage provider for StorageRuntime: {provider}"
))),
}
}
pub(super) fn from_s3_config(storage: StorageProviderConfig) -> Result<Option<Self>> {
pub(crate) fn from_s3_config(storage: StorageProviderConfig) -> ObjectStorageResult<Option<Self>> {
let config: S3ConfigFile = serde_json::from_value(storage.config)
.map_err(|err| napi_error(format!("invalid aws-s3 blob storage config: {err}")))?;
.map_err(|err| ObjectStorageError::Config(format!("invalid aws-s3 blob storage config: {err}")))?;
let region = config
.region
.ok_or_else(|| napi_error("aws-s3 blob storage config requires region"))?;
.ok_or_else(|| ObjectStorageError::Config("aws-s3 blob storage config requires region".to_string()))?;
let endpoint = config.endpoint.or_else(|| Some(resolve_s3_endpoint(&region)));
let credentials = config.credentials.unwrap_or_default();
@@ -113,17 +116,29 @@ impl ObjectStorageConfig {
presign_expires_in_seconds: config.presign.as_ref().and_then(|v| v.expires_in_seconds),
presign_sign_content_type_for_put: config.presign.as_ref().and_then(|v| v.sign_content_type_for_put),
use_presigned_url: config.use_presigned_url.map(|v| v.enabled).unwrap_or(false),
proxy_upload: false,
}))
}
pub(super) fn from_r2_config(storage: StorageProviderConfig) -> Result<Option<Self>> {
pub(crate) fn from_r2_config(storage: StorageProviderConfig) -> ObjectStorageResult<Option<Self>> {
let config: R2ConfigFile = serde_json::from_value(storage.config)
.map_err(|err| napi_error(format!("invalid cloudflare-r2 blob storage config: {err}")))?;
.map_err(|err| ObjectStorageError::Config(format!("invalid cloudflare-r2 blob storage config: {err}")))?;
let account = match config.jurisdiction {
Some(jurisdiction) => format!("{}.{}", config.account_id, jurisdiction),
None => config.account_id,
};
let credentials = config.credentials.unwrap_or_default();
let (use_presigned_url, proxy_upload) = config
.use_presigned_url
.map(|value| {
(
value.enabled,
value.enabled
&& value.url_prefix.as_ref().is_some_and(|prefix| !prefix.is_empty())
&& value.sign_key.as_ref().is_some_and(|key| !key.is_empty()),
)
})
.unwrap_or((false, false));
Ok(Some(Self {
provider: storage.provider,
@@ -138,81 +153,51 @@ impl ObjectStorageConfig {
min_part_size: config.min_part_size,
presign_expires_in_seconds: config.presign.as_ref().and_then(|v| v.expires_in_seconds),
presign_sign_content_type_for_put: config.presign.as_ref().and_then(|v| v.sign_content_type_for_put),
use_presigned_url: config.use_presigned_url.map(|v| v.enabled).unwrap_or(false),
use_presigned_url,
proxy_upload,
}))
}
pub(super) fn build_client(&self) -> Result<ObjectStorageClient> {
pub(crate) fn build_client(&self) -> ObjectStorageResult<ObjectStorageClient> {
let region = self
.region
.clone()
.ok_or_else(|| napi_error("object storage region is required"))?;
.ok_or_else(|| ObjectStorageError::Config("object storage region is required".to_string()))?;
let access_key_id = self
.access_key_id
.clone()
.ok_or_else(|| napi_error("object storage accessKeyId is required"))?;
.ok_or_else(|| ObjectStorageError::Config("object storage accessKeyId is required".to_string()))?;
let secret_access_key = self
.secret_access_key
.clone()
.ok_or_else(|| napi_error("object storage secretAccessKey is required"))?;
.ok_or_else(|| ObjectStorageError::Config("object storage secretAccessKey is required".to_string()))?;
let credentials = Credentials::new(
access_key_id,
secret_access_key,
self.session_token.clone(),
None,
"affine-server-config-json",
);
let mut builder = aws_sdk_s3::Config::builder()
.behavior_version(BehaviorVersion::latest())
.region(Region::new(region))
.credentials_provider(credentials)
.force_path_style(self.force_path_style)
.request_checksum_calculation(RequestChecksumCalculation::WhenRequired)
.response_checksum_validation(ResponseChecksumValidation::WhenRequired);
if let Some(endpoint) = &self.endpoint {
builder = builder.endpoint_url(endpoint);
}
if let Some(request_timeout_ms) = self.request_timeout_ms {
builder = builder.timeout_config(
TimeoutConfig::builder()
.operation_timeout(std::time::Duration::from_millis(request_timeout_ms))
.build(),
);
}
Ok(ObjectStorageClient::new(
builder.build(),
let endpoint = self.endpoint.clone().unwrap_or_else(|| resolve_s3_endpoint(&region));
let endpoint = Url::parse(&endpoint)
.map_err(|err| ObjectStorageError::Config(format!("object storage endpoint is invalid: {err}")))?;
let bucket = Bucket::new(
endpoint,
if self.force_path_style {
UrlStyle::Path
} else {
UrlStyle::VirtualHost
},
self.bucket.clone(),
region,
)
.map_err(|err| ObjectStorageError::Config(format!("object storage bucket url is invalid: {err}")))?;
let credentials = match self.session_token.as_ref().filter(|token| !token.is_empty()) {
Some(session_token) => Credentials::new_with_token(access_key_id, secret_access_key, session_token.clone()),
None => Credentials::new(access_key_id, secret_access_key),
};
ObjectStorageClient::new(
bucket,
credentials,
self.request_timeout_ms,
self.presign_expires_in_seconds.unwrap_or(60),
self.presign_sign_content_type_for_put.unwrap_or(true),
))
}
pub(super) fn health(&self) -> RuntimeObjectStorageHealth {
let client_buildable = self
.build_client()
.map(|client| client.non_destructive_health())
.unwrap_or(false);
RuntimeObjectStorageHealth {
configured: true,
provider: Some(self.provider.clone()),
bucket: Some(self.bucket.clone()),
endpoint: self.endpoint.clone(),
region: self.region.clone(),
has_credentials: self.access_key_id.is_some()
&& self.secret_access_key.is_some()
&& self.session_token.as_ref().map(|v| !v.is_empty()).unwrap_or(true),
force_path_style: self.force_path_style,
request_timeout_ms: self.request_timeout_ms.map(|v| v as i64),
min_part_size: self.min_part_size.map(|v| v as i64),
presign_expires_in_seconds: self.presign_expires_in_seconds.map(|v| v as i64),
presign_sign_content_type_for_put: self.presign_sign_content_type_for_put,
use_presigned_url: self.use_presigned_url,
client_buildable,
}
)
}
}
@@ -0,0 +1,64 @@
use reqwest::StatusCode;
#[derive(Debug, thiserror::Error)]
pub(crate) enum ObjectStorageError {
#[error("ObjectStorage config error: {0}")]
Config(String),
#[error("{context}: {source}")]
Operation {
context: String,
#[source]
source: Box<ObjectStorageError>,
},
#[error("ObjectStorage http client build failed: {0}")]
HttpClientBuild(#[source] reqwest::Error),
#[error("ObjectStorage http request failed: {0}")]
HttpRequest(#[source] reqwest::Error),
#[error("ObjectStorage invalid http header: {0}")]
InvalidHeader(String),
#[error("ObjectStorage response body exceeds {limit} bytes")]
BodyTooLarge { limit: usize },
#[error("{context}: status={status} body={body}")]
HttpStatus {
context: String,
status: StatusCode,
body: String,
},
#[error("{context}: invalid utf8 response: {source}")]
InvalidUtf8 {
context: String,
#[source]
source: std::string::FromUtf8Error,
},
#[error("{context}: invalid xml response: {source}")]
InvalidXml {
context: String,
#[source]
source: instant_xml::Error,
},
#[error("ObjectStorage invalid input: {0}")]
InvalidInput(String),
}
impl ObjectStorageError {
pub(crate) fn is_not_found(&self) -> bool {
match self {
Self::Operation { source, .. } => source.is_not_found(),
Self::HttpStatus { status, body, .. } => {
*status == StatusCode::NOT_FOUND
&& (body.contains("NoSuchKey") || body.contains("NoSuchUpload") || body.contains("NotFound"))
}
_ => false,
}
}
pub(crate) fn is_retryable_http_status(&self) -> bool {
match self {
Self::Operation { source, .. } => source.is_retryable_http_status(),
Self::HttpStatus { status, .. } => *status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error(),
_ => false,
}
}
}
pub(crate) type ObjectStorageResult<T> = std::result::Result<T, ObjectStorageError>;
@@ -0,0 +1,9 @@
pub(crate) mod client;
pub(crate) mod config;
pub(crate) mod error;
#[cfg(test)]
mod tests;
pub(crate) mod types;
pub(crate) use config::ObjectStorageConfig;
pub(crate) use types::StorageProviderConfig;
@@ -0,0 +1,376 @@
use reqwest::StatusCode;
use super::{
config::ObjectStorageConfig,
error::ObjectStorageError,
types::{
MultipartUploadPart, ObjectPutMetadata, StorageProviderConfig, checksum_crc32_base64, completed_multipart_parts,
trim_etag,
},
};
fn storage_config(provider: &str, config: serde_json::Value) -> StorageProviderConfig {
StorageProviderConfig {
provider: provider.to_string(),
bucket: "test-bucket".to_string(),
config,
}
}
#[test]
fn resolves_r2_config_from_config_json_shape() {
let storage = StorageProviderConfig {
provider: "cloudflare-r2".to_string(),
bucket: "workspace-blobs".to_string(),
config: serde_json::json!({
"accountId": "account",
"jurisdiction": "eu",
"credentials": {
"accessKeyId": "key",
"secretAccessKey": "secret"
},
"usePresignedURL": {
"enabled": true
}
}),
};
let config = ObjectStorageConfig::from_r2_config(storage).unwrap().unwrap();
assert_eq!(config.provider, "cloudflare-r2");
assert_eq!(config.bucket, "workspace-blobs");
assert_eq!(
config.endpoint.as_deref(),
Some("https://account.eu.r2.cloudflarestorage.com")
);
assert_eq!(config.region.as_deref(), Some("auto"));
assert!(config.force_path_style);
assert!(config.use_presigned_url);
assert!(!config.proxy_upload);
assert_eq!(config.access_key_id.as_deref(), Some("key"));
}
#[test]
fn resolves_r2_endpoint_cases_from_config_json_shape() {
for (case, config, expected_endpoint) in [
(
"default account endpoint",
serde_json::json!({
"accountId": "account",
"credentials": {
"accessKeyId": "key",
"secretAccessKey": "secret"
}
}),
Some("https://account.r2.cloudflarestorage.com"),
),
(
"explicit null jurisdiction",
serde_json::json!({
"accountId": "account",
"jurisdiction": null,
"credentials": {
"accessKeyId": "key",
"secretAccessKey": "secret"
}
}),
Some("https://account.r2.cloudflarestorage.com"),
),
(
"eu jurisdiction",
serde_json::json!({
"accountId": "account",
"jurisdiction": "eu",
"credentials": {
"accessKeyId": "key",
"secretAccessKey": "secret"
}
}),
Some("https://account.eu.r2.cloudflarestorage.com"),
),
] {
let config = ObjectStorageConfig::from_r2_config(storage_config("cloudflare-r2", config))
.unwrap()
.unwrap();
assert_eq!(config.endpoint.as_deref(), expected_endpoint, "{case}");
assert!(config.force_path_style, "{case}");
}
assert!(
ObjectStorageConfig::from_r2_config(storage_config(
"cloudflare-r2",
serde_json::json!({
"credentials": {
"accessKeyId": "key",
"secretAccessKey": "secret"
}
})
))
.is_err()
);
}
#[test]
fn object_storage_not_found_requires_object_error_code() {
let bucket_or_route_missing = ObjectStorageError::HttpStatus {
context: "head failed".to_string(),
status: StatusCode::NOT_FOUND,
body: String::new(),
};
let object_missing = ObjectStorageError::HttpStatus {
context: "get failed".to_string(),
status: StatusCode::NOT_FOUND,
body: "<Error><Code>NoSuchKey</Code></Error>".to_string(),
};
let upload_missing = ObjectStorageError::HttpStatus {
context: "abort failed".to_string(),
status: StatusCode::NOT_FOUND,
body: "<Error><Code>NoSuchUpload</Code></Error>".to_string(),
};
assert!(!bucket_or_route_missing.is_not_found());
assert!(object_missing.is_not_found());
assert!(upload_missing.is_not_found());
}
#[test]
fn resolves_r2_proxy_upload_capability_from_config_json_shape() {
let storage = StorageProviderConfig {
provider: "cloudflare-r2".to_string(),
bucket: "workspace-blobs".to_string(),
config: serde_json::json!({
"accountId": "account",
"credentials": {
"accessKeyId": "key",
"secretAccessKey": "secret"
},
"usePresignedURL": {
"enabled": true,
"urlPrefix": "https://cdn.example.com",
"signKey": "secret"
}
}),
};
let config = ObjectStorageConfig::from_r2_config(storage).unwrap().unwrap();
assert!(config.use_presigned_url);
assert!(config.proxy_upload);
}
#[test]
fn resolves_s3_config_from_config_json_shape() {
let storage = StorageProviderConfig {
provider: "aws-s3".to_string(),
bucket: "workspace-blobs".to_string(),
config: serde_json::json!({
"region": "us-west-2",
"credentials": {
"accessKeyId": "key",
"secretAccessKey": "secret",
"sessionToken": "session"
},
"forcePathStyle": true,
"requestTimeoutMs": 1000,
"minPartSize": 1024,
"presign": {
"expiresInSeconds": 60,
"signContentTypeForPut": false
}
}),
};
let config = ObjectStorageConfig::from_s3_config(storage).unwrap().unwrap();
assert_eq!(config.provider, "aws-s3");
assert_eq!(config.endpoint.as_deref(), Some("https://s3.us-west-2.amazonaws.com"));
assert_eq!(config.session_token.as_deref(), Some("session"));
assert!(config.force_path_style);
assert_eq!(config.request_timeout_ms, Some(1000));
assert_eq!(config.min_part_size, Some(1024));
assert_eq!(config.presign_expires_in_seconds, Some(60));
assert_eq!(config.presign_sign_content_type_for_put, Some(false));
}
#[test]
fn resolves_s3_default_endpoint_cases_from_config_json_shape() {
for (region, expected_endpoint) in [
("us-east-1", "https://s3.amazonaws.com"),
("us-west-2", "https://s3.us-west-2.amazonaws.com"),
] {
let config = ObjectStorageConfig::from_s3_config(storage_config(
"aws-s3",
serde_json::json!({
"region": region,
"credentials": {
"accessKeyId": "key",
"secretAccessKey": "secret"
}
}),
))
.unwrap()
.unwrap();
assert_eq!(config.endpoint.as_deref(), Some(expected_endpoint), "{region}");
}
}
#[tokio::test]
async fn object_storage_presign_put_returns_sigv4_url_and_headers() {
let storage = StorageProviderConfig {
provider: "aws-s3".to_string(),
bucket: "test-bucket".to_string(),
config: serde_json::json!({
"region": "us-east-1",
"endpoint": "https://s3.us-east-1.amazonaws.com",
"credentials": {
"accessKeyId": "key",
"secretAccessKey": "secret"
},
"presign": {
"expiresInSeconds": 60
}
}),
};
let config = ObjectStorageConfig::from_s3_config(storage).unwrap().unwrap();
let Ok(Ok(client)) = std::panic::catch_unwind(|| config.build_client()) else {
eprintln!("skipping object storage presign test: S3 client cannot be built in this environment");
return;
};
let result = client
.presign_put(
"key",
ObjectPutMetadata {
content_type: Some("text/plain".to_string()),
..Default::default()
},
)
.await
.unwrap();
assert!(result.url.contains("X-Amz-Algorithm=AWS4-HMAC-SHA256"));
assert!(result.url.contains("X-Amz-SignedHeaders="));
assert_eq!(
result.headers.get("Content-Type").map(String::as_str),
Some("text/plain")
);
assert!(result.expires_at_ms > 0);
}
#[tokio::test]
async fn object_storage_presign_put_respects_content_length_and_signed_content_type_flag() {
let config = ObjectStorageConfig::from_s3_config(storage_config(
"aws-s3",
serde_json::json!({
"region": "us-east-1",
"endpoint": "https://s3.us-east-1.amazonaws.com",
"credentials": {
"accessKeyId": "key",
"secretAccessKey": "secret"
},
"presign": {
"expiresInSeconds": 60,
"signContentTypeForPut": false
}
}),
))
.unwrap()
.unwrap();
let client = config.build_client().unwrap();
let result = client
.presign_put(
"key",
ObjectPutMetadata {
content_type: Some("text/plain".to_string()),
content_length: Some(42),
..Default::default()
},
)
.await
.unwrap();
assert_eq!(
result.headers.get("Content-Type").map(String::as_str),
Some("text/plain")
);
assert_eq!(result.headers.get("Content-Length").map(String::as_str), Some("42"));
assert!(!result.url.contains("content-type"));
assert!(result.url.contains("content-length"));
}
#[tokio::test]
async fn object_storage_presign_get_returns_sigv4_url_without_headers() {
let storage = StorageProviderConfig {
provider: "cloudflare-r2".to_string(),
bucket: "test-bucket".to_string(),
config: serde_json::json!({
"accountId": "account",
"credentials": {
"accessKeyId": "key",
"secretAccessKey": "secret"
},
"presign": {
"expiresInSeconds": 60
}
}),
};
let config = ObjectStorageConfig::from_r2_config(storage).unwrap().unwrap();
let client = config.build_client().unwrap();
let result = client.presign_get("workspace/key").await.unwrap();
assert!(result.url.contains("X-Amz-Algorithm=AWS4-HMAC-SHA256"));
assert!(result.url.contains("X-Amz-SignedHeaders=host"));
assert!(result.url.contains("/test-bucket/workspace/key?"));
assert!(result.headers.is_empty());
assert!(result.expires_at_ms > 0);
}
#[tokio::test]
async fn object_storage_presign_upload_part_returns_sigv4_url() {
let config = ObjectStorageConfig::from_s3_config(storage_config(
"aws-s3",
serde_json::json!({
"region": "us-east-1",
"endpoint": "https://s3.us-east-1.amazonaws.com",
"credentials": {
"accessKeyId": "key",
"secretAccessKey": "secret"
},
"presign": {
"expiresInSeconds": 60
}
}),
))
.unwrap()
.unwrap();
let client = config.build_client().unwrap();
let result = client.presign_upload_part("key", "upload-1", 3).await.unwrap();
assert!(result.url.contains("X-Amz-Algorithm=AWS4-HMAC-SHA256"));
assert!(result.url.contains("partNumber=3"));
assert!(result.url.contains("uploadId=upload-1"));
assert!(result.headers.is_empty());
assert!(result.expires_at_ms > 0);
}
#[test]
fn object_storage_orders_completed_multipart_parts_and_trims_etags() {
let parts = completed_multipart_parts(vec![
MultipartUploadPart {
part_number: 2,
etag: trim_etag("\"b\""),
},
MultipartUploadPart {
part_number: 1,
etag: trim_etag("a"),
},
]);
assert_eq!(parts[0].part_number, 1);
assert_eq!(parts[0].etag, "a");
assert_eq!(parts[1].part_number, 2);
assert_eq!(parts[1].etag, "b");
}
#[test]
fn object_storage_crc32_checksum_uses_s3_base64_format() {
assert_eq!(checksum_crc32_base64(b"hello"), "NhCmhg==");
assert_ne!(checksum_crc32_base64(b"hello"), "3610a686");
}
@@ -0,0 +1,191 @@
use std::collections::HashMap;
use base64::{Engine as _, engine::general_purpose::STANDARD};
use serde::Deserialize;
use super::super::{
RuntimeError, RuntimeMultipartUploadInit, RuntimeMultipartUploadPart, RuntimeObjectGetResult, RuntimeObjectListEntry,
RuntimeObjectMetadata, RuntimeObjectStoragePutOptions, RuntimePresignedObjectRequest, RuntimeResult,
};
#[derive(Clone, Debug, Default)]
pub(crate) struct ObjectPutMetadata {
pub(crate) content_type: Option<String>,
pub(crate) content_length: Option<i64>,
pub(crate) checksum_crc32: Option<String>,
}
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct ObjectMetadata {
pub(crate) content_type: String,
pub(crate) content_length: i64,
pub(crate) last_modified_ms: i64,
pub(crate) checksum_crc32: Option<String>,
}
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct ObjectListEntry {
pub(crate) key: String,
pub(crate) content_length: i64,
pub(crate) last_modified_ms: i64,
}
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct ObjectListPage {
pub(crate) entries: Vec<ObjectListEntry>,
pub(crate) next_continuation_token: Option<String>,
}
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct ObjectDeleteOutcome {
pub(crate) key: String,
pub(crate) error: Option<String>,
}
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct ObjectGetResult {
pub(crate) body: Vec<u8>,
pub(crate) metadata: ObjectMetadata,
}
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct PresignedObjectRequest {
pub(crate) url: String,
pub(crate) headers: HashMap<String, String>,
pub(crate) expires_at_ms: i64,
}
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct MultipartUploadInitResult {
pub(crate) upload_id: String,
pub(crate) expires_at_ms: i64,
}
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct MultipartUploadPart {
pub(crate) part_number: i32,
pub(crate) etag: String,
}
#[derive(Clone, Debug, Deserialize)]
pub(crate) struct StorageProviderConfig {
pub(crate) provider: String,
pub(crate) bucket: String,
#[serde(default)]
pub(crate) config: serde_json::Value,
}
pub(crate) fn trim_etag(etag: &str) -> String {
etag.trim_matches('"').to_string()
}
pub(crate) fn completed_multipart_parts(mut parts: Vec<MultipartUploadPart>) -> Vec<MultipartUploadPart> {
parts.sort_by_key(|part| part.part_number);
parts
}
impl From<RuntimeObjectStoragePutOptions> for ObjectPutMetadata {
fn from(options: RuntimeObjectStoragePutOptions) -> Self {
Self {
content_type: options.content_type,
content_length: options.content_length,
checksum_crc32: options.checksum_crc32,
}
}
}
impl ObjectPutMetadata {
pub(crate) fn complete_for_body(mut self, body: &[u8]) -> Self {
self.content_length.get_or_insert(body.len() as i64);
self.checksum_crc32.get_or_insert_with(|| checksum_crc32_base64(body));
self
.content_type
.get_or_insert_with(|| crate::file_type::get_mime(body));
self
}
pub(crate) fn into_object_metadata(self, last_modified_ms: i64) -> ObjectMetadata {
ObjectMetadata {
content_type: self
.content_type
.unwrap_or_else(|| "application/octet-stream".to_string()),
content_length: self.content_length.unwrap_or(0),
last_modified_ms,
checksum_crc32: self.checksum_crc32,
}
}
}
pub(crate) fn checksum_crc32_base64(body: &[u8]) -> String {
STANDARD.encode(crc32fast::hash(body).to_be_bytes())
}
impl From<ObjectMetadata> for RuntimeObjectMetadata {
fn from(metadata: ObjectMetadata) -> Self {
Self {
content_type: metadata.content_type,
content_length: metadata.content_length,
last_modified_ms: metadata.last_modified_ms,
checksum_crc32: metadata.checksum_crc32,
}
}
}
impl From<ObjectListEntry> for RuntimeObjectListEntry {
fn from(entry: ObjectListEntry) -> Self {
Self {
key: entry.key,
content_length: entry.content_length,
last_modified_ms: entry.last_modified_ms,
}
}
}
impl TryFrom<PresignedObjectRequest> for RuntimePresignedObjectRequest {
type Error = RuntimeError;
fn try_from(request: PresignedObjectRequest) -> RuntimeResult<Self> {
Ok(Self {
url: request.url,
headers_json: serde_json::to_string(&request.headers)
.map_err(|err| RuntimeError::json("ObjectStorage headers serialization failed", err))?,
expires_at_ms: request.expires_at_ms,
})
}
}
impl From<ObjectGetResult> for RuntimeObjectGetResult {
fn from(result: ObjectGetResult) -> Self {
Self {
body: result.body.into(),
metadata: result.metadata.into(),
}
}
}
impl From<MultipartUploadInitResult> for RuntimeMultipartUploadInit {
fn from(init: MultipartUploadInitResult) -> Self {
Self {
upload_id: init.upload_id,
expires_at_ms: init.expires_at_ms,
}
}
}
impl From<RuntimeMultipartUploadPart> for MultipartUploadPart {
fn from(part: RuntimeMultipartUploadPart) -> Self {
Self {
part_number: part.part_number,
etag: part.etag,
}
}
}
impl From<MultipartUploadPart> for RuntimeMultipartUploadPart {
fn from(part: MultipartUploadPart) -> Self {
Self {
part_number: part.part_number,
etag: part.etag,
}
}
}
@@ -12,24 +12,6 @@ pub struct RuntimeVerificationTokenRecord {
pub struct BackendRuntimeHealth {
pub started: bool,
pub database_connected: bool,
pub object_storage_configured: bool,
}
#[napi_derive::napi(object)]
pub struct RuntimeObjectStorageHealth {
pub configured: bool,
pub provider: Option<String>,
pub bucket: Option<String>,
pub endpoint: Option<String>,
pub region: Option<String>,
pub has_credentials: bool,
pub force_path_style: bool,
pub request_timeout_ms: Option<i64>,
pub min_part_size: Option<i64>,
pub presign_expires_in_seconds: Option<i64>,
pub presign_sign_content_type_for_put: Option<bool>,
pub use_presigned_url: bool,
pub client_buildable: bool,
}
#[napi_derive::napi(object)]
@@ -138,6 +120,49 @@ pub struct RuntimeBlobCompleteResult {
pub last_modified_ms: Option<i64>,
}
#[napi_derive::napi(object)]
pub struct RuntimeBlobMetadataBackfillResult {
pub scanned_objects: i64,
pub headed_objects: i64,
pub upserted_metadata: i64,
pub skipped_existing: i64,
pub skipped_workspace_missing: i64,
pub failed: i64,
pub next_cursor: Option<String>,
pub workspace_ids: Vec<String>,
}
#[napi_derive::napi(object)]
pub struct RuntimeDocBlobRefsResult {
pub scanned_docs: i64,
pub parsed_docs: i64,
pub refs_written: i64,
pub refs_deleted: i64,
pub failed_docs: i64,
pub next_cursor: Option<String>,
}
#[napi_derive::napi(object)]
pub struct RuntimeBlobCleanupPlanResult {
pub run_id: Option<String>,
pub scanned_blobs: i64,
pub candidates_marked: i64,
pub protected_by_doc_refs: i64,
pub protected_by_metadata: i64,
pub protected_by_other_refs: i64,
pub next_cursor: Option<String>,
}
#[napi_derive::napi(object)]
pub struct RuntimeBlobCleanupExecuteResult {
pub scanned_candidates: i64,
pub deleted_objects: i64,
pub deleted_metadata: i64,
pub skipped_still_referenced: i64,
pub failed: i64,
pub workspace_ids: Vec<String>,
}
#[napi_derive::napi(object)]
pub struct RuntimeDocCompactionResult {
pub lease_acquired: bool,
+6
View File
@@ -1,3 +1,9 @@
use std::time::{SystemTime, SystemTimeError, UNIX_EPOCH};
pub(crate) fn system_time_millis(time: SystemTime) -> Result<u128, SystemTimeError> {
Ok(time.duration_since(UNIX_EPOCH)?.as_millis())
}
fn collapse_whitespace(s: &str) -> String {
let mut result = String::new();
let mut prev_was_whitespace = false;
@@ -65,7 +65,7 @@ Generated by [AVA](https://avajs.dev).
<td>␊
<p␊
style="font-size:20px;line-height:28px;font-weight:600;font-family:Inter, Arial, Helvetica, sans-serif;margin-top:24px;margin-bottom:0;color:#141414">␊
Sign in to AFFiNE Cloud
Sign in to AFFiNE␊
</p>␊
</td>␊
</tr>␊
@@ -205,7 +205,7 @@ Generated by [AVA](https://avajs.dev).
<td>␊
<p␊
style="font-size:20px;line-height:28px;font-weight:600;font-family:Inter, Arial, Helvetica, sans-serif;margin-top:24px;margin-bottom:0;color:#141414">␊
Sign up to AFFiNE Cloud
Sign up to AFFiNE␊
</p>␊
</td>␊
</tr>␊
@@ -2,17 +2,25 @@ import { ScheduleModule } from '@nestjs/schedule';
import { TestingModule } from '@nestjs/testing';
import { PrismaClient } from '@prisma/client';
import test from 'ava';
import Sinon from 'sinon';
import { AuthModule, AuthService } from '../../core/auth';
import { AuthCronJob } from '../../core/auth/job';
import { BackendRuntimeProvider } from '../../core/backend-runtime';
import { createTestingModule } from '../utils';
let m: TestingModule;
let db: PrismaClient;
const runtime = {
cleanupExpiredUserSessions: Sinon.stub(),
};
test.before(async () => {
m = await createTestingModule({
imports: [ScheduleModule.forRoot(), AuthModule],
tapModule: builder => {
builder.overrideProvider(BackendRuntimeProvider).useValue(runtime);
},
});
db = m.get(PrismaClient);
@@ -32,16 +40,17 @@ test('should clean expired user sessions', async t => {
let userSessions = await db.userSession.findMany();
t.is(userSessions.length, 2);
// no expired sessions
runtime.cleanupExpiredUserSessions.reset();
runtime.cleanupExpiredUserSessions.resolves(0);
await job.cleanExpiredUserSessions();
userSessions = await db.userSession.findMany();
t.is(userSessions.length, 2);
t.true(runtime.cleanupExpiredUserSessions.calledOnce);
t.deepEqual(runtime.cleanupExpiredUserSessions.firstCall.args, [1000]);
// clean all expired sessions
await db.userSession.updateMany({
data: { expiresAt: new Date(Date.now() - 1000) },
});
runtime.cleanupExpiredUserSessions.reset();
runtime.cleanupExpiredUserSessions.onCall(0).resolves(1000);
runtime.cleanupExpiredUserSessions.onCall(1).resolves(2);
await job.cleanExpiredUserSessions();
userSessions = await db.userSession.findMany();
t.is(userSessions.length, 0);
t.is(runtime.cleanupExpiredUserSessions.callCount, 2);
t.deepEqual(runtime.cleanupExpiredUserSessions.firstCall.args, [1000]);
t.deepEqual(runtime.cleanupExpiredUserSessions.secondCall.args, [1000]);
});
@@ -1,13 +1,16 @@
import { randomUUID } from 'node:crypto';
import { Global, Module } from '@nestjs/common';
import type { Prisma } from '@prisma/client';
import type { ExecutionContext, TestFn } from 'ava';
import ava from 'ava';
import Sinon from 'sinon';
import { z } from 'zod';
import { AppModuleBuilder, FunctionalityModules } from '../../app.module';
import { JobModule, JobQueue } from '../../base';
import { ServerFeature, ServerService } from '../../core';
import { AuthService } from '../../core/auth';
import { AuthModule, AuthService } from '../../core/auth';
import { QuotaModule } from '../../core/quota';
import { Models } from '../../models';
import { llmImageDispatchPlan } from '../../native';
@@ -25,6 +28,7 @@ import { ChatSession, ChatSessionService } from '../../plugins/copilot/session';
import { TranscriptPayloadSchema } from '../../plugins/copilot/transcript/schema';
import { CopilotTranscriptionService } from '../../plugins/copilot/transcript/service';
import { TestingPromptService } from '../mocks/prompt-service.mock';
import { MockJobQueue } from '../mocks/queue.mock';
import { createTestingModule, TestingModule } from '../utils';
import { TestAssets } from '../utils/copilot';
import {
@@ -48,6 +52,13 @@ type Tester = {
const test = ava as TestFn<Tester>;
@Global()
@Module({
providers: [{ provide: JobQueue, useClass: MockJobQueue }],
exports: [JobQueue],
})
class MockJobModule {}
let isCopilotConfigured = false;
const runIfCopilotConfigured = test.macro(
async (
@@ -64,8 +75,20 @@ const runIfCopilotConfigured = test.macro(
);
test.serial.before(async t => {
const appModule = new AppModuleBuilder()
.use(
...FunctionalityModules.filter(module => {
const moduleType = 'module' in module ? module.module : module;
return moduleType !== JobModule;
}),
MockJobModule,
AuthModule,
QuotaModule,
CopilotModule
)
.compile();
const module = await createTestingModule({
imports: [QuotaModule, CopilotModule],
imports: [appModule],
tapModule: builder => {
builder.overrideProvider(PromptService).useClass(TestingPromptService);
},
@@ -156,8 +156,6 @@ test.before(async t => {
CopilotModule,
],
tapModule: builder => {
// use real JobQueue for testing
builder.overrideProvider(JobQueue).useClass(JobQueue);
builder.overrideProvider(RequestMutex).useValue({
acquire: async () => ({
async [Symbol.asyncDispose]() {},
@@ -811,7 +809,9 @@ test('should schedule title generation as a background job', async t => {
const chatSession = await session.get(sessionId);
t.truthy(chatSession);
const addJob = Sinon.stub(jobs, 'add').resolves();
const addJob = jobs.add as Sinon.SinonStub;
addJob.resetHistory();
addJob.resolves();
chatSession!.pushTurn(
buildTurn(sessionId, {
@@ -1835,6 +1835,14 @@ test('should be able to manage workspace embedding', async t => {
fileId: file.fileId,
fileName: file.fileName,
});
await jobs.embedPendingFile({
userId,
workspaceId: ws.id,
contextId: undefined,
blobId,
fileId: file.fileId,
fileName: file.fileName,
});
let ret = 0;
while (!ret) {
@@ -2059,7 +2067,9 @@ test('should handle copilot cron jobs correctly', async t => {
copilotSession,
'toBeGenerateTitle'
).resolves(mockSessions);
const jobAddStub = Sinon.stub(cronJobs['jobs'], 'add').resolves();
const jobAddStub = cronJobs['jobs'].add as Sinon.SinonStub;
jobAddStub.resetHistory();
jobAddStub.resolves();
// daily cleanup job scheduling
{
@@ -2107,7 +2117,7 @@ test('should handle copilot cron jobs correctly', async t => {
cleanupStub.restore();
toBeGenerateStub.restore();
jobAddStub.restore();
jobAddStub.resetHistory();
});
test('model selection policy should resolve requested optional models consistently', async t => {
@@ -7,8 +7,13 @@ import {
import { PrismaClient } from '@prisma/client';
import { FunctionalityModules } from '../app.module';
import { AFFiNELogger, EventBus, JobQueue } from '../base';
import { createFactory, MockEventBus, MockJobQueue } from './mocks';
import { AFFiNELogger, EventBus, JobModule, JobQueue } from '../base';
import {
createFactory,
MockEventBus,
MockJobModule,
MockJobQueue,
} from './mocks';
import { TEST_LOG_LEVEL } from './utils';
interface TestingModuleMetadata extends ModuleMetadata {
@@ -26,10 +31,17 @@ export async function createModule(
metadata: TestingModuleMetadata = {}
): Promise<TestingModule> {
const { tapModule, ...meta } = metadata;
const functionalityModules = [
...FunctionalityModules.filter(module => {
const moduleType = 'module' in module ? module.module : module;
return moduleType !== JobModule;
}),
MockJobModule,
];
const builder = Test.createTestingModule({
...meta,
imports: [...FunctionalityModules, ...(meta.imports ?? [])],
imports: [...functionalityModules, ...(meta.imports ?? [])],
});
builder
@@ -1,7 +1,9 @@
import { ScheduleModule } from '@nestjs/schedule';
import { PrismaClient } from '@prisma/client';
import ava, { TestFn } from 'ava';
import Sinon from 'sinon';
import { BackendRuntimeProvider } from '../../core/backend-runtime';
import { DocStorageModule } from '../../core/doc';
import { DocStorageCronJob } from '../../core/doc/job';
import { createTestingModule, type TestingModule } from '../utils';
@@ -10,14 +12,23 @@ interface Context {
module: TestingModule;
db: PrismaClient;
cronJob: DocStorageCronJob;
runtime: { cleanupExpiredSnapshotHistories: Sinon.SinonStub };
}
const test = ava as TestFn<Context>;
// cleanup database before each test
test.before(async t => {
t.context.runtime = {
cleanupExpiredSnapshotHistories: Sinon.stub(),
};
t.context.module = await createTestingModule({
imports: [ScheduleModule.forRoot(), DocStorageModule],
tapModule: builder => {
builder
.overrideProvider(BackendRuntimeProvider)
.useValue(t.context.runtime);
},
});
t.context.db = t.context.module.get(PrismaClient);
@@ -26,6 +37,7 @@ test.before(async t => {
test.beforeEach(async t => {
await t.context.module.initTestingDB();
t.context.runtime.cleanupExpiredSnapshotHistories.reset();
});
test.after.always(async t => {
@@ -33,7 +45,7 @@ test.after.always(async t => {
});
test('should be able to cleanup expired history', async t => {
const { db } = t.context;
const { db, runtime } = t.context;
const timestamp = Date.now();
// insert expired data
@@ -65,12 +77,10 @@ test('should be able to cleanup expired history', async t => {
let count = await db.snapshotHistory.count();
t.is(count, 20);
runtime.cleanupExpiredSnapshotHistories.onCall(0).resolves(1000);
runtime.cleanupExpiredSnapshotHistories.onCall(1).resolves(10);
await t.context.cronJob.cleanExpiredHistories();
count = await db.snapshotHistory.count();
t.is(count, 10);
const example = await db.snapshotHistory.findFirst();
t.truthy(example);
t.true(example!.expiredAt > new Date());
t.is(runtime.cleanupExpiredSnapshotHistories.callCount, 2);
});
@@ -13,6 +13,7 @@ import {
AFFiNELogger,
CacheInterceptor,
CloudThrottlerGuard,
ConfigFactory,
EventBus,
GlobalExceptionFilter,
JobQueue,
@@ -250,6 +251,31 @@ export async function createApp(
}
const module = await builder.compile();
module.get(ConfigFactory).override({
storages: {
avatar: {
storage: {
provider: 'assetpack',
bucket: 'avatars',
config: { path: '/tmp/affine-test-storage' },
},
},
blob: {
storage: {
provider: 'assetpack',
bucket: 'blobs',
config: { path: '/tmp/affine-test-storage' },
},
},
},
copilot: {
storage: {
provider: 'assetpack',
bucket: 'copilot',
config: { path: '/tmp/affine-test-storage' },
},
},
});
module.useCustomApplicationConstructor(TestingApp);
@@ -306,3 +306,50 @@ e2e('should require Doc.Read to query workspace page meta', async t => {
})
);
});
e2e('should require Doc.Read to query doc histories', async t => {
const owner = await app.signup();
const member = await app.createUser();
await app.login(member);
await app.switchUser(owner);
const workspace = await app.create(Mockers.Workspace, {
owner: { id: owner.id },
});
await app.create(Mockers.WorkspaceUser, {
workspaceId: workspace.id,
userId: member.id,
type: WorkspaceRole.Collaborator,
});
const docSnapshot = await app.create(Mockers.DocSnapshot, {
workspaceId: workspace.id,
user: owner,
});
const doc = await app.create(Mockers.DocMeta, {
workspaceId: workspace.id,
docId: docSnapshot.id,
title: 'private-doc',
defaultRole: DocRole.None,
});
await app.switchUser(member);
await t.throwsAsync(
app.gql({
query: {
id: 'workspaceDocHistoriesPermissionCheck',
op: 'workspaceDocHistoriesPermissionCheck',
query: `
query {
workspace(id: "${workspace.id}") {
histories(guid: "${doc.docId}") {
timestamp
}
}
}
`,
} satisfies GraphQLQuery,
variables: undefined,
})
);
});
@@ -1,4 +1,4 @@
import { createHash } from 'node:crypto';
import { createHash, createHmac } from 'node:crypto';
import { mock } from 'node:test';
import {
@@ -6,22 +6,13 @@ import {
ConfigFactory,
PROXY_MULTIPART_PATH,
PROXY_UPLOAD_PATH,
StorageProviderConfig,
StorageProviderFactory,
toBuffer,
type R2StorageConfig,
SIGNED_URL_EXPIRED,
type StorageProviderConfig,
} from '../../../base';
import {
R2StorageConfig,
R2StorageProvider,
} from '../../../base/storage/providers/r2';
import { SIGNED_URL_EXPIRED } from '../../../base/storage/providers/utils';
import { EntitlementService } from '../../../core/entitlement';
import {
CommentAttachmentStorage,
WorkspaceBlobStorage,
} from '../../../core/storage';
import { MULTIPART_THRESHOLD } from '../../../core/storage/constants';
import { R2UploadController } from '../../../core/storage/r2-proxy';
import { StorageRuntimeProvider } from '../../../core/storage-runtime';
import {
SubscriptionPlan,
SubscriptionRecurring,
@@ -29,7 +20,7 @@ import {
} from '../../../plugins/payment/types';
import { app, e2e, Mockers } from '../test';
class MockR2Provider extends R2StorageProvider {
class MockStorageRuntime {
createMultipartCalls = 0;
putCalls: {
key: string;
@@ -46,45 +37,72 @@ class MockR2Provider extends R2StorageProvider {
contentLength?: number;
}[] = [];
constructor(config: R2StorageConfig, bucket: string) {
super(config, bucket);
async providerCapabilities() {
const storage = app.get(Config).storages.blob.storage;
const usePresignedURL = (storage.config as R2StorageConfig).usePresignedURL;
if (storage.provider !== 'cloudflare-r2') {
return {
put: true,
get: true,
head: true,
list: true,
delete: true,
presignPut: false,
presignGet: false,
multipartDirect: false,
proxyUpload: false,
assetpack: false,
serverMediatedOnly: true,
};
}
return {
put: true,
get: true,
head: true,
list: true,
delete: true,
presignPut: true,
presignGet: false,
multipartDirect: true,
proxyUpload: !!usePresignedURL?.enabled,
assetpack: false,
serverMediatedOnly: false,
};
}
destroy() {}
override async proxyPutObject(
async presignPut(
_scope: string,
key: string,
body: any,
options: { contentType?: string; contentLength?: number } = {}
metadata: { contentType?: string; contentLength?: number } = {}
) {
this.putCalls.push({
key,
body: await toBuffer(body),
contentType: options.contentType,
contentLength: options.contentLength,
});
const storage = app.get(Config).storages.blob.storage;
const r2 = storage.config as R2StorageConfig;
if (!r2.usePresignedURL?.enabled) {
return {
url: 'https://test-bucket.r2.example.com/object?X-Amz-Algorithm=AWS4-HMAC-SHA256',
headers: {},
expiresAt: new Date(Date.now() + SIGNED_URL_EXPIRED * 1000),
};
}
const [workspaceId, blobKey] = key.split('/');
return createProxyUrl(
PROXY_UPLOAD_PATH,
[
workspaceId,
blobKey,
metadata.contentType ?? 'application/octet-stream',
metadata.contentLength,
],
{
workspaceId,
key: blobKey,
contentType: metadata.contentType ?? 'application/octet-stream',
contentLength: metadata.contentLength,
}
);
}
override async proxyUploadPart(
key: string,
uploadId: string,
partNumber: number,
body: any,
options: { contentLength?: number } = {}
) {
const etag = `etag-${partNumber}`;
this.partCalls.push({
key,
uploadId,
partNumber,
etag,
body: await toBuffer(body),
contentLength: options.contentLength,
});
return etag;
}
override async createMultipartUpload() {
async createMultipartUpload() {
this.createMultipartCalls += 1;
return {
uploadId: 'upload-id',
@@ -92,7 +110,30 @@ class MockR2Provider extends R2StorageProvider {
};
}
override async listMultipartUploadParts(key: string, uploadId: string) {
async presignUploadPart(
_scope: string,
key: string,
uploadId: string,
partNumber: number
) {
const [workspaceId, blobKey] = key.split('/');
return createProxyUrl(
PROXY_MULTIPART_PATH,
[workspaceId, blobKey, uploadId, partNumber],
{
workspaceId,
key: blobKey,
uploadId,
partNumber,
}
);
}
async listMultipartUploadParts(
_scope: string,
key: string,
uploadId: string
) {
const latest = new Map<number, string>();
for (const part of this.partCalls) {
if (part.key !== key || part.uploadId !== uploadId) {
@@ -104,6 +145,45 @@ class MockR2Provider extends R2StorageProvider {
.sort((left, right) => left[0] - right[0])
.map(([partNumber, etag]) => ({ partNumber, etag }));
}
async putObject(
_scope: string,
key: string,
body: Buffer,
options: { contentType?: string; contentLength?: number } = {}
) {
this.putCalls.push({
key,
body,
contentType: options.contentType,
contentLength: options.contentLength,
});
return {
contentType: options.contentType ?? 'application/octet-stream',
contentLength: options.contentLength ?? body.length,
lastModified: new Date(),
};
}
async proxyUploadPart(
_scope: string,
key: string,
uploadId: string,
partNumber: number,
body: Buffer,
contentLength?: number
) {
const etag = `etag-${partNumber}`;
this.partCalls.push({
key,
uploadId,
partNumber,
etag,
body,
contentLength,
});
return etag;
}
}
const baseR2Storage: StorageProviderConfig = {
@@ -125,55 +205,40 @@ const baseR2Storage: StorageProviderConfig = {
};
let defaultBlobStorage: StorageProviderConfig;
let provider: MockR2Provider | null = null;
let factoryCreateUnmocked: StorageProviderFactory['create'];
let runtime: MockStorageRuntime;
e2e.before(() => {
defaultBlobStorage = structuredClone(app.get(Config).storages.blob.storage);
const factory = app.get(StorageProviderFactory);
factoryCreateUnmocked = factory.create.bind(factory);
});
e2e.beforeEach(async () => {
provider?.destroy();
provider = null;
const factory = app.get(StorageProviderFactory);
mock.method(factory, 'create', (config: StorageProviderConfig) => {
if (config.provider === 'cloudflare-r2') {
if (!provider) {
provider = new MockR2Provider(
config.config as R2StorageConfig,
config.bucket
);
}
return provider;
}
return factoryCreateUnmocked(config);
});
runtime = new MockStorageRuntime();
const rt = app.get(StorageRuntimeProvider);
for (const method of [
'providerCapabilities',
'presignPut',
'createMultipartUpload',
'presignUploadPart',
'listMultipartUploadParts',
'putObject',
'proxyUploadPart',
] as const) {
mock.method(rt, method, (...args: any[]) =>
(runtime[method] as any)(...args)
);
}
await useR2Storage();
});
e2e.afterEach.always(async () => {
await setBlobStorage(defaultBlobStorage);
provider?.destroy();
provider = null;
mock.reset();
});
async function setBlobStorage(storage: StorageProviderConfig) {
provider?.destroy();
provider = null;
const configFactory = app.get(ConfigFactory);
configFactory.override({ storages: { blob: { storage } } });
const blobStorage = app.get(WorkspaceBlobStorage);
await blobStorage.onConfigInit();
const commentAttachmentStorage = app.get(CommentAttachmentStorage);
await commentAttachmentStorage.onConfigInit();
const controller = app.get(R2UploadController);
// reset cached provider in controller
(controller as any).provider = null;
}
async function useR2Storage(
@@ -193,11 +258,8 @@ async function useR2Storage(
return storage;
}
function getProvider(): MockR2Provider {
if (!provider) {
throw new Error('R2 provider is not initialized');
}
return provider;
function getRuntime(): MockStorageRuntime {
return runtime;
}
async function createBlobUpload(
@@ -285,7 +347,7 @@ async function gql<QueryData = any>(
return res.body.data;
}
e2e('should proxy single upload with valid signature', async t => {
e2e.serial('should proxy single upload with valid signature', async t => {
const { workspace } = await setupWorkspace();
const buffer = Buffer.from('r2-proxy');
const key = sha256Base64urlWithPadding(buffer);
@@ -300,6 +362,7 @@ e2e('should proxy single upload with valid signature', async t => {
t.is(init.method, 'PRESIGNED');
t.truthy(init.uploadUrl);
const uploadUrl = new URL(init.uploadUrl, app.url);
t.is(uploadUrl.origin, 'https://cdn.example.com');
t.is(uploadUrl.pathname, PROXY_UPLOAD_PATH);
const res = await app
@@ -309,7 +372,7 @@ e2e('should proxy single upload with valid signature', async t => {
.send(buffer);
t.is(res.status, 200);
const calls = getProvider().putCalls;
const calls = getRuntime().putCalls;
t.is(calls.length, 1);
t.is(calls[0].key, `${workspace.id}/${key}`);
t.is(calls[0].contentType, 'text/plain');
@@ -317,7 +380,7 @@ e2e('should proxy single upload with valid signature', async t => {
t.deepEqual(calls[0].body, buffer);
});
e2e('should proxy multipart upload and return etag', async t => {
e2e.serial('should proxy multipart upload and return etag', async t => {
const { workspace } = await setupWorkspace();
const key = 'multipart-object';
const totalSize = MULTIPART_THRESHOLD + 1024;
@@ -329,6 +392,7 @@ e2e('should proxy multipart upload and return etag', async t => {
const part = await getBlobUploadPartUrl(workspace.id, key, init.uploadId, 1);
const partUrl = new URL(part.uploadUrl, app.url);
t.is(partUrl.origin, 'https://cdn.example.com');
t.is(partUrl.pathname, PROXY_MULTIPART_PATH);
const payload = Buffer.from('part-body');
@@ -340,7 +404,7 @@ e2e('should proxy multipart upload and return etag', async t => {
t.is(res.status, 200);
t.is(res.get('etag'), 'etag-1');
const calls = getProvider().partCalls;
const calls = getRuntime().partCalls;
t.is(calls.length, 1);
t.is(calls[0].key, `${workspace.id}/${key}`);
t.is(calls[0].uploadId, 'upload-id');
@@ -349,34 +413,42 @@ e2e('should proxy multipart upload and return etag', async t => {
t.deepEqual(calls[0].body, payload);
});
e2e('should resume multipart upload and return uploaded parts', async t => {
const { workspace } = await setupWorkspace();
const key = 'multipart-resume';
const totalSize = MULTIPART_THRESHOLD + 1024;
e2e.serial(
'should resume multipart upload and return uploaded parts',
async t => {
const { workspace } = await setupWorkspace();
const key = 'multipart-resume';
const totalSize = MULTIPART_THRESHOLD + 1024;
const init1 = await createBlobUpload(workspace.id, key, totalSize, 'bin');
t.is(init1.method, 'MULTIPART');
t.is(init1.uploadId, 'upload-id');
t.deepEqual(init1.uploadedParts, []);
t.is(getProvider().createMultipartCalls, 1);
const init1 = await createBlobUpload(workspace.id, key, totalSize, 'bin');
t.is(init1.method, 'MULTIPART');
t.is(init1.uploadId, 'upload-id');
t.deepEqual(init1.uploadedParts, []);
t.is(getRuntime().createMultipartCalls, 1);
const part = await getBlobUploadPartUrl(workspace.id, key, init1.uploadId, 1);
const payload = Buffer.from('part-body');
const partUrl = new URL(part.uploadUrl, app.url);
await app
.PUT(partUrl.pathname + partUrl.search)
.set('content-length', payload.length.toString())
.send(payload)
.expect(200);
const part = await getBlobUploadPartUrl(
workspace.id,
key,
init1.uploadId,
1
);
const payload = Buffer.from('part-body');
const partUrl = new URL(part.uploadUrl, app.url);
await app
.PUT(partUrl.pathname + partUrl.search)
.set('content-length', payload.length.toString())
.send(payload)
.expect(200);
const init2 = await createBlobUpload(workspace.id, key, totalSize, 'bin');
t.is(init2.method, 'MULTIPART');
t.is(init2.uploadId, 'upload-id');
t.deepEqual(init2.uploadedParts, [{ partNumber: 1, etag: 'etag-1' }]);
t.is(getProvider().createMultipartCalls, 1);
});
const init2 = await createBlobUpload(workspace.id, key, totalSize, 'bin');
t.is(init2.method, 'MULTIPART');
t.is(init2.uploadId, 'upload-id');
t.deepEqual(init2.uploadedParts, [{ partNumber: 1, etag: 'etag-1' }]);
t.is(getRuntime().createMultipartCalls, 1);
}
);
e2e('should reject upload when token is invalid', async t => {
e2e.serial('should reject upload when token is invalid', async t => {
const { workspace } = await setupWorkspace();
const buffer = Buffer.from('payload');
const init = await createBlobUpload(
@@ -396,10 +468,10 @@ e2e('should reject upload when token is invalid', async t => {
t.is(res.status, 400);
t.is(res.body.message, 'Invalid upload token');
t.is(getProvider().putCalls.length, 0);
t.is(getRuntime().putCalls.length, 0);
});
e2e('should reject upload when url is expired', async t => {
e2e.serial('should reject upload when url is expired', async t => {
const { workspace } = await setupWorkspace();
const buffer = Buffer.from('expired');
const init = await createBlobUpload(
@@ -422,10 +494,10 @@ e2e('should reject upload when url is expired', async t => {
t.is(res.status, 400);
t.is(res.body.message, 'Upload URL expired');
t.is(getProvider().putCalls.length, 0);
t.is(getRuntime().putCalls.length, 0);
});
e2e(
e2e.serial(
'should fall back to direct presign when custom domain is disabled',
async t => {
await useR2Storage({
@@ -449,7 +521,7 @@ e2e(
}
);
e2e(
e2e.serial(
'should still fallback to graphql when provider does not support presign',
async t => {
await setBlobStorage({
@@ -473,6 +545,40 @@ e2e(
}
);
function createProxyUrl(
path: string,
canonicalFields: (string | number | undefined)[],
query: Record<string, string | number | undefined>
) {
const signKey = (
app.get(Config).storages.blob.storage.config as R2StorageConfig
).usePresignedURL?.signKey;
if (!signKey) {
throw new Error('missing R2 proxy sign key');
}
const exp = Math.floor(Date.now() / 1000) + SIGNED_URL_EXPIRED;
const canonical = [
path,
...canonicalFields.map(field =>
field === undefined ? '' : field.toString()
),
exp.toString(),
].join('\n');
const token = createHmac('sha256', signKey)
.update(canonical)
.digest('base64');
const url = new URL(`http://localhost${path}`);
for (const [key, value] of Object.entries(query)) {
if (value !== undefined) {
url.searchParams.set(key, value.toString());
}
}
url.searchParams.set('exp', exp.toString());
url.searchParams.set('token', `${exp}-${token}`);
return { url: url.pathname + url.search, expiresAt: new Date(exp * 1000) };
}
function sha256Base64urlWithPadding(buffer: Buffer) {
return createHash('sha256')
.update(buffer)
@@ -1,4 +1,5 @@
import { randomUUID } from 'node:crypto';
import { Readable } from 'node:stream';
import { mock } from 'node:test';
import {
@@ -7,6 +8,8 @@ import {
type StorageProviderConfig,
} from '../../../base';
import { CommentAttachmentStorage } from '../../../core/storage';
import { StorageRuntimeProvider } from '../../../core/storage-runtime';
import { getMime } from '../../../native';
import { Mockers } from '../../mocks';
import { app, e2e } from '../test';
@@ -26,14 +29,68 @@ e2e.afterEach.always(() => {
mock.reset();
});
const objects = new Map<
string,
{
body: Buffer;
metadata?: {
contentLength?: number;
contentType?: string;
checksumCRC32?: string;
lastModified?: Date;
};
}
>();
e2e.beforeEach(() => {
objects.clear();
const rt = app.get(StorageRuntimeProvider);
mock.method(
rt,
'putObject',
async (
_scope: string,
key: string,
body: Buffer,
metadata?: {
contentLength?: number;
contentType?: string;
checksumCRC32?: string;
}
) => {
const object = {
body,
metadata: {
...metadata,
contentType: metadata?.contentType ?? getMime(body),
contentLength: metadata?.contentLength ?? body.length,
lastModified: new Date(),
},
};
objects.set(key, object);
return object.metadata;
}
);
mock.method(rt, 'getObject', async (_scope: string, key: string) => {
const object = objects.get(key);
if (!object) {
return {};
}
return {
body: Readable.from(object.body),
metadata: object.metadata,
};
});
mock.method(rt, 'presignGet', async () => undefined);
});
async function useCommentAttachmentBlobStorage(storage: StorageProviderConfig) {
app.get(ConfigFactory).override({ storages: { blob: { storage } } });
await app.get(CommentAttachmentStorage).onConfigInit();
}
// #region comment attachment
e2e(
e2e.serial(
'should get comment attachment not found when key is not exists',
async t => {
const { owner, workspace } = await createWorkspace();
@@ -50,7 +107,7 @@ e2e(
}
);
e2e(
e2e.serial(
'should get comment attachment no permission when user is not member',
async t => {
const { workspace } = await createWorkspace();
@@ -117,7 +174,7 @@ e2e.serial('should get comment attachment body', async t => {
}
});
e2e('should get comment attachment redirect url', async t => {
e2e.serial('should get comment attachment redirect url', async t => {
const { owner, workspace } = await createWorkspace();
await app.login(owner);
@@ -145,6 +145,28 @@ e2e('should set new invited users to waiting-seat status', async t => {
t.is(invitationInfo.status, WorkspaceMemberStatus.NeedMoreSeat);
});
e2e('should allocate existing team seats for new invited users', async t => {
const { owner, workspace } = await createTeamWorkspace(4);
await app.login(owner);
const u1 = await app.createUser();
const result = await app.gql({
query: inviteByEmailsMutation,
variables: {
workspaceId: workspace.id,
emails: [u1.email],
},
});
t.not(result.inviteMembers[0].inviteId, null);
const invitationInfo = await getInvitationInfo(
result.inviteMembers[0].inviteId!
);
t.is(invitationInfo.status, WorkspaceMemberStatus.Pending);
});
e2e('should allocate seats', async t => {
const { owner, workspace } = await createTeamWorkspace();
await app.login(owner);
@@ -165,6 +187,7 @@ e2e('should allocate seats', async t => {
source: 'Link',
});
const invitationCount = app.queue.count('notification.sendInvitation');
await app.eventBus.emitAsync('workspace.members.allocateSeats', {
workspaceId: workspace.id,
quantity: 5,
@@ -184,7 +207,7 @@ e2e('should allocate seats', async t => {
WorkspaceMemberStatus.Accepted
);
t.is(app.queue.count('notification.sendInvitation'), 1);
t.is(app.queue.count('notification.sendInvitation') - invitationCount, 1);
});
e2e('should set all rests to NeedMoreSeat', async t => {
@@ -12,7 +12,7 @@ import { MockDocSnapshot } from './doc-snapshot.mock';
import { MockDocUser } from './doc-user.mock';
import { MockEventBus } from './eventbus.mock';
import { MockMailer } from './mailer.mock';
import { MockJobQueue } from './queue.mock';
import { MockJobModule, MockJobQueue } from './queue.mock';
import { MockTeamWorkspace } from './team-workspace.mock';
import { MockUser } from './user.mock';
import { MockUserSettings } from './user-settings.mock';
@@ -35,6 +35,7 @@ export {
installMockCopilotRuntime,
MockCopilotProvider,
MockEventBus,
MockJobModule,
MockJobQueue,
MockMailer,
};
@@ -1,12 +1,16 @@
import { Global, Module } from '@nestjs/common';
import { interval, map, take, takeUntil } from 'rxjs';
import Sinon from 'sinon';
import { JobQueue } from '../../base';
export class MockJobQueue {
add = Sinon.createStubInstance(JobQueue).add.resolves();
remove = Sinon.createStubInstance(JobQueue).remove.resolves();
removeWhere = Sinon.createStubInstance(JobQueue).removeWhere.resolves([]);
private readonly sandbox = Sinon.createSandbox();
add = this.sandbox.stub().resolves();
get = this.sandbox.stub().resolves();
remove = this.sandbox.stub().resolves();
removeWhere = this.sandbox.stub().resolves([]);
last<Job extends JobName>(name: Job): { name: Job; payload: Jobs[Job] } {
const addJobName = this.add.lastCall?.args[0];
@@ -57,3 +61,10 @@ export class MockJobQueue {
return this.add.getCalls().filter(call => call.args[0] === name).length;
}
}
@Global()
@Module({
providers: [{ provide: JobQueue, useClass: MockJobQueue }],
exports: [JobQueue],
})
export class MockJobModule {}
@@ -303,28 +303,3 @@ test('should delete userSession fail when sessionId not match', async t => {
);
t.is(count, 0);
});
test('should cleanup expired userSessions', async t => {
const user = await t.context.user.create({
email: 'test@affine.pro',
});
const session = await t.context.db.session.create({
data: {},
});
const userSession = await t.context.session.createOrRefreshUserSession(
user.id,
session.id
);
await t.context.session.cleanExpiredUserSessions();
let count = await t.context.db.userSession.count();
t.is(count, 1);
// Set expiresAt to past time
await t.context.db.userSession.update({
where: { id: userSession.id },
data: { expiresAt: new Date('2022-01-01') },
});
await t.context.session.cleanExpiredUserSessions();
count = await t.context.db.userSession.count();
t.is(count, 0);
});
@@ -1475,6 +1475,41 @@ test('should not be able to checkout for workspace if subscribed', async t => {
);
});
test('should be able to checkout for workspace after canceled subscription', async t => {
const { service, u1, db, stripe } = t.context;
await db.subscription.create({
data: {
targetId: 'ws_1',
stripeSubscriptionId: 'sub_1',
plan: SubscriptionPlan.Team,
recurring: SubscriptionRecurring.Monthly,
status: SubscriptionStatus.Canceled,
start: new Date(Date.now() - 100000),
end: new Date(Date.now() - 1000),
quantity: 1,
},
});
await service.checkout(
{
plan: SubscriptionPlan.Team,
recurring: SubscriptionRecurring.Monthly,
variant: null,
successCallbackLink: '',
},
{
user: u1,
workspaceId: 'ws_1',
}
);
t.deepEqual(getLastCheckoutPrice(stripe.checkout.sessions.create), {
price: TEAM_MONTHLY,
coupon: undefined,
});
});
const teamSub: Stripe.Subscription = {
...sub,
items: {
@@ -1517,6 +1552,48 @@ test('should be able to create team subscription', async t => {
t.is(subInDB?.stripeSubscriptionId, sub.id);
});
test('should replace old team subscription row when stripe creates a new subscription', async t => {
const { service, db } = t.context;
const old = await db.subscription.create({
data: {
targetId: 'ws_1',
stripeSubscriptionId: 'sub_old_team',
plan: SubscriptionPlan.Team,
recurring: SubscriptionRecurring.Yearly,
status: SubscriptionStatus.Canceled,
start: new Date('2026-03-26T08:23:57.000Z'),
end: new Date('2027-03-26T08:23:57.000Z'),
quantity: 24,
},
});
await service.saveStripeSubscription({
...teamSub,
id: 'sub_new_team',
status: SubscriptionStatus.Active,
items: {
...teamSub.items,
data: [
{
...teamSub.items.data[0],
quantity: 11,
},
],
},
});
const subscriptions = await db.subscription.findMany({
where: { targetId: 'ws_1', plan: SubscriptionPlan.Team },
});
t.is(subscriptions.length, 1);
t.is(subscriptions[0].id, old.id);
t.is(subscriptions[0].stripeSubscriptionId, 'sub_new_team');
t.is(subscriptions[0].status, SubscriptionStatus.Active);
t.is(subscriptions[0].quantity, 11);
});
test('should be able to update team subscription', async t => {
const { service, db, event } = t.context;
@@ -1551,6 +1628,77 @@ test('should be able to update team subscription', async t => {
);
});
test('should persist mutable team subscription fields on same stripe subscription update', async t => {
const { service, db } = t.context;
await service.saveStripeSubscription(teamSub);
await service.saveStripeSubscription({
...teamSub,
current_period_start: 1780000000,
current_period_end: 1811536000,
trial_start: 1780000000,
trial_end: 1780604800,
items: {
...teamSub.items,
data: [
{
...teamSub.items.data[0],
quantity: 9,
price: {
...PRICES[TEAM_YEARLY],
lookup_key: TEAM_YEARLY,
},
},
],
},
});
const subInDB = await db.subscription.findFirst({
where: { targetId: 'ws_1' },
});
const entitlement = await db.entitlement.findFirst({
where: {
source: 'cloud_subscription',
subjectId: teamSub.id,
},
});
const providerFact = await db.providerSubscription.findUnique({
where: {
provider_externalSubscriptionId: {
provider: 'stripe',
externalSubscriptionId: teamSub.id,
},
},
});
t.like(subInDB, {
recurring: SubscriptionRecurring.Yearly,
quantity: 9,
start: new Date(1780000000 * 1000),
end: new Date(1811536000 * 1000),
trialStart: new Date(1780000000 * 1000),
trialEnd: new Date(1780604800 * 1000),
});
t.like(entitlement, {
plan: 'team',
quantity: 9,
startsAt: new Date(1780000000 * 1000),
expiresAt: new Date(1811536000 * 1000),
});
t.like(providerFact, {
recurring: SubscriptionRecurring.Yearly,
externalPriceId: TEAM_YEARLY,
currency: 'usd',
amount: 14400,
quantity: 9,
periodStart: new Date(1780000000 * 1000),
periodEnd: new Date(1811536000 * 1000),
trialStart: new Date(1780000000 * 1000),
trialEnd: new Date(1780604800 * 1000),
});
});
test('should suspend on dispute and restore when dispute won', async t => {
const { service, db, stripe, event } = t.context;
@@ -6,6 +6,7 @@ import Sinon from 'sinon';
import { OneDay } from '../../base';
import { StorageModule, WorkspaceBlobStorage } from '../../core/storage';
import { BlobUploadCleanupJob } from '../../core/storage/job';
import { StorageRuntimeProvider } from '../../core/storage-runtime';
import { MockUser, MockWorkspace } from '../mocks';
import { createTestingModule, TestingModule } from '../utils';
@@ -14,13 +15,22 @@ interface Context {
db: PrismaClient;
job: BlobUploadCleanupJob;
storage: WorkspaceBlobStorage;
runtime: { cleanupExpiredPendingBlobs: Sinon.SinonStub };
}
const test = ava as TestFn<Context>;
test.before(async t => {
t.context.runtime = {
cleanupExpiredPendingBlobs: Sinon.stub(),
};
t.context.module = await createTestingModule({
imports: [ScheduleModule.forRoot(), StorageModule],
tapModule: builder => {
builder
.overrideProvider(StorageRuntimeProvider)
.useValue(t.context.runtime);
},
});
t.context.db = t.context.module.get(PrismaClient);
@@ -30,6 +40,7 @@ test.before(async t => {
test.beforeEach(async t => {
await t.context.module.initTestingDB();
t.context.runtime.cleanupExpiredPendingBlobs.reset();
});
test.after.always(async t => {
@@ -86,24 +97,14 @@ test('should cleanup expired pending blobs', async t => {
],
});
const abortSpy = Sinon.stub(
t.context.storage,
'abortMultipartUpload'
).resolves();
const deleteSpy = Sinon.spy(t.context.storage, 'delete');
t.teardown(() => {
abortSpy.restore();
deleteSpy.restore();
t.context.runtime.cleanupExpiredPendingBlobs.resolves({
scanned: 2,
deleted: 2,
abortedMultipart: 1,
workspaceIds: [workspace.id],
});
await t.context.job.cleanExpiredPendingBlobs();
t.is(abortSpy.callCount, 1);
t.is(deleteSpy.callCount, 2);
const remaining = await t.context.db.blob.findMany({
where: { workspaceId: workspace.id },
});
const remainingKeys = remaining.map(record => record.key).sort();
t.deepEqual(remainingKeys, ['completed-keep', 'pending-active']);
t.true(t.context.runtime.cleanupExpiredPendingBlobs.calledOnce);
});
@@ -9,7 +9,7 @@ import {
import { PrismaClient } from '@prisma/client';
import { buildAppModule, FunctionalityModules } from '../../app.module';
import { AFFiNELogger, JobQueue } from '../../base';
import { AFFiNELogger, ConfigFactory, JobModule, JobQueue } from '../../base';
import { GqlModule } from '../../base/graphql';
import { ServerConfigModule } from '../../core';
import { AuthGuard, AuthModule } from '../../core/auth';
@@ -18,7 +18,7 @@ import { ModelsModule } from '../../models';
// for jsdoc inference
// oxlint-disable-next-line no-unused-vars
import type { createModule } from '../create-module';
import { createFactory, MockJobQueue } from '../mocks';
import { createFactory, MockJobModule, MockJobQueue } from '../mocks';
import { MockMailer } from '../mocks/mailer.mock';
import { initTestingDB, TEST_LOG_LEVEL } from './utils';
@@ -48,6 +48,16 @@ function dedupeModules(modules: NonNullable<ModuleMetadata['imports']>) {
return Array.from(map.values());
}
function testingFunctionalityModules() {
return [
...FunctionalityModules.filter(module => {
const moduleType = 'module' in module ? module.module : module;
return moduleType !== JobModule;
}),
MockJobModule,
];
}
@Resolver(() => String)
class MockResolver {
@Query(() => String)
@@ -70,7 +80,7 @@ export async function createTestingModule(
imports[0].module?.name === 'AppModule'
? imports
: dedupeModules([
...FunctionalityModules,
...testingFunctionalityModules(),
ModelsModule,
AuthModule,
GqlModule,
@@ -99,6 +109,31 @@ export async function createTestingModule(
}
const module = await builder.compile();
module.get(ConfigFactory).override({
storages: {
avatar: {
storage: {
provider: 'assetpack',
bucket: 'avatars',
config: { path: '/tmp/affine-test-storage' },
},
},
blob: {
storage: {
provider: 'assetpack',
bucket: 'blobs',
config: { path: '/tmp/affine-test-storage' },
},
},
},
copilot: {
storage: {
provider: 'assetpack',
bucket: 'copilot',
config: { path: '/tmp/affine-test-storage' },
},
},
});
const testingModule = module as TestingModule;
@@ -1,12 +1,15 @@
import { createHash } from 'node:crypto';
import { Readable } from 'node:stream';
import test from 'ava';
import Sinon from 'sinon';
import { Config, ConfigFactory, StorageProviderFactory } from '../../base';
import { ConfigFactory } from '../../base';
import { QuotaStateService } from '../../core/quota/state';
import { WorkspaceBlobStorage } from '../../core/storage/wrappers/blob';
import { StorageRuntimeProvider } from '../../core/storage-runtime';
import { BlobModel, WorkspaceFeatureModel } from '../../models';
import { getMime } from '../../native';
import {
collectAllBlobSizes,
completeBlobUpload,
@@ -32,9 +35,117 @@ const RESTRICTED_QUOTA = {
let app: TestingApp;
let model: WorkspaceFeatureModel;
type CompleteResult =
| {
ok: true;
contentType: string;
contentLength: number;
lastModifiedMs: number;
}
| {
ok: false;
reason:
| 'not_found'
| 'size_mismatch'
| 'mime_mismatch'
| 'checksum_mismatch'
| 'size_too_large';
};
const objects = new Map<
string,
{
body: Buffer;
metadata: {
contentType: string;
contentLength: number;
lastModified: Date;
};
}
>();
const completeResults = new Map<string, CompleteResult>();
const storageRuntime = {
providerCapabilities: async () => ({
put: true,
get: true,
head: true,
list: true,
delete: true,
presignPut: false,
presignGet: false,
multipartDirect: false,
proxyUpload: false,
assetpack: false,
serverMediatedOnly: true,
}),
putObject: async (
_scope: string,
key: string,
body: Buffer,
metadata?: { contentType?: string; contentLength?: number }
) => {
const object = {
body,
metadata: {
contentType: metadata?.contentType ?? getMime(body),
contentLength: metadata?.contentLength ?? body.length,
lastModified: new Date(),
},
};
objects.set(key, object);
return object.metadata;
},
headObject: async (_scope: string, key: string) => {
return objects.get(key)?.metadata;
},
getObject: async (_scope: string, key: string) => {
const object = objects.get(key);
return object
? { body: Readable.from(object.body), metadata: object.metadata }
: {};
},
listObjects: async (_scope: string, prefix?: string) => {
return Array.from(objects.entries())
.filter(([key]) => !prefix || key.startsWith(prefix))
.map(([key, object]) => ({ key, ...object.metadata }));
},
deleteObject: async (_scope: string, key: string) => {
objects.delete(key);
},
presignPut: async () => undefined,
presignGet: async () => undefined,
createMultipartUpload: async () => undefined,
presignUploadPart: async () => undefined,
listMultipartUploadParts: async () => undefined,
completeMultipartUpload: async () => undefined,
completeWorkspaceBlobUpload: async (workspaceId: string, key: string) => {
const objectKey = `${workspaceId}/${key}`;
const configured = completeResults.get(objectKey);
if (configured) return configured;
const object = objects.get(objectKey);
if (!object) return { ok: false, reason: 'not_found' };
await app.get(BlobModel).upsert({
workspaceId,
key,
mime: object.metadata.contentType,
size: object.metadata.contentLength,
status: 'completed',
uploadId: null,
});
return {
ok: true,
contentType: object.metadata.contentType,
contentLength: object.metadata.contentLength,
lastModifiedMs: object.metadata.lastModified.getTime(),
};
},
};
test.before(async () => {
app = await createTestingApp();
app = await createTestingApp({
tapModule: builder => {
builder.overrideProvider(StorageRuntimeProvider).useValue(storageRuntime);
},
});
model = app.get(WorkspaceFeatureModel);
app.get(ConfigFactory).override({
storages: {
@@ -47,11 +158,12 @@ test.before(async () => {
},
},
});
await app.get(WorkspaceBlobStorage).onConfigInit();
});
test.beforeEach(async () => {
await app.initTestingDB();
objects.clear();
completeResults.clear();
});
test.after.always(async () => {
@@ -119,6 +231,47 @@ test('should list blobs', async t => {
t.deepEqual(ret.map(x => x.key).sort(), [hash1, hash2].sort());
});
test('should keep partial blob metadata listing on DB path without storage scan', async t => {
await app.signupV1('u1@affine.pro');
const workspace = await createWorkspace(app);
const storage = app.get(WorkspaceBlobStorage);
const rt = app.get(StorageRuntimeProvider);
const listSpy = Sinon.spy(rt, 'listObjects');
t.teardown(() => listSpy.restore());
const buffer1 = Buffer.from('with metadata');
const buffer2 = Buffer.from('without metadata');
const key1 = sha256Base64urlWithPadding(buffer1);
const key2 = sha256Base64urlWithPadding(buffer2);
await rt.putObject('blob', `${workspace.id}/${key1}`, buffer1, {
contentType: 'text/plain',
contentLength: buffer1.length,
});
await rt.putObject('blob', `${workspace.id}/${key2}`, buffer2, {
contentType: 'text/plain',
contentLength: buffer2.length,
});
const blobModel = app.get(BlobModel);
await blobModel.upsert({
workspaceId: workspace.id,
key: key1,
mime: 'text/plain',
size: buffer1.length,
status: 'completed',
uploadId: null,
});
const listed = await storage.list(workspace.id);
t.deepEqual(
listed.map(blob => blob.key),
[key1]
);
t.true(listSpy.notCalled);
});
test('should create pending blob upload with graphql fallback', async t => {
await app.signupV1('u1@affine.pro');
@@ -150,11 +303,9 @@ test('should complete pending blob upload', async t => {
await createBlobUpload(app, workspace.id, key, buffer.length, mime);
const config = app.get(Config);
const factory = app.get(StorageProviderFactory);
const provider = factory.create(config.storages.blob.storage);
const rt = app.get(StorageRuntimeProvider);
await provider.put(`${workspace.id}/${key}`, buffer, {
await rt.putObject('blob', `${workspace.id}/${key}`, buffer, {
contentType: mime,
contentLength: buffer.length,
});
@@ -181,14 +332,16 @@ test('should reject complete when blob key mismatched', async t => {
const wrongKey = sha256Base64urlWithPadding(Buffer.from('other'));
await createBlobUpload(app, workspace.id, wrongKey, buffer.length, mime);
const config = app.get(Config);
const factory = app.get(StorageProviderFactory);
const provider = factory.create(config.storages.blob.storage);
const rt = app.get(StorageRuntimeProvider);
await provider.put(`${workspace.id}/${wrongKey}`, buffer, {
await rt.putObject('blob', `${workspace.id}/${wrongKey}`, buffer, {
contentType: mime,
contentLength: buffer.length,
});
completeResults.set(`${workspace.id}/${wrongKey}`, {
ok: false,
reason: 'checksum_mismatch',
});
await t.throwsAsync(() => completeBlobUpload(app, workspace.id, wrongKey), {
message: 'Blob key mismatch',
@@ -221,10 +374,12 @@ test('should auto delete blobs when workspace is deleted', async t => {
const blobs = await listBlobs(app, workspace.id);
t.is(blobs.length, 2);
const workspaceBlobStorage = Sinon.spy(app.get(WorkspaceBlobStorage));
const rt = app.get(StorageRuntimeProvider);
const listSpy = Sinon.spy(rt, 'listObjects');
t.teardown(() => listSpy.restore());
await deleteWorkspace(app, workspace.id);
// should not emit workspace.blob.sync event
t.is(workspaceBlobStorage.syncBlobMeta.callCount, 0);
t.is(listSpy.callCount, 0);
});
test('should calc blobs size', async t => {
@@ -212,8 +212,7 @@ test('should be able to get permission granted workspace', async t => {
test('should return 404 if blob not found', async t => {
const { app, storage } = t.context;
// @ts-expect-error mock
storage.get.resolves({ body: null });
storage.get.resolves({ body: undefined });
const res = await app.GET('/api/workspaces/public/blobs/test');
t.is(res.status, HttpStatus.NOT_FOUND);
+4 -2
View File
@@ -25,11 +25,11 @@ import { MetricsModule } from './base/metrics';
import { MutexModule } from './base/mutex';
import { PrismaModule } from './base/prisma';
import { RedisModule } from './base/redis';
import { StorageProviderModule } from './base/storage';
import { RateLimiterModule } from './base/throttler';
import { WebSocketModule } from './base/websocket';
import { AccessTokenModule } from './core/access-token';
import { AuthModule } from './core/auth';
import { BackendRuntimeModule } from './core/backend-runtime';
import { CommentModule } from './core/comment';
import { ServerConfigModule, ServerConfigResolverModule } from './core/config';
import { DocStorageModule } from './core/doc';
@@ -46,6 +46,7 @@ import { RealtimeModule } from './core/realtime';
import { SelfhostModule } from './core/selfhost';
import { StaticFileModule } from './core/static-files';
import { StorageModule } from './core/storage';
import { StorageRuntimeModule } from './core/storage-runtime';
import { SyncModule } from './core/sync';
import { TelemetryModule } from './core/telemetry';
import { UserModule } from './core/user';
@@ -113,13 +114,14 @@ export const FunctionalityModules = [
MutexModule,
MetricsModule,
RateLimiterModule,
StorageProviderModule,
HelpersModule,
ErrorModule,
WebSocketModule,
JobModule.forRoot(),
RealtimeModule,
ModelsModule,
BackendRuntimeModule,
StorageRuntimeModule,
ScheduleModule.forRoot(),
MonitorModule,
];
@@ -21,6 +21,7 @@ export type JSONSchema = { description?: string } & (
| {
type: 'object';
properties?: Record<string, JSONSchema>;
required?: string[];
}
);
@@ -907,6 +907,14 @@ export const USER_FRIENDLY_ERRORS = {
message: ({ clientVersion, requiredVersion }) =>
`Unsupported client with version [${clientVersion}], required version is [${requiredVersion}].`,
},
unsupported_server_version: {
type: 'action_forbidden',
args: {
requiredVersion: 'string',
},
message: ({ requiredVersion }) =>
`This AFFiNE server is too old for this client. Please upgrade the server to ${requiredVersion}.`,
},
// Notification Errors
notification_not_found: {
@@ -1059,6 +1059,16 @@ export class UnsupportedClientVersion extends UserFriendlyError {
super('action_forbidden', 'unsupported_client_version', message, args);
}
}
@ObjectType()
class UnsupportedServerVersionDataType {
@Field() requiredVersion!: string
}
export class UnsupportedServerVersion extends UserFriendlyError {
constructor(args: UnsupportedServerVersionDataType, message?: string | ((args: UnsupportedServerVersionDataType) => string)) {
super('action_forbidden', 'unsupported_server_version', message, args);
}
}
export class NotificationNotFound extends UserFriendlyError {
constructor(message?: string) {
@@ -1288,6 +1298,7 @@ export enum ErrorNames {
INVALID_LICENSE_UPDATE_PARAMS,
LICENSE_EXPIRED,
UNSUPPORTED_CLIENT_VERSION,
UNSUPPORTED_SERVER_VERSION,
NOTIFICATION_NOT_FOUND,
MENTION_USER_DOC_ACCESS_DENIED,
MENTION_USER_ONESELF_DENIED,
@@ -1308,5 +1319,5 @@ registerEnumType(ErrorNames, {
export const ErrorDataUnionType = createUnionType({
name: 'ErrorDataUnion',
types: () =>
[GraphqlBadRequestDataType, HttpRequestErrorDataType, SsrfBlockedErrorDataType, ResponseTooLargeErrorDataType, ImageFormatNotSupportedDataType, QueryTooLongDataType, ValidationErrorDataType, WrongSignInCredentialsDataType, UnknownOauthProviderDataType, InvalidOauthCallbackCodeDataType, MissingOauthQueryParameterDataType, InvalidOauthResponseDataType, InvalidEmailDataType, InvalidPasswordLengthDataType, WorkspacePermissionNotFoundDataType, SpaceNotFoundDataType, MemberNotFoundInSpaceDataType, NotInSpaceDataType, AlreadyInSpaceDataType, SpaceAccessDeniedDataType, SpaceOwnerNotFoundDataType, SpaceShouldHaveOnlyOneOwnerDataType, DocNotFoundDataType, DocActionDeniedDataType, DocUpdateBlockedDataType, VersionRejectedDataType, InvalidHistoryTimestampDataType, DocHistoryNotFoundDataType, BlobNotFoundDataType, ExpectToGrantDocUserRolesDataType, ExpectToRevokeDocUserRolesDataType, ExpectToUpdateDocUserRoleDataType, NoMoreSeatDataType, UnsupportedSubscriptionPlanDataType, SubscriptionAlreadyExistsDataType, SubscriptionNotExistsDataType, SameSubscriptionRecurringDataType, SubscriptionPlanNotFoundDataType, CalendarProviderRequestErrorDataType, NoCopilotProviderAvailableDataType, CopilotFailedToGenerateEmbeddingDataType, CopilotDocNotFoundDataType, CopilotMessageNotFoundDataType, CopilotPromptNotFoundDataType, CopilotProviderNotSupportedDataType, CopilotProviderSideErrorDataType, CopilotInvalidContextDataType, CopilotContextFileNotSupportedDataType, CopilotFailedToModifyContextDataType, CopilotFailedToMatchContextDataType, CopilotFailedToMatchGlobalContextDataType, CopilotFailedToAddWorkspaceFileEmbeddingDataType, RuntimeConfigNotFoundDataType, InvalidRuntimeConfigTypeDataType, InvalidLicenseToActivateDataType, InvalidLicenseUpdateParamsDataType, UnsupportedClientVersionDataType, MentionUserDocAccessDeniedDataType, InvalidAppConfigDataType, InvalidAppConfigInputDataType, InvalidSearchProviderRequestDataType, InvalidIndexerInputDataType] as const,
[GraphqlBadRequestDataType, HttpRequestErrorDataType, SsrfBlockedErrorDataType, ResponseTooLargeErrorDataType, ImageFormatNotSupportedDataType, QueryTooLongDataType, ValidationErrorDataType, WrongSignInCredentialsDataType, UnknownOauthProviderDataType, InvalidOauthCallbackCodeDataType, MissingOauthQueryParameterDataType, InvalidOauthResponseDataType, InvalidEmailDataType, InvalidPasswordLengthDataType, WorkspacePermissionNotFoundDataType, SpaceNotFoundDataType, MemberNotFoundInSpaceDataType, NotInSpaceDataType, AlreadyInSpaceDataType, SpaceAccessDeniedDataType, SpaceOwnerNotFoundDataType, SpaceShouldHaveOnlyOneOwnerDataType, DocNotFoundDataType, DocActionDeniedDataType, DocUpdateBlockedDataType, VersionRejectedDataType, InvalidHistoryTimestampDataType, DocHistoryNotFoundDataType, BlobNotFoundDataType, ExpectToGrantDocUserRolesDataType, ExpectToRevokeDocUserRolesDataType, ExpectToUpdateDocUserRoleDataType, NoMoreSeatDataType, UnsupportedSubscriptionPlanDataType, SubscriptionAlreadyExistsDataType, SubscriptionNotExistsDataType, SameSubscriptionRecurringDataType, SubscriptionPlanNotFoundDataType, CalendarProviderRequestErrorDataType, NoCopilotProviderAvailableDataType, CopilotFailedToGenerateEmbeddingDataType, CopilotDocNotFoundDataType, CopilotMessageNotFoundDataType, CopilotPromptNotFoundDataType, CopilotProviderNotSupportedDataType, CopilotProviderSideErrorDataType, CopilotInvalidContextDataType, CopilotContextFileNotSupportedDataType, CopilotFailedToModifyContextDataType, CopilotFailedToMatchContextDataType, CopilotFailedToMatchGlobalContextDataType, CopilotFailedToAddWorkspaceFileEmbeddingDataType, RuntimeConfigNotFoundDataType, InvalidRuntimeConfigTypeDataType, InvalidLicenseToActivateDataType, InvalidLicenseUpdateParamsDataType, UnsupportedClientVersionDataType, UnsupportedServerVersionDataType, MentionUserDocAccessDeniedDataType, InvalidAppConfigDataType, InvalidAppConfigInputDataType, InvalidSearchProviderRequestDataType, InvalidIndexerInputDataType] as const,
});
+1 -6
View File
@@ -30,11 +30,6 @@ export { Lock, Locker, Mutex, RequestMutex } from './mutex';
export * from './nestjs';
export { type PrismaTransaction } from './prisma';
export * from './storage';
export {
autoMetadata,
type StorageProvider,
type StorageProviderConfig,
StorageProviderFactory,
} from './storage';
export { type StorageProviderConfig } from './storage';
export { CloudThrottlerGuard, SkipThrottle, Throttle } from './throttler';
export * from './utils';
@@ -94,4 +94,12 @@ defineModuleConfig('job', {
},
schema,
},
'queues.backendRuntime': {
desc: 'The config for backend runtime job queue',
default: {
concurrency: 1,
},
schema,
},
});
@@ -29,6 +29,7 @@ export enum Queue {
COPILOT = 'copilot',
INDEXER = 'indexer',
CALENDAR = 'calendar',
BACKENDRUNTIME = 'backendRuntime',
}
export const QUEUES = Object.values(Queue);
@@ -8,6 +8,13 @@ import { Redis as IORedis, RedisOptions } from 'ioredis';
import { Config } from '../config';
function redisOptions(options: RedisOptions) {
return {
...(env.testing ? { lazyConnect: true } : {}),
...options,
};
}
class Redis extends IORedis implements OnModuleInit, OnModuleDestroy {
private readonly logger = new Logger(this.constructor.name);
@@ -47,41 +54,47 @@ class Redis extends IORedis implements OnModuleInit, OnModuleDestroy {
@Injectable()
export class CacheRedis extends Redis {
constructor(config: Config) {
super({ ...config.redis, ...config.redis.ioredis });
super(redisOptions({ ...config.redis, ...config.redis.ioredis }));
}
}
@Injectable()
export class SessionRedis extends Redis {
constructor(config: Config) {
super({
...config.redis,
...config.redis.ioredis,
db: (config.redis.db ?? 0) + 2,
});
super(
redisOptions({
...config.redis,
...config.redis.ioredis,
db: (config.redis.db ?? 0) + 2,
})
);
}
}
@Injectable()
export class SocketIoRedis extends Redis {
constructor(config: Config) {
super({
...config.redis,
...config.redis.ioredis,
db: (config.redis.db ?? 0) + 3,
});
super(
redisOptions({
...config.redis,
...config.redis.ioredis,
db: (config.redis.db ?? 0) + 3,
})
);
}
}
@Injectable()
export class QueueRedis extends Redis {
constructor(config: Config) {
super({
...config.redis,
...config.redis.ioredis,
db: (config.redis.db ?? 0) + 4,
// required explicitly set to `null` by bullmq
maxRetriesPerRequest: null,
});
super(
redisOptions({
...config.redis,
...config.redis.ioredis,
db: (config.redis.db ?? 0) + 4,
// required explicitly set to `null` by bullmq
maxRetriesPerRequest: null,
})
);
}
}
@@ -1,130 +0,0 @@
import { promises as fs } from 'node:fs';
import { join } from 'node:path';
import test from 'ava';
import { getStreamAsBuffer } from 'get-stream';
import { ListObjectsMetadata } from '../providers';
import { FsStorageProvider } from '../providers/fs';
const config = {
path: join(process.cwd(), 'node_modules', '.cache/affine-test-storage'),
};
function createProvider() {
return new FsStorageProvider(
config,
'test' + Math.random().toString(16).substring(2, 8)
);
}
function keys(list: ListObjectsMetadata[]) {
return list.map(i => i.key);
}
async function randomPut(
provider: FsStorageProvider,
prefix = ''
): Promise<string> {
const key = prefix + 'test-key-' + Math.random().toString(16).substring(2, 8);
const body = Buffer.from(key);
await provider.put(key, body);
return key;
}
test.after.always(() => {
fs.rm(config.path, { recursive: true }).catch(console.error);
});
test('put & get', async t => {
const provider = createProvider();
const key = 'testKey';
const body = Buffer.from('testBody');
await provider.put(key, body);
const result = await provider.get(key);
t.deepEqual(await getStreamAsBuffer(result.body!), body);
t.is(result.metadata?.contentLength, body.length);
});
test('list - one level', async t => {
const provider = createProvider();
const list = await Promise.all(
Array.from({ length: 100 }).map(() => randomPut(provider))
);
list.sort();
// random order, use set
const result = await provider.list();
t.deepEqual(keys(result), list);
const result2 = await provider.list('test-key');
t.deepEqual(keys(result2), list);
const result3 = await provider.list('testKey');
t.is(result3.length, 0);
});
test('list recursively', async t => {
const provider = createProvider();
await Promise.all([
Promise.all(Array.from({ length: 10 }).map(() => randomPut(provider))),
Promise.all(
Array.from({ length: 10 }).map(() => randomPut(provider, 'a/'))
),
Promise.all(
Array.from({ length: 10 }).map(() => randomPut(provider, 'a/b/'))
),
Promise.all(
Array.from({ length: 10 }).map(() => randomPut(provider, 'a/b/t/'))
),
]);
const r1 = await provider.list();
t.is(r1.length, 40);
// contains all `a/xxx` and `a/b/xxx` and `a/b/c/xxx`
const r2 = await provider.list('a');
t.is(r2.length, 30);
// contains only `a/b/xxx`
const r3 = await provider.list('a/b');
const r4 = await provider.list('a/b/');
t.is(r3.length, 20);
t.deepEqual(r3, r4);
// prefix is not ended with '/', it's open to all files and sub dirs
// contains all `a/b/t/xxx` and `a/b/t{xxxx}`
const r5 = await provider.list('a/b/t');
t.is(r5.length, 20);
});
test('delete', async t => {
const provider = createProvider();
const key = 'testKey';
const body = Buffer.from('testBody');
await provider.put(key, body);
await provider.delete(key);
await t.throwsAsync(() => fs.access(join(config.path, provider.bucket, key)));
});
test('rejects unsafe object keys', async t => {
const provider = createProvider();
await t.throwsAsync(() => provider.put('../escape', Buffer.from('nope')));
await t.throwsAsync(() => provider.get('nested/../escape'));
await t.throwsAsync(() => provider.head('./escape'));
t.throws(() => provider.delete('nested//escape'));
});
test('rejects unsafe list prefixes', async t => {
const provider = createProvider();
await t.throwsAsync(() => provider.list('../escape'));
await t.throwsAsync(() => provider.list('nested/../../escape'));
await t.throwsAsync(() => provider.list('/absolute'));
});
@@ -1,82 +0,0 @@
import test from 'ava';
import { R2StorageProvider } from '../providers/r2';
function endpointOf(provider: R2StorageProvider) {
return provider.endpointUrl;
}
test('R2 provider should use account endpoint by default', t => {
const provider = new R2StorageProvider(
{
accountId: 'test-account',
region: 'auto',
credentials: {
accessKeyId: 'test',
secretAccessKey: 'test',
},
},
'test-bucket'
);
t.is(
endpointOf(provider),
'https://test-account.r2.cloudflarestorage.com/test-bucket'
);
});
test('R2 provider should append jurisdiction suffix for EU buckets', t => {
const provider = new R2StorageProvider(
{
accountId: 'test-account',
jurisdiction: 'eu',
region: 'auto',
credentials: {
accessKeyId: 'test',
secretAccessKey: 'test',
},
},
'test-bucket'
);
t.is(
endpointOf(provider),
'https://test-account.eu.r2.cloudflarestorage.com/test-bucket'
);
});
test('R2 provider should throw when accountId is missing', t => {
t.throws(
() =>
new R2StorageProvider(
{
region: 'auto',
credentials: {
accessKeyId: 'test',
secretAccessKey: 'test',
},
} as any,
'test-bucket'
)
);
});
test('R2 provider should use default endpoint when jurisdiction is explicitly undefined', t => {
const provider = new R2StorageProvider(
{
accountId: 'test-account',
jurisdiction: undefined,
region: 'auto',
credentials: {
accessKeyId: 'test',
secretAccessKey: 'test',
},
},
'test-bucket'
);
t.is(
endpointOf(provider),
'https://test-account.r2.cloudflarestorage.com/test-bucket'
);
});
@@ -1,49 +0,0 @@
import { parseListPartsXml } from '@affine/s3-compat';
import test from 'ava';
test('parseListPartsXml handles array parts and pagination', t => {
const xml = `<?xml version="1.0" encoding="UTF-8"?>
<ListPartsResult>
<Bucket>test</Bucket>
<Key>key</Key>
<UploadId>upload-id</UploadId>
<PartNumberMarker>0</PartNumberMarker>
<NextPartNumberMarker>3</NextPartNumberMarker>
<IsTruncated>true</IsTruncated>
<Part>
<PartNumber>1</PartNumber>
<ETag>"etag-1"</ETag>
</Part>
<Part>
<PartNumber>2</PartNumber>
<ETag>etag-2</ETag>
</Part>
</ListPartsResult>`;
const result = parseListPartsXml(xml);
t.deepEqual(result.parts, [
{ partNumber: 1, etag: 'etag-1' },
{ partNumber: 2, etag: 'etag-2' },
]);
t.true(result.isTruncated);
t.is(result.nextPartNumberMarker, '3');
});
test('parseListPartsXml handles single part', t => {
const xml = `<?xml version="1.0" encoding="UTF-8"?>
<ListPartsResult>
<Bucket>test</Bucket>
<Key>key</Key>
<UploadId>upload-id</UploadId>
<IsTruncated>false</IsTruncated>
<Part>
<PartNumber>5</PartNumber>
<ETag>"etag-5"</ETag>
</Part>
</ListPartsResult>`;
const result = parseListPartsXml(xml);
t.deepEqual(result.parts, [{ partNumber: 5, etag: 'etag-5' }]);
t.false(result.isTruncated);
t.is(result.nextPartNumberMarker, undefined);
});
@@ -1,90 +0,0 @@
import test from 'ava';
import { S3StorageProvider } from '../providers/s3';
import { SIGNED_URL_EXPIRED } from '../providers/utils';
const config = {
region: 'us-east-1',
endpoint: 'https://s3.us-east-1.amazonaws.com',
credentials: {
accessKeyId: 'test',
secretAccessKey: 'test',
},
};
function createProvider() {
return new S3StorageProvider(config, 'test-bucket');
}
test('presignPut should return url and headers', async t => {
const provider = createProvider();
const result = await provider.presignPut('key', {
contentType: 'text/plain',
});
t.truthy(result);
t.true(result!.url.length > 0);
t.true(result!.url.includes('X-Amz-Algorithm=AWS4-HMAC-SHA256'));
t.true(result!.url.includes('X-Amz-SignedHeaders='));
t.true(result!.url.includes('content-type'));
t.deepEqual(result!.headers, { 'Content-Type': 'text/plain' });
const now = Date.now();
t.true(result!.expiresAt.getTime() >= now + SIGNED_URL_EXPIRED * 1000 - 2000);
t.true(result!.expiresAt.getTime() <= now + SIGNED_URL_EXPIRED * 1000 + 2000);
});
test('presignUploadPart should return url', async t => {
const provider = createProvider();
const result = await provider.presignUploadPart('key', 'upload-1', 3);
t.truthy(result);
t.true(result!.url.length > 0);
t.true(result!.url.includes('X-Amz-Algorithm=AWS4-HMAC-SHA256'));
});
test('createMultipartUpload should return uploadId', async t => {
const provider = createProvider();
let receivedKey: string | undefined;
let receivedMeta: any;
(provider as any).client = {
createMultipartUpload: async (key: string, meta: any) => {
receivedKey = key;
receivedMeta = meta;
return { uploadId: 'upload-1' };
},
};
const now = Date.now();
const result = await provider.createMultipartUpload('key', {
contentType: 'text/plain',
});
t.is(result?.uploadId, 'upload-1');
t.true(result!.expiresAt.getTime() >= now + SIGNED_URL_EXPIRED * 1000 - 2000);
t.true(result!.expiresAt.getTime() <= now + SIGNED_URL_EXPIRED * 1000 + 2000);
t.is(receivedKey, 'key');
t.is(receivedMeta.contentType, 'text/plain');
});
test('completeMultipartUpload should order parts', async t => {
const provider = createProvider();
let receivedParts: any;
(provider as any).client = {
completeMultipartUpload: async (
_key: string,
_uploadId: string,
parts: any
) => {
receivedParts = parts;
},
};
await provider.completeMultipartUpload('key', 'upload-1', [
{ partNumber: 2, etag: 'b' },
{ partNumber: 1, etag: 'a' },
]);
t.deepEqual(receivedParts, [
{ partNumber: 1, etag: 'a' },
{ partNumber: 2, etag: 'b' },
]);
});
@@ -1,20 +0,0 @@
import { Injectable } from '@nestjs/common';
import {
StorageProvider,
StorageProviderConfig,
StorageProviders,
} from './providers';
@Injectable()
export class StorageProviderFactory {
create(config: StorageProviderConfig): StorageProvider {
const Provider = StorageProviders[config.provider];
if (!Provider) {
throw new Error(`Unknown storage provider type: ${config.provider}`);
}
return new Provider(config.config, config.bucket);
}
}
@@ -1,12 +1,3 @@
import { Global, Module } from '@nestjs/common';
import { StorageProviderFactory } from './factory';
@Global()
@Module({
providers: [StorageProviderFactory],
exports: [StorageProviderFactory],
})
export class StorageProviderModule {}
export { StorageProviderFactory } from './factory';
export * from './providers';
export type * from './types';
export * from './utils';
@@ -1,318 +0,0 @@
import {
accessSync,
constants,
createReadStream,
Dirent,
mkdirSync,
readdirSync,
readFileSync,
rmSync,
statSync,
writeFileSync,
} from 'node:fs';
import { homedir } from 'node:os';
import { join, parse } from 'node:path';
import { Readable } from 'node:stream';
import { Logger } from '@nestjs/common';
import {
BlobInputType,
GetObjectMetadata,
ListObjectsMetadata,
PutObjectMetadata,
StorageProvider,
} from './provider';
import { autoMetadata, toBuffer } from './utils';
function normalizeStorageKey(key: string): string {
const normalized = key.replaceAll('\\', '/');
const segments = normalized.split('/');
if (
!normalized ||
normalized.startsWith('/') ||
segments.some(segment => !segment || segment === '.' || segment === '..')
) {
throw new Error(`Invalid storage key: ${key}`);
}
return segments.join('/');
}
function normalizeStoragePrefix(prefix: string): string {
const normalized = prefix.replaceAll('\\', '/');
if (!normalized) {
return normalized;
}
if (normalized.startsWith('/')) {
throw new Error(`Invalid storage prefix: ${prefix}`);
}
const segments = normalized.split('/');
const lastSegment = segments.pop();
if (
lastSegment === undefined ||
segments.some(segment => !segment || segment === '.' || segment === '..') ||
lastSegment === '.' ||
lastSegment === '..'
) {
throw new Error(`Invalid storage prefix: ${prefix}`);
}
if (lastSegment === '') {
return `${segments.join('/')}/`;
}
return [...segments, lastSegment].join('/');
}
export interface FsStorageConfig {
path: string;
}
export class FsStorageProvider implements StorageProvider {
private readonly path: string;
private readonly logger: Logger;
readonly type = 'fs';
constructor(
config: FsStorageConfig,
public readonly bucket: string
) {
this.path = config.path.startsWith('~/')
? join(homedir(), config.path.slice(2), bucket)
: join(config.path, bucket);
this.ensureAvailability();
this.logger = new Logger(`${FsStorageProvider.name}:${bucket}`);
}
async put(
key: string,
body: BlobInputType,
metadata: PutObjectMetadata = {}
): Promise<void> {
key = normalizeStorageKey(key);
const blob = await toBuffer(body);
// write object
this.writeObject(key, blob);
// write metadata
await this.writeMetadata(key, blob, metadata);
this.logger.verbose(`Object \`${key}\` put`);
}
async head(key: string) {
key = normalizeStorageKey(key);
const metadata = this.readMetadata(key);
if (!metadata) {
this.logger.verbose(`Object \`${key}\` not found`);
return undefined;
}
return metadata;
}
async get(key: string): Promise<{
body?: Readable;
metadata?: GetObjectMetadata;
}> {
key = normalizeStorageKey(key);
try {
const metadata = this.readMetadata(key);
const stream = this.readObject(this.join(key));
this.logger.verbose(`Read object \`${key}\``);
return {
body: stream,
metadata,
};
} catch (e) {
this.logger.error(`Failed to read object \`${key}\``, e);
return {};
}
}
async list(prefix?: string): Promise<ListObjectsMetadata[]> {
// prefix cases:
// - `undefined`: list all objects
// - `a/b`: list objects under dir `a` with prefix `b`, `b` might be a dir under `a` as well.
// - `a/b/` list objects under dir `a/b`
// read dir recursively and filter out '.metadata.json' files
let dir = this.path;
if (prefix) {
prefix = normalizeStoragePrefix(prefix);
const parts = prefix.split(/[/\\]/);
// for prefix `a/b/c`, move `a/b` to dir and `c` to key prefix
if (parts.length > 1) {
dir = join(dir, ...parts.slice(0, -1));
prefix = parts[parts.length - 1];
}
}
const results: ListObjectsMetadata[] = [];
async function getFiles(dir: string, prefix?: string): Promise<void> {
try {
const entries: Dirent[] = readdirSync(dir, { withFileTypes: true });
for (const entry of entries) {
const res = join(dir, entry.name);
if (entry.isDirectory()) {
if (!prefix || entry.name.startsWith(prefix)) {
await getFiles(res);
}
} else if (
(!prefix || entry.name.startsWith(prefix)) &&
!entry.name.endsWith('.metadata.json')
) {
const stat = statSync(res);
results.push({
key: res,
lastModified: stat.mtime,
contentLength: stat.size,
});
}
}
} catch {
// failed to read dir, stop recursion
}
}
await getFiles(dir, prefix);
// trim path with `this.path` prefix
results.forEach(r => (r.key = r.key.slice(this.path.length + 1)));
return results;
}
delete(key: string): Promise<void> {
key = normalizeStorageKey(key);
try {
rmSync(this.join(key), { force: true });
rmSync(this.join(`${key}.metadata.json`), { force: true });
} catch (e) {
throw new Error(`Failed to delete object \`${key}\``, {
cause: e,
});
}
this.logger.verbose(`Object \`${key}\` deleted`);
return Promise.resolve();
}
ensureAvailability() {
// check stats
const stats = statSync(this.path, {
throwIfNoEntry: false,
});
// not existing, create it
if (!stats) {
try {
mkdirSync(this.path, { recursive: true });
} catch (e) {
throw new Error(
`Failed to create target directory for fs storage provider: ${this.path}`,
{
cause: e,
}
);
}
} else if (stats.isDirectory()) {
// the target directory has already existed, check if it is readable & writable
try {
accessSync(this.path, constants.W_OK | constants.R_OK);
} catch (e) {
throw new Error(
`The target directory for fs storage provider has already existed, but it is not readable & writable: ${this.path}`,
{
cause: e,
}
);
}
} else if (stats.isFile()) {
throw new Error(
`The target directory for fs storage provider is a file: ${this.path}`
);
}
}
private join(...paths: string[]) {
return join(this.path, ...paths);
}
private readObject(file: string): Readable | undefined {
const state = statSync(file, { throwIfNoEntry: false });
if (state?.isFile()) {
return createReadStream(file);
}
return undefined;
}
private writeObject(key: string, blob: Buffer) {
const path = this.join(key);
mkdirSync(parse(path).dir, { recursive: true });
writeFileSync(path, blob);
}
private async writeMetadata(
key: string,
blob: Buffer,
raw: PutObjectMetadata
) {
try {
const metadata = autoMetadata(blob, raw);
if (raw.checksumCRC32 && metadata.checksumCRC32 !== raw.checksumCRC32) {
throw new Error(
'The checksum of the uploaded file is not matched with the one you provide, the file may be corrupted and the uploading will not be processed.'
);
}
if (raw.contentLength && metadata.contentLength !== raw.contentLength) {
throw new Error(
'The content length of the uploaded file is not matched with the one you provide, the file may be corrupted and the uploading will not be processed.'
);
}
writeFileSync(
this.join(`${key}.metadata.json`),
JSON.stringify({
...metadata,
lastModified: Date.now(),
})
);
} catch (e) {
this.logger.warn(`Failed to write metadata of object \`${key}\``, e);
}
}
private readMetadata(key: string): GetObjectMetadata | undefined {
try {
const raw = JSON.parse(
readFileSync(this.join(`${key}.metadata.json`), {
encoding: 'utf-8',
})
);
return {
...raw,
lastModified: new Date(raw.lastModified),
expires: raw.expires ? new Date(raw.expires) : undefined,
};
} catch (e) {
this.logger.warn(`Failed to read metadata of object \`${key}\``, e);
return;
}
}
}
@@ -1,20 +1,45 @@
import { Type } from '@nestjs/common';
import { JSONSchema } from '../../config';
import { FsStorageConfig, FsStorageProvider } from './fs';
import { StorageProvider } from './provider';
import { R2_JURISDICTIONS, R2StorageConfig, R2StorageProvider } from './r2';
import { S3StorageConfig, S3StorageProvider } from './s3';
export type StorageProviderName = 'fs' | 'aws-s3' | 'cloudflare-r2';
export const StorageProviders: Record<
StorageProviderName,
Type<StorageProvider>
> = {
fs: FsStorageProvider,
'aws-s3': S3StorageProvider,
'cloudflare-r2': R2StorageProvider,
};
export type StorageProviderName =
| 'fs'
| 'aws-s3'
| 'cloudflare-r2'
| 'assetpack';
export interface FsStorageConfig {
path: string;
}
export type AssetpackStorageConfig = FsStorageConfig;
export interface S3StorageConfig {
endpoint?: string;
region: string;
credentials?: {
accessKeyId?: string;
secretAccessKey?: string;
sessionToken?: string;
};
forcePathStyle?: boolean;
requestTimeoutMs?: number;
minPartSize?: number;
presign?: {
expiresInSeconds?: number;
signContentTypeForPut?: boolean;
};
}
export const R2_JURISDICTIONS = ['default', 'eu'] as const;
export interface R2StorageConfig extends Omit<S3StorageConfig, 'endpoint'> {
accountId: string;
jurisdiction?: (typeof R2_JURISDICTIONS)[number];
usePresignedURL?: {
enabled: boolean;
urlPrefix?: string;
signKey?: string;
};
}
export type StorageProviderConfig = { bucket: string } & (
| {
@@ -29,6 +54,10 @@ export type StorageProviderConfig = { bucket: string } & (
provider: 'cloudflare-r2';
config: R2StorageConfig;
}
| {
provider: 'assetpack';
config: AssetpackStorageConfig;
}
);
const S3ConfigSchema: JSONSchema = {
@@ -186,16 +215,37 @@ export const StorageJSONSchema: JSONSchema = {
},
},
},
{
type: 'object',
properties: {
provider: {
type: 'string',
enum: ['assetpack'],
},
bucket: {
type: 'string',
},
config: {
type: 'object',
properties: {
path: {
type: 'string',
},
},
required: ['path'],
},
},
required: ['provider', 'bucket', 'config'],
},
],
};
export type * from './provider';
export type * from '../types';
export {
applyAttachHeaders,
autoMetadata,
PROXY_MULTIPART_PATH,
PROXY_UPLOAD_PATH,
sniffMime,
STORAGE_PROXY_ROOT,
toBuffer,
} from './utils';
} from '../utils';
@@ -1,84 +0,0 @@
import type { Readable } from 'node:stream';
export interface GetObjectMetadata {
/**
* @default 'application/octet-stream'
*/
contentType: string;
contentLength: number;
lastModified: Date;
checksumCRC32?: string;
}
export interface PutObjectMetadata {
contentType?: string;
contentLength?: number;
checksumCRC32?: string;
}
export interface ListObjectsMetadata {
key: string;
lastModified: Date;
contentLength: number;
}
export type BlobInputType = Buffer | Readable | string;
export type BlobOutputType = Readable;
export interface PresignedUpload {
url: string;
headers?: Record<string, string>;
expiresAt: Date;
}
export interface MultipartUploadInit {
uploadId: string;
expiresAt: Date;
}
export interface MultipartUploadPart {
partNumber: number;
etag: string;
}
export interface StorageProvider {
put(
key: string,
body: BlobInputType,
metadata?: PutObjectMetadata
): Promise<void>;
presignPut?(
key: string,
metadata?: PutObjectMetadata
): Promise<PresignedUpload | undefined>;
createMultipartUpload?(
key: string,
metadata?: PutObjectMetadata
): Promise<MultipartUploadInit | undefined>;
presignUploadPart?(
key: string,
uploadId: string,
partNumber: number
): Promise<PresignedUpload | undefined>;
listMultipartUploadParts?(
key: string,
uploadId: string
): Promise<MultipartUploadPart[] | undefined>;
completeMultipartUpload?(
key: string,
uploadId: string,
parts: MultipartUploadPart[]
): Promise<void>;
abortMultipartUpload?(key: string, uploadId: string): Promise<void>;
head(key: string): Promise<GetObjectMetadata | undefined>;
get(
key: string,
signedUrl?: boolean
): Promise<{
redirectUrl?: string;
body?: BlobOutputType;
metadata?: GetObjectMetadata;
}>;
list(prefix?: string): Promise<ListObjectsMetadata[]>;
delete(key: string): Promise<void>;
}
@@ -1,251 +0,0 @@
import assert from 'node:assert';
import { Readable } from 'node:stream';
import { Logger } from '@nestjs/common';
import {
GetObjectMetadata,
PresignedUpload,
PutObjectMetadata,
} from './provider';
import { S3StorageConfig, S3StorageProvider } from './s3';
import {
PROXY_MULTIPART_PATH,
PROXY_UPLOAD_PATH,
SIGNED_URL_EXPIRED,
} from './utils';
export const R2_JURISDICTIONS = ['eu'] as const;
type R2Jurisdiction = (typeof R2_JURISDICTIONS)[number];
export interface R2StorageConfig extends Omit<
S3StorageConfig,
'endpoint' | 'forcePathStyle'
> {
accountId: string;
jurisdiction?: R2Jurisdiction;
usePresignedURL?: {
enabled: boolean;
urlPrefix?: string;
signKey?: string;
};
}
export class R2StorageProvider extends S3StorageProvider {
private readonly encoder = new TextEncoder();
private readonly key: Uint8Array;
constructor(
private readonly config: R2StorageConfig,
bucket: string
) {
assert(config.accountId, 'accountId is required for R2 storage provider');
const account = config.jurisdiction
? `${config.accountId}.${config.jurisdiction}`
: config.accountId;
const endpoint = `https://${account}.r2.cloudflarestorage.com`;
super(
{
...config,
forcePathStyle: true,
endpoint,
},
bucket
);
this.logger = new Logger(`${R2StorageProvider.name}:${bucket}`);
this.key = this.encoder.encode(config.usePresignedURL?.signKey ?? '');
}
private get shouldUseProxyUpload() {
const { usePresignedURL } = this.config;
return (
!!usePresignedURL?.enabled &&
!!usePresignedURL.signKey &&
this.key.length > 0
);
}
private parseWorkspaceKey(fullKey: string) {
const [workspaceId, ...rest] = fullKey.split('/');
if (!workspaceId || rest.length !== 1) {
return null;
}
return { workspaceId, key: rest.join('/') };
}
private async signPayload(payload: string) {
const key = await crypto.subtle.importKey(
'raw',
this.key,
{ name: 'HMAC', hash: 'SHA-256' },
false,
['sign', 'verify']
);
const mac = await crypto.subtle.sign(
'HMAC',
key,
this.encoder.encode(payload)
);
return Buffer.from(mac).toString('base64');
}
private async signUrl(url: URL): Promise<string> {
const timestamp = Math.floor(Date.now() / 1000);
const base64Mac = await this.signPayload(`${url.pathname}${timestamp}`);
url.searchParams.set('sign', `${timestamp}-${base64Mac}`);
return url.toString();
}
private async createProxyUrl(
path: string,
canonicalFields: (string | number | undefined)[],
query: Record<string, string | number | undefined>
) {
const exp = Math.floor(Date.now() / 1000) + SIGNED_URL_EXPIRED;
const canonical = [
path,
...canonicalFields.map(field =>
field === undefined ? '' : field.toString()
),
exp.toString(),
].join('\n');
const token = await this.signPayload(canonical);
const url = new URL(`http://localhost${path}`);
for (const [key, value] of Object.entries(query)) {
if (value === undefined) continue;
url.searchParams.set(key, value.toString());
}
url.searchParams.set('exp', exp.toString());
url.searchParams.set('token', `${exp}-${token}`);
return { url: url.pathname + url.search, expiresAt: new Date(exp * 1000) };
}
override async presignPut(
key: string,
metadata: PutObjectMetadata = {}
): Promise<PresignedUpload | undefined> {
if (!this.shouldUseProxyUpload) {
return super.presignPut(key, metadata);
}
const parsed = this.parseWorkspaceKey(key);
if (!parsed) {
return super.presignPut(key, metadata);
}
const contentType = metadata.contentType ?? 'application/octet-stream';
const { url, expiresAt } = await this.createProxyUrl(
PROXY_UPLOAD_PATH,
[parsed.workspaceId, parsed.key, contentType, metadata.contentLength],
{
workspaceId: parsed.workspaceId,
key: parsed.key,
contentType,
contentLength: metadata.contentLength,
}
);
return {
url,
headers: { 'Content-Type': contentType },
expiresAt,
};
}
override async presignUploadPart(
key: string,
uploadId: string,
partNumber: number
): Promise<PresignedUpload | undefined> {
if (!this.shouldUseProxyUpload) {
return super.presignUploadPart(key, uploadId, partNumber);
}
const parsed = this.parseWorkspaceKey(key);
if (!parsed) {
return super.presignUploadPart(key, uploadId, partNumber);
}
return this.createProxyUrl(
PROXY_MULTIPART_PATH,
[parsed.workspaceId, parsed.key, uploadId, partNumber],
{
workspaceId: parsed.workspaceId,
key: parsed.key,
uploadId,
partNumber,
}
);
}
async proxyPutObject(
key: string,
body: Readable | Buffer | Uint8Array | string,
options: { contentType?: string; contentLength?: number } = {}
) {
return this.client.putObject(key, this.normalizeBody(body), {
contentType: options.contentType,
contentLength: options.contentLength,
});
}
async proxyUploadPart(
key: string,
uploadId: string,
partNumber: number,
body: Readable | Buffer | Uint8Array | string,
options: { contentLength?: number } = {}
) {
const result = await this.client.uploadPart(
key,
uploadId,
partNumber,
this.normalizeBody(body),
{ contentLength: options.contentLength }
);
return result.etag;
}
private normalizeBody(body: Readable | Buffer | Uint8Array | string) {
// s3mini does not accept Node.js Readable directly.
// Convert it to Web ReadableStream for compatibility.
if (body instanceof Readable) {
return Readable.toWeb(body);
} else if (typeof body === 'string') {
return this.encoder.encode(body);
}
return body;
}
override async get(
key: string,
signedUrl?: boolean
): Promise<{
body?: Readable;
metadata?: GetObjectMetadata;
redirectUrl?: string;
}> {
const { usePresignedURL: { enabled, urlPrefix } = {} } = this.config;
if (signedUrl && enabled && urlPrefix) {
const metadata = await this.head(key);
const url = await this.signUrl(new URL(`/${key}`, urlPrefix));
if (metadata) {
return {
redirectUrl: url.toString(),
metadata,
};
}
// object not found
return {};
}
// fallback to s3 get
return super.get(key, signedUrl);
}
}
@@ -1,363 +0,0 @@
/* oxlint-disable @typescript-eslint/no-non-null-assertion */
import { Readable } from 'node:stream';
import type {
S3CompatClient,
S3CompatConfig,
S3CompatCredentials,
} from '@affine/s3-compat';
import { createS3CompatClient } from '@affine/s3-compat';
import { Logger } from '@nestjs/common';
import {
BlobInputType,
GetObjectMetadata,
ListObjectsMetadata,
MultipartUploadInit,
MultipartUploadPart,
PresignedUpload,
PutObjectMetadata,
StorageProvider,
} from './provider';
import { autoMetadata, SIGNED_URL_EXPIRED, toBuffer } from './utils';
export interface S3StorageConfig {
endpoint?: string;
region: string;
credentials: S3CompatCredentials;
forcePathStyle?: boolean;
requestTimeoutMs?: number;
minPartSize?: number;
presign?: {
expiresInSeconds?: number;
signContentTypeForPut?: boolean;
};
usePresignedURL?: {
enabled: boolean;
};
}
function resolveEndpoint(config: S3StorageConfig) {
if (config.endpoint) {
return config.endpoint;
}
if (config.region === 'us-east-1') {
return 'https://s3.amazonaws.com';
}
return `https://s3.${config.region}.amazonaws.com`;
}
function joinPath(basePath: string, suffix: string) {
const trimmedBase = basePath.endsWith('/') ? basePath.slice(0, -1) : basePath;
const trimmedSuffix = suffix.startsWith('/') ? suffix.slice(1) : suffix;
if (!trimmedBase) {
return `/${trimmedSuffix}`;
}
if (!trimmedSuffix) {
return trimmedBase;
}
return `${trimmedBase}/${trimmedSuffix}`;
}
function composeEndpointUrl(config: S3CompatConfig) {
const url = new URL(config.endpoint);
if (config.forcePathStyle) {
const firstSegment = url.pathname.split('/').find(Boolean);
if (firstSegment !== config.bucket) {
url.pathname = joinPath(url.pathname, config.bucket);
}
return url.toString();
}
const firstSegment = url.pathname.split('/').find(Boolean);
const hostHasBucket = url.hostname.startsWith(`${config.bucket}.`);
const pathHasBucket = firstSegment === config.bucket;
if (!hostHasBucket && !pathHasBucket) {
url.hostname = `${config.bucket}.${url.hostname}`;
}
return url.toString();
}
export class S3StorageProvider implements StorageProvider {
protected logger: Logger;
protected client: S3CompatClient;
private readonly usePresignedURL: boolean;
private readonly endpoint: string;
get endpointUrl() {
return this.endpoint;
}
constructor(
config: S3StorageConfig,
public readonly bucket: string
) {
const { usePresignedURL, presign, credentials, ...clientConfig } = config;
const compatConfig: S3CompatConfig = {
...clientConfig,
endpoint: resolveEndpoint(config),
bucket,
requestTimeoutMs: clientConfig.requestTimeoutMs ?? 60_000,
presign: {
expiresInSeconds: presign?.expiresInSeconds ?? SIGNED_URL_EXPIRED,
signContentTypeForPut: presign?.signContentTypeForPut ?? true,
},
};
this.endpoint = composeEndpointUrl(compatConfig);
this.client = createS3CompatClient(compatConfig, credentials);
this.usePresignedURL = usePresignedURL?.enabled ?? false;
this.logger = new Logger(`${S3StorageProvider.name}:${bucket}`);
}
async put(
key: string,
body: BlobInputType,
metadata: PutObjectMetadata = {}
): Promise<void> {
const blob = await toBuffer(body);
metadata = autoMetadata(blob, metadata);
try {
await this.client.putObject(key, blob, {
contentType: metadata.contentType,
contentLength: metadata.contentLength,
});
this.logger.verbose(`Object \`${key}\` put`);
} catch (e) {
this.logger.error(
`Failed to put object (${JSON.stringify({
key,
bucket: this.bucket,
metadata,
})})`
);
throw e;
}
}
async presignPut(
key: string,
metadata: PutObjectMetadata = {}
): Promise<PresignedUpload | undefined> {
try {
const contentType = metadata.contentType ?? 'application/octet-stream';
const result = await this.client.presignPutObject(key, { contentType });
return {
url: result.url,
headers: result.headers,
expiresAt: result.expiresAt,
};
} catch (e) {
this.logger.error(
`Failed to presign put object (${JSON.stringify({
key,
bucket: this.bucket,
metadata,
})}`
);
throw e;
}
}
async createMultipartUpload(
key: string,
metadata: PutObjectMetadata = {}
): Promise<MultipartUploadInit | undefined> {
try {
const contentType = metadata.contentType ?? 'application/octet-stream';
const response = await this.client.createMultipartUpload(key, {
contentType,
});
if (!response.uploadId) {
return;
}
return {
uploadId: response.uploadId,
expiresAt: new Date(Date.now() + SIGNED_URL_EXPIRED * 1000),
};
} catch (e) {
this.logger.error(
`Failed to create multipart upload (${JSON.stringify({
key,
bucket: this.bucket,
metadata,
})}`
);
throw e;
}
}
async presignUploadPart(
key: string,
uploadId: string,
partNumber: number
): Promise<PresignedUpload | undefined> {
try {
const result = await this.client.presignUploadPart(
key,
uploadId,
partNumber
);
return {
url: result.url,
expiresAt: result.expiresAt,
};
} catch (e) {
this.logger.error(
`Failed to presign upload part (${JSON.stringify({ key, bucket: this.bucket, uploadId, partNumber })}`
);
throw e;
}
}
async listMultipartUploadParts(
key: string,
uploadId: string
): Promise<MultipartUploadPart[] | undefined> {
try {
return await this.client.listParts(key, uploadId);
} catch (e) {
this.logger.error(`Failed to list multipart upload parts for \`${key}\``);
throw e;
}
}
async completeMultipartUpload(
key: string,
uploadId: string,
parts: MultipartUploadPart[]
): Promise<void> {
try {
const orderedParts = [...parts].sort(
(left, right) => left.partNumber - right.partNumber
);
await this.client.completeMultipartUpload(key, uploadId, orderedParts);
} catch (e) {
this.logger.error(`Failed to complete multipart upload for \`${key}\``);
throw e;
}
}
async abortMultipartUpload(key: string, uploadId: string): Promise<void> {
try {
await this.client.abortMultipartUpload(key, uploadId);
} catch (e) {
this.logger.error(`Failed to abort multipart upload for \`${key}\``);
throw e;
}
}
async head(key: string) {
try {
const obj = await this.client.headObject(key);
if (!obj) {
this.logger.verbose(`Object \`${key}\` not found`);
return undefined;
}
return {
contentType: obj.contentType ?? 'application/octet-stream',
contentLength: obj.contentLength ?? 0,
lastModified: obj.lastModified ?? new Date(0),
checksumCRC32: obj.checksumCRC32,
};
} catch (e) {
this.logger.error(`Failed to head object \`${key}\``);
throw e;
}
}
async get(
key: string,
signedUrl?: boolean
): Promise<{
body?: Readable;
metadata?: GetObjectMetadata;
redirectUrl?: string;
}> {
try {
if (this.usePresignedURL && signedUrl) {
const metadata = await this.head(key);
if (metadata) {
const result = await this.client.presignGetObject(key);
return {
redirectUrl: result.url,
metadata,
};
}
// object not found
return {};
}
const obj = await this.client.getObjectResponse(key);
if (!obj || !obj.body) {
this.logger.verbose(`Object \`${key}\` not found`);
return {};
}
const contentType = obj.headers.get('content-type') ?? undefined;
const contentLengthHeader = obj.headers.get('content-length');
const contentLength = contentLengthHeader
? Number(contentLengthHeader)
: undefined;
const lastModifiedHeader = obj.headers.get('last-modified');
const lastModified = lastModifiedHeader
? new Date(lastModifiedHeader)
: undefined;
this.logger.verbose(`Read object \`${key}\``);
return {
body: Readable.fromWeb(obj.body),
metadata: {
contentType: contentType ?? 'application/octet-stream',
contentLength: contentLength ?? 0,
lastModified: lastModified ?? new Date(0),
checksumCRC32: obj.headers.get('x-amz-checksum-crc32') ?? undefined,
},
};
} catch (e) {
this.logger.error(`Failed to read object \`${key}\``);
throw e;
}
}
async list(prefix?: string): Promise<ListObjectsMetadata[]> {
try {
const result = await this.client.listObjectsV2(prefix);
this.logger.verbose(
`List ${result.length} objects with prefix \`${prefix}\``
);
return result;
} catch (e) {
this.logger.error(`Failed to list objects with prefix \`${prefix}\``);
throw e;
}
}
async delete(key: string): Promise<void> {
try {
await this.client.deleteObject(key);
this.logger.verbose(`Deleted object \`${key}\``);
} catch (e) {
this.logger.error(`Failed to delete object \`${key}\``, {
bucket: this.bucket,
key,
cause: e,
});
throw e;
}
}
}
@@ -0,0 +1,32 @@
import type { Readable } from 'node:stream';
export interface GetObjectMetadata {
/**
* @default 'application/octet-stream'
*/
contentType: string;
contentLength: number;
lastModified: Date;
checksumCRC32?: string;
}
export interface PutObjectMetadata {
contentType?: string;
contentLength?: number;
checksumCRC32?: string;
}
export interface ListObjectsMetadata {
key: string;
lastModified: Date;
contentLength: number;
}
export type BlobInputType = Buffer | Readable | string;
export type BlobOutputType = Readable;
export interface PresignedUpload {
url: string;
headers?: Record<string, string>;
expiresAt: Date;
}
@@ -1,11 +1,10 @@
import { Readable } from 'node:stream';
import { crc32 } from '@node-rs/crc32';
import type { Response } from 'express';
import { getStreamAsBuffer } from 'get-stream';
import { getMime } from '../../../native';
import { BlobInputType, PutObjectMetadata } from './provider';
import { getMime } from '../../native';
import type { BlobInputType } from './types';
export async function toBuffer(input: BlobInputType): Promise<Buffer> {
return input instanceof Readable
@@ -15,35 +14,6 @@ export async function toBuffer(input: BlobInputType): Promise<Buffer> {
: Buffer.from(input as string);
}
export function autoMetadata(
blob: Buffer,
raw: PutObjectMetadata = {}
): PutObjectMetadata {
const metadata = {
...raw,
};
if (!metadata.contentLength) {
metadata.contentLength = blob.byteLength;
}
try {
// checksum
if (!metadata.checksumCRC32) {
metadata.checksumCRC32 = crc32(blob).toString(16);
}
// mime type
if (!metadata.contentType) {
metadata.contentType = getMime(blob);
}
} catch {
// noop
}
return metadata;
}
const DANGEROUS_INLINE_MIME_PREFIXES = [
'text/html',
'application/xhtml+xml',
@@ -2,6 +2,7 @@ import './config';
import { Module } from '@nestjs/common';
import { BackendRuntimeModule } from '../backend-runtime';
import { FeatureModule } from '../features';
import { MailModule } from '../mail';
import { QuotaModule } from '../quota';
@@ -20,7 +21,13 @@ import { AuthService } from './service';
import { SessionIssuer } from './session-issuer';
@Module({
imports: [FeatureModule, UserModule, QuotaModule, MailModule],
imports: [
BackendRuntimeModule,
FeatureModule,
UserModule,
QuotaModule,
MailModule,
],
providers: [
AuthService,
AuthResolver,
+6 -3
View File
@@ -2,7 +2,7 @@ import { Injectable } from '@nestjs/common';
import { Cron, CronExpression } from '@nestjs/schedule';
import { JobQueue, OnJob } from '../../base';
import { Models } from '../../models';
import { BackendRuntimeProvider } from '../backend-runtime';
declare global {
interface Jobs {
@@ -13,7 +13,7 @@ declare global {
@Injectable()
export class AuthCronJob {
constructor(
private readonly models: Models,
private readonly rt: BackendRuntimeProvider,
private readonly queue: JobQueue
) {}
@@ -31,6 +31,9 @@ export class AuthCronJob {
@OnJob('nightly.cleanExpiredUserSessions')
async cleanExpiredUserSessions() {
await this.models.session.cleanExpiredUserSessions();
for (;;) {
const count = await this.rt.cleanupExpiredUserSessions(1000);
if (count < 1000) break;
}
}
}
@@ -51,6 +51,13 @@ export class AuthService implements OnApplicationBootstrap {
};
}
private getServerName() {
return (
this.config.server.name ??
(env.selfhosted ? 'AFFiNE Self-hosted' : 'AFFiNE Cloud')
);
}
async onApplicationBootstrap() {
if (env.dev) {
await createDevUsers(this.models);
@@ -372,6 +379,7 @@ export class AuthService implements OnApplicationBootstrap {
props: {
url: link,
otp,
serverName: this.getServerName(),
},
});
}
@@ -0,0 +1,57 @@
import { ScheduleModule } from '@nestjs/schedule';
import ava, { TestFn } from 'ava';
import Sinon from 'sinon';
import {
createTestingModule,
type TestingModule,
} from '../../../__tests__/utils';
import { BackendRuntimeModule, BackendRuntimeProvider } from '../index';
import { BackendRuntimeHousekeepingJob } from '../job';
interface Context {
module: TestingModule;
job: BackendRuntimeHousekeepingJob;
runtime: {
cleanupExpiredRuntimeStates: Sinon.SinonStub;
cleanupExpiredRuntimeGates: Sinon.SinonStub;
};
}
const test = ava as TestFn<Context>;
test.before(async t => {
t.context.runtime = {
cleanupExpiredRuntimeStates: Sinon.stub(),
cleanupExpiredRuntimeGates: Sinon.stub(),
};
t.context.module = await createTestingModule({
imports: [ScheduleModule.forRoot(), BackendRuntimeModule],
tapModule: builder => {
builder
.overrideProvider(BackendRuntimeProvider)
.useValue(t.context.runtime);
},
});
t.context.job = t.context.module.get(BackendRuntimeHousekeepingJob);
});
test.beforeEach(t => {
t.context.runtime.cleanupExpiredRuntimeStates.reset();
t.context.runtime.cleanupExpiredRuntimeGates.reset();
});
test.after.always(async t => {
await t.context.module.close();
});
test('backend-runtime housekeeping cleans runtime state and gate batches', async t => {
t.context.runtime.cleanupExpiredRuntimeStates.onCall(0).resolves(1000);
t.context.runtime.cleanupExpiredRuntimeStates.onCall(1).resolves(2);
t.context.runtime.cleanupExpiredRuntimeGates.resolves(1);
await t.context.job.cleanExpiredRuntimeHousekeeping();
t.is(t.context.runtime.cleanupExpiredRuntimeStates.callCount, 2);
t.is(t.context.runtime.cleanupExpiredRuntimeGates.callCount, 1);
});
@@ -0,0 +1,41 @@
import test from 'ava';
import Sinon from 'sinon';
import { BackendRuntimeProvider } from '../provider';
test('backend-runtime provider starts once, runs migrations once, and reports health', async t => {
const provider = new BackendRuntimeProvider();
const runtime = {
start: Sinon.stub().resolves(),
stop: Sinon.stub().resolves(),
runMigrations: Sinon.stub().resolves(),
health: Sinon.stub().resolves({
started: true,
databaseConnected: true,
}),
};
(provider as any).runtime = runtime;
await provider.start();
await provider.start();
const health = await provider.health();
await provider.stop();
t.is(runtime.start.callCount, 2);
t.is(runtime.runMigrations.callCount, 1);
t.true(health.databaseConnected);
t.is(runtime.stop.callCount, 1);
});
test('backend-runtime provider measures explicit typed methods', async t => {
const provider = new BackendRuntimeProvider();
const runtime = {
cleanupExpiredRuntimeStates: Sinon.stub().resolves(3),
};
(provider as any).runtime = runtime;
const result = await provider.cleanupExpiredRuntimeStates(1000);
t.is(result, 3);
t.true(runtime.cleanupExpiredRuntimeStates.calledOnceWithExactly(1000));
});

Some files were not shown because too many files have changed in this diff Show More