diff --git a/Cargo.lock b/Cargo.lock index 5063d3d461..78283dba3b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -37,7 +37,6 @@ dependencies = [ "anyhow", "base64-simd", "chrono", - "dashmap", "homedir", "objc2", "objc2-foundation", @@ -68,7 +67,6 @@ dependencies = [ "affine_schema", "anyhow", "chrono", - "dashmap", "dotenvy", "napi", "napi-build", diff --git a/Cargo.toml b/Cargo.toml index 18af1ca490..9145d02a42 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,6 @@ anyhow = "1" base64-simd = "0.8" chrono = "0.4" criterion2 = { version = "2", default-features = false } -dashmap = "6" dotenvy = "0.15" file-format = { version = "0.26", features = ["reader"] } mimalloc = "0.1" diff --git a/packages/frontend/mobile-native/Cargo.toml b/packages/frontend/mobile-native/Cargo.toml index eeed92cb1b..a00c2ec69c 100644 --- a/packages/frontend/mobile-native/Cargo.toml +++ b/packages/frontend/mobile-native/Cargo.toml @@ -17,7 +17,6 @@ affine_nbstore = { workspace = true, features = ["use-as-lib"] } anyhow = { workspace = true } base64-simd = { workspace = true } chrono = { workspace = true } -dashmap = { workspace = true } sqlx = { workspace = true } thiserror = { workspace = true } uniffi = { workspace = true, features = ["cli", "tokio"] } diff --git a/packages/frontend/mobile-native/src/lib.rs b/packages/frontend/mobile-native/src/lib.rs index ff7be8c800..3c159f0ce6 100644 --- a/packages/frontend/mobile-native/src/lib.rs +++ b/packages/frontend/mobile-native/src/lib.rs @@ -211,7 +211,8 @@ impl DocStoragePool { Ok( self .inner - .ensure_storage(universal_id)? + .get(universal_id) + .await? .set_space_id(space_id) .await?, ) @@ -226,7 +227,8 @@ impl DocStoragePool { Ok( self .inner - .ensure_storage(universal_id)? + .get(universal_id) + .await? .push_update( doc_id, base64_simd::STANDARD @@ -247,7 +249,8 @@ impl DocStoragePool { Ok( self .inner - .ensure_storage(universal_id)? + .get(universal_id) + .await? .get_doc_snapshot(doc_id) .await? .map(Into::into), @@ -258,7 +261,8 @@ impl DocStoragePool { Ok( self .inner - .ensure_storage(universal_id)? + .get(universal_id) + .await? .set_doc_snapshot(snapshot.try_into()?) .await?, ) @@ -272,7 +276,8 @@ impl DocStoragePool { Ok( self .inner - .ensure_storage(universal_id)? + .get(universal_id) + .await? .get_doc_updates(doc_id) .await? .into_iter() @@ -290,7 +295,8 @@ impl DocStoragePool { Ok( self .inner - .ensure_storage(universal_id)? + .get(universal_id) + .await? .mark_updates_merged( doc_id, updates @@ -310,7 +316,8 @@ impl DocStoragePool { Ok( self .inner - .ensure_storage(universal_id)? + .get(universal_id) + .await? .delete_doc(doc_id) .await?, ) @@ -324,7 +331,8 @@ impl DocStoragePool { Ok( self .inner - .ensure_storage(universal_id)? + .get(universal_id) + .await? .get_doc_clocks( after .map(|t| { @@ -349,7 +357,8 @@ impl DocStoragePool { Ok( self .inner - .ensure_storage(universal_id)? + .get(universal_id) + .await? .get_doc_clock(doc_id) .await? .map(Into::into), @@ -360,7 +369,8 @@ impl DocStoragePool { Ok( self .inner - .ensure_storage(universal_id)? + .get(universal_id) + .await? .get_blob(key) .await? .map(Into::into), @@ -371,7 +381,8 @@ impl DocStoragePool { Ok( self .inner - .ensure_storage(universal_id)? + .get(universal_id) + .await? .set_blob(blob.try_into()?) .await?, ) @@ -386,27 +397,23 @@ impl DocStoragePool { Ok( self .inner - .ensure_storage(universal_id)? + .get(universal_id) + .await? .delete_blob(key, permanently) .await?, ) } pub async fn release_blobs(&self, universal_id: String) -> Result<()> { - Ok( - self - .inner - .ensure_storage(universal_id)? - .release_blobs() - .await?, - ) + Ok(self.inner.get(universal_id).await?.release_blobs().await?) } pub async fn list_blobs(&self, universal_id: String) -> Result> { Ok( self .inner - .ensure_storage(universal_id)? + .get(universal_id) + .await? .list_blobs() .await? .into_iter() @@ -423,7 +430,8 @@ impl DocStoragePool { Ok( self .inner - .ensure_storage(universal_id)? + .get(universal_id) + .await? .get_peer_remote_clocks(peer) .await? .into_iter() @@ -441,7 +449,8 @@ impl DocStoragePool { Ok( self .inner - .ensure_storage(universal_id)? + .get(universal_id) + .await? .get_peer_remote_clock(peer, doc_id) .await? .map(Into::into), @@ -458,7 +467,8 @@ impl DocStoragePool { Ok( self .inner - .ensure_storage(universal_id)? + .get(universal_id) + .await? .set_peer_remote_clock( peer, doc_id, @@ -478,7 +488,8 @@ impl DocStoragePool { Ok( self .inner - .ensure_storage(universal_id)? + .get(universal_id) + .await? .get_peer_pulled_remote_clocks(peer) .await? .into_iter() @@ -496,7 +507,8 @@ impl DocStoragePool { Ok( self .inner - .ensure_storage(universal_id)? + .get(universal_id) + .await? .get_peer_pulled_remote_clock(peer, doc_id) .await? .map(Into::into), @@ -513,7 +525,8 @@ impl DocStoragePool { Ok( self .inner - .ensure_storage(universal_id)? + .get(universal_id) + .await? .set_peer_pulled_remote_clock( peer, doc_id, @@ -534,7 +547,8 @@ impl DocStoragePool { Ok( self .inner - .ensure_storage(universal_id)? + .get(universal_id) + .await? .get_peer_pushed_clock(peer, doc_id) .await? .map(Into::into), @@ -549,7 +563,8 @@ impl DocStoragePool { Ok( self .inner - .ensure_storage(universal_id)? + .get(universal_id) + .await? .get_peer_pushed_clocks(peer) .await? .into_iter() @@ -568,7 +583,8 @@ impl DocStoragePool { Ok( self .inner - .ensure_storage(universal_id)? + .get(universal_id) + .await? .set_peer_pushed_clock( peer, doc_id, @@ -581,12 +597,6 @@ impl DocStoragePool { } pub async fn clear_clocks(&self, universal_id: String) -> Result<()> { - Ok( - self - .inner - .ensure_storage(universal_id)? - .clear_clocks() - .await?, - ) + Ok(self.inner.get(universal_id).await?.clear_clocks().await?) } } diff --git a/packages/frontend/native/__tests__/pool.spec.mts b/packages/frontend/native/__tests__/pool.spec.mts new file mode 100644 index 0000000000..7079afaaba --- /dev/null +++ b/packages/frontend/native/__tests__/pool.spec.mts @@ -0,0 +1,31 @@ +import { DocStoragePool } from '../index'; +import test from 'ava'; + +test('can batch read/write pool', async t => { + const pool = new DocStoragePool(); + await pool.connect('test', ':memory:'); + + const batch = 512; + + await Promise.all( + Array.from({ length: batch }).map(async (_, i) => { + return pool.setBlob('test', { + key: `test-blob-${i}`, + data: new Uint8Array([i % 255]), + mime: 'text/plain', + }); + }) + ); + + const blobs = await Promise.all( + Array.from({ length: batch }).map(async (_, i) => { + return pool.getBlob('test', `test-blob-${i}`); + }) + ); + + t.is(blobs.length, batch); + t.is( + blobs.every((blob, i) => blob!.data.at(0) === i % 255), + true + ); +}); diff --git a/packages/frontend/native/index.d.ts b/packages/frontend/native/index.d.ts index 35e83a01c4..f5ceb6f305 100644 --- a/packages/frontend/native/index.d.ts +++ b/packages/frontend/native/index.d.ts @@ -10,9 +10,9 @@ export declare class DocStoragePool { constructor() /** Initialize the database and run migrations. */ connect(universalId: string, path: string): Promise - setSpaceId(universalId: string, spaceId: string): Promise disconnect(universalId: string): Promise checkpoint(universalId: string): Promise + setSpaceId(universalId: string, spaceId: string): Promise pushUpdate(universalId: string, docId: string, update: Uint8Array): Promise getDocSnapshot(universalId: string, docId: string): Promise setDocSnapshot(universalId: string, snapshot: DocRecord): Promise diff --git a/packages/frontend/native/nbstore/Cargo.toml b/packages/frontend/native/nbstore/Cargo.toml index a3bc0e8146..8cdc424466 100644 --- a/packages/frontend/native/nbstore/Cargo.toml +++ b/packages/frontend/native/nbstore/Cargo.toml @@ -13,7 +13,6 @@ use-as-lib = ["napi-derive/noop", "napi/noop"] affine_schema = { path = "../schema" } anyhow = { workspace = true } chrono = { workspace = true } -dashmap = { workspace = true } napi = { workspace = true } napi-derive = { workspace = true } thiserror = { workspace = true } diff --git a/packages/frontend/native/nbstore/src/lib.rs b/packages/frontend/native/nbstore/src/lib.rs index 0de13c10fc..cf7c4955f2 100644 --- a/packages/frontend/native/nbstore/src/lib.rs +++ b/packages/frontend/native/nbstore/src/lib.rs @@ -8,7 +8,7 @@ pub mod sync; use chrono::NaiveDateTime; use napi::bindgen_prelude::*; use napi_derive::napi; -use pool::SqliteDocStoragePool; +use pool::{Ref, SqliteDocStoragePool}; use storage::SqliteDocStorage; #[cfg(feature = "use-as-lib")] @@ -86,13 +86,17 @@ pub struct DocStoragePool { #[napi] impl DocStoragePool { - #[napi(constructor, async_runtime)] + #[napi(constructor)] pub fn new() -> Result { Ok(Self { pool: SqliteDocStoragePool::default(), }) } + async fn get<'a>(&'a self, universal_id: String) -> Result> { + Ok(self.pool.get(universal_id).await?) + } + #[napi] /// Initialize the database and run migrations. pub async fn connect(&self, universal_id: String, path: String) -> Result<()> { @@ -100,16 +104,6 @@ impl DocStoragePool { Ok(()) } - #[napi] - pub async fn set_space_id(&self, universal_id: String, space_id: String) -> Result<()> { - self - .pool - .ensure_storage(universal_id)? - .set_space_id(space_id) - .await?; - Ok(()) - } - #[napi] pub async fn disconnect(&self, universal_id: String) -> Result<()> { self.pool.disconnect(universal_id).await?; @@ -118,7 +112,13 @@ impl DocStoragePool { #[napi] pub async fn checkpoint(&self, universal_id: String) -> Result<()> { - self.pool.ensure_storage(universal_id)?.checkpoint().await?; + self.pool.get(universal_id).await?.checkpoint().await?; + Ok(()) + } + + #[napi] + pub async fn set_space_id(&self, universal_id: String, space_id: String) -> Result<()> { + self.get(universal_id).await?.set_space_id(space_id).await?; Ok(()) } @@ -131,8 +131,8 @@ impl DocStoragePool { ) -> Result { Ok( self - .pool - .ensure_storage(universal_id)? + .get(universal_id) + .await? .push_update(doc_id, update) .await?, ) @@ -146,8 +146,8 @@ impl DocStoragePool { ) -> Result> { Ok( self - .pool - .ensure_storage(universal_id)? + .get(universal_id) + .await? .get_doc_snapshot(doc_id) .await?, ) @@ -157,8 +157,8 @@ impl DocStoragePool { pub async fn set_doc_snapshot(&self, universal_id: String, snapshot: DocRecord) -> Result { Ok( self - .pool - .ensure_storage(universal_id)? + .get(universal_id) + .await? .set_doc_snapshot(snapshot) .await?, ) @@ -172,8 +172,8 @@ impl DocStoragePool { ) -> Result> { Ok( self - .pool - .ensure_storage(universal_id)? + .get(universal_id) + .await? .get_doc_updates(doc_id) .await?, ) @@ -188,8 +188,8 @@ impl DocStoragePool { ) -> Result { Ok( self - .pool - .ensure_storage(universal_id)? + .get(universal_id) + .await? .mark_updates_merged(doc_id, updates) .await?, ) @@ -197,11 +197,7 @@ impl DocStoragePool { #[napi] pub async fn delete_doc(&self, universal_id: String, doc_id: String) -> Result<()> { - self - .pool - .ensure_storage(universal_id)? - .delete_doc(doc_id) - .await?; + self.get(universal_id).await?.delete_doc(doc_id).await?; Ok(()) } @@ -211,13 +207,7 @@ impl DocStoragePool { universal_id: String, after: Option, ) -> Result> { - Ok( - self - .pool - .ensure_storage(universal_id)? - .get_doc_clocks(after) - .await?, - ) + Ok(self.get(universal_id).await?.get_doc_clocks(after).await?) } #[napi] @@ -226,33 +216,17 @@ impl DocStoragePool { universal_id: String, doc_id: String, ) -> Result> { - Ok( - self - .pool - .ensure_storage(universal_id)? - .get_doc_clock(doc_id) - .await?, - ) + Ok(self.get(universal_id).await?.get_doc_clock(doc_id).await?) } - #[napi] + #[napi(async_runtime)] pub async fn get_blob(&self, universal_id: String, key: String) -> Result> { - Ok( - self - .pool - .ensure_storage(universal_id)? - .get_blob(key) - .await?, - ) + Ok(self.get(universal_id).await?.get_blob(key).await?) } #[napi] pub async fn set_blob(&self, universal_id: String, blob: SetBlob) -> Result<()> { - self - .pool - .ensure_storage(universal_id)? - .set_blob(blob) - .await?; + self.get(universal_id).await?.set_blob(blob).await?; Ok(()) } @@ -264,8 +238,8 @@ impl DocStoragePool { permanently: bool, ) -> Result<()> { self - .pool - .ensure_storage(universal_id)? + .get(universal_id) + .await? .delete_blob(key, permanently) .await?; Ok(()) @@ -273,17 +247,13 @@ impl DocStoragePool { #[napi] pub async fn release_blobs(&self, universal_id: String) -> Result<()> { - self - .pool - .ensure_storage(universal_id)? - .release_blobs() - .await?; + self.get(universal_id).await?.release_blobs().await?; Ok(()) } #[napi] pub async fn list_blobs(&self, universal_id: String) -> Result> { - Ok(self.pool.ensure_storage(universal_id)?.list_blobs().await?) + Ok(self.get(universal_id).await?.list_blobs().await?) } #[napi] @@ -294,8 +264,8 @@ impl DocStoragePool { ) -> Result> { Ok( self - .pool - .ensure_storage(universal_id)? + .get(universal_id) + .await? .get_peer_remote_clocks(peer) .await?, ) @@ -310,8 +280,8 @@ impl DocStoragePool { ) -> Result> { Ok( self - .pool - .ensure_storage(universal_id)? + .get(universal_id) + .await? .get_peer_remote_clock(peer, doc_id) .await?, ) @@ -326,8 +296,8 @@ impl DocStoragePool { clock: NaiveDateTime, ) -> Result<()> { self - .pool - .ensure_storage(universal_id)? + .get(universal_id) + .await? .set_peer_remote_clock(peer, doc_id, clock) .await?; Ok(()) @@ -341,8 +311,8 @@ impl DocStoragePool { ) -> Result> { Ok( self - .pool - .ensure_storage(universal_id)? + .get(universal_id) + .await? .get_peer_pulled_remote_clocks(peer) .await?, ) @@ -357,8 +327,8 @@ impl DocStoragePool { ) -> Result> { Ok( self - .pool - .ensure_storage(universal_id)? + .get(universal_id) + .await? .get_peer_pulled_remote_clock(peer, doc_id) .await?, ) @@ -373,8 +343,8 @@ impl DocStoragePool { clock: NaiveDateTime, ) -> Result<()> { self - .pool - .ensure_storage(universal_id)? + .get(universal_id) + .await? .set_peer_pulled_remote_clock(peer, doc_id, clock) .await?; Ok(()) @@ -388,8 +358,8 @@ impl DocStoragePool { ) -> Result> { Ok( self - .pool - .ensure_storage(universal_id)? + .get(universal_id) + .await? .get_peer_pushed_clocks(peer) .await?, ) @@ -404,8 +374,8 @@ impl DocStoragePool { ) -> Result> { Ok( self - .pool - .ensure_storage(universal_id)? + .get(universal_id) + .await? .get_peer_pushed_clock(peer, doc_id) .await?, ) @@ -420,8 +390,8 @@ impl DocStoragePool { clock: NaiveDateTime, ) -> Result<()> { self - .pool - .ensure_storage(universal_id)? + .get(universal_id) + .await? .set_peer_pushed_clock(peer, doc_id, clock) .await?; Ok(()) @@ -429,11 +399,7 @@ impl DocStoragePool { #[napi] pub async fn clear_clocks(&self, universal_id: String) -> Result<()> { - self - .pool - .ensure_storage(universal_id)? - .clear_clocks() - .await?; + self.get(universal_id).await?.clear_clocks().await?; Ok(()) } } @@ -461,7 +427,6 @@ impl DocStorage { pub async fn set_space_id(&self, space_id: String) -> Result<()> { self.storage.connect().await?; self.storage.set_space_id(space_id).await?; - println!("clocks {:?}", self.storage.get_doc_clocks(None).await?); self.storage.close().await; Ok(()) } diff --git a/packages/frontend/native/nbstore/src/pool.rs b/packages/frontend/native/nbstore/src/pool.rs index b2b6eb9cb5..49f9527cfd 100644 --- a/packages/frontend/native/nbstore/src/pool.rs +++ b/packages/frontend/native/nbstore/src/pool.rs @@ -1,54 +1,96 @@ -use dashmap::{mapref::one::RefMut, DashMap, Entry}; +use core::ops::{Deref, DerefMut}; +use std::collections::hash_map::{Entry, HashMap}; + +use tokio::sync::{RwLock, RwLockMappedWriteGuard, RwLockReadGuard, RwLockWriteGuard}; use super::{ error::{Error, Result}, storage::SqliteDocStorage, }; +pub struct Ref<'a, V> { + _guard: RwLockReadGuard<'a, V>, +} + +impl<'a, V> Deref for Ref<'a, V> { + type Target = V; + + fn deref(&self) -> &Self::Target { + self._guard.deref() + } +} + +pub struct RefMut<'a, V> { + _guard: RwLockMappedWriteGuard<'a, V>, +} + +impl<'a, V> Deref for RefMut<'a, V> { + type Target = V; + + fn deref(&self) -> &Self::Target { + &*self._guard + } +} + +impl<'a, V> DerefMut for RefMut<'a, V> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut *self._guard + } +} + #[derive(Default)] pub struct SqliteDocStoragePool { - inner: DashMap, + inner: RwLock>, } impl SqliteDocStoragePool { - fn get_or_create_storage<'a>( + async fn get_or_create_storage<'a>( &'a self, universal_id: String, path: &str, - ) -> RefMut<'a, String, SqliteDocStorage> { - let entry = self.inner.entry(universal_id); - if let Entry::Occupied(storage) = entry { - return storage.into_ref(); - } - let storage = SqliteDocStorage::new(path.to_string()); + ) -> RefMut<'a, SqliteDocStorage> { + let lock = RwLockWriteGuard::map(self.inner.write().await, |lock| { + match lock.entry(universal_id) { + Entry::Occupied(entry) => entry.into_mut(), + Entry::Vacant(entry) => { + let storage = SqliteDocStorage::new(path.to_string()); + entry.insert(storage) + } + } + }); - entry.or_insert(storage) + RefMut { _guard: lock } } - pub fn ensure_storage<'a>( - &'a self, - universal_id: String, - ) -> Result> { - let entry = self.inner.entry(universal_id); + pub async fn get<'a>(&'a self, universal_id: String) -> Result> { + let lock = RwLockReadGuard::try_map(self.inner.read().await, |lock| { + if let Some(storage) = lock.get(&universal_id) { + Some(storage) + } else { + None + } + }); - if let Entry::Occupied(storage) = entry { - Ok(storage.into_ref()) - } else { - Err(Error::InvalidOperation) + match lock { + Ok(guard) => Ok(Ref { _guard: guard }), + Err(_) => Err(Error::InvalidOperation), } } /// Initialize the database and run migrations. pub async fn connect(&self, universal_id: String, path: String) -> Result<()> { - let storage = self.get_or_create_storage(universal_id.to_owned(), &path); + let storage = self + .get_or_create_storage(universal_id.to_owned(), &path) + .await; + storage.connect().await?; Ok(()) } pub async fn disconnect(&self, universal_id: String) -> Result<()> { - let entry = self.inner.entry(universal_id); + let mut lock = self.inner.write().await; - if let Entry::Occupied(entry) = entry { + if let Entry::Occupied(entry) = lock.entry(universal_id) { let storage = entry.remove(); storage.close().await; } diff --git a/packages/frontend/native/nbstore/src/storage.rs b/packages/frontend/native/nbstore/src/storage.rs index 056f5ee90c..269d59369c 100644 --- a/packages/frontend/native/nbstore/src/storage.rs +++ b/packages/frontend/native/nbstore/src/storage.rs @@ -16,24 +16,28 @@ impl SqliteDocStorage { pub fn new(path: String) -> Self { let sqlite_options = SqliteConnectOptions::new() .filename(&path) - .foreign_keys(false) - .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal); + .foreign_keys(false); let mut pool_options = SqlitePoolOptions::new(); - if cfg!(test) && path == ":memory:" { + if path == ":memory:" { pool_options = pool_options .min_connections(1) .max_connections(1) .idle_timeout(None) .max_lifetime(None); - } else { - pool_options = pool_options.max_connections(4); - } - Self { - pool: pool_options.connect_lazy_with(sqlite_options), - path, + Self { + pool: pool_options.connect_lazy_with(sqlite_options), + path, + } + } else { + Self { + pool: pool_options + .max_connections(4) + .connect_lazy_with(sqlite_options.journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)), + path, + } } } diff --git a/packages/frontend/native/package.json b/packages/frontend/native/package.json index 138243ee37..6b61189771 100644 --- a/packages/frontend/native/package.json +++ b/packages/frontend/native/package.json @@ -20,17 +20,9 @@ "extensions": { "mts": "module" }, - "nodeArguments": [ - "--loader", - "ts-node/esm.mjs", - "--es-module-specifier-resolution=node" - ], "files": [ "__tests__/*.spec.mts" - ], - "environmentVariables": { - "TS_NODE_PROJECT": "./tsconfig.json" - } + ] }, "devDependencies": { "@napi-rs/cli": "3.0.0-alpha.68",