chore: improve codes

This commit is contained in:
DarkSky
2026-02-20 01:40:48 +08:00
parent 849699e93f
commit e9ea299ce9
3 changed files with 169 additions and 97 deletions

View File

@@ -3,9 +3,7 @@ use affine_nbstore::{Data, pool::SqliteDocStoragePool};
#[cfg(any(target_os = "android", target_os = "ios"))]
pub(crate) mod mobile_blob_cache;
#[cfg(any(target_os = "android", target_os = "ios"))]
use mobile_blob_cache::{
MOBILE_BLOB_INLINE_THRESHOLD_BYTES, MobileBlobCache, is_mobile_binary_file_token, read_mobile_binary_file,
};
use mobile_blob_cache::{MOBILE_BLOB_INLINE_THRESHOLD_BYTES, MobileBlobCache, is_mobile_binary_file_token};
#[derive(uniffi::Error, thiserror::Error, Debug)]
pub enum UniffiError {
@@ -27,13 +25,7 @@ type Result<T> = std::result::Result<T, UniffiError>;
uniffi::setup_scaffolding!("affine_mobile_native");
fn decode_mobile_data(data: &str) -> Result<Vec<u8>> {
#[cfg(any(target_os = "android", target_os = "ios"))]
if is_mobile_binary_file_token(data) {
return read_mobile_binary_file(data)
.map_err(|err| UniffiError::Err(format!("Failed to read mobile file token: {err}")));
}
fn decode_base64_data(data: &str) -> Result<Vec<u8>> {
base64_simd::STANDARD
.decode_to_vec(data)
.map_err(|e| UniffiError::Base64DecodingError(e.to_string()))
@@ -69,7 +61,7 @@ impl TryFrom<DocRecord> for affine_nbstore::DocRecord {
fn try_from(record: DocRecord) -> Result<Self> {
Ok(Self {
doc_id: record.doc_id,
bin: Into::<Data>::into(decode_mobile_data(&record.bin)?),
bin: Into::<Data>::into(decode_base64_data(&record.bin)?),
timestamp: chrono::DateTime::<chrono::Utc>::from_timestamp_millis(record.timestamp)
.ok_or(UniffiError::TimestampDecodingError)?
.naive_utc(),
@@ -105,7 +97,7 @@ impl TryFrom<DocUpdate> for affine_nbstore::DocUpdate {
timestamp: chrono::DateTime::<chrono::Utc>::from_timestamp_millis(update.timestamp)
.ok_or(UniffiError::TimestampDecodingError)?
.naive_utc(),
bin: Into::<Data>::into(decode_mobile_data(&update.bin)?),
bin: Into::<Data>::into(decode_base64_data(&update.bin)?),
})
}
}
@@ -217,7 +209,7 @@ impl TryFrom<SetBlob> for affine_nbstore::SetBlob {
fn try_from(blob: SetBlob) -> Result<Self> {
Ok(Self {
key: blob.key,
data: Into::<Data>::into(decode_mobile_data(&blob.data)?),
data: Into::<Data>::into(decode_base64_data(&blob.data)?),
mime: blob.mime,
})
}
@@ -338,6 +330,20 @@ pub fn new_doc_storage_pool() -> DocStoragePool {
#[uniffi::export(async_runtime = "tokio")]
impl DocStoragePool {
fn decode_mobile_data(&self, universal_id: &str, data: &str) -> Result<Vec<u8>> {
#[cfg(any(target_os = "android", target_os = "ios"))]
if is_mobile_binary_file_token(data) {
return self
.mobile_blob_cache
.read_binary_file(universal_id, data)
.map_err(|err| UniffiError::Err(format!("Failed to read mobile file token: {err}")));
}
#[cfg(not(any(target_os = "android", target_os = "ios")))]
let _ = universal_id;
decode_base64_data(data)
}
fn encode_doc_data(&self, universal_id: &str, doc_id: &str, timestamp: i64, data: &[u8]) -> Result<String> {
#[cfg(any(target_os = "android", target_os = "ios"))]
if data.len() >= MOBILE_BLOB_INLINE_THRESHOLD_BYTES {
@@ -375,12 +381,13 @@ impl DocStoragePool {
}
pub async fn push_update(&self, universal_id: String, doc_id: String, update: String) -> Result<i64> {
let decoded_update = self.decode_mobile_data(&universal_id, &update)?;
Ok(
self
.inner
.get(universal_id)
.await?
.push_update(doc_id, decode_mobile_data(&update)?)
.push_update(doc_id, decoded_update)
.await?
.and_utc()
.timestamp_millis(),
@@ -408,14 +415,14 @@ impl DocStoragePool {
}
pub async fn set_doc_snapshot(&self, universal_id: String, snapshot: DocRecord) -> Result<bool> {
Ok(
self
.inner
.get(universal_id)
.await?
.set_doc_snapshot(snapshot.try_into()?)
.await?,
)
let doc_record = affine_nbstore::DocRecord {
doc_id: snapshot.doc_id,
bin: Into::<Data>::into(self.decode_mobile_data(&universal_id, &snapshot.bin)?),
timestamp: chrono::DateTime::<chrono::Utc>::from_timestamp_millis(snapshot.timestamp)
.ok_or(UniffiError::TimestampDecodingError)?
.naive_utc(),
};
Ok(self.inner.get(universal_id).await?.set_doc_snapshot(doc_record).await?)
}
pub async fn get_doc_updates(&self, universal_id: String, doc_id: String) -> Result<Vec<DocUpdate>> {
@@ -534,12 +541,12 @@ impl DocStoragePool {
pub async fn set_blob(&self, universal_id: String, blob: SetBlob) -> Result<()> {
#[cfg(any(target_os = "android", target_os = "ios"))]
let key = blob.key.clone();
self
.inner
.get(universal_id.clone())
.await?
.set_blob(blob.try_into()?)
.await?;
let blob = affine_nbstore::SetBlob {
key: blob.key,
data: Into::<Data>::into(self.decode_mobile_data(&universal_id, &blob.data)?),
mime: blob.mime,
};
self.inner.get(universal_id.clone()).await?.set_blob(blob).await?;
#[cfg(any(target_os = "android", target_os = "ios"))]
self.mobile_blob_cache.invalidate_blob(&universal_id, &key);
Ok(())

View File

@@ -172,11 +172,17 @@ impl MobileBlobCache {
}
}
self
if let Some(cache_dir) = self
.workspace_dirs
.lock()
.expect("workspace cache lock poisoned")
.remove(universal_id);
.remove(universal_id)
{
let _ = std::fs::remove_dir_all(&cache_dir);
if let Some(parent) = cache_dir.parent() {
let _ = std::fs::remove_dir(parent);
}
}
}
fn cache_key(universal_id: &str, key: &str) -> String {
@@ -284,44 +290,72 @@ pub(crate) fn is_mobile_binary_file_token(value: &str) -> bool {
value.starts_with(MOBILE_BLOB_FILE_PREFIX) || value.starts_with(MOBILE_DOC_FILE_PREFIX)
}
pub(crate) fn read_mobile_binary_file(value: &str) -> std::io::Result<Vec<u8>> {
let path = value
.strip_prefix(MOBILE_BLOB_FILE_PREFIX)
.or_else(|| value.strip_prefix(MOBILE_DOC_FILE_PREFIX))
.ok_or_else(|| std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid mobile file token"))?;
impl MobileBlobCache {
pub(crate) fn read_binary_file(&self, universal_id: &str, value: &str) -> std::io::Result<Vec<u8>> {
let path = value
.strip_prefix(MOBILE_BLOB_FILE_PREFIX)
.or_else(|| value.strip_prefix(MOBILE_DOC_FILE_PREFIX))
.ok_or_else(|| std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid mobile file token"))?;
let path = path.strip_prefix("file://").unwrap_or(path);
let canonical = std::fs::canonicalize(path)?;
if !is_valid_mobile_cache_path(&canonical) {
return Err(std::io::Error::new(
std::io::ErrorKind::PermissionDenied,
"mobile file token points outside the nbstore cache directory",
));
}
let path = path.strip_prefix("file://").unwrap_or(path);
let canonical = std::fs::canonicalize(path)?;
let workspace_dir = {
self
.workspace_dirs
.lock()
.expect("workspace cache lock poisoned")
.get(universal_id)
.cloned()
}
.ok_or_else(|| std::io::Error::new(std::io::ErrorKind::NotFound, "workspace cache directory not registered"))?;
let workspace_dir = std::fs::canonicalize(workspace_dir)?;
let metadata = std::fs::metadata(&canonical)?;
if !metadata.is_file() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"mobile file token does not resolve to a file",
));
}
if metadata.len() > MOBILE_BLOB_MAX_READ_BYTES {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"mobile file token exceeds max size: {} > {}",
metadata.len(),
MOBILE_BLOB_MAX_READ_BYTES
),
));
}
if !is_valid_mobile_cache_path(&canonical, &workspace_dir) {
return Err(std::io::Error::new(
std::io::ErrorKind::PermissionDenied,
"mobile file token points outside the workspace cache directory",
));
}
std::fs::read(canonical)
let metadata = std::fs::metadata(&canonical)?;
if !metadata.is_file() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"mobile file token does not resolve to a file",
));
}
if metadata.len() > MOBILE_BLOB_MAX_READ_BYTES {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"mobile file token exceeds max size: {} > {}",
metadata.len(),
MOBILE_BLOB_MAX_READ_BYTES
),
));
}
std::fs::read(canonical)
}
}
fn is_valid_mobile_cache_path(path: &Path) -> bool {
let Some(file_name) = path.file_name().and_then(|name| name.to_str()) else {
fn is_valid_mobile_cache_path(path: &Path, workspace_dir: &Path) -> bool {
if !path.starts_with(workspace_dir) {
return false;
}
let Ok(relative) = path.strip_prefix(workspace_dir) else {
return false;
};
let mut components = relative.components();
let Some(std::path::Component::Normal(file_name)) = components.next() else {
return false;
};
if components.next().is_some() {
return false;
}
let Some(file_name) = file_name.to_str() else {
return false;
};
let Some((stem, extension)) = file_name.rsplit_once('.') else {
@@ -330,21 +364,5 @@ fn is_valid_mobile_cache_path(path: &Path) -> bool {
if extension != "blob" && extension != "docbin" {
return false;
}
if stem.len() != 16 || !stem.chars().all(|c| c.is_ascii_hexdigit()) {
return false;
}
let Some(workspace_bucket) = path.parent().and_then(Path::file_name).and_then(|name| name.to_str()) else {
return false;
};
if workspace_bucket.len() != 16 || !workspace_bucket.chars().all(|c| c.is_ascii_hexdigit()) {
return false;
}
path
.parent()
.and_then(Path::parent)
.and_then(Path::file_name)
.and_then(|name| name.to_str())
== Some(MOBILE_BLOB_CACHE_DIR)
stem.len() == 16 && stem.chars().all(|c| c.is_ascii_hexdigit())
}

View File

@@ -25,13 +25,21 @@ impl<V> Deref for Ref<V> {
#[derive(Default)]
pub struct SqliteDocStoragePool {
inner: RwLock<HashMap<String, Arc<SqliteDocStorage>>>,
inner: RwLock<HashMap<String, StorageState>>,
}
enum StorageState {
Connecting(Arc<SqliteDocStorage>),
Connected(Arc<SqliteDocStorage>),
}
impl SqliteDocStoragePool {
pub async fn get(&self, universal_id: String) -> Result<Ref<SqliteDocStorage>> {
let lock = self.inner.read().await;
let Some(storage) = lock.get(&universal_id) else {
let Some(state) = lock.get(&universal_id) else {
return Err(Error::InvalidOperation);
};
let StorageState::Connected(storage) = state else {
return Err(Error::InvalidOperation);
};
Ok(Ref {
@@ -43,30 +51,69 @@ impl SqliteDocStoragePool {
pub async fn connect(&self, universal_id: String, path: String) -> Result<()> {
let storage = {
let mut lock = self.inner.write().await;
match lock.entry(universal_id) {
Entry::Occupied(entry) => Arc::clone(entry.get()),
Entry::Vacant(entry) => Arc::clone(entry.insert(Arc::new(SqliteDocStorage::new(path)))),
match lock.entry(universal_id.clone()) {
Entry::Occupied(entry) => match entry.get() {
StorageState::Connected(_) => return Ok(()),
StorageState::Connecting(_) => return Err(Error::InvalidOperation),
},
Entry::Vacant(entry) => {
let storage = Arc::new(SqliteDocStorage::new(path));
entry.insert(StorageState::Connecting(Arc::clone(&storage)));
storage
}
}
};
storage.connect().await?;
if let Err(err) = storage.connect().await {
let mut lock = self.inner.write().await;
if matches!(
lock.get(&universal_id),
Some(StorageState::Connecting(existing)) if Arc::ptr_eq(existing, &storage)
) {
lock.remove(&universal_id);
}
return Err(err);
}
let mut transitioned = false;
{
let mut lock = self.inner.write().await;
if matches!(
lock.get(&universal_id),
Some(StorageState::Connecting(existing)) if Arc::ptr_eq(existing, &storage)
) {
lock.insert(universal_id, StorageState::Connected(Arc::clone(&storage)));
transitioned = true;
}
}
if !transitioned {
storage.close().await;
return Err(Error::InvalidOperation);
}
Ok(())
}
pub async fn disconnect(&self, universal_id: String) -> Result<()> {
let storage = {
let mut lock = self.inner.write().await;
lock.remove(&universal_id)
};
let Some(storage) = storage else {
return Ok(());
};
match lock.get(&universal_id) {
None => return Ok(()),
Some(StorageState::Connecting(_)) => return Err(Error::InvalidOperation),
Some(StorageState::Connected(storage)) => {
// Prevent shutting down the shared storage while requests still hold refs.
if Arc::strong_count(storage) > 1 {
return Err(Error::InvalidOperation);
}
}
}
// Prevent shutting down the shared storage while requests still hold refs.
if Arc::strong_count(&storage) > 1 {
let mut lock = self.inner.write().await;
lock.insert(universal_id, storage);
return Err(Error::InvalidOperation);
}
let Some(StorageState::Connected(storage)) = lock.remove(&universal_id) else {
return Err(Error::InvalidOperation);
};
storage
};
storage.close().await;
Ok(())