fix: test & lint

This commit is contained in:
DarkSky
2026-07-01 04:17:30 +08:00
parent 0e20823311
commit f051bc64a8
22 changed files with 818 additions and 96 deletions
+1
View File
@@ -68,6 +68,7 @@ export declare class StorageRuntime {
rebuildWorkspaceDocBlobRefs(workspaceId: string, limit: number): Promise<RuntimeDocBlobRefsResult>
constructor()
start(): Promise<void>
configure(configJson: string): void
stop(): Promise<void>
runMigrations(): Promise<void>
health(): Promise<StorageRuntimeHealth>
@@ -94,6 +94,13 @@ impl BackendRuntime {
owner: String,
ttl_ms: i64,
) -> RuntimeResult<Option<CoordinationLeaseGrant>> {
if ttl_ms <= 0 {
return Err(RuntimeError::invalid_input("coordination lease ttl must be positive"));
}
if owner.is_empty() {
return Err(RuntimeError::invalid_input("coordination lease owner is required"));
}
CoordinationLeaseStore::new(self.pool().await?)
.acquire(key, owner, ttl_ms)
.await
@@ -117,13 +124,6 @@ impl BackendRuntime {
owner: String,
ttl_ms: i64,
) -> Result<Option<CoordinationLeaseGrant>> {
if ttl_ms <= 0 {
return Err(napi_error("coordination lease ttl must be positive"));
}
if owner.is_empty() {
return Err(napi_error("coordination lease owner is required"));
}
self
.acquire_coordination_lease_inner(key, owner, ttl_ms)
.await
@@ -208,21 +208,21 @@ pub(super) async fn list(
scope: &str,
prefix: Option<String>,
) -> RuntimeResult<Vec<ObjectListEntry>> {
let prefix = prefix.unwrap_or_default();
if !prefix.is_empty() {
super::normalize_storage_prefix(&prefix)?;
}
let prefix = prefix
.map(|prefix| super::normalize_storage_prefix(&prefix))
.transpose()?
.unwrap_or_default();
let store = open_store(config).await?;
let rows = sqlx::query(
r#"
SELECT key, content_length, last_modified_ms
FROM storage_assetpack_blobs
WHERE scope = ?1 AND key LIKE ?2
WHERE scope = ?1 AND key LIKE ?2 ESCAPE '\'
ORDER BY key ASC
"#,
)
.bind(scope)
.bind(format!("{prefix}%"))
.bind(format!("{}%", escape_sqlite_like(&prefix)))
.fetch_all(store.pool())
.await
.map_err(|err| RuntimeError::database("Assetpack manifest list failed", err))?;
@@ -239,6 +239,20 @@ pub(super) async fn list(
.collect()
}
fn escape_sqlite_like(value: &str) -> String {
let mut escaped = String::with_capacity(value.len());
for ch in value.chars() {
match ch {
'%' | '_' | '\\' => {
escaped.push('\\');
escaped.push(ch);
}
_ => escaped.push(ch),
}
}
escaped
}
pub(super) async fn delete(config: &FsStorageConfig, scope: &str, key: &str) -> RuntimeResult<()> {
normalize_storage_key(key)?;
let store = open_store(config).await?;
@@ -6,45 +6,6 @@ use super::{
napi_error,
};
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn blob_cleanup_plan_result_keeps_run_id_for_execute() {
let result = RuntimeBlobCleanupPlanResult {
run_id: Some("00000000-0000-0000-0000-000000000000".to_string()),
scanned_blobs: 1,
candidates_marked: 1,
protected_by_doc_refs: 0,
protected_by_metadata: 0,
protected_by_other_refs: 0,
next_cursor: None,
};
assert!(result.run_id.is_some());
assert_eq!(result.candidates_marked, 1);
}
#[test]
fn blob_cleanup_execute_result_tracks_skipped_and_failed_counts() {
let result = RuntimeBlobCleanupExecuteResult {
scanned_candidates: 3,
deleted_objects: 1,
deleted_metadata: 1,
skipped_still_referenced: 1,
failed: 1,
workspace_ids: vec!["workspace".to_string()],
};
assert_eq!(result.scanned_candidates, 3);
assert_eq!(
result.skipped_still_referenced + result.failed + result.deleted_objects,
3
);
}
}
#[derive(FromRow)]
struct BlobCandidateRow {
workspace_id: String,
@@ -111,7 +111,10 @@ struct CopilotConfigFile {
impl StorageRuntimeConfig {
fn from_config_files() -> RuntimeResult<Self> {
let app_config = app_config_from_config_files()?;
Self::from_app_config_file(app_config_from_config_files()?)
}
fn from_app_config_file(app_config: AppConfigFile) -> RuntimeResult<Self> {
let database_url = database_url_from_env()
.or(app_config.database_url())
.unwrap_or_else(|| "postgresql://localhost:5432/affine".to_string());
@@ -120,11 +123,12 @@ impl StorageRuntimeConfig {
}
async fn with_db_overrides(&self, pool: &PgPool) -> RuntimeResult<Self> {
let mut app_config = app_config_from_config_files()?;
app_config.apply_file_config(load_app_config_overrides_from_db(pool).await?);
let app_config = load_app_config_overrides_from_db(pool).await?;
let mut backends = self.backends.clone();
backends.extend(app_config.storage_backends()?);
Ok(Self {
database_url: self.database_url.clone(),
backends: app_config.storage_backends()?,
backends,
})
}
}
@@ -300,6 +304,14 @@ impl StorageRuntime {
self.start_inner().await.map_err(to_napi_error)
}
#[napi]
pub fn configure(&self, config_json: String) -> napi::Result<()> {
let app_config: AppConfigFile = serde_json::from_str(&config_json)
.map_err(|err| to_napi_error(RuntimeError::json("invalid storage runtime config", err)))?;
let config = StorageRuntimeConfig::from_app_config_file(app_config).map_err(to_napi_error)?;
self.update_config(config).map_err(to_napi_error)
}
async fn start_inner(&self) -> RuntimeResult<()> {
let mut guard = self.pool.lock().await;
if guard.is_some() {
@@ -1295,11 +1307,36 @@ mod tests {
#[test]
fn fs_key_normalization_rejects_traversal() {
for key in ["", "/a", "a//b", "a/./b", "a/../b", "..\\secret"] {
assert!(normalize_storage_key(key).is_err(), "{key}");
for (key, valid) in [
("", false),
("/a", false),
("a//b", false),
("a/./b", false),
("a/../b", false),
("..\\secret", false),
("workspace/blob", true),
("workspace\\blob", true),
] {
assert_eq!(normalize_storage_key(key).is_ok(), valid, "{key}");
}
assert_eq!(normalize_storage_key("workspace/blob").unwrap(), ["workspace", "blob"]);
assert_eq!(normalize_storage_prefix("workspace/").unwrap(), "workspace/");
}
#[test]
fn fs_prefix_normalization_rejects_traversal() {
for (prefix, expected) in [
("", Some("")),
("workspace/", Some("workspace/")),
("workspace\\blob", Some("workspace/blob")),
("../escape", None),
("nested/../../escape", None),
("/absolute", None),
("nested//escape", None),
("nested/./escape", None),
("nested/../escape", None),
] {
assert_eq!(normalize_storage_prefix(prefix).ok().as_deref(), expected, "{prefix}");
}
}
#[test]
@@ -1468,6 +1505,185 @@ mod tests {
assert_eq!(keys, ["workspace/blob-a", "workspace/nested/blob-b"]);
}
#[test]
fn fs_backend_lists_old_node_prefix_semantics() {
let temp = tempfile::tempdir().unwrap();
let config = FsStorageConfig {
provider: "fs".to_string(),
root: temp.path().to_string_lossy().to_string(),
bucket: "bucket".to_string(),
};
for key in ["root-a", "a/item", "a/b/item", "a/b/t/item", "a/b/tail", "z/item"] {
fs_put(&config, key, key.as_bytes().to_vec(), ObjectPutMetadata::default()).unwrap();
}
for (prefix, expected) in [
(
None,
vec!["a/b/item", "a/b/t/item", "a/b/tail", "a/item", "root-a", "z/item"],
),
(Some("a"), vec!["a/b/item", "a/b/t/item", "a/b/tail", "a/item"]),
(Some("a/b"), vec!["a/b/item", "a/b/t/item", "a/b/tail"]),
(Some("a/b/"), vec!["a/b/item", "a/b/t/item", "a/b/tail"]),
(Some("a/b/t"), vec!["a/b/t/item", "a/b/tail"]),
(Some("missing"), vec![]),
] {
let keys = fs_list(&config, prefix.map(ToString::to_string))
.unwrap()
.into_iter()
.map(|entry| entry.key)
.collect::<Vec<_>>();
assert_eq!(keys, expected, "{prefix:?}");
}
}
#[test]
fn fs_backend_delete_removes_object_and_sidecar_idempotently() {
let temp = tempfile::tempdir().unwrap();
let config = FsStorageConfig {
provider: "fs".to_string(),
root: temp.path().to_string_lossy().to_string(),
bucket: "bucket".to_string(),
};
fs_put(
&config,
"workspace/blob",
b"body".to_vec(),
ObjectPutMetadata::default(),
)
.unwrap();
fs_delete(&config, "workspace/blob").unwrap();
fs_delete(&config, "workspace/blob").unwrap();
assert!(fs_head(&config, "workspace/blob").unwrap().is_none());
assert!(fs_get(&config, "workspace/blob").unwrap().is_none());
assert!(!temp.path().join("bucket/workspace/blob").exists());
assert!(!temp.path().join("bucket/workspace/blob.metadata.json").exists());
}
fn test_storage_runtime() -> StorageRuntime {
StorageRuntime {
config: RwLock::new(StorageRuntimeConfig {
database_url: "postgresql://unused".to_string(),
backends: HashMap::new(),
}),
pool: Mutex::new(None),
}
}
#[tokio::test]
async fn fs_workspace_blob_complete_returns_native_failure_reasons_before_db_upsert() {
let temp = tempfile::tempdir().unwrap();
let config = FsStorageConfig {
provider: "fs".to_string(),
root: temp.path().to_string_lossy().to_string(),
bucket: "bucket".to_string(),
};
let runtime = test_storage_runtime();
let result = runtime
.complete_fs_workspace_blob(
config.clone(),
"workspace".to_string(),
"missing".to_string(),
1,
"text/plain".to_string(),
)
.await
.unwrap();
assert!(!result.ok);
assert_eq!(result.reason.as_deref(), Some("not_found"));
fs_put(
&config,
"workspace/blob",
b"body".to_vec(),
ObjectPutMetadata {
content_type: Some("text/plain".to_string()),
content_length: Some(4),
checksum_crc32: None,
},
)
.unwrap();
let result = runtime
.complete_fs_workspace_blob(
config.clone(),
"workspace".to_string(),
"blob".to_string(),
5,
"text/plain".to_string(),
)
.await
.unwrap();
assert!(!result.ok);
assert_eq!(result.reason.as_deref(), Some("size_mismatch"));
let result = runtime
.complete_fs_workspace_blob(
config.clone(),
"workspace".to_string(),
"blob".to_string(),
4,
"image/png".to_string(),
)
.await
.unwrap();
assert!(!result.ok);
assert_eq!(result.reason.as_deref(), Some("mime_mismatch"));
let result = runtime
.complete_fs_workspace_blob(
config.clone(),
"workspace".to_string(),
"not-the-sha-key".to_string(),
4,
"text/plain".to_string(),
)
.await
.unwrap();
assert!(!result.ok);
assert_eq!(result.reason.as_deref(), Some("not_found"));
fs_put(
&config,
"workspace/not-the-sha-key",
b"body".to_vec(),
ObjectPutMetadata {
content_type: Some("text/plain".to_string()),
content_length: Some(4),
checksum_crc32: None,
},
)
.unwrap();
let result = runtime
.complete_fs_workspace_blob(
config.clone(),
"workspace".to_string(),
"not-the-sha-key".to_string(),
4,
"text/plain".to_string(),
)
.await
.unwrap();
assert!(!result.ok);
assert_eq!(result.reason.as_deref(), Some("checksum_mismatch"));
assert!(fs_get(&config, "workspace/not-the-sha-key").unwrap().is_none());
let result = runtime
.complete_fs_workspace_blob(
config,
"workspace".to_string(),
"too-large".to_string(),
MAX_BLOB_SIZE + 1,
"text/plain".to_string(),
)
.await
.unwrap();
assert!(!result.ok);
assert_eq!(result.reason.as_deref(), Some("size_too_large"));
}
#[test]
fn fs_backend_rejects_metadata_mismatch() {
let temp = tempfile::tempdir().unwrap();
@@ -1543,6 +1759,26 @@ mod tests {
1
);
let percent_key = "workspace/%literal.txt";
let wildcard_collision_key = "workspace/aliteral.txt";
for key in [percent_key, wildcard_collision_key] {
assetpack::put(
&config,
&scope,
key,
b"literal prefix body".to_vec(),
ObjectPutMetadata {
content_type: None,
content_length: None,
checksum_crc32: None,
},
)
.await?;
}
let percent_matches = assetpack::list(&config, &scope, Some("workspace/%".to_string())).await?;
assert_eq!(percent_matches.len(), 1);
assert_eq!(percent_matches[0].key, percent_key);
assetpack::delete(&config, &scope, key).await?;
assert!(assetpack::head(&config, &scope, key).await?.is_none());
Ok(())
@@ -27,6 +27,8 @@ use super::{
},
};
const DEFAULT_REQUEST_TIMEOUT_MS: u64 = 30_000;
const MAX_MULTIPART_PART_NUMBER: i32 = 10_000;
const MAX_RESPONSE_BODY_BYTES: usize = i32::MAX as usize;
type StorageHttpFuture<'a> = Pin<Box<dyn Future<Output = ObjectStorageResult<StorageHttpResponse>> + Send + 'a>>;
@@ -57,10 +59,9 @@ struct ReqwestStorageHttpClient {
impl ReqwestStorageHttpClient {
fn new(request_timeout_ms: Option<u64>) -> ObjectStorageResult<Self> {
let mut builder = ReqwestClient::builder();
if let Some(request_timeout_ms) = request_timeout_ms {
builder = builder.timeout(Duration::from_millis(request_timeout_ms));
}
let builder = ReqwestClient::builder().timeout(Duration::from_millis(
request_timeout_ms.unwrap_or(DEFAULT_REQUEST_TIMEOUT_MS),
));
Ok(Self {
client: builder.build().map_err(ObjectStorageError::HttpClientBuild)?,
})
@@ -627,7 +628,12 @@ fn response_header_name(headers: &HeaderMap, name: &str) -> Option<String> {
}
fn checked_part_number(part_number: i32) -> ObjectStorageResult<u16> {
u16::try_from(part_number).map_err(|_| ObjectStorageError::InvalidInput("part number must fit u16".to_string()))
if !(1..=MAX_MULTIPART_PART_NUMBER).contains(&part_number) {
return Err(ObjectStorageError::InvalidInput(
"multipart part number must be between 1 and 10000".to_string(),
));
}
Ok(part_number as u16)
}
fn validate_completed_parts(parts: &[MultipartUploadPart]) -> ObjectStorageResult<()> {
@@ -677,3 +683,168 @@ fn parse_rfc3339_ms(value: &str) -> i64 {
.map(|value| value.timestamp_millis())
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use reqwest::header::HeaderValue;
use super::*;
#[test]
fn metadata_from_headers_uses_s3_defaults_and_checksum() {
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/plain"));
headers.insert(CONTENT_LENGTH, HeaderValue::from_static("42"));
headers.insert(LAST_MODIFIED, HeaderValue::from_static("Wed, 21 Oct 2015 07:28:00 GMT"));
headers.insert("x-amz-checksum-crc32", HeaderValue::from_static("checksum"));
let metadata = metadata_from_headers(&headers);
assert_eq!(metadata.content_type, "text/plain");
assert_eq!(metadata.content_length, 42);
assert_eq!(metadata.last_modified_ms, 1_445_412_480_000);
assert_eq!(metadata.checksum_crc32.as_deref(), Some("checksum"));
let defaults = metadata_from_headers(&HeaderMap::new());
assert_eq!(defaults.content_type, "application/octet-stream");
assert_eq!(defaults.content_length, 0);
assert_eq!(defaults.last_modified_ms, 0);
assert!(defaults.checksum_crc32.is_none());
}
#[test]
fn not_found_body_accepts_object_missing_codes_only() {
for body in [
"<Error><Code>NoSuchKey</Code></Error>",
"<Error><Code>NotFound</Code></Error>",
"<Error><Code>NoSuchUpload</Code></Error>",
] {
assert!(is_not_found_body(body.as_bytes()), "{body}");
}
assert!(!is_not_found_body(b""));
assert!(!is_not_found_body(b"<Error><Code>AccessDenied</Code></Error>"));
}
#[test]
fn list_parts_xml_handles_array_single_part_and_pagination() {
let xml = r#"<?xml version="1.0" encoding="UTF-8"?>
<ListPartsResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/">
<Bucket>test</Bucket>
<Key>key</Key>
<UploadId>upload-id</UploadId>
<PartNumberMarker>0</PartNumberMarker>
<NextPartNumberMarker>3</NextPartNumberMarker>
<MaxParts>2</MaxParts>
<IsTruncated>true</IsTruncated>
<Part>
<PartNumber>1</PartNumber>
<LastModified>2010-11-10T20:48:34.000Z</LastModified>
<ETag>"etag-1"</ETag>
<Size>10485760</Size>
</Part>
<Part>
<PartNumber>2</PartNumber>
<LastModified>2010-11-10T20:48:33.000Z</LastModified>
<ETag>etag-2</ETag>
<Size>10485760</Size>
</Part>
</ListPartsResult>"#;
let parsed = ListParts::parse_response(xml).unwrap();
let parts = parsed
.parts
.into_iter()
.map(|part| MultipartUploadPart {
part_number: i32::from(part.number),
etag: trim_etag(&part.etag),
})
.collect::<Vec<_>>();
assert_eq!(
parts,
vec![
MultipartUploadPart {
part_number: 1,
etag: "etag-1".to_string()
},
MultipartUploadPart {
part_number: 2,
etag: "etag-2".to_string()
}
]
);
assert_eq!(parsed.next_part_number_marker, Some(3));
let single = r#"<?xml version="1.0" encoding="UTF-8"?>
<ListPartsResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/">
<Bucket>test</Bucket>
<Key>key</Key>
<UploadId>upload-id</UploadId>
<MaxParts>1</MaxParts>
<IsTruncated>false</IsTruncated>
<Part>
<PartNumber>5</PartNumber>
<LastModified>2010-11-10T20:48:34.000Z</LastModified>
<ETag>"etag-5"</ETag>
<Size>10485760</Size>
</Part>
</ListPartsResult>"#;
let parsed = ListParts::parse_response(single).unwrap();
assert_eq!(parsed.parts.len(), 1);
assert_eq!(parsed.parts[0].number, 5);
assert_eq!(trim_etag(&parsed.parts[0].etag), "etag-5");
assert_eq!(parsed.next_part_number_marker, None);
}
#[test]
fn complete_multipart_body_orders_and_escapes_parts() {
let mut parts = completed_multipart_parts(vec![
MultipartUploadPart {
part_number: 2,
etag: "b&c".to_string(),
},
MultipartUploadPart {
part_number: 1,
etag: "a<tag>".to_string(),
},
]);
validate_completed_parts(&parts).unwrap();
let body = complete_multipart_body(&parts);
assert_eq!(
body,
"<CompleteMultipartUpload><Part><ETag>a&lt;tag&gt;</ETag><PartNumber>1</PartNumber></Part><Part><ETag>b&amp;c</\
ETag><PartNumber>2</PartNumber></Part></CompleteMultipartUpload>"
);
parts[0].etag.clear();
assert!(validate_completed_parts(&parts).is_err());
assert!(
validate_completed_parts(&[MultipartUploadPart {
part_number: -1,
etag: "etag".to_string(),
}])
.is_err()
);
assert!(
validate_completed_parts(&[MultipartUploadPart {
part_number: 0,
etag: "etag".to_string(),
}])
.is_err()
);
assert!(
validate_completed_parts(&[MultipartUploadPart {
part_number: 10_001,
etag: "etag".to_string(),
}])
.is_err()
);
}
#[test]
fn parse_rfc3339_ms_returns_zero_for_invalid_values() {
assert_eq!(parse_rfc3339_ms("2024-01-02T03:04:05Z"), 1_704_164_645_000);
assert_eq!(parse_rfc3339_ms("not a date"), 0);
}
}
@@ -45,7 +45,8 @@ impl ObjectStorageError {
match self {
Self::Operation { source, .. } => source.is_not_found(),
Self::HttpStatus { status, body, .. } => {
*status == StatusCode::NOT_FOUND && (body.contains("NoSuchKey") || body.contains("NotFound"))
*status == StatusCode::NOT_FOUND
&& (body.contains("NoSuchKey") || body.contains("NoSuchUpload") || body.contains("NotFound"))
}
_ => false,
}
@@ -6,6 +6,14 @@ use super::{
types::{MultipartUploadPart, ObjectPutMetadata, StorageProviderConfig, completed_multipart_parts, trim_etag},
};
fn storage_config(provider: &str, config: serde_json::Value) -> StorageProviderConfig {
StorageProviderConfig {
provider: provider.to_string(),
bucket: "test-bucket".to_string(),
config,
}
}
#[test]
fn resolves_r2_config_from_config_json_shape() {
let storage = StorageProviderConfig {
@@ -38,6 +46,66 @@ fn resolves_r2_config_from_config_json_shape() {
assert_eq!(config.access_key_id.as_deref(), Some("key"));
}
#[test]
fn resolves_r2_endpoint_cases_from_config_json_shape() {
for (case, config, expected_endpoint) in [
(
"default account endpoint",
serde_json::json!({
"accountId": "account",
"credentials": {
"accessKeyId": "key",
"secretAccessKey": "secret"
}
}),
Some("https://account.r2.cloudflarestorage.com"),
),
(
"explicit null jurisdiction",
serde_json::json!({
"accountId": "account",
"jurisdiction": null,
"credentials": {
"accessKeyId": "key",
"secretAccessKey": "secret"
}
}),
Some("https://account.r2.cloudflarestorage.com"),
),
(
"eu jurisdiction",
serde_json::json!({
"accountId": "account",
"jurisdiction": "eu",
"credentials": {
"accessKeyId": "key",
"secretAccessKey": "secret"
}
}),
Some("https://account.eu.r2.cloudflarestorage.com"),
),
] {
let config = ObjectStorageConfig::from_r2_config(storage_config("cloudflare-r2", config))
.unwrap()
.unwrap();
assert_eq!(config.endpoint.as_deref(), expected_endpoint, "{case}");
assert!(config.force_path_style, "{case}");
}
assert!(
ObjectStorageConfig::from_r2_config(storage_config(
"cloudflare-r2",
serde_json::json!({
"credentials": {
"accessKeyId": "key",
"secretAccessKey": "secret"
}
})
))
.is_err()
);
}
#[test]
fn object_storage_not_found_requires_object_error_code() {
let bucket_or_route_missing = ObjectStorageError::HttpStatus {
@@ -50,9 +118,15 @@ fn object_storage_not_found_requires_object_error_code() {
status: StatusCode::NOT_FOUND,
body: "<Error><Code>NoSuchKey</Code></Error>".to_string(),
};
let upload_missing = ObjectStorageError::HttpStatus {
context: "abort failed".to_string(),
status: StatusCode::NOT_FOUND,
body: "<Error><Code>NoSuchUpload</Code></Error>".to_string(),
};
assert!(!bucket_or_route_missing.is_not_found());
assert!(object_missing.is_not_found());
assert!(upload_missing.is_not_found());
}
#[test]
@@ -113,6 +187,28 @@ fn resolves_s3_config_from_config_json_shape() {
assert_eq!(config.presign_sign_content_type_for_put, Some(false));
}
#[test]
fn resolves_s3_default_endpoint_cases_from_config_json_shape() {
for (region, expected_endpoint) in [
("us-east-1", "https://s3.amazonaws.com"),
("us-west-2", "https://s3.us-west-2.amazonaws.com"),
] {
let config = ObjectStorageConfig::from_s3_config(storage_config(
"aws-s3",
serde_json::json!({
"region": region,
"credentials": {
"accessKeyId": "key",
"secretAccessKey": "secret"
}
}),
))
.unwrap()
.unwrap();
assert_eq!(config.endpoint.as_deref(), Some(expected_endpoint), "{region}");
}
}
#[tokio::test]
async fn object_storage_presign_put_returns_sigv4_url_and_headers() {
let storage = StorageProviderConfig {
@@ -155,6 +251,47 @@ async fn object_storage_presign_put_returns_sigv4_url_and_headers() {
assert!(result.expires_at_ms > 0);
}
#[tokio::test]
async fn object_storage_presign_put_respects_content_length_and_signed_content_type_flag() {
let config = ObjectStorageConfig::from_s3_config(storage_config(
"aws-s3",
serde_json::json!({
"region": "us-east-1",
"endpoint": "https://s3.us-east-1.amazonaws.com",
"credentials": {
"accessKeyId": "key",
"secretAccessKey": "secret"
},
"presign": {
"expiresInSeconds": 60,
"signContentTypeForPut": false
}
}),
))
.unwrap()
.unwrap();
let client = config.build_client().unwrap();
let result = client
.presign_put(
"key",
ObjectPutMetadata {
content_type: Some("text/plain".to_string()),
content_length: Some(42),
..Default::default()
},
)
.await
.unwrap();
assert_eq!(
result.headers.get("Content-Type").map(String::as_str),
Some("text/plain")
);
assert_eq!(result.headers.get("Content-Length").map(String::as_str), Some("42"));
assert!(!result.url.contains("content-type"));
assert!(result.url.contains("content-length"));
}
#[tokio::test]
async fn object_storage_presign_get_returns_sigv4_url_without_headers() {
let storage = StorageProviderConfig {
@@ -182,6 +319,34 @@ async fn object_storage_presign_get_returns_sigv4_url_without_headers() {
assert!(result.expires_at_ms > 0);
}
#[tokio::test]
async fn object_storage_presign_upload_part_returns_sigv4_url() {
let config = ObjectStorageConfig::from_s3_config(storage_config(
"aws-s3",
serde_json::json!({
"region": "us-east-1",
"endpoint": "https://s3.us-east-1.amazonaws.com",
"credentials": {
"accessKeyId": "key",
"secretAccessKey": "secret"
},
"presign": {
"expiresInSeconds": 60
}
}),
))
.unwrap()
.unwrap();
let client = config.build_client().unwrap();
let result = client.presign_upload_part("key", "upload-1", 3).await.unwrap();
assert!(result.url.contains("X-Amz-Algorithm=AWS4-HMAC-SHA256"));
assert!(result.url.contains("partNumber=3"));
assert!(result.url.contains("uploadId=upload-1"));
assert!(result.headers.is_empty());
assert!(result.expires_at_ms > 0);
}
#[test]
fn object_storage_orders_completed_multipart_parts_and_trims_etags() {
let parts = completed_multipart_parts(vec![
@@ -13,6 +13,7 @@ import {
AFFiNELogger,
CacheInterceptor,
CloudThrottlerGuard,
ConfigFactory,
EventBus,
GlobalExceptionFilter,
JobQueue,
@@ -250,6 +251,31 @@ export async function createApp(
}
const module = await builder.compile();
module.get(ConfigFactory).override({
storages: {
avatar: {
storage: {
provider: 'assetpack',
bucket: 'avatars',
config: { path: '/tmp/affine-test-storage' },
},
},
blob: {
storage: {
provider: 'assetpack',
bucket: 'blobs',
config: { path: '/tmp/affine-test-storage' },
},
},
},
copilot: {
storage: {
provider: 'assetpack',
bucket: 'copilot',
config: { path: '/tmp/affine-test-storage' },
},
},
});
module.useCustomApplicationConstructor(TestingApp);
@@ -362,6 +362,7 @@ e2e.serial('should proxy single upload with valid signature', async t => {
t.is(init.method, 'PRESIGNED');
t.truthy(init.uploadUrl);
const uploadUrl = new URL(init.uploadUrl, app.url);
t.is(uploadUrl.origin, 'https://cdn.example.com');
t.is(uploadUrl.pathname, PROXY_UPLOAD_PATH);
const res = await app
@@ -391,6 +392,7 @@ e2e.serial('should proxy multipart upload and return etag', async t => {
const part = await getBlobUploadPartUrl(workspace.id, key, init.uploadId, 1);
const partUrl = new URL(part.uploadUrl, app.url);
t.is(partUrl.origin, 'https://cdn.example.com');
t.is(partUrl.pathname, PROXY_MULTIPART_PATH);
const payload = Buffer.from('part-body');
@@ -9,7 +9,7 @@ import {
import { PrismaClient } from '@prisma/client';
import { buildAppModule, FunctionalityModules } from '../../app.module';
import { AFFiNELogger, JobQueue } from '../../base';
import { AFFiNELogger, ConfigFactory, JobQueue } from '../../base';
import { GqlModule } from '../../base/graphql';
import { ServerConfigModule } from '../../core';
import { AuthGuard, AuthModule } from '../../core/auth';
@@ -99,6 +99,31 @@ export async function createTestingModule(
}
const module = await builder.compile();
module.get(ConfigFactory).override({
storages: {
avatar: {
storage: {
provider: 'assetpack',
bucket: 'avatars',
config: { path: '/tmp/affine-test-storage' },
},
},
blob: {
storage: {
provider: 'assetpack',
bucket: 'blobs',
config: { path: '/tmp/affine-test-storage' },
},
},
},
copilot: {
storage: {
provider: 'assetpack',
bucket: 'copilot',
config: { path: '/tmp/affine-test-storage' },
},
},
});
const testingModule = module as TestingModule;
@@ -35,6 +35,22 @@ const RESTRICTED_QUOTA = {
let app: TestingApp;
let model: WorkspaceFeatureModel;
type CompleteResult =
| {
ok: true;
contentType: string;
contentLength: number;
lastModifiedMs: number;
}
| {
ok: false;
reason:
| 'not_found'
| 'size_mismatch'
| 'mime_mismatch'
| 'checksum_mismatch'
| 'size_too_large';
};
const objects = new Map<
string,
{
@@ -46,6 +62,7 @@ const objects = new Map<
};
}
>();
const completeResults = new Map<string, CompleteResult>();
const storageRuntime = {
providerCapabilities: async () => ({
put: true,
@@ -100,25 +117,12 @@ const storageRuntime = {
presignUploadPart: async () => undefined,
listMultipartUploadParts: async () => undefined,
completeMultipartUpload: async () => undefined,
completeWorkspaceBlobUpload: async (
workspaceId: string,
key: string,
expected: { size: number; mime: string }
) => {
completeWorkspaceBlobUpload: async (workspaceId: string, key: string) => {
const objectKey = `${workspaceId}/${key}`;
const configured = completeResults.get(objectKey);
if (configured) return configured;
const object = objects.get(objectKey);
if (!object) {
return { ok: false, reason: 'not_found' };
}
if (object.metadata.contentLength !== expected.size) {
return { ok: false, reason: 'size_mismatch' };
}
if (object.metadata.contentType !== expected.mime) {
return { ok: false, reason: 'mime_mismatch' };
}
if (sha256Base64urlWithPadding(object.body) !== key) {
return { ok: false, reason: 'checksum_mismatch' };
}
if (!object) return { ok: false, reason: 'not_found' };
await app.get(BlobModel).upsert({
workspaceId,
key,
@@ -159,6 +163,7 @@ test.before(async () => {
test.beforeEach(async () => {
await app.initTestingDB();
objects.clear();
completeResults.clear();
});
test.after.always(async () => {
@@ -333,6 +338,10 @@ test('should reject complete when blob key mismatched', async t => {
contentType: mime,
contentLength: buffer.length,
});
completeResults.set(`${workspace.id}/${wrongKey}`, {
ok: false,
reason: 'checksum_mismatch',
});
await t.throwsAsync(() => completeBlobUpload(app, workspace.id, wrongKey), {
message: 'Blob key mismatch',
@@ -21,6 +21,7 @@ export type JSONSchema = { description?: string } & (
| {
type: 'object';
properties?: Record<string, JSONSchema>;
required?: string[];
}
);
@@ -232,8 +232,10 @@ export const StorageJSONSchema: JSONSchema = {
type: 'string',
},
},
required: ['path'],
},
},
required: ['provider', 'bucket', 'config'],
},
],
};
@@ -4,8 +4,36 @@ import Sinon from 'sinon';
import { StorageRuntimeProvider } from '../provider';
function createProvider() {
const provider = new StorageRuntimeProvider();
const provider = new StorageRuntimeProvider({
db: {
datasourceUrl: 'postgresql://localhost:5432/affine',
},
storages: {
blob: {
storage: {
provider: 'fs',
bucket: 'blobs',
config: { path: '~/.affine/storage' },
},
},
avatar: {
storage: {
provider: 'fs',
bucket: 'avatars',
config: { path: '~/.affine/storage' },
},
},
},
copilot: {
storage: {
provider: 'fs',
bucket: 'copilot',
config: { path: '~/.affine/storage' },
},
},
} as any);
const runtime = {
configure: Sinon.stub(),
start: Sinon.stub().resolves(),
stop: Sinon.stub().resolves(),
runMigrations: Sinon.stub().resolves(),
@@ -26,6 +54,29 @@ test('storage-runtime provider restarts on storage config changes', async t => {
await provider.onConfigChanged({ updates: { storages: {} } });
t.is(runtime.stop.callCount, 1);
t.is(runtime.configure.callCount, 2);
t.is(runtime.start.callCount, 2);
t.is(runtime.runMigrations.callCount, 2);
});
test('storage-runtime provider restarts on copilot storage config changes', async t => {
const { provider, runtime } = createProvider();
await provider.start();
await provider.onConfigChanged({
updates: {
copilot: {
storage: {
provider: 'fs',
bucket: 'new-copilot',
config: { path: '~/.affine/storage' },
},
},
},
});
t.is(runtime.stop.callCount, 1);
t.is(runtime.configure.callCount, 2);
t.is(runtime.start.callCount, 2);
t.is(runtime.runMigrations.callCount, 2);
});
@@ -14,7 +14,7 @@ import type {
PresignedUpload,
PutObjectMetadata,
} from '../../base';
import { OnEvent } from '../../base';
import { Config, OnEvent } from '../../base';
import { wrapCallMetric } from '../../base/metrics';
import {
type RuntimeObjectGetResult,
@@ -36,6 +36,8 @@ export class StorageRuntimeProvider
private readonly runtime: RuntimeInstance = new StorageRuntime();
private migrationsStarted = false;
constructor(private readonly config: Config) {}
async onApplicationBootstrap() {
await this.start();
}
@@ -45,6 +47,7 @@ export class StorageRuntimeProvider
}
async start() {
this.configureRuntime();
await this.runtime.start();
await this.runMigrationsOnce();
const health = await this.runtime.health();
@@ -65,7 +68,11 @@ export class StorageRuntimeProvider
@OnEvent('config.changed')
async onConfigChanged({ updates }: Events['config.changed']) {
if (!('storages' in updates) && !('db' in updates)) {
if (
!('storages' in updates) &&
!('db' in updates) &&
!updates.copilot?.storage
) {
return;
}
await this.restart();
@@ -293,6 +300,23 @@ export class StorageRuntimeProvider
this.migrationsStarted = false;
await this.start();
}
private configureRuntime() {
this.runtime.configure(
JSON.stringify({
db: {
datasourceUrl: this.config.db.datasourceUrl,
},
storages: {
'blob.storage': this.config.storages.blob.storage,
'avatar.storage': this.config.storages.avatar.storage,
},
copilot: {
storage: this.config.copilot.storage,
},
})
);
}
}
function toRuntimeMetadata(metadata?: PutObjectMetadata) {
@@ -32,6 +32,7 @@ test.beforeEach(t => {
health: Sinon.stub().resolves({
databaseConnected: true,
providerConfigured: true,
provider: 'fs',
}),
backfillMissingBlobMetadata: Sinon.stub(),
rebuildWorkspaceDocBlobRefs: Sinon.stub(),
@@ -103,7 +104,8 @@ for (const scenario of objectStorageRequiredCases) {
test(`${scenario.name} skips when object storage is not configured`, async t => {
t.context.runtime.health.resolves({
databaseConnected: true,
providerConfigured: false,
providerConfigured: true,
provider: undefined,
});
await scenario.run(t.context);
@@ -411,7 +411,7 @@ export class StorageBlobJob {
private async hasObjectStorage(operation: string) {
const health = await this.rt.health();
if (health.providerConfigured) {
if (health.provider) {
return true;
}
@@ -50,6 +50,11 @@ type BlobGetResult = {
metadata?: GetObjectMetadata;
};
type R2ProxyConfig = {
signKey: string;
urlPrefix: string;
};
@Injectable()
export class WorkspaceBlobStorage {
private readonly logger = new Logger(WorkspaceBlobStorage.name);
@@ -317,7 +322,10 @@ export class WorkspaceBlobStorage {
) {
return;
}
return { signKey: usePresignedURL.signKey };
return {
signKey: usePresignedURL.signKey,
urlPrefix: usePresignedURL.urlPrefix,
};
}
private signProxy(
@@ -340,7 +348,7 @@ export class WorkspaceBlobStorage {
workspaceId: string,
key: string,
metadata: PutObjectMetadata | undefined,
proxy: { signKey: string }
proxy: R2ProxyConfig
) {
const contentType = metadata?.contentType ?? 'application/octet-stream';
const contentLength = metadata?.contentLength;
@@ -353,7 +361,7 @@ export class WorkspaceBlobStorage {
proxy.signKey
);
return {
url: this.url.link(PROXY_UPLOAD_PATH, {
url: this.linkProxyUrl(proxy.urlPrefix, PROXY_UPLOAD_PATH, {
workspaceId,
key,
contentType,
@@ -371,7 +379,7 @@ export class WorkspaceBlobStorage {
key: string,
uploadId: string,
partNumber: number,
proxy: { signKey: string }
proxy: R2ProxyConfig
) {
const expiresAt = new Date(Date.now() + SIGNED_URL_EXPIRED * 1000);
const exp = Math.floor(expiresAt.getTime() / 1000);
@@ -382,7 +390,7 @@ export class WorkspaceBlobStorage {
proxy.signKey
);
return {
url: this.url.link(PROXY_MULTIPART_PATH, {
url: this.linkProxyUrl(proxy.urlPrefix, PROXY_MULTIPART_PATH, {
workspaceId,
key,
uploadId,
@@ -394,4 +402,20 @@ export class WorkspaceBlobStorage {
expiresAt,
};
}
private linkProxyUrl(
urlPrefix: string,
path: string,
query: Record<string, string | number | undefined>
) {
const url = new URL(
`${urlPrefix.replace(/\/+$/, '')}${path.startsWith('/') ? path : `/${path}`}`
);
for (const [key, value] of Object.entries(query)) {
if (value !== undefined) {
url.searchParams.set(key, value.toString());
}
}
return url.toString();
}
}
@@ -16,9 +16,10 @@ export class UserAvatarController {
@Get('/:id')
async getAvatar(@Res() res: Response, @Param('id') id: string) {
if (this.storage.config.storage.provider !== 'fs') {
const provider = this.storage.config.storage.provider;
if (!['assetpack', 'fs'].includes(provider)) {
throw new ActionForbidden(
'Only available when avatar storage provider set to fs.'
'Only available when avatar storage provider is fs or assetpack.'
);
}
@@ -398,6 +398,9 @@ export class WorkspaceBlobResolver {
if (result.reason === 'mime_mismatch') {
throw new BlobInvalid('Blob mime mismatch');
}
if (result.reason === 'size_too_large') {
throw new BlobInvalid('Blob size too large');
}
throw new BlobInvalid('Blob key mismatch');
}
@@ -39,7 +39,10 @@ export class CopilotStorage {
) {
const name = `${userId}/${workspaceId}/${key}`;
const buffer = await toBuffer(blob);
await this.rt.putObject('copilot', name, buffer);
await this.rt.putObject('copilot', name, buffer, {
contentType: mimeType,
contentLength: buffer.length,
});
if (!env.prod) {
// return image base64url for dev environment
return `data:${mimeType};base64,${buffer.toString('base64')}`;