Skip to content

Commit

Permalink
:detetctive: float uniform test
Browse files Browse the repository at this point in the history
  • Loading branch information
jvdd committed Jan 16, 2024
1 parent 83845d6 commit 04441e9
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 7 deletions.
8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ exclude = [".git*", "dev_utils/**/*", "tests/**/*"]


[dependencies]
num-traits = { version = "0.2.14", default-features = false }
num-traits = { version = "0.2.17", default-features = false }
half = { version = "2.3.1", default-features = false, features=["num-traits"], optional = true }
ndarray = { version = "0.15.6", default-features = false, optional = true}
arrow = { version = ">0", default-features = false, optional = true}
Expand All @@ -31,11 +31,11 @@ arrow = ["dep:arrow"]
arrow2 = ["dep:arrow2"]

[dev-dependencies]
rstest = { version = "0.18.1", default-features = false }
rstest = { version = "0.18.2", default-features = false }
rstest_reuse = { version = "0.6", default-features = false }
rand = { version = "0.8.5", default-features = false }
codspeed-criterion-compat = "1.1"
criterion = "0.4"
codspeed-criterion-compat = "2.3.3"
criterion = "0.5.1"
dev_utils = { path = "dev_utils" }

[dev-dependencies.half]
Expand Down
21 changes: 18 additions & 3 deletions tests/argminmax_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,25 @@ fn dtypes_with_nan<T>(#[case] min: T, #[case] max: T) {}

// ----- dtypes template -----

#[cfg(feature = "float")]
#[cfg(all(feature = "float", not(feature = "half")))]
#[template]
#[rstest]
// #[case::float16(f16::MIN, f16::MAX)] // TODO -> https://github.com/starkat99/half-rs/pull/83
#[case::float32(f32::MIN, f32::MAX)]
#[case::float64(f64::MIN, f64::MAX)]
#[case::int8(i8::MIN, i8::MAX)]
#[case::int16(i16::MIN, i16::MAX)]
#[case::int32(i32::MIN, i32::MAX)]
#[case::int64(i64::MIN, i64::MAX)]
#[case::uint8(u8::MIN, u8::MAX)]
#[case::uint16(u16::MIN, u16::MAX)]
#[case::uint32(u32::MIN, u32::MAX)]
#[case::uint64(u64::MIN, u64::MAX)]
fn dtypes<T>(#[case] min: T, #[case] max: T) {}

#[cfg(all(feature = "float", feature = "half"))]
#[template]
#[rstest]
#[case::float16(f16::MIN, f16::from_usize(1 << f16::MANTISSA_DIGITS).unwrap())]
#[case::float32(f32::MIN, f32::MAX)]
#[case::float64(f64::MIN, f64::MAX)]
#[case::int8(i8::MIN, i8::MAX)]
Expand Down Expand Up @@ -223,7 +238,7 @@ mod default_test {
}

#[apply(dtypes)]
fn test_argminmax_many_random_runs<T>(#[case] min: T, #[case] max: T)
fn test_argminmax_many_random_runs<T>(#[case] _min: T, #[case] _max: T)
where
T: Copy + FromPrimitive + AsPrimitive<usize> + SampleUniformFullRange,
for<'a> &'a [T]: ArgMinMax,
Expand Down

0 comments on commit 04441e9

Please sign in to comment.