diff --git a/packages/backend/native/index.d.ts b/packages/backend/native/index.d.ts index 925149d887..8f9e4d77df 100644 --- a/packages/backend/native/index.d.ts +++ b/packages/backend/native/index.d.ts @@ -68,6 +68,7 @@ export declare class StorageRuntime { rebuildWorkspaceDocBlobRefs(workspaceId: string, limit: number): Promise constructor() start(): Promise + configure(configJson: string): void stop(): Promise runMigrations(): Promise health(): Promise diff --git a/packages/backend/native/src/runtime/backend_runtime/coordination_lease.rs b/packages/backend/native/src/runtime/backend_runtime/coordination_lease.rs index 5287167f7e..de770d2e6a 100644 --- a/packages/backend/native/src/runtime/backend_runtime/coordination_lease.rs +++ b/packages/backend/native/src/runtime/backend_runtime/coordination_lease.rs @@ -94,6 +94,13 @@ impl BackendRuntime { owner: String, ttl_ms: i64, ) -> RuntimeResult> { + if ttl_ms <= 0 { + return Err(RuntimeError::invalid_input("coordination lease ttl must be positive")); + } + if owner.is_empty() { + return Err(RuntimeError::invalid_input("coordination lease owner is required")); + } + CoordinationLeaseStore::new(self.pool().await?) .acquire(key, owner, ttl_ms) .await @@ -117,13 +124,6 @@ impl BackendRuntime { owner: String, ttl_ms: i64, ) -> Result> { - if ttl_ms <= 0 { - return Err(napi_error("coordination lease ttl must be positive")); - } - if owner.is_empty() { - return Err(napi_error("coordination lease owner is required")); - } - self .acquire_coordination_lease_inner(key, owner, ttl_ms) .await diff --git a/packages/backend/native/src/runtime/storage_runtime/assetpack.rs b/packages/backend/native/src/runtime/storage_runtime/assetpack.rs index 201f97c50b..c29aed4540 100644 --- a/packages/backend/native/src/runtime/storage_runtime/assetpack.rs +++ b/packages/backend/native/src/runtime/storage_runtime/assetpack.rs @@ -208,21 +208,21 @@ pub(super) async fn list( scope: &str, prefix: Option, ) -> RuntimeResult> { - let prefix = prefix.unwrap_or_default(); - if !prefix.is_empty() { - super::normalize_storage_prefix(&prefix)?; - } + let prefix = prefix + .map(|prefix| super::normalize_storage_prefix(&prefix)) + .transpose()? + .unwrap_or_default(); let store = open_store(config).await?; let rows = sqlx::query( r#" SELECT key, content_length, last_modified_ms FROM storage_assetpack_blobs - WHERE scope = ?1 AND key LIKE ?2 + WHERE scope = ?1 AND key LIKE ?2 ESCAPE '\' ORDER BY key ASC "#, ) .bind(scope) - .bind(format!("{prefix}%")) + .bind(format!("{}%", escape_sqlite_like(&prefix))) .fetch_all(store.pool()) .await .map_err(|err| RuntimeError::database("Assetpack manifest list failed", err))?; @@ -239,6 +239,20 @@ pub(super) async fn list( .collect() } +fn escape_sqlite_like(value: &str) -> String { + let mut escaped = String::with_capacity(value.len()); + for ch in value.chars() { + match ch { + '%' | '_' | '\\' => { + escaped.push('\\'); + escaped.push(ch); + } + _ => escaped.push(ch), + } + } + escaped +} + pub(super) async fn delete(config: &FsStorageConfig, scope: &str, key: &str) -> RuntimeResult<()> { normalize_storage_key(key)?; let store = open_store(config).await?; diff --git a/packages/backend/native/src/runtime/storage_runtime/blob_cleanup.rs b/packages/backend/native/src/runtime/storage_runtime/blob_cleanup.rs index 9c6daf789b..c582741157 100644 --- a/packages/backend/native/src/runtime/storage_runtime/blob_cleanup.rs +++ b/packages/backend/native/src/runtime/storage_runtime/blob_cleanup.rs @@ -6,45 +6,6 @@ use super::{ napi_error, }; -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn blob_cleanup_plan_result_keeps_run_id_for_execute() { - let result = RuntimeBlobCleanupPlanResult { - run_id: Some("00000000-0000-0000-0000-000000000000".to_string()), - scanned_blobs: 1, - candidates_marked: 1, - protected_by_doc_refs: 0, - protected_by_metadata: 0, - protected_by_other_refs: 0, - next_cursor: None, - }; - - assert!(result.run_id.is_some()); - assert_eq!(result.candidates_marked, 1); - } - - #[test] - fn blob_cleanup_execute_result_tracks_skipped_and_failed_counts() { - let result = RuntimeBlobCleanupExecuteResult { - scanned_candidates: 3, - deleted_objects: 1, - deleted_metadata: 1, - skipped_still_referenced: 1, - failed: 1, - workspace_ids: vec!["workspace".to_string()], - }; - - assert_eq!(result.scanned_candidates, 3); - assert_eq!( - result.skipped_still_referenced + result.failed + result.deleted_objects, - 3 - ); - } -} - #[derive(FromRow)] struct BlobCandidateRow { workspace_id: String, diff --git a/packages/backend/native/src/runtime/storage_runtime/mod.rs b/packages/backend/native/src/runtime/storage_runtime/mod.rs index de3dd2f867..3dda90f844 100644 --- a/packages/backend/native/src/runtime/storage_runtime/mod.rs +++ b/packages/backend/native/src/runtime/storage_runtime/mod.rs @@ -111,7 +111,10 @@ struct CopilotConfigFile { impl StorageRuntimeConfig { fn from_config_files() -> RuntimeResult { - let app_config = app_config_from_config_files()?; + Self::from_app_config_file(app_config_from_config_files()?) + } + + fn from_app_config_file(app_config: AppConfigFile) -> RuntimeResult { let database_url = database_url_from_env() .or(app_config.database_url()) .unwrap_or_else(|| "postgresql://localhost:5432/affine".to_string()); @@ -120,11 +123,12 @@ impl StorageRuntimeConfig { } async fn with_db_overrides(&self, pool: &PgPool) -> RuntimeResult { - let mut app_config = app_config_from_config_files()?; - app_config.apply_file_config(load_app_config_overrides_from_db(pool).await?); + let app_config = load_app_config_overrides_from_db(pool).await?; + let mut backends = self.backends.clone(); + backends.extend(app_config.storage_backends()?); Ok(Self { database_url: self.database_url.clone(), - backends: app_config.storage_backends()?, + backends, }) } } @@ -300,6 +304,14 @@ impl StorageRuntime { self.start_inner().await.map_err(to_napi_error) } + #[napi] + pub fn configure(&self, config_json: String) -> napi::Result<()> { + let app_config: AppConfigFile = serde_json::from_str(&config_json) + .map_err(|err| to_napi_error(RuntimeError::json("invalid storage runtime config", err)))?; + let config = StorageRuntimeConfig::from_app_config_file(app_config).map_err(to_napi_error)?; + self.update_config(config).map_err(to_napi_error) + } + async fn start_inner(&self) -> RuntimeResult<()> { let mut guard = self.pool.lock().await; if guard.is_some() { @@ -1295,11 +1307,36 @@ mod tests { #[test] fn fs_key_normalization_rejects_traversal() { - for key in ["", "/a", "a//b", "a/./b", "a/../b", "..\\secret"] { - assert!(normalize_storage_key(key).is_err(), "{key}"); + for (key, valid) in [ + ("", false), + ("/a", false), + ("a//b", false), + ("a/./b", false), + ("a/../b", false), + ("..\\secret", false), + ("workspace/blob", true), + ("workspace\\blob", true), + ] { + assert_eq!(normalize_storage_key(key).is_ok(), valid, "{key}"); } assert_eq!(normalize_storage_key("workspace/blob").unwrap(), ["workspace", "blob"]); - assert_eq!(normalize_storage_prefix("workspace/").unwrap(), "workspace/"); + } + + #[test] + fn fs_prefix_normalization_rejects_traversal() { + for (prefix, expected) in [ + ("", Some("")), + ("workspace/", Some("workspace/")), + ("workspace\\blob", Some("workspace/blob")), + ("../escape", None), + ("nested/../../escape", None), + ("/absolute", None), + ("nested//escape", None), + ("nested/./escape", None), + ("nested/../escape", None), + ] { + assert_eq!(normalize_storage_prefix(prefix).ok().as_deref(), expected, "{prefix}"); + } } #[test] @@ -1468,6 +1505,185 @@ mod tests { assert_eq!(keys, ["workspace/blob-a", "workspace/nested/blob-b"]); } + #[test] + fn fs_backend_lists_old_node_prefix_semantics() { + let temp = tempfile::tempdir().unwrap(); + let config = FsStorageConfig { + provider: "fs".to_string(), + root: temp.path().to_string_lossy().to_string(), + bucket: "bucket".to_string(), + }; + for key in ["root-a", "a/item", "a/b/item", "a/b/t/item", "a/b/tail", "z/item"] { + fs_put(&config, key, key.as_bytes().to_vec(), ObjectPutMetadata::default()).unwrap(); + } + + for (prefix, expected) in [ + ( + None, + vec!["a/b/item", "a/b/t/item", "a/b/tail", "a/item", "root-a", "z/item"], + ), + (Some("a"), vec!["a/b/item", "a/b/t/item", "a/b/tail", "a/item"]), + (Some("a/b"), vec!["a/b/item", "a/b/t/item", "a/b/tail"]), + (Some("a/b/"), vec!["a/b/item", "a/b/t/item", "a/b/tail"]), + (Some("a/b/t"), vec!["a/b/t/item", "a/b/tail"]), + (Some("missing"), vec![]), + ] { + let keys = fs_list(&config, prefix.map(ToString::to_string)) + .unwrap() + .into_iter() + .map(|entry| entry.key) + .collect::>(); + assert_eq!(keys, expected, "{prefix:?}"); + } + } + + #[test] + fn fs_backend_delete_removes_object_and_sidecar_idempotently() { + let temp = tempfile::tempdir().unwrap(); + let config = FsStorageConfig { + provider: "fs".to_string(), + root: temp.path().to_string_lossy().to_string(), + bucket: "bucket".to_string(), + }; + + fs_put( + &config, + "workspace/blob", + b"body".to_vec(), + ObjectPutMetadata::default(), + ) + .unwrap(); + fs_delete(&config, "workspace/blob").unwrap(); + fs_delete(&config, "workspace/blob").unwrap(); + + assert!(fs_head(&config, "workspace/blob").unwrap().is_none()); + assert!(fs_get(&config, "workspace/blob").unwrap().is_none()); + assert!(!temp.path().join("bucket/workspace/blob").exists()); + assert!(!temp.path().join("bucket/workspace/blob.metadata.json").exists()); + } + + fn test_storage_runtime() -> StorageRuntime { + StorageRuntime { + config: RwLock::new(StorageRuntimeConfig { + database_url: "postgresql://unused".to_string(), + backends: HashMap::new(), + }), + pool: Mutex::new(None), + } + } + + #[tokio::test] + async fn fs_workspace_blob_complete_returns_native_failure_reasons_before_db_upsert() { + let temp = tempfile::tempdir().unwrap(); + let config = FsStorageConfig { + provider: "fs".to_string(), + root: temp.path().to_string_lossy().to_string(), + bucket: "bucket".to_string(), + }; + let runtime = test_storage_runtime(); + + let result = runtime + .complete_fs_workspace_blob( + config.clone(), + "workspace".to_string(), + "missing".to_string(), + 1, + "text/plain".to_string(), + ) + .await + .unwrap(); + assert!(!result.ok); + assert_eq!(result.reason.as_deref(), Some("not_found")); + + fs_put( + &config, + "workspace/blob", + b"body".to_vec(), + ObjectPutMetadata { + content_type: Some("text/plain".to_string()), + content_length: Some(4), + checksum_crc32: None, + }, + ) + .unwrap(); + let result = runtime + .complete_fs_workspace_blob( + config.clone(), + "workspace".to_string(), + "blob".to_string(), + 5, + "text/plain".to_string(), + ) + .await + .unwrap(); + assert!(!result.ok); + assert_eq!(result.reason.as_deref(), Some("size_mismatch")); + + let result = runtime + .complete_fs_workspace_blob( + config.clone(), + "workspace".to_string(), + "blob".to_string(), + 4, + "image/png".to_string(), + ) + .await + .unwrap(); + assert!(!result.ok); + assert_eq!(result.reason.as_deref(), Some("mime_mismatch")); + + let result = runtime + .complete_fs_workspace_blob( + config.clone(), + "workspace".to_string(), + "not-the-sha-key".to_string(), + 4, + "text/plain".to_string(), + ) + .await + .unwrap(); + assert!(!result.ok); + assert_eq!(result.reason.as_deref(), Some("not_found")); + + fs_put( + &config, + "workspace/not-the-sha-key", + b"body".to_vec(), + ObjectPutMetadata { + content_type: Some("text/plain".to_string()), + content_length: Some(4), + checksum_crc32: None, + }, + ) + .unwrap(); + let result = runtime + .complete_fs_workspace_blob( + config.clone(), + "workspace".to_string(), + "not-the-sha-key".to_string(), + 4, + "text/plain".to_string(), + ) + .await + .unwrap(); + assert!(!result.ok); + assert_eq!(result.reason.as_deref(), Some("checksum_mismatch")); + assert!(fs_get(&config, "workspace/not-the-sha-key").unwrap().is_none()); + + let result = runtime + .complete_fs_workspace_blob( + config, + "workspace".to_string(), + "too-large".to_string(), + MAX_BLOB_SIZE + 1, + "text/plain".to_string(), + ) + .await + .unwrap(); + assert!(!result.ok); + assert_eq!(result.reason.as_deref(), Some("size_too_large")); + } + #[test] fn fs_backend_rejects_metadata_mismatch() { let temp = tempfile::tempdir().unwrap(); @@ -1543,6 +1759,26 @@ mod tests { 1 ); + let percent_key = "workspace/%literal.txt"; + let wildcard_collision_key = "workspace/aliteral.txt"; + for key in [percent_key, wildcard_collision_key] { + assetpack::put( + &config, + &scope, + key, + b"literal prefix body".to_vec(), + ObjectPutMetadata { + content_type: None, + content_length: None, + checksum_crc32: None, + }, + ) + .await?; + } + let percent_matches = assetpack::list(&config, &scope, Some("workspace/%".to_string())).await?; + assert_eq!(percent_matches.len(), 1); + assert_eq!(percent_matches[0].key, percent_key); + assetpack::delete(&config, &scope, key).await?; assert!(assetpack::head(&config, &scope, key).await?.is_none()); Ok(()) diff --git a/packages/backend/native/src/runtime/storage_runtime/object_storage/client.rs b/packages/backend/native/src/runtime/storage_runtime/object_storage/client.rs index 48a051e8c4..b225380663 100644 --- a/packages/backend/native/src/runtime/storage_runtime/object_storage/client.rs +++ b/packages/backend/native/src/runtime/storage_runtime/object_storage/client.rs @@ -27,6 +27,8 @@ use super::{ }, }; +const DEFAULT_REQUEST_TIMEOUT_MS: u64 = 30_000; +const MAX_MULTIPART_PART_NUMBER: i32 = 10_000; const MAX_RESPONSE_BODY_BYTES: usize = i32::MAX as usize; type StorageHttpFuture<'a> = Pin> + Send + 'a>>; @@ -57,10 +59,9 @@ struct ReqwestStorageHttpClient { impl ReqwestStorageHttpClient { fn new(request_timeout_ms: Option) -> ObjectStorageResult { - let mut builder = ReqwestClient::builder(); - if let Some(request_timeout_ms) = request_timeout_ms { - builder = builder.timeout(Duration::from_millis(request_timeout_ms)); - } + let builder = ReqwestClient::builder().timeout(Duration::from_millis( + request_timeout_ms.unwrap_or(DEFAULT_REQUEST_TIMEOUT_MS), + )); Ok(Self { client: builder.build().map_err(ObjectStorageError::HttpClientBuild)?, }) @@ -627,7 +628,12 @@ fn response_header_name(headers: &HeaderMap, name: &str) -> Option { } fn checked_part_number(part_number: i32) -> ObjectStorageResult { - u16::try_from(part_number).map_err(|_| ObjectStorageError::InvalidInput("part number must fit u16".to_string())) + if !(1..=MAX_MULTIPART_PART_NUMBER).contains(&part_number) { + return Err(ObjectStorageError::InvalidInput( + "multipart part number must be between 1 and 10000".to_string(), + )); + } + Ok(part_number as u16) } fn validate_completed_parts(parts: &[MultipartUploadPart]) -> ObjectStorageResult<()> { @@ -677,3 +683,168 @@ fn parse_rfc3339_ms(value: &str) -> i64 { .map(|value| value.timestamp_millis()) .unwrap_or(0) } + +#[cfg(test)] +mod tests { + use reqwest::header::HeaderValue; + + use super::*; + + #[test] + fn metadata_from_headers_uses_s3_defaults_and_checksum() { + let mut headers = HeaderMap::new(); + headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/plain")); + headers.insert(CONTENT_LENGTH, HeaderValue::from_static("42")); + headers.insert(LAST_MODIFIED, HeaderValue::from_static("Wed, 21 Oct 2015 07:28:00 GMT")); + headers.insert("x-amz-checksum-crc32", HeaderValue::from_static("checksum")); + + let metadata = metadata_from_headers(&headers); + + assert_eq!(metadata.content_type, "text/plain"); + assert_eq!(metadata.content_length, 42); + assert_eq!(metadata.last_modified_ms, 1_445_412_480_000); + assert_eq!(metadata.checksum_crc32.as_deref(), Some("checksum")); + + let defaults = metadata_from_headers(&HeaderMap::new()); + assert_eq!(defaults.content_type, "application/octet-stream"); + assert_eq!(defaults.content_length, 0); + assert_eq!(defaults.last_modified_ms, 0); + assert!(defaults.checksum_crc32.is_none()); + } + + #[test] + fn not_found_body_accepts_object_missing_codes_only() { + for body in [ + "NoSuchKey", + "NotFound", + "NoSuchUpload", + ] { + assert!(is_not_found_body(body.as_bytes()), "{body}"); + } + assert!(!is_not_found_body(b"")); + assert!(!is_not_found_body(b"AccessDenied")); + } + + #[test] + fn list_parts_xml_handles_array_single_part_and_pagination() { + let xml = r#" + + test + key + upload-id + 0 + 3 + 2 + true + + 1 + 2010-11-10T20:48:34.000Z + "etag-1" + 10485760 + + + 2 + 2010-11-10T20:48:33.000Z + etag-2 + 10485760 + +"#; + let parsed = ListParts::parse_response(xml).unwrap(); + let parts = parsed + .parts + .into_iter() + .map(|part| MultipartUploadPart { + part_number: i32::from(part.number), + etag: trim_etag(&part.etag), + }) + .collect::>(); + + assert_eq!( + parts, + vec![ + MultipartUploadPart { + part_number: 1, + etag: "etag-1".to_string() + }, + MultipartUploadPart { + part_number: 2, + etag: "etag-2".to_string() + } + ] + ); + assert_eq!(parsed.next_part_number_marker, Some(3)); + + let single = r#" + + test + key + upload-id + 1 + false + + 5 + 2010-11-10T20:48:34.000Z + "etag-5" + 10485760 + +"#; + let parsed = ListParts::parse_response(single).unwrap(); + assert_eq!(parsed.parts.len(), 1); + assert_eq!(parsed.parts[0].number, 5); + assert_eq!(trim_etag(&parsed.parts[0].etag), "etag-5"); + assert_eq!(parsed.next_part_number_marker, None); + } + + #[test] + fn complete_multipart_body_orders_and_escapes_parts() { + let mut parts = completed_multipart_parts(vec![ + MultipartUploadPart { + part_number: 2, + etag: "b&c".to_string(), + }, + MultipartUploadPart { + part_number: 1, + etag: "a".to_string(), + }, + ]); + validate_completed_parts(&parts).unwrap(); + + let body = complete_multipart_body(&parts); + + assert_eq!( + body, + "a<tag>1b&c2" + ); + + parts[0].etag.clear(); + assert!(validate_completed_parts(&parts).is_err()); + assert!( + validate_completed_parts(&[MultipartUploadPart { + part_number: -1, + etag: "etag".to_string(), + }]) + .is_err() + ); + assert!( + validate_completed_parts(&[MultipartUploadPart { + part_number: 0, + etag: "etag".to_string(), + }]) + .is_err() + ); + assert!( + validate_completed_parts(&[MultipartUploadPart { + part_number: 10_001, + etag: "etag".to_string(), + }]) + .is_err() + ); + } + + #[test] + fn parse_rfc3339_ms_returns_zero_for_invalid_values() { + assert_eq!(parse_rfc3339_ms("2024-01-02T03:04:05Z"), 1_704_164_645_000); + assert_eq!(parse_rfc3339_ms("not a date"), 0); + } +} diff --git a/packages/backend/native/src/runtime/storage_runtime/object_storage/error.rs b/packages/backend/native/src/runtime/storage_runtime/object_storage/error.rs index 232e2610d1..0b2a35f219 100644 --- a/packages/backend/native/src/runtime/storage_runtime/object_storage/error.rs +++ b/packages/backend/native/src/runtime/storage_runtime/object_storage/error.rs @@ -45,7 +45,8 @@ impl ObjectStorageError { match self { Self::Operation { source, .. } => source.is_not_found(), Self::HttpStatus { status, body, .. } => { - *status == StatusCode::NOT_FOUND && (body.contains("NoSuchKey") || body.contains("NotFound")) + *status == StatusCode::NOT_FOUND + && (body.contains("NoSuchKey") || body.contains("NoSuchUpload") || body.contains("NotFound")) } _ => false, } diff --git a/packages/backend/native/src/runtime/storage_runtime/object_storage/tests.rs b/packages/backend/native/src/runtime/storage_runtime/object_storage/tests.rs index 16278f988d..dea7f55972 100644 --- a/packages/backend/native/src/runtime/storage_runtime/object_storage/tests.rs +++ b/packages/backend/native/src/runtime/storage_runtime/object_storage/tests.rs @@ -6,6 +6,14 @@ use super::{ types::{MultipartUploadPart, ObjectPutMetadata, StorageProviderConfig, completed_multipart_parts, trim_etag}, }; +fn storage_config(provider: &str, config: serde_json::Value) -> StorageProviderConfig { + StorageProviderConfig { + provider: provider.to_string(), + bucket: "test-bucket".to_string(), + config, + } +} + #[test] fn resolves_r2_config_from_config_json_shape() { let storage = StorageProviderConfig { @@ -38,6 +46,66 @@ fn resolves_r2_config_from_config_json_shape() { assert_eq!(config.access_key_id.as_deref(), Some("key")); } +#[test] +fn resolves_r2_endpoint_cases_from_config_json_shape() { + for (case, config, expected_endpoint) in [ + ( + "default account endpoint", + serde_json::json!({ + "accountId": "account", + "credentials": { + "accessKeyId": "key", + "secretAccessKey": "secret" + } + }), + Some("https://account.r2.cloudflarestorage.com"), + ), + ( + "explicit null jurisdiction", + serde_json::json!({ + "accountId": "account", + "jurisdiction": null, + "credentials": { + "accessKeyId": "key", + "secretAccessKey": "secret" + } + }), + Some("https://account.r2.cloudflarestorage.com"), + ), + ( + "eu jurisdiction", + serde_json::json!({ + "accountId": "account", + "jurisdiction": "eu", + "credentials": { + "accessKeyId": "key", + "secretAccessKey": "secret" + } + }), + Some("https://account.eu.r2.cloudflarestorage.com"), + ), + ] { + let config = ObjectStorageConfig::from_r2_config(storage_config("cloudflare-r2", config)) + .unwrap() + .unwrap(); + assert_eq!(config.endpoint.as_deref(), expected_endpoint, "{case}"); + assert!(config.force_path_style, "{case}"); + } + + assert!( + ObjectStorageConfig::from_r2_config(storage_config( + "cloudflare-r2", + serde_json::json!({ + "credentials": { + "accessKeyId": "key", + "secretAccessKey": "secret" + } + }) + )) + .is_err() + ); +} + #[test] fn object_storage_not_found_requires_object_error_code() { let bucket_or_route_missing = ObjectStorageError::HttpStatus { @@ -50,9 +118,15 @@ fn object_storage_not_found_requires_object_error_code() { status: StatusCode::NOT_FOUND, body: "NoSuchKey".to_string(), }; + let upload_missing = ObjectStorageError::HttpStatus { + context: "abort failed".to_string(), + status: StatusCode::NOT_FOUND, + body: "NoSuchUpload".to_string(), + }; assert!(!bucket_or_route_missing.is_not_found()); assert!(object_missing.is_not_found()); + assert!(upload_missing.is_not_found()); } #[test] @@ -113,6 +187,28 @@ fn resolves_s3_config_from_config_json_shape() { assert_eq!(config.presign_sign_content_type_for_put, Some(false)); } +#[test] +fn resolves_s3_default_endpoint_cases_from_config_json_shape() { + for (region, expected_endpoint) in [ + ("us-east-1", "https://s3.amazonaws.com"), + ("us-west-2", "https://s3.us-west-2.amazonaws.com"), + ] { + let config = ObjectStorageConfig::from_s3_config(storage_config( + "aws-s3", + serde_json::json!({ + "region": region, + "credentials": { + "accessKeyId": "key", + "secretAccessKey": "secret" + } + }), + )) + .unwrap() + .unwrap(); + assert_eq!(config.endpoint.as_deref(), Some(expected_endpoint), "{region}"); + } +} + #[tokio::test] async fn object_storage_presign_put_returns_sigv4_url_and_headers() { let storage = StorageProviderConfig { @@ -155,6 +251,47 @@ async fn object_storage_presign_put_returns_sigv4_url_and_headers() { assert!(result.expires_at_ms > 0); } +#[tokio::test] +async fn object_storage_presign_put_respects_content_length_and_signed_content_type_flag() { + let config = ObjectStorageConfig::from_s3_config(storage_config( + "aws-s3", + serde_json::json!({ + "region": "us-east-1", + "endpoint": "https://s3.us-east-1.amazonaws.com", + "credentials": { + "accessKeyId": "key", + "secretAccessKey": "secret" + }, + "presign": { + "expiresInSeconds": 60, + "signContentTypeForPut": false + } + }), + )) + .unwrap() + .unwrap(); + let client = config.build_client().unwrap(); + let result = client + .presign_put( + "key", + ObjectPutMetadata { + content_type: Some("text/plain".to_string()), + content_length: Some(42), + ..Default::default() + }, + ) + .await + .unwrap(); + + assert_eq!( + result.headers.get("Content-Type").map(String::as_str), + Some("text/plain") + ); + assert_eq!(result.headers.get("Content-Length").map(String::as_str), Some("42")); + assert!(!result.url.contains("content-type")); + assert!(result.url.contains("content-length")); +} + #[tokio::test] async fn object_storage_presign_get_returns_sigv4_url_without_headers() { let storage = StorageProviderConfig { @@ -182,6 +319,34 @@ async fn object_storage_presign_get_returns_sigv4_url_without_headers() { assert!(result.expires_at_ms > 0); } +#[tokio::test] +async fn object_storage_presign_upload_part_returns_sigv4_url() { + let config = ObjectStorageConfig::from_s3_config(storage_config( + "aws-s3", + serde_json::json!({ + "region": "us-east-1", + "endpoint": "https://s3.us-east-1.amazonaws.com", + "credentials": { + "accessKeyId": "key", + "secretAccessKey": "secret" + }, + "presign": { + "expiresInSeconds": 60 + } + }), + )) + .unwrap() + .unwrap(); + let client = config.build_client().unwrap(); + let result = client.presign_upload_part("key", "upload-1", 3).await.unwrap(); + + assert!(result.url.contains("X-Amz-Algorithm=AWS4-HMAC-SHA256")); + assert!(result.url.contains("partNumber=3")); + assert!(result.url.contains("uploadId=upload-1")); + assert!(result.headers.is_empty()); + assert!(result.expires_at_ms > 0); +} + #[test] fn object_storage_orders_completed_multipart_parts_and_trims_etags() { let parts = completed_multipart_parts(vec![ diff --git a/packages/backend/server/src/__tests__/e2e/create-app.ts b/packages/backend/server/src/__tests__/e2e/create-app.ts index 4d6e157599..9b1e0f69bd 100644 --- a/packages/backend/server/src/__tests__/e2e/create-app.ts +++ b/packages/backend/server/src/__tests__/e2e/create-app.ts @@ -13,6 +13,7 @@ import { AFFiNELogger, CacheInterceptor, CloudThrottlerGuard, + ConfigFactory, EventBus, GlobalExceptionFilter, JobQueue, @@ -250,6 +251,31 @@ export async function createApp( } const module = await builder.compile(); + module.get(ConfigFactory).override({ + storages: { + avatar: { + storage: { + provider: 'assetpack', + bucket: 'avatars', + config: { path: '/tmp/affine-test-storage' }, + }, + }, + blob: { + storage: { + provider: 'assetpack', + bucket: 'blobs', + config: { path: '/tmp/affine-test-storage' }, + }, + }, + }, + copilot: { + storage: { + provider: 'assetpack', + bucket: 'copilot', + config: { path: '/tmp/affine-test-storage' }, + }, + }, + }); module.useCustomApplicationConstructor(TestingApp); diff --git a/packages/backend/server/src/__tests__/e2e/storage/r2-proxy.spec.ts b/packages/backend/server/src/__tests__/e2e/storage/r2-proxy.spec.ts index 5a1e8c317f..f705ffe2e1 100644 --- a/packages/backend/server/src/__tests__/e2e/storage/r2-proxy.spec.ts +++ b/packages/backend/server/src/__tests__/e2e/storage/r2-proxy.spec.ts @@ -362,6 +362,7 @@ e2e.serial('should proxy single upload with valid signature', async t => { t.is(init.method, 'PRESIGNED'); t.truthy(init.uploadUrl); const uploadUrl = new URL(init.uploadUrl, app.url); + t.is(uploadUrl.origin, 'https://cdn.example.com'); t.is(uploadUrl.pathname, PROXY_UPLOAD_PATH); const res = await app @@ -391,6 +392,7 @@ e2e.serial('should proxy multipart upload and return etag', async t => { const part = await getBlobUploadPartUrl(workspace.id, key, init.uploadId, 1); const partUrl = new URL(part.uploadUrl, app.url); + t.is(partUrl.origin, 'https://cdn.example.com'); t.is(partUrl.pathname, PROXY_MULTIPART_PATH); const payload = Buffer.from('part-body'); diff --git a/packages/backend/server/src/__tests__/utils/testing-module.ts b/packages/backend/server/src/__tests__/utils/testing-module.ts index 242a7ae695..06f5eb2e68 100644 --- a/packages/backend/server/src/__tests__/utils/testing-module.ts +++ b/packages/backend/server/src/__tests__/utils/testing-module.ts @@ -9,7 +9,7 @@ import { import { PrismaClient } from '@prisma/client'; import { buildAppModule, FunctionalityModules } from '../../app.module'; -import { AFFiNELogger, JobQueue } from '../../base'; +import { AFFiNELogger, ConfigFactory, JobQueue } from '../../base'; import { GqlModule } from '../../base/graphql'; import { ServerConfigModule } from '../../core'; import { AuthGuard, AuthModule } from '../../core/auth'; @@ -99,6 +99,31 @@ export async function createTestingModule( } const module = await builder.compile(); + module.get(ConfigFactory).override({ + storages: { + avatar: { + storage: { + provider: 'assetpack', + bucket: 'avatars', + config: { path: '/tmp/affine-test-storage' }, + }, + }, + blob: { + storage: { + provider: 'assetpack', + bucket: 'blobs', + config: { path: '/tmp/affine-test-storage' }, + }, + }, + }, + copilot: { + storage: { + provider: 'assetpack', + bucket: 'copilot', + config: { path: '/tmp/affine-test-storage' }, + }, + }, + }); const testingModule = module as TestingModule; diff --git a/packages/backend/server/src/__tests__/workspace/blobs.e2e.ts b/packages/backend/server/src/__tests__/workspace/blobs.e2e.ts index 13c81ea205..6722b98240 100644 --- a/packages/backend/server/src/__tests__/workspace/blobs.e2e.ts +++ b/packages/backend/server/src/__tests__/workspace/blobs.e2e.ts @@ -35,6 +35,22 @@ const RESTRICTED_QUOTA = { let app: TestingApp; let model: WorkspaceFeatureModel; +type CompleteResult = + | { + ok: true; + contentType: string; + contentLength: number; + lastModifiedMs: number; + } + | { + ok: false; + reason: + | 'not_found' + | 'size_mismatch' + | 'mime_mismatch' + | 'checksum_mismatch' + | 'size_too_large'; + }; const objects = new Map< string, { @@ -46,6 +62,7 @@ const objects = new Map< }; } >(); +const completeResults = new Map(); const storageRuntime = { providerCapabilities: async () => ({ put: true, @@ -100,25 +117,12 @@ const storageRuntime = { presignUploadPart: async () => undefined, listMultipartUploadParts: async () => undefined, completeMultipartUpload: async () => undefined, - completeWorkspaceBlobUpload: async ( - workspaceId: string, - key: string, - expected: { size: number; mime: string } - ) => { + completeWorkspaceBlobUpload: async (workspaceId: string, key: string) => { const objectKey = `${workspaceId}/${key}`; + const configured = completeResults.get(objectKey); + if (configured) return configured; const object = objects.get(objectKey); - if (!object) { - return { ok: false, reason: 'not_found' }; - } - if (object.metadata.contentLength !== expected.size) { - return { ok: false, reason: 'size_mismatch' }; - } - if (object.metadata.contentType !== expected.mime) { - return { ok: false, reason: 'mime_mismatch' }; - } - if (sha256Base64urlWithPadding(object.body) !== key) { - return { ok: false, reason: 'checksum_mismatch' }; - } + if (!object) return { ok: false, reason: 'not_found' }; await app.get(BlobModel).upsert({ workspaceId, key, @@ -159,6 +163,7 @@ test.before(async () => { test.beforeEach(async () => { await app.initTestingDB(); objects.clear(); + completeResults.clear(); }); test.after.always(async () => { @@ -333,6 +338,10 @@ test('should reject complete when blob key mismatched', async t => { contentType: mime, contentLength: buffer.length, }); + completeResults.set(`${workspace.id}/${wrongKey}`, { + ok: false, + reason: 'checksum_mismatch', + }); await t.throwsAsync(() => completeBlobUpload(app, workspace.id, wrongKey), { message: 'Blob key mismatch', diff --git a/packages/backend/server/src/base/config/register.ts b/packages/backend/server/src/base/config/register.ts index 4153114969..f2e6b577ff 100644 --- a/packages/backend/server/src/base/config/register.ts +++ b/packages/backend/server/src/base/config/register.ts @@ -21,6 +21,7 @@ export type JSONSchema = { description?: string } & ( | { type: 'object'; properties?: Record; + required?: string[]; } ); diff --git a/packages/backend/server/src/base/storage/providers/index.ts b/packages/backend/server/src/base/storage/providers/index.ts index 9f7c3fb16f..f5451dbd1a 100644 --- a/packages/backend/server/src/base/storage/providers/index.ts +++ b/packages/backend/server/src/base/storage/providers/index.ts @@ -232,8 +232,10 @@ export const StorageJSONSchema: JSONSchema = { type: 'string', }, }, + required: ['path'], }, }, + required: ['provider', 'bucket', 'config'], }, ], }; diff --git a/packages/backend/server/src/core/storage-runtime/__tests__/provider.spec.ts b/packages/backend/server/src/core/storage-runtime/__tests__/provider.spec.ts index aeefd0a947..16b455e0dc 100644 --- a/packages/backend/server/src/core/storage-runtime/__tests__/provider.spec.ts +++ b/packages/backend/server/src/core/storage-runtime/__tests__/provider.spec.ts @@ -4,8 +4,36 @@ import Sinon from 'sinon'; import { StorageRuntimeProvider } from '../provider'; function createProvider() { - const provider = new StorageRuntimeProvider(); + const provider = new StorageRuntimeProvider({ + db: { + datasourceUrl: 'postgresql://localhost:5432/affine', + }, + storages: { + blob: { + storage: { + provider: 'fs', + bucket: 'blobs', + config: { path: '~/.affine/storage' }, + }, + }, + avatar: { + storage: { + provider: 'fs', + bucket: 'avatars', + config: { path: '~/.affine/storage' }, + }, + }, + }, + copilot: { + storage: { + provider: 'fs', + bucket: 'copilot', + config: { path: '~/.affine/storage' }, + }, + }, + } as any); const runtime = { + configure: Sinon.stub(), start: Sinon.stub().resolves(), stop: Sinon.stub().resolves(), runMigrations: Sinon.stub().resolves(), @@ -26,6 +54,29 @@ test('storage-runtime provider restarts on storage config changes', async t => { await provider.onConfigChanged({ updates: { storages: {} } }); t.is(runtime.stop.callCount, 1); + t.is(runtime.configure.callCount, 2); + t.is(runtime.start.callCount, 2); + t.is(runtime.runMigrations.callCount, 2); +}); + +test('storage-runtime provider restarts on copilot storage config changes', async t => { + const { provider, runtime } = createProvider(); + + await provider.start(); + await provider.onConfigChanged({ + updates: { + copilot: { + storage: { + provider: 'fs', + bucket: 'new-copilot', + config: { path: '~/.affine/storage' }, + }, + }, + }, + }); + + t.is(runtime.stop.callCount, 1); + t.is(runtime.configure.callCount, 2); t.is(runtime.start.callCount, 2); t.is(runtime.runMigrations.callCount, 2); }); diff --git a/packages/backend/server/src/core/storage-runtime/provider.ts b/packages/backend/server/src/core/storage-runtime/provider.ts index 905ae99d38..4b05958a22 100644 --- a/packages/backend/server/src/core/storage-runtime/provider.ts +++ b/packages/backend/server/src/core/storage-runtime/provider.ts @@ -14,7 +14,7 @@ import type { PresignedUpload, PutObjectMetadata, } from '../../base'; -import { OnEvent } from '../../base'; +import { Config, OnEvent } from '../../base'; import { wrapCallMetric } from '../../base/metrics'; import { type RuntimeObjectGetResult, @@ -36,6 +36,8 @@ export class StorageRuntimeProvider private readonly runtime: RuntimeInstance = new StorageRuntime(); private migrationsStarted = false; + constructor(private readonly config: Config) {} + async onApplicationBootstrap() { await this.start(); } @@ -45,6 +47,7 @@ export class StorageRuntimeProvider } async start() { + this.configureRuntime(); await this.runtime.start(); await this.runMigrationsOnce(); const health = await this.runtime.health(); @@ -65,7 +68,11 @@ export class StorageRuntimeProvider @OnEvent('config.changed') async onConfigChanged({ updates }: Events['config.changed']) { - if (!('storages' in updates) && !('db' in updates)) { + if ( + !('storages' in updates) && + !('db' in updates) && + !updates.copilot?.storage + ) { return; } await this.restart(); @@ -293,6 +300,23 @@ export class StorageRuntimeProvider this.migrationsStarted = false; await this.start(); } + + private configureRuntime() { + this.runtime.configure( + JSON.stringify({ + db: { + datasourceUrl: this.config.db.datasourceUrl, + }, + storages: { + 'blob.storage': this.config.storages.blob.storage, + 'avatar.storage': this.config.storages.avatar.storage, + }, + copilot: { + storage: this.config.copilot.storage, + }, + }) + ); + } } function toRuntimeMetadata(metadata?: PutObjectMetadata) { diff --git a/packages/backend/server/src/core/storage/__tests__/blob-job.spec.ts b/packages/backend/server/src/core/storage/__tests__/blob-job.spec.ts index 40d5c1297a..2b433ad2a1 100644 --- a/packages/backend/server/src/core/storage/__tests__/blob-job.spec.ts +++ b/packages/backend/server/src/core/storage/__tests__/blob-job.spec.ts @@ -32,6 +32,7 @@ test.beforeEach(t => { health: Sinon.stub().resolves({ databaseConnected: true, providerConfigured: true, + provider: 'fs', }), backfillMissingBlobMetadata: Sinon.stub(), rebuildWorkspaceDocBlobRefs: Sinon.stub(), @@ -103,7 +104,8 @@ for (const scenario of objectStorageRequiredCases) { test(`${scenario.name} skips when object storage is not configured`, async t => { t.context.runtime.health.resolves({ databaseConnected: true, - providerConfigured: false, + providerConfigured: true, + provider: undefined, }); await scenario.run(t.context); diff --git a/packages/backend/server/src/core/storage/blob-job.ts b/packages/backend/server/src/core/storage/blob-job.ts index 13cf8d74e2..235daf59b3 100644 --- a/packages/backend/server/src/core/storage/blob-job.ts +++ b/packages/backend/server/src/core/storage/blob-job.ts @@ -411,7 +411,7 @@ export class StorageBlobJob { private async hasObjectStorage(operation: string) { const health = await this.rt.health(); - if (health.providerConfigured) { + if (health.provider) { return true; } diff --git a/packages/backend/server/src/core/storage/wrappers/blob.ts b/packages/backend/server/src/core/storage/wrappers/blob.ts index 219b9c01bc..09387fc74b 100644 --- a/packages/backend/server/src/core/storage/wrappers/blob.ts +++ b/packages/backend/server/src/core/storage/wrappers/blob.ts @@ -50,6 +50,11 @@ type BlobGetResult = { metadata?: GetObjectMetadata; }; +type R2ProxyConfig = { + signKey: string; + urlPrefix: string; +}; + @Injectable() export class WorkspaceBlobStorage { private readonly logger = new Logger(WorkspaceBlobStorage.name); @@ -317,7 +322,10 @@ export class WorkspaceBlobStorage { ) { return; } - return { signKey: usePresignedURL.signKey }; + return { + signKey: usePresignedURL.signKey, + urlPrefix: usePresignedURL.urlPrefix, + }; } private signProxy( @@ -340,7 +348,7 @@ export class WorkspaceBlobStorage { workspaceId: string, key: string, metadata: PutObjectMetadata | undefined, - proxy: { signKey: string } + proxy: R2ProxyConfig ) { const contentType = metadata?.contentType ?? 'application/octet-stream'; const contentLength = metadata?.contentLength; @@ -353,7 +361,7 @@ export class WorkspaceBlobStorage { proxy.signKey ); return { - url: this.url.link(PROXY_UPLOAD_PATH, { + url: this.linkProxyUrl(proxy.urlPrefix, PROXY_UPLOAD_PATH, { workspaceId, key, contentType, @@ -371,7 +379,7 @@ export class WorkspaceBlobStorage { key: string, uploadId: string, partNumber: number, - proxy: { signKey: string } + proxy: R2ProxyConfig ) { const expiresAt = new Date(Date.now() + SIGNED_URL_EXPIRED * 1000); const exp = Math.floor(expiresAt.getTime() / 1000); @@ -382,7 +390,7 @@ export class WorkspaceBlobStorage { proxy.signKey ); return { - url: this.url.link(PROXY_MULTIPART_PATH, { + url: this.linkProxyUrl(proxy.urlPrefix, PROXY_MULTIPART_PATH, { workspaceId, key, uploadId, @@ -394,4 +402,20 @@ export class WorkspaceBlobStorage { expiresAt, }; } + + private linkProxyUrl( + urlPrefix: string, + path: string, + query: Record + ) { + const url = new URL( + `${urlPrefix.replace(/\/+$/, '')}${path.startsWith('/') ? path : `/${path}`}` + ); + for (const [key, value] of Object.entries(query)) { + if (value !== undefined) { + url.searchParams.set(key, value.toString()); + } + } + return url.toString(); + } } diff --git a/packages/backend/server/src/core/user/controller.ts b/packages/backend/server/src/core/user/controller.ts index 9e7d91bdff..63571f3ac9 100644 --- a/packages/backend/server/src/core/user/controller.ts +++ b/packages/backend/server/src/core/user/controller.ts @@ -16,9 +16,10 @@ export class UserAvatarController { @Get('/:id') async getAvatar(@Res() res: Response, @Param('id') id: string) { - if (this.storage.config.storage.provider !== 'fs') { + const provider = this.storage.config.storage.provider; + if (!['assetpack', 'fs'].includes(provider)) { throw new ActionForbidden( - 'Only available when avatar storage provider set to fs.' + 'Only available when avatar storage provider is fs or assetpack.' ); } diff --git a/packages/backend/server/src/core/workspaces/resolvers/blob.ts b/packages/backend/server/src/core/workspaces/resolvers/blob.ts index b6ca0a6f1b..464d150def 100644 --- a/packages/backend/server/src/core/workspaces/resolvers/blob.ts +++ b/packages/backend/server/src/core/workspaces/resolvers/blob.ts @@ -398,6 +398,9 @@ export class WorkspaceBlobResolver { if (result.reason === 'mime_mismatch') { throw new BlobInvalid('Blob mime mismatch'); } + if (result.reason === 'size_too_large') { + throw new BlobInvalid('Blob size too large'); + } throw new BlobInvalid('Blob key mismatch'); } diff --git a/packages/backend/server/src/plugins/copilot/storage.ts b/packages/backend/server/src/plugins/copilot/storage.ts index 2cb94b1d5f..27189654f7 100644 --- a/packages/backend/server/src/plugins/copilot/storage.ts +++ b/packages/backend/server/src/plugins/copilot/storage.ts @@ -39,7 +39,10 @@ export class CopilotStorage { ) { const name = `${userId}/${workspaceId}/${key}`; const buffer = await toBuffer(blob); - await this.rt.putObject('copilot', name, buffer); + await this.rt.putObject('copilot', name, buffer, { + contentType: mimeType, + contentLength: buffer.length, + }); if (!env.prod) { // return image base64url for dev environment return `data:${mimeType};base64,${buffer.toString('base64')}`;