mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-17 22:37:04 +08:00
feat: introduce fuzzy search for native indexer (#14109)
This commit is contained in:
@@ -121,7 +121,7 @@ export const Form = () => {
|
||||
console.error(err);
|
||||
throw err;
|
||||
}
|
||||
}, [emailValue, passwordValue, refreshServerConfig]);
|
||||
}, [nameValue, emailValue, passwordValue, refreshServerConfig]);
|
||||
|
||||
const onNext = useAsyncCallback(async () => {
|
||||
if (isCreateAdminStep) {
|
||||
|
||||
@@ -171,7 +171,7 @@ export interface NbStorePlugin {
|
||||
id: string;
|
||||
indexName: string;
|
||||
query: string;
|
||||
}) => Promise<{ id: string; score: number }[]>;
|
||||
}) => Promise<{ id: string; score: number; terms: Array<string> }[]>;
|
||||
ftsGetDocument: (options: {
|
||||
id: string;
|
||||
indexName: string;
|
||||
|
||||
@@ -369,7 +369,7 @@ export const NbStoreNativeDBApis: NativeDBApis = {
|
||||
id: string,
|
||||
indexName: string,
|
||||
query: string
|
||||
): Promise<{ id: string; score: number }[]> {
|
||||
): Promise<{ id: string; score: number; terms: Array<string> }[]> {
|
||||
return await NbStore.ftsSearch({
|
||||
id,
|
||||
indexName,
|
||||
|
||||
@@ -171,7 +171,7 @@ export interface NbStorePlugin {
|
||||
id: string;
|
||||
indexName: string;
|
||||
query: string;
|
||||
}) => Promise<{ id: string; score: number }[]>;
|
||||
}) => Promise<{ id: string; score: number; terms: Array<string> }[]>;
|
||||
ftsGetDocument: (options: {
|
||||
id: string;
|
||||
indexName: string;
|
||||
|
||||
@@ -373,7 +373,7 @@ export const NbStoreNativeDBApis: NativeDBApis = {
|
||||
id: string,
|
||||
indexName: string,
|
||||
query: string
|
||||
): Promise<{ id: string; score: number }[]> {
|
||||
): Promise<{ id: string; score: number; terms: Array<string> }[]> {
|
||||
return await NbStore.ftsSearch({
|
||||
id,
|
||||
indexName,
|
||||
|
||||
1
packages/frontend/native/index.d.ts
vendored
1
packages/frontend/native/index.d.ts
vendored
@@ -148,6 +148,7 @@ export interface NativeMatch {
|
||||
export interface NativeSearchHit {
|
||||
id: string
|
||||
score: number
|
||||
terms: Array<string>
|
||||
}
|
||||
|
||||
export interface SetBlob {
|
||||
|
||||
@@ -171,7 +171,7 @@ impl std::fmt::Display for AudioFormatFlags {
|
||||
|
||||
impl std::fmt::Debug for AudioFormatFlags {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "AudioFormatFlags({})", self)
|
||||
write!(f, "AudioFormatFlags({self})")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -354,7 +354,7 @@ impl AggregateDevice {
|
||||
input_device_id: CFString,
|
||||
output_device_id: CFString,
|
||||
) -> Result<CFDictionary<CFType, CFType>> {
|
||||
let aggregate_device_name = CFString::new(&format!("Tap-{}", tap_id));
|
||||
let aggregate_device_name = CFString::new(&format!("Tap-{tap_id}"));
|
||||
let aggregate_device_uid: uuid::Uuid = CFUUID::new().into();
|
||||
let aggregate_device_uid_string = aggregate_device_uid.to_string();
|
||||
|
||||
@@ -469,18 +469,12 @@ impl AudioTapStream {
|
||||
// Ignore errors as device might be disconnected
|
||||
let status = unsafe { AudioDeviceStop(self.input_device_id, proc_id) };
|
||||
if status != 0 {
|
||||
println!(
|
||||
"DEBUG: WARNING: Input device stop failed with status: {}",
|
||||
status
|
||||
);
|
||||
println!("DEBUG: WARNING: Input device stop failed with status: {status}");
|
||||
}
|
||||
|
||||
let status = unsafe { AudioDeviceDestroyIOProcID(self.input_device_id, proc_id) };
|
||||
if status != 0 {
|
||||
println!(
|
||||
"DEBUG: WARNING: Input device destroy IO proc failed with status: {}",
|
||||
status
|
||||
);
|
||||
println!("DEBUG: WARNING: Input device destroy IO proc failed with status: {status}");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -489,18 +483,12 @@ impl AudioTapStream {
|
||||
// Ignore errors as device might be disconnected
|
||||
let status = unsafe { AudioDeviceStop(self.output_device_id, proc_id) };
|
||||
if status != 0 {
|
||||
println!(
|
||||
"DEBUG: WARNING: Output device stop failed with status: {}",
|
||||
status
|
||||
);
|
||||
println!("DEBUG: WARNING: Output device stop failed with status: {status}");
|
||||
}
|
||||
|
||||
let status = unsafe { AudioDeviceDestroyIOProcID(self.output_device_id, proc_id) };
|
||||
if status != 0 {
|
||||
println!(
|
||||
"DEBUG: WARNING: Output device destroy IO proc failed with status: {}",
|
||||
status
|
||||
);
|
||||
println!("DEBUG: WARNING: Output device destroy IO proc failed with status: {status}");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -508,27 +496,18 @@ impl AudioTapStream {
|
||||
if device_exists {
|
||||
let status = unsafe { AudioDeviceDestroyIOProcID(self.device_id, self.in_proc_id) };
|
||||
if status != 0 {
|
||||
println!(
|
||||
"DEBUG: WARNING: Destroy IO proc failed with status: {}",
|
||||
status
|
||||
);
|
||||
println!("DEBUG: WARNING: Destroy IO proc failed with status: {status}");
|
||||
}
|
||||
}
|
||||
let status = unsafe { AudioHardwareDestroyAggregateDevice(self.device_id) };
|
||||
if status != 0 {
|
||||
println!(
|
||||
"DEBUG: WARNING: AudioHardwareDestroyAggregateDevice failed with status: {}",
|
||||
status
|
||||
);
|
||||
println!("DEBUG: WARNING: AudioHardwareDestroyAggregateDevice failed with status: {status}");
|
||||
}
|
||||
|
||||
// Destroy the process tap - don't fail if this fails
|
||||
let status = unsafe { AudioHardwareDestroyProcessTap(self.device_id) };
|
||||
if status != 0 {
|
||||
println!(
|
||||
"DEBUG: WARNING: AudioHardwareDestroyProcessTap failed with status: {}",
|
||||
status
|
||||
);
|
||||
println!("DEBUG: WARNING: AudioHardwareDestroyProcessTap failed with status: {status}");
|
||||
}
|
||||
|
||||
// destroy the queue
|
||||
@@ -743,10 +722,7 @@ impl AggregateDeviceManager {
|
||||
let stop_result = old_stream.stop();
|
||||
match stop_result {
|
||||
Ok(_) => {}
|
||||
Err(e) => println!(
|
||||
"DEBUG: Error stopping old stream (proceeding anyway): {}",
|
||||
e
|
||||
),
|
||||
Err(e) => println!("DEBUG: Error stopping old stream (proceeding anyway): {e}"),
|
||||
};
|
||||
drop(old_stream); // Ensure it's dropped now
|
||||
}
|
||||
@@ -757,12 +733,12 @@ impl AggregateDeviceManager {
|
||||
*stream_guard = Some(new_stream);
|
||||
}
|
||||
Err(e) => {
|
||||
println!("DEBUG: Failed to start new stream: {}", e);
|
||||
println!("DEBUG: Failed to start new stream: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
println!("DEBUG: Failed to create new device: {}", e);
|
||||
println!("DEBUG: Failed to create new device: {e}");
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -786,10 +762,7 @@ impl AggregateDeviceManager {
|
||||
);
|
||||
|
||||
if status != 0 {
|
||||
println!(
|
||||
"DEBUG: Failed to register input device listener, status: {}",
|
||||
status
|
||||
);
|
||||
println!("DEBUG: Failed to register input device listener, status: {status}");
|
||||
return Err(CoreAudioError::AddPropertyListenerBlockFailed(status).into());
|
||||
}
|
||||
|
||||
@@ -805,10 +778,7 @@ impl AggregateDeviceManager {
|
||||
);
|
||||
|
||||
if status != 0 {
|
||||
println!(
|
||||
"DEBUG: Failed to register output device listener, status: {}",
|
||||
status
|
||||
);
|
||||
println!("DEBUG: Failed to register output device listener, status: {status}");
|
||||
// Clean up the first listener if the second one fails
|
||||
AudioObjectRemovePropertyListenerBlock(
|
||||
kAudioObjectSystemObject,
|
||||
@@ -907,10 +877,7 @@ impl AggregateDeviceManager {
|
||||
if let Some(mut stream) = stream_to_stop {
|
||||
match stream.stop() {
|
||||
Ok(_) => {}
|
||||
Err(e) => println!(
|
||||
"DEBUG: Error stopping stream in stop_capture (ignored): {}",
|
||||
e
|
||||
),
|
||||
Err(e) => println!("DEBUG: Error stopping stream in stop_capture (ignored): {e}"),
|
||||
}
|
||||
// Explicitly drop here after stopping
|
||||
drop(stream);
|
||||
@@ -960,7 +927,7 @@ impl AggregateDeviceManager {
|
||||
match stream.get_actual_sample_rate() {
|
||||
Ok(rate) => Ok(Some(rate)),
|
||||
Err(e) => {
|
||||
println!("DEBUG: Error getting actual sample rate from stream: {}", e);
|
||||
println!("DEBUG: Error getting actual sample rate from stream: {e}");
|
||||
// Propagate the error
|
||||
Err(e)
|
||||
}
|
||||
@@ -976,7 +943,7 @@ impl Drop for AggregateDeviceManager {
|
||||
// Call stop_capture which handles listener cleanup and stream stopping
|
||||
match self.stop_capture() {
|
||||
Ok(_) => {}
|
||||
Err(e) => println!("DEBUG: Error during stop_capture in Drop (ignored): {}", e),
|
||||
Err(e) => println!("DEBUG: Error during stop_capture in Drop (ignored): {e}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -73,8 +73,8 @@ impl BufferedResampler {
|
||||
} else {
|
||||
// interleave
|
||||
let out_len = out_blocks[0].len();
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
for i in 0..out_len {
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
// apply clippy lint suggestion would regress performance
|
||||
for ch in 0..self.channels {
|
||||
interleaved_out.push(out_blocks[ch][i]);
|
||||
|
||||
@@ -15,10 +15,9 @@ affine_schema = { path = "../schema" }
|
||||
anyhow = { workspace = true }
|
||||
bincode = { version = "2.0.1", features = ["serde"] }
|
||||
chrono = { workspace = true }
|
||||
jieba-rs = "0.8.1"
|
||||
memory-indexer = { workspace = true }
|
||||
napi = { workspace = true }
|
||||
napi-derive = { workspace = true }
|
||||
once_cell = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
sqlx = { workspace = true, default-features = false, features = [
|
||||
"chrono",
|
||||
@@ -29,7 +28,6 @@ sqlx = { workspace = true, default-features = false, features = [
|
||||
"tls-rustls",
|
||||
] }
|
||||
thiserror = { workspace = true }
|
||||
tiniestsegmenter = "0.3"
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
y-octo = { workspace = true }
|
||||
zstd = "0.13"
|
||||
|
||||
@@ -1,13 +1,8 @@
|
||||
mod memory_indexer;
|
||||
mod tokenizer;
|
||||
mod types;
|
||||
|
||||
use affine_common::doc_parser::{parse_doc_from_binary, ParseError};
|
||||
pub use memory_indexer::InMemoryIndex;
|
||||
use affine_common::doc_parser::{parse_doc_from_binary, BlockInfo, CrawlResult, ParseError};
|
||||
use memory_indexer::{SearchHit, SnapshotData};
|
||||
use napi_derive::napi;
|
||||
use serde::Serialize;
|
||||
use sqlx::Row;
|
||||
pub use types::{
|
||||
DocData, NativeBlockInfo, NativeCrawlResult, NativeMatch, NativeSearchHit, SnapshotData,
|
||||
};
|
||||
use y_octo::DocOptions;
|
||||
|
||||
use super::{
|
||||
@@ -15,6 +10,88 @@ use super::{
|
||||
storage::SqliteDocStorage,
|
||||
};
|
||||
|
||||
#[napi(object)]
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct NativeBlockInfo {
|
||||
pub block_id: String,
|
||||
pub flavour: String,
|
||||
pub content: Option<Vec<String>>,
|
||||
pub blob: Option<Vec<String>>,
|
||||
pub ref_doc_id: Option<Vec<String>>,
|
||||
pub ref_info: Option<Vec<String>>,
|
||||
pub parent_flavour: Option<String>,
|
||||
pub parent_block_id: Option<String>,
|
||||
pub additional: Option<String>,
|
||||
}
|
||||
|
||||
impl From<BlockInfo> for NativeBlockInfo {
|
||||
fn from(value: BlockInfo) -> Self {
|
||||
Self {
|
||||
block_id: value.block_id,
|
||||
flavour: value.flavour,
|
||||
content: value.content,
|
||||
blob: value.blob,
|
||||
ref_doc_id: value.ref_doc_id,
|
||||
ref_info: value.ref_info,
|
||||
parent_flavour: value.parent_flavour,
|
||||
parent_block_id: value.parent_block_id,
|
||||
additional: value.additional,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[napi(object)]
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct NativeCrawlResult {
|
||||
pub blocks: Vec<NativeBlockInfo>,
|
||||
pub title: String,
|
||||
pub summary: String,
|
||||
}
|
||||
|
||||
impl From<CrawlResult> for NativeCrawlResult {
|
||||
fn from(value: CrawlResult) -> Self {
|
||||
Self {
|
||||
blocks: value.blocks.into_iter().map(Into::into).collect(),
|
||||
title: value.title,
|
||||
summary: value.summary,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[napi(object)]
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct NativeSearchHit {
|
||||
pub id: String,
|
||||
pub score: f64,
|
||||
pub terms: Vec<String>,
|
||||
}
|
||||
|
||||
impl From<SearchHit> for NativeSearchHit {
|
||||
fn from(value: SearchHit) -> Self {
|
||||
Self {
|
||||
id: value.doc_id,
|
||||
score: value.score,
|
||||
terms: value.matched_terms.into_iter().map(|t| t.term).collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[napi(object)]
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct NativeMatch {
|
||||
pub start: u32,
|
||||
pub end: u32,
|
||||
}
|
||||
|
||||
impl From<(u32, u32)> for NativeMatch {
|
||||
fn from(value: (u32, u32)) -> Self {
|
||||
Self {
|
||||
start: value.0,
|
||||
end: value.1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SqliteDocStorage {
|
||||
pub async fn crawl_doc_data(&self, doc_id: &str) -> Result<NativeCrawlResult> {
|
||||
let doc_bin = self
|
||||
@@ -53,14 +130,14 @@ impl SqliteDocStorage {
|
||||
|
||||
{
|
||||
let mut index = self.index.write().await;
|
||||
let config = bincode::config::standard();
|
||||
for row in snapshots {
|
||||
let index_name: String = row.get("index_name");
|
||||
let data: Vec<u8> = row.get("data");
|
||||
if let Ok(decompressed) = zstd::stream::decode_all(std::io::Cursor::new(&data)) {
|
||||
if let Ok((snapshot, _)) = bincode::serde::decode_from_slice::<SnapshotData, _>(
|
||||
&decompressed,
|
||||
bincode::config::standard(),
|
||||
) {
|
||||
if let Ok((snapshot, _)) =
|
||||
bincode::serde::decode_from_slice::<SnapshotData, _>(&decompressed, config)
|
||||
{
|
||||
index.load_snapshot(&index_name, snapshot);
|
||||
}
|
||||
}
|
||||
@@ -79,7 +156,7 @@ impl SqliteDocStorage {
|
||||
if let Some(data) = snapshot_data {
|
||||
let blob = bincode::serde::encode_to_vec(&data, bincode::config::standard())
|
||||
.map_err(|e| Error::Serialization(e.to_string()))?;
|
||||
let compressed = zstd::stream::encode_all(std::io::Cursor::new(&blob), 0)
|
||||
let compressed = zstd::stream::encode_all(std::io::Cursor::new(&blob), 4)
|
||||
.map_err(|e| Error::Serialization(e.to_string()))?;
|
||||
|
||||
let mut tx = self.pool.begin().await?;
|
||||
@@ -147,9 +224,9 @@ impl SqliteDocStorage {
|
||||
let idx = self.index.read().await;
|
||||
Ok(
|
||||
idx
|
||||
.search(index_name, query)
|
||||
.search_hits(index_name, query)
|
||||
.into_iter()
|
||||
.map(|(id, score)| NativeSearchHit { id, score })
|
||||
.map(Into::into)
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
@@ -165,7 +242,23 @@ impl SqliteDocStorage {
|
||||
idx
|
||||
.get_matches(index_name, doc_id, query)
|
||||
.into_iter()
|
||||
.map(|(start, end)| NativeMatch { start, end })
|
||||
.map(Into::into)
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn fts_get_matches_for_terms(
|
||||
&self,
|
||||
index_name: &str,
|
||||
doc_id: &str,
|
||||
terms: Vec<String>,
|
||||
) -> Result<Vec<NativeMatch>> {
|
||||
let idx = self.index.read().await;
|
||||
Ok(
|
||||
idx
|
||||
.get_matches_for_terms(index_name, doc_id, &terms)
|
||||
.into_iter()
|
||||
.map(Into::into)
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
@@ -206,8 +299,8 @@ mod tests {
|
||||
|
||||
use super::{super::error::Error, *};
|
||||
|
||||
const DEMO_BIN: &[u8] = include_bytes!("../../../../../common/native/fixtures/demo.ydoc");
|
||||
const DEMO_JSON: &[u8] = include_bytes!("../../../../../common/native/fixtures/demo.ydoc.json");
|
||||
const DEMO_BIN: &[u8] = include_bytes!("../../../../common/native/fixtures/demo.ydoc");
|
||||
const DEMO_JSON: &[u8] = include_bytes!("../../../../common/native/fixtures/demo.ydoc.json");
|
||||
|
||||
fn temp_workspace_dir() -> PathBuf {
|
||||
std::env::temp_dir().join(format!("affine-native-{}", Uuid::new_v4()))
|
||||
@@ -1,261 +0,0 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use super::{
|
||||
tokenizer::tokenize,
|
||||
types::{DocData, SnapshotData},
|
||||
};
|
||||
|
||||
type DirtyDoc = (String, String, String, i64);
|
||||
type DeletedDoc = HashMap<String, HashSet<String>>;
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct InMemoryIndex {
|
||||
pub docs: HashMap<String, HashMap<String, DocData>>,
|
||||
pub inverted: HashMap<String, HashMap<String, HashMap<String, i64>>>,
|
||||
pub total_lens: HashMap<String, i64>,
|
||||
pub dirty: HashMap<String, HashSet<String>>,
|
||||
pub deleted: HashMap<String, HashSet<String>>,
|
||||
}
|
||||
|
||||
impl InMemoryIndex {
|
||||
pub fn add_doc(&mut self, index_name: &str, doc_id: &str, text: &str, index: bool) {
|
||||
let tokens = if index { tokenize(text) } else { vec![] };
|
||||
// doc_len should be the number of tokens (including duplicates)
|
||||
let doc_len = tokens.len() as i64;
|
||||
|
||||
let mut pos_map: HashMap<String, Vec<(u32, u32)>> = HashMap::new();
|
||||
for token in tokens {
|
||||
pos_map
|
||||
.entry(token.term)
|
||||
.or_default()
|
||||
.push((token.start as u32, token.end as u32));
|
||||
}
|
||||
|
||||
if let Some(docs) = self.docs.get_mut(index_name) {
|
||||
if let Some(old_data) = docs.remove(doc_id) {
|
||||
*self.total_lens.entry(index_name.to_string()).or_default() -= old_data.doc_len;
|
||||
|
||||
if let Some(inverted) = self.inverted.get_mut(index_name) {
|
||||
for (term, _) in old_data.term_pos {
|
||||
if let Some(doc_map) = inverted.get_mut(&term) {
|
||||
doc_map.remove(doc_id);
|
||||
if doc_map.is_empty() {
|
||||
inverted.remove(&term);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let doc_data = DocData {
|
||||
content: text.to_string(),
|
||||
doc_len,
|
||||
term_pos: pos_map.clone(),
|
||||
};
|
||||
|
||||
self
|
||||
.docs
|
||||
.entry(index_name.to_string())
|
||||
.or_default()
|
||||
.insert(doc_id.to_string(), doc_data);
|
||||
*self.total_lens.entry(index_name.to_string()).or_default() += doc_len;
|
||||
|
||||
let inverted = self.inverted.entry(index_name.to_string()).or_default();
|
||||
for (term, positions) in pos_map {
|
||||
inverted
|
||||
.entry(term)
|
||||
.or_default()
|
||||
.insert(doc_id.to_string(), positions.len() as i64);
|
||||
}
|
||||
|
||||
self
|
||||
.dirty
|
||||
.entry(index_name.to_string())
|
||||
.or_default()
|
||||
.insert(doc_id.to_string());
|
||||
if let Some(deleted) = self.deleted.get_mut(index_name) {
|
||||
deleted.remove(doc_id);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn remove_doc(&mut self, index_name: &str, doc_id: &str) {
|
||||
if let Some(docs) = self.docs.get_mut(index_name) {
|
||||
if let Some(old_data) = docs.remove(doc_id) {
|
||||
*self.total_lens.entry(index_name.to_string()).or_default() -= old_data.doc_len;
|
||||
|
||||
if let Some(inverted) = self.inverted.get_mut(index_name) {
|
||||
for (term, _) in old_data.term_pos {
|
||||
if let Some(doc_map) = inverted.get_mut(&term) {
|
||||
doc_map.remove(doc_id);
|
||||
if doc_map.is_empty() {
|
||||
inverted.remove(&term);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self
|
||||
.deleted
|
||||
.entry(index_name.to_string())
|
||||
.or_default()
|
||||
.insert(doc_id.to_string());
|
||||
if let Some(dirty) = self.dirty.get_mut(index_name) {
|
||||
dirty.remove(doc_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_doc(&self, index_name: &str, doc_id: &str) -> Option<String> {
|
||||
self
|
||||
.docs
|
||||
.get(index_name)
|
||||
.and_then(|docs| docs.get(doc_id))
|
||||
.map(|d| d.content.clone())
|
||||
}
|
||||
|
||||
pub fn search(&self, index_name: &str, query: &str) -> Vec<(String, f64)> {
|
||||
if query == "*" || query.is_empty() {
|
||||
if let Some(docs) = self.docs.get(index_name) {
|
||||
return docs.keys().map(|k| (k.clone(), 1.0)).collect();
|
||||
}
|
||||
return vec![];
|
||||
}
|
||||
|
||||
let query_terms = tokenize(query);
|
||||
if query_terms.is_empty() {
|
||||
return vec![];
|
||||
}
|
||||
|
||||
let inverted = match self.inverted.get(index_name) {
|
||||
Some(i) => i,
|
||||
None => return vec![],
|
||||
};
|
||||
|
||||
let mut candidates: Option<HashSet<String>> = None;
|
||||
|
||||
for token in &query_terms {
|
||||
if let Some(doc_map) = inverted.get(&token.term) {
|
||||
let docs: HashSet<String> = doc_map.keys().cloned().collect();
|
||||
match candidates {
|
||||
None => candidates = Some(docs),
|
||||
Some(ref mut c) => {
|
||||
c.retain(|id| docs.contains(id));
|
||||
}
|
||||
}
|
||||
if candidates.as_ref().unwrap().is_empty() {
|
||||
return vec![];
|
||||
}
|
||||
} else {
|
||||
return vec![];
|
||||
}
|
||||
}
|
||||
|
||||
let candidates = candidates.unwrap_or_default();
|
||||
if candidates.is_empty() {
|
||||
return vec![];
|
||||
}
|
||||
|
||||
let docs = self.docs.get(index_name).unwrap();
|
||||
let total_len = *self.total_lens.get(index_name).unwrap_or(&0);
|
||||
let n = docs.len() as f64;
|
||||
let avgdl = if n > 0.0 { total_len as f64 / n } else { 0.0 };
|
||||
|
||||
let k1 = 1.2;
|
||||
let b = 0.75;
|
||||
|
||||
let mut scores: Vec<(String, f64)> = Vec::with_capacity(candidates.len());
|
||||
|
||||
let mut idfs = HashMap::new();
|
||||
for token in &query_terms {
|
||||
let n_q = inverted.get(&token.term).map(|m| m.len()).unwrap_or(0) as f64;
|
||||
let idf = ((n - n_q + 0.5) / (n_q + 0.5) + 1.0).ln();
|
||||
idfs.insert(&token.term, idf);
|
||||
}
|
||||
|
||||
for doc_id in candidates {
|
||||
let doc_data = docs.get(&doc_id).unwrap();
|
||||
let mut score = 0.0;
|
||||
|
||||
for token in &query_terms {
|
||||
if let Some(positions) = doc_data.term_pos.get(&token.term) {
|
||||
let freq = positions.len() as f64;
|
||||
let idf = idfs.get(&token.term).unwrap();
|
||||
let numerator = freq * (k1 + 1.0);
|
||||
let denominator = freq + k1 * (1.0 - b + b * (doc_data.doc_len as f64 / avgdl));
|
||||
score += idf * (numerator / denominator);
|
||||
}
|
||||
}
|
||||
scores.push((doc_id, score));
|
||||
}
|
||||
|
||||
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
scores
|
||||
}
|
||||
|
||||
pub fn take_dirty_and_deleted(&mut self) -> (Vec<DirtyDoc>, DeletedDoc) {
|
||||
let dirty = std::mem::take(&mut self.dirty);
|
||||
let deleted = std::mem::take(&mut self.deleted);
|
||||
|
||||
let mut dirty_data = Vec::new();
|
||||
for (index_name, doc_ids) in &dirty {
|
||||
if let Some(docs) = self.docs.get(index_name) {
|
||||
for doc_id in doc_ids {
|
||||
if let Some(data) = docs.get(doc_id) {
|
||||
dirty_data.push((
|
||||
index_name.clone(),
|
||||
doc_id.clone(),
|
||||
data.content.clone(),
|
||||
data.doc_len,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
(dirty_data, deleted)
|
||||
}
|
||||
|
||||
pub fn get_matches(&self, index_name: &str, doc_id: &str, query: &str) -> Vec<(u32, u32)> {
|
||||
let mut matches = Vec::new();
|
||||
if let Some(docs) = self.docs.get(index_name) {
|
||||
if let Some(doc_data) = docs.get(doc_id) {
|
||||
let query_tokens = tokenize(query);
|
||||
for token in query_tokens {
|
||||
if let Some(positions) = doc_data.term_pos.get(&token.term) {
|
||||
matches.extend(positions.iter().cloned());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
matches.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
matches
|
||||
}
|
||||
|
||||
pub fn load_snapshot(&mut self, index_name: &str, snapshot: SnapshotData) {
|
||||
let docs = self.docs.entry(index_name.to_string()).or_default();
|
||||
let inverted = self.inverted.entry(index_name.to_string()).or_default();
|
||||
let total_len = self.total_lens.entry(index_name.to_string()).or_default();
|
||||
|
||||
for (doc_id, doc_data) in snapshot.docs {
|
||||
*total_len += doc_data.doc_len;
|
||||
|
||||
for (term, positions) in &doc_data.term_pos {
|
||||
inverted
|
||||
.entry(term.clone())
|
||||
.or_default()
|
||||
.insert(doc_id.clone(), positions.len() as i64);
|
||||
}
|
||||
|
||||
docs.insert(doc_id, doc_data);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_snapshot_data(&self, index_name: &str) -> Option<SnapshotData> {
|
||||
self
|
||||
.docs
|
||||
.get(index_name)
|
||||
.map(|docs| SnapshotData { docs: docs.clone() })
|
||||
}
|
||||
}
|
||||
@@ -1,85 +0,0 @@
|
||||
use jieba_rs::Jieba;
|
||||
use once_cell::sync::Lazy;
|
||||
use tiniestsegmenter::tokenize as ts_tokenize;
|
||||
|
||||
static JIEBA: Lazy<Jieba> = Lazy::new(Jieba::new);
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct Token {
|
||||
pub term: String,
|
||||
pub start: usize,
|
||||
pub end: usize,
|
||||
}
|
||||
|
||||
pub fn tokenize(text: &str) -> Vec<Token> {
|
||||
let mut tokens = Vec::new();
|
||||
|
||||
// Use jieba for Chinese/English
|
||||
// Jieba tokenize returns tokens with offsets
|
||||
let jieba_tokens = JIEBA.tokenize(text, jieba_rs::TokenizeMode::Search, false);
|
||||
for token in jieba_tokens {
|
||||
if token.word.chars().any(|c| c.is_alphanumeric()) {
|
||||
tokens.push(Token {
|
||||
term: token.word.to_lowercase(),
|
||||
start: token.start,
|
||||
end: token.end,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Use TinySegmenter for Japanese
|
||||
// TinySegmenter does not provide offsets, so we have to find them manually
|
||||
// This is a simplified approach and might not be perfect for repeated terms
|
||||
let mut last_pos = 0;
|
||||
for term in ts_tokenize(text) {
|
||||
if term.chars().any(|c| c.is_alphanumeric()) {
|
||||
if let Some(pos) = text[last_pos..].find(term) {
|
||||
let start = last_pos + pos;
|
||||
let end = start + term.len();
|
||||
tokens.push(Token {
|
||||
term: term.to_lowercase(),
|
||||
start,
|
||||
end,
|
||||
});
|
||||
last_pos = end;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Manually handle Korean bigrams and unigrams
|
||||
let chars: Vec<char> = text.chars().collect();
|
||||
let mut byte_offset = 0;
|
||||
for (i, &c) in chars.iter().enumerate() {
|
||||
let char_len = c.len_utf8();
|
||||
if is_hangul(c) {
|
||||
tokens.push(Token {
|
||||
term: c.to_string().to_lowercase(),
|
||||
start: byte_offset,
|
||||
end: byte_offset + char_len,
|
||||
});
|
||||
if i + 1 < chars.len() {
|
||||
let next = chars[i + 1];
|
||||
if is_hangul(next) {
|
||||
let next_len = next.len_utf8();
|
||||
tokens.push(Token {
|
||||
term: format!("{}{}", c, next).to_lowercase(),
|
||||
start: byte_offset,
|
||||
end: byte_offset + char_len + next_len,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
byte_offset += char_len;
|
||||
}
|
||||
|
||||
tokens
|
||||
}
|
||||
|
||||
fn is_hangul(c: char) -> bool {
|
||||
// Hangul Syllables
|
||||
('\u{AC00}'..='\u{D7AF}').contains(&c)
|
||||
// Hangul Jamo
|
||||
|| ('\u{1100}'..='\u{11FF}').contains(&c)
|
||||
// Hangul Compatibility Jamo
|
||||
|| ('\u{3130}'..='\u{318F}').contains(&c)
|
||||
}
|
||||
@@ -1,79 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use affine_common::doc_parser::{BlockInfo, CrawlResult};
|
||||
use napi_derive::napi;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DocData {
|
||||
pub content: String,
|
||||
pub doc_len: i64,
|
||||
pub term_pos: HashMap<String, Vec<(u32, u32)>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct SnapshotData {
|
||||
pub docs: HashMap<String, DocData>,
|
||||
}
|
||||
|
||||
#[napi(object)]
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct NativeBlockInfo {
|
||||
pub block_id: String,
|
||||
pub flavour: String,
|
||||
pub content: Option<Vec<String>>,
|
||||
pub blob: Option<Vec<String>>,
|
||||
pub ref_doc_id: Option<Vec<String>>,
|
||||
pub ref_info: Option<Vec<String>>,
|
||||
pub parent_flavour: Option<String>,
|
||||
pub parent_block_id: Option<String>,
|
||||
pub additional: Option<String>,
|
||||
}
|
||||
|
||||
#[napi(object)]
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct NativeCrawlResult {
|
||||
pub blocks: Vec<NativeBlockInfo>,
|
||||
pub title: String,
|
||||
pub summary: String,
|
||||
}
|
||||
|
||||
#[napi(object)]
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct NativeSearchHit {
|
||||
pub id: String,
|
||||
pub score: f64,
|
||||
}
|
||||
|
||||
#[napi(object)]
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct NativeMatch {
|
||||
pub start: u32,
|
||||
pub end: u32,
|
||||
}
|
||||
|
||||
impl From<BlockInfo> for NativeBlockInfo {
|
||||
fn from(value: BlockInfo) -> Self {
|
||||
Self {
|
||||
block_id: value.block_id,
|
||||
flavour: value.flavour,
|
||||
content: value.content,
|
||||
blob: value.blob,
|
||||
ref_doc_id: value.ref_doc_id,
|
||||
ref_info: value.ref_info,
|
||||
parent_flavour: value.parent_flavour,
|
||||
parent_block_id: value.parent_block_id,
|
||||
additional: value.additional,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CrawlResult> for NativeCrawlResult {
|
||||
fn from(value: CrawlResult) -> Self {
|
||||
Self {
|
||||
blocks: value.blocks.into_iter().map(Into::into).collect(),
|
||||
title: value.title,
|
||||
summary: value.summary,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use affine_schema::get_migrator;
|
||||
use memory_indexer::InMemoryIndex;
|
||||
use sqlx::{
|
||||
migrate::MigrateDatabase,
|
||||
sqlite::{Sqlite, SqliteConnectOptions, SqlitePoolOptions},
|
||||
@@ -8,7 +9,7 @@ use sqlx::{
|
||||
};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use super::{error::Result, indexer::InMemoryIndex};
|
||||
use super::error::Result;
|
||||
|
||||
pub struct SqliteDocStorage {
|
||||
pub pool: Pool<Sqlite>,
|
||||
|
||||
Reference in New Issue
Block a user