diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c9ec9859e..9019ae4b8b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ A [separate changelog is kept for rand_core](rand_core/CHANGELOG.md). You may also find the [Upgrade Guide](https://rust-random.github.io/book/update.html) useful. ## [Unreleased] +- Add `rand::distributions::WeightedIndex::{weight, weights, total_weight}` (#1420) - Bump the MSRV to 1.61.0 ## [0.9.0-alpha.1] - 2024-03-18 diff --git a/src/distributions/weighted_index.rs b/src/distributions/weighted_index.rs index 59ff1c1915..dfb30ef62a 100644 --- a/src/distributions/weighted_index.rs +++ b/src/distributions/weighted_index.rs @@ -15,6 +15,7 @@ use core::fmt; // Note that this whole module is only imported if feature="alloc" is enabled. use alloc::vec::Vec; +use core::fmt::Debug; #[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; @@ -243,6 +244,124 @@ impl WeightedIndex { } } +/// A lazy-loading iterator over the weights of a `WeightedIndex` distribution. +/// This is returned by [`WeightedIndex::weights`]. +pub struct WeightedIndexIter<'a, X: SampleUniform + PartialOrd> { + weighted_index: &'a WeightedIndex, + index: usize, +} + +impl<'a, X> Debug for WeightedIndexIter<'a, X> + where + X: SampleUniform + PartialOrd + Debug, + X::Sampler: Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("WeightedIndexIter") + .field("weighted_index", &self.weighted_index) + .field("index", &self.index) + .finish() + } +} + +impl<'a, X> Clone for WeightedIndexIter<'a, X> +where + X: SampleUniform + PartialOrd, +{ + fn clone(&self) -> Self { + WeightedIndexIter { + weighted_index: self.weighted_index, + index: self.index, + } + } +} + +impl<'a, X> Iterator for WeightedIndexIter<'a, X> +where + X: for<'b> ::core::ops::SubAssign<&'b X> + + SampleUniform + + PartialOrd + + Clone, +{ + type Item = X; + + fn next(&mut self) -> Option { + match self.weighted_index.weight(self.index) { + None => None, + Some(weight) => { + self.index += 1; + Some(weight) + } + } + } +} + +impl WeightedIndex { + /// Returns the weight at the given index, if it exists. + /// + /// If the index is out of bounds, this will return `None`. + /// + /// # Example + /// + /// ``` + /// use rand::distributions::WeightedIndex; + /// + /// let weights = [0, 1, 2]; + /// let dist = WeightedIndex::new(&weights).unwrap(); + /// assert_eq!(dist.weight(0), Some(0)); + /// assert_eq!(dist.weight(1), Some(1)); + /// assert_eq!(dist.weight(2), Some(2)); + /// assert_eq!(dist.weight(3), None); + /// ``` + pub fn weight(&self, index: usize) -> Option + where + X: for<'a> ::core::ops::SubAssign<&'a X> + { + let mut weight = if index < self.cumulative_weights.len() { + self.cumulative_weights[index].clone() + } else if index == self.cumulative_weights.len() { + self.total_weight.clone() + } else { + return None; + }; + if index > 0 { + weight -= &self.cumulative_weights[index - 1]; + } + Some(weight) + } + + /// Returns a lazy-loading iterator containing the current weights of this distribution. + /// + /// If this distribution has not been updated since its creation, this will return the + /// same weights as were passed to `new`. + /// + /// # Example + /// + /// ``` + /// use rand::distributions::WeightedIndex; + /// + /// let weights = [1, 2, 3]; + /// let mut dist = WeightedIndex::new(&weights).unwrap(); + /// assert_eq!(dist.weights().collect::>(), vec![1, 2, 3]); + /// dist.update_weights(&[(0, &2)]).unwrap(); + /// assert_eq!(dist.weights().collect::>(), vec![2, 2, 3]); + /// ``` + pub fn weights(&self) -> WeightedIndexIter<'_, X> + where + X: for<'a> ::core::ops::SubAssign<&'a X> + { + WeightedIndexIter { + weighted_index: self, + index: 0, + } + } + + /// Returns the sum of all weights in this distribution. + pub fn total_weight(&self) -> X { + self.total_weight.clone() + } +} + impl Distribution for WeightedIndex where X: SampleUniform + PartialOrd, @@ -458,6 +577,75 @@ mod test { } } + #[test] + fn test_update_weights_errors() { + let data = [ + ( + &[1i32, 0, 0][..], + &[(0, &0)][..], + WeightError::InsufficientNonZero, + ), + ( + &[10, 10, 10, 10][..], + &[(1, &-11)][..], + WeightError::InvalidWeight, // A weight is negative + ), + ( + &[1, 2, 3, 4, 5][..], + &[(1, &5), (0, &5)][..], // Wrong order + WeightError::InvalidInput, + ), + ( + &[1][..], + &[(1, &1)][..], // Index too large + WeightError::InvalidInput, + ), + ]; + + for (weights, update, err) in data.iter() { + let total_weight = weights.iter().sum::(); + let mut distr = WeightedIndex::new(weights.to_vec()).unwrap(); + assert_eq!(distr.total_weight, total_weight); + match distr.update_weights(update) { + Ok(_) => panic!("Expected update_weights to fail, but it succeeded"), + Err(e) => assert_eq!(e, *err), + } + } + } + + #[test] + fn test_weight_at() { + let data = [ + &[1][..], + &[10, 2, 3, 4][..], + &[1, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..], + &[u32::MAX][..], + ]; + + for weights in data.iter() { + let distr = WeightedIndex::new(weights.to_vec()).unwrap(); + for (i, weight) in weights.iter().enumerate() { + assert_eq!(distr.weight(i), Some(*weight)); + } + assert_eq!(distr.weight(weights.len()), None); + } + } + + #[test] + fn test_weights() { + let data = [ + &[1][..], + &[10, 2, 3, 4][..], + &[1, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..], + &[u32::MAX][..], + ]; + + for weights in data.iter() { + let distr = WeightedIndex::new(weights.to_vec()).unwrap(); + assert_eq!(distr.weights().collect::>(), weights.to_vec()); + } + } + #[test] fn value_stability() { fn test_samples(