From 36cedb26d6aa677f69718ab52ade7f2b2272bef2 Mon Sep 17 00:00:00 2001 From: Najib Ishaq Date: Thu, 9 Nov 2023 13:46:04 -0500 Subject: [PATCH] feat: brought back batched search functions --- abd-clam/README.md | 4 +- abd-clam/benches/genomic.rs | 15 +--- abd-clam/benches/knn-vs-rnn.rs | 11 ++- abd-clam/benches/rnn-search.rs | 15 +--- abd-clam/src/cakes/mod.rs | 104 ++++++++++++++++++++++++++- abd-clam/src/cakes/sharded.rs | 14 ---- cakes-results/Cargo.toml | 1 - cakes-results/src/knn_reports.rs | 11 +-- cakes-results/src/scaling_reports.rs | 6 +- 9 files changed, 116 insertions(+), 65 deletions(-) diff --git a/abd-clam/README.md b/abd-clam/README.md index e69e90140..e5e0fb29b 100644 --- a/abd-clam/README.md +++ b/abd-clam/README.md @@ -12,7 +12,7 @@ CLAM is a library crate so you can add it to your crate using `cargo add abd_cla ### Cakes: Nearest Neighbor Search ```rust -use abd_clam::{knn, rnn ,Cakes, PartitionCriteria, VecDataset}; +use abd_clam::{knn, rnn, Cakes, PartitionCriteria, VecDataset}; /// The distance function with with to perform clustering and search. /// @@ -67,7 +67,7 @@ let data = VecDataset::, f32>::new( let criteria = PartitionCriteria::default(); // The Cakes struct provides the functionality described in the CHESS paper. -// It performs search assuming that the dataset is a single shard. +// We use a single shard here because the demo data is small. let model = Cakes::new_single_shard(data, Some(seed), &criteria); // This line performs a non-trivial amount of work. #understatement diff --git a/abd-clam/benches/genomic.rs b/abd-clam/benches/genomic.rs index 71f130a9a..c79f58e3a 100644 --- a/abd-clam/benches/genomic.rs +++ b/abd-clam/benches/genomic.rs @@ -1,6 +1,5 @@ use criterion::*; -use rayon::prelude::*; use symagen::random_data; use abd_clam::{rnn, Cakes, PartitionCriteria, VecDataset}; @@ -47,24 +46,14 @@ fn genomic(c: &mut Criterion) { for radius in radii { let id = BenchmarkId::new("Clustered", radius); group.bench_with_input(id, &radius, |b, &radius| { - b.iter_with_large_drop(|| { - queries - .par_iter() - .map(|q| cakes.rnn_search(q, radius, rnn::Algorithm::Clustered)) - .collect::>() - }); + b.iter_with_large_drop(|| cakes.batch_rnn_search(&queries, radius, rnn::Algorithm::Clustered)); }); } group.sample_size(10); let id = BenchmarkId::new("Linear", radii[0]); group.bench_with_input(id, &radii[0], |b, _| { - b.iter_with_large_drop(|| { - queries - .par_iter() - .map(|q| cakes.rnn_search(q, radii[0], rnn::Algorithm::Linear)) - .collect::>() - }); + b.iter_with_large_drop(|| cakes.batch_rnn_search(&queries, radii[0], rnn::Algorithm::Linear)); }); group.finish(); diff --git a/abd-clam/benches/knn-vs-rnn.rs b/abd-clam/benches/knn-vs-rnn.rs index 437865fa1..b83f732cb 100644 --- a/abd-clam/benches/knn-vs-rnn.rs +++ b/abd-clam/benches/knn-vs-rnn.rs @@ -36,12 +36,11 @@ fn cakes(c: &mut Criterion) { let cakes = Cakes::new_single_shard(dataset, Some(seed), &criteria); for k in (0..=8).map(|v| 2usize.pow(v)) { - let radii = queries - .par_iter() - .map(|query| { - cakes - .linear_rnn_search(query, 0.0) - .into_iter() + let radii = cakes + .batch_knn_search(&queries, k, knn::Algorithm::Linear) + .into_iter() + .map(|hits| { + hits.into_iter() .map(|(_, d)| d) .max_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Less)) .unwrap() diff --git a/abd-clam/benches/rnn-search.rs b/abd-clam/benches/rnn-search.rs index d5caf336f..fbb86e78b 100644 --- a/abd-clam/benches/rnn-search.rs +++ b/abd-clam/benches/rnn-search.rs @@ -1,6 +1,5 @@ use criterion::*; -use rayon::prelude::*; use symagen::random_data; use abd_clam::{rnn, Cakes, PartitionCriteria, VecDataset}; @@ -45,12 +44,7 @@ fn cakes(c: &mut Criterion) { let id = BenchmarkId::new(variant.name(), radius); group.bench_with_input(id, &radius, |b, _| { - b.iter_with_large_drop(|| { - queries - .par_iter() - .map(|query| cakes.rnn_search(query, radius, variant)) - .collect::>() - }); + b.iter_with_large_drop(|| cakes.batch_rnn_search(&queries, radius, variant)); }); } } @@ -58,12 +52,7 @@ fn cakes(c: &mut Criterion) { group.sample_size(10); let id = BenchmarkId::new("Linear", radius); group.bench_with_input(id, &radius, |b, _| { - b.iter_with_large_drop(|| { - queries - .par_iter() - .map(|query| cakes.rnn_search(query, radius, rnn::Algorithm::Linear)) - .collect::>() - }); + b.iter_with_large_drop(|| cakes.batch_rnn_search(&queries, radius, rnn::Algorithm::Linear)); }); group.finish(); diff --git a/abd-clam/src/cakes/mod.rs b/abd-clam/src/cakes/mod.rs index 50e2fab20..d824c05aa 100644 --- a/abd-clam/src/cakes/mod.rs +++ b/abd-clam/src/cakes/mod.rs @@ -9,6 +9,7 @@ mod sharded; mod singular; use distances::Number; +use rayon::prelude::*; use search::Search; use sharded::RandomlySharded; use singular::SingleShard; @@ -75,6 +76,22 @@ impl> Cakes { } } + /// Performs RNN search on a batch of queries with the given algorithm. + /// + /// # Arguments + /// + /// * `queries` - The queries to search. + /// * `radius` - The search radius. + /// * `algo` - The algorithm to use. + /// + /// # Returns + /// + /// A vector of vectors of tuples containing the index of the instance and + /// the distance to the query. + pub fn batch_rnn_search(&self, queries: &[&I], radius: U, algo: rnn::Algorithm) -> Vec> { + queries.par_iter().map(|q| self.rnn_search(q, radius, algo)).collect() + } + /// Performs an RNN search with the given algorithm. /// /// # Arguments @@ -85,7 +102,8 @@ impl> Cakes { /// /// # Returns /// - /// A vector of tuples containing the index of the instance and the distance to the query. + /// A vector of tuples containing the index of the instance and the distance + /// to the query. pub fn rnn_search(&self, query: &I, radius: U, algo: rnn::Algorithm) -> Vec<(usize, U)> { match self { Self::SingleShard(ss) => ss.rnn_search(query, radius, algo), @@ -93,6 +111,21 @@ impl> Cakes { } } + /// Performs Linear RNN search on a batch of queries. + /// + /// # Arguments + /// + /// * `queries` - The queries to search. + /// * `radius` - The search radius. + /// + /// # Returns + /// + /// A vector of vectors of tuples containing the index of the instance and + /// the distance to the query. + pub fn batch_linear_rnn_search(&self, queries: &[&I], radius: U) -> Vec> { + queries.par_iter().map(|q| self.linear_rnn_search(q, radius)).collect() + } + /// Performs a linear RNN search. /// /// # Arguments @@ -102,7 +135,8 @@ impl> Cakes { /// /// # Returns /// - /// A vector of tuples containing the index of the instance and the distance to the query. + /// A vector of tuples containing the index of the instance and the distance + /// to the query. pub fn linear_rnn_search(&self, query: &I, radius: U) -> Vec<(usize, U)> { match self { Self::SingleShard(ss) => ss.linear_rnn_search(query, radius), @@ -118,6 +152,22 @@ impl> Cakes { } } + /// Performs KNN search on a batch of queries with the given algorithm. + /// + /// # Arguments + /// + /// * `queries` - The queries to search. + /// * `k` - The number of nearest neighbors to return. + /// * `algo` - The algorithm to use. + /// + /// # Returns + /// + /// A vector of vectors of tuples containing the index of the instance and + /// the distance to the query. + pub fn batch_knn_search(&self, queries: &[&I], k: usize, algo: knn::Algorithm) -> Vec> { + queries.par_iter().map(|q| self.knn_search(q, k, algo)).collect() + } + /// Performs a KNN search with the given algorithm. /// /// # Arguments @@ -162,6 +212,21 @@ impl> Cakes { } } + /// Performs Linear KNN search on a batch of queries. + /// + /// # Arguments + /// + /// * `queries` - The queries to search. + /// * `k` - The number of nearest neighbors to return. + /// + /// # Returns + /// + /// A vector of vectors of tuples containing the index of the instance and + /// the distance to the query. + pub fn batch_linear_knn_search(&self, queries: &[&I], k: usize) -> Vec> { + queries.par_iter().map(|q| self.linear_knn_search(q, k)).collect() + } + /// Performs a linear KNN search. /// /// # Arguments @@ -179,6 +244,23 @@ impl> Cakes { } } + /// Performs RNN search on a batch of queries with the tuned algorithm. + /// + /// If the algorithm has not been tuned, this will use the default algorithm. + /// + /// # Arguments + /// + /// * `queries` - The queries to search. + /// * `radius` - The search radius. + /// + /// # Returns + /// + /// A vector of vectors of tuples containing the index of the instance and + /// the distance to the query. + pub fn batch_tuned_rnn_search(&self, queries: &[&I], radius: U) -> Vec> { + queries.par_iter().map(|q| self.tuned_rnn_search(q, radius)).collect() + } + /// Performs a RNN search with the tuned algorithm. /// /// If the algorithm has not been tuned, this will use the default algorithm. @@ -196,6 +278,23 @@ impl> Cakes { self.rnn_search(query, radius, algo) } + /// Performs KNN search on a batch of queries with the tuned algorithm. + /// + /// If the algorithm has not been tuned, this will use the default algorithm. + /// + /// # Arguments + /// + /// * `queries` - The queries to search. + /// * `k` - The number of nearest neighbors to return. + /// + /// # Returns + /// + /// A vector of vectors of tuples containing the index of the instance and + /// the distance to the query. + pub fn batch_tuned_knn_search(&self, queries: &[&I], k: usize) -> Vec> { + queries.par_iter().map(|q| self.tuned_knn_search(q, k)).collect() + } + /// Performs a KNN search with the tuned algorithm. /// /// If the algorithm has not been tuned, this will use the default algorithm. @@ -233,6 +332,7 @@ where .find(|(_, &o)| o > index) .map_or_else(|| rs.num_shards() - 1, |(i, _)| i - 1); + let index = index - rs.offsets()[i]; rs.shards()[i].data().index(index) } } diff --git a/abd-clam/src/cakes/sharded.rs b/abd-clam/src/cakes/sharded.rs index cd9c4309f..29c2c1c05 100644 --- a/abd-clam/src/cakes/sharded.rs +++ b/abd-clam/src/cakes/sharded.rs @@ -65,20 +65,6 @@ impl> RandomlySharded { pub fn offsets(&self) -> &[usize] { &self.offsets } - - // /// K-nearest neighbor search. - // pub fn knn_search(&self, query: &I, k: usize) -> Vec<(usize, U)> { - // let hits = self - // .sample_shard - // .knn_search(query, k, self.best_knn_algorithm().unwrap_or_default()); - // let mut hits = knn::Hits::from_vec(k, hits); - // for (shard, &o) in self.shards.iter().zip(self.offsets.iter()) { - // let radius = hits.peek(); - // let new_hits = shard.rnn_search(query, radius, rnn::Algorithm::Clustered); - // hits.push_batch(new_hits.into_iter().map(|(i, d)| (i + o, d))); - // } - // hits.extract() - // } } impl> Search for RandomlySharded { diff --git a/cakes-results/Cargo.toml b/cakes-results/Cargo.toml index 877e3f98c..e14f937cf 100644 --- a/cakes-results/Cargo.toml +++ b/cakes-results/Cargo.toml @@ -21,7 +21,6 @@ log = "0.4.19" env_logger = "0.10.0" num-format = "0.4.4" csv = "1.2.2" -rayon = "1.8.0" [[bin]] name = "knn-results" diff --git a/cakes-results/src/knn_reports.rs b/cakes-results/src/knn_reports.rs index 66ea5bf26..2a391c4b9 100644 --- a/cakes-results/src/knn_reports.rs +++ b/cakes-results/src/knn_reports.rs @@ -26,7 +26,6 @@ use clap::Parser; use distances::Number; use log::info; use num_format::ToFormattedString; -use rayon::prelude::*; use serde::{Deserialize, Serialize}; mod ann_datasets; @@ -182,19 +181,13 @@ fn make_reports( info!("k: {k}"); let start = Instant::now(); - let hits = queries - .par_iter() - .map(|q| cakes.tuned_knn_search(q, k)) - .collect::>(); + let hits = cakes.batch_tuned_knn_search(&queries, k); let elapsed = start.elapsed().as_secs_f32(); let throughput = queries.len().as_f32() / elapsed; info!("Throughput: {} QPS", format_f32(throughput)); let start = Instant::now(); - let linear_hits = queries - .par_iter() - .map(|q| cakes.linear_knn_search(q, k)) - .collect::>(); + let linear_hits = cakes.batch_linear_knn_search(&queries, k); let linear_elapsed = start.elapsed().as_secs_f32(); let linear_throughput = queries.len().as_f32() / linear_elapsed; info!("Linear throughput: {} QPS", format_f32(linear_throughput)); diff --git a/cakes-results/src/scaling_reports.rs b/cakes-results/src/scaling_reports.rs index 18babfce7..0aec354d3 100644 --- a/cakes-results/src/scaling_reports.rs +++ b/cakes-results/src/scaling_reports.rs @@ -8,7 +8,6 @@ use clap::Parser; use distances::Number; use log::{debug, error, info, warn}; use num_format::ToFormattedString; -use rayon::prelude::*; use serde::{Deserialize, Serialize}; use symagen::augmentation; @@ -314,10 +313,7 @@ fn measure_algorithm<'a>( ) -> (Vec>, f32) { let num_queries = queries.len(); let start = Instant::now(); - let hits = queries - .par_iter() - .map(|q| cakes.knn_search(q, k, algorithm)) - .collect::>(); + let hits = cakes.batch_knn_search(queries, k, algorithm); let elapsed = start.elapsed().as_secs_f32(); let throughput = num_queries.as_f32() / elapsed;