From a809643eef29c5e33f1b2ee11d566c8de81ff43f Mon Sep 17 00:00:00 2001 From: Michele Riva Date: Sun, 17 Nov 2024 18:03:15 +0100 Subject: [PATCH 01/12] feat: adds FST-based full-text search --- Cargo.lock | 7 ++ string_index/Cargo.toml | 1 + string_index/src/lib.rs | 120 +++++++++++++++------------- string_index/src/posting_storage.rs | 2 +- 4 files changed, 73 insertions(+), 57 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 416025f..0d7107b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2273,6 +2273,12 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "fst" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ab85b9b05e3978cc9a9cf8fea7f01b494e1a09ed3037e16ba39edc7a29eb61a" + [[package]] name = "fuchsia-cprng" version = "0.1.1" @@ -6102,6 +6108,7 @@ dependencies = [ "bincode", "csv", "dashmap", + "fst", "itertools 0.13.0", "nlp", "ptrie 0.7.1 (git+https://github.com/oramasearch/ptrie.git?branch=feat/expose-get-mut)", diff --git a/string_index/Cargo.toml b/string_index/Cargo.toml index e466cb9..c925dcc 100644 --- a/string_index/Cargo.toml +++ b/string_index/Cargo.toml @@ -18,6 +18,7 @@ thiserror = "1.0.65" smallvec = "1.13.2" types = { path = "../types" } ptrie = { git = "https://github.com/oramasearch/ptrie.git", branch = "feat/expose-get-mut" } +fst = "0.4.7" [dev-dependencies] nlp = { path = "../nlp" } diff --git a/string_index/src/lib.rs b/string_index/src/lib.rs index deb67b8..161e7af 100644 --- a/string_index/src/lib.rs +++ b/string_index/src/lib.rs @@ -8,9 +8,8 @@ use std::{ use anyhow::Result; use dictionary::{Dictionary, TermId}; +use fst::{Automaton, IntoStreamer, Map, MapBuilder, Streamer}; use posting_storage::{PostingListId, PostingStorage}; -// use radix_trie::Trie; -use ptrie::Trie; use rayon::iter::{IntoParallelIterator, ParallelIterator}; use scorer::Scorer; use serde::{Deserialize, Serialize}; @@ -32,14 +31,15 @@ pub struct Posting { pub doc_length: u16, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy)] pub struct StringIndexValue { - posting_list_id: PostingListId, - term_frequency: usize, + pub posting_list_id: PostingListId, + pub term_frequency: usize, } pub struct StringIndex { - tree: RwLock>, + fst_map: RwLock>>>, + temp_map: RwLock>, posting_storage: PostingStorage, dictionary: Dictionary, total_documents: AtomicUsize, @@ -54,7 +54,8 @@ pub struct GlobalInfo { impl StringIndex { pub fn new(storage: Arc) -> Self { StringIndex { - tree: RwLock::new(Trie::new()), + fst_map: RwLock::new(None), + temp_map: RwLock::new(HashMap::new()), posting_storage: PostingStorage::new(storage), dictionary: Dictionary::new(), total_documents: AtomicUsize::new(0), @@ -62,6 +63,28 @@ impl StringIndex { } } + fn build_fst(&self) -> Result>> { + let temp_map = self.temp_map.read().unwrap(); + + let mut entries: Vec<_> = temp_map + .iter() + .map(|(key, value)| { + ( + key.clone(), + ((value.posting_list_id.0 as u64) << 32) | (value.term_frequency as u64), + ) + }) + .collect(); + entries.sort_by(|(a, _), (b, _)| a.cmp(b)); + + let mut builder = MapBuilder::memory(); + for (key, value) in entries { + builder.insert(key, value)?; + } + + Ok(builder.into_map()) + } + pub fn get_total_documents(&self) -> usize { self.total_documents.load(Ordering::Relaxed) } @@ -88,17 +111,24 @@ impl StringIndex { total_document_length => total_document_length, }; - // let avg_doc_length = total_document_length / total_documents; - - let mut posting_list_ids_with_freq = Vec::::new(); - let tree = self - .tree - .read() - // TODO: better error handling - .expect("Unable to read"); - for token in tokens { - let a = tree.find_postfixes(token.bytes()); - posting_list_ids_with_freq.extend(a.into_iter().cloned()); + let mut posting_list_ids_with_freq = Vec::new(); + + if let Some(fst) = self.fst_map.read().unwrap().as_ref() { + for token in tokens { + let automaton = fst::automaton::Str::new(&token).starts_with(); + let mut stream = fst.search(automaton).into_stream(); + + while let Some(result) = stream.next() { + let (_, packed_value) = result; + let posting_list_id = PostingListId((packed_value >> 32).try_into().unwrap()); + let term_frequency = packed_value & 0xFFFFFFFF; + + posting_list_ids_with_freq.push(StringIndexValue { + posting_list_id: PostingListId(posting_list_id.0), + term_frequency: term_frequency as usize, + }); + } + } } let fields = search_on.as_ref(); @@ -131,7 +161,6 @@ impl StringIndex { Some((posting, string_index_value.term_frequency)) }) - // Every thread perform on a separated hashmap .for_each(|(postings, total_token_count)| { let total_token_count = total_token_count as f32; for posting in postings { @@ -140,9 +169,7 @@ impl StringIndex { } }); - let scores = scorer.get_scores(); - - Ok(scores) + Ok(scorer.get_scores()) } pub fn insert_multiple(&self, data: DocumentBatch) -> Result<()> { @@ -151,9 +178,8 @@ impl StringIndex { let dictionary = &self.dictionary; - let t = data + let posting_per_term = data .into_par_iter() - // Parallel .fold( HashMap::>::new, |mut acc, (document_id, strings)| { @@ -185,8 +211,6 @@ impl StringIndex { for (term, field_positions) in term_freqs { let term_id = dictionary.get_or_add(&term); - // println!("Term: {} -> {}", term, term_id.0); - let v = acc.entry(term_id).or_default(); let posting = @@ -197,7 +221,6 @@ impl StringIndex { document_id, field_id, positions, - // original_term: term.clone(), term_frequency, doc_length: doc_length as u16, } @@ -208,28 +231,20 @@ impl StringIndex { acc }, - ); - - let posting_per_term = t.reduce( - HashMap::>::new, - // Merge the hashmap - |mut acc, item| { + ) + .reduce(HashMap::>::new, |mut acc, item| { for (term_id, postings) in item { let vec = acc.entry(term_id).or_default(); vec.extend(postings.into_iter()); } acc - }, - ); + }); let mut postings_per_posting_list_id: HashMap>> = HashMap::with_capacity(posting_per_term.len()); - let mut tree = self.tree.write().expect("Unable to write"); - // NB: We cannot parallelize the tree insertion yet :( - // We could move the tree into a custom implementation to support parallelism - // Once we resolve this issue, we StringIndex is thread safe! - // TODO: move to custom implementation - // For the time being, we can just use the sync tree + + let mut temp_map = self.temp_map.write().unwrap(); + for (term_id, postings) in posting_per_term { self.total_document_length.fetch_add( postings.iter().map(|p| p.positions.len()).sum::(), @@ -237,15 +252,9 @@ impl StringIndex { ); let number_of_occurence_of_term = postings.len(); - // Due to this implementation, we have a limitation - // because we "forgot" the term. Here we have just the term_id - // This invocation shouldn't exist at all: - // we have the term on the top of this function - // TODO: find a way to avoid this invocation let term = dictionary.retrive(term_id); - let value = tree.get_mut(term.bytes()); - if let Some(value) = value { + if let Some(value) = temp_map.get_mut(&term) { value.term_frequency += number_of_occurence_of_term; let vec = postings_per_posting_list_id .entry(value.posting_list_id) @@ -253,8 +262,8 @@ impl StringIndex { vec.push(postings); } else { let posting_list_id = self.posting_storage.generate_new_id(); - tree.insert( - term.bytes(), + temp_map.insert( + term, StringIndexValue { posting_list_id, term_frequency: number_of_occurence_of_term, @@ -268,10 +277,12 @@ impl StringIndex { } } + let new_fst = self.build_fst()?; + *self.fst_map.write().unwrap() = Some(new_fst); + postings_per_posting_list_id .into_par_iter() .map(|(k, v)| self.posting_storage.add_or_create(k, v)) - // TODO: handle error .all(|_| true); Ok(()) @@ -280,14 +291,11 @@ impl StringIndex { #[cfg(test)] mod tests { - use std::{collections::HashMap, sync::Arc}; - + use crate::{scorer::bm25::BM25Score, DocumentId, FieldId, StringIndex}; use nlp::{locales::Locale, TextParser}; - + use std::{collections::HashMap, sync::Arc}; use tempdir::TempDir; - use crate::{scorer::bm25::BM25Score, DocumentId, FieldId, StringIndex}; - #[test] fn test_search() { let tmp_dir = TempDir::new("string_index_test").unwrap(); diff --git a/string_index/src/posting_storage.rs b/string_index/src/posting_storage.rs index 976df89..d14f402 100644 --- a/string_index/src/posting_storage.rs +++ b/string_index/src/posting_storage.rs @@ -6,7 +6,7 @@ use thiserror::Error; use crate::Posting; #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] -pub struct PostingListId(usize); +pub struct PostingListId(pub usize); #[derive(Debug, Error)] pub enum PostingStorageError { From cba262daf6873eccb9ad5ecd22ee76577da53e05 Mon Sep 17 00:00:00 2001 From: Michele Riva Date: Sun, 17 Nov 2024 18:26:20 +0100 Subject: [PATCH 02/12] wip --- string_index/src/lib.rs | 124 ++++++++++++++++++------------------- string_index/src/scorer.rs | 25 ++++---- 2 files changed, 74 insertions(+), 75 deletions(-) diff --git a/string_index/src/lib.rs b/string_index/src/lib.rs index 161e7af..2b6e4e1 100644 --- a/string_index/src/lib.rs +++ b/string_index/src/lib.rs @@ -64,25 +64,31 @@ impl StringIndex { } fn build_fst(&self) -> Result>> { - let temp_map = self.temp_map.read().unwrap(); - - let mut entries: Vec<_> = temp_map - .iter() - .map(|(key, value)| { - ( - key.clone(), - ((value.posting_list_id.0 as u64) << 32) | (value.term_frequency as u64), - ) - }) - .collect(); - entries.sort_by(|(a, _), (b, _)| a.cmp(b)); + let entries: Vec<_> = { + let temp_map = self.temp_map.read().unwrap(); + temp_map + .iter() + .map(|(key, value)| { + ( + key.clone(), + ((value.posting_list_id.0 as u64) << 32) | (value.term_frequency as u64), + ) + }) + .collect() + }; + + let mut sorted_entries = entries; + sorted_entries.sort_by(|(a, _), (b, _)| a.cmp(b)); let mut builder = MapBuilder::memory(); - for (key, value) in entries { - builder.insert(key, value)?; + + for (key, value) in sorted_entries { + builder.insert(key.as_bytes(), value)?; } - Ok(builder.into_map()) + let fst = builder.into_map(); + + Ok(fst) } pub fn get_total_documents(&self) -> usize { @@ -98,78 +104,66 @@ impl StringIndex { ) -> Result> { let total_documents = match self.total_documents.load(Ordering::Relaxed) { 0 => { - println!("total_documents == 0"); return Ok(Default::default()); } total_documents => total_documents, }; + let total_document_length = match self.total_document_length.load(Ordering::Relaxed) { 0 => { - println!("total_document_length == 0"); return Ok(Default::default()); } total_document_length => total_document_length, }; - let mut posting_list_ids_with_freq = Vec::new(); + let mut all_postings = Vec::new(); if let Some(fst) = self.fst_map.read().unwrap().as_ref() { - for token in tokens { - let automaton = fst::automaton::Str::new(&token).starts_with(); + for token in &tokens { + let automaton = fst::automaton::Str::new(token).starts_with(); let mut stream = fst.search(automaton).into_stream(); while let Some(result) = stream.next() { let (_, packed_value) = result; - let posting_list_id = PostingListId((packed_value >> 32).try_into().unwrap()); - let term_frequency = packed_value & 0xFFFFFFFF; + let posting_list_id = PostingListId(((packed_value >> 32) as u32) as usize); + let term_frequency = (packed_value & 0xFFFFFFFF) as usize; - posting_list_ids_with_freq.push(StringIndexValue { - posting_list_id: PostingListId(posting_list_id.0), - term_frequency: term_frequency as usize, - }); + match self.posting_storage.get(posting_list_id) { + Ok(postings) => { + all_postings.push((postings, term_frequency)); + } + Err(e) => {} + } } } } let fields = search_on.as_ref(); - let global_info = GlobalInfo { total_documents, total_document_length, }; - posting_list_ids_with_freq - .into_par_iter() - .filter_map(|string_index_value| { - let output = self - .posting_storage - .get(string_index_value.posting_list_id) - .ok(); - let posting = match output { - Some(v) => v, - None => return None, - }; - - let posting: Vec<_> = posting - .into_iter() - .filter(move |posting| { - fields - .map(|search_on| search_on.contains(&posting.field_id)) - .unwrap_or(true) - }) - .collect(); - - Some((posting, string_index_value.term_frequency)) - }) - .for_each(|(postings, total_token_count)| { - let total_token_count = total_token_count as f32; - for posting in postings { - let boost_per_field = *boost.get(&posting.field_id).unwrap_or(&1.0); - scorer.add_entry(&global_info, posting, total_token_count, boost_per_field); + let mut filtered_postings = Vec::new(); + for (postings, term_freq) in all_postings { + for posting in postings { + if fields + .map(|search_on| search_on.contains(&posting.field_id)) + .unwrap_or(true) + { + filtered_postings.push((posting, term_freq)); } - }); + } + } + + for (posting, term_freq) in filtered_postings { + let boost_per_field = *boost.get(&posting.field_id).unwrap_or(&1.0); + scorer.add_entry(&global_info, posting, term_freq as f32, boost_per_field); + } + + let scores = scorer.get_scores(); - Ok(scorer.get_scores()) + Ok(scores) } pub fn insert_multiple(&self, data: DocumentBatch) -> Result<()> { @@ -242,7 +236,6 @@ impl StringIndex { let mut postings_per_posting_list_id: HashMap>> = HashMap::with_capacity(posting_per_term.len()); - let mut temp_map = self.temp_map.write().unwrap(); for (term_id, postings) in posting_per_term { @@ -277,13 +270,17 @@ impl StringIndex { } } + drop(temp_map); let new_fst = self.build_fst()?; - *self.fst_map.write().unwrap() = Some(new_fst); - postings_per_posting_list_id - .into_par_iter() - .map(|(k, v)| self.posting_storage.add_or_create(k, v)) - .all(|_| true); + { + let mut fst_map = self.fst_map.write().unwrap(); + *fst_map = Some(new_fst); + } + + for (k, v) in postings_per_posting_list_id { + self.posting_storage.add_or_create(k, v)?; + } Ok(()) } @@ -300,6 +297,7 @@ mod tests { fn test_search() { let tmp_dir = TempDir::new("string_index_test").unwrap(); let tmp_dir: String = tmp_dir.into_path().to_str().unwrap().to_string(); + let storage = Arc::new(crate::Storage::from_path(&tmp_dir)); let string_index = StringIndex::new(storage); let parser = TextParser::from_language(Locale::EN); diff --git a/string_index/src/scorer.rs b/string_index/src/scorer.rs index 6f0334a..4a2171f 100644 --- a/string_index/src/scorer.rs +++ b/string_index/src/scorer.rs @@ -54,32 +54,33 @@ pub mod bm25 { #[inline] fn add_entry( &self, - _global_info: &crate::GlobalInfo, + global_info: &crate::GlobalInfo, posting: crate::Posting, - _total_token_count: f32, - _boost_per_field: f32, + total_token_count: f32, + boost_per_field: f32, ) { let term_frequency = posting.term_frequency; let doc_length = posting.doc_length as f32; - let freq = 1.0; + let total_documents = global_info.total_documents as f32; + let avg_doc_length = global_info.total_document_length as f32 / total_documents; - let total_documents = 1.0; - let avg_doc_length = total_documents / 1.0; - - let idf = ((total_documents - freq + 0.5_f32) / (freq + 0.5_f32)).ln_1p(); + let idf = + ((total_documents - total_token_count + 0.5) / (total_token_count + 0.5)).ln_1p(); let score = Self::calculate_score( term_frequency, idf, doc_length, avg_doc_length, - _boost_per_field, + boost_per_field, ); - let mut previous = self.scores.entry(posting.document_id).or_insert(0.0); - *previous += score; + self.scores + .entry(posting.document_id) + .and_modify(|e| *e += score) + .or_insert(score); } - fn get_scores(self) -> HashMap { + fn get_scores(self) -> HashMap { self.scores.into_iter().collect() } } From 142ea9a22d79b835c486cf81f806f9363b552f6a Mon Sep 17 00:00:00 2001 From: Michele Riva Date: Sun, 17 Nov 2024 20:49:08 +0100 Subject: [PATCH 03/12] adds benchmarks --- Cargo.lock | 177 ++++++++++++-- string_index/Cargo.toml | 9 +- string_index/benches/string_index_bench.rs | 268 +++++++++++++++++++++ 3 files changed, 435 insertions(+), 19 deletions(-) create mode 100644 string_index/benches/string_index_bench.rs diff --git a/Cargo.lock b/Cargo.lock index 0d7107b..143efa3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -286,6 +286,12 @@ dependencies = [ "libc", ] +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + [[package]] name = "anstream" version = "0.6.18" @@ -614,9 +620,9 @@ dependencies = [ [[package]] name = "axum" -version = "0.7.7" +version = "0.7.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "504e3947307ac8326a5437504c517c4b56716c9d98fac0028c2acc7ca47d70ae" +checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" dependencies = [ "async-trait", "axum-core", @@ -971,8 +977,8 @@ dependencies = [ [[package]] name = "candle-core" -version = "0.7.2" -source = "git+https://github.com/EricLBuehler/candle.git?rev=11495ab#11495abeba0c1d27b680168edd0cd52189a7ca30" +version = "0.8.0" +source = "git+https://github.com/EricLBuehler/candle.git?rev=cb8082b#cb8082bf28eadb4140c1774f002983a6d182bf3e" dependencies = [ "byteorder", "candle-metal-kernels", @@ -996,8 +1002,8 @@ dependencies = [ [[package]] name = "candle-metal-kernels" -version = "0.7.2" -source = "git+https://github.com/EricLBuehler/candle.git?rev=11495ab#11495abeba0c1d27b680168edd0cd52189a7ca30" +version = "0.8.0" +source = "git+https://github.com/EricLBuehler/candle.git?rev=cb8082b#cb8082bf28eadb4140c1774f002983a6d182bf3e" dependencies = [ "metal 0.27.0", "once_cell", @@ -1007,8 +1013,8 @@ dependencies = [ [[package]] name = "candle-nn" -version = "0.7.2" -source = "git+https://github.com/EricLBuehler/candle.git?rev=11495ab#11495abeba0c1d27b680168edd0cd52189a7ca30" +version = "0.8.0" +source = "git+https://github.com/EricLBuehler/candle.git?rev=cb8082b#cb8082bf28eadb4140c1774f002983a6d182bf3e" dependencies = [ "candle-core", "candle-metal-kernels", @@ -1031,6 +1037,12 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + [[package]] name = "cc" version = "1.2.1" @@ -1101,6 +1113,33 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + [[package]] name = "clang-sys" version = "1.8.1" @@ -1406,6 +1445,42 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "is-terminal", + "itertools 0.10.5", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools 0.10.5", +] + [[package]] name = "crossbeam-deque" version = "0.8.5" @@ -3284,12 +3359,32 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "is-terminal" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "261f68e344040fbd0edea105bef17c66edf46f984ddb1115b775ce31be948f4b" +dependencies = [ + "hermit-abi 0.4.0", + "libc", + "windows-sys 0.52.0", +] + [[package]] name = "is_terminal_polyfill" version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + [[package]] name = "itertools" version = "0.11.0" @@ -3402,9 +3497,9 @@ checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8" [[package]] name = "libc" -version = "0.2.162" +version = "0.2.164" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18d287de67fe55fd7e1581fe933d965a5a9477b38e949cfa9f8574ef01506398" +checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f" [[package]] name = "libflate" @@ -3859,7 +3954,7 @@ dependencies = [ [[package]] name = "mistralrs" version = "0.3.2" -source = "git+https://github.com/EricLBuehler/mistral.rs.git#ddc63f1e0433356789cd875c3e39df16df0d0a43" +source = "git+https://github.com/EricLBuehler/mistral.rs.git#72eef3e212ae11a2824f89e72657e7c3dc69fba8" dependencies = [ "anyhow", "candle-core", @@ -3878,7 +3973,7 @@ dependencies = [ [[package]] name = "mistralrs-core" version = "0.3.2" -source = "git+https://github.com/EricLBuehler/mistral.rs.git#ddc63f1e0433356789cd875c3e39df16df0d0a43" +source = "git+https://github.com/EricLBuehler/mistral.rs.git#72eef3e212ae11a2824f89e72657e7c3dc69fba8" dependencies = [ "akin", "anyhow", @@ -3945,7 +4040,7 @@ dependencies = [ [[package]] name = "mistralrs-quant" version = "0.3.2" -source = "git+https://github.com/EricLBuehler/mistral.rs.git#ddc63f1e0433356789cd875c3e39df16df0d0a43" +source = "git+https://github.com/EricLBuehler/mistral.rs.git#72eef3e212ae11a2824f89e72657e7c3dc69fba8" dependencies = [ "byteorder", "candle-core", @@ -3964,7 +4059,7 @@ dependencies = [ [[package]] name = "mistralrs-vision" version = "0.3.2" -source = "git+https://github.com/EricLBuehler/mistral.rs.git#ddc63f1e0433356789cd875c3e39df16df0d0a43" +source = "git+https://github.com/EricLBuehler/mistral.rs.git#72eef3e212ae11a2824f89e72657e7c3dc69fba8" dependencies = [ "candle-core", "image", @@ -4355,6 +4450,12 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "oorandom" +version = "11.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" + [[package]] name = "openai" version = "1.0.0-alpha.16" @@ -4706,6 +4807,34 @@ dependencies = [ "time", ] +[[package]] +name = "plotters" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" + +[[package]] +name = "plotters-svg" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" +dependencies = [ + "plotters-backend", +] + [[package]] name = "png" version = "0.17.14" @@ -5521,9 +5650,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.16" +version = "0.23.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eee87ff5d9b36712a58574e12e9f0ea80f915a5b0ac518d322b24a465617925e" +checksum = "7f1a745511c54ba6d4465e8d5dfbd81b45791756de28d4981af70d6dca128f1e" dependencies = [ "log", "once_cell", @@ -5747,9 +5876,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.132" +version = "1.0.133" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" +checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377" dependencies = [ "itoa", "memchr", @@ -6106,6 +6235,7 @@ version = "0.1.0" dependencies = [ "anyhow", "bincode", + "criterion", "csv", "dashmap", "fst", @@ -6113,6 +6243,7 @@ dependencies = [ "nlp", "ptrie 0.7.1 (git+https://github.com/oramasearch/ptrie.git?branch=feat/expose-get-mut)", "radix_trie", + "rand 0.8.5", "rayon", "serde", "smallvec", @@ -6737,6 +6868,16 @@ dependencies = [ "zerovec", ] +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "tinyvec" version = "1.8.0" diff --git a/string_index/Cargo.toml b/string_index/Cargo.toml index c925dcc..da37682 100644 --- a/string_index/Cargo.toml +++ b/string_index/Cargo.toml @@ -3,6 +3,11 @@ name = "string_index" version = "0.1.0" edition = "2021" +[[bench]] +name = "string_index_bench" +harness = false +path = "benches/string_index_bench.rs" + [dependencies] anyhow = "1.0.90" bincode = "1.3.3" @@ -19,7 +24,9 @@ smallvec = "1.13.2" types = { path = "../types" } ptrie = { git = "https://github.com/oramasearch/ptrie.git", branch = "feat/expose-get-mut" } fst = "0.4.7" +rand = "0.8.5" [dev-dependencies] +criterion = "0.5.1" nlp = { path = "../nlp" } - +tempdir = "0.3.7" diff --git a/string_index/benches/string_index_bench.rs b/string_index/benches/string_index_bench.rs new file mode 100644 index 0000000..f079dff --- /dev/null +++ b/string_index/benches/string_index_bench.rs @@ -0,0 +1,268 @@ +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use rand::{seq::SliceRandom, Rng}; +use std::collections::HashMap; +use std::sync::Arc; +use storage::Storage; +use string_index::scorer::bm25::BM25Score; +use string_index::StringIndex; +use types::{DocumentId, FieldId}; + +#[derive(Debug)] +struct BenchmarkResults { + docs_count: usize, + index_size_bytes: u64, + indexing_time_ms: u64, + avg_search_time_ms: f64, + memory_usage_mb: f64, + throughput_docs_per_sec: f64, +} + +fn generate_test_data( + num_docs: usize, +) -> Vec<(DocumentId, Vec<(FieldId, Vec<(String, Vec)>)>)> { + let mut rng = rand::thread_rng(); + + let vocabulary: Vec = (0..5000) + .map(|_| { + let len = rng.gen_range(3..8); + (0..len) + .map(|_| (b'a' + rng.gen_range(0..26)) as char) + .collect() + }) + .collect(); + + (0..num_docs) + .map(|i| { + let doc_id = DocumentId(i as u64); + let num_fields = rng.gen_range(1..=3); + + let fields = (0..num_fields) + .map(|field_i| { + let field_id = FieldId(field_i); + let num_terms = rng.gen_range(5..50); + + let terms = (0..num_terms) + .map(|_| { + let word = vocabulary.choose(&mut rng).unwrap().clone(); + let stems = (0..rng.gen_range(0..2)) + .map(|_| vocabulary.choose(&mut rng).unwrap().clone()) + .collect(); + (word, stems) + }) + .collect(); + + (field_id, terms) + }) + .collect(); + + (doc_id, fields) + }) + .collect() +} + +fn benchmark_indexing(c: &mut Criterion) { + let mut group = c.benchmark_group("indexing"); + group.measurement_time(std::time::Duration::from_secs(30)); + group.sample_size(50); + + for size in [1000, 5000, 10_000].iter() { + // Reduced upper limit + group.bench_with_input(BenchmarkId::new("simple_index", size), size, |b, &size| { + let tmp_dir = tempdir::TempDir::new("bench_index").unwrap(); + let storage = Arc::new(Storage::from_path(tmp_dir.path().to_str().unwrap())); + let index = StringIndex::new(Arc::clone(&storage)); + let data = generate_test_data(size); + + b.iter(|| { + for (doc_id, fields) in data.iter() { + let mut batch = HashMap::new(); + batch.insert(*doc_id, fields.clone()); + index.insert_multiple(batch).expect("Insertion failed"); + } + }); + }); + } + + group.finish(); +} + +fn benchmark_batch_indexing(c: &mut Criterion) { + let mut group = c.benchmark_group("batch_indexing"); + group.measurement_time(std::time::Duration::from_secs(30)); + group.sample_size(30); + + for size in [1000, 5000, 10_000].iter() { + for batch_size in [100, 500].iter() { + group.bench_with_input( + BenchmarkId::new(format!("batch_size_{}", batch_size), size), + size, + |b, &size| { + let tmp_dir = tempdir::TempDir::new("bench_index").unwrap(); + let storage = Arc::new(Storage::from_path(tmp_dir.path().to_str().unwrap())); + let index = StringIndex::new(Arc::clone(&storage)); + let data = generate_test_data(size); + + b.iter(|| { + for chunk in data.chunks(*batch_size) { + let mut batch = HashMap::new(); + for (doc_id, fields) in chunk { + batch.insert(*doc_id, fields.clone()); + } + index + .insert_multiple(batch) + .expect("Batch insertion failed"); + } + }); + }, + ); + } + } + + group.finish(); +} + +fn benchmark_search(c: &mut Criterion) { + let mut group = c.benchmark_group("search"); + group.measurement_time(std::time::Duration::from_secs(20)); + + let mut rng = rand::thread_rng(); + let vocabulary: Vec = (0..500) + .map(|_| { + let len = rng.gen_range(3..8); + (0..len) + .map(|_| (b'a' + rng.gen_range(0..26)) as char) + .collect() + }) + .collect(); + + let queries: Vec> = (0..100) + .map(|_| { + let num_terms = rng.gen_range(1..3); + (0..num_terms) + .map(|_| vocabulary.choose(&mut rng).unwrap().clone()) + .collect() + }) + .collect(); + + for size in [1000, 5000].iter() { + group.bench_with_input(BenchmarkId::new("search", size), size, |b, &size| { + let tmp_dir = tempdir::TempDir::new("bench_index").unwrap(); + let storage = Arc::new(Storage::from_path(tmp_dir.path().to_str().unwrap())); + let index = StringIndex::new(Arc::clone(&storage)); + + let data = generate_test_data(size); + let mut batch = HashMap::new(); + for (doc_id, fields) in data { + batch.insert(doc_id, fields); + } + index + .insert_multiple(batch) + .expect("Initial data insertion failed"); + + b.iter(|| { + for query in queries.iter() { + let _ = index + .search( + query.clone(), + None, + Default::default(), + BM25Score::default(), + ) + .expect("Search failed"); + } + }); + }); + } + + group.finish(); +} + +fn benchmark_concurrent_ops(c: &mut Criterion) { + let mut group = c.benchmark_group("concurrent_operations"); + group.measurement_time(std::time::Duration::from_secs(40)); + + for size in [5000].iter() { + group.bench_with_input(BenchmarkId::new("concurrent", size), size, |b, &size| { + let tmp_dir = tempdir::TempDir::new("bench_index").unwrap(); + let storage = Arc::new(Storage::from_path(tmp_dir.path().to_str().unwrap())); + let index = Arc::new(StringIndex::new(Arc::clone(&storage))); + + let data = generate_test_data(size); + let mut batch = HashMap::new(); + for (doc_id, fields) in data { + batch.insert(doc_id, fields); + } + index + .insert_multiple(batch) + .expect("Initial data insertion failed"); + + let queries = generate_test_data(50) + .into_iter() + .map(|(_, fields)| { + fields + .into_iter() + .flat_map(|(_, terms)| terms) + .map(|(term, _)| term) + .collect::>() + }) + .collect::>(); + + b.iter(|| { + use std::thread; + + let search_threads: Vec<_> = (0..2) + .map(|_| { + let index = Arc::clone(&index); + let queries = queries.clone(); + thread::spawn(move || { + for query in &queries { + let _ = index + .search( + query.clone(), + None, + Default::default(), + BM25Score::default(), + ) + .expect("Concurrent search failed"); + } + }) + }) + .collect(); + + let insert_threads: Vec<_> = (0..1) + .map(|_| { + let index = Arc::clone(&index); + let data = generate_test_data(50); + thread::spawn(move || { + for (doc_id, fields) in data { + let mut batch = HashMap::new(); + batch.insert(doc_id, fields); + index + .insert_multiple(batch) + .expect("Concurrent insert failed"); + } + }) + }) + .collect(); + + for thread in search_threads { + thread.join().expect("Search thread join failed"); + } + for thread in insert_threads { + thread.join().expect("Insert thread join failed"); + } + }); + }); + } + + group.finish(); +} + +criterion_group!( + benches, + benchmark_indexing, + benchmark_batch_indexing, + benchmark_search, + benchmark_concurrent_ops +); +criterion_main!(benches); From 209b21815427b23568f33b7985d596cd4b7c168f Mon Sep 17 00:00:00 2001 From: Michele Riva Date: Sun, 17 Nov 2024 21:28:49 +0100 Subject: [PATCH 04/12] adds tests --- string_index/src/lib.rs | 511 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 491 insertions(+), 20 deletions(-) diff --git a/string_index/src/lib.rs b/string_index/src/lib.rs index 2b6e4e1..0ae9571 100644 --- a/string_index/src/lib.rs +++ b/string_index/src/lib.rs @@ -294,8 +294,214 @@ mod tests { use tempdir::TempDir; #[test] - fn test_search() { - let tmp_dir = TempDir::new("string_index_test").unwrap(); + fn test_empty_search_query() { + let tmp_dir = TempDir::new("string_index_test_empty_search").unwrap(); + let tmp_dir: String = tmp_dir.into_path().to_str().unwrap().to_string(); + + let storage = Arc::new(crate::Storage::from_path(&tmp_dir)); + let string_index = StringIndex::new(storage); + let parser = TextParser::from_language(Locale::EN); + + let batch: HashMap<_, _> = vec![( + DocumentId(1), + vec![(FieldId(0), "This is a test document.".to_string())], + )] + .into_iter() + .map(|(doc_id, fields)| { + let fields: Vec<_> = fields + .into_iter() + .map(|(field_id, data)| { + let tokens = parser.tokenize_and_stem(&data); + (field_id, tokens) + }) + .collect(); + (doc_id, fields) + }) + .collect(); + + string_index.insert_multiple(batch).unwrap(); + + let output = string_index + .search( + vec![], // Empty search tokens + None, + Default::default(), + BM25Score::default(), + ) + .unwrap(); + + assert!( + output.is_empty(), + "Search results should be empty for empty query" + ); + } + + #[test] + fn test_search_nonexistent_term() { + let tmp_dir = TempDir::new("string_index_test_nonexistent_term").unwrap(); + let tmp_dir: String = tmp_dir.into_path().to_str().unwrap().to_string(); + + let storage = Arc::new(crate::Storage::from_path(&tmp_dir)); + let string_index = StringIndex::new(storage); + let parser = TextParser::from_language(Locale::EN); + + let batch: HashMap<_, _> = vec![( + DocumentId(1), + vec![(FieldId(0), "This is a test document.".to_string())], + )] + .into_iter() + .map(|(doc_id, fields)| { + let fields: Vec<_> = fields + .into_iter() + .map(|(field_id, data)| { + let tokens = parser.tokenize_and_stem(&data); + (field_id, tokens) + }) + .collect(); + (doc_id, fields) + }) + .collect(); + + string_index.insert_multiple(batch).unwrap(); + + let output = string_index + .search( + vec!["nonexistent".to_string()], // Term does not exist + None, + Default::default(), + BM25Score::default(), + ) + .unwrap(); + + assert!( + output.is_empty(), + "Search results should be empty for non-existent term" + ); + } + + #[test] + fn test_insert_empty_document() { + let tmp_dir = TempDir::new("string_index_test_empty_document").unwrap(); + let tmp_dir: String = tmp_dir.into_path().to_str().unwrap().to_string(); + + let storage = Arc::new(crate::Storage::from_path(&tmp_dir)); + let string_index = StringIndex::new(storage); + let parser = TextParser::from_language(Locale::EN); + + let batch: HashMap<_, _> = vec![( + DocumentId(1), + vec![(FieldId(0), "".to_string())], // Empty document content + )] + .into_iter() + .map(|(doc_id, fields)| { + let fields: Vec<_> = fields + .into_iter() + .map(|(field_id, data)| { + let tokens = parser.tokenize_and_stem(&data); + (field_id, tokens) + }) + .collect(); + (doc_id, fields) + }) + .collect(); + + string_index.insert_multiple(batch).unwrap(); + + // Search for any term, should get empty result + let output = string_index + .search( + vec!["test".to_string()], + None, + Default::default(), + BM25Score::default(), + ) + .unwrap(); + + assert!( + output.is_empty(), + "Search results should be empty when only empty documents are indexed" + ); + } + + #[test] + fn test_search_with_field_filter() { + let tmp_dir = TempDir::new("string_index_test_field_filter").unwrap(); + let tmp_dir: String = tmp_dir.into_path().to_str().unwrap().to_string(); + + let storage = Arc::new(crate::Storage::from_path(&tmp_dir)); + let string_index = StringIndex::new(storage); + let parser = TextParser::from_language(Locale::EN); + + let batch: HashMap<_, _> = vec![( + DocumentId(1), + vec![ + (FieldId(0), "This is a test in field zero.".to_string()), + (FieldId(1), "Another test in field one.".to_string()), + ], + )] + .into_iter() + .map(|(doc_id, fields)| { + let fields: Vec<_> = fields + .into_iter() + .map(|(field_id, data)| { + let tokens = parser.tokenize_and_stem(&data); + (field_id, tokens) + }) + .collect(); + (doc_id, fields) + }) + .collect(); + + string_index.insert_multiple(batch).unwrap(); + + let output = string_index + .search( + vec!["test".to_string()], + Some(vec![FieldId(0)]), // Search only in FieldId(0) + Default::default(), + BM25Score::default(), + ) + .unwrap(); + + assert_eq!( + output.len(), + 1, + "Should find the document when searching in FieldId(0)" + ); + + let output = string_index + .search( + vec!["test".to_string()], + Some(vec![FieldId(1)]), + Default::default(), + BM25Score::default(), + ) + .unwrap(); + + assert_eq!( + output.len(), + 1, + "Should find the document when searching in FieldId(1)" + ); + + let output = string_index + .search( + vec!["test".to_string()], + Some(vec![FieldId(2)]), + Default::default(), + BM25Score::default(), + ) + .unwrap(); + + assert!( + output.is_empty(), + "Should not find any documents when searching in non-existent FieldId" + ); + } + + #[test] + fn test_search_with_boosts() { + let tmp_dir = TempDir::new("string_index_test_boosts").unwrap(); let tmp_dir: String = tmp_dir.into_path().to_str().unwrap().to_string(); let storage = Arc::new(crate::Storage::from_path(&tmp_dir)); @@ -305,24 +511,13 @@ mod tests { let batch: HashMap<_, _> = vec![ ( DocumentId(1), - vec![( - FieldId(0), - "Yo, I'm from where Nicky Barnes got rich as fuck, welcome!".to_string(), - )], + vec![(FieldId(0), "Important content in field zero.".to_string())], ), ( DocumentId(2), vec![( - FieldId(0), - "Welcome to Harlem, where you welcome to problems".to_string(), - )], - ), - ( - DocumentId(3), - vec![( - FieldId(0), - "Now bitches, they want to neuter me, niggas, they want to tutor me" - .to_string(), + FieldId(1), + "Less important content in field one.".to_string(), )], ), ] @@ -341,26 +536,302 @@ mod tests { string_index.insert_multiple(batch).unwrap(); + let mut boost = HashMap::new(); + boost.insert(FieldId(0), 2.0); + let output = string_index .search( - vec!["welcome".to_string()], + vec!["content".to_string()], + None, + boost, + BM25Score::default(), + ) + .unwrap(); + + assert_eq!(output.len(), 2, "Should find both documents"); + + let score_doc1 = output.get(&DocumentId(1)).unwrap(); + let score_doc2 = output.get(&DocumentId(2)).unwrap(); + + assert!( + score_doc1 > score_doc2, + "Document with boosted field should have higher score" + ); + } + + #[test] + fn test_insert_document_with_stop_words_only() { + let tmp_dir = TempDir::new("string_index_test_stop_words").unwrap(); + let tmp_dir: String = tmp_dir.into_path().to_str().unwrap().to_string(); + + let storage = Arc::new(crate::Storage::from_path(&tmp_dir)); + let string_index = StringIndex::new(storage); + let parser = TextParser::from_language(Locale::EN); + + let batch: HashMap<_, _> = vec![( + DocumentId(1), + vec![(FieldId(0), "the and but or".to_string())], // Only stop words + )] + .into_iter() + .map(|(doc_id, fields)| { + let fields: Vec<_> = fields + .into_iter() + .map(|(field_id, data)| { + let tokens = parser.tokenize_and_stem(&data); + (field_id, tokens) + }) + .collect(); + (doc_id, fields) + }) + .collect(); + + string_index.insert_multiple(batch).unwrap(); + + // Search for any term, should get empty result since only stop words are indexed + let output = string_index + .search( + vec!["the".to_string()], None, Default::default(), BM25Score::default(), ) .unwrap(); - assert_eq!(output.len(), 2); + assert!( + output.is_empty(), + "Search results should be empty when only stop words are indexed" + ); + } + + #[test] + fn test_search_on_empty_index() { + let tmp_dir = TempDir::new("string_index_test_empty_index").unwrap(); + let tmp_dir: String = tmp_dir.into_path().to_str().unwrap().to_string(); + + let storage = Arc::new(crate::Storage::from_path(&tmp_dir)); + let string_index = StringIndex::new(storage); let output = string_index .search( - vec!["wel".to_string()], + vec!["test".to_string()], None, Default::default(), BM25Score::default(), ) .unwrap(); - assert_eq!(output.len(), 2); + assert!( + output.is_empty(), + "Search results should be empty when index is empty" + ); + } + + #[test] + fn test_concurrent_insertions() { + use std::thread; + + let tmp_dir = TempDir::new("string_index_test_concurrent_inserts").unwrap(); + let tmp_dir: String = tmp_dir.into_path().to_str().unwrap().to_string(); + + let storage = Arc::new(crate::Storage::from_path(&tmp_dir)); + let string_index = Arc::new(StringIndex::new(storage)); + + let string_index_clone1 = Arc::clone(&string_index); + let string_index_clone2 = Arc::clone(&string_index); + + let handle1 = thread::spawn(move || { + let parser = TextParser::from_language(Locale::EN); + let batch: HashMap<_, _> = vec![( + DocumentId(1), + vec![( + FieldId(0), + "Concurrent insertion test document one.".to_string(), + )], + )] + .into_iter() + .map(|(doc_id, fields)| { + let fields: Vec<_> = fields + .into_iter() + .map(|(field_id, data)| { + let tokens = parser.tokenize_and_stem(&data); + (field_id, tokens) + }) + .collect(); + (doc_id, fields) + }) + .collect(); + + string_index_clone1.insert_multiple(batch).unwrap(); + }); + + let handle2 = thread::spawn(move || { + let parser = TextParser::from_language(Locale::EN); + let batch: HashMap<_, _> = vec![( + DocumentId(2), + vec![( + FieldId(0), + "Concurrent insertion test document two.".to_string(), + )], + )] + .into_iter() + .map(|(doc_id, fields)| { + let fields: Vec<_> = fields + .into_iter() + .map(|(field_id, data)| { + let tokens = parser.tokenize_and_stem(&data); + (field_id, tokens) + }) + .collect(); + (doc_id, fields) + }) + .collect(); + + string_index_clone2.insert_multiple(batch).unwrap(); + }); + + handle1.join().unwrap(); + handle2.join().unwrap(); + + // After concurrent insertions, search for "concurrent" + let parser = TextParser::from_language(Locale::EN); + let search_tokens = parser + .tokenize_and_stem("concurrent") + .into_iter() + .map(|(original, _)| original) + .collect::>(); + + let output = string_index + .search( + search_tokens, + None, + Default::default(), + BM25Score::default(), + ) + .unwrap(); + + assert_eq!( + output.len(), + 2, + "Should find both documents after concurrent insertions" + ); + } + + #[test] + fn test_large_documents() { + let tmp_dir = TempDir::new("string_index_test_large_documents").unwrap(); + let tmp_dir: String = tmp_dir.into_path().to_str().unwrap().to_string(); + + let storage = Arc::new(crate::Storage::from_path(&tmp_dir)); + let string_index = StringIndex::new(storage); + let parser = TextParser::from_language(Locale::EN); + + let large_text = "word ".repeat(10000); // Create a large document + + let batch: HashMap<_, _> = vec![(DocumentId(1), vec![(FieldId(0), large_text.clone())])] + .into_iter() + .map(|(doc_id, fields)| { + let fields: Vec<_> = fields + .into_iter() + .map(|(field_id, data)| { + let tokens = parser.tokenize_and_stem(&data); + (field_id, tokens) + }) + .collect(); + (doc_id, fields) + }) + .collect(); + + string_index.insert_multiple(batch).unwrap(); + + // Search for "word" + let output = string_index + .search( + vec!["word".to_string()], + None, + Default::default(), + BM25Score::default(), + ) + .unwrap(); + + assert_eq!( + output.len(), + 1, + "Should find the document containing the large text" + ); + } + + #[test] + fn test_high_term_frequency() { + let tmp_dir = TempDir::new("string_index_test_high_term_frequency").unwrap(); + let tmp_dir: String = tmp_dir.into_path().to_str().unwrap().to_string(); + + let storage = Arc::new(crate::Storage::from_path(&tmp_dir)); + let string_index = StringIndex::new(storage); + let parser = TextParser::from_language(Locale::EN); + + let repeated_word = "repeat ".repeat(1000); // High term frequency + + let batch: HashMap<_, _> = vec![(DocumentId(1), vec![(FieldId(0), repeated_word.clone())])] + .into_iter() + .map(|(doc_id, fields)| { + let fields: Vec<_> = fields + .into_iter() + .map(|(field_id, data)| { + let tokens = parser.tokenize_and_stem(&data); + (field_id, tokens) + }) + .collect(); + (doc_id, fields) + }) + .collect(); + + string_index.insert_multiple(batch).unwrap(); + + // Search for "repeat" + let output = string_index + .search( + vec!["repeat".to_string()], + None, + Default::default(), + BM25Score::default(), + ) + .unwrap(); + + assert_eq!( + output.len(), + 1, + "Should find the document with high term frequency" + ); + } + + #[test] + fn test_term_positions() { + let tmp_dir = TempDir::new("string_index_test_term_positions").unwrap(); + let tmp_dir: String = tmp_dir.into_path().to_str().unwrap().to_string(); + let storage = Arc::new(crate::Storage::from_path(&tmp_dir)); + let string_index = StringIndex::new(storage); + let parser = TextParser::from_language(Locale::EN); + + let batch: HashMap<_, _> = vec![( + DocumentId(1), + vec![( + FieldId(0), + "quick brown fox jumps over the lazy dog".to_string(), + )], + )] + .into_iter() + .map(|(doc_id, fields)| { + let fields: Vec<_> = fields + .into_iter() + .map(|(field_id, data)| { + let tokens = parser.tokenize_and_stem(&data); + (field_id, tokens) + }) + .collect(); + (doc_id, fields) + }) + .collect(); + + string_index.insert_multiple(batch).unwrap(); } } From d671330e3e8233081d69151c1a11051779a443c7 Mon Sep 17 00:00:00 2001 From: Michele Riva Date: Sun, 17 Nov 2024 21:54:13 +0100 Subject: [PATCH 05/12] updates benchmarks --- string_index/benches/string_index_bench.rs | 85 +++++++--------------- 1 file changed, 27 insertions(+), 58 deletions(-) diff --git a/string_index/benches/string_index_bench.rs b/string_index/benches/string_index_bench.rs index f079dff..fcc6cc8 100644 --- a/string_index/benches/string_index_bench.rs +++ b/string_index/benches/string_index_bench.rs @@ -65,20 +65,19 @@ fn benchmark_indexing(c: &mut Criterion) { group.measurement_time(std::time::Duration::from_secs(30)); group.sample_size(50); - for size in [1000, 5000, 10_000].iter() { - // Reduced upper limit - group.bench_with_input(BenchmarkId::new("simple_index", size), size, |b, &size| { - let tmp_dir = tempdir::TempDir::new("bench_index").unwrap(); - let storage = Arc::new(Storage::from_path(tmp_dir.path().to_str().unwrap())); - let index = StringIndex::new(Arc::clone(&storage)); + for &size in &[1000, 5000, 10_000] { + group.bench_with_input(BenchmarkId::new("simple_index", size), &size, |b, &size| { let data = generate_test_data(size); + let batch: HashMap<_, _> = data.into_iter().collect(); b.iter(|| { - for (doc_id, fields) in data.iter() { - let mut batch = HashMap::new(); - batch.insert(*doc_id, fields.clone()); - index.insert_multiple(batch).expect("Insertion failed"); - } + let tmp_dir = tempdir::TempDir::new("bench_index").unwrap(); + let storage = Arc::new(Storage::from_path(tmp_dir.path().to_str().unwrap())); + let index = StringIndex::new(Arc::clone(&storage)); + + index + .insert_multiple(batch.clone()) + .expect("Insertion failed"); }); }); } @@ -91,27 +90,23 @@ fn benchmark_batch_indexing(c: &mut Criterion) { group.measurement_time(std::time::Duration::from_secs(30)); group.sample_size(30); - for size in [1000, 5000, 10_000].iter() { - for batch_size in [100, 500].iter() { + for &size in &[1000, 5000, 10_000] { + for &batch_size in &[100, 500] { group.bench_with_input( BenchmarkId::new(format!("batch_size_{}", batch_size), size), - size, + &size, |b, &size| { - let tmp_dir = tempdir::TempDir::new("bench_index").unwrap(); - let storage = Arc::new(Storage::from_path(tmp_dir.path().to_str().unwrap())); - let index = StringIndex::new(Arc::clone(&storage)); let data = generate_test_data(size); + let batch: HashMap<_, _> = data.into_iter().collect(); b.iter(|| { - for chunk in data.chunks(*batch_size) { - let mut batch = HashMap::new(); - for (doc_id, fields) in chunk { - batch.insert(*doc_id, fields.clone()); - } - index - .insert_multiple(batch) - .expect("Batch insertion failed"); - } + let tmp_dir = tempdir::TempDir::new("bench_index").unwrap(); + let storage = + Arc::new(Storage::from_path(tmp_dir.path().to_str().unwrap())); + let index = StringIndex::new(Arc::clone(&storage)); + index + .insert_multiple(batch.clone()) + .expect("Batch insertion failed"); }); }, ); @@ -144,17 +139,14 @@ fn benchmark_search(c: &mut Criterion) { }) .collect(); - for size in [1000, 5000].iter() { - group.bench_with_input(BenchmarkId::new("search", size), size, |b, &size| { + for &size in &[1000, 5000] { + group.bench_with_input(BenchmarkId::new("search", size), &size, |b, &size| { let tmp_dir = tempdir::TempDir::new("bench_index").unwrap(); let storage = Arc::new(Storage::from_path(tmp_dir.path().to_str().unwrap())); let index = StringIndex::new(Arc::clone(&storage)); let data = generate_test_data(size); - let mut batch = HashMap::new(); - for (doc_id, fields) in data { - batch.insert(doc_id, fields); - } + let batch: HashMap<_, _> = data.into_iter().collect(); index .insert_multiple(batch) .expect("Initial data insertion failed"); @@ -181,17 +173,14 @@ fn benchmark_concurrent_ops(c: &mut Criterion) { let mut group = c.benchmark_group("concurrent_operations"); group.measurement_time(std::time::Duration::from_secs(40)); - for size in [5000].iter() { - group.bench_with_input(BenchmarkId::new("concurrent", size), size, |b, &size| { + for &size in &[5000] { + group.bench_with_input(BenchmarkId::new("concurrent", size), &size, |b, &size| { let tmp_dir = tempdir::TempDir::new("bench_index").unwrap(); let storage = Arc::new(Storage::from_path(tmp_dir.path().to_str().unwrap())); let index = Arc::new(StringIndex::new(Arc::clone(&storage))); let data = generate_test_data(size); - let mut batch = HashMap::new(); - for (doc_id, fields) in data { - batch.insert(doc_id, fields); - } + let batch: HashMap<_, _> = data.into_iter().collect(); index .insert_multiple(batch) .expect("Initial data insertion failed"); @@ -229,35 +218,15 @@ fn benchmark_concurrent_ops(c: &mut Criterion) { }) .collect(); - let insert_threads: Vec<_> = (0..1) - .map(|_| { - let index = Arc::clone(&index); - let data = generate_test_data(50); - thread::spawn(move || { - for (doc_id, fields) in data { - let mut batch = HashMap::new(); - batch.insert(doc_id, fields); - index - .insert_multiple(batch) - .expect("Concurrent insert failed"); - } - }) - }) - .collect(); - for thread in search_threads { thread.join().expect("Search thread join failed"); } - for thread in insert_threads { - thread.join().expect("Insert thread join failed"); - } }); }); } group.finish(); } - criterion_group!( benches, benchmark_indexing, From ba4bf43a9f390234599cab979658f908971d4f77 Mon Sep 17 00:00:00 2001 From: Michele Riva Date: Mon, 18 Nov 2024 11:43:30 +0100 Subject: [PATCH 06/12] adds example --- rustorama/.gitignore | 1 + rustorama/README.md | 3 +++ rustorama/config.json | 8 ++++++ rustorama/insert.js | 62 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 74 insertions(+) create mode 100644 rustorama/.gitignore create mode 100644 rustorama/README.md create mode 100644 rustorama/config.json create mode 100644 rustorama/insert.js diff --git a/rustorama/.gitignore b/rustorama/.gitignore new file mode 100644 index 0000000..de5e3b9 --- /dev/null +++ b/rustorama/.gitignore @@ -0,0 +1 @@ +reviews.json \ No newline at end of file diff --git a/rustorama/README.md b/rustorama/README.md new file mode 100644 index 0000000..0134288 --- /dev/null +++ b/rustorama/README.md @@ -0,0 +1,3 @@ +# Rustorama + +Download the dataset from https://www.kaggle.com/datasets/abdallahwagih/amazon-reviews and save it as `reviews.json`. Then run `deno run -A insert.js` to upload all the reviews to a local rustorama instance. \ No newline at end of file diff --git a/rustorama/config.json b/rustorama/config.json new file mode 100644 index 0000000..a5a085a --- /dev/null +++ b/rustorama/config.json @@ -0,0 +1,8 @@ +{ + "data_dir": "/tmp/rustorama", + "http": { + "host": "127.0.0.1", + "port": 8080, + "allow_cors": true + } +} \ No newline at end of file diff --git a/rustorama/insert.js b/rustorama/insert.js new file mode 100644 index 0000000..bbd66ef --- /dev/null +++ b/rustorama/insert.js @@ -0,0 +1,62 @@ +import { readLines } from "https://deno.land/std@0.208.0/io/mod.ts"; + +const INPUT_FILE = "./reviews.json"; +const BATCH_SIZE = 2_000; +const API_URL = "http://127.0.0.1:8080/v0/collections/xyz/documents"; + +async function sendBatch(batch, currentCount, totalCount) { + try { + const response = await fetch(API_URL, { + method: "PATCH", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(batch), + }); + + if (response.ok) { + console.log(`Batch upload successful: ${currentCount}/${totalCount} uploaded`); + } else { + console.error(`Batch upload failed: HTTP ${response.status}`); + } + } catch (error) { + console.error("Error during batch upload:", error); + } +} + +async function processJsonlFile() { + const fileReader = await Deno.open(INPUT_FILE); + let batch = []; + let counter = 0; + let totalLines = 0; + + for await (const _ of readLines(await Deno.open(INPUT_FILE))) { + totalLines++; + } + + for await (const line of readLines(fileReader)) { + const json = JSON.parse(line.trim()); + + batch.push({ + id: counter.toString(), + reviewerName: json.reviewerName, + reviewText: json.reviewText, + summary: json.summary, + }); + counter++; + + if (batch.length === BATCH_SIZE) { + await sendBatch(batch, counter, totalLines); + batch = []; + } + } + + if (batch.length > 0) { + await sendBatch(batch, counter, totalLines); + } + + fileReader.close(); + console.log("All data sent."); +} + +await processJsonlFile(); From 4f278b442946eda3032ec61fa3fc96f12842f8ab Mon Sep 17 00:00:00 2001 From: Michele Riva Date: Mon, 18 Nov 2024 12:12:47 +0100 Subject: [PATCH 07/12] fixes division by 0 in scorer --- string_index/src/lib.rs | 2 +- string_index/src/scorer.rs | 239 ++++++++++++++++++++++++++++++++++++- 2 files changed, 235 insertions(+), 6 deletions(-) diff --git a/string_index/src/lib.rs b/string_index/src/lib.rs index 0ae9571..dda46f2 100644 --- a/string_index/src/lib.rs +++ b/string_index/src/lib.rs @@ -22,7 +22,7 @@ pub mod scorer; pub type DocumentBatch = HashMap)>)>>; -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Deserialize, Serialize, Clone)] pub struct Posting { pub document_id: DocumentId, pub field_id: FieldId, diff --git a/string_index/src/scorer.rs b/string_index/src/scorer.rs index 4a2171f..9338f90 100644 --- a/string_index/src/scorer.rs +++ b/string_index/src/scorer.rs @@ -28,7 +28,7 @@ pub mod bm25 { scores: DashMap, } impl BM25Score { - fn new() -> Self { + pub fn new() -> Self { Self { scores: DashMap::new(), } @@ -42,10 +42,20 @@ pub mod bm25 { avg_doc_length: f32, boost: f32, ) -> f32 { + if tf == 0.0 || doc_length == 0.0 || avg_doc_length == 0.0 { + return 0.0; + } + let k1 = 1.5; let b = 0.75; let numerator = tf * (k1 + 1.0); let denominator = tf + k1 * (1.0 - b + b * (doc_length / avg_doc_length)); + + // @todo: find a better way to avoid division by 0 + if denominator == 0.0 { + return 0.0; + } + idf * (numerator / denominator) * boost } } @@ -62,8 +72,13 @@ pub mod bm25 { let term_frequency = posting.term_frequency; let doc_length = posting.doc_length as f32; let total_documents = global_info.total_documents as f32; - let avg_doc_length = global_info.total_document_length as f32 / total_documents; + if total_documents == 0.0 { + self.scores.insert(posting.document_id, 0.0); + return; + } + + let avg_doc_length = global_info.total_document_length as f32 / total_documents; let idf = ((total_documents - total_token_count + 0.5) / (total_token_count + 0.5)).ln_1p(); let score = Self::calculate_score( @@ -106,14 +121,20 @@ pub mod code { } impl CodeScore { - fn new() -> Self { + pub fn new() -> Self { Self { scores: DashMap::new(), } } - fn calculate_score(pos: DashMap, boost: f32, doc_lenth: u16) -> f32 { + #[inline] + fn calculate_score(pos: DashMap, boost: f32, doc_length: u16) -> f32 { let mut foo: Vec<_> = pos.into_iter().map(|(p, v)| (p.0, v)).collect(); + + if foo.is_empty() || doc_length == 0 { + return 0.0; + } + foo.sort_by_key(|(p, _)| (*p as isize)); let pos_len = foo.len(); @@ -154,7 +175,12 @@ pub mod code { score += score_for_position; } - score / (pos_len * (doc_lenth as usize)) as f32 + let denominator = (pos_len * (doc_length as usize)) as f32; + if denominator == 0.0 { + 0.0 + } else { + score / denominator + } } } @@ -197,3 +223,206 @@ pub mod code { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::GlobalInfo; + use types::FieldId; + + #[test] + fn test_bm25_basic_scoring() { + let scorer = bm25::BM25Score::new(); + let global_info = GlobalInfo { + total_documents: 10, + total_document_length: 1000, + }; + + let posting = Posting { + field_id: FieldId(1), + document_id: DocumentId(1), + term_frequency: 5.0, + doc_length: 100, + positions: vec![1, 2, 3, 4, 5], + }; + + scorer.add_entry(&global_info, posting, 1.0, 1.0); + let scores = scorer.get_scores(); + assert!(scores.contains_key(&DocumentId(1))); + assert!(scores[&DocumentId(1)] > 0.0); + } + + #[test] + fn test_bm25_empty_document() { + let scorer = bm25::BM25Score::new(); + let global_info = GlobalInfo { + total_documents: 1, + total_document_length: 0, + }; + + let posting = Posting { + field_id: FieldId(1), + document_id: DocumentId(1), + term_frequency: 0.0, + doc_length: 0, + positions: vec![], + }; + + scorer.add_entry(&global_info, posting, 1.0, 1.0); + let scores = scorer.get_scores(); + assert_eq!(scores[&DocumentId(1)], 0.0); + } + + #[test] + fn test_bm25_boost_effect() { + let scorer = bm25::BM25Score::new(); + let global_info = GlobalInfo { + total_documents: 10, + total_document_length: 1000, + }; + + let posting = Posting { + field_id: FieldId(1), + document_id: DocumentId(1), + term_frequency: 5.0, + doc_length: 100, + positions: vec![1, 2, 3, 4, 5], + }; + + // Test with different boost values + scorer.add_entry(&global_info, posting.clone(), 1.0, 1.0); + let normal_scores = scorer.get_scores(); + + let scorer = bm25::BM25Score::new(); + scorer.add_entry(&global_info, posting, 1.0, 2.0); + let boosted_scores = scorer.get_scores(); + + assert!(boosted_scores[&DocumentId(1)] > normal_scores[&DocumentId(1)]); + } + + #[test] + fn test_code_score_basic() { + let scorer = code::CodeScore::new(); + let global_info = GlobalInfo { + total_documents: 10, + total_document_length: 1000, + }; + + let posting = Posting { + field_id: FieldId(1), + document_id: DocumentId(1), + term_frequency: 3.0, + doc_length: 100, + positions: vec![1, 3, 5], + }; + + scorer.add_entry(&global_info, posting, 1.0, 1.0); + let scores = scorer.get_scores(); + assert!(scores.contains_key(&DocumentId(1))); + assert!(scores[&DocumentId(1)] > 0.0); + } + + #[test] + fn test_code_score_adjacent_positions() { + let scorer = code::CodeScore::new(); + let global_info = GlobalInfo { + total_documents: 10, + total_document_length: 1000, + }; + + let posting_adjacent = Posting { + field_id: FieldId(1), + document_id: DocumentId(1), + term_frequency: 3.0, + doc_length: 100, + positions: vec![1, 2, 3], + }; + + let posting_spread = Posting { + field_id: FieldId(1), + document_id: DocumentId(2), + term_frequency: 3.0, + doc_length: 100, + positions: vec![1, 10, 20], + }; + + scorer.add_entry(&global_info, posting_adjacent, 1.0, 1.0); + scorer.add_entry(&global_info, posting_spread, 1.0, 1.0); + let scores = scorer.get_scores(); + + assert!(scores[&DocumentId(1)] > scores[&DocumentId(2)]); + } + + #[test] + fn test_code_score_empty_positions() { + let scorer = code::CodeScore::new(); + let global_info = GlobalInfo { + total_documents: 1, + total_document_length: 100, + }; + + let posting = Posting { + field_id: FieldId(1), + document_id: DocumentId(1), + term_frequency: 0.0, + doc_length: 100, + positions: vec![], + }; + + scorer.add_entry(&global_info, posting, 1.0, 1.0); + let scores = scorer.get_scores(); + assert_eq!(scores[&DocumentId(1)], 0.0); + } + + #[test] + fn test_code_score_single_position() { + let scorer = code::CodeScore::new(); + let global_info = GlobalInfo { + total_documents: 1, + total_document_length: 100, + }; + + let posting = Posting { + field_id: FieldId(1), + document_id: DocumentId(1), + term_frequency: 1.0, + doc_length: 100, + positions: vec![1], + }; + + scorer.add_entry(&global_info, posting, 1.0, 1.0); + let scores = scorer.get_scores(); + assert!(scores[&DocumentId(1)] > 0.0); + } + + #[test] + fn test_multiple_entries_same_document() { + let scorer = code::CodeScore::new(); + let global_info = GlobalInfo { + total_documents: 1, + total_document_length: 100, + }; + + let posting1 = Posting { + field_id: FieldId(1), + document_id: DocumentId(1), + term_frequency: 2.0, + doc_length: 100, + positions: vec![1, 2], + }; + + let posting2 = Posting { + field_id: FieldId(2), + document_id: DocumentId(1), + term_frequency: 2.0, + doc_length: 100, + positions: vec![3, 4], + }; + + scorer.add_entry(&global_info, posting1, 1.0, 1.0); + scorer.add_entry(&global_info, posting2, 1.0, 1.0); + + let scores = scorer.get_scores(); + assert!(scores[&DocumentId(1)] > 0.0); + } +} From c00970356d38f001030d08c95a02e6b2280fad12 Mon Sep 17 00:00:00 2001 From: Michele Riva Date: Mon, 18 Nov 2024 13:12:56 +0100 Subject: [PATCH 08/12] adds boosting for exact matches --- string_index/src/lib.rs | 48 ++++++++++++++++++++++++++++++++------ string_index/src/scorer.rs | 1 + 2 files changed, 42 insertions(+), 7 deletions(-) diff --git a/string_index/src/lib.rs b/string_index/src/lib.rs index dda46f2..c7d6eb5 100644 --- a/string_index/src/lib.rs +++ b/string_index/src/lib.rs @@ -132,7 +132,7 @@ impl StringIndex { Ok(postings) => { all_postings.push((postings, term_frequency)); } - Err(e) => {} + Err(_) => {} } } } @@ -144,24 +144,58 @@ impl StringIndex { total_document_length, }; - let mut filtered_postings = Vec::new(); + let mut filtered_postings: HashMap> = HashMap::new(); for (postings, term_freq) in all_postings { for posting in postings { if fields .map(|search_on| search_on.contains(&posting.field_id)) .unwrap_or(true) { - filtered_postings.push((posting, term_freq)); + filtered_postings + .entry(posting.document_id) + .or_insert_with(Vec::new) + .push((posting, term_freq)); } } } - for (posting, term_freq) in filtered_postings { - let boost_per_field = *boost.get(&posting.field_id).unwrap_or(&1.0); - scorer.add_entry(&global_info, posting, term_freq as f32, boost_per_field); + let mut exact_match_documents = Vec::new(); + + for (document_id, postings) in &filtered_postings { + let mut positions: Vec = postings + .iter() + .flat_map(|(posting, _)| posting.positions.clone()) + .collect(); + positions.sort_unstable(); + + if positions.windows(tokens.len()).any(|window| { + window + .iter() + .enumerate() + .all(|(i, &pos)| i == 0 || pos == window[i - 1] + 1) + }) { + exact_match_documents.push(*document_id); + } + + for (posting, term_freq) in postings { + let boost_per_field = *boost.get(&posting.field_id).unwrap_or(&1.0); + scorer.add_entry( + &global_info, + posting.clone(), + *term_freq as f32, + boost_per_field, + ); + } } - let scores = scorer.get_scores(); + let mut scores = scorer.get_scores(); + + let exact_match_boost = 5.0; // @todo: make this configurable. + for document_id in exact_match_documents { + if let Some(score) = scores.get_mut(&document_id) { + *score *= exact_match_boost; + } + } Ok(scores) } diff --git a/string_index/src/scorer.rs b/string_index/src/scorer.rs index 9338f90..281ceca 100644 --- a/string_index/src/scorer.rs +++ b/string_index/src/scorer.rs @@ -81,6 +81,7 @@ pub mod bm25 { let avg_doc_length = global_info.total_document_length as f32 / total_documents; let idf = ((total_documents - total_token_count + 0.5) / (total_token_count + 0.5)).ln_1p(); + let score = Self::calculate_score( term_frequency, idf, From d1ef7b8503163453f6187e604286a78997585823 Mon Sep 17 00:00:00 2001 From: Michele Riva Date: Mon, 18 Nov 2024 13:15:36 +0100 Subject: [PATCH 09/12] removes useless comments --- string_index/src/lib.rs | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/string_index/src/lib.rs b/string_index/src/lib.rs index c7d6eb5..292a4b4 100644 --- a/string_index/src/lib.rs +++ b/string_index/src/lib.rs @@ -357,7 +357,7 @@ mod tests { let output = string_index .search( - vec![], // Empty search tokens + vec![], None, Default::default(), BM25Score::default(), @@ -400,7 +400,7 @@ mod tests { let output = string_index .search( - vec!["nonexistent".to_string()], // Term does not exist + vec!["nonexistent".to_string()], None, Default::default(), BM25Score::default(), @@ -424,7 +424,7 @@ mod tests { let batch: HashMap<_, _> = vec![( DocumentId(1), - vec![(FieldId(0), "".to_string())], // Empty document content + vec![(FieldId(0), "".to_string())], )] .into_iter() .map(|(doc_id, fields)| { @@ -441,7 +441,6 @@ mod tests { string_index.insert_multiple(batch).unwrap(); - // Search for any term, should get empty result let output = string_index .search( vec!["test".to_string()], @@ -491,7 +490,7 @@ mod tests { let output = string_index .search( vec!["test".to_string()], - Some(vec![FieldId(0)]), // Search only in FieldId(0) + Some(vec![FieldId(0)]), Default::default(), BM25Score::default(), ) @@ -604,7 +603,7 @@ mod tests { let batch: HashMap<_, _> = vec![( DocumentId(1), - vec![(FieldId(0), "the and but or".to_string())], // Only stop words + vec![(FieldId(0), "the and but or".to_string())], )] .into_iter() .map(|(doc_id, fields)| { @@ -621,7 +620,6 @@ mod tests { string_index.insert_multiple(batch).unwrap(); - // Search for any term, should get empty result since only stop words are indexed let output = string_index .search( vec!["the".to_string()], @@ -726,7 +724,6 @@ mod tests { handle1.join().unwrap(); handle2.join().unwrap(); - // After concurrent insertions, search for "concurrent" let parser = TextParser::from_language(Locale::EN); let search_tokens = parser .tokenize_and_stem("concurrent") @@ -759,7 +756,7 @@ mod tests { let string_index = StringIndex::new(storage); let parser = TextParser::from_language(Locale::EN); - let large_text = "word ".repeat(10000); // Create a large document + let large_text = "word ".repeat(10000); let batch: HashMap<_, _> = vec![(DocumentId(1), vec![(FieldId(0), large_text.clone())])] .into_iter() @@ -777,7 +774,6 @@ mod tests { string_index.insert_multiple(batch).unwrap(); - // Search for "word" let output = string_index .search( vec!["word".to_string()], @@ -803,7 +799,7 @@ mod tests { let string_index = StringIndex::new(storage); let parser = TextParser::from_language(Locale::EN); - let repeated_word = "repeat ".repeat(1000); // High term frequency + let repeated_word = "repeat ".repeat(1000); let batch: HashMap<_, _> = vec![(DocumentId(1), vec![(FieldId(0), repeated_word.clone())])] .into_iter() @@ -821,7 +817,6 @@ mod tests { string_index.insert_multiple(batch).unwrap(); - // Search for "repeat" let output = string_index .search( vec!["repeat".to_string()], From da740a9ec9074c5feb77f3157650ebdc7e4bbea0 Mon Sep 17 00:00:00 2001 From: Michele Riva Date: Mon, 18 Nov 2024 15:41:22 +0100 Subject: [PATCH 10/12] fixes parallel insertion --- string_index/src/lib.rs | 115 ++++++++++++++++++++++++++++++++++------ 1 file changed, 98 insertions(+), 17 deletions(-) diff --git a/string_index/src/lib.rs b/string_index/src/lib.rs index 292a4b4..35c4aae 100644 --- a/string_index/src/lib.rs +++ b/string_index/src/lib.rs @@ -2,7 +2,7 @@ use std::{ collections::HashMap, sync::{ atomic::{AtomicUsize, Ordering}, - Arc, RwLock, + Arc, RwLock, Mutex }, }; @@ -44,6 +44,7 @@ pub struct StringIndex { dictionary: Dictionary, total_documents: AtomicUsize, total_document_length: AtomicUsize, + insert_mutex: Mutex<()>, } pub struct GlobalInfo { @@ -60,6 +61,7 @@ impl StringIndex { dictionary: Dictionary::new(), total_documents: AtomicUsize::new(0), total_document_length: AtomicUsize::new(0), + insert_mutex: Mutex::new(()), } } @@ -128,11 +130,8 @@ impl StringIndex { let posting_list_id = PostingListId(((packed_value >> 32) as u32) as usize); let term_frequency = (packed_value & 0xFFFFFFFF) as usize; - match self.posting_storage.get(posting_list_id) { - Ok(postings) => { - all_postings.push((postings, term_frequency)); - } - Err(_) => {} + if let Ok(postings) = self.posting_storage.get(posting_list_id) { + all_postings.push((postings, term_frequency)); } } } @@ -162,18 +161,14 @@ impl StringIndex { let mut exact_match_documents = Vec::new(); for (document_id, postings) in &filtered_postings { - let mut positions: Vec = postings + let mut token_positions: Vec> = postings .iter() - .flat_map(|(posting, _)| posting.positions.clone()) + .map(|(posting, _)| posting.positions.clone()) .collect(); - positions.sort_unstable(); - - if positions.windows(tokens.len()).any(|window| { - window - .iter() - .enumerate() - .all(|(i, &pos)| i == 0 || pos == window[i - 1] + 1) - }) { + + token_positions.iter_mut().for_each(|positions| positions.sort_unstable()); + + if self.is_phrase_match(&token_positions) { exact_match_documents.push(*document_id); } @@ -190,7 +185,7 @@ impl StringIndex { let mut scores = scorer.get_scores(); - let exact_match_boost = 5.0; // @todo: make this configurable. + let exact_match_boost = 12.0; // @todo: make this configurable. for document_id in exact_match_documents { if let Some(score) = scores.get_mut(&document_id) { *score *= exact_match_boost; @@ -201,6 +196,8 @@ impl StringIndex { } pub fn insert_multiple(&self, data: DocumentBatch) -> Result<()> { + let _lock = self.insert_mutex.lock().unwrap(); + self.total_documents .fetch_add(data.len(), Ordering::Relaxed); @@ -318,6 +315,41 @@ impl StringIndex { Ok(()) } + + fn is_phrase_match(&self, token_positions: &Vec>) -> bool { + if token_positions.is_empty() { + return false; + } + + let position_sets: Vec> = token_positions + .iter() + .skip(1) + .map(|positions| positions.iter().copied().collect()) + .collect(); + + for &start_pos in &token_positions[0] { + let mut current_pos = start_pos; + let mut matched = true; + + for positions in &position_sets { + let next_pos = current_pos + 1; + if positions.contains(&next_pos) { + current_pos = next_pos; + } else { + matched = false; + break; + } + } + + if matched { + return true; + } + } + + false + } + + } #[cfg(test)] @@ -863,4 +895,53 @@ mod tests { string_index.insert_multiple(batch).unwrap(); } + + #[test] + fn test_exact_phrase_match() { + let tmp_dir = TempDir::new("string_index_test_exact_phrase").unwrap(); + let tmp_dir: String = tmp_dir.into_path().to_str().unwrap().to_string(); + + let storage = Arc::new(crate::Storage::from_path(&tmp_dir)); + let string_index = StringIndex::new(storage); + let parser = TextParser::from_language(Locale::EN); + + let batch: HashMap<_, _> = vec![( + DocumentId(1), + vec![(FieldId(0), "5200 mAh battery in disguise".to_string())], + )] + .into_iter() + .map(|(doc_id, fields)| { + let fields: Vec<_> = fields + .into_iter() + .map(|(field_id, data)| { + let tokens = parser.tokenize_and_stem(&data); + (field_id, tokens) + }) + .collect(); + (doc_id, fields) + }) + .collect(); + + string_index.insert_multiple(batch).unwrap(); + + let output = string_index + .search( + vec!["5200".to_string(), "mAh".to_string(), "battery".to_string(), "in".to_string(), "disguise".to_string()], + Some(vec![FieldId(0)]), + Default::default(), + BM25Score::default(), + ) + .unwrap(); + + assert_eq!( + output.len(), + 1, + "Should find the document containing the exact phrase" + ); + + assert!( + output.contains_key(&DocumentId(1)), + "Document with ID 1 should be found" + ); + } } From 27d45ea8d1c74f9e702e8d295f682009979fa563 Mon Sep 17 00:00:00 2001 From: Michele Riva Date: Mon, 18 Nov 2024 21:36:24 +0100 Subject: [PATCH 11/12] wip --- nlp/src/lib.rs | 1 + nlp/src/tokenizer.rs | 8 ++++++++ 2 files changed, 9 insertions(+) diff --git a/nlp/src/lib.rs b/nlp/src/lib.rs index 6c8c546..6a43b77 100644 --- a/nlp/src/lib.rs +++ b/nlp/src/lib.rs @@ -34,6 +34,7 @@ impl Clone for TextParser { impl TextParser { pub fn from_language(locale: Locale) -> Self { let (tokenizer, stemmer) = match locale { + Locale::IT => (Tokenizer::italian(), Stemmer::create(Algorithm::Italian)), Locale::EN => (Tokenizer::english(), Stemmer::create(Algorithm::English)), // @todo: manage other locales _ => (Tokenizer::english(), Stemmer::create(Algorithm::English)), diff --git a/nlp/src/tokenizer.rs b/nlp/src/tokenizer.rs index de052d7..b404754 100644 --- a/nlp/src/tokenizer.rs +++ b/nlp/src/tokenizer.rs @@ -17,6 +17,14 @@ impl Tokenizer { stop_words, } } + + pub fn italian() -> Self { + let stop_words: HashSet<&str> = Locale::IT.stop_words().unwrap(); + Tokenizer { + split_regex: Locale::IT.split_regex().unwrap(), + stop_words + } + } pub fn tokenize<'a, 'b>(&'a self, input: &'b str) -> impl Iterator + 'b where From 8088a6b417423867d72f302635f126955e87449b Mon Sep 17 00:00:00 2001 From: Michele Riva Date: Mon, 18 Nov 2024 21:38:26 +0100 Subject: [PATCH 12/12] wip --- string_index/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/string_index/src/lib.rs b/string_index/src/lib.rs index 35c4aae..4b84da1 100644 --- a/string_index/src/lib.rs +++ b/string_index/src/lib.rs @@ -185,7 +185,7 @@ impl StringIndex { let mut scores = scorer.get_scores(); - let exact_match_boost = 12.0; // @todo: make this configurable. + let exact_match_boost = 25.0; // @todo: make this configurable. for document_id in exact_match_documents { if let Some(score) = scores.get_mut(&document_id) { *score *= exact_match_boost;