Skip to content

Commit

Permalink
add f32 support
Browse files Browse the repository at this point in the history
  • Loading branch information
jvdd committed Feb 23, 2024
1 parent 098661e commit d3e19fd
Show file tree
Hide file tree
Showing 11 changed files with 124 additions and 35 deletions.
18 changes: 14 additions & 4 deletions src/simd/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ where
target_arch = "x86",
target_arch = "x86_64",
target_arch = "aarch64",
feature = "nightly_simd"
all(feature = "nightly_simd", target_arch = "arm")
))]
macro_rules! impl_SIMDInit_Int {
($scalar_dtype:ty, $simd_vec_dtype:ty, $simd_mask_dtype:ty, $lane_size:expr, $simd_struct:ty) => {
Expand All @@ -201,14 +201,19 @@ macro_rules! impl_SIMDInit_Int {
target_arch = "x86",
target_arch = "x86_64",
target_arch = "aarch64",
feature = "nightly_simd"
all(feature = "nightly_simd", target_arch = "arm")
))]
pub(crate) use impl_SIMDInit_Int; // Now classic paths Just Work™

// --------------- Float Return NaNs

#[cfg(any(feature = "float", feature = "half"))]
#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
target_arch = "aarch64",
all(feature = "nightly_simd", target_arch = "arm")
))]
macro_rules! impl_SIMDInit_FloatReturnNaN {
($scalar_dtype:ty, $simd_vec_dtype:ty, $simd_mask_dtype:ty, $lane_size:expr, $simd_struct:ty) => {
impl SIMDInit<$scalar_dtype, $simd_vec_dtype, $simd_mask_dtype, $lane_size>
Expand All @@ -231,7 +236,12 @@ macro_rules! impl_SIMDInit_FloatReturnNaN {
}

#[cfg(any(feature = "float", feature = "half"))]
#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
target_arch = "aarch64",
all(feature = "nightly_simd", target_arch = "arm") // TODO: all like this?
))]
pub(crate) use impl_SIMDInit_FloatReturnNaN; // Now classic paths Just Work™

// --------------- Float Ignore NaNs
Expand Down
52 changes: 42 additions & 10 deletions src/simd/simd_f32_ignore_nan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,37 @@
/// As comparisons with NaN always return false, it is guaranteed that no NaN values
/// are added to the accumulating SIMD register.
///
#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
target_arch = "aarch64",
feature = "nightly_simd"
))]
use super::config::SIMDInstructionSet;
#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
target_arch = "aarch64",
feature = "nightly_simd"
))]
use super::generic::{
impl_SIMDArgMinMax, impl_SIMDInit_FloatIgnoreNaN, SIMDArgMinMax, SIMDInit, SIMDOps,
};
#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
target_arch = "aarch64",
feature = "nightly_simd"
))]
use crate::SCALAR;
#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
target_arch = "aarch64",
feature = "nightly_simd"
))]
use num_traits::Zero;
#[cfg(all(target_arch = "aarch64", feature = "nightly_simd"))]
#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
#[cfg(all(target_arch = "arm", feature = "nightly_simd"))]
use std::arch::arm::*;
Expand All @@ -31,11 +51,21 @@ use std::arch::x86::*;
use std::arch::x86_64::*;

/// The dtype-strategy for performing operations on f32 data: ignore NaN values
#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
target_arch = "aarch64",
feature = "nightly_simd"
))]
use super::super::dtype_strategy::FloatIgnoreNaN;

// https://stackoverflow.com/a/3793950
#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
target_arch = "aarch64",
feature = "nightly_simd"
))]
const MAX_INDEX: usize = 1 << f32::MANTISSA_DIGITS;

// --------------------------------------- AVX2 ----------------------------------------
Expand Down Expand Up @@ -250,8 +280,10 @@ mod avx512_ignore_nan {

// --------------------------------------- NEON ----------------------------------------

#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
#[cfg(feature = "nightly_simd")]
#[cfg(any(
all(target_arch = "arm", feature = "nightly_simd"),
target_arch = "aarch64" // stable for AArch64
))]
mod neon_ignore_nan {
use super::super::config::NEON;
use super::*;
Expand Down Expand Up @@ -328,7 +360,7 @@ mod neon_ignore_nan {
target_arch = "x86",
target_arch = "x86_64",
all(target_arch = "arm", feature = "nightly_simd"),
all(target_arch = "aarch64", feature = "nightly_simd"),
target_arch = "aarch64",
))]
#[cfg(test)]
mod tests {
Expand Down
73 changes: 60 additions & 13 deletions src/simd/simd_f32_return_nan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,28 @@
/// SIMDOps::_get_overflow_lane_size_limit() chunk of the data - which is not
/// necessarily the index of the first NaN value.
///
#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
target_arch = "aarch64",
feature = "nightly_simd"
))]
use super::config::SIMDInstructionSet;
#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
target_arch = "aarch64",
feature = "nightly_simd"
))]
use super::generic::{impl_SIMDInit_FloatReturnNaN, SIMDArgMinMax, SIMDInit, SIMDOps};
#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
target_arch = "aarch64",
feature = "nightly_simd"
))]
use crate::SCALAR;
#[cfg(all(target_arch = "aarch64", feature = "nightly_simd"))]
#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
#[cfg(all(target_arch = "arm", feature = "nightly_simd"))]
use std::arch::arm::*;
Expand All @@ -44,25 +59,55 @@ use std::arch::x86::*;
use std::arch::x86_64::*;

