chore: improve codes

This commit is contained in:
DarkSky
2026-02-20 01:40:48 +08:00
parent 849699e93f
commit e9ea299ce9
3 changed files with 169 additions and 97 deletions

View File

@@ -25,13 +25,21 @@ impl<V> Deref for Ref<V> {
#[derive(Default)]
pub struct SqliteDocStoragePool {
inner: RwLock<HashMap<String, Arc<SqliteDocStorage>>>,
inner: RwLock<HashMap<String, StorageState>>,
}
enum StorageState {
Connecting(Arc<SqliteDocStorage>),
Connected(Arc<SqliteDocStorage>),
}
impl SqliteDocStoragePool {
pub async fn get(&self, universal_id: String) -> Result<Ref<SqliteDocStorage>> {
let lock = self.inner.read().await;
let Some(storage) = lock.get(&universal_id) else {
let Some(state) = lock.get(&universal_id) else {
return Err(Error::InvalidOperation);
};
let StorageState::Connected(storage) = state else {
return Err(Error::InvalidOperation);
};
Ok(Ref {
@@ -43,30 +51,69 @@ impl SqliteDocStoragePool {
pub async fn connect(&self, universal_id: String, path: String) -> Result<()> {
let storage = {
let mut lock = self.inner.write().await;
match lock.entry(universal_id) {
Entry::Occupied(entry) => Arc::clone(entry.get()),
Entry::Vacant(entry) => Arc::clone(entry.insert(Arc::new(SqliteDocStorage::new(path)))),
match lock.entry(universal_id.clone()) {
Entry::Occupied(entry) => match entry.get() {
StorageState::Connected(_) => return Ok(()),
StorageState::Connecting(_) => return Err(Error::InvalidOperation),
},
Entry::Vacant(entry) => {
let storage = Arc::new(SqliteDocStorage::new(path));
entry.insert(StorageState::Connecting(Arc::clone(&storage)));
storage
}
}
};
storage.connect().await?;
if let Err(err) = storage.connect().await {
let mut lock = self.inner.write().await;
if matches!(
lock.get(&universal_id),
Some(StorageState::Connecting(existing)) if Arc::ptr_eq(existing, &storage)
) {
lock.remove(&universal_id);
}
return Err(err);
}
let mut transitioned = false;
{
let mut lock = self.inner.write().await;
if matches!(
lock.get(&universal_id),
Some(StorageState::Connecting(existing)) if Arc::ptr_eq(existing, &storage)
) {
lock.insert(universal_id, StorageState::Connected(Arc::clone(&storage)));
transitioned = true;
}
}
if !transitioned {
storage.close().await;
return Err(Error::InvalidOperation);
}
Ok(())
}
pub async fn disconnect(&self, universal_id: String) -> Result<()> {
let storage = {
let mut lock = self.inner.write().await;
lock.remove(&universal_id)
};
let Some(storage) = storage else {
return Ok(());
};
match lock.get(&universal_id) {
None => return Ok(()),
Some(StorageState::Connecting(_)) => return Err(Error::InvalidOperation),
Some(StorageState::Connected(storage)) => {
// Prevent shutting down the shared storage while requests still hold refs.
if Arc::strong_count(storage) > 1 {
return Err(Error::InvalidOperation);
}
}
}
// Prevent shutting down the shared storage while requests still hold refs.
if Arc::strong_count(&storage) > 1 {
let mut lock = self.inner.write().await;
lock.insert(universal_id, storage);
return Err(Error::InvalidOperation);
}
let Some(StorageState::Connected(storage)) = lock.remove(&universal_id) else {
return Err(Error::InvalidOperation);
};
storage
};
storage.close().await;
Ok(())