-
Notifications
You must be signed in to change notification settings - Fork 300
feat: add runtime batch_bool mask overloads for load_masked/store_masked #1332
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
b57a766
e227346
d5f21c7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,7 @@ | |
| #include <algorithm> | ||
| #include <array> | ||
| #include <complex> | ||
| #include <cstdint> | ||
|
|
||
| #include "../../types/xsimd_batch_constant.hpp" | ||
| #include "./xsimd_common_details.hpp" | ||
|
|
@@ -374,6 +375,21 @@ namespace xsimd | |
| return batch<T_out, A>::load(buffer.data(), aligned_mode {}); | ||
| } | ||
|
|
||
| template <class A, class T, class Mode> | ||
| XSIMD_INLINE batch<T, A> | ||
| load_masked(T const* mem, batch_bool<T, A> mask, convert<T>, Mode, requires_arch<common>) noexcept | ||
| { | ||
| // Per-lane validity contract: only active lanes are read. | ||
| // Arches with hardware predicated loads override this. | ||
| constexpr std::size_t size = batch<T, A>::size; | ||
| alignas(A::alignment()) std::array<T, size> buffer {}; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this array assignment forces everything to zero, while some stores are not needed, and the compiler is notable to optimize this away in the generic case |
||
| const uint64_t bits = mask.mask(); | ||
| for (std::size_t i = 0; i < size; ++i) | ||
| if ((bits >> i) & uint64_t(1)) | ||
| buffer[i] = mem[i]; | ||
| return batch<T, A>::load_aligned(buffer.data()); | ||
| } | ||
|
|
||
| template <class A, class T_in, class T_out, bool... Values, class alignment> | ||
| XSIMD_INLINE void | ||
| store_masked(T_out* mem, batch<T_in, A> const& src, batch_bool_constant<T_in, A, Values...>, alignment, requires_arch<common>) noexcept | ||
|
|
@@ -388,6 +404,73 @@ namespace xsimd | |
| } | ||
| } | ||
|
|
||
| template <class A, class T, class Mode> | ||
| XSIMD_INLINE void | ||
| store_masked(T* mem, batch<T, A> const& src, batch_bool<T, A> mask, Mode, requires_arch<common>) noexcept | ||
| { | ||
| // Per-lane validity contract: only active lanes are written. | ||
| // Arches with hardware predicated stores override this. | ||
| constexpr std::size_t size = batch<T, A>::size; | ||
| alignas(A::alignment()) std::array<T, size> src_buf; | ||
| src.store_aligned(src_buf.data()); | ||
| const uint64_t bits = mask.mask(); | ||
| for (std::size_t i = 0; i < size; ++i) | ||
| if ((bits >> i) & uint64_t(1)) | ||
| mem[i] = src_buf[i]; | ||
| } | ||
|
|
||
| // Head/tail forward to the runtime-mask path. ``tail`` offsets | ||
| // the base pointer back by ``(size - n)`` so the active high-``n`` | ||
| // lanes land at ``[mem, mem + n)``; the offset goes through | ||
| // ``uintptr_t`` to dodge ``-Warray-bounds`` on small buffers. | ||
| namespace detail | ||
| { | ||
| template <class T> | ||
| XSIMD_INLINE T const* offset_back(T const* p, std::size_t k) noexcept | ||
| { | ||
| return reinterpret_cast<T const*>(reinterpret_cast<std::uintptr_t>(p) - k * sizeof(T)); | ||
| } | ||
| template <class T> | ||
| XSIMD_INLINE T* offset_back(T* p, std::size_t k) noexcept | ||
| { | ||
| return reinterpret_cast<T*>(reinterpret_cast<std::uintptr_t>(p) - k * sizeof(T)); | ||
| } | ||
| } | ||
|
|
||
| template <class A, class T, class Mode> | ||
| XSIMD_INLINE batch<T, A> | ||
| load_head(T const* mem, std::size_t n, Mode, requires_arch<common>) noexcept | ||
| { | ||
| const auto mask = batch_bool<T, A>::from_mask(::xsimd::details::full_mask(n)); | ||
| return load_masked<A>(mem, mask, convert<T> {}, unaligned_mode {}, A {}); | ||
| } | ||
|
|
||
| template <class A, class T, class Mode> | ||
| XSIMD_INLINE void | ||
| store_head(T* mem, std::size_t n, batch<T, A> const& src, Mode, requires_arch<common>) noexcept | ||
| { | ||
| const auto mask = batch_bool<T, A>::from_mask(::xsimd::details::full_mask(n)); | ||
| store_masked<A>(mem, src, mask, unaligned_mode {}, A {}); | ||
| } | ||
|
|
||
| template <class A, class T, class Mode> | ||
| XSIMD_INLINE batch<T, A> | ||
| load_tail(T const* mem, std::size_t n, Mode, requires_arch<common>) noexcept | ||
| { | ||
| constexpr std::size_t size = batch<T, A>::size; | ||
| const auto mask = batch_bool<T, A>::from_mask(::xsimd::details::full_mask(n) << (size - n)); | ||
| return load_masked<A>(detail::offset_back(mem, size - n), mask, convert<T> {}, unaligned_mode {}, A {}); | ||
| } | ||
|
|
||
| template <class A, class T, class Mode> | ||
| XSIMD_INLINE void | ||
| store_tail(T* mem, std::size_t n, batch<T, A> const& src, Mode, requires_arch<common>) noexcept | ||
| { | ||
| constexpr std::size_t size = batch<T, A>::size; | ||
| const auto mask = batch_bool<T, A>::from_mask(::xsimd::details::full_mask(n) << (size - n)); | ||
| store_masked<A>(detail::offset_back(mem, size - n), src, mask, unaligned_mode {}, A {}); | ||
| } | ||
|
|
||
| template <class A, bool... Values, class Mode> | ||
| XSIMD_INLINE batch<int32_t, A> load_masked(int32_t const* mem, batch_bool_constant<int32_t, A, Values...>, convert<int32_t>, Mode, requires_arch<A>) noexcept | ||
| { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -119,7 +119,6 @@ namespace xsimd | |
| } | ||
|
|
||
| // load_masked | ||
| // AVX2 low-level helpers (operate on raw SIMD registers) | ||
| namespace detail | ||
| { | ||
| XSIMD_INLINE __m256i maskload(const int32_t* mem, __m256i mask) noexcept | ||
|
|
@@ -138,14 +137,12 @@ namespace xsimd | |
| } | ||
| } | ||
|
|
||
| // single templated implementation for integer masked loads (32/64-bit) | ||
| template <class A, class T, bool... Values, class Mode> | ||
| XSIMD_INLINE std::enable_if_t<std::is_integral<T>::value && (sizeof(T) >= 4), batch<T, A>> | ||
| load_masked(T const* mem, batch_bool_constant<T, A, Values...> mask, convert<T>, Mode, requires_arch<avx2>) noexcept | ||
| { | ||
| static_assert(sizeof(T) == 4 || sizeof(T) == 8, "load_masked supports only 32/64-bit integers on AVX2"); | ||
| using int_t = std::conditional_t<sizeof(T) == 4, int32_t, long long>; | ||
| // Use the raw register-level maskload helpers for the remaining cases. | ||
| return detail::maskload(reinterpret_cast<const int_t*>(mem), mask.as_batch()); | ||
| } | ||
|
|
||
|
|
@@ -175,6 +172,20 @@ namespace xsimd | |
| return bitwise_cast<uint64_t>(r); | ||
| } | ||
|
|
||
| // Runtime-mask load for 32/64-bit integers on AVX2. 8/16-bit integers | ||
| // fall back to the scalar common path: AVX2 has no native maskload for | ||
| // those widths, and a load-then-blend would break fault-suppression at | ||
| // page boundaries (the main reason callers ask for a masked load). | ||
| // Both aligned_mode and unaligned_mode route to the same intrinsic — | ||
| // masked-off lanes do not fault regardless of alignment. | ||
| template <class A, class T, class Mode> | ||
| XSIMD_INLINE std::enable_if_t<std::is_integral<T>::value && (sizeof(T) == 4 || sizeof(T) == 8), batch<T, A>> | ||
| load_masked(T const* mem, batch_bool<T, A> mask, convert<T>, Mode, requires_arch<avx2>) noexcept | ||
| { | ||
| using int_t = std::conditional_t<sizeof(T) == 4, int32_t, long long>; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why |
||
| return detail::maskload(reinterpret_cast<const int_t*>(mem), __m256i(mask)); | ||
| } | ||
|
|
||
| // store_masked | ||
| namespace detail | ||
| { | ||
|
|
@@ -196,14 +207,12 @@ namespace xsimd | |
| { | ||
| constexpr size_t lanes_per_half = batch<T, A>::size / 2; | ||
|
|
||
| // confined to lower 128-bit half → forward to SSE | ||
| XSIMD_IF_CONSTEXPR(mask.countl_zero() >= lanes_per_half) | ||
| { | ||
| constexpr auto mlo = ::xsimd::detail::lower_half<sse4_2>(mask); | ||
| const auto lo = detail::lower_half(src); | ||
| store_masked<sse4_2>(mem, lo, mlo, Mode {}, sse4_2 {}); | ||
| } | ||
| // confined to upper 128-bit half → forward to SSE | ||
| else XSIMD_IF_CONSTEXPR(mask.countr_zero() >= lanes_per_half) | ||
| { | ||
| constexpr auto mhi = ::xsimd::detail::upper_half<sse4_2>(mask); | ||
|
|
@@ -230,6 +239,20 @@ namespace xsimd | |
| store_masked<A>(reinterpret_cast<int64_t*>(mem), s64, batch_bool_constant<int64_t, A, Values...> {}, Mode {}, avx2 {}); | ||
| } | ||
|
|
||
| template <class A, class T, class Mode> | ||
| XSIMD_INLINE std::enable_if_t<std::is_integral<T>::value && (sizeof(T) == 4 || sizeof(T) == 8), void> | ||
| store_masked(T* mem, batch<T, A> const& src, batch_bool<T, A> mask, Mode, requires_arch<avx2>) noexcept | ||
| { | ||
| XSIMD_IF_CONSTEXPR(sizeof(T) == 4) | ||
| { | ||
| _mm256_maskstore_epi32(reinterpret_cast<int*>(mem), __m256i(mask), __m256i(src)); | ||
| } | ||
| else | ||
| { | ||
| _mm256_maskstore_epi64(reinterpret_cast<long long*>(mem), __m256i(mask), __m256i(src)); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, I guess that's a constraint of the Intel intrinsic, at least static_assert that |
||
| } | ||
| } | ||
|
|
||
| // load_stream | ||
| template <class A, class T, class = std::enable_if_t<std::is_integral<T>::value, void>> | ||
| XSIMD_INLINE batch<T, A> load_stream(T const* mem, convert<T>, requires_arch<avx2>) noexcept | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to make it worse, building a mask is not always a single operation depending on the target...