/// The dtype-strategy for performing operations on f32 data: return NaN index
#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
target_arch = "aarch64",
feature = "nightly_simd"
))]
use super::super::dtype_strategy::FloatReturnNaN;

#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
target_arch = "aarch64",
feature = "nightly_simd"
))]
use super::task::{max_index_value, min_index_value};

#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
target_arch = "aarch64",
feature = "nightly_simd"
))]
const BIT_SHIFT: i32 = 31;
#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
target_arch = "aarch64",
feature = "nightly_simd"
))]
const MASK_VALUE: i32 = 0x7FFFFFFF; // i32::MAX - masks everything but the sign bit

#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
target_arch = "aarch64",
feature = "nightly_simd"
))]
#[inline(always)]
fn _i32ord_to_f32(ord_i32: i32) -> f32 {
let v = ((ord_i32 >> BIT_SHIFT) & MASK_VALUE) ^ ord_i32;
f32::from_bits(v as u32)
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
target_arch = "aarch64",
feature = "nightly_simd"
))]
const MAX_INDEX: usize = i32::MAX as usize;

// --------------------------------------- AVX2 ----------------------------------------
Expand Down Expand Up @@ -371,8 +416,10 @@ mod avx512 {

// --------------------------------------- NEON ----------------------------------------

#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
#[cfg(feature = "nightly_simd")]
#[cfg(any(
all(target_arch = "arm", feature = "nightly_simd"),
target_arch = "aarch64" // stable for AArch64
))]
mod neon {
use super::super::config::NEON;
use super::*;
Expand Down Expand Up @@ -475,7 +522,7 @@ mod neon {
target_arch = "x86",
target_arch = "x86_64",
all(target_arch = "arm", feature = "nightly_simd"),
all(target_arch = "aarch64", feature = "nightly_simd"),
target_arch = "aarch64",
))]
#[cfg(test)]
mod tests {
Expand Down
2 changes: 1 addition & 1 deletion src/simd/simd_i16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ mod avx512 {

#[cfg(any(
all(target_arch = "arm", feature = "nightly_simd"),
target_arch = "aarch64" // stable for AArch64
target_arch = "aarch64" // stable for AArch64
))]
mod neon {
use super::super::config::NEON;
Expand Down
2 changes: 1 addition & 1 deletion src/simd/simd_i32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ mod avx512 {

#[cfg(any(
all(target_arch = "arm", feature = "nightly_simd"),
target_arch = "aarch64" // stable for AArch64
target_arch = "aarch64" // stable for AArch64
))]
mod neon {
use super::super::config::NEON;
Expand Down
2 changes: 1 addition & 1 deletion src/simd/simd_i64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ mod neon {
unimpl_SIMDArgMinMax!(i64, usize, SCALAR<Int>, NEON<Int>);
}

#[cfg(target_arch = "aarch64")]
#[cfg(target_arch = "aarch64")] // stable for AArch64
mod neon {
use super::super::config::NEON;
use super::*;
Expand Down
2 changes: 1 addition & 1 deletion src/simd/simd_i8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ mod avx512 {

#[cfg(any(
all(target_arch = "arm", feature = "nightly_simd"),
target_arch = "aarch64" // stable for AArch64
target_arch = "aarch64" // stable for AArch64
))]
mod neon {
use super::super::config::NEON;
Expand Down
2 changes: 1 addition & 1 deletion src/simd/simd_u16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ mod avx512 {

#[cfg(any(
all(target_arch = "arm", feature = "nightly_simd"),
target_arch = "aarch64" // stable for AArch64
target_arch = "aarch64" // stable for AArch64
))]
mod neon {
use super::super::config::NEON;
Expand Down
2 changes: 1 addition & 1 deletion src/simd/simd_u32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ mod avx512 {

#[cfg(any(
all(target_arch = "arm", feature = "nightly_simd"),
target_arch = "aarch64" // stable for AArch64
target_arch = "aarch64" // stable for AArch64
))]
mod neon {
use super::super::config::NEON;
Expand Down
2 changes: 1 addition & 1 deletion src/simd/simd_u64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ mod neon {
unimpl_SIMDArgMinMax!(u64, usize, SCALAR<Int>, NEON<Int>);
}

#[cfg(target_arch = "aarch64")]
#[cfg(target_arch = "aarch64")] // stable for AArch64
mod neon {
use super::super::config::NEON;
use super::*;
Expand Down
2 changes: 1 addition & 1 deletion src/simd/simd_u8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ mod avx512 {

#[cfg(any(
all(target_arch = "arm", feature = "nightly_simd"),
target_arch = "aarch64" // stable for AArch64
target_arch = "aarch64" // stable for AArch64
))]
mod neon {
use super::super::config::NEON;
Expand Down

0 comments on commit d3e19fd

Please sign in to comment.