diff --git a/avx.hpp b/avx.hpp index e1e6bf9..f13846a 100644 --- a/avx.hpp +++ b/avx.hpp @@ -392,6 +392,103 @@ SIMD_ALWAYS_INLINE inline simd choose( return simd(_mm256_blendv_pd(c.get(), b.get(), a.get())); } + +template <> +class simd { + __m256i m_value; + public: + using value_type = int; + using abi_type = simd_abi::avx; + using mask_type = simd_mask; + using storage_type = simd_storage; + SIMD_ALWAYS_INLINE inline simd() = default; + SIMD_ALWAYS_INLINE inline static constexpr int size() { return 8; } + SIMD_ALWAYS_INLINE inline simd(int value) + :m_value(_mm256_set1_epi32(value)) + {} + SIMD_ALWAYS_INLINE inline simd( + int a, int b, int c, int d, + int e, int f, int g, int h) + :m_value(_mm256_setr_epi32(a, b, c, d, e, f, g, h)) + {} + SIMD_ALWAYS_INLINE inline simd( + int a, int b, int c, int d) + :m_value(_mm256_setr_epi32(a, b, c, d, 0, 0, 0, 0)) + {} + + SIMD_ALWAYS_INLINE inline + simd(storage_type const& value) { + copy_from(value.data(), element_aligned_tag()); + } + SIMD_ALWAYS_INLINE inline + simd& operator=(storage_type const& value) { + copy_from(value.data(), element_aligned_tag()); + return *this; + } + template + SIMD_ALWAYS_INLINE inline simd(int const* ptr, Flags /*flags*/) + :m_value(::_mm256_loadu_si256((__m256i const *)ptr)) + {} + SIMD_ALWAYS_INLINE inline simd(int const* ptr, int stride) + :simd(ptr[0], ptr[stride], ptr[2*stride], ptr[3*stride], + ptr[4*stride], ptr[5*stride], ptr[6*stride], ptr[7*stride]) + {} + SIMD_ALWAYS_INLINE inline constexpr simd(__m256i const& value_in) + :m_value(value_in) + {} + +#if defined(__AVX2__) + + SIMD_ALWAYS_INLINE inline simd operator*(simd const& other) const { + return simd(_mm256_mul_epi32(m_value, other.m_value)); + } + + SIMD_ALWAYS_INLINE inline simd operator+(simd const& other) const { + return simd(_mm256_add_epi32(m_value, other.m_value)); + } + SIMD_ALWAYS_INLINE inline simd operator-(simd const& other) const { + return simd(_mm256_sub_epi32(m_value, other.m_value)); + } + + + SIMD_ALWAYS_INLINE SIMD_HOST_DEVICE inline simd operator-() const { + return simd(_mm256_sub_epi32(_mm256_set1_epi32(0.0), m_value)); + } +#endif + +#ifdef __INTEL_COMPILER + SIMD_ALWAYS_INLINE inline simd operator/(simd const& other) const { + return simd(_mm256_div_epi32(m_value, other.m_value)); + } +#endif + + SIMD_ALWAYS_INLINE inline void copy_from(int const* ptr, element_aligned_tag) { + m_value = _mm256_loadu_si256((__m256i const *)ptr); + } + SIMD_ALWAYS_INLINE inline void copy_to(int* ptr, element_aligned_tag) const { + _mm256_storeu_si256((__m256i *)ptr, m_value); + } + SIMD_ALWAYS_INLINE inline constexpr __m256i get() const { return m_value; } + + /* + SIMD_ALWAYS_INLINE inline simd_mask operator<(simd const& other) const { + return simd_mask(_mm256_cmp_ps(m_value, other.m_value, _CMP_LT_OS)); + } + SIMD_ALWAYS_INLINE inline simd_mask operator==(simd const& other) const { + return simd_mask(_mm256_cmp_ps(m_value, other.m_value, _CMP_EQ_OS)); + } + */ +}; + + +// Specialized permute +SIMD_ALWAYS_INLINE SIMD_HOST_DEVICE +inline simd permute(simd const& control, + simd const& a) { + simd result(_mm256_permutexvar_ps(control.get(),a.get()) ); + return result; +} + } #endif diff --git a/avx512.hpp b/avx512.hpp index f592b98..dc7c43c 100644 --- a/avx512.hpp +++ b/avx512.hpp @@ -397,6 +397,169 @@ SIMD_ALWAYS_INLINE inline simd choose( return simd(_mm512_mask_blend_pd(a.get(), c.get(), b.get())); } + // SIMD MASK FOR SIMD INT + // Essentially this is the same as the mask for simd but I tried + // Deriving from that and that didn't work for me. + // The only difference is the line 'using simd_type=' +template <> +class simd_mask +{ + __mmask16 m_value; + public: + using value_type = bool; + using simd_type = simd; + using abi_type = simd_abi::avx512; + SIMD_ALWAYS_INLINE inline simd_mask() = default; + SIMD_ALWAYS_INLINE inline simd_mask(bool value) + :m_value(-std::int16_t(value)) + {} + SIMD_ALWAYS_INLINE inline static constexpr int size() { return 16; } + SIMD_ALWAYS_INLINE inline constexpr simd_mask(__mmask16 const& value_in) + :m_value(value_in) + {} + SIMD_ALWAYS_INLINE inline constexpr __mmask16 get() const { return m_value; } + SIMD_ALWAYS_INLINE inline simd_mask operator||(simd_mask const& other) const { + return simd_mask(_kor_mask16(m_value, other.m_value)); + } + SIMD_ALWAYS_INLINE inline simd_mask operator&&(simd_mask const& other) const { + return simd_mask(_kand_mask16(m_value, other.m_value)); + } + SIMD_ALWAYS_INLINE inline simd_mask operator!() const { + return simd_mask(_knot_mask16(m_value)); + } +}; + + // An integer SIMD class for AVX512 +template <> +class simd { + __m512i m_value; + public: + SIMD_ALWAYS_INLINE simd() = default; + using value_type = float; + using abi_type = simd_abi::avx512; + using mask_type = simd_mask; + using storage_type = simd_storage; + SIMD_ALWAYS_INLINE inline static constexpr int size() { return 16; } + SIMD_ALWAYS_INLINE inline simd(int value) + :m_value(_mm512_set1_epi32(value)) + {} + SIMD_ALWAYS_INLINE inline simd( + int a, int b, int c, int d, + int e, int f, int g, int h, + int i, int j, int k, int l, + int m, int n, int o, int p) + :m_value(_mm512_setr_epi32( + a, b, c, d, e, f, g, h, + i, j, k, l, m, n, o, p)) + {} + + SIMD_ALWAYS_INLINE inline + simd(storage_type const& value) { + copy_from(value.data(), element_aligned_tag()); + } + SIMD_ALWAYS_INLINE inline + simd& operator=(storage_type const& value) { + copy_from(value.data(), element_aligned_tag()); + return *this; + } + template + SIMD_ALWAYS_INLINE inline simd(int const* ptr, Flags /*flags*/) + :m_value(_mm512_load_epi32(static_cast(ptr))) + {} + SIMD_ALWAYS_INLINE inline simd(int const* ptr, int stride) + :simd(ptr[0], ptr[stride], ptr[2*stride], ptr[3*stride], + ptr[4*stride], ptr[5*stride], ptr[6*stride], ptr[7*stride], + ptr[8*stride], ptr[9*stride], ptr[10*stride], ptr[11*stride], + ptr[12*stride], ptr[13*stride], ptr[14*stride], ptr[15*stride]) + {} + SIMD_ALWAYS_INLINE inline constexpr simd(__m512i const& value_in) + :m_value(value_in) + {} + SIMD_ALWAYS_INLINE inline simd operator*(simd const& other) const { + return simd(_mm512_mul_epi32(m_value, other.m_value)); + } + +#if 0 + // This needs SVML extension + SIMD_ALWAYS_INLINE inline simd operator/(simd const& other) const { + return simd(_mm512_div_epi32(m_value, other.m_value)); + } +#endif + + SIMD_ALWAYS_INLINE inline simd operator+(simd const& other) const { + return simd(_mm512_add_epi32(m_value, other.m_value)); + } + SIMD_ALWAYS_INLINE inline simd operator-(simd const& other) const { + return simd(_mm512_sub_epi32(m_value, other.m_value)); + } + SIMD_ALWAYS_INLINE SIMD_HOST_DEVICE inline simd operator-() const { + return simd(_mm512_sub_epi32(_mm512_set1_epi32(0), m_value)); + } + SIMD_ALWAYS_INLINE inline void copy_from(int const* ptr, element_aligned_tag) { + m_value = _mm512_load_epi32(static_cast(ptr)); + } + SIMD_ALWAYS_INLINE inline void copy_to(int* ptr, element_aligned_tag) const { + _mm512_store_epi32(ptr, m_value); + } + SIMD_ALWAYS_INLINE inline constexpr __m512i get() const { return m_value; } + SIMD_ALWAYS_INLINE inline simd_mask operator<(simd const& other) const { + return simd_mask(_mm512_cmp_epi32_mask(m_value, other.m_value, _CMP_LT_OS)); + } + SIMD_ALWAYS_INLINE inline simd_mask operator==(simd const& other) const { + return simd_mask(_mm512_cmp_epi32_mask(m_value, other.m_value, _CMP_EQ_OS)); + } +}; + + + template<> + class simd_utils< simd > { +public: + SIMD_ALWAYS_INLINE + inline static typename simd::storage_type make_permute(const int source_lanes[simd::size()] ) { + using simd_t = simd; + using control_storage_t = typename simd::storage_type; + control_storage_t my_mask_storage; + for(int i=0; i < simd_t::size(); ++i) { + my_mask_storage[i] = source_lanes[i]; + } + return my_mask_storage; + } +}; + + template<> + class simd_utils< simd > { +public: + SIMD_ALWAYS_INLINE SIMD_HOST_DEVICE + inline static typename simd::storage_type make_permute(const int source_lanes[simd::size()] ) { + using control_storage_t = typename simd::storage_type; + control_storage_t my_mask_storage; + for(int i=0; i < simd::size(); ++i) { + my_mask_storage[2*i] = source_lanes[i]; + my_mask_storage[2*i+1] = 0; + } + + return my_mask_storage; + } +}; + +// Specialized permute +SIMD_ALWAYS_INLINE SIMD_HOST_DEVICE +inline simd permute(simd const& control, + simd const& a) { + simd result(_mm512_permutexvar_ps(control.get(),a.get()) ); + return result; +} + +// Specialized permute +SIMD_ALWAYS_INLINE SIMD_HOST_DEVICE +inline simd permute(simd const& control, + simd const& a) { + simd result(_mm512_permutexvar_pd(control.get(),a.get()) ); + return result; +} + + + } #endif diff --git a/cuda_warp.hpp b/cuda_warp.hpp index 40a1233..53a8adc 100644 --- a/cuda_warp.hpp +++ b/cuda_warp.hpp @@ -72,7 +72,8 @@ class cuda_warp { static_assert(N <= 32, "CUDA warps can't be more than 32 threads"); public: SIMD_HOST_DEVICE static unsigned mask() { - return (unsigned(1) << N) - unsigned(1); + + return N == 32 ? 0xffffffff : ( (unsigned(1) << N) - unsigned(1) ) ; } }; @@ -267,6 +268,14 @@ SIMD_CUDA_ALWAYS_INLINE SIMD_HOST_DEVICE simd> choose( return simd>(a.get() ? b.get() : c.get()); } + // Generic Permute + template +SIMD_ALWAYS_INLINE SIMD_HOST_DEVICE + inline simd> permute(simd> const& control, simd> const& a) { + return simd>( __shfl_sync(simd_abi::cuda_warp::mask(), a.get(), control.get(), N)); + } + + } #endif diff --git a/hip_wavefront.hpp b/hip_wavefront.hpp index 888708d..e460cc3 100644 --- a/hip_wavefront.hpp +++ b/hip_wavefront.hpp @@ -64,6 +64,8 @@ #ifdef __HIPCC__ #include +#include + namespace SIMD_NAMESPACE { namespace simd_abi { @@ -72,8 +74,11 @@ template class hip_wavefront { static_assert(N <= 64, "HIP wavefronts can't be more than 64 threads"); public: - SIMD_HOST_DEVICE static unsigned mask() { - return (unsigned(1) << N) - unsigned(1); + + // Do we need strictly 64 bit masks now? + // + SIMD_HOST_DEVICE static uint64_t mask() { + return ( N==64 ) ? 0xffffffffffffffff : ( ( uint64_t(1) << N ) - uint64_t(1) ); } }; @@ -144,16 +149,23 @@ class simd_mask> { } }; +/*! FIXME: HIP does not support warp lane masked __all_of like CUDA + * yet + */ template SIMD_HIP_ALWAYS_INLINE SIMD_DEVICE bool all_of(simd_mask> const& a) { - return bool(__all_sync(simd_abi::hip_wavefront::mask(), int(a.get()))); + return bool(__all(int(a.get()))); } +/*! FIXME: HIP does not support warp lane masked __any_of like CUDA + * yet + */ + template SIMD_HIP_ALWAYS_INLINE SIMD_DEVICE bool any_of(simd_mask> const& a) { - return bool(__any_sync(simd_abi::hip_wavefront::mask(), int(a.get()))); + return bool(__any(int(a.get()))); } template @@ -289,6 +301,16 @@ SIMD_HIP_ALWAYS_INLINE SIMD_HOST_DEVICE simd> choo return simd>(a.get() ? b.get() : c.get()); } + +// Generic Permute +template +SIMD_ALWAYS_INLINE SIMD_HOST_DEVICE +inline simd> permute(simd> const& control, simd> const& a) { + return simd>( __shfl(a.get(), control.get(), N)); +} + + + } // SIMD_NAMESPACE #endif diff --git a/simd_common.hpp b/simd_common.hpp index 6d4d4be..6171962 100644 --- a/simd_common.hpp +++ b/simd_common.hpp @@ -278,4 +278,46 @@ class simd_size> { static constexpr int value = simd::size(); }; + +template +class simd_utils; + +template +class simd_utils< simd > { +public: + SIMD_ALWAYS_INLINE + static inline typename simd::storage_type make_permute(const int source_lanes[simd::size()] ) { + using control_storage_t = typename simd::storage_type; + control_storage_t my_mask_storage; + for(int i=0; i < simd::size(); ++i) { + my_mask_storage[i] = source_lanes[i]; + } + + return my_mask_storage; + } +}; + +// Generic Permute -- does not work on GPUs +// FIXME: This guard is not sufficient: it should guard against +// General Clang GPU Combinations too +#if !defined(__CUDACC__) && !defined(__HIPCC__) +template +SIMD_ALWAYS_INLINE SIMD_HOST_DEVICE +inline simd permute(simd const& control, simd const& a) { + T stack_a[simd::size()]; + T stack_res[simd::size()]; + int stack_control[ simd::size()]; + + a.copy_to(stack_a, element_aligned_tag()); + control.copy_to(stack_control, element_aligned_tag()); + + for (int i = 0; i < simd::size(); ++i) { + stack_res[i] = stack_a[ stack_control[i] ]; + } + simd result; + result.copy_from(stack_res, element_aligned_tag()); + return result; +} +#endif + }