diff --git a/Cargo.lock b/Cargo.lock index 416025f..143efa3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -286,6 +286,12 @@ dependencies = [ "libc", ] +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + [[package]] name = "anstream" version = "0.6.18" @@ -614,9 +620,9 @@ dependencies = [ [[package]] name = "axum" -version = "0.7.7" +version = "0.7.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "504e3947307ac8326a5437504c517c4b56716c9d98fac0028c2acc7ca47d70ae" +checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" dependencies = [ "async-trait", "axum-core", @@ -971,8 +977,8 @@ dependencies = [ [[package]] name = "candle-core" -version = "0.7.2" -source = "git+https://github.com/EricLBuehler/candle.git?rev=11495ab#11495abeba0c1d27b680168edd0cd52189a7ca30" +version = "0.8.0" +source = "git+https://github.com/EricLBuehler/candle.git?rev=cb8082b#cb8082bf28eadb4140c1774f002983a6d182bf3e" dependencies = [ "byteorder", "candle-metal-kernels", @@ -996,8 +1002,8 @@ dependencies = [ [[package]] name = "candle-metal-kernels" -version = "0.7.2" -source = "git+https://github.com/EricLBuehler/candle.git?rev=11495ab#11495abeba0c1d27b680168edd0cd52189a7ca30" +version = "0.8.0" +source = "git+https://github.com/EricLBuehler/candle.git?rev=cb8082b#cb8082bf28eadb4140c1774f002983a6d182bf3e" dependencies = [ "metal 0.27.0", "once_cell", @@ -1007,8 +1013,8 @@ dependencies = [ [[package]] name = "candle-nn" -version = "0.7.2" -source = "git+https://github.com/EricLBuehler/candle.git?rev=11495ab#11495abeba0c1d27b680168edd0cd52189a7ca30" +version = "0.8.0" +source = "git+https://github.com/EricLBuehler/candle.git?rev=cb8082b#cb8082bf28eadb4140c1774f002983a6d182bf3e" dependencies = [ "candle-core", "candle-metal-kernels", @@ -1031,6 +1037,12 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + [[package]] name = "cc" version = "1.2.1" @@ -1101,6 +1113,33 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + [[package]] name = "clang-sys" version = "1.8.1" @@ -1406,6 +1445,42 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "is-terminal", + "itertools 0.10.5", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools 0.10.5", +] + [[package]] name = "crossbeam-deque" version = "0.8.5" @@ -2273,6 +2348,12 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "fst" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ab85b9b05e3978cc9a9cf8fea7f01b494e1a09ed3037e16ba39edc7a29eb61a" + [[package]] name = "fuchsia-cprng" version = "0.1.1" @@ -3278,12 +3359,32 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "is-terminal" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "261f68e344040fbd0edea105bef17c66edf46f984ddb1115b775ce31be948f4b" +dependencies = [ + "hermit-abi 0.4.0", + "libc", + "windows-sys 0.52.0", +] + [[package]] name = "is_terminal_polyfill" version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + [[package]] name = "itertools" version = "0.11.0" @@ -3396,9 +3497,9 @@ checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8" [[package]] name = "libc" -version = "0.2.162" +version = "0.2.164" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18d287de67fe55fd7e1581fe933d965a5a9477b38e949cfa9f8574ef01506398" +checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f" [[package]] name = "libflate" @@ -3853,7 +3954,7 @@ dependencies = [ [[package]] name = "mistralrs" version = "0.3.2" -source = "git+https://github.com/EricLBuehler/mistral.rs.git#ddc63f1e0433356789cd875c3e39df16df0d0a43" +source = "git+https://github.com/EricLBuehler/mistral.rs.git#72eef3e212ae11a2824f89e72657e7c3dc69fba8" dependencies = [ "anyhow", "candle-core", @@ -3872,7 +3973,7 @@ dependencies = [ [[package]] name = "mistralrs-core" version = "0.3.2" -source = "git+https://github.com/EricLBuehler/mistral.rs.git#ddc63f1e0433356789cd875c3e39df16df0d0a43" +source = "git+https://github.com/EricLBuehler/mistral.rs.git#72eef3e212ae11a2824f89e72657e7c3dc69fba8" dependencies = [ "akin", "anyhow", @@ -3939,7 +4040,7 @@ dependencies = [ [[package]] name = "mistralrs-quant" version = "0.3.2" -source = "git+https://github.com/EricLBuehler/mistral.rs.git#ddc63f1e0433356789cd875c3e39df16df0d0a43" +source = "git+https://github.com/EricLBuehler/mistral.rs.git#72eef3e212ae11a2824f89e72657e7c3dc69fba8" dependencies = [ "byteorder", "candle-core", @@ -3958,7 +4059,7 @@ dependencies = [ [[package]] name = "mistralrs-vision" version = "0.3.2" -source = "git+https://github.com/EricLBuehler/mistral.rs.git#ddc63f1e0433356789cd875c3e39df16df0d0a43" +source = "git+https://github.com/EricLBuehler/mistral.rs.git#72eef3e212ae11a2824f89e72657e7c3dc69fba8" dependencies = [ "candle-core", "image", @@ -4349,6 +4450,12 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "oorandom" +version = "11.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" + [[package]] name = "openai" version = "1.0.0-alpha.16" @@ -4700,6 +4807,34 @@ dependencies = [ "time", ] +[[package]] +name = "plotters" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" + +[[package]] +name = "plotters-svg" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" +dependencies = [ + "plotters-backend", +] + [[package]] name = "png" version = "0.17.14" @@ -5515,9 +5650,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.16" +version = "0.23.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eee87ff5d9b36712a58574e12e9f0ea80f915a5b0ac518d322b24a465617925e" +checksum = "7f1a745511c54ba6d4465e8d5dfbd81b45791756de28d4981af70d6dca128f1e" dependencies = [ "log", "once_cell", @@ -5741,9 +5876,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.132" +version = "1.0.133" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" +checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377" dependencies = [ "itoa", "memchr", @@ -6100,12 +6235,15 @@ version = "0.1.0" dependencies = [ "anyhow", "bincode", + "criterion", "csv", "dashmap", + "fst", "itertools 0.13.0", "nlp", "ptrie 0.7.1 (git+https://github.com/oramasearch/ptrie.git?branch=feat/expose-get-mut)", "radix_trie", + "rand 0.8.5", "rayon", "serde", "smallvec", @@ -6730,6 +6868,16 @@ dependencies = [ "zerovec", ] +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "tinyvec" version = "1.8.0" diff --git a/nlp/src/lib.rs b/nlp/src/lib.rs index 6c8c546..6a43b77 100644 --- a/nlp/src/lib.rs +++ b/nlp/src/lib.rs @@ -34,6 +34,7 @@ impl Clone for TextParser { impl TextParser { pub fn from_language(locale: Locale) -> Self { let (tokenizer, stemmer) = match locale { + Locale::IT => (Tokenizer::italian(), Stemmer::create(Algorithm::Italian)), Locale::EN => (Tokenizer::english(), Stemmer::create(Algorithm::English)), // @todo: manage other locales _ => (Tokenizer::english(), Stemmer::create(Algorithm::English)), diff --git a/nlp/src/tokenizer.rs b/nlp/src/tokenizer.rs index de052d7..b404754 100644 --- a/nlp/src/tokenizer.rs +++ b/nlp/src/tokenizer.rs @@ -17,6 +17,14 @@ impl Tokenizer { stop_words, } } + + pub fn italian() -> Self { + let stop_words: HashSet<&str> = Locale::IT.stop_words().unwrap(); + Tokenizer { + split_regex: Locale::IT.split_regex().unwrap(), + stop_words + } + } pub fn tokenize<'a, 'b>(&'a self, input: &'b str) -> impl Iterator + 'b where diff --git a/rustorama/.gitignore b/rustorama/.gitignore new file mode 100644 index 0000000..de5e3b9 --- /dev/null +++ b/rustorama/.gitignore @@ -0,0 +1 @@ +reviews.json \ No newline at end of file diff --git a/rustorama/README.md b/rustorama/README.md new file mode 100644 index 0000000..0134288 --- /dev/null +++ b/rustorama/README.md @@ -0,0 +1,3 @@ +# Rustorama + +Download the dataset from https://www.kaggle.com/datasets/abdallahwagih/amazon-reviews and save it as `reviews.json`. Then run `deno run -A insert.js` to upload all the reviews to a local rustorama instance. \ No newline at end of file diff --git a/rustorama/config.json b/rustorama/config.json new file mode 100644 index 0000000..a5a085a --- /dev/null +++ b/rustorama/config.json @@ -0,0 +1,8 @@ +{ + "data_dir": "/tmp/rustorama", + "http": { + "host": "127.0.0.1", + "port": 8080, + "allow_cors": true + } +} \ No newline at end of file diff --git a/rustorama/insert.js b/rustorama/insert.js new file mode 100644 index 0000000..bbd66ef --- /dev/null +++ b/rustorama/insert.js @@ -0,0 +1,62 @@ +import { readLines } from "https://deno.land/std@0.208.0/io/mod.ts"; + +const INPUT_FILE = "./reviews.json"; +const BATCH_SIZE = 2_000; +const API_URL = "http://127.0.0.1:8080/v0/collections/xyz/documents"; + +async function sendBatch(batch, currentCount, totalCount) { + try { + const response = await fetch(API_URL, { + method: "PATCH", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(batch), + }); + + if (response.ok) { + console.log(`Batch upload successful: ${currentCount}/${totalCount} uploaded`); + } else { + console.error(`Batch upload failed: HTTP ${response.status}`); + } + } catch (error) { + console.error("Error during batch upload:", error); + } +} + +async function processJsonlFile() { + const fileReader = await Deno.open(INPUT_FILE); + let batch = []; + let counter = 0; + let totalLines = 0; + + for await (const _ of readLines(await Deno.open(INPUT_FILE))) { + totalLines++; + } + + for await (const line of readLines(fileReader)) { + const json = JSON.parse(line.trim()); + + batch.push({ + id: counter.toString(), + reviewerName: json.reviewerName, + reviewText: json.reviewText, + summary: json.summary, + }); + counter++; + + if (batch.length === BATCH_SIZE) { + await sendBatch(batch, counter, totalLines); + batch = []; + } + } + + if (batch.length > 0) { + await sendBatch(batch, counter, totalLines); + } + + fileReader.close(); + console.log("All data sent."); +} + +await processJsonlFile(); diff --git a/string_index/Cargo.toml b/string_index/Cargo.toml index e466cb9..da37682 100644 --- a/string_index/Cargo.toml +++ b/string_index/Cargo.toml @@ -3,6 +3,11 @@ name = "string_index" version = "0.1.0" edition = "2021" +[[bench]] +name = "string_index_bench" +harness = false +path = "benches/string_index_bench.rs" + [dependencies] anyhow = "1.0.90" bincode = "1.3.3" @@ -18,7 +23,10 @@ thiserror = "1.0.65" smallvec = "1.13.2" types = { path = "../types" } ptrie = { git = "https://github.com/oramasearch/ptrie.git", branch = "feat/expose-get-mut" } +fst = "0.4.7" +rand = "0.8.5" [dev-dependencies] +criterion = "0.5.1" nlp = { path = "../nlp" } - +tempdir = "0.3.7" diff --git a/string_index/benches/string_index_bench.rs b/string_index/benches/string_index_bench.rs new file mode 100644 index 0000000..fcc6cc8 --- /dev/null +++ b/string_index/benches/string_index_bench.rs @@ -0,0 +1,237 @@ +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use rand::{seq::SliceRandom, Rng}; +use std::collections::HashMap; +use std::sync::Arc; +use storage::Storage; +use string_index::scorer::bm25::BM25Score; +use string_index::StringIndex; +use types::{DocumentId, FieldId}; + +#[derive(Debug)] +struct BenchmarkResults { + docs_count: usize, + index_size_bytes: u64, + indexing_time_ms: u64, + avg_search_time_ms: f64, + memory_usage_mb: f64, + throughput_docs_per_sec: f64, +} + +fn generate_test_data( + num_docs: usize, +) -> Vec<(DocumentId, Vec<(FieldId, Vec<(String, Vec)>)>)> { + let mut rng = rand::thread_rng(); + + let vocabulary: Vec = (0..5000) + .map(|_| { + let len = rng.gen_range(3..8); + (0..len) + .map(|_| (b'a' + rng.gen_range(0..26)) as char) + .collect() + }) + .collect(); + + (0..num_docs) + .map(|i| { + let doc_id = DocumentId(i as u64); + let num_fields = rng.gen_range(1..=3); + + let fields = (0..num_fields) + .map(|field_i| { + let field_id = FieldId(field_i); + let num_terms = rng.gen_range(5..50); + + let terms = (0..num_terms) + .map(|_| { + let word = vocabulary.choose(&mut rng).unwrap().clone(); + let stems = (0..rng.gen_range(0..2)) + .map(|_| vocabulary.choose(&mut rng).unwrap().clone()) + .collect(); + (word, stems) + }) + .collect(); + + (field_id, terms) + }) + .collect(); + + (doc_id, fields) + }) + .collect() +} + +fn benchmark_indexing(c: &mut Criterion) { + let mut group = c.benchmark_group("indexing"); + group.measurement_time(std::time::Duration::from_secs(30)); + group.sample_size(50); + + for &size in &[1000, 5000, 10_000] { + group.bench_with_input(BenchmarkId::new("simple_index", size), &size, |b, &size| { + let data = generate_test_data(size); + let batch: HashMap<_, _> = data.into_iter().collect(); + + b.iter(|| { + let tmp_dir = tempdir::TempDir::new("bench_index").unwrap(); + let storage = Arc::new(Storage::from_path(tmp_dir.path().to_str().unwrap())); + let index = StringIndex::new(Arc::clone(&storage)); + + index + .insert_multiple(batch.clone()) + .expect("Insertion failed"); + }); + }); + } + + group.finish(); +} + +fn benchmark_batch_indexing(c: &mut Criterion) { + let mut group = c.benchmark_group("batch_indexing"); + group.measurement_time(std::time::Duration::from_secs(30)); + group.sample_size(30); + + for &size in &[1000, 5000, 10_000] { + for &batch_size in &[100, 500] { + group.bench_with_input( + BenchmarkId::new(format!("batch_size_{}", batch_size), size), + &size, + |b, &size| { + let data = generate_test_data(size); + let batch: HashMap<_, _> = data.into_iter().collect(); + + b.iter(|| { + let tmp_dir = tempdir::TempDir::new("bench_index").unwrap(); + let storage = + Arc::new(Storage::from_path(tmp_dir.path().to_str().unwrap())); + let index = StringIndex::new(Arc::clone(&storage)); + index + .insert_multiple(batch.clone()) + .expect("Batch insertion failed"); + }); + }, + ); + } + } + + group.finish(); +} + +fn benchmark_search(c: &mut Criterion) { + let mut group = c.benchmark_group("search"); + group.measurement_time(std::time::Duration::from_secs(20)); + + let mut rng = rand::thread_rng(); + let vocabulary: Vec = (0..500) + .map(|_| { + let len = rng.gen_range(3..8); + (0..len) + .map(|_| (b'a' + rng.gen_range(0..26)) as char) + .collect() + }) + .collect(); + + let queries: Vec> = (0..100) + .map(|_| { + let num_terms = rng.gen_range(1..3); + (0..num_terms) + .map(|_| vocabulary.choose(&mut rng).unwrap().clone()) + .collect() + }) + .collect(); + + for &size in &[1000, 5000] { + group.bench_with_input(BenchmarkId::new("search", size), &size, |b, &size| { + let tmp_dir = tempdir::TempDir::new("bench_index").unwrap(); + let storage = Arc::new(Storage::from_path(tmp_dir.path().to_str().unwrap())); + let index = StringIndex::new(Arc::clone(&storage)); + + let data = generate_test_data(size); + let batch: HashMap<_, _> = data.into_iter().collect(); + index + .insert_multiple(batch) + .expect("Initial data insertion failed"); + + b.iter(|| { + for query in queries.iter() { + let _ = index + .search( + query.clone(), + None, + Default::default(), + BM25Score::default(), + ) + .expect("Search failed"); + } + }); + }); + } + + group.finish(); +} + +fn benchmark_concurrent_ops(c: &mut Criterion) { + let mut group = c.benchmark_group("concurrent_operations"); + group.measurement_time(std::time::Duration::from_secs(40)); + + for &size in &[5000] { + group.bench_with_input(BenchmarkId::new("concurrent", size), &size, |b, &size| { + let tmp_dir = tempdir::TempDir::new("bench_index").unwrap(); + let storage = Arc::new(Storage::from_path(tmp_dir.path().to_str().unwrap())); + let index = Arc::new(StringIndex::new(Arc::clone(&storage))); + + let data = generate_test_data(size); + let batch: HashMap<_, _> = data.into_iter().collect(); + index + .insert_multiple(batch) + .expect("Initial data insertion failed"); + + let queries = generate_test_data(50) + .into_iter() + .map(|(_, fields)| { + fields + .into_iter() + .flat_map(|(_, terms)| terms) + .map(|(term, _)| term) + .collect::>() + }) + .collect::>(); + + b.iter(|| { + use std::thread; + + let search_threads: Vec<_> = (0..2) + .map(|_| { + let index = Arc::clone(&index); + let queries = queries.clone(); + thread::spawn(move || { + for query in &queries { + let _ = index + .search( + query.clone(), + None, + Default::default(), + BM25Score::default(), + ) + .expect("Concurrent search failed"); + } + }) + }) + .collect(); + + for thread in search_threads { + thread.join().expect("Search thread join failed"); + } + }); + }); + } + + group.finish(); +} +criterion_group!( + benches, + benchmark_indexing, + benchmark_batch_indexing, + benchmark_search, + benchmark_concurrent_ops +); +criterion_main!(benches); diff --git a/string_index/src/lib.rs b/string_index/src/lib.rs index deb67b8..4b84da1 100644 --- a/string_index/src/lib.rs +++ b/string_index/src/lib.rs @@ -2,15 +2,14 @@ use std::{ collections::HashMap, sync::{ atomic::{AtomicUsize, Ordering}, - Arc, RwLock, + Arc, RwLock, Mutex }, }; use anyhow::Result; use dictionary::{Dictionary, TermId}; +use fst::{Automaton, IntoStreamer, Map, MapBuilder, Streamer}; use posting_storage::{PostingListId, PostingStorage}; -// use radix_trie::Trie; -use ptrie::Trie; use rayon::iter::{IntoParallelIterator, ParallelIterator}; use scorer::Scorer; use serde::{Deserialize, Serialize}; @@ -23,7 +22,7 @@ pub mod scorer; pub type DocumentBatch = HashMap)>)>>; -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Deserialize, Serialize, Clone)] pub struct Posting { pub document_id: DocumentId, pub field_id: FieldId, @@ -32,18 +31,20 @@ pub struct Posting { pub doc_length: u16, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy)] pub struct StringIndexValue { - posting_list_id: PostingListId, - term_frequency: usize, + pub posting_list_id: PostingListId, + pub term_frequency: usize, } pub struct StringIndex { - tree: RwLock>, + fst_map: RwLock>>>, + temp_map: RwLock>, posting_storage: PostingStorage, dictionary: Dictionary, total_documents: AtomicUsize, total_document_length: AtomicUsize, + insert_mutex: Mutex<()>, } pub struct GlobalInfo { @@ -54,14 +55,44 @@ pub struct GlobalInfo { impl StringIndex { pub fn new(storage: Arc) -> Self { StringIndex { - tree: RwLock::new(Trie::new()), + fst_map: RwLock::new(None), + temp_map: RwLock::new(HashMap::new()), posting_storage: PostingStorage::new(storage), dictionary: Dictionary::new(), total_documents: AtomicUsize::new(0), total_document_length: AtomicUsize::new(0), + insert_mutex: Mutex::new(()), } } + fn build_fst(&self) -> Result>> { + let entries: Vec<_> = { + let temp_map = self.temp_map.read().unwrap(); + temp_map + .iter() + .map(|(key, value)| { + ( + key.clone(), + ((value.posting_list_id.0 as u64) << 32) | (value.term_frequency as u64), + ) + }) + .collect() + }; + + let mut sorted_entries = entries; + sorted_entries.sort_by(|(a, _), (b, _)| a.cmp(b)); + + let mut builder = MapBuilder::memory(); + + for (key, value) in sorted_entries { + builder.insert(key.as_bytes(), value)?; + } + + let fst = builder.into_map(); + + Ok(fst) + } + pub fn get_total_documents(&self) -> usize { self.total_documents.load(Ordering::Relaxed) } @@ -75,85 +106,105 @@ impl StringIndex { ) -> Result> { let total_documents = match self.total_documents.load(Ordering::Relaxed) { 0 => { - println!("total_documents == 0"); return Ok(Default::default()); } total_documents => total_documents, }; + let total_document_length = match self.total_document_length.load(Ordering::Relaxed) { 0 => { - println!("total_document_length == 0"); return Ok(Default::default()); } total_document_length => total_document_length, }; - // let avg_doc_length = total_document_length / total_documents; - - let mut posting_list_ids_with_freq = Vec::::new(); - let tree = self - .tree - .read() - // TODO: better error handling - .expect("Unable to read"); - for token in tokens { - let a = tree.find_postfixes(token.bytes()); - posting_list_ids_with_freq.extend(a.into_iter().cloned()); + let mut all_postings = Vec::new(); + + if let Some(fst) = self.fst_map.read().unwrap().as_ref() { + for token in &tokens { + let automaton = fst::automaton::Str::new(token).starts_with(); + let mut stream = fst.search(automaton).into_stream(); + + while let Some(result) = stream.next() { + let (_, packed_value) = result; + let posting_list_id = PostingListId(((packed_value >> 32) as u32) as usize); + let term_frequency = (packed_value & 0xFFFFFFFF) as usize; + + if let Ok(postings) = self.posting_storage.get(posting_list_id) { + all_postings.push((postings, term_frequency)); + } + } + } } let fields = search_on.as_ref(); - let global_info = GlobalInfo { total_documents, total_document_length, }; - posting_list_ids_with_freq - .into_par_iter() - .filter_map(|string_index_value| { - let output = self - .posting_storage - .get(string_index_value.posting_list_id) - .ok(); - let posting = match output { - Some(v) => v, - None => return None, - }; - - let posting: Vec<_> = posting - .into_iter() - .filter(move |posting| { - fields - .map(|search_on| search_on.contains(&posting.field_id)) - .unwrap_or(true) - }) - .collect(); - - Some((posting, string_index_value.term_frequency)) - }) - // Every thread perform on a separated hashmap - .for_each(|(postings, total_token_count)| { - let total_token_count = total_token_count as f32; - for posting in postings { - let boost_per_field = *boost.get(&posting.field_id).unwrap_or(&1.0); - scorer.add_entry(&global_info, posting, total_token_count, boost_per_field); + let mut filtered_postings: HashMap> = HashMap::new(); + for (postings, term_freq) in all_postings { + for posting in postings { + if fields + .map(|search_on| search_on.contains(&posting.field_id)) + .unwrap_or(true) + { + filtered_postings + .entry(posting.document_id) + .or_insert_with(Vec::new) + .push((posting, term_freq)); } - }); + } + } + + let mut exact_match_documents = Vec::new(); + + for (document_id, postings) in &filtered_postings { + let mut token_positions: Vec> = postings + .iter() + .map(|(posting, _)| posting.positions.clone()) + .collect(); + + token_positions.iter_mut().for_each(|positions| positions.sort_unstable()); + + if self.is_phrase_match(&token_positions) { + exact_match_documents.push(*document_id); + } - let scores = scorer.get_scores(); + for (posting, term_freq) in postings { + let boost_per_field = *boost.get(&posting.field_id).unwrap_or(&1.0); + scorer.add_entry( + &global_info, + posting.clone(), + *term_freq as f32, + boost_per_field, + ); + } + } + + let mut scores = scorer.get_scores(); + + let exact_match_boost = 25.0; // @todo: make this configurable. + for document_id in exact_match_documents { + if let Some(score) = scores.get_mut(&document_id) { + *score *= exact_match_boost; + } + } Ok(scores) } pub fn insert_multiple(&self, data: DocumentBatch) -> Result<()> { + let _lock = self.insert_mutex.lock().unwrap(); + self.total_documents .fetch_add(data.len(), Ordering::Relaxed); let dictionary = &self.dictionary; - let t = data + let posting_per_term = data .into_par_iter() - // Parallel .fold( HashMap::>::new, |mut acc, (document_id, strings)| { @@ -185,8 +236,6 @@ impl StringIndex { for (term, field_positions) in term_freqs { let term_id = dictionary.get_or_add(&term); - // println!("Term: {} -> {}", term, term_id.0); - let v = acc.entry(term_id).or_default(); let posting = @@ -197,7 +246,6 @@ impl StringIndex { document_id, field_id, positions, - // original_term: term.clone(), term_frequency, doc_length: doc_length as u16, } @@ -208,28 +256,19 @@ impl StringIndex { acc }, - ); - - let posting_per_term = t.reduce( - HashMap::>::new, - // Merge the hashmap - |mut acc, item| { + ) + .reduce(HashMap::>::new, |mut acc, item| { for (term_id, postings) in item { let vec = acc.entry(term_id).or_default(); vec.extend(postings.into_iter()); } acc - }, - ); + }); let mut postings_per_posting_list_id: HashMap>> = HashMap::with_capacity(posting_per_term.len()); - let mut tree = self.tree.write().expect("Unable to write"); - // NB: We cannot parallelize the tree insertion yet :( - // We could move the tree into a custom implementation to support parallelism - // Once we resolve this issue, we StringIndex is thread safe! - // TODO: move to custom implementation - // For the time being, we can just use the sync tree + let mut temp_map = self.temp_map.write().unwrap(); + for (term_id, postings) in posting_per_term { self.total_document_length.fetch_add( postings.iter().map(|p| p.positions.len()).sum::(), @@ -237,15 +276,9 @@ impl StringIndex { ); let number_of_occurence_of_term = postings.len(); - // Due to this implementation, we have a limitation - // because we "forgot" the term. Here we have just the term_id - // This invocation shouldn't exist at all: - // we have the term on the top of this function - // TODO: find a way to avoid this invocation let term = dictionary.retrive(term_id); - let value = tree.get_mut(term.bytes()); - if let Some(value) = value { + if let Some(value) = temp_map.get_mut(&term) { value.term_frequency += number_of_occurence_of_term; let vec = postings_per_posting_list_id .entry(value.posting_list_id) @@ -253,8 +286,8 @@ impl StringIndex { vec.push(postings); } else { let posting_list_id = self.posting_storage.generate_new_id(); - tree.insert( - term.bytes(), + temp_map.insert( + term, StringIndexValue { posting_list_id, term_frequency: number_of_occurence_of_term, @@ -268,30 +301,274 @@ impl StringIndex { } } - postings_per_posting_list_id - .into_par_iter() - .map(|(k, v)| self.posting_storage.add_or_create(k, v)) - // TODO: handle error - .all(|_| true); + drop(temp_map); + let new_fst = self.build_fst()?; + + { + let mut fst_map = self.fst_map.write().unwrap(); + *fst_map = Some(new_fst); + } + + for (k, v) in postings_per_posting_list_id { + self.posting_storage.add_or_create(k, v)?; + } Ok(()) } + + fn is_phrase_match(&self, token_positions: &Vec>) -> bool { + if token_positions.is_empty() { + return false; + } + + let position_sets: Vec> = token_positions + .iter() + .skip(1) + .map(|positions| positions.iter().copied().collect()) + .collect(); + + for &start_pos in &token_positions[0] { + let mut current_pos = start_pos; + let mut matched = true; + + for positions in &position_sets { + let next_pos = current_pos + 1; + if positions.contains(&next_pos) { + current_pos = next_pos; + } else { + matched = false; + break; + } + } + + if matched { + return true; + } + } + + false + } + + } #[cfg(test)] mod tests { + use crate::{scorer::bm25::BM25Score, DocumentId, FieldId, StringIndex}; + use nlp::{locales::Locale, TextParser}; use std::{collections::HashMap, sync::Arc}; + use tempdir::TempDir; - use nlp::{locales::Locale, TextParser}; + #[test] + fn test_empty_search_query() { + let tmp_dir = TempDir::new("string_index_test_empty_search").unwrap(); + let tmp_dir: String = tmp_dir.into_path().to_str().unwrap().to_string(); - use tempdir::TempDir; + let storage = Arc::new(crate::Storage::from_path(&tmp_dir)); + let string_index = StringIndex::new(storage); + let parser = TextParser::from_language(Locale::EN); - use crate::{scorer::bm25::BM25Score, DocumentId, FieldId, StringIndex}; + let batch: HashMap<_, _> = vec![( + DocumentId(1), + vec![(FieldId(0), "This is a test document.".to_string())], + )] + .into_iter() + .map(|(doc_id, fields)| { + let fields: Vec<_> = fields + .into_iter() + .map(|(field_id, data)| { + let tokens = parser.tokenize_and_stem(&data); + (field_id, tokens) + }) + .collect(); + (doc_id, fields) + }) + .collect(); + + string_index.insert_multiple(batch).unwrap(); + + let output = string_index + .search( + vec![], + None, + Default::default(), + BM25Score::default(), + ) + .unwrap(); + + assert!( + output.is_empty(), + "Search results should be empty for empty query" + ); + } + + #[test] + fn test_search_nonexistent_term() { + let tmp_dir = TempDir::new("string_index_test_nonexistent_term").unwrap(); + let tmp_dir: String = tmp_dir.into_path().to_str().unwrap().to_string(); + + let storage = Arc::new(crate::Storage::from_path(&tmp_dir)); + let string_index = StringIndex::new(storage); + let parser = TextParser::from_language(Locale::EN); + + let batch: HashMap<_, _> = vec![( + DocumentId(1), + vec![(FieldId(0), "This is a test document.".to_string())], + )] + .into_iter() + .map(|(doc_id, fields)| { + let fields: Vec<_> = fields + .into_iter() + .map(|(field_id, data)| { + let tokens = parser.tokenize_and_stem(&data); + (field_id, tokens) + }) + .collect(); + (doc_id, fields) + }) + .collect(); + + string_index.insert_multiple(batch).unwrap(); + + let output = string_index + .search( + vec!["nonexistent".to_string()], + None, + Default::default(), + BM25Score::default(), + ) + .unwrap(); + + assert!( + output.is_empty(), + "Search results should be empty for non-existent term" + ); + } + + #[test] + fn test_insert_empty_document() { + let tmp_dir = TempDir::new("string_index_test_empty_document").unwrap(); + let tmp_dir: String = tmp_dir.into_path().to_str().unwrap().to_string(); + + let storage = Arc::new(crate::Storage::from_path(&tmp_dir)); + let string_index = StringIndex::new(storage); + let parser = TextParser::from_language(Locale::EN); + + let batch: HashMap<_, _> = vec![( + DocumentId(1), + vec![(FieldId(0), "".to_string())], + )] + .into_iter() + .map(|(doc_id, fields)| { + let fields: Vec<_> = fields + .into_iter() + .map(|(field_id, data)| { + let tokens = parser.tokenize_and_stem(&data); + (field_id, tokens) + }) + .collect(); + (doc_id, fields) + }) + .collect(); + + string_index.insert_multiple(batch).unwrap(); + + let output = string_index + .search( + vec!["test".to_string()], + None, + Default::default(), + BM25Score::default(), + ) + .unwrap(); + + assert!( + output.is_empty(), + "Search results should be empty when only empty documents are indexed" + ); + } + + #[test] + fn test_search_with_field_filter() { + let tmp_dir = TempDir::new("string_index_test_field_filter").unwrap(); + let tmp_dir: String = tmp_dir.into_path().to_str().unwrap().to_string(); + + let storage = Arc::new(crate::Storage::from_path(&tmp_dir)); + let string_index = StringIndex::new(storage); + let parser = TextParser::from_language(Locale::EN); + + let batch: HashMap<_, _> = vec![( + DocumentId(1), + vec![ + (FieldId(0), "This is a test in field zero.".to_string()), + (FieldId(1), "Another test in field one.".to_string()), + ], + )] + .into_iter() + .map(|(doc_id, fields)| { + let fields: Vec<_> = fields + .into_iter() + .map(|(field_id, data)| { + let tokens = parser.tokenize_and_stem(&data); + (field_id, tokens) + }) + .collect(); + (doc_id, fields) + }) + .collect(); + + string_index.insert_multiple(batch).unwrap(); + + let output = string_index + .search( + vec!["test".to_string()], + Some(vec![FieldId(0)]), + Default::default(), + BM25Score::default(), + ) + .unwrap(); + + assert_eq!( + output.len(), + 1, + "Should find the document when searching in FieldId(0)" + ); + + let output = string_index + .search( + vec!["test".to_string()], + Some(vec![FieldId(1)]), + Default::default(), + BM25Score::default(), + ) + .unwrap(); + + assert_eq!( + output.len(), + 1, + "Should find the document when searching in FieldId(1)" + ); + + let output = string_index + .search( + vec!["test".to_string()], + Some(vec![FieldId(2)]), + Default::default(), + BM25Score::default(), + ) + .unwrap(); + + assert!( + output.is_empty(), + "Should not find any documents when searching in non-existent FieldId" + ); + } #[test] - fn test_search() { - let tmp_dir = TempDir::new("string_index_test").unwrap(); + fn test_search_with_boosts() { + let tmp_dir = TempDir::new("string_index_test_boosts").unwrap(); let tmp_dir: String = tmp_dir.into_path().to_str().unwrap().to_string(); + let storage = Arc::new(crate::Storage::from_path(&tmp_dir)); let string_index = StringIndex::new(storage); let parser = TextParser::from_language(Locale::EN); @@ -299,24 +576,13 @@ mod tests { let batch: HashMap<_, _> = vec![ ( DocumentId(1), - vec![( - FieldId(0), - "Yo, I'm from where Nicky Barnes got rich as fuck, welcome!".to_string(), - )], + vec![(FieldId(0), "Important content in field zero.".to_string())], ), ( DocumentId(2), vec![( - FieldId(0), - "Welcome to Harlem, where you welcome to problems".to_string(), - )], - ), - ( - DocumentId(3), - vec![( - FieldId(0), - "Now bitches, they want to neuter me, niggas, they want to tutor me" - .to_string(), + FieldId(1), + "Less important content in field one.".to_string(), )], ), ] @@ -335,26 +601,347 @@ mod tests { string_index.insert_multiple(batch).unwrap(); + let mut boost = HashMap::new(); + boost.insert(FieldId(0), 2.0); + + let output = string_index + .search( + vec!["content".to_string()], + None, + boost, + BM25Score::default(), + ) + .unwrap(); + + assert_eq!(output.len(), 2, "Should find both documents"); + + let score_doc1 = output.get(&DocumentId(1)).unwrap(); + let score_doc2 = output.get(&DocumentId(2)).unwrap(); + + assert!( + score_doc1 > score_doc2, + "Document with boosted field should have higher score" + ); + } + + #[test] + fn test_insert_document_with_stop_words_only() { + let tmp_dir = TempDir::new("string_index_test_stop_words").unwrap(); + let tmp_dir: String = tmp_dir.into_path().to_str().unwrap().to_string(); + + let storage = Arc::new(crate::Storage::from_path(&tmp_dir)); + let string_index = StringIndex::new(storage); + let parser = TextParser::from_language(Locale::EN); + + let batch: HashMap<_, _> = vec![( + DocumentId(1), + vec![(FieldId(0), "the and but or".to_string())], + )] + .into_iter() + .map(|(doc_id, fields)| { + let fields: Vec<_> = fields + .into_iter() + .map(|(field_id, data)| { + let tokens = parser.tokenize_and_stem(&data); + (field_id, tokens) + }) + .collect(); + (doc_id, fields) + }) + .collect(); + + string_index.insert_multiple(batch).unwrap(); + + let output = string_index + .search( + vec!["the".to_string()], + None, + Default::default(), + BM25Score::default(), + ) + .unwrap(); + + assert!( + output.is_empty(), + "Search results should be empty when only stop words are indexed" + ); + } + + #[test] + fn test_search_on_empty_index() { + let tmp_dir = TempDir::new("string_index_test_empty_index").unwrap(); + let tmp_dir: String = tmp_dir.into_path().to_str().unwrap().to_string(); + + let storage = Arc::new(crate::Storage::from_path(&tmp_dir)); + let string_index = StringIndex::new(storage); + + let output = string_index + .search( + vec!["test".to_string()], + None, + Default::default(), + BM25Score::default(), + ) + .unwrap(); + + assert!( + output.is_empty(), + "Search results should be empty when index is empty" + ); + } + + #[test] + fn test_concurrent_insertions() { + use std::thread; + + let tmp_dir = TempDir::new("string_index_test_concurrent_inserts").unwrap(); + let tmp_dir: String = tmp_dir.into_path().to_str().unwrap().to_string(); + + let storage = Arc::new(crate::Storage::from_path(&tmp_dir)); + let string_index = Arc::new(StringIndex::new(storage)); + + let string_index_clone1 = Arc::clone(&string_index); + let string_index_clone2 = Arc::clone(&string_index); + + let handle1 = thread::spawn(move || { + let parser = TextParser::from_language(Locale::EN); + let batch: HashMap<_, _> = vec![( + DocumentId(1), + vec![( + FieldId(0), + "Concurrent insertion test document one.".to_string(), + )], + )] + .into_iter() + .map(|(doc_id, fields)| { + let fields: Vec<_> = fields + .into_iter() + .map(|(field_id, data)| { + let tokens = parser.tokenize_and_stem(&data); + (field_id, tokens) + }) + .collect(); + (doc_id, fields) + }) + .collect(); + + string_index_clone1.insert_multiple(batch).unwrap(); + }); + + let handle2 = thread::spawn(move || { + let parser = TextParser::from_language(Locale::EN); + let batch: HashMap<_, _> = vec![( + DocumentId(2), + vec![( + FieldId(0), + "Concurrent insertion test document two.".to_string(), + )], + )] + .into_iter() + .map(|(doc_id, fields)| { + let fields: Vec<_> = fields + .into_iter() + .map(|(field_id, data)| { + let tokens = parser.tokenize_and_stem(&data); + (field_id, tokens) + }) + .collect(); + (doc_id, fields) + }) + .collect(); + + string_index_clone2.insert_multiple(batch).unwrap(); + }); + + handle1.join().unwrap(); + handle2.join().unwrap(); + + let parser = TextParser::from_language(Locale::EN); + let search_tokens = parser + .tokenize_and_stem("concurrent") + .into_iter() + .map(|(original, _)| original) + .collect::>(); + + let output = string_index + .search( + search_tokens, + None, + Default::default(), + BM25Score::default(), + ) + .unwrap(); + + assert_eq!( + output.len(), + 2, + "Should find both documents after concurrent insertions" + ); + } + + #[test] + fn test_large_documents() { + let tmp_dir = TempDir::new("string_index_test_large_documents").unwrap(); + let tmp_dir: String = tmp_dir.into_path().to_str().unwrap().to_string(); + + let storage = Arc::new(crate::Storage::from_path(&tmp_dir)); + let string_index = StringIndex::new(storage); + let parser = TextParser::from_language(Locale::EN); + + let large_text = "word ".repeat(10000); + + let batch: HashMap<_, _> = vec![(DocumentId(1), vec![(FieldId(0), large_text.clone())])] + .into_iter() + .map(|(doc_id, fields)| { + let fields: Vec<_> = fields + .into_iter() + .map(|(field_id, data)| { + let tokens = parser.tokenize_and_stem(&data); + (field_id, tokens) + }) + .collect(); + (doc_id, fields) + }) + .collect(); + + string_index.insert_multiple(batch).unwrap(); + let output = string_index .search( - vec!["welcome".to_string()], + vec!["word".to_string()], None, Default::default(), BM25Score::default(), ) .unwrap(); - assert_eq!(output.len(), 2); + assert_eq!( + output.len(), + 1, + "Should find the document containing the large text" + ); + } + + #[test] + fn test_high_term_frequency() { + let tmp_dir = TempDir::new("string_index_test_high_term_frequency").unwrap(); + let tmp_dir: String = tmp_dir.into_path().to_str().unwrap().to_string(); + + let storage = Arc::new(crate::Storage::from_path(&tmp_dir)); + let string_index = StringIndex::new(storage); + let parser = TextParser::from_language(Locale::EN); + + let repeated_word = "repeat ".repeat(1000); + + let batch: HashMap<_, _> = vec![(DocumentId(1), vec![(FieldId(0), repeated_word.clone())])] + .into_iter() + .map(|(doc_id, fields)| { + let fields: Vec<_> = fields + .into_iter() + .map(|(field_id, data)| { + let tokens = parser.tokenize_and_stem(&data); + (field_id, tokens) + }) + .collect(); + (doc_id, fields) + }) + .collect(); + + string_index.insert_multiple(batch).unwrap(); let output = string_index .search( - vec!["wel".to_string()], + vec!["repeat".to_string()], None, Default::default(), BM25Score::default(), ) .unwrap(); - assert_eq!(output.len(), 2); + assert_eq!( + output.len(), + 1, + "Should find the document with high term frequency" + ); + } + + #[test] + fn test_term_positions() { + let tmp_dir = TempDir::new("string_index_test_term_positions").unwrap(); + let tmp_dir: String = tmp_dir.into_path().to_str().unwrap().to_string(); + let storage = Arc::new(crate::Storage::from_path(&tmp_dir)); + let string_index = StringIndex::new(storage); + let parser = TextParser::from_language(Locale::EN); + + let batch: HashMap<_, _> = vec![( + DocumentId(1), + vec![( + FieldId(0), + "quick brown fox jumps over the lazy dog".to_string(), + )], + )] + .into_iter() + .map(|(doc_id, fields)| { + let fields: Vec<_> = fields + .into_iter() + .map(|(field_id, data)| { + let tokens = parser.tokenize_and_stem(&data); + (field_id, tokens) + }) + .collect(); + (doc_id, fields) + }) + .collect(); + + string_index.insert_multiple(batch).unwrap(); + } + + #[test] + fn test_exact_phrase_match() { + let tmp_dir = TempDir::new("string_index_test_exact_phrase").unwrap(); + let tmp_dir: String = tmp_dir.into_path().to_str().unwrap().to_string(); + + let storage = Arc::new(crate::Storage::from_path(&tmp_dir)); + let string_index = StringIndex::new(storage); + let parser = TextParser::from_language(Locale::EN); + + let batch: HashMap<_, _> = vec![( + DocumentId(1), + vec![(FieldId(0), "5200 mAh battery in disguise".to_string())], + )] + .into_iter() + .map(|(doc_id, fields)| { + let fields: Vec<_> = fields + .into_iter() + .map(|(field_id, data)| { + let tokens = parser.tokenize_and_stem(&data); + (field_id, tokens) + }) + .collect(); + (doc_id, fields) + }) + .collect(); + + string_index.insert_multiple(batch).unwrap(); + + let output = string_index + .search( + vec!["5200".to_string(), "mAh".to_string(), "battery".to_string(), "in".to_string(), "disguise".to_string()], + Some(vec![FieldId(0)]), + Default::default(), + BM25Score::default(), + ) + .unwrap(); + + assert_eq!( + output.len(), + 1, + "Should find the document containing the exact phrase" + ); + + assert!( + output.contains_key(&DocumentId(1)), + "Document with ID 1 should be found" + ); } } diff --git a/string_index/src/posting_storage.rs b/string_index/src/posting_storage.rs index 976df89..d14f402 100644 --- a/string_index/src/posting_storage.rs +++ b/string_index/src/posting_storage.rs @@ -6,7 +6,7 @@ use thiserror::Error; use crate::Posting; #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] -pub struct PostingListId(usize); +pub struct PostingListId(pub usize); #[derive(Debug, Error)] pub enum PostingStorageError { diff --git a/string_index/src/scorer.rs b/string_index/src/scorer.rs index 6f0334a..281ceca 100644 --- a/string_index/src/scorer.rs +++ b/string_index/src/scorer.rs @@ -28,7 +28,7 @@ pub mod bm25 { scores: DashMap, } impl BM25Score { - fn new() -> Self { + pub fn new() -> Self { Self { scores: DashMap::new(), } @@ -42,10 +42,20 @@ pub mod bm25 { avg_doc_length: f32, boost: f32, ) -> f32 { + if tf == 0.0 || doc_length == 0.0 || avg_doc_length == 0.0 { + return 0.0; + } + let k1 = 1.5; let b = 0.75; let numerator = tf * (k1 + 1.0); let denominator = tf + k1 * (1.0 - b + b * (doc_length / avg_doc_length)); + + // @todo: find a better way to avoid division by 0 + if denominator == 0.0 { + return 0.0; + } + idf * (numerator / denominator) * boost } } @@ -54,32 +64,39 @@ pub mod bm25 { #[inline] fn add_entry( &self, - _global_info: &crate::GlobalInfo, + global_info: &crate::GlobalInfo, posting: crate::Posting, - _total_token_count: f32, - _boost_per_field: f32, + total_token_count: f32, + boost_per_field: f32, ) { let term_frequency = posting.term_frequency; let doc_length = posting.doc_length as f32; - let freq = 1.0; + let total_documents = global_info.total_documents as f32; - let total_documents = 1.0; - let avg_doc_length = total_documents / 1.0; + if total_documents == 0.0 { + self.scores.insert(posting.document_id, 0.0); + return; + } + + let avg_doc_length = global_info.total_document_length as f32 / total_documents; + let idf = + ((total_documents - total_token_count + 0.5) / (total_token_count + 0.5)).ln_1p(); - let idf = ((total_documents - freq + 0.5_f32) / (freq + 0.5_f32)).ln_1p(); let score = Self::calculate_score( term_frequency, idf, doc_length, avg_doc_length, - _boost_per_field, + boost_per_field, ); - let mut previous = self.scores.entry(posting.document_id).or_insert(0.0); - *previous += score; + self.scores + .entry(posting.document_id) + .and_modify(|e| *e += score) + .or_insert(score); } - fn get_scores(self) -> HashMap { + fn get_scores(self) -> HashMap { self.scores.into_iter().collect() } } @@ -105,14 +122,20 @@ pub mod code { } impl CodeScore { - fn new() -> Self { + pub fn new() -> Self { Self { scores: DashMap::new(), } } - fn calculate_score(pos: DashMap, boost: f32, doc_lenth: u16) -> f32 { + #[inline] + fn calculate_score(pos: DashMap, boost: f32, doc_length: u16) -> f32 { let mut foo: Vec<_> = pos.into_iter().map(|(p, v)| (p.0, v)).collect(); + + if foo.is_empty() || doc_length == 0 { + return 0.0; + } + foo.sort_by_key(|(p, _)| (*p as isize)); let pos_len = foo.len(); @@ -153,7 +176,12 @@ pub mod code { score += score_for_position; } - score / (pos_len * (doc_lenth as usize)) as f32 + let denominator = (pos_len * (doc_length as usize)) as f32; + if denominator == 0.0 { + 0.0 + } else { + score / denominator + } } } @@ -196,3 +224,206 @@ pub mod code { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::GlobalInfo; + use types::FieldId; + + #[test] + fn test_bm25_basic_scoring() { + let scorer = bm25::BM25Score::new(); + let global_info = GlobalInfo { + total_documents: 10, + total_document_length: 1000, + }; + + let posting = Posting { + field_id: FieldId(1), + document_id: DocumentId(1), + term_frequency: 5.0, + doc_length: 100, + positions: vec![1, 2, 3, 4, 5], + }; + + scorer.add_entry(&global_info, posting, 1.0, 1.0); + let scores = scorer.get_scores(); + assert!(scores.contains_key(&DocumentId(1))); + assert!(scores[&DocumentId(1)] > 0.0); + } + + #[test] + fn test_bm25_empty_document() { + let scorer = bm25::BM25Score::new(); + let global_info = GlobalInfo { + total_documents: 1, + total_document_length: 0, + }; + + let posting = Posting { + field_id: FieldId(1), + document_id: DocumentId(1), + term_frequency: 0.0, + doc_length: 0, + positions: vec![], + }; + + scorer.add_entry(&global_info, posting, 1.0, 1.0); + let scores = scorer.get_scores(); + assert_eq!(scores[&DocumentId(1)], 0.0); + } + + #[test] + fn test_bm25_boost_effect() { + let scorer = bm25::BM25Score::new(); + let global_info = GlobalInfo { + total_documents: 10, + total_document_length: 1000, + }; + + let posting = Posting { + field_id: FieldId(1), + document_id: DocumentId(1), + term_frequency: 5.0, + doc_length: 100, + positions: vec![1, 2, 3, 4, 5], + }; + + // Test with different boost values + scorer.add_entry(&global_info, posting.clone(), 1.0, 1.0); + let normal_scores = scorer.get_scores(); + + let scorer = bm25::BM25Score::new(); + scorer.add_entry(&global_info, posting, 1.0, 2.0); + let boosted_scores = scorer.get_scores(); + + assert!(boosted_scores[&DocumentId(1)] > normal_scores[&DocumentId(1)]); + } + + #[test] + fn test_code_score_basic() { + let scorer = code::CodeScore::new(); + let global_info = GlobalInfo { + total_documents: 10, + total_document_length: 1000, + }; + + let posting = Posting { + field_id: FieldId(1), + document_id: DocumentId(1), + term_frequency: 3.0, + doc_length: 100, + positions: vec![1, 3, 5], + }; + + scorer.add_entry(&global_info, posting, 1.0, 1.0); + let scores = scorer.get_scores(); + assert!(scores.contains_key(&DocumentId(1))); + assert!(scores[&DocumentId(1)] > 0.0); + } + + #[test] + fn test_code_score_adjacent_positions() { + let scorer = code::CodeScore::new(); + let global_info = GlobalInfo { + total_documents: 10, + total_document_length: 1000, + }; + + let posting_adjacent = Posting { + field_id: FieldId(1), + document_id: DocumentId(1), + term_frequency: 3.0, + doc_length: 100, + positions: vec![1, 2, 3], + }; + + let posting_spread = Posting { + field_id: FieldId(1), + document_id: DocumentId(2), + term_frequency: 3.0, + doc_length: 100, + positions: vec![1, 10, 20], + }; + + scorer.add_entry(&global_info, posting_adjacent, 1.0, 1.0); + scorer.add_entry(&global_info, posting_spread, 1.0, 1.0); + let scores = scorer.get_scores(); + + assert!(scores[&DocumentId(1)] > scores[&DocumentId(2)]); + } + + #[test] + fn test_code_score_empty_positions() { + let scorer = code::CodeScore::new(); + let global_info = GlobalInfo { + total_documents: 1, + total_document_length: 100, + }; + + let posting = Posting { + field_id: FieldId(1), + document_id: DocumentId(1), + term_frequency: 0.0, + doc_length: 100, + positions: vec![], + }; + + scorer.add_entry(&global_info, posting, 1.0, 1.0); + let scores = scorer.get_scores(); + assert_eq!(scores[&DocumentId(1)], 0.0); + } + + #[test] + fn test_code_score_single_position() { + let scorer = code::CodeScore::new(); + let global_info = GlobalInfo { + total_documents: 1, + total_document_length: 100, + }; + + let posting = Posting { + field_id: FieldId(1), + document_id: DocumentId(1), + term_frequency: 1.0, + doc_length: 100, + positions: vec![1], + }; + + scorer.add_entry(&global_info, posting, 1.0, 1.0); + let scores = scorer.get_scores(); + assert!(scores[&DocumentId(1)] > 0.0); + } + + #[test] + fn test_multiple_entries_same_document() { + let scorer = code::CodeScore::new(); + let global_info = GlobalInfo { + total_documents: 1, + total_document_length: 100, + }; + + let posting1 = Posting { + field_id: FieldId(1), + document_id: DocumentId(1), + term_frequency: 2.0, + doc_length: 100, + positions: vec![1, 2], + }; + + let posting2 = Posting { + field_id: FieldId(2), + document_id: DocumentId(1), + term_frequency: 2.0, + doc_length: 100, + positions: vec![3, 4], + }; + + scorer.add_entry(&global_info, posting1, 1.0, 1.0); + scorer.add_entry(&global_info, posting2, 1.0, 1.0); + + let scores = scorer.get_scores(); + assert!(scores[&DocumentId(1)] > 0.0); + } +}