mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-24 18:02:47 +08:00
chore: improve codes
This commit is contained in:
@@ -25,13 +25,21 @@ impl<V> Deref for Ref<V> {
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct SqliteDocStoragePool {
|
||||
inner: RwLock<HashMap<String, Arc<SqliteDocStorage>>>,
|
||||
inner: RwLock<HashMap<String, StorageState>>,
|
||||
}
|
||||
|
||||
enum StorageState {
|
||||
Connecting(Arc<SqliteDocStorage>),
|
||||
Connected(Arc<SqliteDocStorage>),
|
||||
}
|
||||
|
||||
impl SqliteDocStoragePool {
|
||||
pub async fn get(&self, universal_id: String) -> Result<Ref<SqliteDocStorage>> {
|
||||
let lock = self.inner.read().await;
|
||||
let Some(storage) = lock.get(&universal_id) else {
|
||||
let Some(state) = lock.get(&universal_id) else {
|
||||
return Err(Error::InvalidOperation);
|
||||
};
|
||||
let StorageState::Connected(storage) = state else {
|
||||
return Err(Error::InvalidOperation);
|
||||
};
|
||||
Ok(Ref {
|
||||
@@ -43,30 +51,69 @@ impl SqliteDocStoragePool {
|
||||
pub async fn connect(&self, universal_id: String, path: String) -> Result<()> {
|
||||
let storage = {
|
||||
let mut lock = self.inner.write().await;
|
||||
match lock.entry(universal_id) {
|
||||
Entry::Occupied(entry) => Arc::clone(entry.get()),
|
||||
Entry::Vacant(entry) => Arc::clone(entry.insert(Arc::new(SqliteDocStorage::new(path)))),
|
||||
match lock.entry(universal_id.clone()) {
|
||||
Entry::Occupied(entry) => match entry.get() {
|
||||
StorageState::Connected(_) => return Ok(()),
|
||||
StorageState::Connecting(_) => return Err(Error::InvalidOperation),
|
||||
},
|
||||
Entry::Vacant(entry) => {
|
||||
let storage = Arc::new(SqliteDocStorage::new(path));
|
||||
entry.insert(StorageState::Connecting(Arc::clone(&storage)));
|
||||
storage
|
||||
}
|
||||
}
|
||||
};
|
||||
storage.connect().await?;
|
||||
|
||||
if let Err(err) = storage.connect().await {
|
||||
let mut lock = self.inner.write().await;
|
||||
if matches!(
|
||||
lock.get(&universal_id),
|
||||
Some(StorageState::Connecting(existing)) if Arc::ptr_eq(existing, &storage)
|
||||
) {
|
||||
lock.remove(&universal_id);
|
||||
}
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
let mut transitioned = false;
|
||||
{
|
||||
let mut lock = self.inner.write().await;
|
||||
if matches!(
|
||||
lock.get(&universal_id),
|
||||
Some(StorageState::Connecting(existing)) if Arc::ptr_eq(existing, &storage)
|
||||
) {
|
||||
lock.insert(universal_id, StorageState::Connected(Arc::clone(&storage)));
|
||||
transitioned = true;
|
||||
}
|
||||
}
|
||||
|
||||
if !transitioned {
|
||||
storage.close().await;
|
||||
return Err(Error::InvalidOperation);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn disconnect(&self, universal_id: String) -> Result<()> {
|
||||
let storage = {
|
||||
let mut lock = self.inner.write().await;
|
||||
lock.remove(&universal_id)
|
||||
};
|
||||
let Some(storage) = storage else {
|
||||
return Ok(());
|
||||
};
|
||||
match lock.get(&universal_id) {
|
||||
None => return Ok(()),
|
||||
Some(StorageState::Connecting(_)) => return Err(Error::InvalidOperation),
|
||||
Some(StorageState::Connected(storage)) => {
|
||||
// Prevent shutting down the shared storage while requests still hold refs.
|
||||
if Arc::strong_count(storage) > 1 {
|
||||
return Err(Error::InvalidOperation);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Prevent shutting down the shared storage while requests still hold refs.
|
||||
if Arc::strong_count(&storage) > 1 {
|
||||
let mut lock = self.inner.write().await;
|
||||
lock.insert(universal_id, storage);
|
||||
return Err(Error::InvalidOperation);
|
||||
}
|
||||
let Some(StorageState::Connected(storage)) = lock.remove(&universal_id) else {
|
||||
return Err(Error::InvalidOperation);
|
||||
};
|
||||
storage
|
||||
};
|
||||
|
||||
storage.close().await;
|
||||
Ok(())
|
||||
|
||||
Reference in New Issue
Block a user