diff --git a/Cargo.toml b/Cargo.toml index b03d72d..d1a2c80 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,7 @@ exclude = [".git*", "dev_utils/**/*", "tests/**/*"] [dependencies] -num-traits = { version = "0.2.14", default-features = false } +num-traits = { version = "0.2.17", default-features = false } half = { version = "2.3.1", default-features = false, features=["num-traits"], optional = true } ndarray = { version = "0.15.6", default-features = false, optional = true} arrow = { version = ">0", default-features = false, optional = true} @@ -31,11 +31,11 @@ arrow = ["dep:arrow"] arrow2 = ["dep:arrow2"] [dev-dependencies] -rstest = { version = "0.18.1", default-features = false } +rstest = { version = "0.18.2", default-features = false } rstest_reuse = { version = "0.6", default-features = false } rand = { version = "0.8.5", default-features = false } -codspeed-criterion-compat = "1.1" -criterion = "0.4" +codspeed-criterion-compat = "2.3.3" +criterion = "0.5.1" dev_utils = { path = "dev_utils" } [dev-dependencies.half] diff --git a/tests/argminmax_test.rs b/tests/argminmax_test.rs index fe394bc..05dcdb5 100644 --- a/tests/argminmax_test.rs +++ b/tests/argminmax_test.rs @@ -45,10 +45,25 @@ fn dtypes_with_nan(#[case] min: T, #[case] max: T) {} // ----- dtypes template ----- -#[cfg(feature = "float")] +#[cfg(all(feature = "float", not(feature = "half")))] #[template] #[rstest] -// #[case::float16(f16::MIN, f16::MAX)] // TODO -> https://github.com/starkat99/half-rs/pull/83 +#[case::float32(f32::MIN, f32::MAX)] +#[case::float64(f64::MIN, f64::MAX)] +#[case::int8(i8::MIN, i8::MAX)] +#[case::int16(i16::MIN, i16::MAX)] +#[case::int32(i32::MIN, i32::MAX)] +#[case::int64(i64::MIN, i64::MAX)] +#[case::uint8(u8::MIN, u8::MAX)] +#[case::uint16(u16::MIN, u16::MAX)] +#[case::uint32(u32::MIN, u32::MAX)] +#[case::uint64(u64::MIN, u64::MAX)] +fn dtypes(#[case] min: T, #[case] max: T) {} + +#[cfg(all(feature = "float", feature = "half"))] +#[template] +#[rstest] +#[case::float16(f16::MIN, f16::from_usize(1 << f16::MANTISSA_DIGITS).unwrap())] #[case::float32(f32::MIN, f32::MAX)] #[case::float64(f64::MIN, f64::MAX)] #[case::int8(i8::MIN, i8::MAX)] @@ -223,7 +238,7 @@ mod default_test { } #[apply(dtypes)] - fn test_argminmax_many_random_runs(#[case] min: T, #[case] max: T) + fn test_argminmax_many_random_runs(#[case] _min: T, #[case] _max: T) where T: Copy + FromPrimitive + AsPrimitive + SampleUniformFullRange, for<'a> &'a [T]: ArgMinMax,