From 55c8c5a6fd80fc3afba49f0418962344acf5415b Mon Sep 17 00:00:00 2001 From: Jeroen Van Der Donckt <18898740+jvdd@users.noreply.github.com> Date: Sat, 1 Apr 2023 22:17:37 +0200 Subject: [PATCH] :tada: add optional arrow2 support (#48) --- Cargo.toml | 3 +- README.md | 3 +- src/lib.rs | 92 ++++++++++++++++++++++ tests/argminmax_test.rs | 168 ++++++++++++++++++++++++++++++++++++++-- 4 files changed, 258 insertions(+), 8 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 49a0ca7..27f2355 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,7 +18,7 @@ num-traits = { version = "0.2.15", default-features = false } half = { version = "2.1.0", default-features = false, features=["num-traits"], optional = true } ndarray = { version = "0.15.6", default-features = false, optional = true} arrow = { version = ">0", default-features = false, optional = true} -# arrow2 = { version = ">0", default-features = false, optional = true} +arrow2 = { version = ">0.0", default-features = false, optional = true} # once_cell = "1.16.0" [features] @@ -27,6 +27,7 @@ float = [] half = ["dep:half"] ndarray = ["dep:ndarray"] arrow = ["dep:arrow"] +arrow2 = ["dep:arrow2"] [dev-dependencies] rstest = { version = "0.16", default-features = false } diff --git a/README.md b/README.md index ed1aa8c..ae128fd 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ 🚀 The function is generic over the type of the array, so it can be used on `&[T]` or `Vec` where `T` can be `f16`1, `f32`2, `f64`2, `i8`, `i16`, `i32`, `i64`, `u8`, `u16`, `u32`, `u64`. -🤝 The trait is implemented for [`slice`](https://doc.rust-lang.org/std/primitive.slice.html), [`Vec`](https://doc.rust-lang.org/std/vec/struct.Vec.html), 1D [`ndarray::ArrayBase`](https://docs.rs/ndarray/latest/ndarray/struct.ArrayBase.html)3, and apache [`arrow::PrimitiveArray`](https://docs.rs/arrow/latest/arrow/array/struct.PrimitiveArray.html)4. +🤝 The trait is implemented for [`slice`](https://doc.rust-lang.org/std/primitive.slice.html), [`Vec`](https://doc.rust-lang.org/std/vec/struct.Vec.html), 1D [`ndarray::ArrayBase`](https://docs.rs/ndarray/latest/ndarray/struct.ArrayBase.html)3, apache [`arrow::PrimitiveArray`](https://docs.rs/arrow/latest/arrow/array/struct.PrimitiveArray.html)4 and [`arrow2::PrimitiveArray`](https://docs.rs/arrow2/latest/arrow2/array/struct.PrimitiveArray.html)5. ⚡ **Runtime CPU feature detection** is used to select the most efficient implementation for the current CPU. This means that the same binary can be used on different CPUs without recompilation. @@ -18,6 +18,7 @@ > 2 for f32 and f64 you should enable the (default) `"float"` feature. > 3 for ndarray::ArrayBase you should enable the `"ndarray"` feature. > 4 for arrow::PrimitiveArray you should enable the `"arrow"` feature. +> 5 for arrow2::PrimitiveArray you should enable the `"arrow2"` feature. ## Installing diff --git a/src/lib.rs b/src/lib.rs index d78fa15..1cebd85 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -35,6 +35,7 @@ //! - **`half`** - enables the traits for `f16` (requires the [`half`](https://crates.io/crates/half) crate). //! - **`ndarray`** - adds the traits to [`ndarray::ArrayBase`](https://docs.rs/ndarray/latest/ndarray/struct.ArrayBase.html) (requires the `ndarray` crate). //! - **`arrow`** - adds the traits to [`arrow::array::PrimitiveArray`](https://docs.rs/arrow/latest/arrow/array/struct.PrimitiveArray.html) (requires the `arrow` crate). +//! - **`arrow2`** - adds the traits to [`arrow2::array::PrimitiveArray`](https://docs.rs/arrow2/latest/arrow2/array/struct.PrimitiveArray.html) (requires the `arrow2` crate). //! //! //! # Examples @@ -740,3 +741,94 @@ mod arrow_impl { } } } + +// ---------------------- (optional) arrow2 ---------------------- + +#[cfg(feature = "arrow2")] +mod arrow2_impl { + use super::*; + use arrow2::array::PrimitiveArray; + + impl ArgMinMax for PrimitiveArray + where + T: arrow2::types::NativeType, + for<'a> &'a [T]: ArgMinMax, + { + fn argminmax(&self) -> (usize, usize) { + self.values().as_ref().argminmax() + } + + fn argmin(&self) -> usize { + self.values().as_ref().argmin() + } + + fn argmax(&self) -> usize { + self.values().as_ref().argmax() + } + } + + #[cfg(feature = "float")] + impl NaNArgMinMax for PrimitiveArray + where + T: arrow2::types::NativeType, + for<'a> &'a [T]: NaNArgMinMax, + { + fn nanargminmax(&self) -> (usize, usize) { + self.values().as_ref().nanargminmax() + } + + fn nanargmin(&self) -> usize { + self.values().as_ref().nanargmin() + } + + fn nanargmax(&self) -> usize { + self.values().as_ref().nanargmax() + } + } + + #[cfg(feature = "half")] + #[inline(always)] + /// Convert a PrimitiveArray to a slice of half::f16 + /// To do so, the pointer to the arrow2::types::f16 slice is casted to a pointer to + /// a slice of half::f16 (since both use u16 as their underlying type) + fn _to_half_f16_slice( + primitive_array_f16: &PrimitiveArray, + ) -> &[half::f16] { + unsafe { + std::slice::from_raw_parts( + primitive_array_f16.values().as_ptr() as *const half::f16, + primitive_array_f16.len(), + ) + } + } + + #[cfg(feature = "half")] + impl ArgMinMax for PrimitiveArray { + fn argminmax(&self) -> (usize, usize) { + _to_half_f16_slice(self).argminmax() + } + + fn argmin(&self) -> usize { + _to_half_f16_slice(self).argmin() + } + + fn argmax(&self) -> usize { + _to_half_f16_slice(self).argmax() + } + } + + #[cfg(feature = "half")] + impl NaNArgMinMax for PrimitiveArray { + fn nanargminmax(&self) -> (usize, usize) { + _to_half_f16_slice(self).nanargminmax() + } + + fn nanargmin(&self) -> usize { + _to_half_f16_slice(self).nanargmin() + } + + fn nanargmax(&self) -> usize { + _to_half_f16_slice(self).nanargmax() + } + } +} diff --git a/tests/argminmax_test.rs b/tests/argminmax_test.rs index 4f3cfd2..117b9cf 100644 --- a/tests/argminmax_test.rs +++ b/tests/argminmax_test.rs @@ -13,6 +13,8 @@ use dev_utils::utils; use rand; const ARRAY_LENGTH: usize = 100_000; +const NB_RANDOM_RUNS: usize = 500; +const RANDOM_ARR_LENGTH: usize = 5_000; // ----- dtypes_with_nan template ----- @@ -207,8 +209,8 @@ mod default_test { T: Copy + FromPrimitive + AsPrimitive + rand::distributions::uniform::SampleUniform, for<'a> &'a [T]: ArgMinMax, { - for _ in 0..500 { - let data: Vec = utils::get_random_array::(5_000, min, max); + for _ in 0..NB_RANDOM_RUNS { + let data: Vec = utils::get_random_array::(RANDOM_ARR_LENGTH, min, max); // Slice let slice: &[T] = &data; let (min_slice, max_slice) = slice.argminmax(); @@ -337,8 +339,8 @@ mod ndarray_tests { T: Copy + FromPrimitive + AsPrimitive + rand::distributions::uniform::SampleUniform, for<'a> &'a [T]: ArgMinMax, { - for _ in 0..500 { - let data: Vec = utils::get_random_array::(5_000, min, max); + for _ in 0..NB_RANDOM_RUNS { + let data: Vec = utils::get_random_array::(RANDOM_ARR_LENGTH, min, max); // Slice let slice: &[T] = &data; let (min_slice, max_slice) = slice.argminmax(); @@ -478,8 +480,8 @@ mod arrow_tests { ArrowDataType: ArrowPrimitiveType + ArrowNumericType, PrimitiveArray: From>, { - for _ in 0..500 { - let data: Vec = utils::get_random_array::(5_000, min, max); + for _ in 0..NB_RANDOM_RUNS { + let data: Vec = utils::get_random_array::(RANDOM_ARR_LENGTH, min, max); // Slice let slice: &[T] = &data; let (min_slice, max_slice) = slice.argminmax(); @@ -497,3 +499,157 @@ mod arrow_tests { } } } + +#[cfg(feature = "arrow2")] +#[cfg(test)] +mod arrow2_tests { + use super::*; + + use arrow2::array::PrimitiveArray; + use arrow2::types::NativeType; + + // Float and not half + #[cfg(feature = "float")] + #[template] + #[rstest] + #[case::float32(f32::MIN, f32::MAX)] + #[case::float64(f64::MIN, f64::MAX)] + fn dtypes_with_nan_arrow2(#[case] min: T, #[case] max: T) {} + + #[apply(dtypes)] + fn test_argminmax_arrow2(#[case] _min: T, #[case] max: T) + where + for<'a> &'a [T]: ArgMinMax, + T: Copy + FromPrimitive + AsPrimitive + NativeType, + { + // max_index is the max value that can be represented by T + let max_index: usize = std::cmp::min(ARRAY_LENGTH, max.as_()); + + let data: PrimitiveArray = + PrimitiveArray::from_vec(get_monotonic_array(ARRAY_LENGTH, max_index)); + // Test owned PrimitiveArray + let (min, max) = data.argminmax(); + assert_eq!(min, 0); + assert_eq!(max, max_index - 1); + // Test borrowed PrimitiveArray + let (min, max) = (&data).argminmax(); + assert_eq!(min, 0); + assert_eq!(max, max_index - 1); + } + + #[cfg(feature = "float")] + #[apply(dtypes_with_nan_arrow2)] + fn test_argminmax_arrow2_nan(#[case] _min: T, #[case] max: T) + where + for<'a> &'a [T]: NaNArgMinMax, + T: Copy + FromPrimitive + AsPrimitive + NativeType, + { + // max_index is the max value that can be represented by T + let max_index: usize = std::cmp::min(ARRAY_LENGTH, max.as_()); + + let data: PrimitiveArray = + PrimitiveArray::from_vec(get_monotonic_array(ARRAY_LENGTH, max_index)); + // Test owned PrimitiveArray + let (min, max) = data.nanargminmax(); + assert_eq!(min, 0); + assert_eq!(max, max_index - 1); + // Test borrowed PrimitiveArray + let (min, max) = (&data).nanargminmax(); + assert_eq!(min, 0); + assert_eq!(max, max_index - 1); + } + + #[apply(dtypes)] + fn test_argminmax_many_random_runs_arrow2(#[case] min: T, #[case] max: T) + where + for<'a> &'a [T]: ArgMinMax, + T: Copy + + FromPrimitive + + AsPrimitive + + rand::distributions::uniform::SampleUniform + + NativeType, + { + for _ in 0..NB_RANDOM_RUNS { + let data: Vec = utils::get_random_array::(RANDOM_ARR_LENGTH, min, max); + // Slice + let slice: &[T] = &data; + let (min_slice, max_slice) = slice.argminmax(); + // Vec + let (min_vec, max_vec) = data.argminmax(); + // Arrow + let arrow: PrimitiveArray = PrimitiveArray::from_vec(data); + let (min_arrow, max_arrow) = arrow.argminmax(); + + // Check + assert_eq!(min_slice, min_vec); + assert_eq!(max_slice, max_vec); + assert_eq!(min_slice, min_arrow); + assert_eq!(max_slice, max_arrow); + } + } + + // Perform the same tests with half::f16 - convert to arrow2::types::f16 + #[test] + #[cfg(feature = "half")] + fn test_argminmax_arrow2_f16() { + // Get monotonic array + let max_index: usize = 1 << f16::MANTISSA_DIGITS; + let data: Vec = get_monotonic_array(ARRAY_LENGTH, max_index); + // Convert the half::f16 vec to PrimitiveArray + let data: Vec = data + .into_iter() + .map(|x| arrow2::types::f16(x.to_bits())) + .collect(); + + let data: PrimitiveArray = PrimitiveArray::from_vec(data); + + // --- ArgMinMax + // Test owned PrimitiveArray + let (min, max) = data.argminmax(); + assert_eq!(min, 0); + assert_eq!(max, max_index - 1); + // Test borrowed PrimitiveArray + let (min, max) = (&data).argminmax(); + assert_eq!(min, 0); + assert_eq!(max, max_index - 1); + + // --- NaNArgMinMax + // Test owned PrimitiveArray + let (min, max) = data.nanargminmax(); + assert_eq!(min, 0); + assert_eq!(max, max_index - 1); + // Test borrowed PrimitiveArray + let (min, max) = (&data).nanargminmax(); + assert_eq!(min, 0); + assert_eq!(max, max_index - 1); + + // --- many random runs + for _ in 0..NB_RANDOM_RUNS { + let data: Vec = + utils::get_random_array::(RANDOM_ARR_LENGTH, i16::MIN, i16::MAX); + // convert to half::f16 + let data_half: Vec = data.into_iter().map(|x| f16::from_bits(x as u16)).collect(); + // convert to arrow2::types::f16 + let data: Vec = data_half + .clone() + .into_iter() + .map(|x| arrow2::types::f16(x.to_bits())) + .collect(); + + // Slice + let slice: &[f16] = &data_half; + let (min_slice, max_slice) = slice.argminmax(); + // Vec + let (min_vec, max_vec) = data_half.argminmax(); + // Arrow2 + let arrow: PrimitiveArray = PrimitiveArray::from_vec(data); + let (min_arrow, max_arrow) = arrow.argminmax(); + + // Check + assert_eq!(min_slice, min_vec); + assert_eq!(max_slice, max_vec); + assert_eq!(min_slice, min_arrow); + assert_eq!(max_slice, max_arrow); + } + } +}