Skip to content

Commit

Permalink
🎉 add optional arrow2 support (#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
jvdd authored Apr 1, 2023
1 parent dc6548a commit 55c8c5a
Show file tree
Hide file tree
Showing 4 changed files with 258 additions and 8 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ num-traits = { version = "0.2.15", default-features = false }
half = { version = "2.1.0", default-features = false, features=["num-traits"], optional = true }
ndarray = { version = "0.15.6", default-features = false, optional = true}
arrow = { version = ">0", default-features = false, optional = true}
# arrow2 = { version = ">0", default-features = false, optional = true}
arrow2 = { version = ">0.0", default-features = false, optional = true}
# once_cell = "1.16.0"

[features]
Expand All @@ -27,6 +27,7 @@ float = []
half = ["dep:half"]
ndarray = ["dep:ndarray"]
arrow = ["dep:arrow"]
arrow2 = ["dep:arrow2"]

[dev-dependencies]
rstest = { version = "0.16", default-features = false }
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

🚀 The function is generic over the type of the array, so it can be used on `&[T]` or `Vec<T>` where `T` can be `f16`<sup>1</sup>, `f32`<sup>2</sup>, `f64`<sup>2</sup>, `i8`, `i16`, `i32`, `i64`, `u8`, `u16`, `u32`, `u64`.

🤝 The trait is implemented for [`slice`](https://doc.rust-lang.org/std/primitive.slice.html), [`Vec`](https://doc.rust-lang.org/std/vec/struct.Vec.html), 1D [`ndarray::ArrayBase`](https://docs.rs/ndarray/latest/ndarray/struct.ArrayBase.html)<sup>3</sup>, and apache [`arrow::PrimitiveArray`](https://docs.rs/arrow/latest/arrow/array/struct.PrimitiveArray.html)<sup>4</sup>.
🤝 The trait is implemented for [`slice`](https://doc.rust-lang.org/std/primitive.slice.html), [`Vec`](https://doc.rust-lang.org/std/vec/struct.Vec.html), 1D [`ndarray::ArrayBase`](https://docs.rs/ndarray/latest/ndarray/struct.ArrayBase.html)<sup>3</sup>, apache [`arrow::PrimitiveArray`](https://docs.rs/arrow/latest/arrow/array/struct.PrimitiveArray.html)<sup>4</sup> and [`arrow2::PrimitiveArray`](https://docs.rs/arrow2/latest/arrow2/array/struct.PrimitiveArray.html)<sup>5</sup>.

**Runtime CPU feature detection** is used to select the most efficient implementation for the current CPU. This means that the same binary can be used on different CPUs without recompilation.

Expand All @@ -18,6 +18,7 @@
> <i><sup>2</sup> for <code>f32</code> and <code>f64</code> you should enable the (default) `"float"` feature.</i>
> <i><sup>3</sup> for <code>ndarray::ArrayBase</code> you should enable the `"ndarray"` feature.</i>
> <i><sup>4</sup> for <code>arrow::PrimitiveArray</code> you should enable the `"arrow"` feature.</i>
> <i><sup>5</sup> for <code>arrow2::PrimitiveArray</code> you should enable the `"arrow2"` feature.</i>
## Installing

Expand Down
92 changes: 92 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
//! - **`half`** - enables the traits for `f16` (requires the [`half`](https://crates.io/crates/half) crate).
//! - **`ndarray`** - adds the traits to [`ndarray::ArrayBase`](https://docs.rs/ndarray/latest/ndarray/struct.ArrayBase.html) (requires the `ndarray` crate).
//! - **`arrow`** - adds the traits to [`arrow::array::PrimitiveArray`](https://docs.rs/arrow/latest/arrow/array/struct.PrimitiveArray.html) (requires the `arrow` crate).
//! - **`arrow2`** - adds the traits to [`arrow2::array::PrimitiveArray`](https://docs.rs/arrow2/latest/arrow2/array/struct.PrimitiveArray.html) (requires the `arrow2` crate).
//!
//!
//! # Examples
Expand Down Expand Up @@ -740,3 +741,94 @@ mod arrow_impl {
}
}
}

// ---------------------- (optional) arrow2 ----------------------

#[cfg(feature = "arrow2")]
mod arrow2_impl {
use super::*;
use arrow2::array::PrimitiveArray;

impl<T> ArgMinMax for PrimitiveArray<T>
where
T: arrow2::types::NativeType,
for<'a> &'a [T]: ArgMinMax,
{
fn argminmax(&self) -> (usize, usize) {
self.values().as_ref().argminmax()
}

fn argmin(&self) -> usize {
self.values().as_ref().argmin()
}

fn argmax(&self) -> usize {
self.values().as_ref().argmax()
}
}

#[cfg(feature = "float")]
impl<T> NaNArgMinMax for PrimitiveArray<T>
where
T: arrow2::types::NativeType,
for<'a> &'a [T]: NaNArgMinMax,
{
fn nanargminmax(&self) -> (usize, usize) {
self.values().as_ref().nanargminmax()
}

fn nanargmin(&self) -> usize {
self.values().as_ref().nanargmin()
}

fn nanargmax(&self) -> usize {
self.values().as_ref().nanargmax()
}
}

#[cfg(feature = "half")]
#[inline(always)]
/// Convert a PrimitiveArray<arrow2::types::f16> to a slice of half::f16
/// To do so, the pointer to the arrow2::types::f16 slice is casted to a pointer to
/// a slice of half::f16 (since both use u16 as their underlying type)
fn _to_half_f16_slice(
primitive_array_f16: &PrimitiveArray<arrow2::types::f16>,
) -> &[half::f16] {
unsafe {
std::slice::from_raw_parts(
primitive_array_f16.values().as_ptr() as *const half::f16,
primitive_array_f16.len(),
)
}
}

#[cfg(feature = "half")]
impl ArgMinMax for PrimitiveArray<arrow2::types::f16> {
fn argminmax(&self) -> (usize, usize) {
_to_half_f16_slice(self).argminmax()
}

fn argmin(&self) -> usize {
_to_half_f16_slice(self).argmin()
}

fn argmax(&self) -> usize {
_to_half_f16_slice(self).argmax()
}
}

#[cfg(feature = "half")]
impl NaNArgMinMax for PrimitiveArray<arrow2::types::f16> {
fn nanargminmax(&self) -> (usize, usize) {
_to_half_f16_slice(self).nanargminmax()
}

fn nanargmin(&self) -> usize {
_to_half_f16_slice(self).nanargmin()
}

fn nanargmax(&self) -> usize {
_to_half_f16_slice(self).nanargmax()
}
}
}
168 changes: 162 additions & 6 deletions tests/argminmax_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ use dev_utils::utils;
use rand;

const ARRAY_LENGTH: usize = 100_000;
const NB_RANDOM_RUNS: usize = 500;
const RANDOM_ARR_LENGTH: usize = 5_000;

// ----- dtypes_with_nan template -----

Expand Down Expand Up @@ -207,8 +209,8 @@ mod default_test {
T: Copy + FromPrimitive + AsPrimitive<usize> + rand::distributions::uniform::SampleUniform,
for<'a> &'a [T]: ArgMinMax,
{
for _ in 0..500 {
let data: Vec<T> = utils::get_random_array::<T>(5_000, min, max);
for _ in 0..NB_RANDOM_RUNS {
let data: Vec<T> = utils::get_random_array::<T>(RANDOM_ARR_LENGTH, min, max);
// Slice
let slice: &[T] = &data;
let (min_slice, max_slice) = slice.argminmax();
Expand Down Expand Up @@ -337,8 +339,8 @@ mod ndarray_tests {
T: Copy + FromPrimitive + AsPrimitive<usize> + rand::distributions::uniform::SampleUniform,
for<'a> &'a [T]: ArgMinMax,
{
for _ in 0..500 {
let data: Vec<T> = utils::get_random_array::<T>(5_000, min, max);
for _ in 0..NB_RANDOM_RUNS {
let data: Vec<T> = utils::get_random_array::<T>(RANDOM_ARR_LENGTH, min, max);
// Slice
let slice: &[T] = &data;
let (min_slice, max_slice) = slice.argminmax();
Expand Down Expand Up @@ -478,8 +480,8 @@ mod arrow_tests {
ArrowDataType: ArrowPrimitiveType<Native = T> + ArrowNumericType,
PrimitiveArray<ArrowDataType>: From<Vec<T>>,
{
for _ in 0..500 {
let data: Vec<T> = utils::get_random_array::<T>(5_000, min, max);
for _ in 0..NB_RANDOM_RUNS {
let data: Vec<T> = utils::get_random_array::<T>(RANDOM_ARR_LENGTH, min, max);
// Slice
let slice: &[T] = &data;
let (min_slice, max_slice) = slice.argminmax();
Expand All @@ -497,3 +499,157 @@ mod arrow_tests {
}
}
}

#[cfg(feature = "arrow2")]
#[cfg(test)]
mod arrow2_tests {
use super::*;

use arrow2::array::PrimitiveArray;
use arrow2::types::NativeType;

// Float and not half
#[cfg(feature = "float")]
#[template]
#[rstest]
#[case::float32(f32::MIN, f32::MAX)]
#[case::float64(f64::MIN, f64::MAX)]
fn dtypes_with_nan_arrow2<T>(#[case] min: T, #[case] max: T) {}

#[apply(dtypes)]
fn test_argminmax_arrow2<T>(#[case] _min: T, #[case] max: T)
where
for<'a> &'a [T]: ArgMinMax,
T: Copy + FromPrimitive + AsPrimitive<usize> + NativeType,
{
// max_index is the max value that can be represented by T
let max_index: usize = std::cmp::min(ARRAY_LENGTH, max.as_());

let data: PrimitiveArray<T> =
PrimitiveArray::from_vec(get_monotonic_array(ARRAY_LENGTH, max_index));
// Test owned PrimitiveArray
let (min, max) = data.argminmax();
assert_eq!(min, 0);
assert_eq!(max, max_index - 1);
// Test borrowed PrimitiveArray
let (min, max) = (&data).argminmax();
assert_eq!(min, 0);
assert_eq!(max, max_index - 1);
}

#[cfg(feature = "float")]
#[apply(dtypes_with_nan_arrow2)]
fn test_argminmax_arrow2_nan<T>(#[case] _min: T, #[case] max: T)
where
for<'a> &'a [T]: NaNArgMinMax,
T: Copy + FromPrimitive + AsPrimitive<usize> + NativeType,
{
// max_index is the max value that can be represented by T
let max_index: usize = std::cmp::min(ARRAY_LENGTH, max.as_());

let data: PrimitiveArray<T> =
PrimitiveArray::from_vec(get_monotonic_array(ARRAY_LENGTH, max_index));
// Test owned PrimitiveArray
let (min, max) = data.nanargminmax();
assert_eq!(min, 0);
assert_eq!(max, max_index - 1);
// Test borrowed PrimitiveArray
let (min, max) = (&data).nanargminmax();
assert_eq!(min, 0);
assert_eq!(max, max_index - 1);
}

#[apply(dtypes)]
fn test_argminmax_many_random_runs_arrow2<T>(#[case] min: T, #[case] max: T)
where
for<'a> &'a [T]: ArgMinMax,
T: Copy
+ FromPrimitive
+ AsPrimitive<usize>
+ rand::distributions::uniform::SampleUniform
+ NativeType,
{
for _ in 0..NB_RANDOM_RUNS {
let data: Vec<T> = utils::get_random_array::<T>(RANDOM_ARR_LENGTH, min, max);
// Slice
let slice: &[T] = &data;
let (min_slice, max_slice) = slice.argminmax();
// Vec
let (min_vec, max_vec) = data.argminmax();
// Arrow
let arrow: PrimitiveArray<T> = PrimitiveArray::from_vec(data);
let (min_arrow, max_arrow) = arrow.argminmax();

// Check
assert_eq!(min_slice, min_vec);
assert_eq!(max_slice, max_vec);
assert_eq!(min_slice, min_arrow);
assert_eq!(max_slice, max_arrow);
}
}

// Perform the same tests with half::f16 - convert to arrow2::types::f16
#[test]
#[cfg(feature = "half")]
fn test_argminmax_arrow2_f16() {
// Get monotonic array
let max_index: usize = 1 << f16::MANTISSA_DIGITS;
let data: Vec<f16> = get_monotonic_array(ARRAY_LENGTH, max_index);
// Convert the half::f16 vec to PrimitiveArray<arrow2::types::f16>
let data: Vec<arrow2::types::f16> = data
.into_iter()
.map(|x| arrow2::types::f16(x.to_bits()))
.collect();

let data: PrimitiveArray<arrow2::types::f16> = PrimitiveArray::from_vec(data);

// --- ArgMinMax
// Test owned PrimitiveArray
let (min, max) = data.argminmax();
assert_eq!(min, 0);
assert_eq!(max, max_index - 1);
// Test borrowed PrimitiveArray
let (min, max) = (&data).argminmax();
assert_eq!(min, 0);
assert_eq!(max, max_index - 1);

// --- NaNArgMinMax
// Test owned PrimitiveArray
let (min, max) = data.nanargminmax();
assert_eq!(min, 0);
assert_eq!(max, max_index - 1);
// Test borrowed PrimitiveArray
let (min, max) = (&data).nanargminmax();
assert_eq!(min, 0);
assert_eq!(max, max_index - 1);

// --- many random runs
for _ in 0..NB_RANDOM_RUNS {
let data: Vec<i16> =
utils::get_random_array::<i16>(RANDOM_ARR_LENGTH, i16::MIN, i16::MAX);
// convert to half::f16
let data_half: Vec<f16> = data.into_iter().map(|x| f16::from_bits(x as u16)).collect();
// convert to arrow2::types::f16
let data: Vec<arrow2::types::f16> = data_half
.clone()
.into_iter()
.map(|x| arrow2::types::f16(x.to_bits()))
.collect();

// Slice
let slice: &[f16] = &data_half;
let (min_slice, max_slice) = slice.argminmax();
// Vec
let (min_vec, max_vec) = data_half.argminmax();
// Arrow2
let arrow: PrimitiveArray<arrow2::types::f16> = PrimitiveArray::from_vec(data);
let (min_arrow, max_arrow) = arrow.argminmax();

// Check
assert_eq!(min_slice, min_vec);
assert_eq!(max_slice, max_vec);
assert_eq!(min_slice, min_arrow);
assert_eq!(max_slice, max_arrow);
}
}
}

0 comments on commit 55c8c5a

Please sign in to comment.