Compare commits

..

11 Commits

Author SHA1 Message Date
DarkSky 5b7f83a6e3 feat(server): batch blob gc (#15183) 2026-07-02 06:14:01 +08:00
DarkSky 6f9269498f fix(server): blob gc planning 2026-07-02 03:28:56 +08:00
DarkSky e5d44b8ff2 fix(server): s3 metadata encode 2026-07-02 00:27:17 +08:00
DarkSky 8c68319094 feat(server): improve client builder 2026-07-01 23:27:20 +08:00
DarkSky 8ebdb7452f feat(server): impl storage runtime (#15181)
#### PR Dependency Tree


* **PR #15181** 👈

This tree was auto-generated by
[Charcoal](https://github.com/danerwilliams/charcoal)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added an additional storage backend option: asset-pack based storage
(provider for avatar, blob, and copilot).
* Introduced a dedicated storage runtime with provider capability
reporting and expanded object operations (put/head/get/list/delete),
including presigned and multipart flows where supported.
* Cloudflare R2 `jurisdiction` now uses an explicit default when
omitted.
* **Bug Fixes**
  * Broadened avatar access to allow both fs and asset-pack providers.
* Improved workspace blob upload completion validation and handling when
stored objects are missing or mismatched.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-07-01 22:24:10 +08:00
FailSafe da7d438377 fix: enforce quota for comment attachments (#15149)
## Summary

This change includes comment attachments in workspace storage usage and
checks workspace storage quota before accepting a new comment attachment
upload.

## Impact

Comment attachments already had a per-file size limit, but they were not
counted in the same workspace storage usage path as other uploaded
blobs. A user with comment permission could keep adding attachments
without those bytes participating in workspace storage quota
calculations.

## Fix

- Count comment attachment bytes in workspace storage usage
reconciliation.
- Check the workspace quota before storing a new comment attachment.
- Return the existing comment attachment quota error when the upload
would exceed limits.

## Validation

- `git diff --check`
- Full test/lint suite was not run locally because dependencies are not
installed in this checkout.


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **New Features**
* Workspace attachment uploads now respect storage and file quota limits
more accurately.
* Workspace storage tracking now includes comment attachments, improving
quota enforcement.

* **Bug Fixes**
* Attachment uploads now fail with a clear quota error when a workspace
is out of space or blob capacity.
* Storage usage calculations now better reflect actual workspace
content, including non-deleted files.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: failsafesecurity <190101117+failsafesecurity@users.noreply.github.com>
2026-07-01 08:45:33 +08:00
DarkSky a821f67fc9 fix: config override 2026-06-30 04:37:52 +08:00
DarkSky a1363b3873 fix(server): config & update handle (#15173)
#### PR Dependency Tree


* **PR #15173** 👈

This tree was auto-generated by
[Charcoal](https://github.com/danerwilliams/charcoal)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added native document update validation to check incoming Yjs updates
for decodability before applying them.
* Introduced support for validation timeouts and cancellation during
update checks.
* Blob maintenance jobs now detect when object storage is unavailable
and skip related work gracefully.

* **Bug Fixes**
* Invalid (and oversized) updates are now filtered out earlier during
document ingestion.
* Background blob maintenance continues processing other work even if
one workspace fails.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-06-29 22:59:17 +08:00
DarkSky 1b9e21f2de fix(core): handle unsupported server error (#15164)
fix #15160
fix #15161
fix #15158
fix #15166


#### PR Dependency Tree


* **PR #15164** 👈

This tree was auto-generated by
[Charcoal](https://github.com/danerwilliams/charcoal)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added a “server version too old” message for self-hosted servers,
including the required upgrade version.
* Sign-in and OAuth-related preflight steps now verify server
compatibility before proceeding.
* **Bug Fixes**
* Improved error handling for missing/invalid server version responses
and schema/type mismatches, mapping them to the upgrade instruction.
* **Tests**
* Added coverage for server version guarding and the resulting
user-friendly error payload.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-06-29 00:03:02 +08:00
DarkSky 0a422aa158 feat(server): blob reconciliation (#15165)
#### PR Dependency Tree


* **PR #15165** 👈

This tree was auto-generated by
[Charcoal](https://github.com/danerwilliams/charcoal)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added automated backend maintenance for missing blob metadata
backfill, document-to-blob reference rebuilding, and unreferenced blob
cleanup planning/execution.
* Introduced scheduled batch processing (workspace-paged) and paginated
object-storage listing.
* **Bug Fixes**
* Improved reliability of object-storage reads by treating expected “not
found” results as non-errors.
* Strengthened blob/expired cleanup flows with runtime-driven batching
and reduced coupling to metadata synchronization.
* **Tests**
* Expanded unit and e2e coverage for partial blob metadata and updated
runtime/job cleanup test assertions.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-06-29 00:02:38 +08:00
DarkSky 4a7c931eca fix(server): member loading (#15156)
#### PR Dependency Tree


* **PR #15156** 👈

This tree was auto-generated by
[Charcoal](https://github.com/danerwilliams/charcoal)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Bug Fixes**
* Fixed Stripe subscription syncing for Team plans to update the correct
existing local subscription (avoiding duplicates) while refreshing
quantity and billing/trial period details.
* **UI/UX Improvements**
* Improved workspace member list loading/error states with shared UI
components and steadier pagination behavior.
  * Refined fallback styling for cleaner, more stable layout.
* **Tests**
* Expanded subscription and projection coverage and adjusted
seat-allocation/e2e assertions to be more robust.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-06-26 21:29:15 +08:00
152 changed files with 10318 additions and 4644 deletions
+105
View File
@@ -121,6 +121,18 @@
"default": {
"concurrency": 1
}
},
"queues.backendRuntime": {
"type": "object",
"description": "The config for backend runtime job queue\n@default {\"concurrency\":1}",
"properties": {
"concurrency": {
"type": "number"
}
},
"default": {
"concurrency": 1
}
}
}
},
@@ -493,6 +505,7 @@
"jurisdiction": {
"type": "string",
"enum": [
"default",
"eu"
],
"description": "Optional jurisdiction for the cloudflare r2 endpoint. Set to \"eu\" for EU buckets."
@@ -518,6 +531,36 @@
}
}
}
},
{
"type": "object",
"properties": {
"provider": {
"type": "string",
"enum": [
"assetpack"
]
},
"bucket": {
"type": "string"
},
"config": {
"type": "object",
"properties": {
"path": {
"type": "string"
}
},
"required": [
"path"
]
}
},
"required": [
"provider",
"bucket",
"config"
]
}
],
"default": {
@@ -691,6 +734,7 @@
"jurisdiction": {
"type": "string",
"enum": [
"default",
"eu"
],
"description": "Optional jurisdiction for the cloudflare r2 endpoint. Set to \"eu\" for EU buckets."
@@ -716,6 +760,36 @@
}
}
}
},
{
"type": "object",
"properties": {
"provider": {
"type": "string",
"enum": [
"assetpack"
]
},
"bucket": {
"type": "string"
},
"config": {
"type": "object",
"properties": {
"path": {
"type": "string"
}
},
"required": [
"path"
]
}
},
"required": [
"provider",
"bucket",
"config"
]
}
],
"default": {
@@ -1332,6 +1406,7 @@
"jurisdiction": {
"type": "string",
"enum": [
"default",
"eu"
],
"description": "Optional jurisdiction for the cloudflare r2 endpoint. Set to \"eu\" for EU buckets."
@@ -1357,6 +1432,36 @@
}
}
}
},
{
"type": "object",
"properties": {
"provider": {
"type": "string",
"enum": [
"assetpack"
]
},
"bucket": {
"type": "string"
},
"config": {
"type": "object",
"properties": {
"path": {
"type": "string"
}
},
"required": [
"path"
]
}
},
"required": [
"provider",
"bucket",
"config"
]
}
],
"default": {
Generated
+387 -589
View File
File diff suppressed because it is too large Load Diff
+2 -24
View File
@@ -17,17 +17,12 @@ resolver = "3"
aes-gcm = "0.10"
affine_common = { path = "./packages/common/native" }
affine_nbstore = { path = "./packages/frontend/native/nbstore" }
ahash = "0.8"
anyhow = "1"
arbitrary = { version = "1.3", features = ["derive"] }
assert-json-diff = "2.0"
base64 = "0.22.1"
base64-simd = "0.8"
bitvec = "1.0"
block2 = "0.6"
byteorder = "1.5"
chrono = "0.4"
clap = { version = "4.4", features = ["derive"] }
core-foundation = "0.10"
coreaudio-rs = "0.12"
cpal = "0.15"
@@ -48,14 +43,10 @@ resolver = "3"
"webp",
] }
infer = { version = "0.19.0" }
lasso = { version = "0.7", features = ["multi-threaded"] }
lib0 = { version = "0.16", features = ["lib0-serde"] }
libc = "0.2"
libwebp-sys = "0.14.2"
little_exif = "0.6.23"
llm_adapter = { version = "0.2", default-features = false }
llm_runtime = { version = "0.2", default-features = false }
log = "0.4"
lru = "0.16"
matroska = "0.30"
memory-indexer = "0.3.1"
@@ -72,33 +63,22 @@ resolver = "3"
] }
napi-build = { version = "2" }
napi-derive = { version = "3.4" }
nom = "8"
notify = { version = "8", features = ["serde"] }
objc2 = "0.6"
objc2-foundation = "0.3"
ogg = "0.9"
once_cell = "1"
ordered-float = "5"
p256 = { version = "0.13", features = ["ecdsa", "pem"] }
parking_lot = "0.12"
phf = { version = "0.11", features = ["macros"] }
proptest = "1.3"
proptest-derive = "0.5"
pulldown-cmark = "0.13"
rand = "0.9"
rand_chacha = "0.9"
rand_distr = "0.5"
rayon = "1.10"
regex = "1.10"
rubato = "0.16"
safefetch = "0.1.0"
schemars = "0.8"
screencapturekit = "0.3"
serde = "1"
serde_json = "1"
sha2 = "0.10"
sha3 = "0.10"
smol_str = "0.3"
sha2 = "0.11"
sha3 = "0.11"
sqlx = { version = "0.8", default-features = false, features = [
"chrono",
"macros",
@@ -136,8 +116,6 @@ resolver = "3"
] }
windows-core = { version = "0.61" }
y-octo = "0.0.3"
y-sync = { version = "0.4" }
yrs = "0.23.0"
[profile.dev.package.sqlx-macros]
opt-level = 3
+20 -5
View File
@@ -16,17 +16,20 @@ affine_common = { workspace = true, features = [
"ydoc-loader",
] }
anyhow = { workspace = true }
aws-sdk-s3 = "1.137.0"
assetpack-core = "0.1.0"
assetpack-transform-precomp2 = "0.1.0"
base64 = { workspace = true }
chrono = { workspace = true }
crc32fast = "1.5.0"
doc_extractor = { workspace = true }
file-format = { workspace = true }
hex = { workspace = true }
homedir = { workspace = true }
image = { workspace = true }
infer = { workspace = true }
instant-xml = "0.7.5"
jsonschema = "0.46"
libwebp-sys = { workspace = true }
libwebp-sys = "0.9.6"
little_exif = { workspace = true }
llm_adapter = { workspace = true, features = ["schema", "ureq-client"] }
llm_runtime = { workspace = true, features = ["schema", "ureq-client"] }
@@ -36,6 +39,15 @@ napi = { workspace = true, features = ["async", "serde-json"] }
napi-derive = { workspace = true }
p256 = { workspace = true }
rand = { workspace = true }
reqwest = { version = "0.13.4", default-features = false, features = [
"rustls",
] }
rustls = { version = "0.23", default-features = false, features = [
"aws-lc-rs",
"std",
"tls12",
] }
rusty-s3 = "0.10.0"
safefetch = { workspace = true }
schemars = { workspace = true }
serde = { workspace = true, features = ["derive"] }
@@ -50,11 +62,13 @@ sqlx = { workspace = true, default-features = false, features = [
"postgres",
"runtime-tokio",
] }
thiserror.workspace = true
tiktoken-rs = { workspace = true }
tokio = { workspace = true, features = ["rt-multi-thread", "sync"] }
tokio = { workspace = true, features = ["rt-multi-thread", "sync", "time"] }
url = { workspace = true }
uuid = { workspace = true, features = ["v4"] }
v_htmlescape = { workspace = true }
webpki-roots = "1.0"
y-octo = { workspace = true, features = ["large_refs"] }
[target.'cfg(not(target_os = "linux"))'.dependencies]
@@ -64,8 +78,9 @@ mimalloc = { workspace = true }
mimalloc = { workspace = true, features = ["local_dynamic_tls"] }
[dev-dependencies]
rayon = { workspace = true }
tokio = { workspace = true, features = ["macros", "rt-multi-thread"] }
rayon = { workspace = true }
tempfile = "3.27.0"
tokio = { workspace = true, features = ["macros", "rt-multi-thread"] }
[build-dependencies]
napi-build = { workspace = true }
+6
View File
@@ -0,0 +1,6 @@
/* auto-generated by NAPI-RS */
/* eslint-disable */
declare const _default: typeof import('./index')
export default _default
+102 -33
View File
@@ -1,10 +1,10 @@
/* auto-generated by NAPI-RS */
/* eslint-disable */
declare const _default: typeof import('./index')
export default _default
export declare class BackendRuntime {
completeBlobUpload(workspaceId: string, key: string, expectedSize: number, expectedMime: string): Promise<RuntimeBlobCompleteResult>
completeFsBlobUpload(root: string, bucket: string, workspaceId: string, key: string, expectedSize: number, expectedMime: string): Promise<RuntimeBlobCompleteResult>
cleanupExpiredPendingBlobs(cutoffMs: number, limit: number): Promise<RuntimeBlobCleanupResult>
releaseDeletedBlobs(workspaceId: string, limit: number): Promise<RuntimeBlobCleanupResult>
acquireCoordinationLease(key: string, owner: string, ttlMs: number): Promise<CoordinationLeaseGrant | null>
releaseCoordinationLease(key: string, owner: string, fencingToken: bigint | number): Promise<boolean>
renewCoordinationLease(key: string, owner: string, fencingToken: bigint | number, ttlMs: number): Promise<boolean>
@@ -27,18 +27,6 @@ export declare class BackendRuntime {
cleanupExpiredRuntimeGates(limit: number): Promise<number>
cleanupExpiredUserSessions(limit: number): Promise<number>
cleanupExpiredSnapshotHistories(limit: number): Promise<number>
objectStorageHealth(): RuntimeObjectStorageHealth
objectStoragePut(key: string, body: Buffer, metadata?: RuntimeObjectStoragePutOptions | undefined | null): Promise<void>
objectStoragePresignPut(key: string, metadata?: RuntimeObjectStoragePutOptions | undefined | null): Promise<RuntimePresignedObjectRequest>
objectStorageCreateMultipartUpload(key: string, metadata?: RuntimeObjectStoragePutOptions | undefined | null): Promise<RuntimeMultipartUploadInit | null>
objectStoragePresignUploadPart(key: string, uploadId: string, partNumber: number): Promise<RuntimePresignedObjectRequest>
objectStorageListMultipartUploadParts(key: string, uploadId: string): Promise<Array<RuntimeMultipartUploadPart>>
objectStorageCompleteMultipartUpload(key: string, uploadId: string, parts: Array<RuntimeMultipartUploadPart>): Promise<void>
objectStorageAbortMultipartUpload(key: string, uploadId: string): Promise<void>
objectStorageHead(key: string): Promise<RuntimeObjectMetadata | null>
objectStorageGet(key: string): Promise<RuntimeObjectGetResult | null>
objectStorageList(prefix?: string | undefined | null): Promise<Array<RuntimeObjectListEntry>>
objectStorageDelete(key: string): Promise<void>
createAuthChallenge(purpose: string, token: string, payload: any, ttlMs: number): Promise<boolean>
getAuthChallenge(purpose: string, token: string): Promise<any | null>
consumeAuthChallenge(purpose: string, token: string): Promise<any | null>
@@ -70,6 +58,37 @@ export declare class LlmStreamHandle {
abort(): void
}
export declare class StorageRuntime {
planUnreferencedWorkspaceBlobs(workspaceId: string, gracePeriodDays: number, limit: number): Promise<RuntimeBlobCleanupPlanResult>
executeBlobCleanupCandidates(runId: string, gracePeriodDays: number, limit: number): Promise<RuntimeBlobCleanupExecuteResult>
cleanupExpiredPendingBlobs(cutoffMs: number, limit: number): Promise<RuntimeBlobCleanupResult>
releaseDeletedBlobs(workspaceId: string, limit: number): Promise<RuntimeBlobCleanupResult>
backfillMissingBlobMetadata(workspaceId: string | undefined | null, limit: number): Promise<RuntimeBlobMetadataBackfillResult>
rebuildDocBlobRefs(workspaceId: string, docId: string): Promise<RuntimeDocBlobRefsResult>
rebuildWorkspaceDocBlobRefs(workspaceId: string, limit: number): Promise<RuntimeDocBlobRefsResult>
constructor()
start(): Promise<void>
configure(configJson: string): void
stop(): Promise<void>
runMigrations(): Promise<void>
health(): Promise<StorageRuntimeHealth>
providerCapabilities(scope: string): Promise<StorageProviderCapabilities>
putObject(scope: string, key: string, body: Buffer, metadata?: RuntimeObjectStoragePutOptions | undefined | null): Promise<RuntimeObjectMetadata>
headObject(scope: string, key: string): Promise<RuntimeObjectMetadata | null>
getObject(scope: string, key: string): Promise<RuntimeObjectGetResult | null>
listObjects(scope: string, prefix?: string | undefined | null): Promise<Array<RuntimeObjectListEntry>>
deleteObject(scope: string, key: string): Promise<void>
presignPut(scope: string, key: string, metadata?: RuntimeObjectStoragePutOptions | undefined | null): Promise<RuntimePresignedObjectRequest | null>
presignGet(scope: string, key: string): Promise<RuntimePresignedObjectRequest | null>
createMultipartUpload(scope: string, key: string, metadata?: RuntimeObjectStoragePutOptions | undefined | null): Promise<RuntimeMultipartUploadInit | null>
presignUploadPart(scope: string, key: string, uploadId: string, partNumber: number): Promise<RuntimePresignedObjectRequest | null>
proxyUploadPart(scope: string, key: string, uploadId: string, partNumber: number, body: Buffer, contentLength?: number | undefined | null): Promise<string | null>
listMultipartUploadParts(scope: string, key: string, uploadId: string): Promise<Array<RuntimeMultipartUploadPart> | null>
completeMultipartUpload(scope: string, key: string, uploadId: string, parts: Array<RuntimeMultipartUploadPart>): Promise<boolean>
abortMultipartUpload(scope: string, key: string, uploadId: string): Promise<boolean>
completeWorkspaceBlobUpload(workspaceId: string, key: string, expectedSize: number, expectedMime: string): Promise<RuntimeBlobCompleteResult>
}
export declare class Tokenizer {
count(content: string, allowedSpecial?: Array<string> | undefined | null): number
}
@@ -143,7 +162,6 @@ export interface AssertSafeUrlRequest {
export interface BackendRuntimeHealth {
started: boolean
databaseConnected: boolean
objectStorageConfigured: boolean
}
export declare function buildPublicRootDoc(rootDocBin: Buffer, docMetas: Array<PublicDocMetaInput>): Buffer
@@ -816,6 +834,25 @@ export declare function resolveEntitlementV1(input: ResolveEntitlementInput): Re
export declare function runNativeActionRecipePreparedStream(input: ActionRuntimeInput, callback: ((err: Error | null, arg: string) => void)): LlmStreamHandle
export interface RuntimeBlobCleanupExecuteResult {
scannedCandidates: number
deletedObjects: number
deletedMetadata: number
skippedStillReferenced: number
failed: number
workspaceIds: Array<string>
}
export interface RuntimeBlobCleanupPlanResult {
runId?: string
scannedBlobs: number
candidatesMarked: number
protectedByDocRefs: number
protectedByMetadata: number
protectedByOtherRefs: number
nextCursor?: string
}
export interface RuntimeBlobCleanupResult {
scanned: number
deleted: number
@@ -831,12 +868,32 @@ export interface RuntimeBlobCompleteResult {
lastModifiedMs?: number
}
export interface RuntimeBlobMetadataBackfillResult {
scannedObjects: number
headedObjects: number
upsertedMetadata: number
skippedExisting: number
skippedWorkspaceMissing: number
failed: number
nextCursor?: string
workspaceIds: Array<string>
}
export interface RuntimeByokLocalLeaseRecord {
leaseId: string
payload: any
expiresAtMs: number
}
export interface RuntimeDocBlobRefsResult {
scannedDocs: number
parsedDocs: number
refsWritten: number
refsDeleted: number
failedDocs: number
nextCursor?: string
}
export interface RuntimeDocCompactionResult {
leaseAcquired: boolean
merged: boolean
@@ -891,22 +948,6 @@ export interface RuntimeObjectMetadata {
checksumCrc32?: string
}
export interface RuntimeObjectStorageHealth {
configured: boolean
provider?: string
bucket?: string
endpoint?: string
region?: string
hasCredentials: boolean
forcePathStyle: boolean
requestTimeoutMs?: number
minPartSize?: number
presignExpiresInSeconds?: number
presignSignContentTypeForPut?: boolean
usePresignedUrl: boolean
clientBuildable: boolean
}
export interface RuntimeObjectStoragePutOptions {
contentType?: string
contentLength?: number
@@ -989,6 +1030,28 @@ export interface SafeFetchResponse {
body: Buffer
}
export interface StorageProviderCapabilities {
put: boolean
get: boolean
head: boolean
list: boolean
delete: boolean
presignPut: boolean
presignGet: boolean
multipartDirect: boolean
proxyUpload: boolean
assetpack: boolean
serverMediatedOnly: boolean
}
export interface StorageRuntimeHealth {
started: boolean
databaseConnected: boolean
providerConfigured: boolean
provider?: string
bucket?: string
}
export interface ToolContract {
name: string
description?: string
@@ -1055,4 +1118,10 @@ export declare function updateLicenseSeats(request: LicenseSeatsRequest): Promis
*/
export declare function updateRootDocMetaTitle(rootDocBin: Buffer, docId: string, title: string): Buffer
/**
* Check whether a Yjs update binary can be decoded without applying it to a
* document state.
*/
export declare function validateDocUpdate(update: Buffer): Promise<boolean>
export declare function verifyChallengeResponse(response: string, bits: number, resource: string): Promise<boolean>
+1
View File
@@ -16,6 +16,7 @@
},
"napi": {
"binaryName": "server-native",
"dtsHeaderFile": "dts-header.d.ts",
"targets": [
"aarch64-apple-darwin",
"aarch64-unknown-linux-gnu",
@@ -1,296 +0,0 @@
use std::{
fs,
path::{Path, PathBuf},
};
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use napi::Result;
use serde::Deserialize;
use sha2::{Digest, Sha256};
use super::{BackendRuntime, error::napi_error, types::RuntimeBlobCompleteResult};
const MAX_BLOB_SIZE: i64 = i32::MAX as i64;
fn object_missing_error(err: &napi::Error) -> bool {
let message = err.to_string();
message.contains("NoSuchKey") || message.contains("NotFound") || message.contains("not found")
}
fn blob_complete_failure(reason: &str) -> RuntimeBlobCompleteResult {
RuntimeBlobCompleteResult {
ok: false,
reason: Some(reason.to_string()),
content_type: None,
content_length: None,
last_modified_ms: None,
}
}
fn blob_complete_success(
content_type: String,
content_length: i64,
last_modified_ms: i64,
) -> RuntimeBlobCompleteResult {
RuntimeBlobCompleteResult {
ok: true,
reason: None,
content_type: Some(content_type),
content_length: Some(content_length),
last_modified_ms: Some(last_modified_ms),
}
}
fn normalize_base64_url_key(key: &str) -> &str {
key.trim_end_matches('=')
}
fn sha256_base64_url(body: &[u8]) -> String {
URL_SAFE_NO_PAD.encode(Sha256::digest(body))
}
fn sha256_base64_url_matches(body: &[u8], key: &str) -> bool {
sha256_base64_url(body) == normalize_base64_url_key(key)
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct FsBlobMetadata {
content_type: String,
content_length: i64,
last_modified: i64,
}
fn normalize_storage_key(key: &str) -> Result<Vec<String>> {
let normalized = key.replace('\\', "/");
let segments = normalized.split('/').map(ToString::to_string).collect::<Vec<_>>();
if normalized.is_empty()
|| normalized.starts_with('/')
|| segments
.iter()
.any(|segment| segment.is_empty() || segment == "." || segment == "..")
{
return Err(napi_error(format!("Invalid storage key: {key}")));
}
Ok(segments)
}
fn fs_bucket_path(root: &str, bucket: &str) -> PathBuf {
if let Some(stripped) = root.strip_prefix("~/")
&& let Ok(Some(home)) = homedir::my_home()
{
return home.join(stripped).join(bucket);
}
Path::new(root).join(bucket)
}
fn fs_object_path(root: &str, bucket: &str, key: &str) -> Result<PathBuf> {
let mut path = fs_bucket_path(root, bucket);
for segment in normalize_storage_key(key)? {
path.push(segment);
}
Ok(path)
}
fn read_fs_metadata(path: &Path) -> Result<Option<FsBlobMetadata>> {
let metadata_path = PathBuf::from(format!("{}.metadata.json", path.display()));
let raw = match fs::read_to_string(metadata_path) {
Ok(raw) => raw,
Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(None),
Err(err) => {
return Err(napi_error(format!("BlobComplete read fs metadata failed: {err}")));
}
};
serde_json::from_str(&raw).map(Some).map_err(|err| {
napi_error(format!(
"BlobComplete parse fs metadata failed for {}: {err}",
path.display()
))
})
}
async fn upsert_completed_blob(
runtime: &BackendRuntime,
workspace_id: &str,
key: &str,
mime: &str,
size: i64,
) -> Result<()> {
if !(0..=MAX_BLOB_SIZE).contains(&size) {
return Err(napi_error("BlobComplete size exceeds limit"));
}
let size = i32::try_from(size).map_err(|_| napi_error("BlobComplete size exceeds limit"))?;
sqlx::query(
r#"
INSERT INTO blobs (workspace_id, key, mime, size, status, upload_id)
VALUES ($1, $2, $3, $4, 'completed', NULL)
ON CONFLICT (workspace_id, key)
DO UPDATE SET
mime = EXCLUDED.mime,
size = EXCLUDED.size,
status = EXCLUDED.status,
upload_id = NULL
"#,
)
.bind(workspace_id)
.bind(key)
.bind(mime)
.bind(size)
.execute(&runtime.pool().await?)
.await
.map_err(|err| napi_error(format!("BlobComplete upsert metadata failed: {err}")))?;
Ok(())
}
#[napi_derive::napi]
impl BackendRuntime {
#[napi]
pub async fn complete_blob_upload(
&self,
workspace_id: String,
key: String,
expected_size: i64,
expected_mime: String,
) -> Result<RuntimeBlobCompleteResult> {
if !(0..=MAX_BLOB_SIZE).contains(&expected_size) {
return Ok(blob_complete_failure("size_too_large"));
}
let object_key = format!("{workspace_id}/{key}");
let object = match self.object_storage_get(object_key.clone()).await {
Ok(Some(object)) => object,
Ok(None) => return Ok(blob_complete_failure("not_found")),
Err(err) if object_missing_error(&err) => return Ok(blob_complete_failure("not_found")),
Err(err) => return Err(err),
};
if !(0..=MAX_BLOB_SIZE).contains(&object.metadata.content_length) {
match self.object_storage_delete(object_key).await {
Ok(()) => {}
Err(err) if object_missing_error(&err) => {}
Err(err) => return Err(err),
}
return Ok(blob_complete_failure("size_too_large"));
}
if object.metadata.content_length != expected_size {
return Ok(blob_complete_failure("size_mismatch"));
}
if !expected_mime.is_empty() && object.metadata.content_type != expected_mime {
return Ok(blob_complete_failure("mime_mismatch"));
}
if !sha256_base64_url_matches(&object.body, &key) {
match self.object_storage_delete(object_key).await {
Ok(()) => {}
Err(err) if object_missing_error(&err) => {}
Err(err) => return Err(err),
}
return Ok(blob_complete_failure("checksum_mismatch"));
}
upsert_completed_blob(
self,
&workspace_id,
&key,
&object.metadata.content_type,
object.metadata.content_length,
)
.await?;
Ok(blob_complete_success(
object.metadata.content_type,
object.metadata.content_length,
object.metadata.last_modified_ms,
))
}
#[napi]
pub async fn complete_fs_blob_upload(
&self,
root: String,
bucket: String,
workspace_id: String,
key: String,
expected_size: i64,
expected_mime: String,
) -> Result<RuntimeBlobCompleteResult> {
if !(0..=MAX_BLOB_SIZE).contains(&expected_size) {
return Ok(blob_complete_failure("size_too_large"));
}
let storage_key = format!("{workspace_id}/{key}");
let path = fs_object_path(&root, &bucket, &storage_key)?;
let metadata = match read_fs_metadata(&path)? {
Some(metadata) => metadata,
None => return Ok(blob_complete_failure("not_found")),
};
if !(0..=MAX_BLOB_SIZE).contains(&metadata.content_length) {
let _ = fs::remove_file(&path);
let _ = fs::remove_file(PathBuf::from(format!("{}.metadata.json", path.display())));
return Ok(blob_complete_failure("size_too_large"));
}
if metadata.content_length != expected_size {
return Ok(blob_complete_failure("size_mismatch"));
}
if !expected_mime.is_empty() && metadata.content_type != expected_mime {
return Ok(blob_complete_failure("mime_mismatch"));
}
let body = match fs::read(&path) {
Ok(body) => body,
Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(blob_complete_failure("not_found")),
Err(err) => return Err(napi_error(format!("BlobComplete read fs object failed: {err}"))),
};
if !sha256_base64_url_matches(&body, &key) {
let _ = fs::remove_file(&path);
let _ = fs::remove_file(PathBuf::from(format!("{}.metadata.json", path.display())));
return Ok(blob_complete_failure("checksum_mismatch"));
}
upsert_completed_blob(
self,
&workspace_id,
&key,
&metadata.content_type,
metadata.content_length,
)
.await?;
Ok(blob_complete_success(
metadata.content_type,
metadata.content_length,
metadata.last_modified,
))
}
}
#[cfg(test)]
mod tests {
use super::{sha256_base64_url, sha256_base64_url_matches};
#[test]
fn sha256_base64_url_omits_padding() {
assert_eq!(
sha256_base64_url(b"hello"),
"LPJNul-wow4m6DsqxbninhsWHlwfp0JecwQzYpOLmCQ"
);
}
#[test]
fn sha256_base64_url_matches_legacy_padding() {
assert!(sha256_base64_url_matches(
b"hello",
"LPJNul-wow4m6DsqxbninhsWHlwfp0JecwQzYpOLmCQ="
));
}
}
@@ -1,128 +0,0 @@
use std::{
collections::HashMap,
env, fs,
path::{Path, PathBuf},
};
use napi::Result;
use serde::Deserialize;
use super::{
error::napi_error,
object_storage::{ObjectStorageConfig, StorageProviderConfig},
};
#[derive(Clone, Debug)]
pub(super) struct RuntimeConfig {
pub(super) database_url: String,
pub(super) storage: Option<ObjectStorageConfig>,
}
impl RuntimeConfig {
pub(super) fn from_config_files() -> Result<Self> {
let database_url =
database_url_from_config_files()?.unwrap_or_else(|| "postgresql://localhost:5432/affine".to_string());
let storage = ObjectStorageConfig::from_config_files()?;
Ok(Self { database_url, storage })
}
}
#[derive(Debug, Deserialize)]
struct AppConfigFile {
db: Option<DbConfigFile>,
storages: Option<HashMap<String, StorageProviderConfig>>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct DbConfigFile {
datasource_url: Option<String>,
}
fn database_url_from_config_files() -> Result<Option<String>> {
let mut database_url = None;
for path in config_json_paths() {
if !path.exists() {
continue;
}
let raw = fs::read_to_string(&path)
.map_err(|err| napi_error(format!("failed to read config file {}: {err}", path.display())))?;
let config: AppConfigFile = serde_json::from_str(&raw)
.map_err(|err| napi_error(format!("failed to parse config file {}: {err}", path.display())))?;
if let Some(next) = config.db.and_then(|db| db.datasource_url)
&& !next.trim().is_empty()
{
database_url = Some(next);
}
}
Ok(database_url)
}
pub(super) fn blob_storage_config_from_config_files() -> Result<Option<StorageProviderConfig>> {
let mut storage = None;
for path in config_json_paths() {
if !path.exists() {
continue;
}
let raw = fs::read_to_string(&path)
.map_err(|err| napi_error(format!("failed to read config file {}: {err}", path.display())))?;
let config: AppConfigFile = serde_json::from_str(&raw)
.map_err(|err| napi_error(format!("failed to parse config file {}: {err}", path.display())))?;
if let Some(next) = config.storages.and_then(|mut storages| storages.remove("blob.storage")) {
storage = Some(next);
}
}
Ok(storage)
}
pub(super) fn config_json_paths() -> Vec<PathBuf> {
let mut paths = Vec::new();
if let Ok(exe) = env::current_exe()
&& let Some(dir) = exe.parent()
{
paths.push(config_in(dir));
}
if let Ok(cwd) = env::current_dir() {
paths.push(config_in(&cwd));
}
dedupe_paths(paths)
}
fn config_in(dir: &Path) -> PathBuf {
dir.join("config.json")
}
fn dedupe_paths(paths: Vec<PathBuf>) -> Vec<PathBuf> {
let mut deduped = Vec::new();
for path in paths {
if !deduped.contains(&path) {
deduped.push(path);
}
}
deduped
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn config_paths_are_limited_to_executable_dir_and_cwd() {
let paths = config_json_paths();
assert!(!paths.is_empty());
assert!(paths.len() <= 2);
assert!(
paths
.iter()
.all(|path| path.file_name().is_some_and(|name| name == "config.json"))
);
assert!(paths.iter().all(|path| !path.to_string_lossy().contains(".affine")));
assert!(
paths
.iter()
.all(|path| !path.to_string_lossy().contains("packages/backend/server"))
);
}
}
@@ -1,5 +0,0 @@
use napi::{Error, Status};
pub(super) fn napi_error(message: impl Into<String>) -> Error {
Error::new(Status::GenericFailure, message.into())
}
@@ -1,353 +0,0 @@
use std::{
collections::HashMap,
time::{Duration, SystemTime, UNIX_EPOCH},
};
use aws_sdk_s3::{
Client as S3Client, presigning::PresigningConfig, primitives::ByteStream, types::CompletedMultipartUpload,
};
use napi::Result;
use super::types::{
MultipartUploadInitResult, MultipartUploadPart, ObjectGetResult, ObjectListEntry, ObjectMetadata, ObjectPutMetadata,
PresignedObjectRequest, completed_multipart_parts, trim_etag,
};
use crate::backend_runtime::error::napi_error;
#[derive(Clone)]
pub(super) struct ObjectStorageClient {
client: S3Client,
bucket: String,
presign_expires_in_seconds: u64,
presign_sign_content_type_for_put: bool,
}
impl ObjectStorageClient {
pub(super) fn new(
config: aws_sdk_s3::Config,
bucket: String,
presign_expires_in_seconds: u64,
presign_sign_content_type_for_put: bool,
) -> Self {
Self {
client: S3Client::from_conf(config),
bucket,
presign_expires_in_seconds,
presign_sign_content_type_for_put,
}
}
pub(super) fn non_destructive_health(&self) -> bool {
let _ = &self.client;
!self.bucket.is_empty()
}
pub(super) async fn put(&self, key: &str, body: Vec<u8>, metadata: ObjectPutMetadata) -> Result<()> {
let content_length = metadata.content_length.unwrap_or(body.len() as i64);
let content_type = metadata
.content_type
.unwrap_or_else(|| "application/octet-stream".to_string());
let mut request = self
.client
.put_object()
.bucket(&self.bucket)
.key(key)
.body(ByteStream::from(body))
.content_type(content_type)
.content_length(content_length);
if let Some(checksum) = metadata.checksum_crc32 {
request = request.checksum_crc32(checksum);
}
request
.send()
.await
.map_err(|err| napi_error(format!("ObjectStorage put failed for {key}: {err:?}")))?;
Ok(())
}
pub(super) async fn presign_put(&self, key: &str, metadata: ObjectPutMetadata) -> Result<PresignedObjectRequest> {
let content_type = metadata
.content_type
.unwrap_or_else(|| "application/octet-stream".to_string());
let expires_at_ms = expires_at_ms(self.presign_expires_in_seconds)?;
let config = PresigningConfig::expires_in(Duration::from_secs(self.presign_expires_in_seconds))
.map_err(|err| napi_error(format!("ObjectStorage presign config failed: {err}")))?;
let mut request = self.client.put_object().bucket(&self.bucket).key(key);
if self.presign_sign_content_type_for_put {
request = request.content_type(content_type.clone());
}
if let Some(content_length) = metadata.content_length {
request = request.content_length(content_length);
}
let presigned = request
.presigned(config)
.await
.map_err(|err| napi_error(format!("ObjectStorage presign put failed for {key}: {err}")))?;
let mut headers = presigned_headers(&presigned);
headers.insert("Content-Type".to_string(), content_type);
Ok(PresignedObjectRequest {
url: presigned.uri().to_string(),
headers,
expires_at_ms,
})
}
pub(super) async fn create_multipart_upload(
&self,
key: &str,
metadata: ObjectPutMetadata,
) -> Result<Option<MultipartUploadInitResult>> {
let content_type = metadata
.content_type
.unwrap_or_else(|| "application/octet-stream".to_string());
let result = self
.client
.create_multipart_upload()
.bucket(&self.bucket)
.key(key)
.content_type(content_type)
.send()
.await
.map_err(|err| {
napi_error(format!(
"ObjectStorage create multipart upload failed for {key}: {err:?}"
))
})?;
let expires_at_ms = expires_at_ms(self.presign_expires_in_seconds)?;
Ok(result.upload_id.map(|upload_id| MultipartUploadInitResult {
upload_id,
expires_at_ms,
}))
}
pub(super) async fn presign_upload_part(
&self,
key: &str,
upload_id: &str,
part_number: i32,
) -> Result<PresignedObjectRequest> {
let expires_at_ms = expires_at_ms(self.presign_expires_in_seconds)?;
let config = PresigningConfig::expires_in(Duration::from_secs(self.presign_expires_in_seconds))
.map_err(|err| napi_error(format!("ObjectStorage presign config failed: {err}")))?;
let presigned = self
.client
.upload_part()
.bucket(&self.bucket)
.key(key)
.upload_id(upload_id)
.part_number(part_number)
.presigned(config)
.await
.map_err(|err| napi_error(format!("ObjectStorage presign upload part failed for {key}: {err}")))?;
Ok(PresignedObjectRequest {
url: presigned.uri().to_string(),
headers: presigned_headers(&presigned),
expires_at_ms,
})
}
pub(super) async fn list_multipart_upload_parts(
&self,
key: &str,
upload_id: &str,
) -> Result<Vec<MultipartUploadPart>> {
let result = self
.client
.list_parts()
.bucket(&self.bucket)
.key(key)
.upload_id(upload_id)
.send()
.await
.map_err(|err| {
napi_error(format!(
"ObjectStorage list multipart upload parts failed for {key}: {err}"
))
})?;
Ok(
result
.parts()
.iter()
.filter_map(|part| {
Some(MultipartUploadPart {
part_number: part.part_number?,
etag: trim_etag(part.e_tag.as_deref().unwrap_or_default()),
})
})
.collect(),
)
}
pub(super) async fn complete_multipart_upload(
&self,
key: &str,
upload_id: &str,
parts: Vec<MultipartUploadPart>,
) -> Result<()> {
let ordered_parts = completed_multipart_parts(parts);
self
.client
.complete_multipart_upload()
.bucket(&self.bucket)
.key(key)
.upload_id(upload_id)
.multipart_upload(
CompletedMultipartUpload::builder()
.set_parts(Some(ordered_parts))
.build(),
)
.send()
.await
.map_err(|err| {
napi_error(format!(
"ObjectStorage complete multipart upload failed for {key}: {err}"
))
})?;
Ok(())
}
pub(super) async fn abort_multipart_upload(&self, key: &str, upload_id: &str) -> Result<()> {
self
.client
.abort_multipart_upload()
.bucket(&self.bucket)
.key(key)
.upload_id(upload_id)
.send()
.await
.map_err(|err| {
napi_error(format!(
"ObjectStorage abort multipart upload failed for {key}: {err:?}"
))
})?;
Ok(())
}
pub(super) async fn head(&self, key: &str) -> Result<Option<ObjectMetadata>> {
let result = self
.client
.head_object()
.bucket(&self.bucket)
.key(key)
.send()
.await
.map_err(|err| napi_error(format!("ObjectStorage head failed for {key}: {err:?}")))?;
Ok(Some(ObjectMetadata {
content_type: result
.content_type
.unwrap_or_else(|| "application/octet-stream".to_string()),
content_length: result.content_length.unwrap_or(0),
last_modified_ms: optional_datetime_ms(result.last_modified),
checksum_crc32: result.checksum_crc32,
}))
}
pub(super) async fn get(&self, key: &str) -> Result<Option<ObjectGetResult>> {
let result = self
.client
.get_object()
.bucket(&self.bucket)
.key(key)
.send()
.await
.map_err(|err| napi_error(format!("ObjectStorage get failed for {key}: {err:?}")))?;
let metadata = ObjectMetadata {
content_type: result
.content_type
.unwrap_or_else(|| "application/octet-stream".to_string()),
content_length: result.content_length.unwrap_or(0),
last_modified_ms: optional_datetime_ms(result.last_modified),
checksum_crc32: result.checksum_crc32,
};
let body = result
.body
.collect()
.await
.map_err(|err| napi_error(format!("ObjectStorage read body failed for {key}: {err}")))?
.into_bytes()
.to_vec();
Ok(Some(ObjectGetResult { body, metadata }))
}
pub(super) async fn list(&self, prefix: Option<String>) -> Result<Vec<ObjectListEntry>> {
let mut entries = Vec::new();
let mut token = None;
loop {
let mut request = self.client.list_objects_v2().bucket(&self.bucket);
if let Some(prefix) = &prefix {
request = request.prefix(prefix);
}
if let Some(next_token) = token {
request = request.continuation_token(next_token);
}
let result = request
.send()
.await
.map_err(|err| napi_error(format!("ObjectStorage list failed: {err:?}")))?;
entries.extend(result.contents().iter().filter_map(|object| {
Some(ObjectListEntry {
key: object.key.as_ref()?.clone(),
content_length: object.size.unwrap_or(0),
last_modified_ms: optional_datetime_ms(object.last_modified),
})
}));
if result.is_truncated.unwrap_or(false) {
token = result.next_continuation_token;
} else {
break;
}
}
Ok(entries)
}
pub(super) async fn delete(&self, key: &str) -> Result<()> {
self
.client
.delete_object()
.bucket(&self.bucket)
.key(key)
.send()
.await
.map_err(|err| napi_error(format!("ObjectStorage delete failed for {key}: {err:?}")))?;
Ok(())
}
}
fn expires_at_ms(expires_in_seconds: u64) -> Result<i64> {
let expires_at = SystemTime::now()
.checked_add(Duration::from_secs(expires_in_seconds))
.ok_or_else(|| napi_error("ObjectStorage presign expiration overflow"))?;
system_time_ms(expires_at)
}
fn system_time_ms(time: SystemTime) -> Result<i64> {
let duration = time
.duration_since(UNIX_EPOCH)
.map_err(|err| napi_error(format!("system time before unix epoch: {err}")))?;
Ok(duration.as_millis() as i64)
}
fn optional_datetime_ms(time: Option<aws_sdk_s3::primitives::DateTime>) -> i64 {
time.and_then(|value| value.to_millis().ok()).unwrap_or(0)
}
fn presigned_headers(request: &aws_sdk_s3::presigning::PresignedRequest) -> HashMap<String, String> {
request
.headers()
.map(|(key, value)| (key.to_string(), value.to_string()))
.collect()
}
@@ -1,184 +0,0 @@
mod client;
mod config;
#[cfg(test)]
mod tests;
mod types;
use client::ObjectStorageClient;
pub(super) use config::ObjectStorageConfig;
use napi::{Result, bindgen_prelude::Buffer};
pub(super) use types::StorageProviderConfig;
use super::{
BackendRuntime,
types::{
RuntimeMultipartUploadInit, RuntimeMultipartUploadPart, RuntimeObjectGetResult, RuntimeObjectListEntry,
RuntimeObjectMetadata, RuntimeObjectStorageHealth, RuntimeObjectStoragePutOptions, RuntimePresignedObjectRequest,
},
};
#[napi_derive::napi]
impl BackendRuntime {
fn object_storage_client(&self) -> Result<ObjectStorageClient> {
self
.config
.storage
.as_ref()
.ok_or_else(|| super::error::napi_error("ObjectStorageClient is not configured"))?
.build_client()
}
pub(super) async fn object_storage_delete_object(&self, key: &str) -> Result<()> {
self.object_storage_client()?.delete(key).await
}
pub(super) async fn object_storage_abort_upload(&self, key: &str, upload_id: &str) -> Result<()> {
self
.object_storage_client()?
.abort_multipart_upload(key, upload_id)
.await
}
#[napi]
pub fn object_storage_health(&self) -> RuntimeObjectStorageHealth {
match &self.config.storage {
Some(storage) => storage.health(),
None => RuntimeObjectStorageHealth {
configured: false,
provider: None,
bucket: None,
endpoint: None,
region: None,
has_credentials: false,
force_path_style: false,
request_timeout_ms: None,
min_part_size: None,
presign_expires_in_seconds: None,
presign_sign_content_type_for_put: None,
use_presigned_url: false,
client_buildable: false,
},
}
}
#[napi]
pub async fn object_storage_put(
&self,
key: String,
body: Buffer,
metadata: Option<RuntimeObjectStoragePutOptions>,
) -> Result<()> {
self
.object_storage_client()?
.put(&key, body.to_vec(), metadata.map(Into::into).unwrap_or_default())
.await
}
#[napi]
pub async fn object_storage_presign_put(
&self,
key: String,
metadata: Option<RuntimeObjectStoragePutOptions>,
) -> Result<RuntimePresignedObjectRequest> {
self
.object_storage_client()?
.presign_put(&key, metadata.map(Into::into).unwrap_or_default())
.await?
.try_into()
}
#[napi]
pub async fn object_storage_create_multipart_upload(
&self,
key: String,
metadata: Option<RuntimeObjectStoragePutOptions>,
) -> Result<Option<RuntimeMultipartUploadInit>> {
Ok(
self
.object_storage_client()?
.create_multipart_upload(&key, metadata.map(Into::into).unwrap_or_default())
.await?
.map(Into::into),
)
}
#[napi]
pub async fn object_storage_presign_upload_part(
&self,
key: String,
upload_id: String,
part_number: i32,
) -> Result<RuntimePresignedObjectRequest> {
self
.object_storage_client()?
.presign_upload_part(&key, &upload_id, part_number)
.await?
.try_into()
}
#[napi]
pub async fn object_storage_list_multipart_upload_parts(
&self,
key: String,
upload_id: String,
) -> Result<Vec<RuntimeMultipartUploadPart>> {
Ok(
self
.object_storage_client()?
.list_multipart_upload_parts(&key, &upload_id)
.await?
.into_iter()
.map(Into::into)
.collect(),
)
}
#[napi]
pub async fn object_storage_complete_multipart_upload(
&self,
key: String,
upload_id: String,
parts: Vec<RuntimeMultipartUploadPart>,
) -> Result<()> {
self
.object_storage_client()?
.complete_multipart_upload(&key, &upload_id, parts.into_iter().map(Into::into).collect())
.await
}
#[napi]
pub async fn object_storage_abort_multipart_upload(&self, key: String, upload_id: String) -> Result<()> {
self
.object_storage_client()?
.abort_multipart_upload(&key, &upload_id)
.await
}
#[napi]
pub async fn object_storage_head(&self, key: String) -> Result<Option<RuntimeObjectMetadata>> {
Ok(self.object_storage_client()?.head(&key).await?.map(Into::into))
}
#[napi]
pub async fn object_storage_get(&self, key: String) -> Result<Option<RuntimeObjectGetResult>> {
Ok(self.object_storage_client()?.get(&key).await?.map(Into::into))
}
#[napi]
pub async fn object_storage_list(&self, prefix: Option<String>) -> Result<Vec<RuntimeObjectListEntry>> {
Ok(
self
.object_storage_client()?
.list(prefix)
.await?
.into_iter()
.map(Into::into)
.collect(),
)
}
#[napi]
pub async fn object_storage_delete(&self, key: String) -> Result<()> {
self.object_storage_client()?.delete(&key).await
}
}
@@ -1,129 +0,0 @@
use super::{
config::ObjectStorageConfig,
types::{MultipartUploadPart, ObjectPutMetadata, StorageProviderConfig, completed_multipart_parts, trim_etag},
};
#[test]
fn resolves_r2_config_from_config_json_shape() {
let storage = StorageProviderConfig {
provider: "cloudflare-r2".to_string(),
bucket: "workspace-blobs".to_string(),
config: serde_json::json!({
"accountId": "account",
"jurisdiction": "eu",
"credentials": {
"accessKeyId": "key",
"secretAccessKey": "secret"
},
"usePresignedURL": {
"enabled": true
}
}),
};
let config = ObjectStorageConfig::from_r2_config(storage).unwrap().unwrap();
assert_eq!(config.provider, "cloudflare-r2");
assert_eq!(config.bucket, "workspace-blobs");
assert_eq!(
config.endpoint.as_deref(),
Some("https://account.eu.r2.cloudflarestorage.com")
);
assert_eq!(config.region.as_deref(), Some("auto"));
assert!(config.force_path_style);
assert!(config.use_presigned_url);
assert_eq!(config.access_key_id.as_deref(), Some("key"));
}
#[test]
fn resolves_s3_config_from_config_json_shape() {
let storage = StorageProviderConfig {
provider: "aws-s3".to_string(),
bucket: "workspace-blobs".to_string(),
config: serde_json::json!({
"region": "us-west-2",
"credentials": {
"accessKeyId": "key",
"secretAccessKey": "secret",
"sessionToken": "session"
},
"forcePathStyle": true,
"requestTimeoutMs": 1000,
"minPartSize": 1024,
"presign": {
"expiresInSeconds": 60,
"signContentTypeForPut": false
}
}),
};
let config = ObjectStorageConfig::from_s3_config(storage).unwrap().unwrap();
assert_eq!(config.provider, "aws-s3");
assert_eq!(config.endpoint.as_deref(), Some("https://s3.us-west-2.amazonaws.com"));
assert_eq!(config.session_token.as_deref(), Some("session"));
assert!(config.force_path_style);
assert_eq!(config.request_timeout_ms, Some(1000));
assert_eq!(config.min_part_size, Some(1024));
assert_eq!(config.presign_expires_in_seconds, Some(60));
assert_eq!(config.presign_sign_content_type_for_put, Some(false));
}
#[tokio::test]
async fn object_storage_presign_put_returns_sigv4_url_and_headers() {
let storage = StorageProviderConfig {
provider: "aws-s3".to_string(),
bucket: "test-bucket".to_string(),
config: serde_json::json!({
"region": "us-east-1",
"endpoint": "https://s3.us-east-1.amazonaws.com",
"credentials": {
"accessKeyId": "key",
"secretAccessKey": "secret"
},
"presign": {
"expiresInSeconds": 60
}
}),
};
let config = ObjectStorageConfig::from_s3_config(storage).unwrap().unwrap();
let Ok(Ok(client)) = std::panic::catch_unwind(|| config.build_client()) else {
eprintln!("skipping object storage presign test: S3 client cannot be built in this environment");
return;
};
let result = client
.presign_put(
"key",
ObjectPutMetadata {
content_type: Some("text/plain".to_string()),
..Default::default()
},
)
.await
.unwrap();
assert!(result.url.contains("X-Amz-Algorithm=AWS4-HMAC-SHA256"));
assert!(result.url.contains("X-Amz-SignedHeaders="));
assert_eq!(
result.headers.get("Content-Type").map(String::as_str),
Some("text/plain")
);
assert!(result.expires_at_ms > 0);
}
#[test]
fn object_storage_orders_completed_multipart_parts_and_trims_etags() {
let parts = completed_multipart_parts(vec![
MultipartUploadPart {
part_number: 2,
etag: trim_etag("\"b\""),
},
MultipartUploadPart {
part_number: 1,
etag: trim_etag("a"),
},
]);
assert_eq!(parts[0].part_number, Some(1));
assert_eq!(parts[0].e_tag.as_deref(), Some("a"));
assert_eq!(parts[1].part_number, Some(2));
assert_eq!(parts[1].e_tag.as_deref(), Some("b"));
}
@@ -1,165 +0,0 @@
use std::collections::HashMap;
use aws_sdk_s3::types::CompletedPart;
use napi::Result;
use serde::Deserialize;
use crate::backend_runtime::{
error::napi_error,
types::{
RuntimeMultipartUploadInit, RuntimeMultipartUploadPart, RuntimeObjectGetResult, RuntimeObjectListEntry,
RuntimeObjectMetadata, RuntimeObjectStoragePutOptions, RuntimePresignedObjectRequest,
},
};
#[derive(Clone, Debug, Default)]
pub(super) struct ObjectPutMetadata {
pub(super) content_type: Option<String>,
pub(super) content_length: Option<i64>,
pub(super) checksum_crc32: Option<String>,
}
#[derive(Clone, Debug, PartialEq)]
pub(super) struct ObjectMetadata {
pub(super) content_type: String,
pub(super) content_length: i64,
pub(super) last_modified_ms: i64,
pub(super) checksum_crc32: Option<String>,
}
#[derive(Clone, Debug, PartialEq)]
pub(super) struct ObjectListEntry {
pub(super) key: String,
pub(super) content_length: i64,
pub(super) last_modified_ms: i64,
}
#[derive(Clone, Debug, PartialEq)]
pub(super) struct ObjectGetResult {
pub(super) body: Vec<u8>,
pub(super) metadata: ObjectMetadata,
}
#[derive(Clone, Debug, PartialEq)]
pub(super) struct PresignedObjectRequest {
pub(super) url: String,
pub(super) headers: HashMap<String, String>,
pub(super) expires_at_ms: i64,
}
#[derive(Clone, Debug, PartialEq)]
pub(super) struct MultipartUploadInitResult {
pub(super) upload_id: String,
pub(super) expires_at_ms: i64,
}
#[derive(Clone, Debug, PartialEq)]
pub(super) struct MultipartUploadPart {
pub(super) part_number: i32,
pub(super) etag: String,
}
#[derive(Clone, Debug, Deserialize)]
pub(in crate::backend_runtime) struct StorageProviderConfig {
pub(super) provider: String,
pub(super) bucket: String,
#[serde(default)]
pub(super) config: serde_json::Value,
}
pub(super) fn trim_etag(etag: &str) -> String {
etag.trim_matches('"').to_string()
}
pub(super) fn completed_multipart_parts(mut parts: Vec<MultipartUploadPart>) -> Vec<CompletedPart> {
parts.sort_by_key(|part| part.part_number);
parts
.into_iter()
.map(|part| {
CompletedPart::builder()
.part_number(part.part_number)
.e_tag(part.etag)
.build()
})
.collect()
}
impl From<RuntimeObjectStoragePutOptions> for ObjectPutMetadata {
fn from(options: RuntimeObjectStoragePutOptions) -> Self {
Self {
content_type: options.content_type,
content_length: options.content_length,
checksum_crc32: options.checksum_crc32,
}
}
}
impl From<ObjectMetadata> for RuntimeObjectMetadata {
fn from(metadata: ObjectMetadata) -> Self {
Self {
content_type: metadata.content_type,
content_length: metadata.content_length,
last_modified_ms: metadata.last_modified_ms,
checksum_crc32: metadata.checksum_crc32,
}
}
}
impl From<ObjectListEntry> for RuntimeObjectListEntry {
fn from(entry: ObjectListEntry) -> Self {
Self {
key: entry.key,
content_length: entry.content_length,
last_modified_ms: entry.last_modified_ms,
}
}
}
impl TryFrom<PresignedObjectRequest> for RuntimePresignedObjectRequest {
type Error = napi::Error;
fn try_from(request: PresignedObjectRequest) -> Result<Self> {
Ok(Self {
url: request.url,
headers_json: serde_json::to_string(&request.headers)
.map_err(|err| napi_error(format!("ObjectStorage headers serialization failed: {err}")))?,
expires_at_ms: request.expires_at_ms,
})
}
}
impl From<ObjectGetResult> for RuntimeObjectGetResult {
fn from(result: ObjectGetResult) -> Self {
Self {
body: result.body.into(),
metadata: result.metadata.into(),
}
}
}
impl From<MultipartUploadInitResult> for RuntimeMultipartUploadInit {
fn from(init: MultipartUploadInitResult) -> Self {
Self {
upload_id: init.upload_id,
expires_at_ms: init.expires_at_ms,
}
}
}
impl From<RuntimeMultipartUploadPart> for MultipartUploadPart {
fn from(part: RuntimeMultipartUploadPart) -> Self {
Self {
part_number: part.part_number,
etag: part.etag,
}
}
}
impl From<MultipartUploadPart> for RuntimeMultipartUploadPart {
fn from(part: MultipartUploadPart) -> Self {
Self {
part_number: part.part_number,
etag: part.etag,
}
}
}
@@ -1,40 +0,0 @@
CREATE TABLE IF NOT EXISTS runtime_states (
purpose TEXT NOT NULL,
token_hash TEXT NOT NULL,
lookup_key TEXT,
payload JSONB NOT NULL,
attempts INTEGER NOT NULL DEFAULT 0,
consumed_at TIMESTAMPTZ(3),
expires_at TIMESTAMPTZ(3) NOT NULL,
created_at TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (purpose, token_hash)
);
CREATE INDEX IF NOT EXISTS runtime_states_lookup_idx
ON runtime_states (purpose, lookup_key)
WHERE lookup_key IS NOT NULL AND consumed_at IS NULL;
CREATE INDEX IF NOT EXISTS runtime_states_expires_at_idx
ON runtime_states (expires_at);
CREATE TABLE IF NOT EXISTS runtime_gates (
key TEXT PRIMARY KEY,
expires_at TIMESTAMPTZ(3) NOT NULL,
created_at TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS runtime_gates_expires_at_idx
ON runtime_gates (expires_at);
CREATE TABLE IF NOT EXISTS runtime_leases (
key TEXT PRIMARY KEY,
owner TEXT NOT NULL,
fencing_token BIGINT NOT NULL,
expires_at TIMESTAMPTZ(3) NOT NULL,
created_at TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS runtime_leases_expires_at_idx
ON runtime_leases (expires_at);
+18 -2
View File
@@ -2,7 +2,6 @@
mod utils;
pub mod backend_runtime;
pub mod doc;
pub mod doc_loader;
pub mod entitlement;
@@ -13,12 +12,13 @@ pub mod image;
pub mod license;
pub mod llm;
pub mod permission;
pub mod runtime;
pub mod safe_fetch;
pub mod tiktoken;
use affine_common::napi_utils::map_napi_err;
use napi::{Result, Status, bindgen_prelude::*};
use y_octo::Doc;
use y_octo::{Doc, Update};
#[cfg(not(target_arch = "arm"))]
#[global_allocator]
@@ -41,6 +41,16 @@ pub fn merge_updates_in_apply_way(updates: Vec<Buffer>) -> Result<Buffer> {
Ok(buf.into())
}
/// Check whether a Yjs update binary can be decoded without applying it to a
/// document state.
#[napi(catch_unwind)]
pub async fn validate_doc_update(update: Buffer) -> Result<bool> {
let update = update.to_vec();
tokio::task::spawn_blocking(move || Update::decode_v1(update).is_ok())
.await
.map_err(|err| napi::Error::from_reason(format!("Doc update validation task failed: {err}")))
}
#[napi]
pub const AFFINE_PRO_PUBLIC_KEY: Option<&'static str> = std::option_env!("AFFINE_PRO_PUBLIC_KEY");
@@ -59,4 +69,10 @@ mod tests {
};
assert_eq!(err.status, Status::GenericFailure);
}
#[test]
fn y_octo_update_decode_accepts_valid_update_and_rejects_invalid_update() {
assert!(Update::decode_v1(vec![0, 0]).is_ok());
assert!(Update::decode_v1(vec![0]).is_err());
}
}
+2 -5
View File
@@ -1,7 +1,7 @@
use std::{
collections::HashMap,
sync::{Mutex, OnceLock},
time::{Duration, SystemTime, UNIX_EPOCH},
time::{Duration, SystemTime},
};
use anyhow::{Context, Result as AnyResult, bail};
@@ -458,10 +458,7 @@ fn parse_future_end_at(value: &serde_json::Value) -> AnyResult<f64> {
}
fn now_millis() -> f64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as f64
crate::utils::system_time_millis(SystemTime::now()).unwrap_or_default() as f64
}
fn affine_pro_ech_config() -> AnyResult<Vec<u8>> {
@@ -7,4 +7,3 @@ pub(super) const WORKSPACE_INVITE_LINK_WORKSPACE_PURPOSE: &str = "workspace_invi
pub(super) const WORKSPACE_STATS_LEASE_KEY: &str = "workspace:admin-stats:refresh";
pub(super) const WORKSPACE_STATS_LOCK_NAMESPACE: i64 = 97_301;
pub(super) const WORKSPACE_STATS_REFRESH_LOCK_KEY: i64 = 1;
pub(super) const RUNTIME_MIGRATIONS: &str = include_str!("sql/runtime_migrations.sql");
@@ -1,7 +1,7 @@
use napi::Result;
use sqlx::{FromRow, PgPool};
use super::{BackendRuntime, error::napi_error, types::CoordinationLeaseGrant};
use super::{BackendRuntime, RuntimeError, RuntimeResult, napi_error, types::CoordinationLeaseGrant};
#[derive(FromRow)]
struct LeaseGrantRow {
@@ -17,7 +17,7 @@ impl CoordinationLeaseStore {
Self { pool }
}
async fn acquire(&self, key: String, owner: String, ttl_ms: i64) -> Result<Option<CoordinationLeaseGrant>> {
async fn acquire(&self, key: String, owner: String, ttl_ms: i64) -> RuntimeResult<Option<CoordinationLeaseGrant>> {
let row = sqlx::query_as::<_, LeaseGrantRow>(
r#"
INSERT INTO runtime_leases (key, owner, fencing_token, expires_at)
@@ -36,7 +36,7 @@ impl CoordinationLeaseStore {
.bind(ttl_ms as f64)
.fetch_optional(&self.pool)
.await
.map_err(|err| napi_error(format!("CoordinationLease acquire failed: {err}")))?;
.map_err(|err| RuntimeError::database("CoordinationLease acquire failed", err))?;
Ok(row.map(|row| CoordinationLeaseGrant {
key,
@@ -45,7 +45,7 @@ impl CoordinationLeaseStore {
}))
}
async fn release(&self, key: &str, owner: &str, fencing_token: i64) -> Result<bool> {
async fn release(&self, key: &str, owner: &str, fencing_token: i64) -> RuntimeResult<bool> {
let result = sqlx::query(
r#"
DELETE FROM runtime_leases
@@ -57,12 +57,12 @@ impl CoordinationLeaseStore {
.bind(fencing_token)
.execute(&self.pool)
.await
.map_err(|err| napi_error(format!("CoordinationLease release failed: {err}")))?;
.map_err(|err| RuntimeError::database("CoordinationLease release failed", err))?;
Ok(result.rows_affected() == 1)
}
async fn renew(&self, key: &str, owner: &str, fencing_token: i64, ttl_ms: i64) -> Result<bool> {
async fn renew(&self, key: &str, owner: &str, fencing_token: i64, ttl_ms: i64) -> RuntimeResult<bool> {
let result = sqlx::query(
r#"
UPDATE runtime_leases
@@ -80,7 +80,7 @@ impl CoordinationLeaseStore {
.bind(ttl_ms as f64)
.execute(&self.pool)
.await
.map_err(|err| napi_error(format!("CoordinationLease renew failed: {err}")))?;
.map_err(|err| RuntimeError::database("CoordinationLease renew failed", err))?;
Ok(result.rows_affected() == 1)
}
@@ -88,6 +88,35 @@ impl CoordinationLeaseStore {
#[napi_derive::napi]
impl BackendRuntime {
pub(crate) async fn acquire_coordination_lease_inner(
&self,
key: String,
owner: String,
ttl_ms: i64,
) -> RuntimeResult<Option<CoordinationLeaseGrant>> {
if ttl_ms <= 0 {
return Err(RuntimeError::invalid_input("coordination lease ttl must be positive"));
}
if owner.is_empty() {
return Err(RuntimeError::invalid_input("coordination lease owner is required"));
}
CoordinationLeaseStore::new(self.pool().await?)
.acquire(key, owner, ttl_ms)
.await
}
pub(crate) async fn release_coordination_lease_inner(
&self,
key: String,
owner: String,
fencing_token: i64,
) -> RuntimeResult<bool> {
CoordinationLeaseStore::new(self.pool().await?)
.release(&key, &owner, fencing_token)
.await
}
#[napi]
pub async fn acquire_coordination_lease(
&self,
@@ -95,16 +124,10 @@ impl BackendRuntime {
owner: String,
ttl_ms: i64,
) -> Result<Option<CoordinationLeaseGrant>> {
if ttl_ms <= 0 {
return Err(napi_error("coordination lease ttl must be positive"));
}
if owner.is_empty() {
return Err(napi_error("coordination lease owner is required"));
}
CoordinationLeaseStore::new(self.pool().await?)
.acquire(key, owner, ttl_ms)
self
.acquire_coordination_lease_inner(key, owner, ttl_ms)
.await
.map_err(napi::Error::from)
}
#[napi]
@@ -114,9 +137,10 @@ impl BackendRuntime {
owner: String,
#[napi(ts_arg_type = "bigint | number")] fencing_token: i64,
) -> Result<bool> {
CoordinationLeaseStore::new(self.pool().await?)
.release(&key, &owner, fencing_token)
self
.release_coordination_lease_inner(key, owner, fencing_token)
.await
.map_err(napi::Error::from)
}
#[napi]
@@ -134,5 +158,6 @@ impl BackendRuntime {
CoordinationLeaseStore::new(self.pool().await?)
.renew(&key, &owner, fencing_token, ttl_ms)
.await
.map_err(napi::Error::from)
}
}
@@ -1,9 +1,8 @@
use chrono::{DateTime, Duration, Utc};
use napi::Result;
use sqlx::{FromRow, PgPool, Postgres, Row, Transaction};
use y_octo::Doc;
use super::{BackendRuntime, error::napi_error, types::RuntimeDocCompactionResult};
use super::{BackendRuntime, RuntimeError, RuntimeResult, napi_error, types::RuntimeDocCompactionResult};
#[derive(FromRow)]
struct SnapshotRow {
@@ -35,7 +34,7 @@ impl DocCompactorStore {
batch_limit: i64,
history_min_interval_ms: i64,
history_max_age_seconds: i64,
) -> Result<(i64, bool)> {
) -> RuntimeResult<(i64, bool)> {
compact_doc(
self.pool.clone(),
workspace_id,
@@ -52,31 +51,32 @@ fn is_empty_doc(bin: &[u8]) -> bool {
bin.is_empty() || (bin.len() == 1 && bin[0] == 0) || (bin.len() == 2 && bin[0] == 0 && bin[1] == 0)
}
fn apply_updates(updates: impl IntoIterator<Item = Vec<u8>>) -> Result<Vec<u8>> {
fn apply_updates(updates: impl IntoIterator<Item = Vec<u8>>) -> RuntimeResult<Vec<u8>> {
let mut doc = Doc::default();
for update in updates {
doc
.apply_update_from_binary_v1(&update)
.map_err(|err| napi_error(format!("DocCompactor merge failed: {err}")))?;
.map_err(|err| RuntimeError::invalid_state(format!("DocCompactor merge failed: {err}")))?;
}
doc
.encode_update_v1()
.map_err(|err| napi_error(format!("DocCompactor encode failed: {err}")))
.map_err(|err| RuntimeError::invalid_state(format!("DocCompactor encode failed: {err}")))
}
fn checked_milliseconds(value: i64, field: &str) -> Result<Duration> {
Duration::try_milliseconds(value).ok_or_else(|| napi_error(format!("DocCompactor {field} is too large")))
fn checked_milliseconds(value: i64, field: &str) -> RuntimeResult<Duration> {
Duration::try_milliseconds(value)
.ok_or_else(|| RuntimeError::invalid_input(format!("DocCompactor {field} is too large")))
}
fn checked_seconds(value: i64, field: &str) -> Result<Duration> {
Duration::try_seconds(value).ok_or_else(|| napi_error(format!("DocCompactor {field} is too large")))
fn checked_seconds(value: i64, field: &str) -> RuntimeResult<Duration> {
Duration::try_seconds(value).ok_or_else(|| RuntimeError::invalid_input(format!("DocCompactor {field} is too large")))
}
async fn load_snapshot(
tx: &mut Transaction<'_, Postgres>,
workspace_id: &str,
doc_id: &str,
) -> Result<Option<SnapshotRow>> {
) -> RuntimeResult<Option<SnapshotRow>> {
sqlx::query_as::<_, SnapshotRow>(
r#"
SELECT blob, updated_at, updated_by
@@ -89,7 +89,7 @@ async fn load_snapshot(
.bind(doc_id)
.fetch_optional(&mut **tx)
.await
.map_err(|err| napi_error(format!("DocCompactor load snapshot failed: {err}")))
.map_err(|err| RuntimeError::database("DocCompactor load snapshot failed", err))
}
async fn load_updates(
@@ -97,7 +97,7 @@ async fn load_updates(
workspace_id: &str,
doc_id: &str,
batch_limit: i64,
) -> Result<Vec<UpdateRow>> {
) -> RuntimeResult<Vec<UpdateRow>> {
sqlx::query_as::<_, UpdateRow>(
r#"
SELECT blob, created_at, created_by
@@ -113,7 +113,7 @@ async fn load_updates(
.bind(batch_limit)
.fetch_all(&mut **tx)
.await
.map_err(|err| napi_error(format!("DocCompactor load updates failed: {err}")))
.map_err(|err| RuntimeError::database("DocCompactor load updates failed", err))
}
async fn upsert_snapshot(
@@ -123,7 +123,7 @@ async fn upsert_snapshot(
blob: &[u8],
timestamp: DateTime<Utc>,
editor: Option<&str>,
) -> Result<bool> {
) -> RuntimeResult<bool> {
if is_empty_doc(blob) {
return Ok(false);
}
@@ -154,7 +154,7 @@ async fn upsert_snapshot(
.bind(editor)
.fetch_optional(&mut **tx)
.await
.map_err(|err| napi_error(format!("DocCompactor upsert snapshot failed: {err}")))?;
.map_err(|err| RuntimeError::database("DocCompactor upsert snapshot failed", err))?;
Ok(row.is_some())
}
@@ -165,7 +165,7 @@ async fn should_create_history(
workspace_id: &str,
doc_id: &str,
history_min_interval_ms: i64,
) -> Result<bool> {
) -> RuntimeResult<bool> {
if is_empty_doc(&snapshot.blob) {
return Ok(false);
}
@@ -183,7 +183,7 @@ async fn should_create_history(
.bind(doc_id)
.fetch_optional(&mut **tx)
.await
.map_err(|err| napi_error(format!("DocCompactor load latest history failed: {err}")))?;
.map_err(|err| RuntimeError::database("DocCompactor load latest history failed", err))?;
let Some(row) = row else {
return Ok(true);
@@ -198,7 +198,7 @@ async fn should_create_history(
let threshold = snapshot
.updated_at
.checked_sub_signed(min_interval)
.ok_or_else(|| napi_error("DocCompactor history interval is out of range"))?;
.ok_or_else(|| RuntimeError::invalid_input("DocCompactor history interval is out of range"))?;
Ok(last_timestamp < threshold)
}
@@ -209,7 +209,7 @@ async fn create_history(
doc_id: &str,
snapshot: &SnapshotRow,
max_age_seconds: i64,
) -> Result<bool> {
) -> RuntimeResult<bool> {
if max_age_seconds <= 0 {
return Ok(false);
}
@@ -217,7 +217,7 @@ async fn create_history(
let max_age = checked_seconds(max_age_seconds, "history max age")?;
let expired_at = Utc::now()
.checked_add_signed(max_age)
.ok_or_else(|| napi_error("DocCompactor history max age is out of range"))?;
.ok_or_else(|| RuntimeError::invalid_input("DocCompactor history max age is out of range"))?;
sqlx::query(
r#"
INSERT INTO snapshot_histories
@@ -236,7 +236,7 @@ async fn create_history(
.bind(snapshot.updated_by.as_deref())
.execute(&mut **tx)
.await
.map_err(|err| napi_error(format!("DocCompactor create history failed: {err}")))?;
.map_err(|err| RuntimeError::database("DocCompactor create history failed", err))?;
Ok(true)
}
@@ -246,7 +246,7 @@ async fn delete_updates(
workspace_id: &str,
doc_id: &str,
timestamps: &[DateTime<Utc>],
) -> Result<i64> {
) -> RuntimeResult<i64> {
let result = sqlx::query(
r#"
DELETE FROM updates
@@ -260,7 +260,7 @@ async fn delete_updates(
.bind(timestamps)
.execute(&mut **tx)
.await
.map_err(|err| napi_error(format!("DocCompactor delete updates failed: {err}")))?;
.map_err(|err| RuntimeError::database("DocCompactor delete updates failed", err))?;
Ok(result.rows_affected() as i64)
}
@@ -272,18 +272,18 @@ async fn compact_doc(
batch_limit: i64,
history_min_interval_ms: i64,
history_max_age_seconds: i64,
) -> Result<(i64, bool)> {
) -> RuntimeResult<(i64, bool)> {
let mut tx = pool
.begin()
.await
.map_err(|err| napi_error(format!("DocCompactor begin transaction failed: {err}")))?;
.map_err(|err| RuntimeError::database("DocCompactor begin transaction failed", err))?;
let snapshot = load_snapshot(&mut tx, workspace_id, doc_id).await?;
let updates = load_updates(&mut tx, workspace_id, doc_id, batch_limit).await?;
if updates.is_empty() {
tx.commit()
.await
.map_err(|err| napi_error(format!("DocCompactor commit transaction failed: {err}")))?;
.map_err(|err| RuntimeError::database("DocCompactor commit transaction failed", err))?;
return Ok((0, false));
}
@@ -323,7 +323,7 @@ async fn compact_doc(
tx.commit()
.await
.map_err(|err| napi_error(format!("DocCompactor commit transaction failed: {err}")))?;
.map_err(|err| RuntimeError::database("DocCompactor commit transaction failed", err))?;
Ok((deleted, history_created))
}
@@ -350,7 +350,7 @@ impl BackendRuntime {
history_max_age_seconds: i64,
owner: String,
lease_ttl_ms: i64,
) -> Result<RuntimeDocCompactionResult> {
) -> napi::Result<RuntimeDocCompactionResult> {
if batch_limit <= 0 {
return Err(napi_error("doc compactor batch limit must be positive"));
}
@@ -365,7 +365,7 @@ impl BackendRuntime {
let max_age = checked_seconds(history_max_age_seconds, "history max age")?;
Utc::now()
.checked_add_signed(max_age)
.ok_or_else(|| napi_error("DocCompactor history max age is out of range"))?;
.ok_or_else(|| RuntimeError::invalid_input("DocCompactor history max age is out of range"))?;
}
let lease_key = format!("doc:update:{workspace_id}:{doc_id}");
@@ -394,7 +394,7 @@ impl BackendRuntime {
.release_coordination_lease(lease.key, lease.owner, lease.fencing_token)
.await?;
if !released {
return Err(napi_error("DocCompactor failed to release coordination lease"));
return Err(RuntimeError::invalid_state("DocCompactor failed to release coordination lease").into());
}
let (updates_merged, history_created) = result?;
@@ -1,14 +1,18 @@
use chrono::{DateTime, Duration, Utc};
use napi::{Result, bindgen_prelude::Buffer};
use napi::bindgen_prelude::Buffer;
use sqlx::{PgPool, Row};
use super::{BackendRuntime, error::napi_error, types::RuntimeDocHistoryInput};
use super::{BackendRuntime, RuntimeError, RuntimeResult, napi_error, types::RuntimeDocHistoryInput};
fn is_empty_doc(bin: &[u8]) -> bool {
bin.is_empty() || (bin.len() == 1 && bin[0] == 0) || (bin.len() == 2 && bin[0] == 0 && bin[1] == 0)
}
async fn latest_history_timestamp(pool: &PgPool, workspace_id: &str, doc_id: &str) -> Result<Option<DateTime<Utc>>> {
async fn latest_history_timestamp(
pool: &PgPool,
workspace_id: &str,
doc_id: &str,
) -> RuntimeResult<Option<DateTime<Utc>>> {
sqlx::query(
r#"
SELECT timestamp
@@ -23,7 +27,7 @@ async fn latest_history_timestamp(pool: &PgPool, workspace_id: &str, doc_id: &st
.fetch_optional(pool)
.await
.map(|row| row.map(|row| row.get("timestamp")))
.map_err(|err| napi_error(format!("DocStorage load latest history failed: {err}")))
.map_err(|err| RuntimeError::database("DocStorage load latest history failed", err))
}
#[napi_derive::napi]
@@ -36,13 +40,13 @@ impl BackendRuntime {
blob: Buffer,
timestamp_ms: i64,
editor_id: Option<String>,
) -> Result<bool> {
) -> napi::Result<bool> {
if is_empty_doc(blob.as_ref()) {
return Ok(false);
}
let timestamp = DateTime::<Utc>::from_timestamp_millis(timestamp_ms)
.ok_or_else(|| napi_error(format!("Invalid doc snapshot timestamp: {timestamp_ms}")))?;
.ok_or_else(|| RuntimeError::invalid_input(format!("Invalid doc snapshot timestamp: {timestamp_ms}")))?;
let pool = self.pool().await?;
let row = sqlx::query(
r#"
@@ -70,13 +74,13 @@ impl BackendRuntime {
.bind(editor_id.as_deref())
.fetch_optional(&pool)
.await
.map_err(|err| napi_error(format!("DocStorage upsert snapshot failed: {err}")))?;
.map_err(|err| RuntimeError::database("DocStorage upsert snapshot failed", err))?;
Ok(row.is_some())
}
#[napi]
pub async fn create_doc_history(&self, input: RuntimeDocHistoryInput) -> Result<bool> {
pub async fn create_doc_history(&self, input: RuntimeDocHistoryInput) -> napi::Result<bool> {
if input.history_min_interval_ms < 0 {
return Err(napi_error("doc history interval must be non-negative"));
}
@@ -85,7 +89,7 @@ impl BackendRuntime {
}
let timestamp = DateTime::<Utc>::from_timestamp_millis(input.timestamp_ms)
.ok_or_else(|| napi_error(format!("Invalid doc history timestamp: {}", input.timestamp_ms)))?;
.ok_or_else(|| RuntimeError::invalid_input(format!("Invalid doc history timestamp: {}", input.timestamp_ms)))?;
let pool = self.pool().await?;
let should_create = match latest_history_timestamp(&pool, &input.workspace_id, &input.doc_id).await? {
None => true,
@@ -118,41 +122,41 @@ impl BackendRuntime {
.bind(input.editor_id.as_deref())
.execute(&pool)
.await
.map_err(|err| napi_error(format!("DocStorage create history failed: {err}")))?;
.map_err(|err| RuntimeError::database("DocStorage create history failed", err))?;
Ok(true)
}
#[napi]
pub async fn delete_doc_storage(&self, workspace_id: String, doc_id: String) -> Result<()> {
pub async fn delete_doc_storage(&self, workspace_id: String, doc_id: String) -> napi::Result<()> {
let pool = self.pool().await?;
let mut tx = pool
.begin()
.await
.map_err(|err| napi_error(format!("DocStorage delete begin transaction failed: {err}")))?;
.map_err(|err| RuntimeError::database("DocStorage delete begin transaction failed", err))?;
sqlx::query("DELETE FROM snapshots WHERE workspace_id = $1 AND guid = $2")
.bind(&workspace_id)
.bind(&doc_id)
.execute(&mut *tx)
.await
.map_err(|err| napi_error(format!("DocStorage delete snapshot failed: {err}")))?;
.map_err(|err| RuntimeError::database("DocStorage delete snapshot failed", err))?;
sqlx::query("DELETE FROM updates WHERE workspace_id = $1 AND guid = $2")
.bind(&workspace_id)
.bind(&doc_id)
.execute(&mut *tx)
.await
.map_err(|err| napi_error(format!("DocStorage delete updates failed: {err}")))?;
.map_err(|err| RuntimeError::database("DocStorage delete updates failed", err))?;
sqlx::query("DELETE FROM snapshot_histories WHERE workspace_id = $1 AND guid = $2")
.bind(&workspace_id)
.bind(&doc_id)
.execute(&mut *tx)
.await
.map_err(|err| napi_error(format!("DocStorage delete histories failed: {err}")))?;
.map_err(|err| RuntimeError::database("DocStorage delete histories failed", err))?;
tx.commit()
.await
.map_err(|err| napi_error(format!("DocStorage delete commit failed: {err}")))?;
.map_err(|err| RuntimeError::database("DocStorage delete commit failed", err))?;
Ok(())
}
}
@@ -1,7 +1,7 @@
use napi::Result;
use sqlx::PgPool;
use super::{BackendRuntime, error::napi_error};
use super::{BackendRuntime, RuntimeError, RuntimeResult, napi_error};
struct RuntimeGateStore {
pool: PgPool,
@@ -12,18 +12,18 @@ impl RuntimeGateStore {
Self { pool }
}
async fn put_if_absent(&self, key: &str, ttl_ms: i64) -> Result<bool> {
async fn put_if_absent(&self, key: &str, ttl_ms: i64) -> RuntimeResult<bool> {
let mut tx = self
.pool
.begin()
.await
.map_err(|err| napi_error(format!("RuntimeGate transaction failed: {err}")))?;
.map_err(|err| RuntimeError::database("RuntimeGate transaction failed", err))?;
sqlx::query("DELETE FROM runtime_gates WHERE key = $1 AND expires_at <= CURRENT_TIMESTAMP")
.bind(key)
.execute(&mut *tx)
.await
.map_err(|err| napi_error(format!("RuntimeGate expired cleanup failed: {err}")))?;
.map_err(|err| RuntimeError::database("RuntimeGate expired cleanup failed", err))?;
let inserted = sqlx::query(
r#"
@@ -36,18 +36,18 @@ impl RuntimeGateStore {
.bind(ttl_ms as f64)
.execute(&mut *tx)
.await
.map_err(|err| napi_error(format!("RuntimeGate put_if_absent failed: {err}")))?
.map_err(|err| RuntimeError::database("RuntimeGate put_if_absent failed", err))?
.rows_affected()
== 1;
tx.commit()
.await
.map_err(|err| napi_error(format!("RuntimeGate transaction commit failed: {err}")))?;
.map_err(|err| RuntimeError::database("RuntimeGate transaction commit failed", err))?;
Ok(inserted)
}
async fn cleanup_expired(&self, limit: i64) -> Result<i64> {
async fn cleanup_expired(&self, limit: i64) -> RuntimeResult<i64> {
let result = sqlx::query(
r#"
DELETE FROM runtime_gates
@@ -62,7 +62,7 @@ impl RuntimeGateStore {
.bind(limit)
.execute(&self.pool)
.await
.map_err(|err| napi_error(format!("RuntimeGate cleanup failed: {err}")))?;
.map_err(|err| RuntimeError::database("RuntimeGate cleanup failed", err))?;
Ok(result.rows_affected() as i64)
}
@@ -78,6 +78,7 @@ impl BackendRuntime {
RuntimeGateStore::new(self.pool().await?)
.put_if_absent(&key, ttl_ms)
.await
.map_err(napi::Error::from)
}
#[napi]
@@ -85,6 +86,9 @@ impl BackendRuntime {
if limit <= 0 {
return Err(napi_error("runtime gate cleanup limit must be positive"));
}
RuntimeGateStore::new(self.pool().await?).cleanup_expired(limit).await
RuntimeGateStore::new(self.pool().await?)
.cleanup_expired(limit)
.await
.map_err(napi::Error::from)
}
}
@@ -1,7 +1,7 @@
use napi::Result;
use sqlx::PgPool;
use super::{BackendRuntime, error::napi_error};
use super::{BackendRuntime, RuntimeError, RuntimeResult, napi_error};
struct HousekeepingStore {
pool: PgPool,
@@ -12,7 +12,7 @@ impl HousekeepingStore {
Self { pool }
}
async fn cleanup_expired_user_sessions(&self, limit: i64) -> Result<i64> {
async fn cleanup_expired_user_sessions(&self, limit: i64) -> RuntimeResult<i64> {
let result = sqlx::query(
r#"
DELETE FROM user_sessions
@@ -27,12 +27,12 @@ impl HousekeepingStore {
.bind(limit)
.execute(&self.pool)
.await
.map_err(|err| napi_error(format!("Housekeeping user sessions cleanup failed: {err}")))?;
.map_err(|err| RuntimeError::database("Housekeeping user sessions cleanup failed", err))?;
Ok(result.rows_affected() as i64)
}
async fn cleanup_expired_snapshot_histories(&self, limit: i64) -> Result<i64> {
async fn cleanup_expired_snapshot_histories(&self, limit: i64) -> RuntimeResult<i64> {
let result = sqlx::query(
r#"
DELETE FROM snapshot_histories
@@ -48,7 +48,7 @@ impl HousekeepingStore {
.bind(limit)
.execute(&self.pool)
.await
.map_err(|err| napi_error(format!("Housekeeping snapshot histories cleanup failed: {err}")))?;
.map_err(|err| RuntimeError::database("Housekeeping snapshot histories cleanup failed", err))?;
Ok(result.rows_affected() as i64)
}
@@ -65,6 +65,7 @@ impl BackendRuntime {
HousekeepingStore::new(self.pool().await?)
.cleanup_expired_user_sessions(limit)
.await
.map_err(napi::Error::from)
}
#[napi]
@@ -76,5 +77,6 @@ impl BackendRuntime {
HousekeepingStore::new(self.pool().await?)
.cleanup_expired_snapshot_histories(limit)
.await
.map_err(napi::Error::from)
}
}
@@ -1,28 +1,25 @@
mod blob_complete;
mod blob_reclaimer;
mod config;
mod constants;
mod coordination_lease;
mod doc_compactor;
mod doc_storage;
mod error;
mod gate;
mod housekeeping;
mod object_storage;
mod runtime_state;
#[cfg(test)]
mod tests;
mod types;
mod workspace_stats;
use std::time::Duration;
use std::{sync::RwLock, time::Duration};
use napi::Result;
use sha2::{Digest, Sha256};
use sqlx::{PgPool, Row, postgres::PgPoolOptions};
use tokio::sync::Mutex;
use self::{config::RuntimeConfig, constants::RUNTIME_MIGRATIONS, error::napi_error, types::BackendRuntimeHealth};
use self::types::BackendRuntimeHealth;
pub(crate) use super::types;
use super::{
BackendRuntimeConfig, RuntimeError, RuntimeResult, migrations::migrate_runtime_tables, napi_error, to_napi_error,
};
pub(super) fn token_hash(token: &str) -> String {
hex::encode(Sha256::digest(token.as_bytes()))
@@ -30,7 +27,7 @@ pub(super) fn token_hash(token: &str) -> String {
#[napi_derive::napi]
pub struct BackendRuntime {
config: RuntimeConfig,
config: RwLock<BackendRuntimeConfig>,
pool: Mutex<Option<PgPool>>,
}
@@ -39,29 +36,37 @@ impl BackendRuntime {
#[napi(constructor)]
pub fn new() -> Result<Self> {
Ok(Self {
config: RuntimeConfig::from_config_files()?,
config: RwLock::new(BackendRuntimeConfig::from_config_files().map_err(to_napi_error)?),
pool: Mutex::new(None),
})
}
#[napi]
pub async fn start(&self) -> Result<()> {
self.start_inner().await.map_err(to_napi_error)
}
async fn start_inner(&self) -> RuntimeResult<()> {
let mut guard = self.pool.lock().await;
if guard.is_some() {
return Ok(());
}
let database_url = self.config()?.database_url;
let pool = PgPoolOptions::new()
.max_connections(5)
.acquire_timeout(Duration::from_secs(5))
.connect(&self.config.database_url)
.connect(&database_url)
.await
.map_err(|err| napi_error(format!("BackendRuntime failed to connect postgres: {err}")))?;
.map_err(|err| RuntimeError::database("BackendRuntime failed to connect postgres", err))?;
sqlx::query("SELECT 1")
.execute(&pool)
.await
.map_err(|err| napi_error(format!("BackendRuntime postgres health check failed: {err}")))?;
.map_err(|err| RuntimeError::database("BackendRuntime postgres health check failed", err))?;
let config = self.config()?.with_db_overrides(&pool).await?;
self.update_config(config)?;
*guard = Some(pool);
Ok(())
@@ -91,38 +96,38 @@ impl BackendRuntime {
Ok(BackendRuntimeHealth {
started: pool.is_some(),
database_connected,
object_storage_configured: self.config.storage.is_some(),
})
}
#[napi]
pub async fn run_migrations(&self) -> Result<()> {
let pool = self.pool().await?;
migrate_runtime_tables(&pool).await
migrate_runtime_tables(&pool).await.map_err(to_napi_error)
}
async fn pool(&self) -> Result<PgPool> {
pub(crate) async fn pool(&self) -> RuntimeResult<PgPool> {
self
.pool
.lock()
.await
.as_ref()
.cloned()
.ok_or_else(|| napi_error("BackendRuntime must be started before using postgres operations"))
}
}
async fn migrate_runtime_tables(pool: &PgPool) -> Result<()> {
for statement in RUNTIME_MIGRATIONS
.split(';')
.map(str::trim)
.filter(|statement| !statement.is_empty())
{
sqlx::query(statement)
.execute(pool)
.await
.map_err(|err| napi_error(format!("BackendRuntime migration failed: {err}")))?;
.ok_or_else(|| RuntimeError::invalid_state("BackendRuntime must be started before using postgres operations"))
}
Ok(())
pub(crate) fn config(&self) -> RuntimeResult<BackendRuntimeConfig> {
self
.config
.read()
.map(|config| config.clone())
.map_err(|_| RuntimeError::invalid_state("BackendRuntime config lock poisoned"))
}
fn update_config(&self, config: BackendRuntimeConfig) -> RuntimeResult<()> {
*self
.config
.write()
.map_err(|_| RuntimeError::invalid_state("BackendRuntime config lock poisoned"))? = config;
Ok(())
}
}
@@ -1,6 +1,4 @@
use napi::Result;
use super::{auth_challenge_purpose, dto::RuntimeStateRows};
use super::{Result, auth_challenge_purpose, dto::RuntimeStateRows};
pub(super) async fn create(
rows: &RuntimeStateRows,
@@ -1,10 +1,6 @@
use napi::Result;
use super::dto::{RuntimeStateInsertPayload, RuntimeStatePayloadRow, RuntimeStateRows};
use crate::backend_runtime::{
constants::{BYOK_LOCAL_LEASE_ACTIVE_PURPOSE, BYOK_LOCAL_LEASE_PURPOSE},
error::napi_error,
types::RuntimeByokLocalLeaseRecord,
use super::{
BYOK_LOCAL_LEASE_ACTIVE_PURPOSE, BYOK_LOCAL_LEASE_PURPOSE, Result, RuntimeByokLocalLeaseRecord, RuntimeError,
dto::{RuntimeStateInsertPayload, RuntimeStatePayloadRow, RuntimeStateRows},
};
pub(super) async fn get(rows: &RuntimeStateRows, lease_id: String) -> Result<Option<RuntimeByokLocalLeaseRecord>> {
@@ -19,7 +15,7 @@ pub(super) async fn create(
ttl_ms: i64,
) -> Result<RuntimeByokLocalLeaseRecord> {
if ttl_ms <= 0 {
return Err(napi_error("BYOK local lease ttl must be positive"));
return Err(RuntimeError::invalid_input("BYOK local lease ttl must be positive"));
}
let mut tx = rows.begin("RuntimeState BYOK local lease").await?;
@@ -27,7 +23,7 @@ pub(super) async fn create(
.bind(&active_key)
.execute(&mut *tx)
.await
.map_err(|err| napi_error(format!("RuntimeState BYOK local lease active lock failed: {err}")))?;
.map_err(|err| RuntimeError::database("RuntimeState BYOK local lease active lock failed", err))?;
if let Some(active) = rows
.active_payload_with_expires_for_update_in_tx(
@@ -43,11 +39,9 @@ pub(super) async fn create(
None => None,
};
if let Some(lease) = existing_lease {
tx.commit().await.map_err(|err| {
napi_error(format!(
"RuntimeState BYOK local lease transaction commit failed: {err}"
))
})?;
tx.commit()
.await
.map_err(|err| RuntimeError::database("RuntimeState BYOK local lease transaction commit failed", err))?;
return Ok(lease);
}
@@ -89,11 +83,9 @@ pub(super) async fn create(
)
.await?;
tx.commit().await.map_err(|err| {
napi_error(format!(
"RuntimeState BYOK local lease transaction commit failed: {err}"
))
})?;
tx.commit()
.await
.map_err(|err| RuntimeError::database("RuntimeState BYOK local lease transaction commit failed", err))?;
Ok(RuntimeByokLocalLeaseRecord {
lease_id,
@@ -1,7 +1,8 @@
use napi::Result;
use sqlx::{PgPool, Row};
use crate::backend_runtime::{error::napi_error, token_hash};
use super::{RuntimeError, RuntimeResult, token_hash};
type Result<T> = RuntimeResult<T>;
pub(super) struct RuntimeStatePayloadRow {
pub(super) payload: serde_json::Value,
@@ -42,7 +43,7 @@ impl RuntimeStateRows {
.pool
.begin()
.await
.map_err(|err| napi_error(format!("{context} transaction failed: {err}")))
.map_err(|err| RuntimeError::database(format!("{context} transaction failed"), err))
}
pub(super) async fn insert_payload(
@@ -67,7 +68,7 @@ impl RuntimeStateRows {
.bind(ttl_ms as f64)
.execute(&self.pool)
.await
.map_err(|err| napi_error(format!("{context} failed: {err}")))?;
.map_err(|err| RuntimeError::database(context, err))?;
Ok(())
}
@@ -95,7 +96,7 @@ impl RuntimeStateRows {
.bind(ttl_ms as f64)
.execute(&self.pool)
.await
.map_err(|err| napi_error(format!("{context} failed: {err}")))?
.map_err(|err| RuntimeError::database(context, err))?
.rows_affected()
== 1;
@@ -131,7 +132,7 @@ impl RuntimeStateRows {
.bind(ttl_ms as f64)
.execute(&self.pool)
.await
.map_err(|err| napi_error(format!("{context} failed: {err}")))?;
.map_err(|err| RuntimeError::database(context, err))?;
Ok(())
}
@@ -156,7 +157,7 @@ impl RuntimeStateRows {
.bind(token_hash(token))
.fetch_optional(&self.pool)
.await
.map_err(|err| napi_error(format!("{context} failed: {err}")))?;
.map_err(|err| RuntimeError::database(context, err))?;
Ok(row.map(|row| row.get::<serde_json::Value, _>("payload")))
}
@@ -181,7 +182,7 @@ impl RuntimeStateRows {
.bind(token_hash(token))
.fetch_optional(&self.pool)
.await
.map_err(|err| napi_error(format!("{context} failed: {err}")))?;
.map_err(|err| RuntimeError::database(context, err))?;
Ok(row.map(payload_row))
}
@@ -208,7 +209,7 @@ impl RuntimeStateRows {
.bind(token_hash(token))
.fetch_optional(&self.pool)
.await
.map_err(|err| napi_error(format!("{context} failed: {err}")))?;
.map_err(|err| RuntimeError::database(context, err))?;
Ok(row.map(|row| row.get::<serde_json::Value, _>("payload")))
}
@@ -235,7 +236,7 @@ impl RuntimeStateRows {
.bind(token_hash(token))
.fetch_optional(&self.pool)
.await
.map_err(|err| napi_error(format!("{context} failed: {err}")))?;
.map_err(|err| RuntimeError::database(context, err))?;
Ok(row.map(payload_row))
}
@@ -262,7 +263,7 @@ impl RuntimeStateRows {
.bind(token_hash(token))
.fetch_optional(&mut **tx)
.await
.map_err(|err| napi_error(format!("{context} failed: {err}")))?;
.map_err(|err| RuntimeError::database(context, err))?;
Ok(row.map(payload_row))
}
@@ -288,7 +289,7 @@ impl RuntimeStateRows {
.bind(token_hash(token))
.fetch_optional(&mut **tx)
.await
.map_err(|err| napi_error(format!("{context} failed: {err}")))?;
.map_err(|err| RuntimeError::database(context, err))?;
Ok(row.map(|row| RuntimeStateLockedRow {
payload: row.get("payload"),
@@ -316,7 +317,7 @@ impl RuntimeStateRows {
.bind(input.ttl_ms as f64)
.fetch_one(&mut **tx)
.await
.map_err(|err| napi_error(format!("{} failed: {err}", input.context)))?;
.map_err(|err| RuntimeError::database(input.context, err))?;
Ok(row.get::<i64, _>("expires_at_ms"))
}
@@ -348,7 +349,7 @@ impl RuntimeStateRows {
.bind(input.ttl_ms as f64)
.fetch_optional(&mut **tx)
.await
.map_err(|err| napi_error(format!("{} failed: {err}", input.context)))?;
.map_err(|err| RuntimeError::database(input.context, err))?;
Ok(row.map(|row| row.get::<i64, _>("expires_at_ms")))
}
@@ -375,7 +376,7 @@ impl RuntimeStateRows {
.bind(attempts)
.execute(&mut **tx)
.await
.map_err(|err| napi_error(format!("{context} failed: {err}")))?;
.map_err(|err| RuntimeError::database(context, err))?;
Ok(())
}
@@ -392,7 +393,7 @@ impl RuntimeStateRows {
.bind(token_hash(token))
.execute(&mut **tx)
.await
.map_err(|err| napi_error(format!("{context} failed: {err}")))?;
.map_err(|err| RuntimeError::database(context, err))?;
Ok(())
}
@@ -413,7 +414,7 @@ impl RuntimeStateRows {
.bind(limit)
.execute(&self.pool)
.await
.map_err(|err| napi_error(format!("{context} failed: {err}")))?;
.map_err(|err| RuntimeError::database(context, err))?;
Ok(result.rows_affected() as i64)
}
@@ -440,7 +441,7 @@ impl RuntimeStateRows {
.bind(limit)
.execute(&self.pool)
.await
.map_err(|err| napi_error(format!("{context} failed: {err}")))?;
.map_err(|err| RuntimeError::database(context, err))?;
Ok(result.rows_affected() as i64)
}
@@ -1,10 +1,7 @@
use napi::Result;
use super::dto::{RuntimeStateInsertPayload, RuntimeStatePayloadRow, RuntimeStateRows};
use crate::backend_runtime::{
constants::{WORKSPACE_INVITE_LINK_ID_PURPOSE, WORKSPACE_INVITE_LINK_WORKSPACE_PURPOSE},
error::napi_error,
types::RuntimeWorkspaceInviteLinkRecord,
use super::{
Result, RuntimeError, RuntimeWorkspaceInviteLinkRecord, WORKSPACE_INVITE_LINK_ID_PURPOSE,
WORKSPACE_INVITE_LINK_WORKSPACE_PURPOSE,
dto::{RuntimeStateInsertPayload, RuntimeStatePayloadRow, RuntimeStateRows},
};
pub(super) async fn get_by_workspace(
@@ -29,7 +26,9 @@ pub(super) async fn create(
ttl_ms: i64,
) -> Result<RuntimeWorkspaceInviteLinkRecord> {
if ttl_ms <= 0 {
return Err(napi_error("workspace invite link ttl must be positive"));
return Err(RuntimeError::invalid_input(
"workspace invite link ttl must be positive",
));
}
let mut tx = rows.begin("RuntimeState workspace invite link").await?;
@@ -37,16 +36,14 @@ pub(super) async fn create(
.bind(&workspace_id)
.execute(&mut *tx)
.await
.map_err(|err| napi_error(format!("RuntimeState workspace invite link active lock failed: {err}")))?;
.map_err(|err| RuntimeError::database("RuntimeState workspace invite link active lock failed", err))?;
if let Some(existing) =
get_by_key_in_tx(rows, &mut tx, WORKSPACE_INVITE_LINK_WORKSPACE_PURPOSE, &workspace_id).await?
{
tx.commit().await.map_err(|err| {
napi_error(format!(
"RuntimeState workspace invite link transaction commit failed: {err}"
))
})?;
tx.commit()
.await
.map_err(|err| RuntimeError::database("RuntimeState workspace invite link transaction commit failed", err))?;
return Ok(existing);
}
@@ -71,12 +68,11 @@ pub(super) async fn create(
.await?
else {
let existing = get_by_key_in_tx(rows, &mut tx, WORKSPACE_INVITE_LINK_WORKSPACE_PURPOSE, &workspace_id).await?;
tx.commit().await.map_err(|err| {
napi_error(format!(
"RuntimeState workspace invite link transaction commit failed: {err}"
))
})?;
return existing.ok_or_else(|| napi_error("RuntimeState workspace invite link active conflict missing row"));
tx.commit()
.await
.map_err(|err| RuntimeError::database("RuntimeState workspace invite link transaction commit failed", err))?;
return existing
.ok_or_else(|| RuntimeError::invalid_state("RuntimeState workspace invite link active conflict missing row"));
};
rows
.insert_payload_returning_expires_in_tx(
@@ -92,11 +88,9 @@ pub(super) async fn create(
)
.await?;
tx.commit().await.map_err(|err| {
napi_error(format!(
"RuntimeState workspace invite link transaction commit failed: {err}"
))
})?;
tx.commit()
.await
.map_err(|err| RuntimeError::database("RuntimeState workspace invite link transaction commit failed", err))?;
Ok(RuntimeWorkspaceInviteLinkRecord {
workspace_id,
@@ -110,11 +104,9 @@ pub(super) async fn revoke(rows: &RuntimeStateRows, workspace_id: String) -> Res
let mut tx = rows.begin("RuntimeState workspace invite link").await?;
let existing = get_by_key_in_tx(rows, &mut tx, WORKSPACE_INVITE_LINK_WORKSPACE_PURPOSE, &workspace_id).await?;
let Some(existing) = existing else {
tx.commit().await.map_err(|err| {
napi_error(format!(
"RuntimeState workspace invite link transaction commit failed: {err}"
))
})?;
tx.commit()
.await
.map_err(|err| RuntimeError::database("RuntimeState workspace invite link transaction commit failed", err))?;
return Ok(false);
};
@@ -135,11 +127,9 @@ pub(super) async fn revoke(rows: &RuntimeStateRows, workspace_id: String) -> Res
)
.await?;
tx.commit().await.map_err(|err| {
napi_error(format!(
"RuntimeState workspace invite link transaction commit failed: {err}"
))
})?;
tx.commit()
.await
.map_err(|err| RuntimeError::database("RuntimeState workspace invite link transaction commit failed", err))?;
Ok(true)
}
@@ -175,19 +165,19 @@ fn record_from_row(row: RuntimeStatePayloadRow) -> Result<RuntimeWorkspaceInvite
.payload
.get("workspaceId")
.and_then(serde_json::Value::as_str)
.ok_or_else(|| napi_error("RuntimeState workspace invite link payload missing workspaceId"))?
.ok_or_else(|| RuntimeError::invalid_state("RuntimeState workspace invite link payload missing workspaceId"))?
.to_string(),
invite_id: row
.payload
.get("inviteId")
.and_then(serde_json::Value::as_str)
.ok_or_else(|| napi_error("RuntimeState workspace invite link payload missing inviteId"))?
.ok_or_else(|| RuntimeError::invalid_state("RuntimeState workspace invite link payload missing inviteId"))?
.to_string(),
inviter_user_id: row
.payload
.get("inviterUserId")
.and_then(serde_json::Value::as_str)
.ok_or_else(|| napi_error("RuntimeState workspace invite link payload missing inviterUserId"))?
.ok_or_else(|| RuntimeError::invalid_state("RuntimeState workspace invite link payload missing inviterUserId"))?
.to_string(),
expires_at_ms: row.expires_at_ms,
})
@@ -1,10 +1,6 @@
use napi::Result;
use super::dto::RuntimeStateRows;
use crate::backend_runtime::{
constants::{MAGIC_LINK_OTP_PURPOSE, MAX_MAGIC_LINK_OTP_ATTEMPTS},
error::napi_error,
types::RuntimeMagicLinkOtpConsumeResult,
use super::{
MAGIC_LINK_OTP_PURPOSE, MAX_MAGIC_LINK_OTP_ATTEMPTS, Result, RuntimeError, RuntimeMagicLinkOtpConsumeResult,
dto::RuntimeStateRows,
};
impl RuntimeMagicLinkOtpConsumeResult {
@@ -34,7 +30,7 @@ pub(super) async fn upsert(
ttl_ms: i64,
) -> Result<()> {
if ttl_ms <= 0 {
return Err(napi_error("magic link otp ttl must be positive"));
return Err(RuntimeError::invalid_input("magic link otp ttl must be positive"));
}
let payload = serde_json::json!({
@@ -73,11 +69,9 @@ pub(super) async fn consume(
.await?;
let Some(row) = row else {
tx.rollback().await.map_err(|err| {
napi_error(format!(
"RuntimeState magic link otp transaction rollback failed: {err}"
))
})?;
tx.rollback()
.await
.map_err(|err| RuntimeError::database("RuntimeState magic link otp transaction rollback failed", err))?;
return Ok(RuntimeMagicLinkOtpConsumeResult::fail("not_found"));
};
@@ -96,7 +90,7 @@ pub(super) async fn consume(
.await?;
tx.commit()
.await
.map_err(|err| napi_error(format!("RuntimeState magic link otp transaction commit failed: {err}")))?;
.map_err(|err| RuntimeError::database("RuntimeState magic link otp transaction commit failed", err))?;
return Ok(RuntimeMagicLinkOtpConsumeResult::fail("expired"));
}
@@ -104,7 +98,7 @@ pub(super) async fn consume(
if stored_client_nonce.is_some() && stored_client_nonce != client_nonce.as_deref() {
tx.commit()
.await
.map_err(|err| napi_error(format!("RuntimeState magic link otp transaction commit failed: {err}")))?;
.map_err(|err| RuntimeError::database("RuntimeState magic link otp transaction commit failed", err))?;
return Ok(RuntimeMagicLinkOtpConsumeResult::fail("nonce_mismatch"));
}
@@ -119,7 +113,7 @@ pub(super) async fn consume(
.await?;
tx.commit()
.await
.map_err(|err| napi_error(format!("RuntimeState magic link otp transaction commit failed: {err}")))?;
.map_err(|err| RuntimeError::database("RuntimeState magic link otp transaction commit failed", err))?;
return Ok(RuntimeMagicLinkOtpConsumeResult::fail("locked"));
}
@@ -137,7 +131,7 @@ pub(super) async fn consume(
.await?;
tx.commit()
.await
.map_err(|err| napi_error(format!("RuntimeState magic link otp transaction commit failed: {err}")))?;
.map_err(|err| RuntimeError::database("RuntimeState magic link otp transaction commit failed", err))?;
return Ok(RuntimeMagicLinkOtpConsumeResult::fail("locked"));
}
@@ -153,14 +147,14 @@ pub(super) async fn consume(
tx.commit()
.await
.map_err(|err| napi_error(format!("RuntimeState magic link otp transaction commit failed: {err}")))?;
.map_err(|err| RuntimeError::database("RuntimeState magic link otp transaction commit failed", err))?;
return Ok(RuntimeMagicLinkOtpConsumeResult::fail("invalid_otp"));
}
let token = payload
.get("token")
.and_then(serde_json::Value::as_str)
.ok_or_else(|| napi_error("RuntimeState magic link otp payload missing token"))?
.ok_or_else(|| RuntimeError::invalid_state("RuntimeState magic link otp payload missing token"))?
.to_string();
rows
.delete_by_key_in_tx(
@@ -172,7 +166,7 @@ pub(super) async fn consume(
.await?;
tx.commit()
.await
.map_err(|err| napi_error(format!("RuntimeState magic link otp transaction commit failed: {err}")))?;
.map_err(|err| RuntimeError::database("RuntimeState magic link otp transaction commit failed", err))?;
Ok(RuntimeMagicLinkOtpConsumeResult::ok(token))
}
@@ -1,8 +1,10 @@
use napi::Result;
use super::{
BackendRuntime,
error::napi_error,
use super::{BackendRuntime, RuntimeError, RuntimeResult, napi_error};
pub(super) use super::{
constants::{
BYOK_LOCAL_LEASE_ACTIVE_PURPOSE, BYOK_LOCAL_LEASE_PURPOSE, MAGIC_LINK_OTP_PURPOSE, MAX_MAGIC_LINK_OTP_ATTEMPTS,
WORKSPACE_INVITE_LINK_ID_PURPOSE, WORKSPACE_INVITE_LINK_WORKSPACE_PURPOSE,
},
token_hash,
types::{
RuntimeByokLocalLeaseRecord, RuntimeMagicLinkOtpConsumeResult, RuntimeVerificationTokenRecord,
RuntimeWorkspaceInviteLinkRecord,
@@ -18,6 +20,8 @@ mod store;
mod verification_token;
use store::RuntimeStateStore;
pub(super) type Result<T> = RuntimeResult<T>;
pub(super) fn auth_challenge_purpose(purpose: &str) -> String {
format!("auth_challenge:{purpose}")
}
@@ -35,27 +39,34 @@ impl BackendRuntime {
token: String,
payload: serde_json::Value,
ttl_ms: i64,
) -> Result<bool> {
) -> napi::Result<bool> {
if ttl_ms <= 0 {
return Err(napi_error("auth challenge ttl must be positive"));
}
RuntimeStateStore::new(self.pool().await?)
.create_auth_challenge(&purpose, &token, payload, ttl_ms)
.await
.map_err(napi::Error::from)
}
#[napi]
pub async fn get_auth_challenge(&self, purpose: String, token: String) -> Result<Option<serde_json::Value>> {
pub async fn get_auth_challenge(&self, purpose: String, token: String) -> napi::Result<Option<serde_json::Value>> {
RuntimeStateStore::new(self.pool().await?)
.get_auth_challenge(&purpose, &token)
.await
.map_err(napi::Error::from)
}
#[napi]
pub async fn consume_auth_challenge(&self, purpose: String, token: String) -> Result<Option<serde_json::Value>> {
pub async fn consume_auth_challenge(
&self,
purpose: String,
token: String,
) -> napi::Result<Option<serde_json::Value>> {
RuntimeStateStore::new(self.pool().await?)
.consume_auth_challenge(&purpose, &token)
.await
.map_err(napi::Error::from)
}
#[napi]
@@ -64,13 +75,14 @@ impl BackendRuntime {
token_type: i32,
credential: Option<String>,
ttl_ms: i64,
) -> Result<String> {
) -> napi::Result<String> {
if ttl_ms <= 0 {
return Err(napi_error("verification token ttl must be positive"));
}
RuntimeStateStore::new(self.pool().await?)
.create_verification_token(token_type, credential, ttl_ms)
.await
.map_err(napi::Error::from)
}
#[napi]
@@ -79,11 +91,12 @@ impl BackendRuntime {
token_type: i32,
token: String,
keep: Option<bool>,
) -> Result<Option<RuntimeVerificationTokenRecord>> {
) -> napi::Result<Option<RuntimeVerificationTokenRecord>> {
let keep = keep.unwrap_or(false);
RuntimeStateStore::new(self.pool().await?)
.get_verification_token(token_type, token, keep)
.await
.map_err(napi::Error::from)
}
#[napi]
@@ -93,21 +106,23 @@ impl BackendRuntime {
token: String,
credential: Option<String>,
keep: Option<bool>,
) -> Result<Option<RuntimeVerificationTokenRecord>> {
) -> napi::Result<Option<RuntimeVerificationTokenRecord>> {
let keep = keep.unwrap_or(false);
RuntimeStateStore::new(self.pool().await?)
.verify_verification_token(token_type, token, credential, keep)
.await
.map_err(napi::Error::from)
}
#[napi]
pub async fn cleanup_expired_verification_tokens(&self, limit: i64) -> Result<i64> {
pub async fn cleanup_expired_verification_tokens(&self, limit: i64) -> napi::Result<i64> {
if limit <= 0 {
return Err(napi_error("verification token cleanup limit must be positive"));
}
RuntimeStateStore::new(self.pool().await?)
.cleanup_expired_verification_tokens(limit)
.await
.map_err(napi::Error::from)
}
#[napi]
@@ -118,10 +133,11 @@ impl BackendRuntime {
token: String,
client_nonce: Option<String>,
ttl_ms: i64,
) -> Result<()> {
) -> napi::Result<()> {
RuntimeStateStore::new(self.pool().await?)
.upsert_magic_link_otp(email, otp_hash, token, client_nonce, ttl_ms)
.await
.map_err(napi::Error::from)
}
#[napi]
@@ -130,10 +146,11 @@ impl BackendRuntime {
email: String,
otp_hash: String,
client_nonce: Option<String>,
) -> Result<RuntimeMagicLinkOtpConsumeResult> {
) -> napi::Result<RuntimeMagicLinkOtpConsumeResult> {
RuntimeStateStore::new(self.pool().await?)
.consume_magic_link_otp(email, otp_hash, client_nonce)
.await
.map_err(napi::Error::from)
}
#[napi]
@@ -143,37 +160,41 @@ impl BackendRuntime {
invite_id: String,
inviter_user_id: String,
ttl_ms: i64,
) -> Result<RuntimeWorkspaceInviteLinkRecord> {
) -> napi::Result<RuntimeWorkspaceInviteLinkRecord> {
RuntimeStateStore::new(self.pool().await?)
.create_workspace_invite_link(workspace_id, invite_id, inviter_user_id, ttl_ms)
.await
.map_err(napi::Error::from)
}
#[napi]
pub async fn get_workspace_invite_link(
&self,
workspace_id: String,
) -> Result<Option<RuntimeWorkspaceInviteLinkRecord>> {
) -> napi::Result<Option<RuntimeWorkspaceInviteLinkRecord>> {
RuntimeStateStore::new(self.pool().await?)
.get_workspace_invite_link(workspace_id)
.await
.map_err(napi::Error::from)
}
#[napi]
pub async fn get_workspace_invite_link_by_id(
&self,
invite_id: String,
) -> Result<Option<RuntimeWorkspaceInviteLinkRecord>> {
) -> napi::Result<Option<RuntimeWorkspaceInviteLinkRecord>> {
RuntimeStateStore::new(self.pool().await?)
.get_workspace_invite_link_by_id(invite_id)
.await
.map_err(napi::Error::from)
}
#[napi]
pub async fn revoke_workspace_invite_link(&self, workspace_id: String) -> Result<bool> {
pub async fn revoke_workspace_invite_link(&self, workspace_id: String) -> napi::Result<bool> {
RuntimeStateStore::new(self.pool().await?)
.revoke_workspace_invite_link(workspace_id)
.await
.map_err(napi::Error::from)
}
#[napi]
@@ -183,35 +204,37 @@ impl BackendRuntime {
lease_id: String,
payload: serde_json::Value,
ttl_ms: i64,
) -> Result<RuntimeByokLocalLeaseRecord> {
) -> napi::Result<RuntimeByokLocalLeaseRecord> {
RuntimeStateStore::new(self.pool().await?)
.create_byok_local_lease(active_key, lease_id, payload, ttl_ms)
.await
.map_err(napi::Error::from)
}
#[napi]
pub async fn get_byok_local_lease(&self, lease_id: String) -> Result<Option<RuntimeByokLocalLeaseRecord>> {
pub async fn get_byok_local_lease(&self, lease_id: String) -> napi::Result<Option<RuntimeByokLocalLeaseRecord>> {
RuntimeStateStore::new(self.pool().await?)
.get_byok_local_lease(lease_id)
.await
.map_err(napi::Error::from)
}
#[napi]
pub async fn cleanup_expired_runtime_states(&self, limit: i64) -> Result<i64> {
pub async fn cleanup_expired_runtime_states(&self, limit: i64) -> napi::Result<i64> {
if limit <= 0 {
return Err(napi_error("runtime state cleanup limit must be positive"));
}
RuntimeStateStore::new(self.pool().await?)
.cleanup_expired_runtime_states(limit)
.await
.map_err(napi::Error::from)
}
}
#[cfg(test)]
mod tests {
use crate::backend_runtime::{
constants::{MAGIC_LINK_OTP_PURPOSE, WORKSPACE_INVITE_LINK_ID_PURPOSE, WORKSPACE_INVITE_LINK_WORKSPACE_PURPOSE},
token_hash,
use super::{
MAGIC_LINK_OTP_PURPOSE, WORKSPACE_INVITE_LINK_ID_PURPOSE, WORKSPACE_INVITE_LINK_WORKSPACE_PURPOSE, token_hash,
};
#[test]
@@ -1,10 +1,9 @@
use napi::Result;
use sqlx::PgPool;
use super::{auth_challenge, byok_local_lease, dto::RuntimeStateRows, invite_link, magic_link_otp, verification_token};
use crate::backend_runtime::types::{
RuntimeByokLocalLeaseRecord, RuntimeMagicLinkOtpConsumeResult, RuntimeVerificationTokenRecord,
RuntimeWorkspaceInviteLinkRecord,
use super::{
Result, RuntimeByokLocalLeaseRecord, RuntimeMagicLinkOtpConsumeResult, RuntimeVerificationTokenRecord,
RuntimeWorkspaceInviteLinkRecord, auth_challenge, byok_local_lease, dto::RuntimeStateRows, invite_link,
magic_link_otp, verification_token,
};
pub(super) struct RuntimeStateStore {
@@ -1,12 +1,11 @@
use napi::Result;
use sqlx::{PgPool, Row};
use uuid::Uuid;
use super::{
Result, RuntimeError, RuntimeVerificationTokenRecord,
dto::{RuntimeStatePayloadRow, RuntimeStateRows},
verification_token_purpose,
token_hash, verification_token_purpose,
};
use crate::backend_runtime::{error::napi_error, token_hash, types::RuntimeVerificationTokenRecord};
pub(super) async fn create(
rows: &RuntimeStateRows,
@@ -64,7 +63,7 @@ pub(super) async fn verify(
} else {
consume_payload_with_credential(rows.pool(), &purpose, &token, credential.as_deref()).await
}
.map_err(|err| napi_error(format!("RuntimeState verification token verify failed: {err}")))?;
.map_err(|err| RuntimeError::database("RuntimeState verification token verify failed", err))?;
Ok(row.map(|row| record_from_row(token_type, token, row)))
}
@@ -1,6 +1,10 @@
use anyhow::{Context, Result as AnyResult, anyhow};
use super::{runtime_state::*, *};
use super::{
super::migrations::{RUNTIME_MIGRATIONS, migrate_runtime_tables},
runtime_state::*,
*,
};
static PG_TEST_LOCK: std::sync::OnceLock<tokio::sync::Mutex<()>> = std::sync::OnceLock::new();
const TEST_VERIFICATION_TOKEN_TYPE: i32 = 99_999;
@@ -14,6 +18,10 @@ fn migrations_include_runtime_tables_without_worker_heartbeats() {
assert!(RUNTIME_MIGRATIONS.contains("runtime_states"));
assert!(RUNTIME_MIGRATIONS.contains("runtime_gates"));
assert!(RUNTIME_MIGRATIONS.contains("runtime_leases"));
assert!(RUNTIME_MIGRATIONS.contains("blob_reconciliation_runs"));
assert!(RUNTIME_MIGRATIONS.contains("blob_reconciliation_checkpoints"));
assert!(RUNTIME_MIGRATIONS.contains("doc_blob_refs"));
assert!(RUNTIME_MIGRATIONS.contains("blob_cleanup_candidates"));
assert!(!RUNTIME_MIGRATIONS.contains("runtime_worker_heartbeats"));
}
@@ -66,10 +74,7 @@ async fn runtime_from_database_url() -> AnyResult<Option<BackendRuntime>> {
.context("cleanup runtime_leases for backend runtime tests")?;
Ok(Some(BackendRuntime {
config: RuntimeConfig {
database_url,
storage: None,
},
config: std::sync::RwLock::new(BackendRuntimeConfig { database_url }),
pool: Mutex::new(Some(pool)),
}))
}
@@ -126,7 +131,7 @@ async fn runtime_gate_sql_semantics_are_atomic_and_ttl_bound() {
let mut tasks = Vec::new();
for _ in 0..16 {
let runtime = BackendRuntime {
config: runtime.config.clone(),
config: std::sync::RwLock::new(runtime.config().unwrap()),
pool: Mutex::new(Some(runtime.pool().await.unwrap())),
};
tasks.push(tokio::spawn(async move {
@@ -185,7 +190,7 @@ async fn coordination_lease_sql_semantics_are_fenced_and_ttl_bound() {
let mut tasks = Vec::new();
for index in 0..16 {
let runtime = BackendRuntime {
config: runtime.config.clone(),
config: std::sync::RwLock::new(runtime.config().unwrap()),
pool: Mutex::new(Some(runtime.pool().await.unwrap())),
};
tasks.push(tokio::spawn(async move {
@@ -389,7 +394,7 @@ async fn verification_token_sql_state_machine_handles_keep_verify_and_cleanup()
let mut tasks = Vec::new();
for _ in 0..16 {
let runtime = BackendRuntime {
config: runtime.config.clone(),
config: std::sync::RwLock::new(runtime.config().unwrap()),
pool: Mutex::new(Some(runtime.pool().await.unwrap())),
};
let token = concurrent_token.clone();
@@ -1,11 +1,10 @@
use napi::Result;
use sqlx::{FromRow, PgPool, Postgres, Row, Transaction};
use tokio::time::{Duration as TokioDuration, sleep};
use super::{
BackendRuntime,
BackendRuntime, RuntimeError, RuntimeResult,
constants::{WORKSPACE_STATS_LEASE_KEY, WORKSPACE_STATS_LOCK_NAMESPACE, WORKSPACE_STATS_REFRESH_LOCK_KEY},
error::napi_error,
napi_error,
types::{
CoordinationLeaseGrant, RuntimeWorkspaceStatsDailyRecalibrationResult, RuntimeWorkspaceStatsRecalibrationResult,
RuntimeWorkspaceStatsRefreshResult, RuntimeWorkspaceStatsSnapshotResult,
@@ -22,13 +21,13 @@ impl BackendRuntime {
batch_limit: i64,
owner: String,
lease_ttl_ms: i64,
) -> Result<RuntimeWorkspaceStatsRefreshResult> {
) -> napi::Result<RuntimeWorkspaceStatsRefreshResult> {
if batch_limit <= 0 {
return Err(napi_error("workspace stats dirty refresh limit must be positive"));
}
let Some(lease) = self
.acquire_coordination_lease(WORKSPACE_STATS_LEASE_KEY.to_string(), owner, lease_ttl_ms)
.acquire_coordination_lease_inner(WORKSPACE_STATS_LEASE_KEY.to_string(), owner, lease_ttl_ms)
.await?
else {
return Ok(RuntimeWorkspaceStatsRefreshResult {
@@ -46,7 +45,7 @@ impl BackendRuntime {
.await;
release_workspace_stats_lease(self, lease).await?;
result
Ok(result?)
}
#[napi]
@@ -56,13 +55,13 @@ impl BackendRuntime {
batch_limit: i64,
owner: String,
lease_ttl_ms: i64,
) -> Result<RuntimeWorkspaceStatsRecalibrationResult> {
) -> napi::Result<RuntimeWorkspaceStatsRecalibrationResult> {
if batch_limit <= 0 {
return Err(napi_error("workspace stats recalibration limit must be positive"));
}
let Some(lease) = self
.acquire_coordination_lease(WORKSPACE_STATS_LEASE_KEY.to_string(), owner, lease_ttl_ms)
.acquire_coordination_lease_inner(WORKSPACE_STATS_LEASE_KEY.to_string(), owner, lease_ttl_ms)
.await?
else {
return Ok(RuntimeWorkspaceStatsRecalibrationResult {
@@ -80,7 +79,7 @@ impl BackendRuntime {
.await;
release_workspace_stats_lease(self, lease).await?;
result
Ok(result?)
}
#[napi]
@@ -88,9 +87,9 @@ impl BackendRuntime {
&self,
owner: String,
lease_ttl_ms: i64,
) -> Result<RuntimeWorkspaceStatsSnapshotResult> {
) -> napi::Result<RuntimeWorkspaceStatsSnapshotResult> {
let Some(lease) = self
.acquire_coordination_lease(WORKSPACE_STATS_LEASE_KEY.to_string(), owner, lease_ttl_ms)
.acquire_coordination_lease_inner(WORKSPACE_STATS_LEASE_KEY.to_string(), owner, lease_ttl_ms)
.await?
else {
return Ok(RuntimeWorkspaceStatsSnapshotResult {
@@ -107,7 +106,7 @@ impl BackendRuntime {
.await;
release_workspace_stats_lease(self, lease).await?;
result
Ok(result?)
}
#[napi]
@@ -118,7 +117,7 @@ impl BackendRuntime {
lease_ttl_ms: i64,
lock_retry_times: i64,
lock_retry_delay_ms: i64,
) -> Result<RuntimeWorkspaceStatsDailyRecalibrationResult> {
) -> napi::Result<RuntimeWorkspaceStatsDailyRecalibrationResult> {
if batch_limit <= 0 {
return Err(napi_error("workspace stats daily recalibration limit must be positive"));
}
@@ -150,7 +149,7 @@ impl BackendRuntime {
});
};
let result = async {
let result: RuntimeResult<RuntimeWorkspaceStatsDailyRecalibrationResult> = async {
let store = WorkspaceStatsStore::new(self.pool().await?);
let mut processed = 0;
let mut last_sid = 0;
@@ -195,7 +194,7 @@ impl BackendRuntime {
.await;
release_workspace_stats_lease(self, lease).await?;
result
Ok(result?)
}
}
@@ -214,16 +213,16 @@ impl WorkspaceStatsStore {
Self { pool }
}
async fn refresh_dirty(&self, batch_limit: i64) -> Result<RuntimeWorkspaceStatsRefreshResult> {
async fn refresh_dirty(&self, batch_limit: i64) -> RuntimeResult<RuntimeWorkspaceStatsRefreshResult> {
let mut tx = self
.pool
.begin()
.await
.map_err(|err| napi_error(format!("WorkspaceStats dirty refresh transaction failed: {err}")))?;
.map_err(|err| RuntimeError::database("WorkspaceStats dirty refresh transaction failed", err))?;
if !try_transaction_lock(&mut tx).await? {
tx.commit()
.await
.map_err(|err| napi_error(format!("WorkspaceStats dirty refresh commit failed: {err}")))?;
.map_err(|err| RuntimeError::database("WorkspaceStats dirty refresh commit failed", err))?;
return Ok(RuntimeWorkspaceStatsRefreshResult {
processed: 0,
backlog: 0,
@@ -236,7 +235,7 @@ impl WorkspaceStatsStore {
if dirty.is_empty() {
tx.commit()
.await
.map_err(|err| napi_error(format!("WorkspaceStats dirty refresh commit failed: {err}")))?;
.map_err(|err| RuntimeError::database("WorkspaceStats dirty refresh commit failed", err))?;
return Ok(RuntimeWorkspaceStatsRefreshResult {
processed: 0,
backlog,
@@ -248,7 +247,7 @@ impl WorkspaceStatsStore {
clear_dirty(&mut tx, &dirty).await?;
tx.commit()
.await
.map_err(|err| napi_error(format!("WorkspaceStats dirty refresh commit failed: {err}")))?;
.map_err(|err| RuntimeError::database("WorkspaceStats dirty refresh commit failed", err))?;
Ok(RuntimeWorkspaceStatsRefreshResult {
processed: dirty.len() as i64,
@@ -257,16 +256,20 @@ impl WorkspaceStatsStore {
})
}
async fn recalibrate(&self, last_sid: i64, batch_limit: i64) -> Result<RuntimeWorkspaceStatsRecalibrationResult> {
async fn recalibrate(
&self,
last_sid: i64,
batch_limit: i64,
) -> RuntimeResult<RuntimeWorkspaceStatsRecalibrationResult> {
let mut tx = self
.pool
.begin()
.await
.map_err(|err| napi_error(format!("WorkspaceStats recalibration transaction failed: {err}")))?;
.map_err(|err| RuntimeError::database("WorkspaceStats recalibration transaction failed", err))?;
if !try_transaction_lock(&mut tx).await? {
tx.commit()
.await
.map_err(|err| napi_error(format!("WorkspaceStats recalibration commit failed: {err}")))?;
.map_err(|err| RuntimeError::database("WorkspaceStats recalibration commit failed", err))?;
return Ok(RuntimeWorkspaceStatsRecalibrationResult {
processed: 0,
last_sid,
@@ -278,7 +281,7 @@ impl WorkspaceStatsStore {
if workspaces.is_empty() {
tx.commit()
.await
.map_err(|err| napi_error(format!("WorkspaceStats recalibration commit failed: {err}")))?;
.map_err(|err| RuntimeError::database("WorkspaceStats recalibration commit failed", err))?;
return Ok(RuntimeWorkspaceStatsRecalibrationResult {
processed: 0,
last_sid,
@@ -297,7 +300,7 @@ impl WorkspaceStatsStore {
upsert_stats(&mut tx, &ids).await?;
tx.commit()
.await
.map_err(|err| napi_error(format!("WorkspaceStats recalibration commit failed: {err}")))?;
.map_err(|err| RuntimeError::database("WorkspaceStats recalibration commit failed", err))?;
Ok(RuntimeWorkspaceStatsRecalibrationResult {
processed: ids.len() as i64,
@@ -306,16 +309,16 @@ impl WorkspaceStatsStore {
})
}
async fn write_daily_snapshot(&self) -> Result<RuntimeWorkspaceStatsSnapshotResult> {
async fn write_daily_snapshot(&self) -> RuntimeResult<RuntimeWorkspaceStatsSnapshotResult> {
let mut tx = self
.pool
.begin()
.await
.map_err(|err| napi_error(format!("WorkspaceStats daily snapshot transaction failed: {err}")))?;
.map_err(|err| RuntimeError::database("WorkspaceStats daily snapshot transaction failed", err))?;
if !try_transaction_lock(&mut tx).await? {
tx.commit()
.await
.map_err(|err| napi_error(format!("WorkspaceStats daily snapshot commit failed: {err}")))?;
.map_err(|err| RuntimeError::database("WorkspaceStats daily snapshot commit failed", err))?;
return Ok(RuntimeWorkspaceStatsSnapshotResult {
snapshotted: 0,
skipped: true,
@@ -324,7 +327,7 @@ impl WorkspaceStatsStore {
let snapshotted = write_daily_snapshot(&mut tx).await?;
tx.commit()
.await
.map_err(|err| napi_error(format!("WorkspaceStats daily snapshot commit failed: {err}")))?;
.map_err(|err| RuntimeError::database("WorkspaceStats daily snapshot commit failed", err))?;
Ok(RuntimeWorkspaceStatsSnapshotResult {
snapshotted,
@@ -333,9 +336,9 @@ impl WorkspaceStatsStore {
}
}
async fn release_workspace_stats_lease(runtime: &BackendRuntime, lease: CoordinationLeaseGrant) -> Result<()> {
async fn release_workspace_stats_lease(runtime: &BackendRuntime, lease: CoordinationLeaseGrant) -> RuntimeResult<()> {
let _ = runtime
.release_coordination_lease(lease.key, lease.owner, lease.fencing_token)
.release_coordination_lease_inner(lease.key, lease.owner, lease.fencing_token)
.await?;
Ok(())
}
@@ -346,10 +349,10 @@ async fn acquire_workspace_stats_lease_with_retry(
lease_ttl_ms: i64,
retry_times: i64,
retry_delay_ms: i64,
) -> Result<Option<CoordinationLeaseGrant>> {
) -> RuntimeResult<Option<CoordinationLeaseGrant>> {
for attempt in 0..retry_times {
let lease = runtime
.acquire_coordination_lease(WORKSPACE_STATS_LEASE_KEY.to_string(), owner.clone(), lease_ttl_ms)
.acquire_coordination_lease_inner(WORKSPACE_STATS_LEASE_KEY.to_string(), owner.clone(), lease_ttl_ms)
.await?;
if lease.is_some() {
return Ok(lease);
@@ -367,11 +370,11 @@ async fn retry_workspace_stats_operation<T, F, Fut>(
retry_times: i64,
retry_delay_ms: i64,
mut operation: F,
) -> Result<T>
) -> RuntimeResult<T>
where
T: WorkspaceStatsSkippable,
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<T>>,
Fut: std::future::Future<Output = RuntimeResult<T>>,
{
for attempt in 0..retry_times {
let result = operation().await?;
@@ -403,7 +406,7 @@ impl WorkspaceStatsSkippable for RuntimeWorkspaceStatsSnapshotResult {
}
}
async fn try_transaction_lock(tx: &mut Transaction<'_, Postgres>) -> Result<bool> {
async fn try_transaction_lock(tx: &mut Transaction<'_, Postgres>) -> RuntimeResult<bool> {
let row = sqlx::query(
r#"
SELECT pg_try_advisory_xact_lock(($1::bigint << 32) + $2::bigint) AS locked
@@ -413,12 +416,12 @@ async fn try_transaction_lock(tx: &mut Transaction<'_, Postgres>) -> Result<bool
.bind(WORKSPACE_STATS_REFRESH_LOCK_KEY)
.fetch_one(&mut **tx)
.await
.map_err(|err| napi_error(format!("WorkspaceStats transaction lock failed: {err}")))?;
.map_err(|err| RuntimeError::database("WorkspaceStats transaction lock failed", err))?;
Ok(row.get::<bool, _>("locked"))
}
async fn load_dirty(tx: &mut Transaction<'_, Postgres>, limit: i64) -> Result<Vec<String>> {
async fn load_dirty(tx: &mut Transaction<'_, Postgres>, limit: i64) -> RuntimeResult<Vec<String>> {
let rows = sqlx::query(
r#"
SELECT workspace_id
@@ -431,20 +434,20 @@ async fn load_dirty(tx: &mut Transaction<'_, Postgres>, limit: i64) -> Result<Ve
.bind(limit)
.fetch_all(&mut **tx)
.await
.map_err(|err| napi_error(format!("WorkspaceStats load dirty workspaces failed: {err}")))?;
.map_err(|err| RuntimeError::database("WorkspaceStats load dirty workspaces failed", err))?;
Ok(rows.into_iter().map(|row| row.get("workspace_id")).collect())
}
async fn count_dirty(tx: &mut Transaction<'_, Postgres>) -> Result<i64> {
async fn count_dirty(tx: &mut Transaction<'_, Postgres>) -> RuntimeResult<i64> {
let row = sqlx::query("SELECT COUNT(*) AS total FROM workspace_admin_stats_dirty")
.fetch_one(&mut **tx)
.await
.map_err(|err| napi_error(format!("WorkspaceStats count dirty workspaces failed: {err}")))?;
.map_err(|err| RuntimeError::database("WorkspaceStats count dirty workspaces failed", err))?;
Ok(row.get::<i64, _>("total"))
}
async fn clear_dirty(tx: &mut Transaction<'_, Postgres>, workspace_ids: &[String]) -> Result<()> {
async fn clear_dirty(tx: &mut Transaction<'_, Postgres>, workspace_ids: &[String]) -> RuntimeResult<()> {
sqlx::query(
r#"
DELETE FROM workspace_admin_stats_dirty
@@ -454,11 +457,11 @@ async fn clear_dirty(tx: &mut Transaction<'_, Postgres>, workspace_ids: &[String
.bind(workspace_ids)
.execute(&mut **tx)
.await
.map_err(|err| napi_error(format!("WorkspaceStats clear dirty workspaces failed: {err}")))?;
.map_err(|err| RuntimeError::database("WorkspaceStats clear dirty workspaces failed", err))?;
Ok(())
}
async fn upsert_stats(tx: &mut Transaction<'_, Postgres>, workspace_ids: &[String]) -> Result<()> {
async fn upsert_stats(tx: &mut Transaction<'_, Postgres>, workspace_ids: &[String]) -> RuntimeResult<()> {
if workspace_ids.is_empty() {
return Ok(());
}
@@ -467,7 +470,7 @@ async fn upsert_stats(tx: &mut Transaction<'_, Postgres>, workspace_ids: &[Strin
.bind(workspace_ids)
.execute(&mut **tx)
.await
.map_err(|err| napi_error(format!("WorkspaceStats upsert stats failed: {err}")))?;
.map_err(|err| RuntimeError::database("WorkspaceStats upsert stats failed", err))?;
Ok(())
}
@@ -475,7 +478,7 @@ async fn fetch_workspace_batch(
tx: &mut Transaction<'_, Postgres>,
last_sid: i64,
limit: i64,
) -> Result<Vec<WorkspaceSid>> {
) -> RuntimeResult<Vec<WorkspaceSid>> {
sqlx::query_as::<_, WorkspaceSid>(
r#"
SELECT id, sid
@@ -489,10 +492,10 @@ async fn fetch_workspace_batch(
.bind(limit)
.fetch_all(&mut **tx)
.await
.map_err(|err| napi_error(format!("WorkspaceStats fetch workspace batch failed: {err}")))
.map_err(|err| RuntimeError::database("WorkspaceStats fetch workspace batch failed", err))
}
async fn write_daily_snapshot(tx: &mut Transaction<'_, Postgres>) -> Result<i64> {
async fn write_daily_snapshot(tx: &mut Transaction<'_, Postgres>) -> RuntimeResult<i64> {
let result = sqlx::query(
r#"
INSERT INTO workspace_admin_stats_daily (
@@ -521,7 +524,7 @@ async fn write_daily_snapshot(tx: &mut Transaction<'_, Postgres>) -> Result<i64>
)
.execute(&mut **tx)
.await
.map_err(|err| napi_error(format!("WorkspaceStats daily snapshot failed: {err}")))?;
.map_err(|err| RuntimeError::database("WorkspaceStats daily snapshot failed", err))?;
Ok(result.rows_affected() as i64)
}
@@ -0,0 +1,200 @@
use std::{
env, fs,
path::{Path, PathBuf},
};
use serde::Deserialize;
use serde_json::Map;
use sqlx::{PgPool, Row};
use super::{RuntimeError, RuntimeResult};
#[derive(Clone, Debug)]
pub(crate) struct BackendRuntimeConfig {
pub(crate) database_url: String,
}
impl BackendRuntimeConfig {
pub(crate) fn from_config_files() -> RuntimeResult<Self> {
let app_config = app_config_from_config_files()?;
let database_url = database_url_from_env()
.or(app_config.database_url())
.unwrap_or_else(|| "postgresql://localhost:5432/affine".to_string());
Ok(Self { database_url })
}
pub(crate) async fn with_db_overrides(&self, pool: &PgPool) -> RuntimeResult<Self> {
let mut app_config = app_config_from_config_files()?;
app_config.apply_file_config(load_app_config_overrides_from_db(pool).await?);
Ok(Self {
// The DB override is loaded after this connection already exists, so it
// must not rewrite the active datasource URL.
database_url: self.database_url.clone(),
})
}
}
#[derive(Debug, Default, Deserialize)]
struct AppConfigFile {
db: Option<DbConfigFile>,
}
#[derive(Debug, Default, Deserialize)]
#[serde(rename_all = "camelCase")]
struct DbConfigFile {
datasource_url: Option<String>,
}
impl AppConfigFile {
fn database_url(&self) -> Option<String> {
self
.db
.as_ref()
.and_then(|db| db.datasource_url.clone())
.and_then(non_empty_string)
}
}
fn database_url_from_env() -> Option<String> {
env::var("DATABASE_URL").ok().and_then(non_empty_string)
}
fn non_empty_string(value: String) -> Option<String> {
if value.trim().is_empty() { None } else { Some(value) }
}
fn app_config_from_config_files() -> RuntimeResult<AppConfigFile> {
let mut merged = AppConfigFile::default();
for path in config_json_paths() {
if !path.exists() {
continue;
}
let raw = fs::read_to_string(&path).map_err(|err| RuntimeError::io("failed to read config file", err))?;
let config: AppConfigFile =
serde_json::from_str(&raw).map_err(|err| RuntimeError::json("failed to parse config file", err))?;
merged.apply_file_config(config);
}
Ok(merged)
}
impl AppConfigFile {
fn apply_file_config(&mut self, config: AppConfigFile) {
if config.db.is_some() {
self.db = config.db;
}
}
}
async fn load_app_config_overrides_from_db(pool: &PgPool) -> RuntimeResult<AppConfigFile> {
let rows = match sqlx::query("SELECT id, value FROM app_configs").fetch_all(pool).await {
Ok(rows) => rows,
Err(sqlx::Error::Database(err)) if err.code().as_deref() == Some("42P01") => return Ok(AppConfigFile::default()),
Err(err) => return Err(RuntimeError::database("failed to load app config overrides", err)),
};
app_config_from_flat_overrides(rows.into_iter().map(|row| {
let id: String = row.get("id");
let value: serde_json::Value = row.get("value");
(id, value)
}))
}
fn app_config_from_flat_overrides<I, S>(rows: I) -> RuntimeResult<AppConfigFile>
where
I: IntoIterator<Item = (S, serde_json::Value)>,
S: AsRef<str>,
{
let mut root = Map::new();
for (path, value) in rows {
let Some((module, key)) = path.as_ref().split_once('.') else {
continue;
};
root
.entry(module.to_string())
.or_insert_with(|| serde_json::Value::Object(Map::new()));
if let Some(serde_json::Value::Object(module_object)) = root.get_mut(module) {
module_object.insert(key.to_string(), value);
}
}
serde_json::from_value(serde_json::Value::Object(root))
.map_err(|err| RuntimeError::json("invalid app config overrides", err))
}
pub(super) fn config_json_paths() -> Vec<PathBuf> {
let mut paths = Vec::new();
if let Ok(exe) = env::current_exe()
&& let Some(dir) = exe.parent()
{
paths.push(config_in(dir));
}
if let Ok(cwd) = env::current_dir() {
paths.push(config_in(&cwd));
}
dedupe_paths(paths)
}
fn config_in(dir: &Path) -> PathBuf {
dir.join("config.json")
}
fn dedupe_paths(paths: Vec<PathBuf>) -> Vec<PathBuf> {
let mut deduped = Vec::new();
for path in paths {
if !deduped.contains(&path) {
deduped.push(path);
}
}
deduped
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn config_paths_are_limited_to_executable_dir_and_cwd() {
let paths = config_json_paths();
assert!(!paths.is_empty());
assert!(paths.len() <= 2);
assert!(
paths
.iter()
.all(|path| path.file_name().is_some_and(|name| name == "config.json"))
);
assert!(paths.iter().all(|path| !path.to_string_lossy().contains(".affine")));
assert!(
paths
.iter()
.all(|path| !path.to_string_lossy().contains("packages/backend/server"))
);
}
#[test]
fn blank_database_urls_are_ignored() {
assert_eq!(non_empty_string("".to_string()), None);
assert_eq!(non_empty_string(" ".to_string()), None);
assert_eq!(
non_empty_string("postgresql://affine:affine@localhost:5432/affine".to_string()),
Some("postgresql://affine:affine@localhost:5432/affine".to_string())
);
}
#[test]
fn ignores_storage_app_config_values() {
let app_config = app_config_from_flat_overrides([
(
"storages.blob.storage",
serde_json::json!({"provider": "cloudflare-r2"}),
),
("db.datasourceUrl", serde_json::json!("postgresql://example/runtime")),
])
.unwrap();
assert_eq!(
app_config.database_url().as_deref(),
Some("postgresql://example/runtime")
);
}
}
@@ -0,0 +1,126 @@
use napi::{Error, Status};
use super::storage_runtime::object_storage::error::ObjectStorageError;
pub(crate) type RuntimeResult<T> = std::result::Result<T, RuntimeError>;
#[derive(Debug, thiserror::Error)]
pub(crate) enum RuntimeError {
#[error("{0}")]
Config(String),
#[error("{0}")]
InvalidInput(String),
#[error("{0}")]
InvalidState(String),
#[error("{context}: {source}")]
Database {
context: String,
#[source]
source: sqlx::Error,
},
#[error("{context}: {source}")]
Io {
context: String,
#[source]
source: std::io::Error,
},
#[error("{context}: {source}")]
Json {
context: String,
#[source]
source: serde_json::Error,
},
#[error("{context}: {source}")]
Time {
context: String,
#[source]
source: std::time::SystemTimeError,
},
#[error(transparent)]
ObjectStorage(#[from] ObjectStorageError),
#[error("{0}")]
NapiBoundary(String),
}
impl RuntimeError {
pub(crate) fn config(message: impl Into<String>) -> Self {
Self::Config(message.into())
}
pub(crate) fn invalid_input(message: impl Into<String>) -> Self {
Self::InvalidInput(message.into())
}
pub(crate) fn invalid_state(message: impl Into<String>) -> Self {
Self::InvalidState(message.into())
}
pub(crate) fn database(context: impl Into<String>, source: sqlx::Error) -> Self {
Self::Database {
context: context.into(),
source,
}
}
pub(crate) fn io(context: impl Into<String>, source: std::io::Error) -> Self {
Self::Io {
context: context.into(),
source,
}
}
pub(crate) fn json(context: impl Into<String>, source: serde_json::Error) -> Self {
Self::Json {
context: context.into(),
source,
}
}
pub(crate) fn is_object_missing(&self) -> bool {
match self {
Self::ObjectStorage(error) => error.is_not_found(),
Self::Io { source, .. } => source.kind() == std::io::ErrorKind::NotFound,
Self::InvalidState(message)
| Self::InvalidInput(message)
| Self::Config(message)
| Self::NapiBoundary(message) => {
message.contains("NoSuchKey") || message.contains("NotFound") || message.contains("not found")
}
_ => false,
}
}
}
pub(crate) fn to_napi_error(error: RuntimeError) -> Error {
Error::new(Status::GenericFailure, error.to_string())
}
impl From<RuntimeError> for Error {
fn from(error: RuntimeError) -> Self {
to_napi_error(error)
}
}
impl From<ObjectStorageError> for Error {
fn from(error: ObjectStorageError) -> Self {
to_napi_error(RuntimeError::from(error))
}
}
impl From<Error> for RuntimeError {
fn from(error: Error) -> Self {
Self::NapiBoundary(error.to_string())
}
}
pub(crate) fn napi_error(message: impl Into<String>) -> Error {
Error::new(Status::GenericFailure, message.into())
}
@@ -0,0 +1,20 @@
use sqlx::PgPool;
use super::{RuntimeError, RuntimeResult};
pub(crate) const RUNTIME_MIGRATIONS: &str = include_str!("sql/runtime_migrations.sql");
pub(crate) async fn migrate_runtime_tables(pool: &PgPool) -> RuntimeResult<()> {
for statement in RUNTIME_MIGRATIONS
.split(';')
.map(str::trim)
.filter(|statement| !statement.is_empty())
{
sqlx::query(statement)
.execute(pool)
.await
.map_err(|err| RuntimeError::database("Runtime migration failed", err))?;
}
Ok(())
}
@@ -0,0 +1,10 @@
pub mod backend_runtime;
pub mod storage_runtime;
pub(crate) mod config;
pub(crate) mod error;
pub(crate) mod migrations;
pub(crate) mod types;
pub(crate) use config::BackendRuntimeConfig;
pub(crate) use error::{RuntimeError, RuntimeResult, napi_error, to_napi_error};
@@ -0,0 +1,112 @@
CREATE TABLE IF NOT EXISTS runtime_states (
purpose TEXT NOT NULL,
token_hash TEXT NOT NULL,
lookup_key TEXT,
payload JSONB NOT NULL,
attempts INTEGER NOT NULL DEFAULT 0,
consumed_at TIMESTAMPTZ(3),
expires_at TIMESTAMPTZ(3) NOT NULL,
created_at TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (purpose, token_hash)
);
CREATE INDEX IF NOT EXISTS runtime_states_lookup_idx
ON runtime_states (purpose, lookup_key)
WHERE lookup_key IS NOT NULL AND consumed_at IS NULL;
CREATE INDEX IF NOT EXISTS runtime_states_expires_at_idx
ON runtime_states (expires_at);
CREATE TABLE IF NOT EXISTS runtime_gates (
key TEXT PRIMARY KEY,
expires_at TIMESTAMPTZ(3) NOT NULL,
created_at TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS runtime_gates_expires_at_idx
ON runtime_gates (expires_at);
CREATE TABLE IF NOT EXISTS runtime_leases (
key TEXT PRIMARY KEY,
owner TEXT NOT NULL,
fencing_token BIGINT NOT NULL,
expires_at TIMESTAMPTZ(3) NOT NULL,
created_at TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS runtime_leases_expires_at_idx
ON runtime_leases (expires_at);
CREATE TABLE IF NOT EXISTS blob_reconciliation_runs (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
kind TEXT NOT NULL,
mode TEXT NOT NULL,
status TEXT NOT NULL,
workspace_id TEXT,
started_at TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
finished_at TIMESTAMPTZ(3),
cursor JSONB NOT NULL DEFAULT '{}',
scanned INTEGER NOT NULL DEFAULT 0,
changed INTEGER NOT NULL DEFAULT 0,
failed INTEGER NOT NULL DEFAULT 0,
metadata JSONB NOT NULL DEFAULT '{}'
);
CREATE INDEX IF NOT EXISTS blob_reconciliation_runs_workspace_idx
ON blob_reconciliation_runs (workspace_id, started_at DESC);
CREATE TABLE IF NOT EXISTS blob_reconciliation_checkpoints (
kind TEXT NOT NULL,
scope TEXT NOT NULL,
status TEXT NOT NULL,
cursor JSONB NOT NULL DEFAULT '{}',
last_key TEXT,
last_sid INTEGER,
completed_at TIMESTAMPTZ(3),
updated_at TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
metadata JSONB NOT NULL DEFAULT '{}',
PRIMARY KEY (kind, scope)
);
CREATE INDEX IF NOT EXISTS blob_reconciliation_checkpoints_status_idx
ON blob_reconciliation_checkpoints (kind, status, updated_at DESC);
CREATE TABLE IF NOT EXISTS doc_blob_refs (
workspace_id TEXT NOT NULL,
doc_id TEXT NOT NULL,
blob_key TEXT NOT NULL,
block_id TEXT NOT NULL,
flavour TEXT NOT NULL,
snapshot_updated_at TIMESTAMPTZ(3) NOT NULL,
indexed_at TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
parser_version INTEGER NOT NULL,
status TEXT NOT NULL DEFAULT 'fresh',
error TEXT,
PRIMARY KEY (workspace_id, doc_id, blob_key, block_id)
);
CREATE INDEX IF NOT EXISTS doc_blob_refs_workspace_blob_idx
ON doc_blob_refs (workspace_id, blob_key);
CREATE INDEX IF NOT EXISTS doc_blob_refs_workspace_status_idx
ON doc_blob_refs (workspace_id, status);
CREATE TABLE IF NOT EXISTS blob_cleanup_candidates (
workspace_id TEXT NOT NULL,
blob_key TEXT NOT NULL,
reason TEXT NOT NULL,
status TEXT NOT NULL,
object_size BIGINT NOT NULL,
object_last_modified TIMESTAMPTZ(3),
planned_at TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
executed_at TIMESTAMPTZ(3),
run_id UUID NOT NULL,
evidence JSONB NOT NULL DEFAULT '{}',
error TEXT,
PRIMARY KEY (workspace_id, blob_key)
);
CREATE INDEX IF NOT EXISTS blob_cleanup_candidates_run_idx
ON blob_cleanup_candidates (run_id, status);
@@ -0,0 +1,358 @@
use std::{
io::{BufReader, Cursor},
path::PathBuf,
time::SystemTime,
};
use assetpack_core::{
Codec, FileHint, FileTransformConfig, Hash32, ObjectKind, Pipeline, PipelineConfig, SqliteStore, TransformRegistry,
TransformSelector, build_recipe, pack::ObjectRecord, parse_recipe_checked,
};
use sqlx::Row;
use super::{
FsStorageConfig, MAX_BLOB_SIZE, ObjectGetResult, ObjectListEntry, ObjectMetadata, ObjectPutMetadata, RuntimeError,
RuntimeResult, fs_bucket_path, normalize_storage_key, system_time_ms,
};
pub(super) async fn put(
config: &FsStorageConfig,
scope: &str,
key: &str,
body: Vec<u8>,
metadata: ObjectPutMetadata,
) -> RuntimeResult<ObjectMetadata> {
normalize_storage_key(key)?;
let metadata = metadata.complete_for_body(&body);
let content_length = metadata.content_length.unwrap_or(body.len() as i64);
if content_length != body.len() as i64 {
return Err(RuntimeError::invalid_input(
"Assetpack contentLength does not match body length",
));
}
if !(0..=MAX_BLOB_SIZE).contains(&content_length) {
return Err(RuntimeError::invalid_input(
"Assetpack contentLength exceeds supported blob size",
));
}
let store = open_store(config).await?;
let transform_config = FileTransformConfig::default();
let bucket_path = fs_bucket_path(config);
let selector = TransformSelector::new(
transform_config.clone(),
transform_config.resolved_temp_dir(&bucket_path),
assetpack_transform_precomp2::default_specs(),
);
let original_hash = Hash32::sha3_256(&body);
let hint = FileHint {
size: body.len() as u64,
extension: extension_from_key(key),
head: Some(body.iter().take(4096).copied().collect()),
};
let plan = Pipeline::new(PipelineConfig::default())
.run(body, &hint, original_hash, Some(&selector))
.map_err(|err| RuntimeError::invalid_state(format!("Assetpack pipeline failed: {err}")))?;
let chunks = plan
.chunks
.iter()
.map(|chunk| (chunk.hash, chunk.raw_len))
.collect::<Vec<_>>();
let recipe = build_recipe(
plan.original_size,
&chunks,
plan.original_hash,
plan.transform_id,
plan.transform_version,
);
let recipe_hash = Hash32::sha3_256(&recipe);
let mut objects = Vec::with_capacity(plan.chunks.len() + 1);
for chunk in plan.chunks {
let Some(payload) = chunk.payload else {
return Err(RuntimeError::invalid_state(
"Assetpack pipeline unexpectedly discarded chunk payload",
));
};
objects.push(ObjectRecord {
hash: chunk.hash,
kind: ObjectKind::Chunk,
size: chunk.raw_len as u64,
codec: chunk.codec,
content: payload,
});
}
objects.push(ObjectRecord {
hash: recipe_hash,
kind: ObjectKind::Recipe,
size: recipe.len() as u64,
codec: Codec::Raw,
content: recipe,
});
let mut tx = store
.begin_write_tx()
.await
.map_err(|err| RuntimeError::invalid_state(format!("Assetpack begin write failed: {err}")))?;
store
.put_objects_batch_tx(&mut tx, &objects)
.await
.map_err(|err| RuntimeError::invalid_state(format!("Assetpack object write failed: {err}")))?;
store
.put_file_recipe_cache_batch_tx(&mut tx, &[(original_hash, recipe_hash)])
.await
.map_err(|err| RuntimeError::invalid_state(format!("Assetpack recipe cache write failed: {err}")))?;
let object_metadata = ObjectMetadata {
content_type: metadata
.content_type
.unwrap_or_else(|| "application/octet-stream".to_string()),
content_length,
last_modified_ms: system_time_ms(SystemTime::now())?,
checksum_crc32: metadata.checksum_crc32,
};
sqlx::query(
r#"
INSERT INTO storage_assetpack_blobs
(scope, key, recipe_hash, content_type, content_length, checksum_crc32, last_modified_ms)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)
ON CONFLICT (scope, key)
DO UPDATE SET
recipe_hash = excluded.recipe_hash,
content_type = excluded.content_type,
content_length = excluded.content_length,
checksum_crc32 = excluded.checksum_crc32,
last_modified_ms = excluded.last_modified_ms
"#,
)
.bind(scope)
.bind(key)
.bind(recipe_hash.to_hex())
.bind(&object_metadata.content_type)
.bind(object_metadata.content_length)
.bind(&object_metadata.checksum_crc32)
.bind(object_metadata.last_modified_ms)
.execute(&mut *tx)
.await
.map_err(|err| RuntimeError::database("Assetpack manifest write failed", err))?;
tx.commit()
.await
.map_err(|err| RuntimeError::invalid_state(format!("Assetpack commit failed: {err}")))?;
Ok(object_metadata)
}
pub(super) async fn head(config: &FsStorageConfig, scope: &str, key: &str) -> RuntimeResult<Option<ObjectMetadata>> {
normalize_storage_key(key)?;
let store = open_store(config).await?;
manifest_row(&store, scope, key)
.await
.map(|row| row.map(|row| row.metadata))
}
pub(super) async fn get(config: &FsStorageConfig, scope: &str, key: &str) -> RuntimeResult<Option<ObjectGetResult>> {
normalize_storage_key(key)?;
let store = open_store(config).await?;
let Some(row) = manifest_row(&store, scope, key).await? else {
return Ok(None);
};
let recipe_hash = Hash32::from_hex(&row.recipe_hash)
.map_err(|err| RuntimeError::invalid_state(format!("Assetpack manifest recipe hash is invalid: {err}")))?;
let Some(recipe_object) = store
.get_object(&recipe_hash)
.await
.map_err(|err| RuntimeError::invalid_state(format!("Assetpack recipe read failed: {err}")))?
else {
return Err(RuntimeError::invalid_state(format!(
"Assetpack recipe object is missing for {key}"
)));
};
let recipe = parse_recipe_checked(&recipe_object.content, &recipe_hash)
.map_err(|err| RuntimeError::invalid_state(format!("Assetpack recipe parse failed: {err}")))?;
let mut stored_stream = Vec::with_capacity(recipe.stored_stream_size as usize);
for (chunk_hash, expected_len) in &recipe.chunks {
let Some(chunk) = store
.get_object(chunk_hash)
.await
.map_err(|err| RuntimeError::invalid_state(format!("Assetpack chunk read failed: {err}")))?
else {
return Err(RuntimeError::invalid_state(format!(
"Assetpack chunk is missing for {key}: {chunk_hash}"
)));
};
if chunk.kind != ObjectKind::Chunk || chunk.size != *expected_len as u64 {
return Err(RuntimeError::invalid_state(format!(
"Assetpack chunk metadata mismatch for {key}: {chunk_hash}"
)));
}
stored_stream.extend_from_slice(&chunk.content);
}
let body = decode_stored_stream(recipe.transform_id, stored_stream)?;
if body.len() as u64 != recipe.original_file_size || Hash32::sha3_256(&body) != recipe.original_file_hash {
return Err(RuntimeError::invalid_state(format!(
"Assetpack reconstructed body failed integrity check for {key}"
)));
}
Ok(Some(ObjectGetResult {
body,
metadata: row.metadata,
}))
}
pub(super) async fn list(
config: &FsStorageConfig,
scope: &str,
prefix: Option<String>,
) -> RuntimeResult<Vec<ObjectListEntry>> {
let prefix = prefix
.map(|prefix| super::normalize_storage_prefix(&prefix))
.transpose()?
.unwrap_or_default();
let store = open_store(config).await?;
let rows = sqlx::query(
r#"
SELECT key, content_length, last_modified_ms
FROM storage_assetpack_blobs
WHERE scope = ?1 AND key LIKE ?2 ESCAPE '\'
ORDER BY key ASC
"#,
)
.bind(scope)
.bind(format!("{}%", escape_sqlite_like(&prefix)))
.fetch_all(store.pool())
.await
.map_err(|err| RuntimeError::database("Assetpack manifest list failed", err))?;
rows
.into_iter()
.map(|row| {
Ok(ObjectListEntry {
key: row.get("key"),
content_length: row.get::<i64, _>("content_length"),
last_modified_ms: row.get("last_modified_ms"),
})
})
.collect()
}
fn escape_sqlite_like(value: &str) -> String {
let mut escaped = String::with_capacity(value.len());
for ch in value.chars() {
match ch {
'%' | '_' | '\\' => {
escaped.push('\\');
escaped.push(ch);
}
_ => escaped.push(ch),
}
}
escaped
}
pub(super) async fn delete(config: &FsStorageConfig, scope: &str, key: &str) -> RuntimeResult<()> {
normalize_storage_key(key)?;
let store = open_store(config).await?;
sqlx::query("DELETE FROM storage_assetpack_blobs WHERE scope = ?1 AND key = ?2")
.bind(scope)
.bind(key)
.execute(store.pool())
.await
.map_err(|err| RuntimeError::database("Assetpack manifest delete failed", err))?;
Ok(())
}
async fn open_store(config: &FsStorageConfig) -> RuntimeResult<SqliteStore> {
let store = SqliteStore::open(store_path(config))
.await
.map_err(|err| RuntimeError::invalid_state(format!("Assetpack store open failed: {err}")))?;
ensure_manifest_schema(&store).await?;
Ok(store)
}
fn store_path(config: &FsStorageConfig) -> PathBuf {
fs_bucket_path(config).join("assetpack.sqlite")
}
fn extension_from_key(key: &str) -> Option<String> {
key
.rsplit_once('.')
.and_then(|(_, extension)| (!extension.is_empty()).then(|| extension.to_ascii_lowercase()))
}
struct ManifestRow {
recipe_hash: String,
metadata: ObjectMetadata,
}
async fn ensure_manifest_schema(store: &SqliteStore) -> RuntimeResult<()> {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS storage_assetpack_blobs (
scope TEXT NOT NULL,
key TEXT NOT NULL,
recipe_hash TEXT NOT NULL,
content_type TEXT NOT NULL,
content_length INTEGER NOT NULL,
checksum_crc32 TEXT,
last_modified_ms INTEGER NOT NULL,
PRIMARY KEY (scope, key)
)
"#,
)
.execute(store.pool())
.await
.map_err(|err| RuntimeError::database("Assetpack manifest schema create failed", err))?;
sqlx::query(
"CREATE INDEX IF NOT EXISTS storage_assetpack_blobs_scope_prefix_idx ON storage_assetpack_blobs (scope, key)",
)
.execute(store.pool())
.await
.map_err(|err| RuntimeError::database("Assetpack manifest index create failed", err))?;
Ok(())
}
async fn manifest_row(store: &SqliteStore, scope: &str, key: &str) -> RuntimeResult<Option<ManifestRow>> {
let row = sqlx::query(
r#"
SELECT recipe_hash, content_type, content_length, checksum_crc32, last_modified_ms
FROM storage_assetpack_blobs
WHERE scope = ?1 AND key = ?2
"#,
)
.bind(scope)
.bind(key)
.fetch_optional(store.pool())
.await
.map_err(|err| RuntimeError::database("Assetpack manifest read failed", err))?;
row
.map(|row| {
Ok(ManifestRow {
recipe_hash: row.get("recipe_hash"),
metadata: ObjectMetadata {
content_type: row.get("content_type"),
content_length: row.get("content_length"),
checksum_crc32: row.get("checksum_crc32"),
last_modified_ms: row.get("last_modified_ms"),
},
})
})
.transpose()
}
fn decode_stored_stream(transform_id: u16, stored_stream: Vec<u8>) -> RuntimeResult<Vec<u8>> {
let transform_config = FileTransformConfig::default();
let registry = TransformRegistry::new(&transform_config, assetpack_transform_precomp2::default_specs());
let transform = registry
.get(transform_id)
.ok_or_else(|| RuntimeError::invalid_state(format!("Assetpack transform is not registered: {transform_id}")))?;
let mut out = Vec::new();
transform
.decode(&mut BufReader::new(Cursor::new(stored_stream)), &mut out)
.map_err(|err| RuntimeError::invalid_state(format!("Assetpack transform decode failed: {err}")))?;
Ok(out)
}
@@ -0,0 +1,691 @@
use std::collections::HashMap;
use chrono::{DateTime, Duration, Utc};
use sqlx::{FromRow, PgPool};
use super::{
RuntimeBlobCleanupExecuteResult, RuntimeBlobCleanupPlanResult, RuntimeError, RuntimeResult, StorageRuntime,
napi_error,
};
#[derive(FromRow)]
struct BlobCandidateRow {
workspace_id: String,
key: String,
size: i32,
}
#[derive(FromRow)]
struct MarkedCandidateRow {
workspace_id: String,
blob_key: String,
}
struct DeletableCandidate {
workspace_id: String,
blob_key: String,
object_key: String,
}
fn push_workspace_once(workspace_ids: &mut Vec<String>, workspace_id: &str) {
if !workspace_ids.iter().any(|id| id == workspace_id) {
workspace_ids.push(workspace_id.to_string());
}
}
async fn checkpoint_completed(pool: &PgPool, kind: &str, scope: &str) -> RuntimeResult<bool> {
sqlx::query_scalar::<_, bool>(
"SELECT EXISTS(SELECT 1 FROM blob_reconciliation_checkpoints WHERE kind = $1 AND scope = $2 AND status = \
'completed')",
)
.bind(kind)
.bind(scope)
.fetch_one(pool)
.await
.map_err(|err| RuntimeError::database("Blob cleanup checkpoint check failed", err))
}
async fn projection_is_stale(pool: &PgPool, workspace_id: &str) -> RuntimeResult<bool> {
let checkpoint_fresh = checkpoint_completed(pool, "doc_blob_refs", workspace_id).await?;
let has_stale_rows = sqlx::query_scalar::<_, bool>(
"SELECT EXISTS(SELECT 1 FROM doc_blob_refs WHERE workspace_id = $1 AND status <> 'fresh')",
)
.bind(workspace_id)
.fetch_one(pool)
.await
.map_err(|err| RuntimeError::database("Blob cleanup projection freshness check failed", err))?;
Ok(!checkpoint_fresh || has_stale_rows)
}
async fn stale_projection_workspaces(pool: &PgPool, workspace_id: &str) -> RuntimeResult<Vec<String>> {
if projection_is_stale(pool, workspace_id).await? {
Ok(vec![workspace_id.to_string()])
} else {
Ok(Vec::new())
}
}
async fn metadata_backfill_is_complete(pool: &PgPool, workspace_id: &str) -> RuntimeResult<bool> {
checkpoint_completed(pool, "blob_metadata_backfill", workspace_id).await
}
async fn has_doc_ref(pool: &PgPool, workspace_id: &str, key: &str) -> RuntimeResult<bool> {
sqlx::query_scalar::<_, bool>(
"SELECT EXISTS(SELECT 1 FROM doc_blob_refs WHERE workspace_id = $1 AND blob_key = $2 AND status = 'fresh')",
)
.bind(workspace_id)
.bind(key)
.fetch_one(pool)
.await
.map_err(|err| RuntimeError::database("Blob cleanup doc ref check failed", err))
}
async fn has_other_ref(pool: &PgPool, workspace_id: &str, key: &str) -> RuntimeResult<bool> {
let required_ref = sqlx::query_scalar::<_, bool>(
r#"
SELECT EXISTS(SELECT 1 FROM workspaces WHERE id = $1 AND avatar_key = $2)
OR EXISTS(SELECT 1 FROM ai_transcript_tasks WHERE workspace_id = $1 AND blob_id = $2)
OR EXISTS(SELECT 1 FROM ai_jobs WHERE workspace_id = $1 AND blob_id = $2)
OR EXISTS(
SELECT 1
FROM ai_contexts c
JOIN ai_sessions_metadata s ON s.id = c.session_id
WHERE s.workspace_id = $1
AND jsonb_path_exists(
c.config::jsonb,
'$.** ? (@ == $blobKey)',
jsonb_build_object('blobKey', to_jsonb($2::text))
)
)
"#,
)
.bind(workspace_id)
.bind(key)
.fetch_one(pool)
.await
.map_err(|err| RuntimeError::database("Blob cleanup protected ref check failed", err))?;
if required_ref {
return Ok(true);
}
if table_exists(pool, "ai_workspace_files").await?
&& sqlx::query_scalar::<_, bool>(
"SELECT EXISTS(SELECT 1 FROM ai_workspace_files WHERE workspace_id = $1 AND blob_id = $2)",
)
.bind(workspace_id)
.bind(key)
.fetch_one(pool)
.await
.map_err(|err| RuntimeError::database("Blob cleanup workspace file ref check failed", err))?
{
return Ok(true);
}
if table_exists(pool, "ai_workspace_blob_embeddings").await?
&& sqlx::query_scalar::<_, bool>(
"SELECT EXISTS(SELECT 1 FROM ai_workspace_blob_embeddings WHERE workspace_id = $1 AND blob_id = $2)",
)
.bind(workspace_id)
.bind(key)
.fetch_one(pool)
.await
.map_err(|err| RuntimeError::database("Blob cleanup workspace blob embedding ref check failed", err))?
{
return Ok(true);
}
Ok(false)
}
async fn table_exists(pool: &PgPool, table: &str) -> RuntimeResult<bool> {
sqlx::query_scalar::<_, bool>("SELECT to_regclass($1) IS NOT NULL")
.bind(format!("public.{table}"))
.fetch_one(pool)
.await
.map_err(|err| RuntimeError::database("Blob cleanup table existence check failed", err))
}
async fn load_completed_blobs(
pool: &PgPool,
workspace_id: &str,
after_key: Option<&str>,
limit: i64,
) -> RuntimeResult<Vec<BlobCandidateRow>> {
sqlx::query_as::<_, BlobCandidateRow>(
r#"
SELECT workspace_id, key, size
FROM blobs
WHERE workspace_id = $1
AND status = 'completed'
AND deleted_at IS NULL
AND ($2::text IS NULL OR key > $2)
ORDER BY key ASC
LIMIT $3
"#,
)
.bind(workspace_id)
.bind(after_key)
.bind(limit)
.fetch_all(pool)
.await
.map_err(|err| RuntimeError::database("Blob cleanup load completed blobs failed", err))
}
async fn load_plan_cursor(pool: &PgPool, workspace_id: &str) -> RuntimeResult<Option<String>> {
let row = sqlx::query_as::<_, (String, serde_json::Value)>(
"SELECT status, cursor FROM blob_reconciliation_checkpoints WHERE kind = 'blob_cleanup_plan' AND scope = $1",
)
.bind(workspace_id)
.fetch_optional(pool)
.await
.map_err(|err| RuntimeError::database("Blob cleanup plan checkpoint load failed", err))?;
let Some((status, cursor)) = row else {
return Ok(None);
};
if status == "completed" {
return Ok(None);
}
Ok({
cursor
.get("lastBlobKey")
.and_then(|value| value.as_str())
.map(ToString::to_string)
})
}
async fn upsert_plan_checkpoint(
pool: &PgPool,
workspace_id: &str,
last_blob_key: Option<&str>,
completed: bool,
) -> RuntimeResult<()> {
let status = if completed { "completed" } else { "running" };
sqlx::query(
r#"
INSERT INTO blob_reconciliation_checkpoints
(kind, scope, status, cursor, last_key, completed_at)
VALUES ('blob_cleanup_plan', $1, $2, $3, $4, CASE WHEN $5 THEN CURRENT_TIMESTAMP ELSE NULL END)
ON CONFLICT (kind, scope) DO UPDATE
SET status = EXCLUDED.status,
cursor = EXCLUDED.cursor,
last_key = COALESCE(EXCLUDED.last_key, blob_reconciliation_checkpoints.last_key),
completed_at = CASE WHEN $5 THEN CURRENT_TIMESTAMP ELSE NULL END,
updated_at = CURRENT_TIMESTAMP
"#,
)
.bind(workspace_id)
.bind(status)
.bind(serde_json::json!({ "lastBlobKey": last_blob_key }))
.bind(last_blob_key)
.bind(completed)
.execute(pool)
.await
.map_err(|err| RuntimeError::database("Blob cleanup plan checkpoint write failed", err))?;
Ok(())
}
async fn create_run(pool: &PgPool, workspace_id: &str) -> RuntimeResult<String> {
sqlx::query_scalar::<_, String>(
r#"
INSERT INTO blob_reconciliation_runs (kind, mode, status, workspace_id)
VALUES ('blob_cleanup_plan', 'mark_only', 'running', $1)
RETURNING id::text
"#,
)
.bind(workspace_id)
.fetch_one(pool)
.await
.map_err(|err| RuntimeError::database("Blob cleanup create run failed", err))
}
async fn finish_run(
pool: &PgPool,
run_id: &str,
workspace_id: &str,
result: &RuntimeBlobCleanupPlanResult,
stale_projection_workspaces: Vec<String>,
) -> RuntimeResult<()> {
let candidate_bytes = sqlx::query_scalar::<_, Option<i64>>(
"SELECT SUM(object_size)::bigint FROM blob_cleanup_candidates WHERE run_id = $1::uuid AND status = 'marked'",
)
.bind(run_id)
.fetch_one(pool)
.await
.map_err(|err| RuntimeError::database("Blob cleanup candidate bytes audit failed", err))?
.unwrap_or(0);
sqlx::query(
r#"
UPDATE blob_reconciliation_runs
SET status = 'finished',
finished_at = CURRENT_TIMESTAMP,
scanned = $2,
changed = $3,
metadata = $4
WHERE id = $1::uuid
"#,
)
.bind(run_id)
.bind(result.scanned_blobs as i32)
.bind(result.candidates_marked as i32)
.bind(serde_json::json!({
"protectedByDocRefs": result.protected_by_doc_refs,
"protectedByMetadata": result.protected_by_metadata,
"protectedByOtherRefs": result.protected_by_other_refs,
"topWorkspaceCandidateBytes": [{
"workspaceId": workspace_id,
"candidateBytes": candidate_bytes,
}],
"staleOrFailedProjectionWorkspaces": stale_projection_workspaces,
}))
.execute(pool)
.await
.map_err(|err| RuntimeError::database("Blob cleanup finish run failed", err))?;
Ok(())
}
async fn mark_candidate_status(
pool: &PgPool,
run_id: &str,
workspace_id: &str,
blob_key: &str,
status: &str,
evidence: serde_json::Value,
error: Option<&str>,
) -> RuntimeResult<()> {
sqlx::query(
r#"
UPDATE blob_cleanup_candidates
SET status = $3,
executed_at = CURRENT_TIMESTAMP,
evidence = evidence || $4,
error = $5
WHERE workspace_id = $1 AND blob_key = $2 AND run_id = $6::uuid
"#,
)
.bind(workspace_id)
.bind(blob_key)
.bind(status)
.bind(evidence)
.bind(error)
.bind(run_id)
.execute(pool)
.await
.map_err(|err| RuntimeError::database("Blob cleanup mark candidate status failed", err))?;
Ok(())
}
async fn finish_execute_run(
pool: &PgPool,
run_id: &str,
result: &RuntimeBlobCleanupExecuteResult,
) -> RuntimeResult<()> {
sqlx::query(
r#"
UPDATE blob_reconciliation_runs
SET status = 'finished',
finished_at = CURRENT_TIMESTAMP,
scanned = $2,
changed = $3,
failed = $4,
metadata = metadata || $5
WHERE id = $1::uuid
"#,
)
.bind(run_id)
.bind(result.scanned_candidates as i32)
.bind(result.deleted_metadata as i32)
.bind(result.failed as i32)
.bind(serde_json::json!({
"deletedObjects": result.deleted_objects,
"deletedMetadata": result.deleted_metadata,
"skippedStillReferenced": result.skipped_still_referenced,
"failed": result.failed,
}))
.execute(pool)
.await
.map_err(|err| RuntimeError::database("Blob cleanup execute run finish failed", err))?;
Ok(())
}
async fn mark_candidate(
pool: &PgPool,
run_id: &str,
row: &BlobCandidateRow,
object_size: i64,
object_last_modified: DateTime<Utc>,
) -> RuntimeResult<i64> {
let result = sqlx::query(
r#"
INSERT INTO blob_cleanup_candidates
(workspace_id, blob_key, reason, status, object_size, object_last_modified, run_id, evidence)
VALUES ($1, $2, 'unreferenced_completed_blob', 'marked', $3, $4, $5::uuid, $6)
ON CONFLICT (workspace_id, blob_key) DO UPDATE
SET reason = EXCLUDED.reason,
status = 'marked',
object_size = EXCLUDED.object_size,
object_last_modified = EXCLUDED.object_last_modified,
planned_at = CURRENT_TIMESTAMP,
run_id = EXCLUDED.run_id,
evidence = EXCLUDED.evidence,
error = NULL
"#,
)
.bind(&row.workspace_id)
.bind(&row.key)
.bind(object_size)
.bind(object_last_modified)
.bind(run_id)
.bind(serde_json::json!({ "metadataSize": row.size }))
.execute(pool)
.await
.map_err(|err| RuntimeError::database("Blob cleanup mark candidate failed", err))?;
Ok(result.rows_affected() as i64)
}
async fn load_marked_candidates(pool: &PgPool, run_id: &str, limit: i64) -> RuntimeResult<Vec<MarkedCandidateRow>> {
sqlx::query_as::<_, MarkedCandidateRow>(
r#"
SELECT workspace_id, blob_key
FROM blob_cleanup_candidates
WHERE run_id = $1::uuid AND status IN ('marked', 'failed')
ORDER BY CASE WHEN status = 'marked' THEN 0 ELSE 1 END, planned_at ASC
LIMIT $2
"#,
)
.bind(run_id)
.bind(limit)
.fetch_all(pool)
.await
.map_err(|err| RuntimeError::database("Blob cleanup load marked candidates failed", err))
}
#[napi_derive::napi]
impl StorageRuntime {
#[napi]
pub async fn plan_unreferenced_workspace_blobs(
&self,
workspace_id: String,
grace_period_days: i64,
limit: i64,
) -> napi::Result<RuntimeBlobCleanupPlanResult> {
if limit <= 0 {
return Err(napi_error("blob cleanup plan limit must be positive"));
}
if grace_period_days < 0 {
return Err(napi_error("blob cleanup grace period must be non-negative"));
}
let pool = self.pool().await?;
let run_id = create_run(&pool, &workspace_id).await?;
let mut result = RuntimeBlobCleanupPlanResult {
run_id: Some(run_id.clone()),
scanned_blobs: 0,
candidates_marked: 0,
protected_by_doc_refs: 0,
protected_by_metadata: 0,
protected_by_other_refs: 0,
next_cursor: None,
};
let cursor = load_plan_cursor(&pool, &workspace_id).await?;
let stale_projection_workspaces = stale_projection_workspaces(&pool, &workspace_id).await?;
if !metadata_backfill_is_complete(&pool, &workspace_id).await? || !stale_projection_workspaces.is_empty() {
result.protected_by_metadata = load_completed_blobs(&pool, &workspace_id, cursor.as_deref(), limit)
.await?
.len() as i64;
finish_run(&pool, &run_id, &workspace_id, &result, stale_projection_workspaces).await?;
return Ok(result);
}
let min_last_modified = Utc::now() - Duration::days(grace_period_days);
let rows = load_completed_blobs(&pool, &workspace_id, cursor.as_deref(), limit).await?;
let has_more = rows.len() == limit as usize;
let mut last_blob_key = None;
for row in rows {
result.scanned_blobs += 1;
last_blob_key = Some(row.key.clone());
if has_doc_ref(&pool, &row.workspace_id, &row.key).await? {
result.protected_by_doc_refs += 1;
continue;
}
if has_other_ref(&pool, &row.workspace_id, &row.key).await? {
result.protected_by_other_refs += 1;
continue;
}
let object_key = format!("{}/{}", row.workspace_id, row.key);
let Some(metadata) = self.object_storage_head(object_key).await? else {
result.protected_by_metadata += 1;
continue;
};
let last_modified = DateTime::<Utc>::from_timestamp_millis(metadata.last_modified_ms)
.ok_or_else(|| RuntimeError::invalid_state("blob cleanup object last modified is invalid"))?;
if metadata.content_length != row.size as i64 || last_modified > min_last_modified {
result.protected_by_metadata += 1;
continue;
}
result.candidates_marked += mark_candidate(&pool, &run_id, &row, metadata.content_length, last_modified).await?;
}
if has_more {
result.next_cursor = last_blob_key.clone();
}
upsert_plan_checkpoint(&pool, &workspace_id, last_blob_key.as_deref(), !has_more).await?;
finish_run(&pool, &run_id, &workspace_id, &result, Vec::new()).await?;
Ok(result)
}
#[napi]
pub async fn execute_blob_cleanup_candidates(
&self,
run_id: String,
grace_period_days: i64,
limit: i64,
) -> napi::Result<RuntimeBlobCleanupExecuteResult> {
if limit <= 0 {
return Err(napi_error("blob cleanup execute limit must be positive"));
}
if grace_period_days < 0 {
return Err(napi_error("blob cleanup grace period must be non-negative"));
}
let pool = self.pool().await?;
let min_last_modified = Utc::now() - Duration::days(grace_period_days);
let rows = load_marked_candidates(&pool, &run_id, limit).await?;
let mut result = RuntimeBlobCleanupExecuteResult {
scanned_candidates: rows.len() as i64,
deleted_objects: 0,
deleted_metadata: 0,
skipped_still_referenced: 0,
failed: 0,
workspace_ids: Vec::new(),
};
let mut deletable_candidates = Vec::new();
for row in rows {
if projection_is_stale(&pool, &row.workspace_id).await?
|| has_doc_ref(&pool, &row.workspace_id, &row.blob_key).await?
|| has_other_ref(&pool, &row.workspace_id, &row.blob_key).await?
{
result.skipped_still_referenced += 1;
mark_candidate_status(
&pool,
&run_id,
&row.workspace_id,
&row.blob_key,
"skipped",
serde_json::json!({ "skipReason": "referenced_or_projection_stale" }),
None,
)
.await?;
continue;
}
let object_key = format!("{}/{}", row.workspace_id, row.blob_key);
let metadata = match self.object_storage_head(object_key.clone()).await {
Ok(metadata) => metadata,
Err(err) => {
result.failed += 1;
mark_candidate_status(
&pool,
&run_id,
&row.workspace_id,
&row.blob_key,
"failed",
serde_json::json!({ "failure": "object_head_failed" }),
Some(&err.to_string()),
)
.await?;
continue;
}
};
if let Some(metadata) = metadata {
let last_modified = DateTime::<Utc>::from_timestamp_millis(metadata.last_modified_ms)
.ok_or_else(|| RuntimeError::invalid_state("blob cleanup execute object last modified is invalid"))?;
if last_modified > min_last_modified {
result.skipped_still_referenced += 1;
mark_candidate_status(
&pool,
&run_id,
&row.workspace_id,
&row.blob_key,
"skipped",
serde_json::json!({ "skipReason": "object_inside_grace_period" }),
None,
)
.await?;
continue;
}
deletable_candidates.push(DeletableCandidate {
workspace_id: row.workspace_id,
blob_key: row.blob_key,
object_key,
});
continue;
}
let deleted_metadata =
match sqlx::query("DELETE FROM blobs WHERE workspace_id = $1 AND key = $2 AND deleted_at IS NULL")
.bind(&row.workspace_id)
.bind(&row.blob_key)
.execute(&pool)
.await
{
Ok(result) => result.rows_affected() as i64,
Err(err) => {
result.failed += 1;
mark_candidate_status(
&pool,
&run_id,
&row.workspace_id,
&row.blob_key,
"failed",
serde_json::json!({ "failure": "metadata_delete_failed" }),
Some(&err.to_string()),
)
.await?;
continue;
}
};
result.deleted_metadata += deleted_metadata;
push_workspace_once(&mut result.workspace_ids, &row.workspace_id);
mark_candidate_status(
&pool,
&run_id,
&row.workspace_id,
&row.blob_key,
"executed",
serde_json::json!({
"deletedMetadata": deleted_metadata,
"objectMissingBeforeDelete": true,
}),
None,
)
.await?;
}
if !deletable_candidates.is_empty() {
let object_keys = deletable_candidates
.iter()
.map(|candidate| candidate.object_key.clone())
.collect::<Vec<_>>();
let outcomes = match self.object_storage_delete_many(object_keys.clone()).await {
Ok(outcomes) => outcomes,
Err(err) => object_keys
.into_iter()
.map(|key| super::object_storage::types::ObjectDeleteOutcome {
key,
error: Some(err.to_string()),
})
.collect(),
};
let mut outcomes_by_key = outcomes
.into_iter()
.map(|outcome| (outcome.key, outcome.error))
.collect::<HashMap<_, _>>();
for row in deletable_candidates {
let delete_error = match outcomes_by_key.remove(&row.object_key) {
Some(Some(error)) => Some(error),
Some(None) => None,
None => Some("DeleteObjects response did not include this key".to_string()),
};
if let Some(error) = delete_error {
result.failed += 1;
mark_candidate_status(
&pool,
&run_id,
&row.workspace_id,
&row.blob_key,
"failed",
serde_json::json!({ "failure": "object_delete_failed" }),
Some(&error),
)
.await?;
continue;
}
result.deleted_objects += 1;
let deleted_metadata =
match sqlx::query("DELETE FROM blobs WHERE workspace_id = $1 AND key = $2 AND deleted_at IS NULL")
.bind(&row.workspace_id)
.bind(&row.blob_key)
.execute(&pool)
.await
{
Ok(result) => result.rows_affected() as i64,
Err(err) => {
result.failed += 1;
mark_candidate_status(
&pool,
&run_id,
&row.workspace_id,
&row.blob_key,
"failed",
serde_json::json!({ "failure": "metadata_delete_failed" }),
Some(&err.to_string()),
)
.await?;
continue;
}
};
result.deleted_metadata += deleted_metadata;
push_workspace_once(&mut result.workspace_ids, &row.workspace_id);
mark_candidate_status(
&pool,
&run_id,
&row.workspace_id,
&row.blob_key,
"executed",
serde_json::json!({
"deletedMetadata": deleted_metadata,
"objectMissingBeforeDelete": false,
}),
None,
)
.await?;
}
}
finish_execute_run(&pool, &run_id, &result).await?;
Ok(result)
}
}
@@ -2,7 +2,7 @@ use chrono::{DateTime, Utc};
use napi::Result;
use sqlx::{FromRow, PgPool};
use super::{BackendRuntime, error::napi_error, types::RuntimeBlobCleanupResult};
use super::{RuntimeBlobCleanupResult, RuntimeError, RuntimeResult, StorageRuntime, napi_error};
#[derive(FromRow)]
struct BlobRow {
@@ -20,7 +20,7 @@ impl BlobReclaimerStore {
Self { pool }
}
async fn load_expired_pending(&self, cutoff: DateTime<Utc>, limit: i64) -> Result<Vec<BlobRow>> {
async fn load_expired_pending(&self, cutoff: DateTime<Utc>, limit: i64) -> RuntimeResult<Vec<BlobRow>> {
sqlx::query_as::<_, BlobRow>(
r#"
SELECT workspace_id, key, upload_id
@@ -36,10 +36,10 @@ impl BlobReclaimerStore {
.bind(limit)
.fetch_all(&self.pool)
.await
.map_err(|err| napi_error(format!("BlobReclaimer load pending blobs failed: {err}")))
.map_err(|err| RuntimeError::database("BlobReclaimer load pending blobs failed", err))
}
async fn load_deleted(&self, workspace_id: &str, limit: i64) -> Result<Vec<BlobRow>> {
async fn load_deleted(&self, workspace_id: &str, limit: i64) -> RuntimeResult<Vec<BlobRow>> {
sqlx::query_as::<_, BlobRow>(
r#"
SELECT workspace_id, key, upload_id
@@ -54,10 +54,10 @@ impl BlobReclaimerStore {
.bind(limit)
.fetch_all(&self.pool)
.await
.map_err(|err| napi_error(format!("BlobReclaimer load deleted blobs failed: {err}")))
.map_err(|err| RuntimeError::database("BlobReclaimer load deleted blobs failed", err))
}
async fn delete_pending_metadata(&self, workspace_id: &str, key: &str) -> Result<i64> {
async fn delete_pending_metadata(&self, workspace_id: &str, key: &str) -> RuntimeResult<i64> {
let result = sqlx::query(
r#"
DELETE FROM blobs
@@ -70,11 +70,11 @@ impl BlobReclaimerStore {
.bind(key)
.execute(&self.pool)
.await
.map_err(|err| napi_error(format!("BlobReclaimer delete pending blob metadata failed: {err}")))?;
.map_err(|err| RuntimeError::database("BlobReclaimer delete pending blob metadata failed", err))?;
Ok(result.rows_affected() as i64)
}
async fn delete_released_metadata(&self, workspace_id: &str, key: &str) -> Result<i64> {
async fn delete_released_metadata(&self, workspace_id: &str, key: &str) -> RuntimeResult<i64> {
let result = sqlx::query(
r#"
DELETE FROM blobs
@@ -86,31 +86,23 @@ impl BlobReclaimerStore {
.bind(key)
.execute(&self.pool)
.await
.map_err(|err| napi_error(format!("BlobReclaimer delete blob metadata failed: {err}")))?;
.map_err(|err| RuntimeError::database("BlobReclaimer delete blob metadata failed", err))?;
Ok(result.rows_affected() as i64)
}
}
fn object_missing_error(err: &napi::Error) -> bool {
let message = err.to_string();
message.contains("NoSuchKey")
|| message.contains("NoSuchUpload")
|| message.contains("NotFound")
|| message.contains("not found")
}
async fn delete_object_idempotent(runtime: &BackendRuntime, key: &str) -> Result<()> {
async fn delete_object_idempotent(runtime: &StorageRuntime, key: &str) -> RuntimeResult<()> {
match runtime.object_storage_delete_object(key).await {
Ok(()) => Ok(()),
Err(err) if object_missing_error(&err) => Ok(()),
Err(err) if err.is_object_missing() => Ok(()),
Err(err) => Err(err),
}
}
async fn abort_upload_idempotent(runtime: &BackendRuntime, key: &str, upload_id: &str) -> Result<()> {
async fn abort_upload_idempotent(runtime: &StorageRuntime, key: &str, upload_id: &str) -> RuntimeResult<()> {
match runtime.object_storage_abort_upload(key, upload_id).await {
Ok(()) => Ok(()),
Err(err) if object_missing_error(&err) => Ok(()),
Err(err) if err.is_object_missing() => Ok(()),
Err(err) => Err(err),
}
}
@@ -122,7 +114,7 @@ fn push_workspace_once(workspace_ids: &mut Vec<String>, workspace_id: &str) {
}
#[napi_derive::napi]
impl BackendRuntime {
impl StorageRuntime {
#[napi]
pub async fn cleanup_expired_pending_blobs(&self, cutoff_ms: i64, limit: i64) -> Result<RuntimeBlobCleanupResult> {
if limit <= 0 {
@@ -130,7 +122,7 @@ impl BackendRuntime {
}
let cutoff = DateTime::<Utc>::from_timestamp_millis(cutoff_ms)
.ok_or_else(|| napi_error("pending blob cleanup cutoff is invalid"))?;
.ok_or_else(|| RuntimeError::invalid_input("pending blob cleanup cutoff is invalid"))?;
let store = BlobReclaimerStore::new(self.pool().await?);
let rows = store.load_expired_pending(cutoff, limit).await?;
@@ -0,0 +1,273 @@
use chrono::{DateTime, Utc};
use sqlx::{FromRow, PgPool};
use super::{
RuntimeBlobMetadataBackfillResult, RuntimeError, RuntimeObjectMetadata, RuntimeResult, StorageRuntime, napi_error,
};
async fn workspace_exists(pool: &PgPool, workspace_id: &str) -> RuntimeResult<bool> {
sqlx::query_scalar::<_, bool>("SELECT EXISTS(SELECT 1 FROM workspaces WHERE id = $1)")
.bind(workspace_id)
.fetch_one(pool)
.await
.map_err(|err| RuntimeError::database("Blob metadata backfill workspace check failed", err))
}
async fn blob_exists(pool: &PgPool, workspace_id: &str, key: &str) -> RuntimeResult<bool> {
sqlx::query_scalar::<_, bool>("SELECT EXISTS(SELECT 1 FROM blobs WHERE workspace_id = $1 AND key = $2)")
.bind(workspace_id)
.bind(key)
.fetch_one(pool)
.await
.map_err(|err| RuntimeError::database("Blob metadata backfill blob check failed", err))
}
async fn upsert_blob_metadata(
pool: &PgPool,
workspace_id: &str,
key: &str,
metadata: RuntimeObjectMetadata,
) -> RuntimeResult<i64> {
let last_modified = DateTime::<Utc>::from_timestamp_millis(metadata.last_modified_ms)
.ok_or_else(|| RuntimeError::invalid_state("Blob metadata backfill object last modified is invalid"))?;
let result = sqlx::query(
r#"
INSERT INTO blobs (workspace_id, key, size, mime, status, upload_id, created_at, deleted_at)
VALUES ($1, $2, $3, $4, 'completed', NULL, $5, NULL)
ON CONFLICT (workspace_id, key) DO UPDATE
SET size = EXCLUDED.size,
mime = EXCLUDED.mime,
status = 'completed',
upload_id = NULL,
deleted_at = NULL
WHERE blobs.deleted_at IS NULL
"#,
)
.bind(workspace_id)
.bind(key)
.bind(metadata.content_length as i32)
.bind(metadata.content_type)
.bind(last_modified)
.execute(pool)
.await
.map_err(|err| RuntimeError::database("Blob metadata backfill upsert failed", err))?;
Ok(result.rows_affected() as i64)
}
fn split_workspace_blob_key(full_key: &str) -> Option<(&str, &str)> {
let (workspace_id, key) = full_key.split_once('/')?;
if workspace_id.is_empty() || key.is_empty() || key.contains('/') {
return None;
}
Some((workspace_id, key))
}
fn checkpoint_scope(workspace_id: Option<&str>) -> String {
workspace_id.unwrap_or("__all__").to_string()
}
#[derive(FromRow)]
struct BackfillCheckpoint {
last_key: Option<String>,
cursor: serde_json::Value,
}
impl BackfillCheckpoint {
fn continuation_token(&self) -> Option<String> {
self
.cursor
.get("continuationToken")
.and_then(|value| value.as_str())
.map(ToString::to_string)
}
}
async fn load_checkpoint(pool: &PgPool, scope: &str) -> RuntimeResult<Option<BackfillCheckpoint>> {
sqlx::query_as::<_, BackfillCheckpoint>(
"SELECT last_key, cursor FROM blob_reconciliation_checkpoints WHERE kind = 'blob_metadata_backfill' AND scope = $1",
)
.bind(scope)
.fetch_optional(pool)
.await
.map_err(|err| RuntimeError::database("Blob metadata backfill checkpoint load failed", err))
}
async fn upsert_checkpoint(
pool: &PgPool,
scope: &str,
last_key: Option<&str>,
continuation_token: Option<&str>,
completed: bool,
) -> RuntimeResult<()> {
let status = if completed { "completed" } else { "running" };
sqlx::query(
r#"
INSERT INTO blob_reconciliation_checkpoints
(kind, scope, status, cursor, last_key, completed_at, metadata)
VALUES ('blob_metadata_backfill', $1, $2, $3, $4, CASE WHEN $5 THEN CURRENT_TIMESTAMP ELSE NULL END, $6)
ON CONFLICT (kind, scope) DO UPDATE
SET status = EXCLUDED.status,
cursor = EXCLUDED.cursor,
last_key = COALESCE(EXCLUDED.last_key, blob_reconciliation_checkpoints.last_key),
completed_at = CASE WHEN $5 THEN CURRENT_TIMESTAMP ELSE NULL END,
updated_at = CURRENT_TIMESTAMP,
metadata = EXCLUDED.metadata
"#,
)
.bind(scope)
.bind(status)
.bind(serde_json::json!({
"lastKey": last_key,
"continuationToken": continuation_token,
}))
.bind(last_key)
.bind(completed)
.bind(serde_json::json!({
"quotaReportingReconciliationRequired": true,
}))
.execute(pool)
.await
.map_err(|err| RuntimeError::database("Blob metadata backfill checkpoint write failed", err))?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn blob_metadata_backfill_splits_workspace_blob_keys() {
assert_eq!(
split_workspace_blob_key("workspace/blob-key"),
Some(("workspace", "blob-key"))
);
assert_eq!(split_workspace_blob_key("workspace/nested/blob-key"), None);
assert_eq!(split_workspace_blob_key("workspace/"), None);
assert_eq!(split_workspace_blob_key("blob-key"), None);
}
#[test]
fn blob_metadata_backfill_checkpoint_scope_is_explicit() {
assert_eq!(checkpoint_scope(Some("workspace")), "workspace");
assert_eq!(checkpoint_scope(None), "__all__");
}
}
fn push_workspace_once(workspace_ids: &mut Vec<String>, workspace_id: &str) {
if !workspace_ids.iter().any(|id| id == workspace_id) {
workspace_ids.push(workspace_id.to_string());
}
}
fn checked_list_page_limit(limit: i64) -> RuntimeResult<i32> {
i32::try_from(limit).map_err(|_| RuntimeError::invalid_input("blob metadata backfill limit exceeds i32::MAX"))
}
#[napi_derive::napi]
impl StorageRuntime {
#[napi]
pub async fn backfill_missing_blob_metadata(
&self,
workspace_id: Option<String>,
limit: i64,
) -> napi::Result<RuntimeBlobMetadataBackfillResult> {
if limit <= 0 {
return Err(napi_error("blob metadata backfill limit must be positive"));
}
let page_limit = checked_list_page_limit(limit)?;
let pool = self.pool().await?;
let prefix = workspace_id.as_ref().map(|id| format!("{id}/"));
let scope = checkpoint_scope(workspace_id.as_deref());
let checkpoint = load_checkpoint(&pool, &scope).await?;
let page = self
.object_storage_list_page(
prefix,
checkpoint.as_ref().and_then(BackfillCheckpoint::continuation_token),
checkpoint.as_ref().and_then(|checkpoint| checkpoint.last_key.clone()),
page_limit,
)
.await?;
let has_more = page.next_continuation_token.is_some();
let mut result = RuntimeBlobMetadataBackfillResult {
scanned_objects: 0,
headed_objects: 0,
upserted_metadata: 0,
skipped_existing: 0,
skipped_workspace_missing: 0,
failed: 0,
next_cursor: None,
workspace_ids: Vec::new(),
};
let mut last_scanned_key = None;
for object in &page.entries {
result.scanned_objects += 1;
last_scanned_key = Some(object.key.clone());
let Some((object_workspace_id, key)) = split_workspace_blob_key(&object.key) else {
result.failed += 1;
continue;
};
if workspace_id.as_deref().is_some_and(|id| id != object_workspace_id) {
result.failed += 1;
continue;
}
if !workspace_exists(&pool, object_workspace_id).await? {
result.skipped_workspace_missing += 1;
continue;
}
if blob_exists(&pool, object_workspace_id, key).await? {
result.skipped_existing += 1;
continue;
}
result.headed_objects += 1;
let Some(metadata) = self.object_storage_head(object.key.clone()).await? else {
result.failed += 1;
continue;
};
let affected = upsert_blob_metadata(&pool, object_workspace_id, key, metadata).await?;
if affected > 0 {
result.upserted_metadata += affected;
push_workspace_once(&mut result.workspace_ids, object_workspace_id);
}
}
if has_more {
result.next_cursor = last_scanned_key.clone();
}
upsert_checkpoint(
&pool,
&scope,
last_scanned_key.as_deref(),
page.next_continuation_token.as_deref(),
!has_more,
)
.await?;
sqlx::query(
r#"
INSERT INTO blob_reconciliation_runs
(kind, mode, status, workspace_id, finished_at, scanned, changed, failed, metadata)
VALUES ('blob_metadata_backfill', 'execute', 'finished', $1, CURRENT_TIMESTAMP, $2, $3, $4, $5)
"#,
)
.bind(workspace_id)
.bind(result.scanned_objects as i32)
.bind(result.upserted_metadata as i32)
.bind(result.failed as i32)
.bind(serde_json::json!({
"headedObjects": result.headed_objects,
"skippedExisting": result.skipped_existing,
"skippedWorkspaceMissing": result.skipped_workspace_missing,
"checkpointScope": scope,
"nextCursor": result.next_cursor,
"quotaReportingReconciliationRequired": true,
}))
.execute(&pool)
.await
.map_err(|err| RuntimeError::database("Blob metadata backfill run record failed", err))?;
Ok(result)
}
}
@@ -0,0 +1,427 @@
use affine_common::doc_parser;
use chrono::{DateTime, Utc};
use sqlx::{FromRow, PgPool};
use y_octo::Doc;
use super::{RuntimeDocBlobRefsResult, RuntimeError, RuntimeResult, StorageRuntime, napi_error};
const PARSER_VERSION: i32 = 1;
#[derive(FromRow)]
struct SnapshotRow {
workspace_id: String,
doc_id: String,
blob: Vec<u8>,
updated_at: DateTime<Utc>,
}
#[derive(FromRow)]
struct UpdateRow {
blob: Vec<u8>,
created_at: DateTime<Utc>,
}
struct ExtractedRef {
blob_key: String,
block_id: String,
flavour: String,
}
async fn load_snapshot(pool: &PgPool, workspace_id: &str, doc_id: &str) -> RuntimeResult<Option<SnapshotRow>> {
sqlx::query_as::<_, SnapshotRow>(
r#"
SELECT workspace_id, guid AS doc_id, blob, updated_at
FROM snapshots
WHERE workspace_id = $1 AND guid = $2
"#,
)
.bind(workspace_id)
.bind(doc_id)
.fetch_optional(pool)
.await
.map_err(|err| RuntimeError::database("Doc blob refs load snapshot failed", err))
}
async fn load_updates(pool: &PgPool, workspace_id: &str, doc_id: &str) -> RuntimeResult<Vec<UpdateRow>> {
sqlx::query_as::<_, UpdateRow>(
r#"
SELECT blob, created_at
FROM updates
WHERE workspace_id = $1 AND guid = $2
ORDER BY created_at ASC
"#,
)
.bind(workspace_id)
.bind(doc_id)
.fetch_all(pool)
.await
.map_err(|err| RuntimeError::database("Doc blob refs load updates failed", err))
}
fn apply_doc_updates(updates: impl IntoIterator<Item = Vec<u8>>) -> RuntimeResult<Vec<u8>> {
let mut doc = Doc::default();
for update in updates {
doc
.apply_update_from_binary_v1(&update)
.map_err(|err| RuntimeError::invalid_state(format!("Doc blob refs merge failed: {err}")))?;
}
doc
.encode_update_v1()
.map_err(|err| RuntimeError::invalid_state(format!("Doc blob refs encode failed: {err}")))
}
async fn load_current_doc(pool: &PgPool, workspace_id: &str, doc_id: &str) -> RuntimeResult<Option<SnapshotRow>> {
let snapshot = load_snapshot(pool, workspace_id, doc_id).await?;
let updates = load_updates(pool, workspace_id, doc_id).await?;
if snapshot.is_none() && updates.is_empty() {
return Ok(None);
}
let mut merge_inputs = Vec::with_capacity(updates.len() + usize::from(snapshot.is_some()));
let mut updated_at = snapshot
.as_ref()
.map(|snapshot| snapshot.updated_at)
.unwrap_or_else(Utc::now);
if let Some(snapshot) = snapshot {
merge_inputs.push(snapshot.blob);
}
for update in updates {
updated_at = update.created_at;
merge_inputs.push(update.blob);
}
Ok(Some(SnapshotRow {
workspace_id: workspace_id.to_string(),
doc_id: doc_id.to_string(),
blob: apply_doc_updates(merge_inputs)?,
updated_at,
}))
}
async fn load_workspace_doc_ids(pool: &PgPool, workspace_id: &str) -> RuntimeResult<Vec<String>> {
let Some(root) = load_current_doc(pool, workspace_id, workspace_id).await? else {
return Ok(Vec::new());
};
let ids = doc_parser::get_doc_ids_from_binary(root.blob, false)
.map_err(|err| RuntimeError::invalid_state(format!("Doc blob refs root doc parse failed: {err}")))?;
let mut ids = ids;
ids.sort();
Ok(ids)
}
async fn upsert_projection_checkpoint(
pool: &PgPool,
workspace_id: &str,
result: &RuntimeDocBlobRefsResult,
) -> RuntimeResult<()> {
let completed = result.next_cursor.is_none();
let status = if completed && result.failed_docs == 0 {
"completed"
} else if result.failed_docs > 0 {
"failed"
} else {
"running"
};
sqlx::query(
r#"
INSERT INTO blob_reconciliation_checkpoints
(kind, scope, status, cursor, completed_at, metadata)
VALUES ('doc_blob_refs', $1, $2, $3, CASE WHEN $4 THEN CURRENT_TIMESTAMP ELSE NULL END, $5)
ON CONFLICT (kind, scope) DO UPDATE
SET status = EXCLUDED.status,
cursor = EXCLUDED.cursor,
completed_at = CASE WHEN $4 THEN CURRENT_TIMESTAMP ELSE NULL END,
updated_at = CURRENT_TIMESTAMP,
metadata = EXCLUDED.metadata
"#,
)
.bind(workspace_id)
.bind(status)
.bind(serde_json::json!({ "lastDocId": result.next_cursor }))
.bind(completed && result.failed_docs == 0)
.bind(serde_json::json!({
"parserVersion": PARSER_VERSION,
}))
.execute(pool)
.await
.map_err(|err| RuntimeError::database("Doc blob refs checkpoint write failed", err))?;
Ok(())
}
async fn upsert_projection_failure_checkpoint(pool: &PgPool, workspace_id: &str, error: &str) -> RuntimeResult<()> {
sqlx::query(
r#"
INSERT INTO blob_reconciliation_checkpoints
(kind, scope, status, cursor, completed_at, metadata)
VALUES ('doc_blob_refs', $1, 'failed', '{}', NULL, $2)
ON CONFLICT (kind, scope) DO UPDATE
SET status = 'failed',
cursor = '{}',
completed_at = NULL,
updated_at = CURRENT_TIMESTAMP,
metadata = EXCLUDED.metadata
"#,
)
.bind(workspace_id)
.bind(serde_json::json!({
"parserVersion": PARSER_VERSION,
"error": error,
}))
.execute(pool)
.await
.map_err(|err| RuntimeError::database("Doc blob refs failure checkpoint write failed", err))?;
Ok(())
}
async fn load_projection_cursor(pool: &PgPool, workspace_id: &str) -> RuntimeResult<Option<String>> {
let cursor = sqlx::query_scalar::<_, serde_json::Value>(
"SELECT cursor FROM blob_reconciliation_checkpoints WHERE kind = 'doc_blob_refs' AND scope = $1",
)
.bind(workspace_id)
.fetch_optional(pool)
.await
.map_err(|err| RuntimeError::database("Doc blob refs checkpoint load failed", err))?;
Ok(cursor.and_then(|cursor| {
cursor
.get("lastDocId")
.and_then(|value| value.as_str())
.map(ToString::to_string)
}))
}
async fn purge_removed_doc_refs(pool: &PgPool, workspace_id: &str, current_doc_ids: &[String]) -> RuntimeResult<i64> {
let result = sqlx::query(
r#"
DELETE FROM doc_blob_refs
WHERE workspace_id = $1
AND NOT (doc_id = ANY($2))
"#,
)
.bind(workspace_id)
.bind(current_doc_ids)
.execute(pool)
.await
.map_err(|err| RuntimeError::database("Doc blob refs purge removed docs failed", err))?;
Ok(result.rows_affected() as i64)
}
fn extract_refs(snapshot: &SnapshotRow) -> RuntimeResult<Vec<ExtractedRef>> {
let parsed = doc_parser::parse_doc_from_binary(snapshot.blob.clone(), snapshot.doc_id.clone())
.map_err(|err| RuntimeError::invalid_state(format!("Doc blob refs parse failed: {err}")))?;
let mut refs = Vec::new();
for block in parsed.blocks {
let Some(blob_keys) = block.blob else {
continue;
};
for blob_key in blob_keys {
refs.push(ExtractedRef {
blob_key,
block_id: block.block_id.clone(),
flavour: block.flavour.clone(),
});
}
}
Ok(refs)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn doc_blob_refs_extracts_image_refs() {
let doc_id = "doc-blob-ref-test".to_string();
let blob =
doc_parser::build_full_doc("Doc", "![Alt](blob://image-blob-key)", &doc_id).expect("doc fixture should build");
let snapshot = SnapshotRow {
workspace_id: "workspace".to_string(),
doc_id,
blob,
updated_at: Utc::now(),
};
let refs = extract_refs(&snapshot).expect("refs should parse");
assert!(
refs
.iter()
.any(|reference| { reference.blob_key == "image-blob-key" && reference.flavour == "affine:image" })
);
}
}
async fn replace_doc_refs(pool: &PgPool, snapshot: &SnapshotRow, refs: Vec<ExtractedRef>) -> RuntimeResult<(i64, i64)> {
let mut tx = pool
.begin()
.await
.map_err(|err| RuntimeError::database("Doc blob refs transaction failed", err))?;
let deleted = sqlx::query("DELETE FROM doc_blob_refs WHERE workspace_id = $1 AND doc_id = $2")
.bind(&snapshot.workspace_id)
.bind(&snapshot.doc_id)
.execute(&mut *tx)
.await
.map_err(|err| RuntimeError::database("Doc blob refs delete failed", err))?
.rows_affected() as i64;
let mut written = 0;
for reference in refs {
let affected = sqlx::query(
r#"
INSERT INTO doc_blob_refs
(workspace_id, doc_id, blob_key, block_id, flavour, snapshot_updated_at, parser_version, status)
VALUES ($1, $2, $3, $4, $5, $6, $7, 'fresh')
ON CONFLICT (workspace_id, doc_id, blob_key, block_id) DO UPDATE
SET flavour = EXCLUDED.flavour,
snapshot_updated_at = EXCLUDED.snapshot_updated_at,
indexed_at = CURRENT_TIMESTAMP,
parser_version = EXCLUDED.parser_version,
status = 'fresh',
error = NULL
"#,
)
.bind(&snapshot.workspace_id)
.bind(&snapshot.doc_id)
.bind(reference.blob_key)
.bind(reference.block_id)
.bind(reference.flavour)
.bind(snapshot.updated_at)
.bind(PARSER_VERSION)
.execute(&mut *tx)
.await
.map_err(|err| RuntimeError::database("Doc blob refs insert failed", err))?
.rows_affected() as i64;
written += affected;
}
tx.commit()
.await
.map_err(|err| RuntimeError::database("Doc blob refs transaction commit failed", err))?;
Ok((written, deleted))
}
async fn mark_doc_failed(pool: &PgPool, workspace_id: &str, doc_id: &str, error: &str) -> RuntimeResult<()> {
sqlx::query(
r#"
INSERT INTO doc_blob_refs
(workspace_id, doc_id, blob_key, block_id, flavour, snapshot_updated_at, parser_version, status, error)
VALUES ($1, $2, '__parse_failed__', '__parse_failed__', '__parse_failed__', CURRENT_TIMESTAMP, $3, 'failed', $4)
ON CONFLICT (workspace_id, doc_id, blob_key, block_id) DO UPDATE
SET indexed_at = CURRENT_TIMESTAMP,
status = 'failed',
error = EXCLUDED.error
"#,
)
.bind(workspace_id)
.bind(doc_id)
.bind(PARSER_VERSION)
.bind(error)
.execute(pool)
.await
.map_err(|err| RuntimeError::database("Doc blob refs mark failure failed", err))?;
Ok(())
}
async fn rebuild_doc_blob_refs_inner(
runtime: &StorageRuntime,
workspace_id: String,
doc_id: String,
) -> RuntimeResult<RuntimeDocBlobRefsResult> {
let pool = runtime.pool().await?;
let mut result = RuntimeDocBlobRefsResult {
scanned_docs: 1,
parsed_docs: 0,
refs_written: 0,
refs_deleted: 0,
failed_docs: 0,
next_cursor: None,
};
let Some(snapshot) = load_current_doc(&pool, &workspace_id, &doc_id).await? else {
result.failed_docs = 1;
mark_doc_failed(&pool, &workspace_id, &doc_id, "snapshot_missing").await?;
return Ok(result);
};
match extract_refs(&snapshot) {
Ok(refs) => {
let (written, deleted) = replace_doc_refs(&pool, &snapshot, refs).await?;
result.parsed_docs = 1;
result.refs_written = written;
result.refs_deleted = deleted;
}
Err(err) => {
result.failed_docs = 1;
mark_doc_failed(&pool, &workspace_id, &doc_id, &err.to_string()).await?;
}
}
Ok(result)
}
#[napi_derive::napi]
impl StorageRuntime {
#[napi]
pub async fn rebuild_doc_blob_refs(
&self,
workspace_id: String,
doc_id: String,
) -> napi::Result<RuntimeDocBlobRefsResult> {
Ok(rebuild_doc_blob_refs_inner(self, workspace_id, doc_id).await?)
}
#[napi]
pub async fn rebuild_workspace_doc_blob_refs(
&self,
workspace_id: String,
limit: i64,
) -> napi::Result<RuntimeDocBlobRefsResult> {
if limit <= 0 {
return Err(napi_error("doc blob refs rebuild limit must be positive"));
}
let pool = self.pool().await?;
let doc_ids = match load_workspace_doc_ids(&pool, &workspace_id).await {
Ok(doc_ids) => doc_ids,
Err(err) => {
upsert_projection_failure_checkpoint(&pool, &workspace_id, &err.to_string()).await?;
return Err(err.into());
}
};
let cursor = load_projection_cursor(&pool, &workspace_id).await?;
let current_doc_ids = doc_ids.clone();
let doc_ids = doc_ids
.into_iter()
.filter(|doc_id| cursor.as_ref().is_none_or(|cursor| doc_id > cursor))
.collect::<Vec<_>>();
let has_more = doc_ids.len() > limit as usize;
let mut total = RuntimeDocBlobRefsResult {
scanned_docs: 0,
parsed_docs: 0,
refs_written: 0,
refs_deleted: 0,
failed_docs: 0,
next_cursor: None,
};
let mut last_doc_id = None;
for doc_id in doc_ids.into_iter().take(limit as usize) {
last_doc_id = Some(doc_id.clone());
let result = rebuild_doc_blob_refs_inner(self, workspace_id.clone(), doc_id).await?;
total.scanned_docs += result.scanned_docs;
total.parsed_docs += result.parsed_docs;
total.refs_written += result.refs_written;
total.refs_deleted += result.refs_deleted;
total.failed_docs += result.failed_docs;
}
if has_more {
total.next_cursor = last_doc_id;
} else if total.failed_docs == 0 {
total.refs_deleted += purge_removed_doc_refs(&pool, &workspace_id, &current_doc_ids).await?;
}
upsert_projection_checkpoint(&pool, &workspace_id, &total).await?;
Ok(total)
}
}
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -1,29 +1,29 @@
use aws_sdk_s3::config::{
BehaviorVersion, Credentials, Region, RequestChecksumCalculation, ResponseChecksumValidation, timeout::TimeoutConfig,
};
use napi::Result;
use rusty_s3::{Bucket, Credentials, UrlStyle};
use serde::Deserialize;
use url::Url;
use super::{client::ObjectStorageClient, types::StorageProviderConfig};
use crate::backend_runtime::{
config::blob_storage_config_from_config_files, error::napi_error, types::RuntimeObjectStorageHealth,
use super::{
client::ObjectStorageClient,
error::{ObjectStorageError, ObjectStorageResult},
types::StorageProviderConfig,
};
#[derive(Clone, Debug)]
pub(in crate::backend_runtime) struct ObjectStorageConfig {
pub(super) provider: String,
pub(super) bucket: String,
pub(super) endpoint: Option<String>,
pub(super) region: Option<String>,
pub(super) access_key_id: Option<String>,
pub(super) secret_access_key: Option<String>,
pub(super) session_token: Option<String>,
pub(super) force_path_style: bool,
pub(super) request_timeout_ms: Option<u64>,
pub(super) min_part_size: Option<u64>,
pub(super) presign_expires_in_seconds: Option<u64>,
pub(super) presign_sign_content_type_for_put: Option<bool>,
pub(super) use_presigned_url: bool,
pub(crate) struct ObjectStorageConfig {
pub(crate) provider: String,
pub(crate) bucket: String,
pub(crate) endpoint: Option<String>,
pub(crate) region: Option<String>,
pub(crate) access_key_id: Option<String>,
pub(crate) secret_access_key: Option<String>,
pub(crate) session_token: Option<String>,
pub(crate) force_path_style: bool,
pub(crate) request_timeout_ms: Option<u64>,
pub(crate) min_part_size: Option<u64>,
pub(crate) presign_expires_in_seconds: Option<u64>,
pub(crate) presign_sign_content_type_for_put: Option<bool>,
pub(crate) use_presigned_url: bool,
pub(crate) proxy_upload: bool,
}
#[derive(Debug, Deserialize)]
@@ -70,13 +70,16 @@ struct S3PresignConfigFile {
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct UsePresignedUrlConfigFile {
enabled: bool,
url_prefix: Option<String>,
sign_key: Option<String>,
}
impl ObjectStorageConfig {
pub(in crate::backend_runtime) fn from_config_files() -> Result<Option<Self>> {
let Some(storage) = blob_storage_config_from_config_files()? else {
pub(crate) fn from_provider_config(storage: Option<StorageProviderConfig>) -> ObjectStorageResult<Option<Self>> {
let Some(storage) = storage else {
return Ok(None);
};
@@ -84,18 +87,18 @@ impl ObjectStorageConfig {
"aws-s3" => Self::from_s3_config(storage),
"cloudflare-r2" => Self::from_r2_config(storage),
"fs" => Ok(None),
provider => Err(napi_error(format!(
"unsupported blob storage provider for BackendRuntime: {provider}"
provider => Err(ObjectStorageError::Config(format!(
"unsupported blob storage provider for StorageRuntime: {provider}"
))),
}
}
pub(super) fn from_s3_config(storage: StorageProviderConfig) -> Result<Option<Self>> {
pub(crate) fn from_s3_config(storage: StorageProviderConfig) -> ObjectStorageResult<Option<Self>> {
let config: S3ConfigFile = serde_json::from_value(storage.config)
.map_err(|err| napi_error(format!("invalid aws-s3 blob storage config: {err}")))?;
.map_err(|err| ObjectStorageError::Config(format!("invalid aws-s3 blob storage config: {err}")))?;
let region = config
.region
.ok_or_else(|| napi_error("aws-s3 blob storage config requires region"))?;
.ok_or_else(|| ObjectStorageError::Config("aws-s3 blob storage config requires region".to_string()))?;
let endpoint = config.endpoint.or_else(|| Some(resolve_s3_endpoint(&region)));
let credentials = config.credentials.unwrap_or_default();
@@ -113,17 +116,29 @@ impl ObjectStorageConfig {
presign_expires_in_seconds: config.presign.as_ref().and_then(|v| v.expires_in_seconds),
presign_sign_content_type_for_put: config.presign.as_ref().and_then(|v| v.sign_content_type_for_put),
use_presigned_url: config.use_presigned_url.map(|v| v.enabled).unwrap_or(false),
proxy_upload: false,
}))
}
pub(super) fn from_r2_config(storage: StorageProviderConfig) -> Result<Option<Self>> {
pub(crate) fn from_r2_config(storage: StorageProviderConfig) -> ObjectStorageResult<Option<Self>> {
let config: R2ConfigFile = serde_json::from_value(storage.config)
.map_err(|err| napi_error(format!("invalid cloudflare-r2 blob storage config: {err}")))?;
.map_err(|err| ObjectStorageError::Config(format!("invalid cloudflare-r2 blob storage config: {err}")))?;
let account = match config.jurisdiction {
Some(jurisdiction) => format!("{}.{}", config.account_id, jurisdiction),
None => config.account_id,
};
let credentials = config.credentials.unwrap_or_default();
let (use_presigned_url, proxy_upload) = config
.use_presigned_url
.map(|value| {
(
value.enabled,
value.enabled
&& value.url_prefix.as_ref().is_some_and(|prefix| !prefix.is_empty())
&& value.sign_key.as_ref().is_some_and(|key| !key.is_empty()),
)
})
.unwrap_or((false, false));
Ok(Some(Self {
provider: storage.provider,
@@ -138,81 +153,51 @@ impl ObjectStorageConfig {
min_part_size: config.min_part_size,
presign_expires_in_seconds: config.presign.as_ref().and_then(|v| v.expires_in_seconds),
presign_sign_content_type_for_put: config.presign.as_ref().and_then(|v| v.sign_content_type_for_put),
use_presigned_url: config.use_presigned_url.map(|v| v.enabled).unwrap_or(false),
use_presigned_url,
proxy_upload,
}))
}
pub(super) fn build_client(&self) -> Result<ObjectStorageClient> {
pub(crate) fn build_client(&self) -> ObjectStorageResult<ObjectStorageClient> {
let region = self
.region
.clone()
.ok_or_else(|| napi_error("object storage region is required"))?;
.ok_or_else(|| ObjectStorageError::Config("object storage region is required".to_string()))?;
let access_key_id = self
.access_key_id
.clone()
.ok_or_else(|| napi_error("object storage accessKeyId is required"))?;
.ok_or_else(|| ObjectStorageError::Config("object storage accessKeyId is required".to_string()))?;
let secret_access_key = self
.secret_access_key
.clone()
.ok_or_else(|| napi_error("object storage secretAccessKey is required"))?;
.ok_or_else(|| ObjectStorageError::Config("object storage secretAccessKey is required".to_string()))?;
let credentials = Credentials::new(
access_key_id,
secret_access_key,
self.session_token.clone(),
None,
"affine-server-config-json",
);
let mut builder = aws_sdk_s3::Config::builder()
.behavior_version(BehaviorVersion::latest())
.region(Region::new(region))
.credentials_provider(credentials)
.force_path_style(self.force_path_style)
.request_checksum_calculation(RequestChecksumCalculation::WhenRequired)
.response_checksum_validation(ResponseChecksumValidation::WhenRequired);
if let Some(endpoint) = &self.endpoint {
builder = builder.endpoint_url(endpoint);
}
if let Some(request_timeout_ms) = self.request_timeout_ms {
builder = builder.timeout_config(
TimeoutConfig::builder()
.operation_timeout(std::time::Duration::from_millis(request_timeout_ms))
.build(),
);
}
Ok(ObjectStorageClient::new(
builder.build(),
let endpoint = self.endpoint.clone().unwrap_or_else(|| resolve_s3_endpoint(&region));
let endpoint = Url::parse(&endpoint)
.map_err(|err| ObjectStorageError::Config(format!("object storage endpoint is invalid: {err}")))?;
let bucket = Bucket::new(
endpoint,
if self.force_path_style {
UrlStyle::Path
} else {
UrlStyle::VirtualHost
},
self.bucket.clone(),
region,
)
.map_err(|err| ObjectStorageError::Config(format!("object storage bucket url is invalid: {err}")))?;
let credentials = match self.session_token.as_ref().filter(|token| !token.is_empty()) {
Some(session_token) => Credentials::new_with_token(access_key_id, secret_access_key, session_token.clone()),
None => Credentials::new(access_key_id, secret_access_key),
};
ObjectStorageClient::new(
bucket,
credentials,
self.request_timeout_ms,
self.presign_expires_in_seconds.unwrap_or(60),
self.presign_sign_content_type_for_put.unwrap_or(true),
))
}
pub(super) fn health(&self) -> RuntimeObjectStorageHealth {
let client_buildable = self
.build_client()
.map(|client| client.non_destructive_health())
.unwrap_or(false);
RuntimeObjectStorageHealth {
configured: true,
provider: Some(self.provider.clone()),
bucket: Some(self.bucket.clone()),
endpoint: self.endpoint.clone(),
region: self.region.clone(),
has_credentials: self.access_key_id.is_some()
&& self.secret_access_key.is_some()
&& self.session_token.as_ref().map(|v| !v.is_empty()).unwrap_or(true),
force_path_style: self.force_path_style,
request_timeout_ms: self.request_timeout_ms.map(|v| v as i64),
min_part_size: self.min_part_size.map(|v| v as i64),
presign_expires_in_seconds: self.presign_expires_in_seconds.map(|v| v as i64),
presign_sign_content_type_for_put: self.presign_sign_content_type_for_put,
use_presigned_url: self.use_presigned_url,
client_buildable,
}
)
}
}
@@ -0,0 +1,64 @@
use reqwest::StatusCode;
#[derive(Debug, thiserror::Error)]
pub(crate) enum ObjectStorageError {
#[error("ObjectStorage config error: {0}")]
Config(String),
#[error("{context}: {source}")]
Operation {
context: String,
#[source]
source: Box<ObjectStorageError>,
},
#[error("ObjectStorage http client build failed: {0}")]
HttpClientBuild(#[source] reqwest::Error),
#[error("ObjectStorage http request failed: {0}")]
HttpRequest(#[source] reqwest::Error),
#[error("ObjectStorage invalid http header: {0}")]
InvalidHeader(String),
#[error("ObjectStorage response body exceeds {limit} bytes")]
BodyTooLarge { limit: usize },
#[error("{context}: status={status} body={body}")]
HttpStatus {
context: String,
status: StatusCode,
body: String,
},
#[error("{context}: invalid utf8 response: {source}")]
InvalidUtf8 {
context: String,
#[source]
source: std::string::FromUtf8Error,
},
#[error("{context}: invalid xml response: {source}")]
InvalidXml {
context: String,
#[source]
source: instant_xml::Error,
},
#[error("ObjectStorage invalid input: {0}")]
InvalidInput(String),
}
impl ObjectStorageError {
pub(crate) fn is_not_found(&self) -> bool {
match self {
Self::Operation { source, .. } => source.is_not_found(),
Self::HttpStatus { status, body, .. } => {
*status == StatusCode::NOT_FOUND
&& (body.contains("NoSuchKey") || body.contains("NoSuchUpload") || body.contains("NotFound"))
}
_ => false,
}
}
pub(crate) fn is_retryable_http_status(&self) -> bool {
match self {
Self::Operation { source, .. } => source.is_retryable_http_status(),
Self::HttpStatus { status, .. } => *status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error(),
_ => false,
}
}
}
pub(crate) type ObjectStorageResult<T> = std::result::Result<T, ObjectStorageError>;
@@ -0,0 +1,9 @@
pub(crate) mod client;
pub(crate) mod config;
pub(crate) mod error;
#[cfg(test)]
mod tests;
pub(crate) mod types;
pub(crate) use config::ObjectStorageConfig;
pub(crate) use types::StorageProviderConfig;
@@ -0,0 +1,376 @@
use reqwest::StatusCode;
use super::{
config::ObjectStorageConfig,
error::ObjectStorageError,
types::{
MultipartUploadPart, ObjectPutMetadata, StorageProviderConfig, checksum_crc32_base64, completed_multipart_parts,
trim_etag,
},
};
fn storage_config(provider: &str, config: serde_json::Value) -> StorageProviderConfig {
StorageProviderConfig {
provider: provider.to_string(),
bucket: "test-bucket".to_string(),
config,
}
}
#[test]
fn resolves_r2_config_from_config_json_shape() {
let storage = StorageProviderConfig {
provider: "cloudflare-r2".to_string(),
bucket: "workspace-blobs".to_string(),
config: serde_json::json!({
"accountId": "account",
"jurisdiction": "eu",
"credentials": {
"accessKeyId": "key",
"secretAccessKey": "secret"
},
"usePresignedURL": {
"enabled": true
}
}),
};
let config = ObjectStorageConfig::from_r2_config(storage).unwrap().unwrap();
assert_eq!(config.provider, "cloudflare-r2");
assert_eq!(config.bucket, "workspace-blobs");
assert_eq!(
config.endpoint.as_deref(),
Some("https://account.eu.r2.cloudflarestorage.com")
);
assert_eq!(config.region.as_deref(), Some("auto"));
assert!(config.force_path_style);
assert!(config.use_presigned_url);
assert!(!config.proxy_upload);
assert_eq!(config.access_key_id.as_deref(), Some("key"));
}
#[test]
fn resolves_r2_endpoint_cases_from_config_json_shape() {
for (case, config, expected_endpoint) in [
(
"default account endpoint",
serde_json::json!({
"accountId": "account",
"credentials": {
"accessKeyId": "key",
"secretAccessKey": "secret"
}
}),
Some("https://account.r2.cloudflarestorage.com"),
),
(
"explicit null jurisdiction",
serde_json::json!({
"accountId": "account",
"jurisdiction": null,
"credentials": {
"accessKeyId": "key",
"secretAccessKey": "secret"
}
}),
Some("https://account.r2.cloudflarestorage.com"),
),
(
"eu jurisdiction",
serde_json::json!({
"accountId": "account",
"jurisdiction": "eu",
"credentials": {
"accessKeyId": "key",
"secretAccessKey": "secret"
}
}),
Some("https://account.eu.r2.cloudflarestorage.com"),
),
] {
let config = ObjectStorageConfig::from_r2_config(storage_config("cloudflare-r2", config))
.unwrap()
.unwrap();
assert_eq!(config.endpoint.as_deref(), expected_endpoint, "{case}");
assert!(config.force_path_style, "{case}");
}
assert!(
ObjectStorageConfig::from_r2_config(storage_config(
"cloudflare-r2",
serde_json::json!({
"credentials": {
"accessKeyId": "key",
"secretAccessKey": "secret"
}
})
))
.is_err()
);
}
#[test]
fn object_storage_not_found_requires_object_error_code() {
let bucket_or_route_missing = ObjectStorageError::HttpStatus {
context: "head failed".to_string(),
status: StatusCode::NOT_FOUND,
body: String::new(),
};
let object_missing = ObjectStorageError::HttpStatus {
context: "get failed".to_string(),
status: StatusCode::NOT_FOUND,
body: "<Error><Code>NoSuchKey</Code></Error>".to_string(),
};
let upload_missing = ObjectStorageError::HttpStatus {
context: "abort failed".to_string(),
status: StatusCode::NOT_FOUND,
body: "<Error><Code>NoSuchUpload</Code></Error>".to_string(),
};
assert!(!bucket_or_route_missing.is_not_found());
assert!(object_missing.is_not_found());
assert!(upload_missing.is_not_found());
}
#[test]
fn resolves_r2_proxy_upload_capability_from_config_json_shape() {
let storage = StorageProviderConfig {
provider: "cloudflare-r2".to_string(),
bucket: "workspace-blobs".to_string(),
config: serde_json::json!({
"accountId": "account",
"credentials": {
"accessKeyId": "key",
"secretAccessKey": "secret"
},
"usePresignedURL": {
"enabled": true,
"urlPrefix": "https://cdn.example.com",
"signKey": "secret"
}
}),
};
let config = ObjectStorageConfig::from_r2_config(storage).unwrap().unwrap();
assert!(config.use_presigned_url);
assert!(config.proxy_upload);
}
#[test]
fn resolves_s3_config_from_config_json_shape() {
let storage = StorageProviderConfig {
provider: "aws-s3".to_string(),
bucket: "workspace-blobs".to_string(),
config: serde_json::json!({
"region": "us-west-2",
"credentials": {
"accessKeyId": "key",
"secretAccessKey": "secret",
"sessionToken": "session"
},
"forcePathStyle": true,
"requestTimeoutMs": 1000,
"minPartSize": 1024,
"presign": {
"expiresInSeconds": 60,
"signContentTypeForPut": false
}
}),
};
let config = ObjectStorageConfig::from_s3_config(storage).unwrap().unwrap();
assert_eq!(config.provider, "aws-s3");
assert_eq!(config.endpoint.as_deref(), Some("https://s3.us-west-2.amazonaws.com"));
assert_eq!(config.session_token.as_deref(), Some("session"));
assert!(config.force_path_style);
assert_eq!(config.request_timeout_ms, Some(1000));
assert_eq!(config.min_part_size, Some(1024));
assert_eq!(config.presign_expires_in_seconds, Some(60));
assert_eq!(config.presign_sign_content_type_for_put, Some(false));
}
#[test]
fn resolves_s3_default_endpoint_cases_from_config_json_shape() {
for (region, expected_endpoint) in [
("us-east-1", "https://s3.amazonaws.com"),
("us-west-2", "https://s3.us-west-2.amazonaws.com"),
] {
let config = ObjectStorageConfig::from_s3_config(storage_config(
"aws-s3",
serde_json::json!({
"region": region,
"credentials": {
"accessKeyId": "key",
"secretAccessKey": "secret"
}
}),
))
.unwrap()
.unwrap();
assert_eq!(config.endpoint.as_deref(), Some(expected_endpoint), "{region}");
}
}
#[tokio::test]
async fn object_storage_presign_put_returns_sigv4_url_and_headers() {
let storage = StorageProviderConfig {
provider: "aws-s3".to_string(),
bucket: "test-bucket".to_string(),
config: serde_json::json!({
"region": "us-east-1",
"endpoint": "https://s3.us-east-1.amazonaws.com",
"credentials": {
"accessKeyId": "key",
"secretAccessKey": "secret"
},
"presign": {
"expiresInSeconds": 60
}
}),
};
let config = ObjectStorageConfig::from_s3_config(storage).unwrap().unwrap();
let Ok(Ok(client)) = std::panic::catch_unwind(|| config.build_client()) else {
eprintln!("skipping object storage presign test: S3 client cannot be built in this environment");
return;
};
let result = client
.presign_put(
"key",
ObjectPutMetadata {
content_type: Some("text/plain".to_string()),
..Default::default()
},
)
.await
.unwrap();
assert!(result.url.contains("X-Amz-Algorithm=AWS4-HMAC-SHA256"));
assert!(result.url.contains("X-Amz-SignedHeaders="));
assert_eq!(
result.headers.get("Content-Type").map(String::as_str),
Some("text/plain")
);
assert!(result.expires_at_ms > 0);
}
#[tokio::test]
async fn object_storage_presign_put_respects_content_length_and_signed_content_type_flag() {
let config = ObjectStorageConfig::from_s3_config(storage_config(
"aws-s3",
serde_json::json!({
"region": "us-east-1",
"endpoint": "https://s3.us-east-1.amazonaws.com",
"credentials": {
"accessKeyId": "key",
"secretAccessKey": "secret"
},
"presign": {
"expiresInSeconds": 60,
"signContentTypeForPut": false
}
}),
))
.unwrap()
.unwrap();
let client = config.build_client().unwrap();
let result = client
.presign_put(
"key",
ObjectPutMetadata {
content_type: Some("text/plain".to_string()),
content_length: Some(42),
..Default::default()
},
)
.await
.unwrap();
assert_eq!(
result.headers.get("Content-Type").map(String::as_str),
Some("text/plain")
);
assert_eq!(result.headers.get("Content-Length").map(String::as_str), Some("42"));
assert!(!result.url.contains("content-type"));
assert!(result.url.contains("content-length"));
}
#[tokio::test]
async fn object_storage_presign_get_returns_sigv4_url_without_headers() {
let storage = StorageProviderConfig {
provider: "cloudflare-r2".to_string(),
bucket: "test-bucket".to_string(),
config: serde_json::json!({
"accountId": "account",
"credentials": {
"accessKeyId": "key",
"secretAccessKey": "secret"
},
"presign": {
"expiresInSeconds": 60
}
}),
};
let config = ObjectStorageConfig::from_r2_config(storage).unwrap().unwrap();
let client = config.build_client().unwrap();
let result = client.presign_get("workspace/key").await.unwrap();
assert!(result.url.contains("X-Amz-Algorithm=AWS4-HMAC-SHA256"));
assert!(result.url.contains("X-Amz-SignedHeaders=host"));
assert!(result.url.contains("/test-bucket/workspace/key?"));
assert!(result.headers.is_empty());
assert!(result.expires_at_ms > 0);
}
#[tokio::test]
async fn object_storage_presign_upload_part_returns_sigv4_url() {
let config = ObjectStorageConfig::from_s3_config(storage_config(
"aws-s3",
serde_json::json!({
"region": "us-east-1",
"endpoint": "https://s3.us-east-1.amazonaws.com",
"credentials": {
"accessKeyId": "key",
"secretAccessKey": "secret"
},
"presign": {
"expiresInSeconds": 60
}
}),
))
.unwrap()
.unwrap();
let client = config.build_client().unwrap();
let result = client.presign_upload_part("key", "upload-1", 3).await.unwrap();
assert!(result.url.contains("X-Amz-Algorithm=AWS4-HMAC-SHA256"));
assert!(result.url.contains("partNumber=3"));
assert!(result.url.contains("uploadId=upload-1"));
assert!(result.headers.is_empty());
assert!(result.expires_at_ms > 0);
}
#[test]
fn object_storage_orders_completed_multipart_parts_and_trims_etags() {
let parts = completed_multipart_parts(vec![
MultipartUploadPart {
part_number: 2,
etag: trim_etag("\"b\""),
},
MultipartUploadPart {
part_number: 1,
etag: trim_etag("a"),
},
]);
assert_eq!(parts[0].part_number, 1);
assert_eq!(parts[0].etag, "a");
assert_eq!(parts[1].part_number, 2);
assert_eq!(parts[1].etag, "b");
}
#[test]
fn object_storage_crc32_checksum_uses_s3_base64_format() {
assert_eq!(checksum_crc32_base64(b"hello"), "NhCmhg==");
assert_ne!(checksum_crc32_base64(b"hello"), "3610a686");
}
@@ -0,0 +1,191 @@
use std::collections::HashMap;
use base64::{Engine as _, engine::general_purpose::STANDARD};
use serde::Deserialize;
use super::super::{
RuntimeError, RuntimeMultipartUploadInit, RuntimeMultipartUploadPart, RuntimeObjectGetResult, RuntimeObjectListEntry,
RuntimeObjectMetadata, RuntimeObjectStoragePutOptions, RuntimePresignedObjectRequest, RuntimeResult,
};
#[derive(Clone, Debug, Default)]
pub(crate) struct ObjectPutMetadata {
pub(crate) content_type: Option<String>,
pub(crate) content_length: Option<i64>,
pub(crate) checksum_crc32: Option<String>,
}
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct ObjectMetadata {
pub(crate) content_type: String,
pub(crate) content_length: i64,
pub(crate) last_modified_ms: i64,
pub(crate) checksum_crc32: Option<String>,
}
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct ObjectListEntry {
pub(crate) key: String,
pub(crate) content_length: i64,
pub(crate) last_modified_ms: i64,
}
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct ObjectListPage {
pub(crate) entries: Vec<ObjectListEntry>,
pub(crate) next_continuation_token: Option<String>,
}
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct ObjectDeleteOutcome {
pub(crate) key: String,
pub(crate) error: Option<String>,
}
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct ObjectGetResult {
pub(crate) body: Vec<u8>,
pub(crate) metadata: ObjectMetadata,
}
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct PresignedObjectRequest {
pub(crate) url: String,
pub(crate) headers: HashMap<String, String>,
pub(crate) expires_at_ms: i64,
}
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct MultipartUploadInitResult {
pub(crate) upload_id: String,
pub(crate) expires_at_ms: i64,
}
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct MultipartUploadPart {
pub(crate) part_number: i32,
pub(crate) etag: String,
}
#[derive(Clone, Debug, Deserialize)]
pub(crate) struct StorageProviderConfig {
pub(crate) provider: String,
pub(crate) bucket: String,
#[serde(default)]
pub(crate) config: serde_json::Value,
}
pub(crate) fn trim_etag(etag: &str) -> String {
etag.trim_matches('"').to_string()
}
pub(crate) fn completed_multipart_parts(mut parts: Vec<MultipartUploadPart>) -> Vec<MultipartUploadPart> {
parts.sort_by_key(|part| part.part_number);
parts
}
impl From<RuntimeObjectStoragePutOptions> for ObjectPutMetadata {
fn from(options: RuntimeObjectStoragePutOptions) -> Self {
Self {
content_type: options.content_type,
content_length: options.content_length,
checksum_crc32: options.checksum_crc32,
}
}
}
impl ObjectPutMetadata {
pub(crate) fn complete_for_body(mut self, body: &[u8]) -> Self {
self.content_length.get_or_insert(body.len() as i64);
self.checksum_crc32.get_or_insert_with(|| checksum_crc32_base64(body));
self
.content_type
.get_or_insert_with(|| crate::file_type::get_mime(body));
self
}
pub(crate) fn into_object_metadata(self, last_modified_ms: i64) -> ObjectMetadata {
ObjectMetadata {
content_type: self
.content_type
.unwrap_or_else(|| "application/octet-stream".to_string()),
content_length: self.content_length.unwrap_or(0),
last_modified_ms,
checksum_crc32: self.checksum_crc32,
}
}
}
pub(crate) fn checksum_crc32_base64(body: &[u8]) -> String {
STANDARD.encode(crc32fast::hash(body).to_be_bytes())
}
impl From<ObjectMetadata> for RuntimeObjectMetadata {
fn from(metadata: ObjectMetadata) -> Self {
Self {
content_type: metadata.content_type,
content_length: metadata.content_length,
last_modified_ms: metadata.last_modified_ms,
checksum_crc32: metadata.checksum_crc32,
}
}
}
impl From<ObjectListEntry> for RuntimeObjectListEntry {
fn from(entry: ObjectListEntry) -> Self {
Self {
key: entry.key,
content_length: entry.content_length,
last_modified_ms: entry.last_modified_ms,
}
}
}
impl TryFrom<PresignedObjectRequest> for RuntimePresignedObjectRequest {
type Error = RuntimeError;
fn try_from(request: PresignedObjectRequest) -> RuntimeResult<Self> {
Ok(Self {
url: request.url,
headers_json: serde_json::to_string(&request.headers)
.map_err(|err| RuntimeError::json("ObjectStorage headers serialization failed", err))?,
expires_at_ms: request.expires_at_ms,
})
}
}
impl From<ObjectGetResult> for RuntimeObjectGetResult {
fn from(result: ObjectGetResult) -> Self {
Self {
body: result.body.into(),
metadata: result.metadata.into(),
}
}
}
impl From<MultipartUploadInitResult> for RuntimeMultipartUploadInit {
fn from(init: MultipartUploadInitResult) -> Self {
Self {
upload_id: init.upload_id,
expires_at_ms: init.expires_at_ms,
}
}
}
impl From<RuntimeMultipartUploadPart> for MultipartUploadPart {
fn from(part: RuntimeMultipartUploadPart) -> Self {
Self {
part_number: part.part_number,
etag: part.etag,
}
}
}
impl From<MultipartUploadPart> for RuntimeMultipartUploadPart {
fn from(part: MultipartUploadPart) -> Self {
Self {
part_number: part.part_number,
etag: part.etag,
}
}
}
@@ -12,24 +12,6 @@ pub struct RuntimeVerificationTokenRecord {
pub struct BackendRuntimeHealth {
pub started: bool,
pub database_connected: bool,
pub object_storage_configured: bool,
}
#[napi_derive::napi(object)]
pub struct RuntimeObjectStorageHealth {
pub configured: bool,
pub provider: Option<String>,
pub bucket: Option<String>,
pub endpoint: Option<String>,
pub region: Option<String>,
pub has_credentials: bool,
pub force_path_style: bool,
pub request_timeout_ms: Option<i64>,
pub min_part_size: Option<i64>,
pub presign_expires_in_seconds: Option<i64>,
pub presign_sign_content_type_for_put: Option<bool>,
pub use_presigned_url: bool,
pub client_buildable: bool,
}
#[napi_derive::napi(object)]
@@ -138,6 +120,49 @@ pub struct RuntimeBlobCompleteResult {
pub last_modified_ms: Option<i64>,
}
#[napi_derive::napi(object)]
pub struct RuntimeBlobMetadataBackfillResult {
pub scanned_objects: i64,
pub headed_objects: i64,
pub upserted_metadata: i64,
pub skipped_existing: i64,
pub skipped_workspace_missing: i64,
pub failed: i64,
pub next_cursor: Option<String>,
pub workspace_ids: Vec<String>,
}
#[napi_derive::napi(object)]
pub struct RuntimeDocBlobRefsResult {
pub scanned_docs: i64,
pub parsed_docs: i64,
pub refs_written: i64,
pub refs_deleted: i64,
pub failed_docs: i64,
pub next_cursor: Option<String>,
}
#[napi_derive::napi(object)]
pub struct RuntimeBlobCleanupPlanResult {
pub run_id: Option<String>,
pub scanned_blobs: i64,
pub candidates_marked: i64,
pub protected_by_doc_refs: i64,
pub protected_by_metadata: i64,
pub protected_by_other_refs: i64,
pub next_cursor: Option<String>,
}
#[napi_derive::napi(object)]
pub struct RuntimeBlobCleanupExecuteResult {
pub scanned_candidates: i64,
pub deleted_objects: i64,
pub deleted_metadata: i64,
pub skipped_still_referenced: i64,
pub failed: i64,
pub workspace_ids: Vec<String>,
}
#[napi_derive::napi(object)]
pub struct RuntimeDocCompactionResult {
pub lease_acquired: bool,
+6
View File
@@ -1,3 +1,9 @@
use std::time::{SystemTime, SystemTimeError, UNIX_EPOCH};
pub(crate) fn system_time_millis(time: SystemTime) -> Result<u128, SystemTimeError> {
Ok(time.duration_since(UNIX_EPOCH)?.as_millis())
}
fn collapse_whitespace(s: &str) -> String {
let mut result = String::new();
let mut prev_was_whitespace = false;
@@ -2,17 +2,25 @@ import { ScheduleModule } from '@nestjs/schedule';
import { TestingModule } from '@nestjs/testing';
import { PrismaClient } from '@prisma/client';
import test from 'ava';
import Sinon from 'sinon';
import { AuthModule, AuthService } from '../../core/auth';
import { AuthCronJob } from '../../core/auth/job';
import { BackendRuntimeProvider } from '../../core/backend-runtime';
import { createTestingModule } from '../utils';
let m: TestingModule;
let db: PrismaClient;
const runtime = {
cleanupExpiredUserSessions: Sinon.stub(),
};
test.before(async () => {
m = await createTestingModule({
imports: [ScheduleModule.forRoot(), AuthModule],
tapModule: builder => {
builder.overrideProvider(BackendRuntimeProvider).useValue(runtime);
},
});
db = m.get(PrismaClient);
@@ -32,16 +40,17 @@ test('should clean expired user sessions', async t => {
let userSessions = await db.userSession.findMany();
t.is(userSessions.length, 2);
// no expired sessions
runtime.cleanupExpiredUserSessions.reset();
runtime.cleanupExpiredUserSessions.resolves(0);
await job.cleanExpiredUserSessions();
userSessions = await db.userSession.findMany();
t.is(userSessions.length, 2);
t.true(runtime.cleanupExpiredUserSessions.calledOnce);
t.deepEqual(runtime.cleanupExpiredUserSessions.firstCall.args, [1000]);
// clean all expired sessions
await db.userSession.updateMany({
data: { expiresAt: new Date(Date.now() - 1000) },
});
runtime.cleanupExpiredUserSessions.reset();
runtime.cleanupExpiredUserSessions.onCall(0).resolves(1000);
runtime.cleanupExpiredUserSessions.onCall(1).resolves(2);
await job.cleanExpiredUserSessions();
userSessions = await db.userSession.findMany();
t.is(userSessions.length, 0);
t.is(runtime.cleanupExpiredUserSessions.callCount, 2);
t.deepEqual(runtime.cleanupExpiredUserSessions.firstCall.args, [1000]);
t.deepEqual(runtime.cleanupExpiredUserSessions.secondCall.args, [1000]);
});
@@ -1,13 +1,16 @@
import { randomUUID } from 'node:crypto';
import { Global, Module } from '@nestjs/common';
import type { Prisma } from '@prisma/client';
import type { ExecutionContext, TestFn } from 'ava';
import ava from 'ava';
import Sinon from 'sinon';
import { z } from 'zod';
import { AppModuleBuilder, FunctionalityModules } from '../../app.module';
import { JobModule, JobQueue } from '../../base';
import { ServerFeature, ServerService } from '../../core';
import { AuthService } from '../../core/auth';
import { AuthModule, AuthService } from '../../core/auth';
import { QuotaModule } from '../../core/quota';
import { Models } from '../../models';
import { llmImageDispatchPlan } from '../../native';
@@ -25,6 +28,7 @@ import { ChatSession, ChatSessionService } from '../../plugins/copilot/session';
import { TranscriptPayloadSchema } from '../../plugins/copilot/transcript/schema';
import { CopilotTranscriptionService } from '../../plugins/copilot/transcript/service';
import { TestingPromptService } from '../mocks/prompt-service.mock';
import { MockJobQueue } from '../mocks/queue.mock';
import { createTestingModule, TestingModule } from '../utils';
import { TestAssets } from '../utils/copilot';
import {
@@ -48,6 +52,13 @@ type Tester = {
const test = ava as TestFn<Tester>;
@Global()
@Module({
providers: [{ provide: JobQueue, useClass: MockJobQueue }],
exports: [JobQueue],
})
class MockJobModule {}
let isCopilotConfigured = false;
const runIfCopilotConfigured = test.macro(
async (
@@ -64,8 +75,20 @@ const runIfCopilotConfigured = test.macro(
);
test.serial.before(async t => {
const appModule = new AppModuleBuilder()
.use(
...FunctionalityModules.filter(module => {
const moduleType = 'module' in module ? module.module : module;
return moduleType !== JobModule;
}),
MockJobModule,
AuthModule,
QuotaModule,
CopilotModule
)
.compile();
const module = await createTestingModule({
imports: [QuotaModule, CopilotModule],
imports: [appModule],
tapModule: builder => {
builder.overrideProvider(PromptService).useClass(TestingPromptService);
},
@@ -156,8 +156,6 @@ test.before(async t => {
CopilotModule,
],
tapModule: builder => {
// use real JobQueue for testing
builder.overrideProvider(JobQueue).useClass(JobQueue);
builder.overrideProvider(RequestMutex).useValue({
acquire: async () => ({
async [Symbol.asyncDispose]() {},
@@ -811,7 +809,9 @@ test('should schedule title generation as a background job', async t => {
const chatSession = await session.get(sessionId);
t.truthy(chatSession);
const addJob = Sinon.stub(jobs, 'add').resolves();
const addJob = jobs.add as Sinon.SinonStub;
addJob.resetHistory();
addJob.resolves();
chatSession!.pushTurn(
buildTurn(sessionId, {
@@ -1835,6 +1835,14 @@ test('should be able to manage workspace embedding', async t => {
fileId: file.fileId,
fileName: file.fileName,
});
await jobs.embedPendingFile({
userId,
workspaceId: ws.id,
contextId: undefined,
blobId,
fileId: file.fileId,
fileName: file.fileName,
});
let ret = 0;
while (!ret) {
@@ -2059,7 +2067,9 @@ test('should handle copilot cron jobs correctly', async t => {
copilotSession,
'toBeGenerateTitle'
).resolves(mockSessions);
const jobAddStub = Sinon.stub(cronJobs['jobs'], 'add').resolves();
const jobAddStub = cronJobs['jobs'].add as Sinon.SinonStub;
jobAddStub.resetHistory();
jobAddStub.resolves();
// daily cleanup job scheduling
{
@@ -2107,7 +2117,7 @@ test('should handle copilot cron jobs correctly', async t => {
cleanupStub.restore();
toBeGenerateStub.restore();
jobAddStub.restore();
jobAddStub.resetHistory();
});
test('model selection policy should resolve requested optional models consistently', async t => {
@@ -7,8 +7,13 @@ import {
import { PrismaClient } from '@prisma/client';
import { FunctionalityModules } from '../app.module';
import { AFFiNELogger, EventBus, JobQueue } from '../base';
import { createFactory, MockEventBus, MockJobQueue } from './mocks';
import { AFFiNELogger, EventBus, JobModule, JobQueue } from '../base';
import {
createFactory,
MockEventBus,
MockJobModule,
MockJobQueue,
} from './mocks';
import { TEST_LOG_LEVEL } from './utils';
interface TestingModuleMetadata extends ModuleMetadata {
@@ -26,10 +31,17 @@ export async function createModule(
metadata: TestingModuleMetadata = {}
): Promise<TestingModule> {
const { tapModule, ...meta } = metadata;
const functionalityModules = [
...FunctionalityModules.filter(module => {
const moduleType = 'module' in module ? module.module : module;
return moduleType !== JobModule;
}),
MockJobModule,
];
const builder = Test.createTestingModule({
...meta,
imports: [...FunctionalityModules, ...(meta.imports ?? [])],
imports: [...functionalityModules, ...(meta.imports ?? [])],
});
builder
@@ -1,7 +1,9 @@
import { ScheduleModule } from '@nestjs/schedule';
import { PrismaClient } from '@prisma/client';
import ava, { TestFn } from 'ava';
import Sinon from 'sinon';
import { BackendRuntimeProvider } from '../../core/backend-runtime';
import { DocStorageModule } from '../../core/doc';
import { DocStorageCronJob } from '../../core/doc/job';
import { createTestingModule, type TestingModule } from '../utils';
@@ -10,14 +12,23 @@ interface Context {
module: TestingModule;
db: PrismaClient;
cronJob: DocStorageCronJob;
runtime: { cleanupExpiredSnapshotHistories: Sinon.SinonStub };
}
const test = ava as TestFn<Context>;
// cleanup database before each test
test.before(async t => {
t.context.runtime = {
cleanupExpiredSnapshotHistories: Sinon.stub(),
};
t.context.module = await createTestingModule({
imports: [ScheduleModule.forRoot(), DocStorageModule],
tapModule: builder => {
builder
.overrideProvider(BackendRuntimeProvider)
.useValue(t.context.runtime);
},
});
t.context.db = t.context.module.get(PrismaClient);
@@ -26,6 +37,7 @@ test.before(async t => {
test.beforeEach(async t => {
await t.context.module.initTestingDB();
t.context.runtime.cleanupExpiredSnapshotHistories.reset();
});
test.after.always(async t => {
@@ -33,7 +45,7 @@ test.after.always(async t => {
});
test('should be able to cleanup expired history', async t => {
const { db } = t.context;
const { db, runtime } = t.context;
const timestamp = Date.now();
// insert expired data
@@ -65,12 +77,10 @@ test('should be able to cleanup expired history', async t => {
let count = await db.snapshotHistory.count();
t.is(count, 20);
runtime.cleanupExpiredSnapshotHistories.onCall(0).resolves(1000);
runtime.cleanupExpiredSnapshotHistories.onCall(1).resolves(10);
await t.context.cronJob.cleanExpiredHistories();
count = await db.snapshotHistory.count();
t.is(count, 10);
const example = await db.snapshotHistory.findFirst();
t.truthy(example);
t.true(example!.expiredAt > new Date());
t.is(runtime.cleanupExpiredSnapshotHistories.callCount, 2);
});
@@ -13,6 +13,7 @@ import {
AFFiNELogger,
CacheInterceptor,
CloudThrottlerGuard,
ConfigFactory,
EventBus,
GlobalExceptionFilter,
JobQueue,
@@ -250,6 +251,31 @@ export async function createApp(
}
const module = await builder.compile();
module.get(ConfigFactory).override({
storages: {
avatar: {
storage: {
provider: 'assetpack',
bucket: 'avatars',
config: { path: '/tmp/affine-test-storage' },
},
},
blob: {
storage: {
provider: 'assetpack',
bucket: 'blobs',
config: { path: '/tmp/affine-test-storage' },
},
},
},
copilot: {
storage: {
provider: 'assetpack',
bucket: 'copilot',
config: { path: '/tmp/affine-test-storage' },
},
},
});
module.useCustomApplicationConstructor(TestingApp);
@@ -1,4 +1,4 @@
import { createHash } from 'node:crypto';
import { createHash, createHmac } from 'node:crypto';
import { mock } from 'node:test';
import {
@@ -6,22 +6,13 @@ import {
ConfigFactory,
PROXY_MULTIPART_PATH,
PROXY_UPLOAD_PATH,
StorageProviderConfig,
StorageProviderFactory,
toBuffer,
type R2StorageConfig,
SIGNED_URL_EXPIRED,
type StorageProviderConfig,
} from '../../../base';
import {
R2StorageConfig,
R2StorageProvider,
} from '../../../base/storage/providers/r2';
import { SIGNED_URL_EXPIRED } from '../../../base/storage/providers/utils';
import { EntitlementService } from '../../../core/entitlement';
import {
CommentAttachmentStorage,
WorkspaceBlobStorage,
} from '../../../core/storage';
import { MULTIPART_THRESHOLD } from '../../../core/storage/constants';
import { R2UploadController } from '../../../core/storage/r2-proxy';
import { StorageRuntimeProvider } from '../../../core/storage-runtime';
import {
SubscriptionPlan,
SubscriptionRecurring,
@@ -29,7 +20,7 @@ import {
} from '../../../plugins/payment/types';
import { app, e2e, Mockers } from '../test';
class MockR2Provider extends R2StorageProvider {
class MockStorageRuntime {
createMultipartCalls = 0;
putCalls: {
key: string;
@@ -46,45 +37,72 @@ class MockR2Provider extends R2StorageProvider {
contentLength?: number;
}[] = [];
constructor(config: R2StorageConfig, bucket: string) {
super(config, bucket);
async providerCapabilities() {
const storage = app.get(Config).storages.blob.storage;
const usePresignedURL = (storage.config as R2StorageConfig).usePresignedURL;
if (storage.provider !== 'cloudflare-r2') {
return {
put: true,
get: true,
head: true,
list: true,
delete: true,
presignPut: false,
presignGet: false,
multipartDirect: false,
proxyUpload: false,
assetpack: false,
serverMediatedOnly: true,
};
}
return {
put: true,
get: true,
head: true,
list: true,
delete: true,
presignPut: true,
presignGet: false,
multipartDirect: true,
proxyUpload: !!usePresignedURL?.enabled,
assetpack: false,
serverMediatedOnly: false,
};
}
destroy() {}
override async proxyPutObject(
async presignPut(
_scope: string,
key: string,
body: any,
options: { contentType?: string; contentLength?: number } = {}
metadata: { contentType?: string; contentLength?: number } = {}
) {
this.putCalls.push({
key,
body: await toBuffer(body),
contentType: options.contentType,
contentLength: options.contentLength,
});
const storage = app.get(Config).storages.blob.storage;
const r2 = storage.config as R2StorageConfig;
if (!r2.usePresignedURL?.enabled) {
return {
url: 'https://test-bucket.r2.example.com/object?X-Amz-Algorithm=AWS4-HMAC-SHA256',
headers: {},
expiresAt: new Date(Date.now() + SIGNED_URL_EXPIRED * 1000),
};
}
const [workspaceId, blobKey] = key.split('/');
return createProxyUrl(
PROXY_UPLOAD_PATH,
[
workspaceId,
blobKey,
metadata.contentType ?? 'application/octet-stream',
metadata.contentLength,
],
{
workspaceId,
key: blobKey,
contentType: metadata.contentType ?? 'application/octet-stream',
contentLength: metadata.contentLength,
}
);
}
override async proxyUploadPart(
key: string,
uploadId: string,
partNumber: number,
body: any,
options: { contentLength?: number } = {}
) {
const etag = `etag-${partNumber}`;
this.partCalls.push({
key,
uploadId,
partNumber,
etag,
body: await toBuffer(body),
contentLength: options.contentLength,
});
return etag;
}
override async createMultipartUpload() {
async createMultipartUpload() {
this.createMultipartCalls += 1;
return {
uploadId: 'upload-id',
@@ -92,7 +110,30 @@ class MockR2Provider extends R2StorageProvider {
};
}
override async listMultipartUploadParts(key: string, uploadId: string) {
async presignUploadPart(
_scope: string,
key: string,
uploadId: string,
partNumber: number
) {
const [workspaceId, blobKey] = key.split('/');
return createProxyUrl(
PROXY_MULTIPART_PATH,
[workspaceId, blobKey, uploadId, partNumber],
{
workspaceId,
key: blobKey,
uploadId,
partNumber,
}
);
}
async listMultipartUploadParts(
_scope: string,
key: string,
uploadId: string
) {
const latest = new Map<number, string>();
for (const part of this.partCalls) {
if (part.key !== key || part.uploadId !== uploadId) {
@@ -104,6 +145,45 @@ class MockR2Provider extends R2StorageProvider {
.sort((left, right) => left[0] - right[0])
.map(([partNumber, etag]) => ({ partNumber, etag }));
}
async putObject(
_scope: string,
key: string,
body: Buffer,
options: { contentType?: string; contentLength?: number } = {}
) {
this.putCalls.push({
key,
body,
contentType: options.contentType,
contentLength: options.contentLength,
});
return {
contentType: options.contentType ?? 'application/octet-stream',
contentLength: options.contentLength ?? body.length,
lastModified: new Date(),
};
}
async proxyUploadPart(
_scope: string,
key: string,
uploadId: string,
partNumber: number,
body: Buffer,
contentLength?: number
) {
const etag = `etag-${partNumber}`;
this.partCalls.push({
key,
uploadId,
partNumber,
etag,
body,
contentLength,
});
return etag;
}
}
const baseR2Storage: StorageProviderConfig = {
@@ -125,55 +205,40 @@ const baseR2Storage: StorageProviderConfig = {
};
let defaultBlobStorage: StorageProviderConfig;
let provider: MockR2Provider | null = null;
let factoryCreateUnmocked: StorageProviderFactory['create'];
let runtime: MockStorageRuntime;
e2e.before(() => {
defaultBlobStorage = structuredClone(app.get(Config).storages.blob.storage);
const factory = app.get(StorageProviderFactory);
factoryCreateUnmocked = factory.create.bind(factory);
});
e2e.beforeEach(async () => {
provider?.destroy();
provider = null;
const factory = app.get(StorageProviderFactory);
mock.method(factory, 'create', (config: StorageProviderConfig) => {
if (config.provider === 'cloudflare-r2') {
if (!provider) {
provider = new MockR2Provider(
config.config as R2StorageConfig,
config.bucket
);
}
return provider;
}
return factoryCreateUnmocked(config);
});
runtime = new MockStorageRuntime();
const rt = app.get(StorageRuntimeProvider);
for (const method of [
'providerCapabilities',
'presignPut',
'createMultipartUpload',
'presignUploadPart',
'listMultipartUploadParts',
'putObject',
'proxyUploadPart',
] as const) {
mock.method(rt, method, (...args: any[]) =>
(runtime[method] as any)(...args)
);
}
await useR2Storage();
});
e2e.afterEach.always(async () => {
await setBlobStorage(defaultBlobStorage);
provider?.destroy();
provider = null;
mock.reset();
});
async function setBlobStorage(storage: StorageProviderConfig) {
provider?.destroy();
provider = null;
const configFactory = app.get(ConfigFactory);
configFactory.override({ storages: { blob: { storage } } });
const blobStorage = app.get(WorkspaceBlobStorage);
await blobStorage.onConfigInit();
const commentAttachmentStorage = app.get(CommentAttachmentStorage);
await commentAttachmentStorage.onConfigInit();
const controller = app.get(R2UploadController);
// reset cached provider in controller
(controller as any).provider = null;
}
async function useR2Storage(
@@ -193,11 +258,8 @@ async function useR2Storage(
return storage;
}
function getProvider(): MockR2Provider {
if (!provider) {
throw new Error('R2 provider is not initialized');
}
return provider;
function getRuntime(): MockStorageRuntime {
return runtime;
}
async function createBlobUpload(
@@ -285,7 +347,7 @@ async function gql<QueryData = any>(
return res.body.data;
}
e2e('should proxy single upload with valid signature', async t => {
e2e.serial('should proxy single upload with valid signature', async t => {
const { workspace } = await setupWorkspace();
const buffer = Buffer.from('r2-proxy');
const key = sha256Base64urlWithPadding(buffer);
@@ -300,6 +362,7 @@ e2e('should proxy single upload with valid signature', async t => {
t.is(init.method, 'PRESIGNED');
t.truthy(init.uploadUrl);
const uploadUrl = new URL(init.uploadUrl, app.url);
t.is(uploadUrl.origin, 'https://cdn.example.com');
t.is(uploadUrl.pathname, PROXY_UPLOAD_PATH);
const res = await app
@@ -309,7 +372,7 @@ e2e('should proxy single upload with valid signature', async t => {
.send(buffer);
t.is(res.status, 200);
const calls = getProvider().putCalls;
const calls = getRuntime().putCalls;
t.is(calls.length, 1);
t.is(calls[0].key, `${workspace.id}/${key}`);
t.is(calls[0].contentType, 'text/plain');
@@ -317,7 +380,7 @@ e2e('should proxy single upload with valid signature', async t => {
t.deepEqual(calls[0].body, buffer);
});
e2e('should proxy multipart upload and return etag', async t => {
e2e.serial('should proxy multipart upload and return etag', async t => {
const { workspace } = await setupWorkspace();
const key = 'multipart-object';
const totalSize = MULTIPART_THRESHOLD + 1024;
@@ -329,6 +392,7 @@ e2e('should proxy multipart upload and return etag', async t => {
const part = await getBlobUploadPartUrl(workspace.id, key, init.uploadId, 1);
const partUrl = new URL(part.uploadUrl, app.url);
t.is(partUrl.origin, 'https://cdn.example.com');
t.is(partUrl.pathname, PROXY_MULTIPART_PATH);
const payload = Buffer.from('part-body');
@@ -340,7 +404,7 @@ e2e('should proxy multipart upload and return etag', async t => {
t.is(res.status, 200);
t.is(res.get('etag'), 'etag-1');
const calls = getProvider().partCalls;
const calls = getRuntime().partCalls;
t.is(calls.length, 1);
t.is(calls[0].key, `${workspace.id}/${key}`);
t.is(calls[0].uploadId, 'upload-id');
@@ -349,34 +413,42 @@ e2e('should proxy multipart upload and return etag', async t => {
t.deepEqual(calls[0].body, payload);
});
e2e('should resume multipart upload and return uploaded parts', async t => {
const { workspace } = await setupWorkspace();
const key = 'multipart-resume';
const totalSize = MULTIPART_THRESHOLD + 1024;
e2e.serial(
'should resume multipart upload and return uploaded parts',
async t => {
const { workspace } = await setupWorkspace();
const key = 'multipart-resume';
const totalSize = MULTIPART_THRESHOLD + 1024;
const init1 = await createBlobUpload(workspace.id, key, totalSize, 'bin');
t.is(init1.method, 'MULTIPART');
t.is(init1.uploadId, 'upload-id');
t.deepEqual(init1.uploadedParts, []);
t.is(getProvider().createMultipartCalls, 1);
const init1 = await createBlobUpload(workspace.id, key, totalSize, 'bin');
t.is(init1.method, 'MULTIPART');
t.is(init1.uploadId, 'upload-id');
t.deepEqual(init1.uploadedParts, []);
t.is(getRuntime().createMultipartCalls, 1);
const part = await getBlobUploadPartUrl(workspace.id, key, init1.uploadId, 1);
const payload = Buffer.from('part-body');
const partUrl = new URL(part.uploadUrl, app.url);
await app
.PUT(partUrl.pathname + partUrl.search)
.set('content-length', payload.length.toString())
.send(payload)
.expect(200);
const part = await getBlobUploadPartUrl(
workspace.id,
key,
init1.uploadId,
1
);
const payload = Buffer.from('part-body');
const partUrl = new URL(part.uploadUrl, app.url);
await app
.PUT(partUrl.pathname + partUrl.search)
.set('content-length', payload.length.toString())
.send(payload)
.expect(200);
const init2 = await createBlobUpload(workspace.id, key, totalSize, 'bin');
t.is(init2.method, 'MULTIPART');
t.is(init2.uploadId, 'upload-id');
t.deepEqual(init2.uploadedParts, [{ partNumber: 1, etag: 'etag-1' }]);
t.is(getProvider().createMultipartCalls, 1);
});
const init2 = await createBlobUpload(workspace.id, key, totalSize, 'bin');
t.is(init2.method, 'MULTIPART');
t.is(init2.uploadId, 'upload-id');
t.deepEqual(init2.uploadedParts, [{ partNumber: 1, etag: 'etag-1' }]);
t.is(getRuntime().createMultipartCalls, 1);
}
);
e2e('should reject upload when token is invalid', async t => {
e2e.serial('should reject upload when token is invalid', async t => {
const { workspace } = await setupWorkspace();
const buffer = Buffer.from('payload');
const init = await createBlobUpload(
@@ -396,10 +468,10 @@ e2e('should reject upload when token is invalid', async t => {
t.is(res.status, 400);
t.is(res.body.message, 'Invalid upload token');
t.is(getProvider().putCalls.length, 0);
t.is(getRuntime().putCalls.length, 0);
});
e2e('should reject upload when url is expired', async t => {
e2e.serial('should reject upload when url is expired', async t => {
const { workspace } = await setupWorkspace();
const buffer = Buffer.from('expired');
const init = await createBlobUpload(
@@ -422,10 +494,10 @@ e2e('should reject upload when url is expired', async t => {
t.is(res.status, 400);
t.is(res.body.message, 'Upload URL expired');
t.is(getProvider().putCalls.length, 0);
t.is(getRuntime().putCalls.length, 0);
});
e2e(
e2e.serial(
'should fall back to direct presign when custom domain is disabled',
async t => {
await useR2Storage({
@@ -449,7 +521,7 @@ e2e(
}
);
e2e(
e2e.serial(
'should still fallback to graphql when provider does not support presign',
async t => {
await setBlobStorage({
@@ -473,6 +545,40 @@ e2e(
}
);
function createProxyUrl(
path: string,
canonicalFields: (string | number | undefined)[],
query: Record<string, string | number | undefined>
) {
const signKey = (
app.get(Config).storages.blob.storage.config as R2StorageConfig
).usePresignedURL?.signKey;
if (!signKey) {
throw new Error('missing R2 proxy sign key');
}
const exp = Math.floor(Date.now() / 1000) + SIGNED_URL_EXPIRED;
const canonical = [
path,
...canonicalFields.map(field =>
field === undefined ? '' : field.toString()
),
exp.toString(),
].join('\n');
const token = createHmac('sha256', signKey)
.update(canonical)
.digest('base64');
const url = new URL(`http://localhost${path}`);
for (const [key, value] of Object.entries(query)) {
if (value !== undefined) {
url.searchParams.set(key, value.toString());
}
}
url.searchParams.set('exp', exp.toString());
url.searchParams.set('token', `${exp}-${token}`);
return { url: url.pathname + url.search, expiresAt: new Date(exp * 1000) };
}
function sha256Base64urlWithPadding(buffer: Buffer) {
return createHash('sha256')
.update(buffer)
@@ -1,4 +1,5 @@
import { randomUUID } from 'node:crypto';
import { Readable } from 'node:stream';
import { mock } from 'node:test';
import {
@@ -7,6 +8,8 @@ import {
type StorageProviderConfig,
} from '../../../base';
import { CommentAttachmentStorage } from '../../../core/storage';
import { StorageRuntimeProvider } from '../../../core/storage-runtime';
import { getMime } from '../../../native';
import { Mockers } from '../../mocks';
import { app, e2e } from '../test';
@@ -26,14 +29,68 @@ e2e.afterEach.always(() => {
mock.reset();
});
const objects = new Map<
string,
{
body: Buffer;
metadata?: {
contentLength?: number;
contentType?: string;
checksumCRC32?: string;
lastModified?: Date;
};
}
>();
e2e.beforeEach(() => {
objects.clear();
const rt = app.get(StorageRuntimeProvider);
mock.method(
rt,
'putObject',
async (
_scope: string,
key: string,
body: Buffer,
metadata?: {
contentLength?: number;
contentType?: string;
checksumCRC32?: string;
}
) => {
const object = {
body,
metadata: {
...metadata,
contentType: metadata?.contentType ?? getMime(body),
contentLength: metadata?.contentLength ?? body.length,
lastModified: new Date(),
},
};
objects.set(key, object);
return object.metadata;
}
);
mock.method(rt, 'getObject', async (_scope: string, key: string) => {
const object = objects.get(key);
if (!object) {
return {};
}
return {
body: Readable.from(object.body),
metadata: object.metadata,
};
});
mock.method(rt, 'presignGet', async () => undefined);
});
async function useCommentAttachmentBlobStorage(storage: StorageProviderConfig) {
app.get(ConfigFactory).override({ storages: { blob: { storage } } });
await app.get(CommentAttachmentStorage).onConfigInit();
}
// #region comment attachment
e2e(
e2e.serial(
'should get comment attachment not found when key is not exists',
async t => {
const { owner, workspace } = await createWorkspace();
@@ -50,7 +107,7 @@ e2e(
}
);
e2e(
e2e.serial(
'should get comment attachment no permission when user is not member',
async t => {
const { workspace } = await createWorkspace();
@@ -117,7 +174,7 @@ e2e.serial('should get comment attachment body', async t => {
}
});
e2e('should get comment attachment redirect url', async t => {
e2e.serial('should get comment attachment redirect url', async t => {
const { owner, workspace } = await createWorkspace();
await app.login(owner);
@@ -187,6 +187,7 @@ e2e('should allocate seats', async t => {
source: 'Link',
});
const invitationCount = app.queue.count('notification.sendInvitation');
await app.eventBus.emitAsync('workspace.members.allocateSeats', {
workspaceId: workspace.id,
quantity: 5,
@@ -206,7 +207,7 @@ e2e('should allocate seats', async t => {
WorkspaceMemberStatus.Accepted
);
t.is(app.queue.count('notification.sendInvitation'), 1);
t.is(app.queue.count('notification.sendInvitation') - invitationCount, 1);
});
e2e('should set all rests to NeedMoreSeat', async t => {
@@ -12,7 +12,7 @@ import { MockDocSnapshot } from './doc-snapshot.mock';
import { MockDocUser } from './doc-user.mock';
import { MockEventBus } from './eventbus.mock';
import { MockMailer } from './mailer.mock';
import { MockJobQueue } from './queue.mock';
import { MockJobModule, MockJobQueue } from './queue.mock';
import { MockTeamWorkspace } from './team-workspace.mock';
import { MockUser } from './user.mock';
import { MockUserSettings } from './user-settings.mock';
@@ -35,6 +35,7 @@ export {
installMockCopilotRuntime,
MockCopilotProvider,
MockEventBus,
MockJobModule,
MockJobQueue,
MockMailer,
};
@@ -1,12 +1,16 @@
import { Global, Module } from '@nestjs/common';
import { interval, map, take, takeUntil } from 'rxjs';
import Sinon from 'sinon';
import { JobQueue } from '../../base';
export class MockJobQueue {
add = Sinon.createStubInstance(JobQueue).add.resolves();
remove = Sinon.createStubInstance(JobQueue).remove.resolves();
removeWhere = Sinon.createStubInstance(JobQueue).removeWhere.resolves([]);
private readonly sandbox = Sinon.createSandbox();
add = this.sandbox.stub().resolves();
get = this.sandbox.stub().resolves();
remove = this.sandbox.stub().resolves();
removeWhere = this.sandbox.stub().resolves([]);
last<Job extends JobName>(name: Job): { name: Job; payload: Jobs[Job] } {
const addJobName = this.add.lastCall?.args[0];
@@ -57,3 +61,10 @@ export class MockJobQueue {
return this.add.getCalls().filter(call => call.args[0] === name).length;
}
}
@Global()
@Module({
providers: [{ provide: JobQueue, useClass: MockJobQueue }],
exports: [JobQueue],
})
export class MockJobModule {}
@@ -303,28 +303,3 @@ test('should delete userSession fail when sessionId not match', async t => {
);
t.is(count, 0);
});
test('should cleanup expired userSessions', async t => {
const user = await t.context.user.create({
email: 'test@affine.pro',
});
const session = await t.context.db.session.create({
data: {},
});
const userSession = await t.context.session.createOrRefreshUserSession(
user.id,
session.id
);
await t.context.session.cleanExpiredUserSessions();
let count = await t.context.db.userSession.count();
t.is(count, 1);
// Set expiresAt to past time
await t.context.db.userSession.update({
where: { id: userSession.id },
data: { expiresAt: new Date('2022-01-01') },
});
await t.context.session.cleanExpiredUserSessions();
count = await t.context.db.userSession.count();
t.is(count, 0);
});
@@ -1552,6 +1552,48 @@ test('should be able to create team subscription', async t => {
t.is(subInDB?.stripeSubscriptionId, sub.id);
});
test('should replace old team subscription row when stripe creates a new subscription', async t => {
const { service, db } = t.context;
const old = await db.subscription.create({
data: {
targetId: 'ws_1',
stripeSubscriptionId: 'sub_old_team',
plan: SubscriptionPlan.Team,
recurring: SubscriptionRecurring.Yearly,
status: SubscriptionStatus.Canceled,
start: new Date('2026-03-26T08:23:57.000Z'),
end: new Date('2027-03-26T08:23:57.000Z'),
quantity: 24,
},
});
await service.saveStripeSubscription({
...teamSub,
id: 'sub_new_team',
status: SubscriptionStatus.Active,
items: {
...teamSub.items,
data: [
{
...teamSub.items.data[0],
quantity: 11,
},
],
},
});
const subscriptions = await db.subscription.findMany({
where: { targetId: 'ws_1', plan: SubscriptionPlan.Team },
});
t.is(subscriptions.length, 1);
t.is(subscriptions[0].id, old.id);
t.is(subscriptions[0].stripeSubscriptionId, 'sub_new_team');
t.is(subscriptions[0].status, SubscriptionStatus.Active);
t.is(subscriptions[0].quantity, 11);
});
test('should be able to update team subscription', async t => {
const { service, db, event } = t.context;
@@ -1586,6 +1628,77 @@ test('should be able to update team subscription', async t => {
);
});
test('should persist mutable team subscription fields on same stripe subscription update', async t => {
const { service, db } = t.context;
await service.saveStripeSubscription(teamSub);
await service.saveStripeSubscription({
...teamSub,
current_period_start: 1780000000,
current_period_end: 1811536000,
trial_start: 1780000000,
trial_end: 1780604800,
items: {
...teamSub.items,
data: [
{
...teamSub.items.data[0],
quantity: 9,
price: {
...PRICES[TEAM_YEARLY],
lookup_key: TEAM_YEARLY,
},
},
],
},
});
const subInDB = await db.subscription.findFirst({
where: { targetId: 'ws_1' },
});
const entitlement = await db.entitlement.findFirst({
where: {
source: 'cloud_subscription',
subjectId: teamSub.id,
},
});
const providerFact = await db.providerSubscription.findUnique({
where: {
provider_externalSubscriptionId: {
provider: 'stripe',
externalSubscriptionId: teamSub.id,
},
},
});
t.like(subInDB, {
recurring: SubscriptionRecurring.Yearly,
quantity: 9,
start: new Date(1780000000 * 1000),
end: new Date(1811536000 * 1000),
trialStart: new Date(1780000000 * 1000),
trialEnd: new Date(1780604800 * 1000),
});
t.like(entitlement, {
plan: 'team',
quantity: 9,
startsAt: new Date(1780000000 * 1000),
expiresAt: new Date(1811536000 * 1000),
});
t.like(providerFact, {
recurring: SubscriptionRecurring.Yearly,
externalPriceId: TEAM_YEARLY,
currency: 'usd',
amount: 14400,
quantity: 9,
periodStart: new Date(1780000000 * 1000),
periodEnd: new Date(1811536000 * 1000),
trialStart: new Date(1780000000 * 1000),
trialEnd: new Date(1780604800 * 1000),
});
});
test('should suspend on dispute and restore when dispute won', async t => {
const { service, db, stripe, event } = t.context;
@@ -6,6 +6,7 @@ import Sinon from 'sinon';
import { OneDay } from '../../base';
import { StorageModule, WorkspaceBlobStorage } from '../../core/storage';
import { BlobUploadCleanupJob } from '../../core/storage/job';
import { StorageRuntimeProvider } from '../../core/storage-runtime';
import { MockUser, MockWorkspace } from '../mocks';
import { createTestingModule, TestingModule } from '../utils';
@@ -14,13 +15,22 @@ interface Context {
db: PrismaClient;
job: BlobUploadCleanupJob;
storage: WorkspaceBlobStorage;
runtime: { cleanupExpiredPendingBlobs: Sinon.SinonStub };
}
const test = ava as TestFn<Context>;
test.before(async t => {
t.context.runtime = {
cleanupExpiredPendingBlobs: Sinon.stub(),
};
t.context.module = await createTestingModule({
imports: [ScheduleModule.forRoot(), StorageModule],
tapModule: builder => {
builder
.overrideProvider(StorageRuntimeProvider)
.useValue(t.context.runtime);
},
});
t.context.db = t.context.module.get(PrismaClient);
@@ -30,6 +40,7 @@ test.before(async t => {
test.beforeEach(async t => {
await t.context.module.initTestingDB();
t.context.runtime.cleanupExpiredPendingBlobs.reset();
});
test.after.always(async t => {
@@ -86,24 +97,14 @@ test('should cleanup expired pending blobs', async t => {
],
});
const abortSpy = Sinon.stub(
t.context.storage,
'abortMultipartUpload'
).resolves();
const deleteSpy = Sinon.spy(t.context.storage, 'delete');
t.teardown(() => {
abortSpy.restore();
deleteSpy.restore();
t.context.runtime.cleanupExpiredPendingBlobs.resolves({
scanned: 2,
deleted: 2,
abortedMultipart: 1,
workspaceIds: [workspace.id],
});
await t.context.job.cleanExpiredPendingBlobs();
t.is(abortSpy.callCount, 1);
t.is(deleteSpy.callCount, 2);
const remaining = await t.context.db.blob.findMany({
where: { workspaceId: workspace.id },
});
const remainingKeys = remaining.map(record => record.key).sort();
t.deepEqual(remainingKeys, ['completed-keep', 'pending-active']);
t.true(t.context.runtime.cleanupExpiredPendingBlobs.calledOnce);
});
@@ -9,7 +9,7 @@ import {
import { PrismaClient } from '@prisma/client';
import { buildAppModule, FunctionalityModules } from '../../app.module';
import { AFFiNELogger, JobQueue } from '../../base';
import { AFFiNELogger, ConfigFactory, JobModule, JobQueue } from '../../base';
import { GqlModule } from '../../base/graphql';
import { ServerConfigModule } from '../../core';
import { AuthGuard, AuthModule } from '../../core/auth';
@@ -18,7 +18,7 @@ import { ModelsModule } from '../../models';
// for jsdoc inference
// oxlint-disable-next-line no-unused-vars
import type { createModule } from '../create-module';
import { createFactory, MockJobQueue } from '../mocks';
import { createFactory, MockJobModule, MockJobQueue } from '../mocks';
import { MockMailer } from '../mocks/mailer.mock';
import { initTestingDB, TEST_LOG_LEVEL } from './utils';
@@ -48,6 +48,16 @@ function dedupeModules(modules: NonNullable<ModuleMetadata['imports']>) {
return Array.from(map.values());
}
function testingFunctionalityModules() {
return [
...FunctionalityModules.filter(module => {
const moduleType = 'module' in module ? module.module : module;
return moduleType !== JobModule;
}),
MockJobModule,
];
}
@Resolver(() => String)
class MockResolver {
@Query(() => String)
@@ -70,7 +80,7 @@ export async function createTestingModule(
imports[0].module?.name === 'AppModule'
? imports
: dedupeModules([
...FunctionalityModules,
...testingFunctionalityModules(),
ModelsModule,
AuthModule,
GqlModule,
@@ -99,6 +109,31 @@ export async function createTestingModule(
}
const module = await builder.compile();
module.get(ConfigFactory).override({
storages: {
avatar: {
storage: {
provider: 'assetpack',
bucket: 'avatars',
config: { path: '/tmp/affine-test-storage' },
},
},
blob: {
storage: {
provider: 'assetpack',
bucket: 'blobs',
config: { path: '/tmp/affine-test-storage' },
},
},
},
copilot: {
storage: {
provider: 'assetpack',
bucket: 'copilot',
config: { path: '/tmp/affine-test-storage' },
},
},
});
const testingModule = module as TestingModule;
@@ -1,12 +1,15 @@
import { createHash } from 'node:crypto';
import { Readable } from 'node:stream';
import test from 'ava';
import Sinon from 'sinon';
import { Config, ConfigFactory, StorageProviderFactory } from '../../base';
import { ConfigFactory } from '../../base';
import { QuotaStateService } from '../../core/quota/state';
import { WorkspaceBlobStorage } from '../../core/storage/wrappers/blob';
import { StorageRuntimeProvider } from '../../core/storage-runtime';
import { BlobModel, WorkspaceFeatureModel } from '../../models';
import { getMime } from '../../native';
import {
collectAllBlobSizes,
completeBlobUpload,
@@ -32,9 +35,117 @@ const RESTRICTED_QUOTA = {
let app: TestingApp;
let model: WorkspaceFeatureModel;
type CompleteResult =
| {
ok: true;
contentType: string;
contentLength: number;
lastModifiedMs: number;
}
| {
ok: false;
reason:
| 'not_found'
| 'size_mismatch'
| 'mime_mismatch'
| 'checksum_mismatch'
| 'size_too_large';
};
const objects = new Map<
string,
{
body: Buffer;
metadata: {
contentType: string;
contentLength: number;
lastModified: Date;
};
}
>();
const completeResults = new Map<string, CompleteResult>();
const storageRuntime = {
providerCapabilities: async () => ({
put: true,
get: true,
head: true,
list: true,
delete: true,
presignPut: false,
presignGet: false,
multipartDirect: false,
proxyUpload: false,
assetpack: false,
serverMediatedOnly: true,
}),
putObject: async (
_scope: string,
key: string,
body: Buffer,
metadata?: { contentType?: string; contentLength?: number }
) => {
const object = {
body,
metadata: {
contentType: metadata?.contentType ?? getMime(body),
contentLength: metadata?.contentLength ?? body.length,
lastModified: new Date(),
},
};
objects.set(key, object);
return object.metadata;
},
headObject: async (_scope: string, key: string) => {
return objects.get(key)?.metadata;
},
getObject: async (_scope: string, key: string) => {
const object = objects.get(key);
return object
? { body: Readable.from(object.body), metadata: object.metadata }
: {};
},
listObjects: async (_scope: string, prefix?: string) => {
return Array.from(objects.entries())
.filter(([key]) => !prefix || key.startsWith(prefix))
.map(([key, object]) => ({ key, ...object.metadata }));
},
deleteObject: async (_scope: string, key: string) => {
objects.delete(key);
},
presignPut: async () => undefined,
presignGet: async () => undefined,
createMultipartUpload: async () => undefined,
presignUploadPart: async () => undefined,
listMultipartUploadParts: async () => undefined,
completeMultipartUpload: async () => undefined,
completeWorkspaceBlobUpload: async (workspaceId: string, key: string) => {
const objectKey = `${workspaceId}/${key}`;
const configured = completeResults.get(objectKey);
if (configured) return configured;
const object = objects.get(objectKey);
if (!object) return { ok: false, reason: 'not_found' };
await app.get(BlobModel).upsert({
workspaceId,
key,
mime: object.metadata.contentType,
size: object.metadata.contentLength,
status: 'completed',
uploadId: null,
});
return {
ok: true,
contentType: object.metadata.contentType,
contentLength: object.metadata.contentLength,
lastModifiedMs: object.metadata.lastModified.getTime(),
};
},
};
test.before(async () => {
app = await createTestingApp();
app = await createTestingApp({
tapModule: builder => {
builder.overrideProvider(StorageRuntimeProvider).useValue(storageRuntime);
},
});
model = app.get(WorkspaceFeatureModel);
app.get(ConfigFactory).override({
storages: {
@@ -47,11 +158,12 @@ test.before(async () => {
},
},
});
await app.get(WorkspaceBlobStorage).onConfigInit();
});
test.beforeEach(async () => {
await app.initTestingDB();
objects.clear();
completeResults.clear();
});
test.after.always(async () => {
@@ -119,6 +231,47 @@ test('should list blobs', async t => {
t.deepEqual(ret.map(x => x.key).sort(), [hash1, hash2].sort());
});
test('should keep partial blob metadata listing on DB path without storage scan', async t => {
await app.signupV1('u1@affine.pro');
const workspace = await createWorkspace(app);
const storage = app.get(WorkspaceBlobStorage);
const rt = app.get(StorageRuntimeProvider);
const listSpy = Sinon.spy(rt, 'listObjects');
t.teardown(() => listSpy.restore());
const buffer1 = Buffer.from('with metadata');
const buffer2 = Buffer.from('without metadata');
const key1 = sha256Base64urlWithPadding(buffer1);
const key2 = sha256Base64urlWithPadding(buffer2);
await rt.putObject('blob', `${workspace.id}/${key1}`, buffer1, {
contentType: 'text/plain',
contentLength: buffer1.length,
});
await rt.putObject('blob', `${workspace.id}/${key2}`, buffer2, {
contentType: 'text/plain',
contentLength: buffer2.length,
});
const blobModel = app.get(BlobModel);
await blobModel.upsert({
workspaceId: workspace.id,
key: key1,
mime: 'text/plain',
size: buffer1.length,
status: 'completed',
uploadId: null,
});
const listed = await storage.list(workspace.id);
t.deepEqual(
listed.map(blob => blob.key),
[key1]
);
t.true(listSpy.notCalled);
});
test('should create pending blob upload with graphql fallback', async t => {
await app.signupV1('u1@affine.pro');
@@ -150,11 +303,9 @@ test('should complete pending blob upload', async t => {
await createBlobUpload(app, workspace.id, key, buffer.length, mime);
const config = app.get(Config);
const factory = app.get(StorageProviderFactory);
const provider = factory.create(config.storages.blob.storage);
const rt = app.get(StorageRuntimeProvider);
await provider.put(`${workspace.id}/${key}`, buffer, {
await rt.putObject('blob', `${workspace.id}/${key}`, buffer, {
contentType: mime,
contentLength: buffer.length,
});
@@ -181,14 +332,16 @@ test('should reject complete when blob key mismatched', async t => {
const wrongKey = sha256Base64urlWithPadding(Buffer.from('other'));
await createBlobUpload(app, workspace.id, wrongKey, buffer.length, mime);
const config = app.get(Config);
const factory = app.get(StorageProviderFactory);
const provider = factory.create(config.storages.blob.storage);
const rt = app.get(StorageRuntimeProvider);
await provider.put(`${workspace.id}/${wrongKey}`, buffer, {
await rt.putObject('blob', `${workspace.id}/${wrongKey}`, buffer, {
contentType: mime,
contentLength: buffer.length,
});
completeResults.set(`${workspace.id}/${wrongKey}`, {
ok: false,
reason: 'checksum_mismatch',
});
await t.throwsAsync(() => completeBlobUpload(app, workspace.id, wrongKey), {
message: 'Blob key mismatch',
@@ -221,10 +374,12 @@ test('should auto delete blobs when workspace is deleted', async t => {
const blobs = await listBlobs(app, workspace.id);
t.is(blobs.length, 2);
const workspaceBlobStorage = Sinon.spy(app.get(WorkspaceBlobStorage));
const rt = app.get(StorageRuntimeProvider);
const listSpy = Sinon.spy(rt, 'listObjects');
t.teardown(() => listSpy.restore());
await deleteWorkspace(app, workspace.id);
// should not emit workspace.blob.sync event
t.is(workspaceBlobStorage.syncBlobMeta.callCount, 0);
t.is(listSpy.callCount, 0);
});
test('should calc blobs size', async t => {
@@ -212,8 +212,7 @@ test('should be able to get permission granted workspace', async t => {
test('should return 404 if blob not found', async t => {
const { app, storage } = t.context;
// @ts-expect-error mock
storage.get.resolves({ body: null });
storage.get.resolves({ body: undefined });
const res = await app.GET('/api/workspaces/public/blobs/test');
t.is(res.status, HttpStatus.NOT_FOUND);
+4 -2
View File
@@ -25,11 +25,11 @@ import { MetricsModule } from './base/metrics';
import { MutexModule } from './base/mutex';
import { PrismaModule } from './base/prisma';
import { RedisModule } from './base/redis';
import { StorageProviderModule } from './base/storage';
import { RateLimiterModule } from './base/throttler';
import { WebSocketModule } from './base/websocket';
import { AccessTokenModule } from './core/access-token';
import { AuthModule } from './core/auth';
import { BackendRuntimeModule } from './core/backend-runtime';
import { CommentModule } from './core/comment';
import { ServerConfigModule, ServerConfigResolverModule } from './core/config';
import { DocStorageModule } from './core/doc';
@@ -46,6 +46,7 @@ import { RealtimeModule } from './core/realtime';
import { SelfhostModule } from './core/selfhost';
import { StaticFileModule } from './core/static-files';
import { StorageModule } from './core/storage';
import { StorageRuntimeModule } from './core/storage-runtime';
import { SyncModule } from './core/sync';
import { TelemetryModule } from './core/telemetry';
import { UserModule } from './core/user';
@@ -113,13 +114,14 @@ export const FunctionalityModules = [
MutexModule,
MetricsModule,
RateLimiterModule,
StorageProviderModule,
HelpersModule,
ErrorModule,
WebSocketModule,
JobModule.forRoot(),
RealtimeModule,
ModelsModule,
BackendRuntimeModule,
StorageRuntimeModule,
ScheduleModule.forRoot(),
MonitorModule,
];
@@ -21,6 +21,7 @@ export type JSONSchema = { description?: string } & (
| {
type: 'object';
properties?: Record<string, JSONSchema>;
required?: string[];
}
);
@@ -907,6 +907,14 @@ export const USER_FRIENDLY_ERRORS = {
message: ({ clientVersion, requiredVersion }) =>
`Unsupported client with version [${clientVersion}], required version is [${requiredVersion}].`,
},
unsupported_server_version: {
type: 'action_forbidden',
args: {
requiredVersion: 'string',
},
message: ({ requiredVersion }) =>
`This AFFiNE server is too old for this client. Please upgrade the server to ${requiredVersion}.`,
},
// Notification Errors
notification_not_found: {
@@ -1059,6 +1059,16 @@ export class UnsupportedClientVersion extends UserFriendlyError {
super('action_forbidden', 'unsupported_client_version', message, args);
}
}
@ObjectType()
class UnsupportedServerVersionDataType {
@Field() requiredVersion!: string
}
export class UnsupportedServerVersion extends UserFriendlyError {
constructor(args: UnsupportedServerVersionDataType, message?: string | ((args: UnsupportedServerVersionDataType) => string)) {
super('action_forbidden', 'unsupported_server_version', message, args);
}
}
export class NotificationNotFound extends UserFriendlyError {
constructor(message?: string) {
@@ -1288,6 +1298,7 @@ export enum ErrorNames {
INVALID_LICENSE_UPDATE_PARAMS,
LICENSE_EXPIRED,
UNSUPPORTED_CLIENT_VERSION,
UNSUPPORTED_SERVER_VERSION,
NOTIFICATION_NOT_FOUND,
MENTION_USER_DOC_ACCESS_DENIED,
MENTION_USER_ONESELF_DENIED,
@@ -1308,5 +1319,5 @@ registerEnumType(ErrorNames, {
export const ErrorDataUnionType = createUnionType({
name: 'ErrorDataUnion',
types: () =>
[GraphqlBadRequestDataType, HttpRequestErrorDataType, SsrfBlockedErrorDataType, ResponseTooLargeErrorDataType, ImageFormatNotSupportedDataType, QueryTooLongDataType, ValidationErrorDataType, WrongSignInCredentialsDataType, UnknownOauthProviderDataType, InvalidOauthCallbackCodeDataType, MissingOauthQueryParameterDataType, InvalidOauthResponseDataType, InvalidEmailDataType, InvalidPasswordLengthDataType, WorkspacePermissionNotFoundDataType, SpaceNotFoundDataType, MemberNotFoundInSpaceDataType, NotInSpaceDataType, AlreadyInSpaceDataType, SpaceAccessDeniedDataType, SpaceOwnerNotFoundDataType, SpaceShouldHaveOnlyOneOwnerDataType, DocNotFoundDataType, DocActionDeniedDataType, DocUpdateBlockedDataType, VersionRejectedDataType, InvalidHistoryTimestampDataType, DocHistoryNotFoundDataType, BlobNotFoundDataType, ExpectToGrantDocUserRolesDataType, ExpectToRevokeDocUserRolesDataType, ExpectToUpdateDocUserRoleDataType, NoMoreSeatDataType, UnsupportedSubscriptionPlanDataType, SubscriptionAlreadyExistsDataType, SubscriptionNotExistsDataType, SameSubscriptionRecurringDataType, SubscriptionPlanNotFoundDataType, CalendarProviderRequestErrorDataType, NoCopilotProviderAvailableDataType, CopilotFailedToGenerateEmbeddingDataType, CopilotDocNotFoundDataType, CopilotMessageNotFoundDataType, CopilotPromptNotFoundDataType, CopilotProviderNotSupportedDataType, CopilotProviderSideErrorDataType, CopilotInvalidContextDataType, CopilotContextFileNotSupportedDataType, CopilotFailedToModifyContextDataType, CopilotFailedToMatchContextDataType, CopilotFailedToMatchGlobalContextDataType, CopilotFailedToAddWorkspaceFileEmbeddingDataType, RuntimeConfigNotFoundDataType, InvalidRuntimeConfigTypeDataType, InvalidLicenseToActivateDataType, InvalidLicenseUpdateParamsDataType, UnsupportedClientVersionDataType, MentionUserDocAccessDeniedDataType, InvalidAppConfigDataType, InvalidAppConfigInputDataType, InvalidSearchProviderRequestDataType, InvalidIndexerInputDataType] as const,
[GraphqlBadRequestDataType, HttpRequestErrorDataType, SsrfBlockedErrorDataType, ResponseTooLargeErrorDataType, ImageFormatNotSupportedDataType, QueryTooLongDataType, ValidationErrorDataType, WrongSignInCredentialsDataType, UnknownOauthProviderDataType, InvalidOauthCallbackCodeDataType, MissingOauthQueryParameterDataType, InvalidOauthResponseDataType, InvalidEmailDataType, InvalidPasswordLengthDataType, WorkspacePermissionNotFoundDataType, SpaceNotFoundDataType, MemberNotFoundInSpaceDataType, NotInSpaceDataType, AlreadyInSpaceDataType, SpaceAccessDeniedDataType, SpaceOwnerNotFoundDataType, SpaceShouldHaveOnlyOneOwnerDataType, DocNotFoundDataType, DocActionDeniedDataType, DocUpdateBlockedDataType, VersionRejectedDataType, InvalidHistoryTimestampDataType, DocHistoryNotFoundDataType, BlobNotFoundDataType, ExpectToGrantDocUserRolesDataType, ExpectToRevokeDocUserRolesDataType, ExpectToUpdateDocUserRoleDataType, NoMoreSeatDataType, UnsupportedSubscriptionPlanDataType, SubscriptionAlreadyExistsDataType, SubscriptionNotExistsDataType, SameSubscriptionRecurringDataType, SubscriptionPlanNotFoundDataType, CalendarProviderRequestErrorDataType, NoCopilotProviderAvailableDataType, CopilotFailedToGenerateEmbeddingDataType, CopilotDocNotFoundDataType, CopilotMessageNotFoundDataType, CopilotPromptNotFoundDataType, CopilotProviderNotSupportedDataType, CopilotProviderSideErrorDataType, CopilotInvalidContextDataType, CopilotContextFileNotSupportedDataType, CopilotFailedToModifyContextDataType, CopilotFailedToMatchContextDataType, CopilotFailedToMatchGlobalContextDataType, CopilotFailedToAddWorkspaceFileEmbeddingDataType, RuntimeConfigNotFoundDataType, InvalidRuntimeConfigTypeDataType, InvalidLicenseToActivateDataType, InvalidLicenseUpdateParamsDataType, UnsupportedClientVersionDataType, UnsupportedServerVersionDataType, MentionUserDocAccessDeniedDataType, InvalidAppConfigDataType, InvalidAppConfigInputDataType, InvalidSearchProviderRequestDataType, InvalidIndexerInputDataType] as const,
});
+1 -6
View File
@@ -30,11 +30,6 @@ export { Lock, Locker, Mutex, RequestMutex } from './mutex';
export * from './nestjs';
export { type PrismaTransaction } from './prisma';
export * from './storage';
export {
autoMetadata,
type StorageProvider,
type StorageProviderConfig,
StorageProviderFactory,
} from './storage';
export { type StorageProviderConfig } from './storage';
export { CloudThrottlerGuard, SkipThrottle, Throttle } from './throttler';
export * from './utils';
@@ -94,4 +94,12 @@ defineModuleConfig('job', {
},
schema,
},
'queues.backendRuntime': {
desc: 'The config for backend runtime job queue',
default: {
concurrency: 1,
},
schema,
},
});
@@ -29,6 +29,7 @@ export enum Queue {
COPILOT = 'copilot',
INDEXER = 'indexer',
CALENDAR = 'calendar',
BACKENDRUNTIME = 'backendRuntime',
}
export const QUEUES = Object.values(Queue);
@@ -8,6 +8,13 @@ import { Redis as IORedis, RedisOptions } from 'ioredis';
import { Config } from '../config';
function redisOptions(options: RedisOptions) {
return {
...(env.testing ? { lazyConnect: true } : {}),
...options,
};
}
class Redis extends IORedis implements OnModuleInit, OnModuleDestroy {
private readonly logger = new Logger(this.constructor.name);
@@ -47,41 +54,47 @@ class Redis extends IORedis implements OnModuleInit, OnModuleDestroy {
@Injectable()
export class CacheRedis extends Redis {
constructor(config: Config) {
super({ ...config.redis, ...config.redis.ioredis });
super(redisOptions({ ...config.redis, ...config.redis.ioredis }));
}
}
@Injectable()
export class SessionRedis extends Redis {
constructor(config: Config) {
super({
...config.redis,
...config.redis.ioredis,
db: (config.redis.db ?? 0) + 2,
});
super(
redisOptions({
...config.redis,
...config.redis.ioredis,
db: (config.redis.db ?? 0) + 2,
})
);
}
}
@Injectable()
export class SocketIoRedis extends Redis {
constructor(config: Config) {
super({
...config.redis,
...config.redis.ioredis,
db: (config.redis.db ?? 0) + 3,
});
super(
redisOptions({
...config.redis,
...config.redis.ioredis,
db: (config.redis.db ?? 0) + 3,
})
);
}
}
@Injectable()
export class QueueRedis extends Redis {
constructor(config: Config) {
super({
...config.redis,
...config.redis.ioredis,
db: (config.redis.db ?? 0) + 4,
// required explicitly set to `null` by bullmq
maxRetriesPerRequest: null,
});
super(
redisOptions({
...config.redis,
...config.redis.ioredis,
db: (config.redis.db ?? 0) + 4,
// required explicitly set to `null` by bullmq
maxRetriesPerRequest: null,
})
);
}
}
@@ -1,130 +0,0 @@
import { promises as fs } from 'node:fs';
import { join } from 'node:path';
import test from 'ava';
import { getStreamAsBuffer } from 'get-stream';
import { ListObjectsMetadata } from '../providers';
import { FsStorageProvider } from '../providers/fs';
const config = {
path: join(process.cwd(), 'node_modules', '.cache/affine-test-storage'),
};
function createProvider() {
return new FsStorageProvider(
config,
'test' + Math.random().toString(16).substring(2, 8)
);
}
function keys(list: ListObjectsMetadata[]) {
return list.map(i => i.key);
}
async function randomPut(
provider: FsStorageProvider,
prefix = ''
): Promise<string> {
const key = prefix + 'test-key-' + Math.random().toString(16).substring(2, 8);
const body = Buffer.from(key);
await provider.put(key, body);
return key;
}
test.after.always(() => {
fs.rm(config.path, { recursive: true }).catch(console.error);
});
test('put & get', async t => {
const provider = createProvider();
const key = 'testKey';
const body = Buffer.from('testBody');
await provider.put(key, body);
const result = await provider.get(key);
t.deepEqual(await getStreamAsBuffer(result.body!), body);
t.is(result.metadata?.contentLength, body.length);
});
test('list - one level', async t => {
const provider = createProvider();
const list = await Promise.all(
Array.from({ length: 100 }).map(() => randomPut(provider))
);
list.sort();
// random order, use set
const result = await provider.list();
t.deepEqual(keys(result), list);
const result2 = await provider.list('test-key');
t.deepEqual(keys(result2), list);
const result3 = await provider.list('testKey');
t.is(result3.length, 0);
});
test('list recursively', async t => {
const provider = createProvider();
await Promise.all([
Promise.all(Array.from({ length: 10 }).map(() => randomPut(provider))),
Promise.all(
Array.from({ length: 10 }).map(() => randomPut(provider, 'a/'))
),
Promise.all(
Array.from({ length: 10 }).map(() => randomPut(provider, 'a/b/'))
),
Promise.all(
Array.from({ length: 10 }).map(() => randomPut(provider, 'a/b/t/'))
),
]);
const r1 = await provider.list();
t.is(r1.length, 40);
// contains all `a/xxx` and `a/b/xxx` and `a/b/c/xxx`
const r2 = await provider.list('a');
t.is(r2.length, 30);
// contains only `a/b/xxx`
const r3 = await provider.list('a/b');
const r4 = await provider.list('a/b/');
t.is(r3.length, 20);
t.deepEqual(r3, r4);
// prefix is not ended with '/', it's open to all files and sub dirs
// contains all `a/b/t/xxx` and `a/b/t{xxxx}`
const r5 = await provider.list('a/b/t');
t.is(r5.length, 20);
});
test('delete', async t => {
const provider = createProvider();
const key = 'testKey';
const body = Buffer.from('testBody');
await provider.put(key, body);
await provider.delete(key);
await t.throwsAsync(() => fs.access(join(config.path, provider.bucket, key)));
});
test('rejects unsafe object keys', async t => {
const provider = createProvider();
await t.throwsAsync(() => provider.put('../escape', Buffer.from('nope')));
await t.throwsAsync(() => provider.get('nested/../escape'));
await t.throwsAsync(() => provider.head('./escape'));
t.throws(() => provider.delete('nested//escape'));
});
test('rejects unsafe list prefixes', async t => {
const provider = createProvider();
await t.throwsAsync(() => provider.list('../escape'));
await t.throwsAsync(() => provider.list('nested/../../escape'));
await t.throwsAsync(() => provider.list('/absolute'));
});
@@ -1,82 +0,0 @@
import test from 'ava';
import { R2StorageProvider } from '../providers/r2';
function endpointOf(provider: R2StorageProvider) {
return provider.endpointUrl;
}
test('R2 provider should use account endpoint by default', t => {
const provider = new R2StorageProvider(
{
accountId: 'test-account',
region: 'auto',
credentials: {
accessKeyId: 'test',
secretAccessKey: 'test',
},
},
'test-bucket'
);
t.is(
endpointOf(provider),
'https://test-account.r2.cloudflarestorage.com/test-bucket'
);
});
test('R2 provider should append jurisdiction suffix for EU buckets', t => {
const provider = new R2StorageProvider(
{
accountId: 'test-account',
jurisdiction: 'eu',
region: 'auto',
credentials: {
accessKeyId: 'test',
secretAccessKey: 'test',
},
},
'test-bucket'
);
t.is(
endpointOf(provider),
'https://test-account.eu.r2.cloudflarestorage.com/test-bucket'
);
});
test('R2 provider should throw when accountId is missing', t => {
t.throws(
() =>
new R2StorageProvider(
{
region: 'auto',
credentials: {
accessKeyId: 'test',
secretAccessKey: 'test',
},
} as any,
'test-bucket'
)
);
});
test('R2 provider should use default endpoint when jurisdiction is explicitly undefined', t => {
const provider = new R2StorageProvider(
{
accountId: 'test-account',
jurisdiction: undefined,
region: 'auto',
credentials: {
accessKeyId: 'test',
secretAccessKey: 'test',
},
},
'test-bucket'
);
t.is(
endpointOf(provider),
'https://test-account.r2.cloudflarestorage.com/test-bucket'
);
});
@@ -1,49 +0,0 @@
import { parseListPartsXml } from '@affine/s3-compat';
import test from 'ava';
test('parseListPartsXml handles array parts and pagination', t => {
const xml = `<?xml version="1.0" encoding="UTF-8"?>
<ListPartsResult>
<Bucket>test</Bucket>
<Key>key</Key>
<UploadId>upload-id</UploadId>
<PartNumberMarker>0</PartNumberMarker>
<NextPartNumberMarker>3</NextPartNumberMarker>
<IsTruncated>true</IsTruncated>
<Part>
<PartNumber>1</PartNumber>
<ETag>"etag-1"</ETag>
</Part>
<Part>
<PartNumber>2</PartNumber>
<ETag>etag-2</ETag>
</Part>
</ListPartsResult>`;
const result = parseListPartsXml(xml);
t.deepEqual(result.parts, [
{ partNumber: 1, etag: 'etag-1' },
{ partNumber: 2, etag: 'etag-2' },
]);
t.true(result.isTruncated);
t.is(result.nextPartNumberMarker, '3');
});
test('parseListPartsXml handles single part', t => {
const xml = `<?xml version="1.0" encoding="UTF-8"?>
<ListPartsResult>
<Bucket>test</Bucket>
<Key>key</Key>
<UploadId>upload-id</UploadId>
<IsTruncated>false</IsTruncated>
<Part>
<PartNumber>5</PartNumber>
<ETag>"etag-5"</ETag>
</Part>
</ListPartsResult>`;
const result = parseListPartsXml(xml);
t.deepEqual(result.parts, [{ partNumber: 5, etag: 'etag-5' }]);
t.false(result.isTruncated);
t.is(result.nextPartNumberMarker, undefined);
});
@@ -1,90 +0,0 @@
import test from 'ava';
import { S3StorageProvider } from '../providers/s3';
import { SIGNED_URL_EXPIRED } from '../providers/utils';
const config = {
region: 'us-east-1',
endpoint: 'https://s3.us-east-1.amazonaws.com',
credentials: {
accessKeyId: 'test',
secretAccessKey: 'test',
},
};
function createProvider() {
return new S3StorageProvider(config, 'test-bucket');
}
test('presignPut should return url and headers', async t => {
const provider = createProvider();
const result = await provider.presignPut('key', {
contentType: 'text/plain',
});
t.truthy(result);
t.true(result!.url.length > 0);
t.true(result!.url.includes('X-Amz-Algorithm=AWS4-HMAC-SHA256'));
t.true(result!.url.includes('X-Amz-SignedHeaders='));
t.true(result!.url.includes('content-type'));
t.deepEqual(result!.headers, { 'Content-Type': 'text/plain' });
const now = Date.now();
t.true(result!.expiresAt.getTime() >= now + SIGNED_URL_EXPIRED * 1000 - 2000);
t.true(result!.expiresAt.getTime() <= now + SIGNED_URL_EXPIRED * 1000 + 2000);
});
test('presignUploadPart should return url', async t => {
const provider = createProvider();
const result = await provider.presignUploadPart('key', 'upload-1', 3);
t.truthy(result);
t.true(result!.url.length > 0);
t.true(result!.url.includes('X-Amz-Algorithm=AWS4-HMAC-SHA256'));
});
test('createMultipartUpload should return uploadId', async t => {
const provider = createProvider();
let receivedKey: string | undefined;
let receivedMeta: any;
(provider as any).client = {
createMultipartUpload: async (key: string, meta: any) => {
receivedKey = key;
receivedMeta = meta;
return { uploadId: 'upload-1' };
},
};
const now = Date.now();
const result = await provider.createMultipartUpload('key', {
contentType: 'text/plain',
});
t.is(result?.uploadId, 'upload-1');
t.true(result!.expiresAt.getTime() >= now + SIGNED_URL_EXPIRED * 1000 - 2000);
t.true(result!.expiresAt.getTime() <= now + SIGNED_URL_EXPIRED * 1000 + 2000);
t.is(receivedKey, 'key');
t.is(receivedMeta.contentType, 'text/plain');
});
test('completeMultipartUpload should order parts', async t => {
const provider = createProvider();
let receivedParts: any;
(provider as any).client = {
completeMultipartUpload: async (
_key: string,
_uploadId: string,
parts: any
) => {
receivedParts = parts;
},
};
await provider.completeMultipartUpload('key', 'upload-1', [
{ partNumber: 2, etag: 'b' },
{ partNumber: 1, etag: 'a' },
]);
t.deepEqual(receivedParts, [
{ partNumber: 1, etag: 'a' },
{ partNumber: 2, etag: 'b' },
]);
});
@@ -1,20 +0,0 @@
import { Injectable } from '@nestjs/common';
import {
StorageProvider,
StorageProviderConfig,
StorageProviders,
} from './providers';
@Injectable()
export class StorageProviderFactory {
create(config: StorageProviderConfig): StorageProvider {
const Provider = StorageProviders[config.provider];
if (!Provider) {
throw new Error(`Unknown storage provider type: ${config.provider}`);
}
return new Provider(config.config, config.bucket);
}
}
@@ -1,12 +1,3 @@
import { Global, Module } from '@nestjs/common';
import { StorageProviderFactory } from './factory';
@Global()
@Module({
providers: [StorageProviderFactory],
exports: [StorageProviderFactory],
})
export class StorageProviderModule {}
export { StorageProviderFactory } from './factory';
export * from './providers';
export type * from './types';
export * from './utils';
@@ -1,318 +0,0 @@
import {
accessSync,
constants,
createReadStream,
Dirent,
mkdirSync,
readdirSync,
readFileSync,
rmSync,
statSync,
writeFileSync,
} from 'node:fs';
import { homedir } from 'node:os';
import { join, parse } from 'node:path';
import { Readable } from 'node:stream';
import { Logger } from '@nestjs/common';
import {
BlobInputType,
GetObjectMetadata,
ListObjectsMetadata,
PutObjectMetadata,
StorageProvider,
} from './provider';
import { autoMetadata, toBuffer } from './utils';
function normalizeStorageKey(key: string): string {
const normalized = key.replaceAll('\\', '/');
const segments = normalized.split('/');
if (
!normalized ||
normalized.startsWith('/') ||
segments.some(segment => !segment || segment === '.' || segment === '..')
) {
throw new Error(`Invalid storage key: ${key}`);
}
return segments.join('/');
}
function normalizeStoragePrefix(prefix: string): string {
const normalized = prefix.replaceAll('\\', '/');
if (!normalized) {
return normalized;
}
if (normalized.startsWith('/')) {
throw new Error(`Invalid storage prefix: ${prefix}`);
}
const segments = normalized.split('/');
const lastSegment = segments.pop();
if (
lastSegment === undefined ||
segments.some(segment => !segment || segment === '.' || segment === '..') ||
lastSegment === '.' ||
lastSegment === '..'
) {
throw new Error(`Invalid storage prefix: ${prefix}`);
}
if (lastSegment === '') {
return `${segments.join('/')}/`;
}
return [...segments, lastSegment].join('/');
}
export interface FsStorageConfig {
path: string;
}
export class FsStorageProvider implements StorageProvider {
private readonly path: string;
private readonly logger: Logger;
readonly type = 'fs';
constructor(
config: FsStorageConfig,
public readonly bucket: string
) {
this.path = config.path.startsWith('~/')
? join(homedir(), config.path.slice(2), bucket)
: join(config.path, bucket);
this.ensureAvailability();
this.logger = new Logger(`${FsStorageProvider.name}:${bucket}`);
}
async put(
key: string,
body: BlobInputType,
metadata: PutObjectMetadata = {}
): Promise<void> {
key = normalizeStorageKey(key);
const blob = await toBuffer(body);
// write object
this.writeObject(key, blob);
// write metadata
await this.writeMetadata(key, blob, metadata);
this.logger.verbose(`Object \`${key}\` put`);
}
async head(key: string) {
key = normalizeStorageKey(key);
const metadata = this.readMetadata(key);
if (!metadata) {
this.logger.verbose(`Object \`${key}\` not found`);
return undefined;
}
return metadata;
}
async get(key: string): Promise<{
body?: Readable;
metadata?: GetObjectMetadata;
}> {
key = normalizeStorageKey(key);
try {
const metadata = this.readMetadata(key);
const stream = this.readObject(this.join(key));
this.logger.verbose(`Read object \`${key}\``);
return {
body: stream,
metadata,
};
} catch (e) {
this.logger.error(`Failed to read object \`${key}\``, e);
return {};
}
}
async list(prefix?: string): Promise<ListObjectsMetadata[]> {
// prefix cases:
// - `undefined`: list all objects
// - `a/b`: list objects under dir `a` with prefix `b`, `b` might be a dir under `a` as well.
// - `a/b/` list objects under dir `a/b`
// read dir recursively and filter out '.metadata.json' files
let dir = this.path;
if (prefix) {
prefix = normalizeStoragePrefix(prefix);
const parts = prefix.split(/[/\\]/);
// for prefix `a/b/c`, move `a/b` to dir and `c` to key prefix
if (parts.length > 1) {
dir = join(dir, ...parts.slice(0, -1));
prefix = parts[parts.length - 1];
}
}
const results: ListObjectsMetadata[] = [];
async function getFiles(dir: string, prefix?: string): Promise<void> {
try {
const entries: Dirent[] = readdirSync(dir, { withFileTypes: true });
for (const entry of entries) {
const res = join(dir, entry.name);
if (entry.isDirectory()) {
if (!prefix || entry.name.startsWith(prefix)) {
await getFiles(res);
}
} else if (
(!prefix || entry.name.startsWith(prefix)) &&
!entry.name.endsWith('.metadata.json')
) {
const stat = statSync(res);
results.push({
key: res,
lastModified: stat.mtime,
contentLength: stat.size,
});
}
}
} catch {
// failed to read dir, stop recursion
}
}
await getFiles(dir, prefix);
// trim path with `this.path` prefix
results.forEach(r => (r.key = r.key.slice(this.path.length + 1)));
return results;
}
delete(key: string): Promise<void> {
key = normalizeStorageKey(key);
try {
rmSync(this.join(key), { force: true });
rmSync(this.join(`${key}.metadata.json`), { force: true });
} catch (e) {
throw new Error(`Failed to delete object \`${key}\``, {
cause: e,
});
}
this.logger.verbose(`Object \`${key}\` deleted`);
return Promise.resolve();
}
ensureAvailability() {
// check stats
const stats = statSync(this.path, {
throwIfNoEntry: false,
});
// not existing, create it
if (!stats) {
try {
mkdirSync(this.path, { recursive: true });
} catch (e) {
throw new Error(
`Failed to create target directory for fs storage provider: ${this.path}`,
{
cause: e,
}
);
}
} else if (stats.isDirectory()) {
// the target directory has already existed, check if it is readable & writable
try {
accessSync(this.path, constants.W_OK | constants.R_OK);
} catch (e) {
throw new Error(
`The target directory for fs storage provider has already existed, but it is not readable & writable: ${this.path}`,
{
cause: e,
}
);
}
} else if (stats.isFile()) {
throw new Error(
`The target directory for fs storage provider is a file: ${this.path}`
);
}
}
private join(...paths: string[]) {
return join(this.path, ...paths);
}
private readObject(file: string): Readable | undefined {
const state = statSync(file, { throwIfNoEntry: false });
if (state?.isFile()) {
return createReadStream(file);
}
return undefined;
}
private writeObject(key: string, blob: Buffer) {
const path = this.join(key);
mkdirSync(parse(path).dir, { recursive: true });
writeFileSync(path, blob);
}
private async writeMetadata(
key: string,
blob: Buffer,
raw: PutObjectMetadata
) {
try {
const metadata = autoMetadata(blob, raw);
if (raw.checksumCRC32 && metadata.checksumCRC32 !== raw.checksumCRC32) {
throw new Error(
'The checksum of the uploaded file is not matched with the one you provide, the file may be corrupted and the uploading will not be processed.'
);
}
if (raw.contentLength && metadata.contentLength !== raw.contentLength) {
throw new Error(
'The content length of the uploaded file is not matched with the one you provide, the file may be corrupted and the uploading will not be processed.'
);
}
writeFileSync(
this.join(`${key}.metadata.json`),
JSON.stringify({
...metadata,
lastModified: Date.now(),
})
);
} catch (e) {
this.logger.warn(`Failed to write metadata of object \`${key}\``, e);
}
}
private readMetadata(key: string): GetObjectMetadata | undefined {
try {
const raw = JSON.parse(
readFileSync(this.join(`${key}.metadata.json`), {
encoding: 'utf-8',
})
);
return {
...raw,
lastModified: new Date(raw.lastModified),
expires: raw.expires ? new Date(raw.expires) : undefined,
};
} catch (e) {
this.logger.warn(`Failed to read metadata of object \`${key}\``, e);
return;
}
}
}
@@ -1,20 +1,45 @@
import { Type } from '@nestjs/common';
import { JSONSchema } from '../../config';
import { FsStorageConfig, FsStorageProvider } from './fs';
import { StorageProvider } from './provider';
import { R2_JURISDICTIONS, R2StorageConfig, R2StorageProvider } from './r2';
import { S3StorageConfig, S3StorageProvider } from './s3';
export type StorageProviderName = 'fs' | 'aws-s3' | 'cloudflare-r2';
export const StorageProviders: Record<
StorageProviderName,
Type<StorageProvider>
> = {
fs: FsStorageProvider,
'aws-s3': S3StorageProvider,
'cloudflare-r2': R2StorageProvider,
};
export type StorageProviderName =
| 'fs'
| 'aws-s3'
| 'cloudflare-r2'
| 'assetpack';
export interface FsStorageConfig {
path: string;
}
export type AssetpackStorageConfig = FsStorageConfig;
export interface S3StorageConfig {
endpoint?: string;
region: string;
credentials?: {
accessKeyId?: string;
secretAccessKey?: string;
sessionToken?: string;
};
forcePathStyle?: boolean;
requestTimeoutMs?: number;
minPartSize?: number;
presign?: {
expiresInSeconds?: number;
signContentTypeForPut?: boolean;
};
}
export const R2_JURISDICTIONS = ['default', 'eu'] as const;
export interface R2StorageConfig extends Omit<S3StorageConfig, 'endpoint'> {
accountId: string;
jurisdiction?: (typeof R2_JURISDICTIONS)[number];
usePresignedURL?: {
enabled: boolean;
urlPrefix?: string;
signKey?: string;
};
}
export type StorageProviderConfig = { bucket: string } & (
| {
@@ -29,6 +54,10 @@ export type StorageProviderConfig = { bucket: string } & (
provider: 'cloudflare-r2';
config: R2StorageConfig;
}
| {
provider: 'assetpack';
config: AssetpackStorageConfig;
}
);
const S3ConfigSchema: JSONSchema = {
@@ -186,16 +215,37 @@ export const StorageJSONSchema: JSONSchema = {
},
},
},
{
type: 'object',
properties: {
provider: {
type: 'string',
enum: ['assetpack'],
},
bucket: {
type: 'string',
},
config: {
type: 'object',
properties: {
path: {
type: 'string',
},
},
required: ['path'],
},
},
required: ['provider', 'bucket', 'config'],
},
],
};
export type * from './provider';
export type * from '../types';
export {
applyAttachHeaders,
autoMetadata,
PROXY_MULTIPART_PATH,
PROXY_UPLOAD_PATH,
sniffMime,
STORAGE_PROXY_ROOT,
toBuffer,
} from './utils';
} from '../utils';
@@ -1,84 +0,0 @@
import type { Readable } from 'node:stream';
export interface GetObjectMetadata {
/**
* @default 'application/octet-stream'
*/
contentType: string;
contentLength: number;
lastModified: Date;
checksumCRC32?: string;
}
export interface PutObjectMetadata {
contentType?: string;
contentLength?: number;
checksumCRC32?: string;
}
export interface ListObjectsMetadata {
key: string;
lastModified: Date;
contentLength: number;
}
export type BlobInputType = Buffer | Readable | string;
export type BlobOutputType = Readable;
export interface PresignedUpload {
url: string;
headers?: Record<string, string>;
expiresAt: Date;
}
export interface MultipartUploadInit {
uploadId: string;
expiresAt: Date;
}
export interface MultipartUploadPart {
partNumber: number;
etag: string;
}
export interface StorageProvider {
put(
key: string,
body: BlobInputType,
metadata?: PutObjectMetadata
): Promise<void>;
presignPut?(
key: string,
metadata?: PutObjectMetadata
): Promise<PresignedUpload | undefined>;
createMultipartUpload?(
key: string,
metadata?: PutObjectMetadata
): Promise<MultipartUploadInit | undefined>;
presignUploadPart?(
key: string,
uploadId: string,
partNumber: number
): Promise<PresignedUpload | undefined>;
listMultipartUploadParts?(
key: string,
uploadId: string
): Promise<MultipartUploadPart[] | undefined>;
completeMultipartUpload?(
key: string,
uploadId: string,
parts: MultipartUploadPart[]
): Promise<void>;
abortMultipartUpload?(key: string, uploadId: string): Promise<void>;
head(key: string): Promise<GetObjectMetadata | undefined>;
get(
key: string,
signedUrl?: boolean
): Promise<{
redirectUrl?: string;
body?: BlobOutputType;
metadata?: GetObjectMetadata;
}>;
list(prefix?: string): Promise<ListObjectsMetadata[]>;
delete(key: string): Promise<void>;
}
@@ -1,251 +0,0 @@
import assert from 'node:assert';
import { Readable } from 'node:stream';
import { Logger } from '@nestjs/common';
import {
GetObjectMetadata,
PresignedUpload,
PutObjectMetadata,
} from './provider';
import { S3StorageConfig, S3StorageProvider } from './s3';
import {
PROXY_MULTIPART_PATH,
PROXY_UPLOAD_PATH,
SIGNED_URL_EXPIRED,
} from './utils';
export const R2_JURISDICTIONS = ['eu'] as const;
type R2Jurisdiction = (typeof R2_JURISDICTIONS)[number];
export interface R2StorageConfig extends Omit<
S3StorageConfig,
'endpoint' | 'forcePathStyle'
> {
accountId: string;
jurisdiction?: R2Jurisdiction;
usePresignedURL?: {
enabled: boolean;
urlPrefix?: string;
signKey?: string;
};
}
export class R2StorageProvider extends S3StorageProvider {
private readonly encoder = new TextEncoder();
private readonly key: Uint8Array;
constructor(
private readonly config: R2StorageConfig,
bucket: string
) {
assert(config.accountId, 'accountId is required for R2 storage provider');
const account = config.jurisdiction
? `${config.accountId}.${config.jurisdiction}`
: config.accountId;
const endpoint = `https://${account}.r2.cloudflarestorage.com`;
super(
{
...config,
forcePathStyle: true,
endpoint,
},
bucket
);
this.logger = new Logger(`${R2StorageProvider.name}:${bucket}`);
this.key = this.encoder.encode(config.usePresignedURL?.signKey ?? '');
}
private get shouldUseProxyUpload() {
const { usePresignedURL } = this.config;
return (
!!usePresignedURL?.enabled &&
!!usePresignedURL.signKey &&
this.key.length > 0
);
}
private parseWorkspaceKey(fullKey: string) {
const [workspaceId, ...rest] = fullKey.split('/');
if (!workspaceId || rest.length !== 1) {
return null;
}
return { workspaceId, key: rest.join('/') };
}
private async signPayload(payload: string) {
const key = await crypto.subtle.importKey(
'raw',
this.key,
{ name: 'HMAC', hash: 'SHA-256' },
false,
['sign', 'verify']
);
const mac = await crypto.subtle.sign(
'HMAC',
key,
this.encoder.encode(payload)
);
return Buffer.from(mac).toString('base64');
}
private async signUrl(url: URL): Promise<string> {
const timestamp = Math.floor(Date.now() / 1000);
const base64Mac = await this.signPayload(`${url.pathname}${timestamp}`);
url.searchParams.set('sign', `${timestamp}-${base64Mac}`);
return url.toString();
}
private async createProxyUrl(
path: string,
canonicalFields: (string | number | undefined)[],
query: Record<string, string | number | undefined>
) {
const exp = Math.floor(Date.now() / 1000) + SIGNED_URL_EXPIRED;
const canonical = [
path,
...canonicalFields.map(field =>
field === undefined ? '' : field.toString()
),
exp.toString(),
].join('\n');
const token = await this.signPayload(canonical);
const url = new URL(`http://localhost${path}`);
for (const [key, value] of Object.entries(query)) {
if (value === undefined) continue;
url.searchParams.set(key, value.toString());
}
url.searchParams.set('exp', exp.toString());
url.searchParams.set('token', `${exp}-${token}`);
return { url: url.pathname + url.search, expiresAt: new Date(exp * 1000) };
}
override async presignPut(
key: string,
metadata: PutObjectMetadata = {}
): Promise<PresignedUpload | undefined> {
if (!this.shouldUseProxyUpload) {
return super.presignPut(key, metadata);
}
const parsed = this.parseWorkspaceKey(key);
if (!parsed) {
return super.presignPut(key, metadata);
}
const contentType = metadata.contentType ?? 'application/octet-stream';
const { url, expiresAt } = await this.createProxyUrl(
PROXY_UPLOAD_PATH,
[parsed.workspaceId, parsed.key, contentType, metadata.contentLength],
{
workspaceId: parsed.workspaceId,
key: parsed.key,
contentType,
contentLength: metadata.contentLength,
}
);
return {
url,
headers: { 'Content-Type': contentType },
expiresAt,
};
}
override async presignUploadPart(
key: string,
uploadId: string,
partNumber: number
): Promise<PresignedUpload | undefined> {
if (!this.shouldUseProxyUpload) {
return super.presignUploadPart(key, uploadId, partNumber);
}
const parsed = this.parseWorkspaceKey(key);
if (!parsed) {
return super.presignUploadPart(key, uploadId, partNumber);
}
return this.createProxyUrl(
PROXY_MULTIPART_PATH,
[parsed.workspaceId, parsed.key, uploadId, partNumber],
{
workspaceId: parsed.workspaceId,
key: parsed.key,
uploadId,
partNumber,
}
);
}
async proxyPutObject(
key: string,
body: Readable | Buffer | Uint8Array | string,
options: { contentType?: string; contentLength?: number } = {}
) {
return this.client.putObject(key, this.normalizeBody(body), {
contentType: options.contentType,
contentLength: options.contentLength,
});
}
async proxyUploadPart(
key: string,
uploadId: string,
partNumber: number,
body: Readable | Buffer | Uint8Array | string,
options: { contentLength?: number } = {}
) {
const result = await this.client.uploadPart(
key,
uploadId,
partNumber,
this.normalizeBody(body),
{ contentLength: options.contentLength }
);
return result.etag;
}
private normalizeBody(body: Readable | Buffer | Uint8Array | string) {
// s3mini does not accept Node.js Readable directly.
// Convert it to Web ReadableStream for compatibility.
if (body instanceof Readable) {
return Readable.toWeb(body);
} else if (typeof body === 'string') {
return this.encoder.encode(body);
}
return body;
}
override async get(
key: string,
signedUrl?: boolean
): Promise<{
body?: Readable;
metadata?: GetObjectMetadata;
redirectUrl?: string;
}> {
const { usePresignedURL: { enabled, urlPrefix } = {} } = this.config;
if (signedUrl && enabled && urlPrefix) {
const metadata = await this.head(key);
const url = await this.signUrl(new URL(`/${key}`, urlPrefix));
if (metadata) {
return {
redirectUrl: url.toString(),
metadata,
};
}
// object not found
return {};
}
// fallback to s3 get
return super.get(key, signedUrl);
}
}
@@ -1,363 +0,0 @@
/* oxlint-disable @typescript-eslint/no-non-null-assertion */
import { Readable } from 'node:stream';
import type {
S3CompatClient,
S3CompatConfig,
S3CompatCredentials,
} from '@affine/s3-compat';
import { createS3CompatClient } from '@affine/s3-compat';
import { Logger } from '@nestjs/common';
import {
BlobInputType,
GetObjectMetadata,
ListObjectsMetadata,
MultipartUploadInit,
MultipartUploadPart,
PresignedUpload,
PutObjectMetadata,
StorageProvider,
} from './provider';
import { autoMetadata, SIGNED_URL_EXPIRED, toBuffer } from './utils';
export interface S3StorageConfig {
endpoint?: string;
region: string;
credentials: S3CompatCredentials;
forcePathStyle?: boolean;
requestTimeoutMs?: number;
minPartSize?: number;
presign?: {
expiresInSeconds?: number;
signContentTypeForPut?: boolean;
};
usePresignedURL?: {
enabled: boolean;
};
}
function resolveEndpoint(config: S3StorageConfig) {
if (config.endpoint) {
return config.endpoint;
}
if (config.region === 'us-east-1') {
return 'https://s3.amazonaws.com';
}
return `https://s3.${config.region}.amazonaws.com`;
}
function joinPath(basePath: string, suffix: string) {
const trimmedBase = basePath.endsWith('/') ? basePath.slice(0, -1) : basePath;
const trimmedSuffix = suffix.startsWith('/') ? suffix.slice(1) : suffix;
if (!trimmedBase) {
return `/${trimmedSuffix}`;
}
if (!trimmedSuffix) {
return trimmedBase;
}
return `${trimmedBase}/${trimmedSuffix}`;
}
function composeEndpointUrl(config: S3CompatConfig) {
const url = new URL(config.endpoint);
if (config.forcePathStyle) {
const firstSegment = url.pathname.split('/').find(Boolean);
if (firstSegment !== config.bucket) {
url.pathname = joinPath(url.pathname, config.bucket);
}
return url.toString();
}
const firstSegment = url.pathname.split('/').find(Boolean);
const hostHasBucket = url.hostname.startsWith(`${config.bucket}.`);
const pathHasBucket = firstSegment === config.bucket;
if (!hostHasBucket && !pathHasBucket) {
url.hostname = `${config.bucket}.${url.hostname}`;
}
return url.toString();
}
export class S3StorageProvider implements StorageProvider {
protected logger: Logger;
protected client: S3CompatClient;
private readonly usePresignedURL: boolean;
private readonly endpoint: string;
get endpointUrl() {
return this.endpoint;
}
constructor(
config: S3StorageConfig,
public readonly bucket: string
) {
const { usePresignedURL, presign, credentials, ...clientConfig } = config;
const compatConfig: S3CompatConfig = {
...clientConfig,
endpoint: resolveEndpoint(config),
bucket,
requestTimeoutMs: clientConfig.requestTimeoutMs ?? 60_000,
presign: {
expiresInSeconds: presign?.expiresInSeconds ?? SIGNED_URL_EXPIRED,
signContentTypeForPut: presign?.signContentTypeForPut ?? true,
},
};
this.endpoint = composeEndpointUrl(compatConfig);
this.client = createS3CompatClient(compatConfig, credentials);
this.usePresignedURL = usePresignedURL?.enabled ?? false;
this.logger = new Logger(`${S3StorageProvider.name}:${bucket}`);
}
async put(
key: string,
body: BlobInputType,
metadata: PutObjectMetadata = {}
): Promise<void> {
const blob = await toBuffer(body);
metadata = autoMetadata(blob, metadata);
try {
await this.client.putObject(key, blob, {
contentType: metadata.contentType,
contentLength: metadata.contentLength,
});
this.logger.verbose(`Object \`${key}\` put`);
} catch (e) {
this.logger.error(
`Failed to put object (${JSON.stringify({
key,
bucket: this.bucket,
metadata,
})})`
);
throw e;
}
}
async presignPut(
key: string,
metadata: PutObjectMetadata = {}
): Promise<PresignedUpload | undefined> {
try {
const contentType = metadata.contentType ?? 'application/octet-stream';
const result = await this.client.presignPutObject(key, { contentType });
return {
url: result.url,
headers: result.headers,
expiresAt: result.expiresAt,
};
} catch (e) {
this.logger.error(
`Failed to presign put object (${JSON.stringify({
key,
bucket: this.bucket,
metadata,
})}`
);
throw e;
}
}
async createMultipartUpload(
key: string,
metadata: PutObjectMetadata = {}
): Promise<MultipartUploadInit | undefined> {
try {
const contentType = metadata.contentType ?? 'application/octet-stream';
const response = await this.client.createMultipartUpload(key, {
contentType,
});
if (!response.uploadId) {
return;
}
return {
uploadId: response.uploadId,
expiresAt: new Date(Date.now() + SIGNED_URL_EXPIRED * 1000),
};
} catch (e) {
this.logger.error(
`Failed to create multipart upload (${JSON.stringify({
key,
bucket: this.bucket,
metadata,
})}`
);
throw e;
}
}
async presignUploadPart(
key: string,
uploadId: string,
partNumber: number
): Promise<PresignedUpload | undefined> {
try {
const result = await this.client.presignUploadPart(
key,
uploadId,
partNumber
);
return {
url: result.url,
expiresAt: result.expiresAt,
};
} catch (e) {
this.logger.error(
`Failed to presign upload part (${JSON.stringify({ key, bucket: this.bucket, uploadId, partNumber })}`
);
throw e;
}
}
async listMultipartUploadParts(
key: string,
uploadId: string
): Promise<MultipartUploadPart[] | undefined> {
try {
return await this.client.listParts(key, uploadId);
} catch (e) {
this.logger.error(`Failed to list multipart upload parts for \`${key}\``);
throw e;
}
}
async completeMultipartUpload(
key: string,
uploadId: string,
parts: MultipartUploadPart[]
): Promise<void> {
try {
const orderedParts = [...parts].sort(
(left, right) => left.partNumber - right.partNumber
);
await this.client.completeMultipartUpload(key, uploadId, orderedParts);
} catch (e) {
this.logger.error(`Failed to complete multipart upload for \`${key}\``);
throw e;
}
}
async abortMultipartUpload(key: string, uploadId: string): Promise<void> {
try {
await this.client.abortMultipartUpload(key, uploadId);
} catch (e) {
this.logger.error(`Failed to abort multipart upload for \`${key}\``);
throw e;
}
}
async head(key: string) {
try {
const obj = await this.client.headObject(key);
if (!obj) {
this.logger.verbose(`Object \`${key}\` not found`);
return undefined;
}
return {
contentType: obj.contentType ?? 'application/octet-stream',
contentLength: obj.contentLength ?? 0,
lastModified: obj.lastModified ?? new Date(0),
checksumCRC32: obj.checksumCRC32,
};
} catch (e) {
this.logger.error(`Failed to head object \`${key}\``);
throw e;
}
}
async get(
key: string,
signedUrl?: boolean
): Promise<{
body?: Readable;
metadata?: GetObjectMetadata;
redirectUrl?: string;
}> {
try {
if (this.usePresignedURL && signedUrl) {
const metadata = await this.head(key);
if (metadata) {
const result = await this.client.presignGetObject(key);
return {
redirectUrl: result.url,
metadata,
};
}
// object not found
return {};
}
const obj = await this.client.getObjectResponse(key);
if (!obj || !obj.body) {
this.logger.verbose(`Object \`${key}\` not found`);
return {};
}
const contentType = obj.headers.get('content-type') ?? undefined;
const contentLengthHeader = obj.headers.get('content-length');
const contentLength = contentLengthHeader
? Number(contentLengthHeader)
: undefined;
const lastModifiedHeader = obj.headers.get('last-modified');
const lastModified = lastModifiedHeader
? new Date(lastModifiedHeader)
: undefined;
this.logger.verbose(`Read object \`${key}\``);
return {
body: Readable.fromWeb(obj.body),
metadata: {
contentType: contentType ?? 'application/octet-stream',
contentLength: contentLength ?? 0,
lastModified: lastModified ?? new Date(0),
checksumCRC32: obj.headers.get('x-amz-checksum-crc32') ?? undefined,
},
};
} catch (e) {
this.logger.error(`Failed to read object \`${key}\``);
throw e;
}
}
async list(prefix?: string): Promise<ListObjectsMetadata[]> {
try {
const result = await this.client.listObjectsV2(prefix);
this.logger.verbose(
`List ${result.length} objects with prefix \`${prefix}\``
);
return result;
} catch (e) {
this.logger.error(`Failed to list objects with prefix \`${prefix}\``);
throw e;
}
}
async delete(key: string): Promise<void> {
try {
await this.client.deleteObject(key);
this.logger.verbose(`Deleted object \`${key}\``);
} catch (e) {
this.logger.error(`Failed to delete object \`${key}\``, {
bucket: this.bucket,
key,
cause: e,
});
throw e;
}
}
}
@@ -0,0 +1,32 @@
import type { Readable } from 'node:stream';
export interface GetObjectMetadata {
/**
* @default 'application/octet-stream'
*/
contentType: string;
contentLength: number;
lastModified: Date;
checksumCRC32?: string;
}
export interface PutObjectMetadata {
contentType?: string;
contentLength?: number;
checksumCRC32?: string;
}
export interface ListObjectsMetadata {
key: string;
lastModified: Date;
contentLength: number;
}
export type BlobInputType = Buffer | Readable | string;
export type BlobOutputType = Readable;
export interface PresignedUpload {
url: string;
headers?: Record<string, string>;
expiresAt: Date;
}
@@ -1,11 +1,10 @@
import { Readable } from 'node:stream';
import { crc32 } from '@node-rs/crc32';
import type { Response } from 'express';
import { getStreamAsBuffer } from 'get-stream';
import { getMime } from '../../../native';
import { BlobInputType, PutObjectMetadata } from './provider';
import { getMime } from '../../native';
import type { BlobInputType } from './types';
export async function toBuffer(input: BlobInputType): Promise<Buffer> {
return input instanceof Readable
@@ -15,35 +14,6 @@ export async function toBuffer(input: BlobInputType): Promise<Buffer> {
: Buffer.from(input as string);
}
export function autoMetadata(
blob: Buffer,
raw: PutObjectMetadata = {}
): PutObjectMetadata {
const metadata = {
...raw,
};
if (!metadata.contentLength) {
metadata.contentLength = blob.byteLength;
}
try {
// checksum
if (!metadata.checksumCRC32) {
metadata.checksumCRC32 = crc32(blob).toString(16);
}
// mime type
if (!metadata.contentType) {
metadata.contentType = getMime(blob);
}
} catch {
// noop
}
return metadata;
}
const DANGEROUS_INLINE_MIME_PREFIXES = [
'text/html',
'application/xhtml+xml',
@@ -2,6 +2,7 @@ import './config';
import { Module } from '@nestjs/common';
import { BackendRuntimeModule } from '../backend-runtime';
import { FeatureModule } from '../features';
import { MailModule } from '../mail';
import { QuotaModule } from '../quota';
@@ -20,7 +21,13 @@ import { AuthService } from './service';
import { SessionIssuer } from './session-issuer';
@Module({
imports: [FeatureModule, UserModule, QuotaModule, MailModule],
imports: [
BackendRuntimeModule,
FeatureModule,
UserModule,
QuotaModule,
MailModule,
],
providers: [
AuthService,
AuthResolver,
+6 -3
View File
@@ -2,7 +2,7 @@ import { Injectable } from '@nestjs/common';
import { Cron, CronExpression } from '@nestjs/schedule';
import { JobQueue, OnJob } from '../../base';
import { Models } from '../../models';
import { BackendRuntimeProvider } from '../backend-runtime';
declare global {
interface Jobs {
@@ -13,7 +13,7 @@ declare global {
@Injectable()
export class AuthCronJob {
constructor(
private readonly models: Models,
private readonly rt: BackendRuntimeProvider,
private readonly queue: JobQueue
) {}
@@ -31,6 +31,9 @@ export class AuthCronJob {
@OnJob('nightly.cleanExpiredUserSessions')
async cleanExpiredUserSessions() {
await this.models.session.cleanExpiredUserSessions();
for (;;) {
const count = await this.rt.cleanupExpiredUserSessions(1000);
if (count < 1000) break;
}
}
}
@@ -0,0 +1,57 @@
import { ScheduleModule } from '@nestjs/schedule';
import ava, { TestFn } from 'ava';
import Sinon from 'sinon';
import {
createTestingModule,
type TestingModule,
} from '../../../__tests__/utils';
import { BackendRuntimeModule, BackendRuntimeProvider } from '../index';
import { BackendRuntimeHousekeepingJob } from '../job';
interface Context {
module: TestingModule;
job: BackendRuntimeHousekeepingJob;
runtime: {
cleanupExpiredRuntimeStates: Sinon.SinonStub;
cleanupExpiredRuntimeGates: Sinon.SinonStub;
};
}
const test = ava as TestFn<Context>;
test.before(async t => {
t.context.runtime = {
cleanupExpiredRuntimeStates: Sinon.stub(),
cleanupExpiredRuntimeGates: Sinon.stub(),
};
t.context.module = await createTestingModule({
imports: [ScheduleModule.forRoot(), BackendRuntimeModule],
tapModule: builder => {
builder
.overrideProvider(BackendRuntimeProvider)
.useValue(t.context.runtime);
},
});
t.context.job = t.context.module.get(BackendRuntimeHousekeepingJob);
});
test.beforeEach(t => {
t.context.runtime.cleanupExpiredRuntimeStates.reset();
t.context.runtime.cleanupExpiredRuntimeGates.reset();
});
test.after.always(async t => {
await t.context.module.close();
});
test('backend-runtime housekeeping cleans runtime state and gate batches', async t => {
t.context.runtime.cleanupExpiredRuntimeStates.onCall(0).resolves(1000);
t.context.runtime.cleanupExpiredRuntimeStates.onCall(1).resolves(2);
t.context.runtime.cleanupExpiredRuntimeGates.resolves(1);
await t.context.job.cleanExpiredRuntimeHousekeeping();
t.is(t.context.runtime.cleanupExpiredRuntimeStates.callCount, 2);
t.is(t.context.runtime.cleanupExpiredRuntimeGates.callCount, 1);
});
@@ -0,0 +1,41 @@
import test from 'ava';
import Sinon from 'sinon';
import { BackendRuntimeProvider } from '../provider';
test('backend-runtime provider starts once, runs migrations once, and reports health', async t => {
const provider = new BackendRuntimeProvider();
const runtime = {
start: Sinon.stub().resolves(),
stop: Sinon.stub().resolves(),
runMigrations: Sinon.stub().resolves(),
health: Sinon.stub().resolves({
started: true,
databaseConnected: true,
}),
};
(provider as any).runtime = runtime;
await provider.start();
await provider.start();
const health = await provider.health();
await provider.stop();
t.is(runtime.start.callCount, 2);
t.is(runtime.runMigrations.callCount, 1);
t.true(health.databaseConnected);
t.is(runtime.stop.callCount, 1);
});
test('backend-runtime provider measures explicit typed methods', async t => {
const provider = new BackendRuntimeProvider();
const runtime = {
cleanupExpiredRuntimeStates: Sinon.stub().resolves(3),
};
(provider as any).runtime = runtime;
const result = await provider.cleanupExpiredRuntimeStates(1000);
t.is(result, 3);
t.true(runtime.cleanupExpiredRuntimeStates.calledOnceWithExactly(1000));
});
@@ -0,0 +1,13 @@
import { Global, Module } from '@nestjs/common';
import { BackendRuntimeHousekeepingJob } from './job';
import { BackendRuntimeProvider } from './provider';
@Global()
@Module({
providers: [BackendRuntimeProvider, BackendRuntimeHousekeepingJob],
exports: [BackendRuntimeProvider],
})
export class BackendRuntimeModule {}
export { BackendRuntimeProvider } from './provider';
@@ -0,0 +1,58 @@
import { Injectable, Logger } from '@nestjs/common';
import { Cron, CronExpression } from '@nestjs/schedule';
import { JobQueue, OnJob } from '../../base';
import { BackendRuntimeProvider } from './provider';
declare global {
interface Jobs {
'nightly.cleanExpiredBackendRuntimeHousekeeping': {};
}
}
@Injectable()
export class BackendRuntimeHousekeepingJob {
private readonly logger = new Logger(BackendRuntimeHousekeepingJob.name);
constructor(
private readonly rt: BackendRuntimeProvider,
private readonly queue: JobQueue
) {}
@Cron(CronExpression.EVERY_DAY_AT_MIDNIGHT)
async nightlyJob() {
await this.queue.add(
'nightly.cleanExpiredBackendRuntimeHousekeeping',
{},
{
jobId: 'nightly-backend-runtime-housekeeping',
}
);
}
@OnJob('nightly.cleanExpiredBackendRuntimeHousekeeping')
async cleanExpiredRuntimeHousekeeping() {
const states = await this.cleanBatches(() =>
this.rt.cleanupExpiredRuntimeStates(1000)
);
const gates = await this.cleanBatches(() =>
this.rt.cleanupExpiredRuntimeGates(1000)
);
this.logger.log(
`cleaned runtime housekeeping states=${states} gates=${gates}`
);
}
private async cleanBatches(fn: () => Promise<number>) {
let total = 0;
for (;;) {
const count = Number(await fn());
total += count;
if (count < 1000) {
break;
}
}
return total;
}
}
@@ -0,0 +1,88 @@
import {
Injectable,
Logger,
type OnApplicationBootstrap,
type OnApplicationShutdown,
} from '@nestjs/common';
import { wrapCallMetric } from '../../base/metrics';
import { BackendRuntime, type BackendRuntimeHealth } from '../../native';
type RuntimeInstance = InstanceType<typeof BackendRuntime>;
@Injectable()
export class BackendRuntimeProvider
implements OnApplicationBootstrap, OnApplicationShutdown
{
private readonly logger = new Logger(BackendRuntimeProvider.name);
private readonly runtime: RuntimeInstance = new BackendRuntime();
private migrationsStarted = false;
async onApplicationBootstrap() {
await this.start();
}
async onApplicationShutdown() {
await this.stop();
}
async start() {
await this.runtime.start();
await this.runMigrationsOnce();
const health = await this.runtime.health();
this.logger.log(`backend runtime started: db=${health.databaseConnected}`);
}
async stop() {
await this.runtime.stop();
this.logger.log('backend runtime stopped');
}
async health(): Promise<BackendRuntimeHealth> {
return await this.runtime.health();
}
async cleanupExpiredSnapshotHistories(limit: number) {
return await this.measured('cleanupExpiredSnapshotHistories', rt =>
rt.cleanupExpiredSnapshotHistories(limit)
);
}
async cleanupExpiredUserSessions(limit: number) {
return await this.measured('cleanupExpiredUserSessions', rt =>
rt.cleanupExpiredUserSessions(limit)
);
}
async cleanupExpiredRuntimeStates(limit: number) {
return await this.measured('cleanupExpiredRuntimeStates', rt =>
rt.cleanupExpiredRuntimeStates(limit)
);
}
async cleanupExpiredRuntimeGates(limit: number) {
return await this.measured('cleanupExpiredRuntimeGates', rt =>
rt.cleanupExpiredRuntimeGates(limit)
);
}
private async measured<T>(
method: string,
fn: (runtime: RuntimeInstance) => Promise<T>
): Promise<T> {
return await wrapCallMetric(
() => fn(this.runtime),
'storage',
'backend_runtime',
{ method }
)();
}
private async runMigrationsOnce() {
if (this.migrationsStarted) {
return;
}
await this.runtime.runMigrations();
this.migrationsStarted = true;
}
}
@@ -2,6 +2,7 @@ import { Module } from '@nestjs/common';
import { ServerConfigModule } from '../config';
import { PermissionModule } from '../permission';
import { QuotaServiceModule } from '../quota';
import { StorageModule } from '../storage';
import { CommentRealtimeModule } from './realtime.module';
import { CommentResolver } from './resolver';
@@ -9,6 +10,7 @@ import { CommentResolver } from './resolver';
@Module({
imports: [
PermissionModule,
QuotaServiceModule,
StorageModule,
ServerConfigModule,
CommentRealtimeModule,

Some files were not shown because too many files have changed in this diff Show More