fix(native): possible deadlock when batching read/write (#9817)

This commit is contained in:
forehalo
2025-01-21 06:07:03 +00:00
parent 46ee235674
commit 07c32d016d
11 changed files with 208 additions and 169 deletions

View File

@@ -8,7 +8,7 @@ pub mod sync;
use chrono::NaiveDateTime;
use napi::bindgen_prelude::*;
use napi_derive::napi;
use pool::SqliteDocStoragePool;
use pool::{Ref, SqliteDocStoragePool};
use storage::SqliteDocStorage;
#[cfg(feature = "use-as-lib")]
@@ -86,13 +86,17 @@ pub struct DocStoragePool {
#[napi]
impl DocStoragePool {
#[napi(constructor, async_runtime)]
#[napi(constructor)]
pub fn new() -> Result<Self> {
Ok(Self {
pool: SqliteDocStoragePool::default(),
})
}
async fn get<'a>(&'a self, universal_id: String) -> Result<Ref<'a, SqliteDocStorage>> {
Ok(self.pool.get(universal_id).await?)
}
#[napi]
/// Initialize the database and run migrations.
pub async fn connect(&self, universal_id: String, path: String) -> Result<()> {
@@ -100,16 +104,6 @@ impl DocStoragePool {
Ok(())
}
#[napi]
pub async fn set_space_id(&self, universal_id: String, space_id: String) -> Result<()> {
self
.pool
.ensure_storage(universal_id)?
.set_space_id(space_id)
.await?;
Ok(())
}
#[napi]
pub async fn disconnect(&self, universal_id: String) -> Result<()> {
self.pool.disconnect(universal_id).await?;
@@ -118,7 +112,13 @@ impl DocStoragePool {
#[napi]
pub async fn checkpoint(&self, universal_id: String) -> Result<()> {
self.pool.ensure_storage(universal_id)?.checkpoint().await?;
self.pool.get(universal_id).await?.checkpoint().await?;
Ok(())
}
#[napi]
pub async fn set_space_id(&self, universal_id: String, space_id: String) -> Result<()> {
self.get(universal_id).await?.set_space_id(space_id).await?;
Ok(())
}
@@ -131,8 +131,8 @@ impl DocStoragePool {
) -> Result<NaiveDateTime> {
Ok(
self
.pool
.ensure_storage(universal_id)?
.get(universal_id)
.await?
.push_update(doc_id, update)
.await?,
)
@@ -146,8 +146,8 @@ impl DocStoragePool {
) -> Result<Option<DocRecord>> {
Ok(
self
.pool
.ensure_storage(universal_id)?
.get(universal_id)
.await?
.get_doc_snapshot(doc_id)
.await?,
)
@@ -157,8 +157,8 @@ impl DocStoragePool {
pub async fn set_doc_snapshot(&self, universal_id: String, snapshot: DocRecord) -> Result<bool> {
Ok(
self
.pool
.ensure_storage(universal_id)?
.get(universal_id)
.await?
.set_doc_snapshot(snapshot)
.await?,
)
@@ -172,8 +172,8 @@ impl DocStoragePool {
) -> Result<Vec<DocUpdate>> {
Ok(
self
.pool
.ensure_storage(universal_id)?
.get(universal_id)
.await?
.get_doc_updates(doc_id)
.await?,
)
@@ -188,8 +188,8 @@ impl DocStoragePool {
) -> Result<u32> {
Ok(
self
.pool
.ensure_storage(universal_id)?
.get(universal_id)
.await?
.mark_updates_merged(doc_id, updates)
.await?,
)
@@ -197,11 +197,7 @@ impl DocStoragePool {
#[napi]
pub async fn delete_doc(&self, universal_id: String, doc_id: String) -> Result<()> {
self
.pool
.ensure_storage(universal_id)?
.delete_doc(doc_id)
.await?;
self.get(universal_id).await?.delete_doc(doc_id).await?;
Ok(())
}
@@ -211,13 +207,7 @@ impl DocStoragePool {
universal_id: String,
after: Option<NaiveDateTime>,
) -> Result<Vec<DocClock>> {
Ok(
self
.pool
.ensure_storage(universal_id)?
.get_doc_clocks(after)
.await?,
)
Ok(self.get(universal_id).await?.get_doc_clocks(after).await?)
}
#[napi]
@@ -226,33 +216,17 @@ impl DocStoragePool {
universal_id: String,
doc_id: String,
) -> Result<Option<DocClock>> {
Ok(
self
.pool
.ensure_storage(universal_id)?
.get_doc_clock(doc_id)
.await?,
)
Ok(self.get(universal_id).await?.get_doc_clock(doc_id).await?)
}
#[napi]
#[napi(async_runtime)]
pub async fn get_blob(&self, universal_id: String, key: String) -> Result<Option<Blob>> {
Ok(
self
.pool
.ensure_storage(universal_id)?
.get_blob(key)
.await?,
)
Ok(self.get(universal_id).await?.get_blob(key).await?)
}
#[napi]
pub async fn set_blob(&self, universal_id: String, blob: SetBlob) -> Result<()> {
self
.pool
.ensure_storage(universal_id)?
.set_blob(blob)
.await?;
self.get(universal_id).await?.set_blob(blob).await?;
Ok(())
}
@@ -264,8 +238,8 @@ impl DocStoragePool {
permanently: bool,
) -> Result<()> {
self
.pool
.ensure_storage(universal_id)?
.get(universal_id)
.await?
.delete_blob(key, permanently)
.await?;
Ok(())
@@ -273,17 +247,13 @@ impl DocStoragePool {
#[napi]
pub async fn release_blobs(&self, universal_id: String) -> Result<()> {
self
.pool
.ensure_storage(universal_id)?
.release_blobs()
.await?;
self.get(universal_id).await?.release_blobs().await?;
Ok(())
}
#[napi]
pub async fn list_blobs(&self, universal_id: String) -> Result<Vec<ListedBlob>> {
Ok(self.pool.ensure_storage(universal_id)?.list_blobs().await?)
Ok(self.get(universal_id).await?.list_blobs().await?)
}
#[napi]
@@ -294,8 +264,8 @@ impl DocStoragePool {
) -> Result<Vec<DocClock>> {
Ok(
self
.pool
.ensure_storage(universal_id)?
.get(universal_id)
.await?
.get_peer_remote_clocks(peer)
.await?,
)
@@ -310,8 +280,8 @@ impl DocStoragePool {
) -> Result<Option<DocClock>> {
Ok(
self
.pool
.ensure_storage(universal_id)?
.get(universal_id)
.await?
.get_peer_remote_clock(peer, doc_id)
.await?,
)
@@ -326,8 +296,8 @@ impl DocStoragePool {
clock: NaiveDateTime,
) -> Result<()> {
self
.pool
.ensure_storage(universal_id)?
.get(universal_id)
.await?
.set_peer_remote_clock(peer, doc_id, clock)
.await?;
Ok(())
@@ -341,8 +311,8 @@ impl DocStoragePool {
) -> Result<Vec<DocClock>> {
Ok(
self
.pool
.ensure_storage(universal_id)?
.get(universal_id)
.await?
.get_peer_pulled_remote_clocks(peer)
.await?,
)
@@ -357,8 +327,8 @@ impl DocStoragePool {
) -> Result<Option<DocClock>> {
Ok(
self
.pool
.ensure_storage(universal_id)?
.get(universal_id)
.await?
.get_peer_pulled_remote_clock(peer, doc_id)
.await?,
)
@@ -373,8 +343,8 @@ impl DocStoragePool {
clock: NaiveDateTime,
) -> Result<()> {
self
.pool
.ensure_storage(universal_id)?
.get(universal_id)
.await?
.set_peer_pulled_remote_clock(peer, doc_id, clock)
.await?;
Ok(())
@@ -388,8 +358,8 @@ impl DocStoragePool {
) -> Result<Vec<DocClock>> {
Ok(
self
.pool
.ensure_storage(universal_id)?
.get(universal_id)
.await?
.get_peer_pushed_clocks(peer)
.await?,
)
@@ -404,8 +374,8 @@ impl DocStoragePool {
) -> Result<Option<DocClock>> {
Ok(
self
.pool
.ensure_storage(universal_id)?
.get(universal_id)
.await?
.get_peer_pushed_clock(peer, doc_id)
.await?,
)
@@ -420,8 +390,8 @@ impl DocStoragePool {
clock: NaiveDateTime,
) -> Result<()> {
self
.pool
.ensure_storage(universal_id)?
.get(universal_id)
.await?
.set_peer_pushed_clock(peer, doc_id, clock)
.await?;
Ok(())
@@ -429,11 +399,7 @@ impl DocStoragePool {
#[napi]
pub async fn clear_clocks(&self, universal_id: String) -> Result<()> {
self
.pool
.ensure_storage(universal_id)?
.clear_clocks()
.await?;
self.get(universal_id).await?.clear_clocks().await?;
Ok(())
}
}
@@ -461,7 +427,6 @@ impl DocStorage {
pub async fn set_space_id(&self, space_id: String) -> Result<()> {
self.storage.connect().await?;
self.storage.set_space_id(space_id).await?;
println!("clocks {:?}", self.storage.get_doc_clocks(None).await?);
self.storage.close().await;
Ok(())
}

View File

@@ -1,54 +1,96 @@
use dashmap::{mapref::one::RefMut, DashMap, Entry};
use core::ops::{Deref, DerefMut};
use std::collections::hash_map::{Entry, HashMap};
use tokio::sync::{RwLock, RwLockMappedWriteGuard, RwLockReadGuard, RwLockWriteGuard};
use super::{
error::{Error, Result},
storage::SqliteDocStorage,
};
pub struct Ref<'a, V> {
_guard: RwLockReadGuard<'a, V>,
}
impl<'a, V> Deref for Ref<'a, V> {
type Target = V;
fn deref(&self) -> &Self::Target {
self._guard.deref()
}
}
pub struct RefMut<'a, V> {
_guard: RwLockMappedWriteGuard<'a, V>,
}
impl<'a, V> Deref for RefMut<'a, V> {
type Target = V;
fn deref(&self) -> &Self::Target {
&*self._guard
}
}
impl<'a, V> DerefMut for RefMut<'a, V> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut *self._guard
}
}
#[derive(Default)]
pub struct SqliteDocStoragePool {
inner: DashMap<String, SqliteDocStorage>,
inner: RwLock<HashMap<String, SqliteDocStorage>>,
}
impl SqliteDocStoragePool {
fn get_or_create_storage<'a>(
async fn get_or_create_storage<'a>(
&'a self,
universal_id: String,
path: &str,
) -> RefMut<'a, String, SqliteDocStorage> {
let entry = self.inner.entry(universal_id);
if let Entry::Occupied(storage) = entry {
return storage.into_ref();
}
let storage = SqliteDocStorage::new(path.to_string());
) -> RefMut<'a, SqliteDocStorage> {
let lock = RwLockWriteGuard::map(self.inner.write().await, |lock| {
match lock.entry(universal_id) {
Entry::Occupied(entry) => entry.into_mut(),
Entry::Vacant(entry) => {
let storage = SqliteDocStorage::new(path.to_string());
entry.insert(storage)
}
}
});
entry.or_insert(storage)
RefMut { _guard: lock }
}
pub fn ensure_storage<'a>(
&'a self,
universal_id: String,
) -> Result<RefMut<'a, String, SqliteDocStorage>> {
let entry = self.inner.entry(universal_id);
pub async fn get<'a>(&'a self, universal_id: String) -> Result<Ref<'a, SqliteDocStorage>> {
let lock = RwLockReadGuard::try_map(self.inner.read().await, |lock| {
if let Some(storage) = lock.get(&universal_id) {
Some(storage)
} else {
None
}
});
if let Entry::Occupied(storage) = entry {
Ok(storage.into_ref())
} else {
Err(Error::InvalidOperation)
match lock {
Ok(guard) => Ok(Ref { _guard: guard }),
Err(_) => Err(Error::InvalidOperation),
}
}
/// Initialize the database and run migrations.
pub async fn connect(&self, universal_id: String, path: String) -> Result<()> {
let storage = self.get_or_create_storage(universal_id.to_owned(), &path);
let storage = self
.get_or_create_storage(universal_id.to_owned(), &path)
.await;
storage.connect().await?;
Ok(())
}
pub async fn disconnect(&self, universal_id: String) -> Result<()> {
let entry = self.inner.entry(universal_id);
let mut lock = self.inner.write().await;
if let Entry::Occupied(entry) = entry {
if let Entry::Occupied(entry) = lock.entry(universal_id) {
let storage = entry.remove();
storage.close().await;
}

View File

@@ -16,24 +16,28 @@ impl SqliteDocStorage {
pub fn new(path: String) -> Self {
let sqlite_options = SqliteConnectOptions::new()
.filename(&path)
.foreign_keys(false)
.journal_mode(sqlx::sqlite::SqliteJournalMode::Wal);
.foreign_keys(false);
let mut pool_options = SqlitePoolOptions::new();
if cfg!(test) && path == ":memory:" {
if path == ":memory:" {
pool_options = pool_options
.min_connections(1)
.max_connections(1)
.idle_timeout(None)
.max_lifetime(None);
} else {
pool_options = pool_options.max_connections(4);
}
Self {
pool: pool_options.connect_lazy_with(sqlite_options),
path,
Self {
pool: pool_options.connect_lazy_with(sqlite_options),
path,
}
} else {
Self {
pool: pool_options
.max_connections(4)
.connect_lazy_with(sqlite_options.journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)),
path,
}
}
}