mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-25 10:22:55 +08:00
fix(native): possible deadlock when batching read/write (#9817)
This commit is contained in:
@@ -17,7 +17,6 @@ affine_nbstore = { workspace = true, features = ["use-as-lib"] }
|
||||
anyhow = { workspace = true }
|
||||
base64-simd = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
dashmap = { workspace = true }
|
||||
sqlx = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
uniffi = { workspace = true, features = ["cli", "tokio"] }
|
||||
|
||||
@@ -211,7 +211,8 @@ impl DocStoragePool {
|
||||
Ok(
|
||||
self
|
||||
.inner
|
||||
.ensure_storage(universal_id)?
|
||||
.get(universal_id)
|
||||
.await?
|
||||
.set_space_id(space_id)
|
||||
.await?,
|
||||
)
|
||||
@@ -226,7 +227,8 @@ impl DocStoragePool {
|
||||
Ok(
|
||||
self
|
||||
.inner
|
||||
.ensure_storage(universal_id)?
|
||||
.get(universal_id)
|
||||
.await?
|
||||
.push_update(
|
||||
doc_id,
|
||||
base64_simd::STANDARD
|
||||
@@ -247,7 +249,8 @@ impl DocStoragePool {
|
||||
Ok(
|
||||
self
|
||||
.inner
|
||||
.ensure_storage(universal_id)?
|
||||
.get(universal_id)
|
||||
.await?
|
||||
.get_doc_snapshot(doc_id)
|
||||
.await?
|
||||
.map(Into::into),
|
||||
@@ -258,7 +261,8 @@ impl DocStoragePool {
|
||||
Ok(
|
||||
self
|
||||
.inner
|
||||
.ensure_storage(universal_id)?
|
||||
.get(universal_id)
|
||||
.await?
|
||||
.set_doc_snapshot(snapshot.try_into()?)
|
||||
.await?,
|
||||
)
|
||||
@@ -272,7 +276,8 @@ impl DocStoragePool {
|
||||
Ok(
|
||||
self
|
||||
.inner
|
||||
.ensure_storage(universal_id)?
|
||||
.get(universal_id)
|
||||
.await?
|
||||
.get_doc_updates(doc_id)
|
||||
.await?
|
||||
.into_iter()
|
||||
@@ -290,7 +295,8 @@ impl DocStoragePool {
|
||||
Ok(
|
||||
self
|
||||
.inner
|
||||
.ensure_storage(universal_id)?
|
||||
.get(universal_id)
|
||||
.await?
|
||||
.mark_updates_merged(
|
||||
doc_id,
|
||||
updates
|
||||
@@ -310,7 +316,8 @@ impl DocStoragePool {
|
||||
Ok(
|
||||
self
|
||||
.inner
|
||||
.ensure_storage(universal_id)?
|
||||
.get(universal_id)
|
||||
.await?
|
||||
.delete_doc(doc_id)
|
||||
.await?,
|
||||
)
|
||||
@@ -324,7 +331,8 @@ impl DocStoragePool {
|
||||
Ok(
|
||||
self
|
||||
.inner
|
||||
.ensure_storage(universal_id)?
|
||||
.get(universal_id)
|
||||
.await?
|
||||
.get_doc_clocks(
|
||||
after
|
||||
.map(|t| {
|
||||
@@ -349,7 +357,8 @@ impl DocStoragePool {
|
||||
Ok(
|
||||
self
|
||||
.inner
|
||||
.ensure_storage(universal_id)?
|
||||
.get(universal_id)
|
||||
.await?
|
||||
.get_doc_clock(doc_id)
|
||||
.await?
|
||||
.map(Into::into),
|
||||
@@ -360,7 +369,8 @@ impl DocStoragePool {
|
||||
Ok(
|
||||
self
|
||||
.inner
|
||||
.ensure_storage(universal_id)?
|
||||
.get(universal_id)
|
||||
.await?
|
||||
.get_blob(key)
|
||||
.await?
|
||||
.map(Into::into),
|
||||
@@ -371,7 +381,8 @@ impl DocStoragePool {
|
||||
Ok(
|
||||
self
|
||||
.inner
|
||||
.ensure_storage(universal_id)?
|
||||
.get(universal_id)
|
||||
.await?
|
||||
.set_blob(blob.try_into()?)
|
||||
.await?,
|
||||
)
|
||||
@@ -386,27 +397,23 @@ impl DocStoragePool {
|
||||
Ok(
|
||||
self
|
||||
.inner
|
||||
.ensure_storage(universal_id)?
|
||||
.get(universal_id)
|
||||
.await?
|
||||
.delete_blob(key, permanently)
|
||||
.await?,
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn release_blobs(&self, universal_id: String) -> Result<()> {
|
||||
Ok(
|
||||
self
|
||||
.inner
|
||||
.ensure_storage(universal_id)?
|
||||
.release_blobs()
|
||||
.await?,
|
||||
)
|
||||
Ok(self.inner.get(universal_id).await?.release_blobs().await?)
|
||||
}
|
||||
|
||||
pub async fn list_blobs(&self, universal_id: String) -> Result<Vec<ListedBlob>> {
|
||||
Ok(
|
||||
self
|
||||
.inner
|
||||
.ensure_storage(universal_id)?
|
||||
.get(universal_id)
|
||||
.await?
|
||||
.list_blobs()
|
||||
.await?
|
||||
.into_iter()
|
||||
@@ -423,7 +430,8 @@ impl DocStoragePool {
|
||||
Ok(
|
||||
self
|
||||
.inner
|
||||
.ensure_storage(universal_id)?
|
||||
.get(universal_id)
|
||||
.await?
|
||||
.get_peer_remote_clocks(peer)
|
||||
.await?
|
||||
.into_iter()
|
||||
@@ -441,7 +449,8 @@ impl DocStoragePool {
|
||||
Ok(
|
||||
self
|
||||
.inner
|
||||
.ensure_storage(universal_id)?
|
||||
.get(universal_id)
|
||||
.await?
|
||||
.get_peer_remote_clock(peer, doc_id)
|
||||
.await?
|
||||
.map(Into::into),
|
||||
@@ -458,7 +467,8 @@ impl DocStoragePool {
|
||||
Ok(
|
||||
self
|
||||
.inner
|
||||
.ensure_storage(universal_id)?
|
||||
.get(universal_id)
|
||||
.await?
|
||||
.set_peer_remote_clock(
|
||||
peer,
|
||||
doc_id,
|
||||
@@ -478,7 +488,8 @@ impl DocStoragePool {
|
||||
Ok(
|
||||
self
|
||||
.inner
|
||||
.ensure_storage(universal_id)?
|
||||
.get(universal_id)
|
||||
.await?
|
||||
.get_peer_pulled_remote_clocks(peer)
|
||||
.await?
|
||||
.into_iter()
|
||||
@@ -496,7 +507,8 @@ impl DocStoragePool {
|
||||
Ok(
|
||||
self
|
||||
.inner
|
||||
.ensure_storage(universal_id)?
|
||||
.get(universal_id)
|
||||
.await?
|
||||
.get_peer_pulled_remote_clock(peer, doc_id)
|
||||
.await?
|
||||
.map(Into::into),
|
||||
@@ -513,7 +525,8 @@ impl DocStoragePool {
|
||||
Ok(
|
||||
self
|
||||
.inner
|
||||
.ensure_storage(universal_id)?
|
||||
.get(universal_id)
|
||||
.await?
|
||||
.set_peer_pulled_remote_clock(
|
||||
peer,
|
||||
doc_id,
|
||||
@@ -534,7 +547,8 @@ impl DocStoragePool {
|
||||
Ok(
|
||||
self
|
||||
.inner
|
||||
.ensure_storage(universal_id)?
|
||||
.get(universal_id)
|
||||
.await?
|
||||
.get_peer_pushed_clock(peer, doc_id)
|
||||
.await?
|
||||
.map(Into::into),
|
||||
@@ -549,7 +563,8 @@ impl DocStoragePool {
|
||||
Ok(
|
||||
self
|
||||
.inner
|
||||
.ensure_storage(universal_id)?
|
||||
.get(universal_id)
|
||||
.await?
|
||||
.get_peer_pushed_clocks(peer)
|
||||
.await?
|
||||
.into_iter()
|
||||
@@ -568,7 +583,8 @@ impl DocStoragePool {
|
||||
Ok(
|
||||
self
|
||||
.inner
|
||||
.ensure_storage(universal_id)?
|
||||
.get(universal_id)
|
||||
.await?
|
||||
.set_peer_pushed_clock(
|
||||
peer,
|
||||
doc_id,
|
||||
@@ -581,12 +597,6 @@ impl DocStoragePool {
|
||||
}
|
||||
|
||||
pub async fn clear_clocks(&self, universal_id: String) -> Result<()> {
|
||||
Ok(
|
||||
self
|
||||
.inner
|
||||
.ensure_storage(universal_id)?
|
||||
.clear_clocks()
|
||||
.await?,
|
||||
)
|
||||
Ok(self.inner.get(universal_id).await?.clear_clocks().await?)
|
||||
}
|
||||
}
|
||||
|
||||
31
packages/frontend/native/__tests__/pool.spec.mts
Normal file
31
packages/frontend/native/__tests__/pool.spec.mts
Normal file
@@ -0,0 +1,31 @@
|
||||
import { DocStoragePool } from '../index';
|
||||
import test from 'ava';
|
||||
|
||||
test('can batch read/write pool', async t => {
|
||||
const pool = new DocStoragePool();
|
||||
await pool.connect('test', ':memory:');
|
||||
|
||||
const batch = 512;
|
||||
|
||||
await Promise.all(
|
||||
Array.from({ length: batch }).map(async (_, i) => {
|
||||
return pool.setBlob('test', {
|
||||
key: `test-blob-${i}`,
|
||||
data: new Uint8Array([i % 255]),
|
||||
mime: 'text/plain',
|
||||
});
|
||||
})
|
||||
);
|
||||
|
||||
const blobs = await Promise.all(
|
||||
Array.from({ length: batch }).map(async (_, i) => {
|
||||
return pool.getBlob('test', `test-blob-${i}`);
|
||||
})
|
||||
);
|
||||
|
||||
t.is(blobs.length, batch);
|
||||
t.is(
|
||||
blobs.every((blob, i) => blob!.data.at(0) === i % 255),
|
||||
true
|
||||
);
|
||||
});
|
||||
2
packages/frontend/native/index.d.ts
vendored
2
packages/frontend/native/index.d.ts
vendored
@@ -10,9 +10,9 @@ export declare class DocStoragePool {
|
||||
constructor()
|
||||
/** Initialize the database and run migrations. */
|
||||
connect(universalId: string, path: string): Promise<void>
|
||||
setSpaceId(universalId: string, spaceId: string): Promise<void>
|
||||
disconnect(universalId: string): Promise<void>
|
||||
checkpoint(universalId: string): Promise<void>
|
||||
setSpaceId(universalId: string, spaceId: string): Promise<void>
|
||||
pushUpdate(universalId: string, docId: string, update: Uint8Array): Promise<Date>
|
||||
getDocSnapshot(universalId: string, docId: string): Promise<DocRecord | null>
|
||||
setDocSnapshot(universalId: string, snapshot: DocRecord): Promise<boolean>
|
||||
|
||||
@@ -13,7 +13,6 @@ use-as-lib = ["napi-derive/noop", "napi/noop"]
|
||||
affine_schema = { path = "../schema" }
|
||||
anyhow = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
dashmap = { workspace = true }
|
||||
napi = { workspace = true }
|
||||
napi-derive = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
|
||||
@@ -8,7 +8,7 @@ pub mod sync;
|
||||
use chrono::NaiveDateTime;
|
||||
use napi::bindgen_prelude::*;
|
||||
use napi_derive::napi;
|
||||
use pool::SqliteDocStoragePool;
|
||||
use pool::{Ref, SqliteDocStoragePool};
|
||||
use storage::SqliteDocStorage;
|
||||
|
||||
#[cfg(feature = "use-as-lib")]
|
||||
@@ -86,13 +86,17 @@ pub struct DocStoragePool {
|
||||
|
||||
#[napi]
|
||||
impl DocStoragePool {
|
||||
#[napi(constructor, async_runtime)]
|
||||
#[napi(constructor)]
|
||||
pub fn new() -> Result<Self> {
|
||||
Ok(Self {
|
||||
pool: SqliteDocStoragePool::default(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn get<'a>(&'a self, universal_id: String) -> Result<Ref<'a, SqliteDocStorage>> {
|
||||
Ok(self.pool.get(universal_id).await?)
|
||||
}
|
||||
|
||||
#[napi]
|
||||
/// Initialize the database and run migrations.
|
||||
pub async fn connect(&self, universal_id: String, path: String) -> Result<()> {
|
||||
@@ -100,16 +104,6 @@ impl DocStoragePool {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub async fn set_space_id(&self, universal_id: String, space_id: String) -> Result<()> {
|
||||
self
|
||||
.pool
|
||||
.ensure_storage(universal_id)?
|
||||
.set_space_id(space_id)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub async fn disconnect(&self, universal_id: String) -> Result<()> {
|
||||
self.pool.disconnect(universal_id).await?;
|
||||
@@ -118,7 +112,13 @@ impl DocStoragePool {
|
||||
|
||||
#[napi]
|
||||
pub async fn checkpoint(&self, universal_id: String) -> Result<()> {
|
||||
self.pool.ensure_storage(universal_id)?.checkpoint().await?;
|
||||
self.pool.get(universal_id).await?.checkpoint().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub async fn set_space_id(&self, universal_id: String, space_id: String) -> Result<()> {
|
||||
self.get(universal_id).await?.set_space_id(space_id).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -131,8 +131,8 @@ impl DocStoragePool {
|
||||
) -> Result<NaiveDateTime> {
|
||||
Ok(
|
||||
self
|
||||
.pool
|
||||
.ensure_storage(universal_id)?
|
||||
.get(universal_id)
|
||||
.await?
|
||||
.push_update(doc_id, update)
|
||||
.await?,
|
||||
)
|
||||
@@ -146,8 +146,8 @@ impl DocStoragePool {
|
||||
) -> Result<Option<DocRecord>> {
|
||||
Ok(
|
||||
self
|
||||
.pool
|
||||
.ensure_storage(universal_id)?
|
||||
.get(universal_id)
|
||||
.await?
|
||||
.get_doc_snapshot(doc_id)
|
||||
.await?,
|
||||
)
|
||||
@@ -157,8 +157,8 @@ impl DocStoragePool {
|
||||
pub async fn set_doc_snapshot(&self, universal_id: String, snapshot: DocRecord) -> Result<bool> {
|
||||
Ok(
|
||||
self
|
||||
.pool
|
||||
.ensure_storage(universal_id)?
|
||||
.get(universal_id)
|
||||
.await?
|
||||
.set_doc_snapshot(snapshot)
|
||||
.await?,
|
||||
)
|
||||
@@ -172,8 +172,8 @@ impl DocStoragePool {
|
||||
) -> Result<Vec<DocUpdate>> {
|
||||
Ok(
|
||||
self
|
||||
.pool
|
||||
.ensure_storage(universal_id)?
|
||||
.get(universal_id)
|
||||
.await?
|
||||
.get_doc_updates(doc_id)
|
||||
.await?,
|
||||
)
|
||||
@@ -188,8 +188,8 @@ impl DocStoragePool {
|
||||
) -> Result<u32> {
|
||||
Ok(
|
||||
self
|
||||
.pool
|
||||
.ensure_storage(universal_id)?
|
||||
.get(universal_id)
|
||||
.await?
|
||||
.mark_updates_merged(doc_id, updates)
|
||||
.await?,
|
||||
)
|
||||
@@ -197,11 +197,7 @@ impl DocStoragePool {
|
||||
|
||||
#[napi]
|
||||
pub async fn delete_doc(&self, universal_id: String, doc_id: String) -> Result<()> {
|
||||
self
|
||||
.pool
|
||||
.ensure_storage(universal_id)?
|
||||
.delete_doc(doc_id)
|
||||
.await?;
|
||||
self.get(universal_id).await?.delete_doc(doc_id).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -211,13 +207,7 @@ impl DocStoragePool {
|
||||
universal_id: String,
|
||||
after: Option<NaiveDateTime>,
|
||||
) -> Result<Vec<DocClock>> {
|
||||
Ok(
|
||||
self
|
||||
.pool
|
||||
.ensure_storage(universal_id)?
|
||||
.get_doc_clocks(after)
|
||||
.await?,
|
||||
)
|
||||
Ok(self.get(universal_id).await?.get_doc_clocks(after).await?)
|
||||
}
|
||||
|
||||
#[napi]
|
||||
@@ -226,33 +216,17 @@ impl DocStoragePool {
|
||||
universal_id: String,
|
||||
doc_id: String,
|
||||
) -> Result<Option<DocClock>> {
|
||||
Ok(
|
||||
self
|
||||
.pool
|
||||
.ensure_storage(universal_id)?
|
||||
.get_doc_clock(doc_id)
|
||||
.await?,
|
||||
)
|
||||
Ok(self.get(universal_id).await?.get_doc_clock(doc_id).await?)
|
||||
}
|
||||
|
||||
#[napi]
|
||||
#[napi(async_runtime)]
|
||||
pub async fn get_blob(&self, universal_id: String, key: String) -> Result<Option<Blob>> {
|
||||
Ok(
|
||||
self
|
||||
.pool
|
||||
.ensure_storage(universal_id)?
|
||||
.get_blob(key)
|
||||
.await?,
|
||||
)
|
||||
Ok(self.get(universal_id).await?.get_blob(key).await?)
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub async fn set_blob(&self, universal_id: String, blob: SetBlob) -> Result<()> {
|
||||
self
|
||||
.pool
|
||||
.ensure_storage(universal_id)?
|
||||
.set_blob(blob)
|
||||
.await?;
|
||||
self.get(universal_id).await?.set_blob(blob).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -264,8 +238,8 @@ impl DocStoragePool {
|
||||
permanently: bool,
|
||||
) -> Result<()> {
|
||||
self
|
||||
.pool
|
||||
.ensure_storage(universal_id)?
|
||||
.get(universal_id)
|
||||
.await?
|
||||
.delete_blob(key, permanently)
|
||||
.await?;
|
||||
Ok(())
|
||||
@@ -273,17 +247,13 @@ impl DocStoragePool {
|
||||
|
||||
#[napi]
|
||||
pub async fn release_blobs(&self, universal_id: String) -> Result<()> {
|
||||
self
|
||||
.pool
|
||||
.ensure_storage(universal_id)?
|
||||
.release_blobs()
|
||||
.await?;
|
||||
self.get(universal_id).await?.release_blobs().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub async fn list_blobs(&self, universal_id: String) -> Result<Vec<ListedBlob>> {
|
||||
Ok(self.pool.ensure_storage(universal_id)?.list_blobs().await?)
|
||||
Ok(self.get(universal_id).await?.list_blobs().await?)
|
||||
}
|
||||
|
||||
#[napi]
|
||||
@@ -294,8 +264,8 @@ impl DocStoragePool {
|
||||
) -> Result<Vec<DocClock>> {
|
||||
Ok(
|
||||
self
|
||||
.pool
|
||||
.ensure_storage(universal_id)?
|
||||
.get(universal_id)
|
||||
.await?
|
||||
.get_peer_remote_clocks(peer)
|
||||
.await?,
|
||||
)
|
||||
@@ -310,8 +280,8 @@ impl DocStoragePool {
|
||||
) -> Result<Option<DocClock>> {
|
||||
Ok(
|
||||
self
|
||||
.pool
|
||||
.ensure_storage(universal_id)?
|
||||
.get(universal_id)
|
||||
.await?
|
||||
.get_peer_remote_clock(peer, doc_id)
|
||||
.await?,
|
||||
)
|
||||
@@ -326,8 +296,8 @@ impl DocStoragePool {
|
||||
clock: NaiveDateTime,
|
||||
) -> Result<()> {
|
||||
self
|
||||
.pool
|
||||
.ensure_storage(universal_id)?
|
||||
.get(universal_id)
|
||||
.await?
|
||||
.set_peer_remote_clock(peer, doc_id, clock)
|
||||
.await?;
|
||||
Ok(())
|
||||
@@ -341,8 +311,8 @@ impl DocStoragePool {
|
||||
) -> Result<Vec<DocClock>> {
|
||||
Ok(
|
||||
self
|
||||
.pool
|
||||
.ensure_storage(universal_id)?
|
||||
.get(universal_id)
|
||||
.await?
|
||||
.get_peer_pulled_remote_clocks(peer)
|
||||
.await?,
|
||||
)
|
||||
@@ -357,8 +327,8 @@ impl DocStoragePool {
|
||||
) -> Result<Option<DocClock>> {
|
||||
Ok(
|
||||
self
|
||||
.pool
|
||||
.ensure_storage(universal_id)?
|
||||
.get(universal_id)
|
||||
.await?
|
||||
.get_peer_pulled_remote_clock(peer, doc_id)
|
||||
.await?,
|
||||
)
|
||||
@@ -373,8 +343,8 @@ impl DocStoragePool {
|
||||
clock: NaiveDateTime,
|
||||
) -> Result<()> {
|
||||
self
|
||||
.pool
|
||||
.ensure_storage(universal_id)?
|
||||
.get(universal_id)
|
||||
.await?
|
||||
.set_peer_pulled_remote_clock(peer, doc_id, clock)
|
||||
.await?;
|
||||
Ok(())
|
||||
@@ -388,8 +358,8 @@ impl DocStoragePool {
|
||||
) -> Result<Vec<DocClock>> {
|
||||
Ok(
|
||||
self
|
||||
.pool
|
||||
.ensure_storage(universal_id)?
|
||||
.get(universal_id)
|
||||
.await?
|
||||
.get_peer_pushed_clocks(peer)
|
||||
.await?,
|
||||
)
|
||||
@@ -404,8 +374,8 @@ impl DocStoragePool {
|
||||
) -> Result<Option<DocClock>> {
|
||||
Ok(
|
||||
self
|
||||
.pool
|
||||
.ensure_storage(universal_id)?
|
||||
.get(universal_id)
|
||||
.await?
|
||||
.get_peer_pushed_clock(peer, doc_id)
|
||||
.await?,
|
||||
)
|
||||
@@ -420,8 +390,8 @@ impl DocStoragePool {
|
||||
clock: NaiveDateTime,
|
||||
) -> Result<()> {
|
||||
self
|
||||
.pool
|
||||
.ensure_storage(universal_id)?
|
||||
.get(universal_id)
|
||||
.await?
|
||||
.set_peer_pushed_clock(peer, doc_id, clock)
|
||||
.await?;
|
||||
Ok(())
|
||||
@@ -429,11 +399,7 @@ impl DocStoragePool {
|
||||
|
||||
#[napi]
|
||||
pub async fn clear_clocks(&self, universal_id: String) -> Result<()> {
|
||||
self
|
||||
.pool
|
||||
.ensure_storage(universal_id)?
|
||||
.clear_clocks()
|
||||
.await?;
|
||||
self.get(universal_id).await?.clear_clocks().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -461,7 +427,6 @@ impl DocStorage {
|
||||
pub async fn set_space_id(&self, space_id: String) -> Result<()> {
|
||||
self.storage.connect().await?;
|
||||
self.storage.set_space_id(space_id).await?;
|
||||
println!("clocks {:?}", self.storage.get_doc_clocks(None).await?);
|
||||
self.storage.close().await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,54 +1,96 @@
|
||||
use dashmap::{mapref::one::RefMut, DashMap, Entry};
|
||||
use core::ops::{Deref, DerefMut};
|
||||
use std::collections::hash_map::{Entry, HashMap};
|
||||
|
||||
use tokio::sync::{RwLock, RwLockMappedWriteGuard, RwLockReadGuard, RwLockWriteGuard};
|
||||
|
||||
use super::{
|
||||
error::{Error, Result},
|
||||
storage::SqliteDocStorage,
|
||||
};
|
||||
|
||||
pub struct Ref<'a, V> {
|
||||
_guard: RwLockReadGuard<'a, V>,
|
||||
}
|
||||
|
||||
impl<'a, V> Deref for Ref<'a, V> {
|
||||
type Target = V;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self._guard.deref()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RefMut<'a, V> {
|
||||
_guard: RwLockMappedWriteGuard<'a, V>,
|
||||
}
|
||||
|
||||
impl<'a, V> Deref for RefMut<'a, V> {
|
||||
type Target = V;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&*self._guard
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, V> DerefMut for RefMut<'a, V> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut *self._guard
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct SqliteDocStoragePool {
|
||||
inner: DashMap<String, SqliteDocStorage>,
|
||||
inner: RwLock<HashMap<String, SqliteDocStorage>>,
|
||||
}
|
||||
|
||||
impl SqliteDocStoragePool {
|
||||
fn get_or_create_storage<'a>(
|
||||
async fn get_or_create_storage<'a>(
|
||||
&'a self,
|
||||
universal_id: String,
|
||||
path: &str,
|
||||
) -> RefMut<'a, String, SqliteDocStorage> {
|
||||
let entry = self.inner.entry(universal_id);
|
||||
if let Entry::Occupied(storage) = entry {
|
||||
return storage.into_ref();
|
||||
}
|
||||
let storage = SqliteDocStorage::new(path.to_string());
|
||||
) -> RefMut<'a, SqliteDocStorage> {
|
||||
let lock = RwLockWriteGuard::map(self.inner.write().await, |lock| {
|
||||
match lock.entry(universal_id) {
|
||||
Entry::Occupied(entry) => entry.into_mut(),
|
||||
Entry::Vacant(entry) => {
|
||||
let storage = SqliteDocStorage::new(path.to_string());
|
||||
entry.insert(storage)
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
entry.or_insert(storage)
|
||||
RefMut { _guard: lock }
|
||||
}
|
||||
|
||||
pub fn ensure_storage<'a>(
|
||||
&'a self,
|
||||
universal_id: String,
|
||||
) -> Result<RefMut<'a, String, SqliteDocStorage>> {
|
||||
let entry = self.inner.entry(universal_id);
|
||||
pub async fn get<'a>(&'a self, universal_id: String) -> Result<Ref<'a, SqliteDocStorage>> {
|
||||
let lock = RwLockReadGuard::try_map(self.inner.read().await, |lock| {
|
||||
if let Some(storage) = lock.get(&universal_id) {
|
||||
Some(storage)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
|
||||
if let Entry::Occupied(storage) = entry {
|
||||
Ok(storage.into_ref())
|
||||
} else {
|
||||
Err(Error::InvalidOperation)
|
||||
match lock {
|
||||
Ok(guard) => Ok(Ref { _guard: guard }),
|
||||
Err(_) => Err(Error::InvalidOperation),
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize the database and run migrations.
|
||||
pub async fn connect(&self, universal_id: String, path: String) -> Result<()> {
|
||||
let storage = self.get_or_create_storage(universal_id.to_owned(), &path);
|
||||
let storage = self
|
||||
.get_or_create_storage(universal_id.to_owned(), &path)
|
||||
.await;
|
||||
|
||||
storage.connect().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn disconnect(&self, universal_id: String) -> Result<()> {
|
||||
let entry = self.inner.entry(universal_id);
|
||||
let mut lock = self.inner.write().await;
|
||||
|
||||
if let Entry::Occupied(entry) = entry {
|
||||
if let Entry::Occupied(entry) = lock.entry(universal_id) {
|
||||
let storage = entry.remove();
|
||||
storage.close().await;
|
||||
}
|
||||
|
||||
@@ -16,24 +16,28 @@ impl SqliteDocStorage {
|
||||
pub fn new(path: String) -> Self {
|
||||
let sqlite_options = SqliteConnectOptions::new()
|
||||
.filename(&path)
|
||||
.foreign_keys(false)
|
||||
.journal_mode(sqlx::sqlite::SqliteJournalMode::Wal);
|
||||
.foreign_keys(false);
|
||||
|
||||
let mut pool_options = SqlitePoolOptions::new();
|
||||
|
||||
if cfg!(test) && path == ":memory:" {
|
||||
if path == ":memory:" {
|
||||
pool_options = pool_options
|
||||
.min_connections(1)
|
||||
.max_connections(1)
|
||||
.idle_timeout(None)
|
||||
.max_lifetime(None);
|
||||
} else {
|
||||
pool_options = pool_options.max_connections(4);
|
||||
}
|
||||
|
||||
Self {
|
||||
pool: pool_options.connect_lazy_with(sqlite_options),
|
||||
path,
|
||||
Self {
|
||||
pool: pool_options.connect_lazy_with(sqlite_options),
|
||||
path,
|
||||
}
|
||||
} else {
|
||||
Self {
|
||||
pool: pool_options
|
||||
.max_connections(4)
|
||||
.connect_lazy_with(sqlite_options.journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)),
|
||||
path,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -20,17 +20,9 @@
|
||||
"extensions": {
|
||||
"mts": "module"
|
||||
},
|
||||
"nodeArguments": [
|
||||
"--loader",
|
||||
"ts-node/esm.mjs",
|
||||
"--es-module-specifier-resolution=node"
|
||||
],
|
||||
"files": [
|
||||
"__tests__/*.spec.mts"
|
||||
],
|
||||
"environmentVariables": {
|
||||
"TS_NODE_PROJECT": "./tsconfig.json"
|
||||
}
|
||||
]
|
||||
},
|
||||
"devDependencies": {
|
||||
"@napi-rs/cli": "3.0.0-alpha.68",
|
||||
|
||||
Reference in New Issue
Block a user