Skip to content

Commit

Permalink
feat: brought back batched search functions
Browse files Browse the repository at this point in the history
  • Loading branch information
nishaq503 committed Nov 9, 2023
1 parent 2984a37 commit 36cedb2
Show file tree
Hide file tree
Showing 9 changed files with 116 additions and 65 deletions.
4 changes: 2 additions & 2 deletions abd-clam/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ CLAM is a library crate so you can add it to your crate using `cargo add abd_cla
### Cakes: Nearest Neighbor Search

```rust
use abd_clam::{knn, rnn ,Cakes, PartitionCriteria, VecDataset};
use abd_clam::{knn, rnn, Cakes, PartitionCriteria, VecDataset};

/// The distance function with with to perform clustering and search.
///
Expand Down Expand Up @@ -67,7 +67,7 @@ let data = VecDataset::<Vec<f32>, f32>::new(
let criteria = PartitionCriteria::default();

// The Cakes struct provides the functionality described in the CHESS paper.
// It performs search assuming that the dataset is a single shard.
// We use a single shard here because the demo data is small.
let model = Cakes::new_single_shard(data, Some(seed), &criteria);
// This line performs a non-trivial amount of work. #understatement

Expand Down
15 changes: 2 additions & 13 deletions abd-clam/benches/genomic.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use criterion::*;

use rayon::prelude::*;
use symagen::random_data;

use abd_clam::{rnn, Cakes, PartitionCriteria, VecDataset};
Expand Down Expand Up @@ -47,24 +46,14 @@ fn genomic(c: &mut Criterion) {
for radius in radii {
let id = BenchmarkId::new("Clustered", radius);
group.bench_with_input(id, &radius, |b, &radius| {
b.iter_with_large_drop(|| {
queries
.par_iter()
.map(|q| cakes.rnn_search(q, radius, rnn::Algorithm::Clustered))
.collect::<Vec<_>>()
});
b.iter_with_large_drop(|| cakes.batch_rnn_search(&queries, radius, rnn::Algorithm::Clustered));
});
}

group.sample_size(10);
let id = BenchmarkId::new("Linear", radii[0]);
group.bench_with_input(id, &radii[0], |b, _| {
b.iter_with_large_drop(|| {
queries
.par_iter()
.map(|q| cakes.rnn_search(q, radii[0], rnn::Algorithm::Linear))
.collect::<Vec<_>>()
});
b.iter_with_large_drop(|| cakes.batch_rnn_search(&queries, radii[0], rnn::Algorithm::Linear));
});

group.finish();
Expand Down
11 changes: 5 additions & 6 deletions abd-clam/benches/knn-vs-rnn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,11 @@ fn cakes(c: &mut Criterion) {
let cakes = Cakes::new_single_shard(dataset, Some(seed), &criteria);

for k in (0..=8).map(|v| 2usize.pow(v)) {
let radii = queries
.par_iter()
.map(|query| {
cakes
.linear_rnn_search(query, 0.0)
.into_iter()
let radii = cakes
.batch_knn_search(&queries, k, knn::Algorithm::Linear)
.into_iter()
.map(|hits| {
hits.into_iter()
.map(|(_, d)| d)
.max_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Less))
.unwrap()
Expand Down
15 changes: 2 additions & 13 deletions abd-clam/benches/rnn-search.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use criterion::*;

use rayon::prelude::*;
use symagen::random_data;

use abd_clam::{rnn, Cakes, PartitionCriteria, VecDataset};
Expand Down Expand Up @@ -45,25 +44,15 @@ fn cakes(c: &mut Criterion) {

let id = BenchmarkId::new(variant.name(), radius);
group.bench_with_input(id, &radius, |b, _| {
b.iter_with_large_drop(|| {
queries
.par_iter()
.map(|query| cakes.rnn_search(query, radius, variant))
.collect::<Vec<_>>()
});
b.iter_with_large_drop(|| cakes.batch_rnn_search(&queries, radius, variant));
});
}
}

group.sample_size(10);
let id = BenchmarkId::new("Linear", radius);
group.bench_with_input(id, &radius, |b, _| {
b.iter_with_large_drop(|| {
queries
.par_iter()
.map(|query| cakes.rnn_search(query, radius, rnn::Algorithm::Linear))
.collect::<Vec<_>>()
});
b.iter_with_large_drop(|| cakes.batch_rnn_search(&queries, radius, rnn::Algorithm::Linear));
});

group.finish();
Expand Down
104 changes: 102 additions & 2 deletions abd-clam/src/cakes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ mod sharded;
mod singular;

use distances::Number;
use rayon::prelude::*;
use search::Search;
use sharded::RandomlySharded;
use singular::SingleShard;
Expand Down Expand Up @@ -75,6 +76,22 @@ impl<I: Instance, U: Number, D: Dataset<I, U>> Cakes<I, U, D> {
}
}

/// Performs RNN search on a batch of queries with the given algorithm.
///
/// # Arguments
///
/// * `queries` - The queries to search.
/// * `radius` - The search radius.
/// * `algo` - The algorithm to use.
///
/// # Returns
///
/// A vector of vectors of tuples containing the index of the instance and
/// the distance to the query.
pub fn batch_rnn_search(&self, queries: &[&I], radius: U, algo: rnn::Algorithm) -> Vec<Vec<(usize, U)>> {
queries.par_iter().map(|q| self.rnn_search(q, radius, algo)).collect()
}

/// Performs an RNN search with the given algorithm.
///
/// # Arguments
Expand All @@ -85,14 +102,30 @@ impl<I: Instance, U: Number, D: Dataset<I, U>> Cakes<I, U, D> {
///
/// # Returns
///
/// A vector of tuples containing the index of the instance and the distance to the query.
/// A vector of tuples containing the index of the instance and the distance
/// to the query.
pub fn rnn_search(&self, query: &I, radius: U, algo: rnn::Algorithm) -> Vec<(usize, U)> {
match self {
Self::SingleShard(ss) => ss.rnn_search(query, radius, algo),
Self::RandomlySharded(rs) => rs.rnn_search(query, radius, algo),
}
}

/// Performs Linear RNN search on a batch of queries.
///
/// # Arguments
///
/// * `queries` - The queries to search.
/// * `radius` - The search radius.
///
/// # Returns
///
/// A vector of vectors of tuples containing the index of the instance and
/// the distance to the query.
pub fn batch_linear_rnn_search(&self, queries: &[&I], radius: U) -> Vec<Vec<(usize, U)>> {
queries.par_iter().map(|q| self.linear_rnn_search(q, radius)).collect()
}

/// Performs a linear RNN search.
///
/// # Arguments
Expand All @@ -102,7 +135,8 @@ impl<I: Instance, U: Number, D: Dataset<I, U>> Cakes<I, U, D> {
///
/// # Returns
///
/// A vector of tuples containing the index of the instance and the distance to the query.
/// A vector of tuples containing the index of the instance and the distance
/// to the query.
pub fn linear_rnn_search(&self, query: &I, radius: U) -> Vec<(usize, U)> {
match self {
Self::SingleShard(ss) => ss.linear_rnn_search(query, radius),
Expand All @@ -118,6 +152,22 @@ impl<I: Instance, U: Number, D: Dataset<I, U>> Cakes<I, U, D> {
}
}

/// Performs KNN search on a batch of queries with the given algorithm.
///
/// # Arguments
///
/// * `queries` - The queries to search.
/// * `k` - The number of nearest neighbors to return.
/// * `algo` - The algorithm to use.
///
/// # Returns
///
/// A vector of vectors of tuples containing the index of the instance and
/// the distance to the query.
pub fn batch_knn_search(&self, queries: &[&I], k: usize, algo: knn::Algorithm) -> Vec<Vec<(usize, U)>> {
queries.par_iter().map(|q| self.knn_search(q, k, algo)).collect()
}

/// Performs a KNN search with the given algorithm.
///
/// # Arguments
Expand Down Expand Up @@ -162,6 +212,21 @@ impl<I: Instance, U: Number, D: Dataset<I, U>> Cakes<I, U, D> {
}
}

/// Performs Linear KNN search on a batch of queries.
///
/// # Arguments
///
/// * `queries` - The queries to search.
/// * `k` - The number of nearest neighbors to return.
///
/// # Returns
///
/// A vector of vectors of tuples containing the index of the instance and
/// the distance to the query.
pub fn batch_linear_knn_search(&self, queries: &[&I], k: usize) -> Vec<Vec<(usize, U)>> {
queries.par_iter().map(|q| self.linear_knn_search(q, k)).collect()
}

/// Performs a linear KNN search.
///
/// # Arguments
Expand All @@ -179,6 +244,23 @@ impl<I: Instance, U: Number, D: Dataset<I, U>> Cakes<I, U, D> {
}
}

/// Performs RNN search on a batch of queries with the tuned algorithm.
///
/// If the algorithm has not been tuned, this will use the default algorithm.
///
/// # Arguments
///
/// * `queries` - The queries to search.
/// * `radius` - The search radius.
///
/// # Returns
///
/// A vector of vectors of tuples containing the index of the instance and
/// the distance to the query.
pub fn batch_tuned_rnn_search(&self, queries: &[&I], radius: U) -> Vec<Vec<(usize, U)>> {
queries.par_iter().map(|q| self.tuned_rnn_search(q, radius)).collect()
}

/// Performs a RNN search with the tuned algorithm.
///
/// If the algorithm has not been tuned, this will use the default algorithm.
Expand All @@ -196,6 +278,23 @@ impl<I: Instance, U: Number, D: Dataset<I, U>> Cakes<I, U, D> {
self.rnn_search(query, radius, algo)
}

/// Performs KNN search on a batch of queries with the tuned algorithm.
///
/// If the algorithm has not been tuned, this will use the default algorithm.
///
/// # Arguments
///
/// * `queries` - The queries to search.
/// * `k` - The number of nearest neighbors to return.
///
/// # Returns
///
/// A vector of vectors of tuples containing the index of the instance and
/// the distance to the query.
pub fn batch_tuned_knn_search(&self, queries: &[&I], k: usize) -> Vec<Vec<(usize, U)>> {
queries.par_iter().map(|q| self.tuned_knn_search(q, k)).collect()
}

/// Performs a KNN search with the tuned algorithm.
///
/// If the algorithm has not been tuned, this will use the default algorithm.
Expand Down Expand Up @@ -233,6 +332,7 @@ where
.find(|(_, &o)| o > index)
.map_or_else(|| rs.num_shards() - 1, |(i, _)| i - 1);

let index = index - rs.offsets()[i];
rs.shards()[i].data().index(index)
}
}
Expand Down
14 changes: 0 additions & 14 deletions abd-clam/src/cakes/sharded.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,20 +65,6 @@ impl<I: Instance, U: Number, D: Dataset<I, U>> RandomlySharded<I, U, D> {
pub fn offsets(&self) -> &[usize] {
&self.offsets
}

// /// K-nearest neighbor search.
// pub fn knn_search(&self, query: &I, k: usize) -> Vec<(usize, U)> {
// let hits = self
// .sample_shard
// .knn_search(query, k, self.best_knn_algorithm().unwrap_or_default());
// let mut hits = knn::Hits::from_vec(k, hits);
// for (shard, &o) in self.shards.iter().zip(self.offsets.iter()) {
// let radius = hits.peek();
// let new_hits = shard.rnn_search(query, radius, rnn::Algorithm::Clustered);
// hits.push_batch(new_hits.into_iter().map(|(i, d)| (i + o, d)));
// }
// hits.extract()
// }
}

impl<I: Instance, U: Number, D: Dataset<I, U>> Search<I, U, D> for RandomlySharded<I, U, D> {
Expand Down
1 change: 0 additions & 1 deletion cakes-results/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ log = "0.4.19"
env_logger = "0.10.0"
num-format = "0.4.4"
csv = "1.2.2"
rayon = "1.8.0"

[[bin]]
name = "knn-results"
Expand Down
11 changes: 2 additions & 9 deletions cakes-results/src/knn_reports.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ use clap::Parser;
use distances::Number;
use log::info;
use num_format::ToFormattedString;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};

mod ann_datasets;
Expand Down Expand Up @@ -182,19 +181,13 @@ fn make_reports(
info!("k: {k}");

let start = Instant::now();
let hits = queries
.par_iter()
.map(|q| cakes.tuned_knn_search(q, k))
.collect::<Vec<_>>();
let hits = cakes.batch_tuned_knn_search(&queries, k);
let elapsed = start.elapsed().as_secs_f32();
let throughput = queries.len().as_f32() / elapsed;
info!("Throughput: {} QPS", format_f32(throughput));

let start = Instant::now();
let linear_hits = queries
.par_iter()
.map(|q| cakes.linear_knn_search(q, k))
.collect::<Vec<_>>();
let linear_hits = cakes.batch_linear_knn_search(&queries, k);
let linear_elapsed = start.elapsed().as_secs_f32();
let linear_throughput = queries.len().as_f32() / linear_elapsed;
info!("Linear throughput: {} QPS", format_f32(linear_throughput));
Expand Down
6 changes: 1 addition & 5 deletions cakes-results/src/scaling_reports.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use clap::Parser;
use distances::Number;
use log::{debug, error, info, warn};
use num_format::ToFormattedString;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use symagen::augmentation;

Expand Down Expand Up @@ -314,10 +313,7 @@ fn measure_algorithm<'a>(
) -> (Vec<Vec<(usize, f32)>>, f32) {
let num_queries = queries.len();
let start = Instant::now();
let hits = queries
.par_iter()
.map(|q| cakes.knn_search(q, k, algorithm))
.collect::<Vec<_>>();
let hits = cakes.batch_knn_search(queries, k, algorithm);
let elapsed = start.elapsed().as_secs_f32();
let throughput = num_queries.as_f32() / elapsed;

Expand Down

0 comments on commit 36cedb2

Please sign in to comment.