Signed-off-by: Matthias Kretz <[email protected]> libstdc++-v3/ChangeLog:
* include/bits/simd_loadstore.h (unchecked_load): Scalar partial
masked loads are either a simple load or nothing.
* include/bits/simd_mask.h (_S_partial_mask_of_n): Conversion
from bool or unsigned integer to basic_mask needs explicit
constructor call. Scalar masks don't need a recursion to
_S_partial_mask_of_n.
* include/bits/simd_vec.h (_S_masked_load): Add case for size 1.
This is needed for recursion from _ScalarAbi<N> with N > 1.
* testsuite/std/simd/loads.cc: Add tests for masked loads.
---
libstdc++-v3/include/bits/simd_loadstore.h | 11 ++--
libstdc++-v3/include/bits/simd_mask.h | 15 ++++--
libstdc++-v3/include/bits/simd_vec.h | 6 ++-
libstdc++-v3/testsuite/std/simd/loads.cc | 63 ++++++++++++++++------
4 files changed, 70 insertions(+), 25 deletions(-)
--
──────────────────────────────────────────────────────────────────────────
Dr. Matthias Kretz https://mattkretz.github.io
GSI Helmholtz Center for Heavy Ion Research https://gsi.de
std::simd
──────────────────────────────────────────────────────────────────────────diff --git a/libstdc++-v3/include/bits/simd_loadstore.h b/libstdc++-v3/include/bits/simd_loadstore.h
index 7154c8d2e208..7ea9de0f6f98 100644
--- a/libstdc++-v3/include/bits/simd_loadstore.h
+++ b/libstdc++-v3/include/bits/simd_loadstore.h
@@ -141,9 +141,14 @@ __glibcxx_simd_precondition(
}
else
{
- if constexpr (!__allow_out_of_bounds
- || (__static_size != dynamic_extent
- && __static_size >= size_t(_RV::size.value)))
+ constexpr bool __no_size_check
+ = !__allow_out_of_bounds
+ || (__static_size != dynamic_extent
+ && __static_size >= size_t(_RV::size.value));
+ if constexpr (_RV::size() == 1)
+ return __mask[0] && (__no_size_check || __rg_size > 0) ? _RV(_LoadCtorTag(), __ptr)
+ : _RV();
+ else if constexpr (__no_size_check)
return _RV::_S_masked_load(__ptr, __mask);
else if (__rg_size >= size_t(_RV::size()))
return _RV::_S_masked_load(__ptr, __mask);
diff --git a/libstdc++-v3/include/bits/simd_mask.h b/libstdc++-v3/include/bits/simd_mask.h
index 44676ff150d1..03bf68ec3d09 100644
--- a/libstdc++-v3/include/bits/simd_mask.h
+++ b/libstdc++-v3/include/bits/simd_mask.h
@@ -614,7 +614,7 @@ _S_partial_mask_of_n(int __n)
"positive __n that does not overflow.");
constexpr _DataType __0123
= __builtin_bit_cast(_DataType, _IotaArray<_Ip(_S_full_size)>);
- return __0123 < _Ip(__n);
+ return basic_mask(__0123 < _Ip(__n));
}
else
{
@@ -1584,13 +1584,18 @@ _S_partial_mask_of_n(int __n)
{
#if __has_builtin(__builtin_ia32_bzhi_di)
if constexpr (_S_use_bitmask && _S_size <= 64 && _Traits._M_have_bmi2())
- return __builtin_ia32_bzhi_di(~0ull >> (64 - _S_size), unsigned(__n));
+ return basic_mask(__builtin_ia32_bzhi_di(~0ull >> (64 - _S_size), unsigned(__n)));
#endif
- if (__n < _N0)
+ if constexpr (_N0 == 1)
+ {
+ static_assert(_S_size == 2); // => __n == 1
+ return _S_init(_Mask0(true), _Mask1(false));
+ }
+ else if (__n < _N0)
return _S_init(_Mask0::_S_partial_mask_of_n(__n), _Mask1(false));
- else if (__n == _N0)
+ else if (__n == _N0 || _N1 == 1)
return _S_init(_Mask0(true), _Mask1(false));
- else
+ else if constexpr (_N1 != 1)
return _S_init(_Mask0(true), _Mask1::_S_partial_mask_of_n(__n - _N0));
}
diff --git a/libstdc++-v3/include/bits/simd_vec.h b/libstdc++-v3/include/bits/simd_vec.h
index c76f052d0532..e4276e8ea9c4 100644
--- a/libstdc++-v3/include/bits/simd_vec.h
+++ b/libstdc++-v3/include/bits/simd_vec.h
@@ -1103,8 +1103,10 @@ _S_partial_load(const _Up* __mem, size_t __n)
static inline basic_vec
_S_masked_load(const _Up* __mem, mask_type __k)
{
+ if constexpr (_S_size == 1)
+ return __k[0] ? static_cast<value_type>(__mem[0]) : value_type();
#if _GLIBCXX_X86
- if constexpr (_Traits._M_have_avx512f())
+ else if constexpr (_Traits._M_have_avx512f())
return __x86_masked_load<_DataType>(__mem, __k._M_data);
else if constexpr (_Traits._M_have_avx() && (sizeof(_Up) == 4 || sizeof(_Up) == 8))
{
@@ -1117,7 +1119,7 @@ _S_masked_load(const _Up* __mem, mask_type __k)
}
}
#endif
- if (__k._M_none_of()) [[unlikely]]
+ else if (__k._M_none_of()) [[unlikely]]
return basic_vec();
else if constexpr (_S_is_scalar)
return basic_vec(static_cast<value_type>(*__mem));
diff --git a/libstdc++-v3/testsuite/std/simd/loads.cc b/libstdc++-v3/testsuite/std/simd/loads.cc
index 869346aa4d5b..e4e67ab98791 100644
--- a/libstdc++-v3/testsuite/std/simd/loads.cc
+++ b/libstdc++-v3/testsuite/std/simd/loads.cc
@@ -18,7 +18,7 @@ struct Tests
static_assert(simd::alignment_v<V> <= 256);
- ADD_TEST(loads) {
+ ADD_TEST(load_zeros) {
std::tuple {aligned_array<T, V::size * 2, 256> {}, aligned_array<int, V::size * 2, 256> {}},
[](auto& t, auto mem, auto ints) {
t.verify_equal(simd::unchecked_load<V>(mem), V());
@@ -37,24 +37,30 @@ struct Tests
t.verify_equal(simd::unchecked_load<V>(ints, simd::flag_convert), V());
t.verify_equal(simd::partial_load<V>(ints, simd::flag_convert), V());
+
+ t.verify_equal(simd::unchecked_load<V>(mem, M(true)), V());
+ t.verify_equal(simd::unchecked_load<V>(mem, M(false)), V());
+ t.verify_equal(simd::partial_load<V>(mem, M(true)), V());
+ t.verify_equal(simd::partial_load<V>(mem, M(false)), V());
}
};
- ADD_TEST(loads_iota, requires {T() + T(1);}) {
- std::tuple {[] {
- aligned_array<T, V::size * 2, simd::alignment_v<V>> arr = {};
- T init = 0;
- for (auto& x : arr) x = (init += T(1));
- return arr;
- }(), [] {
- aligned_array<int, V::size * 2, simd::alignment_v<V, int>> arr = {};
- std::iota(arr.begin(), arr.end(), 1);
- return arr;
- }()},
- [](auto& t, auto mem, auto ints) {
- constexpr V ref = test_iota<V, 1, 0>;
- constexpr V ref1 = V([](int i) { return i == 0 ? T(1): T(); });
+ static constexpr V ref = test_iota<V, 1, 0>;
+ static constexpr V ref1 = V([](int i) { return i == 0 ? T(1): T(); });
+ template <typename U>
+ static constexpr auto
+ make_iota_array()
+ {
+ aligned_array<U, V::size * 2, simd::alignment_v<V, U>> arr = {};
+ U init = 0;
+ for (auto& x : arr) x = (init += U(1));
+ return arr;
+ }
+
+ ADD_TEST(load_iotas, requires {T() + T(1);}) {
+ std::tuple {make_iota_array<T>(), make_iota_array<int>()},
+ [](auto& t, auto mem, auto ints) {
t.verify_equal(simd::unchecked_load<V>(mem), ref);
t.verify_equal(simd::partial_load<V>(mem), ref);
@@ -71,6 +77,33 @@ struct Tests
ints.begin(), ints.begin(), simd::flag_convert), V());
t.verify_equal(simd::partial_load<V>(
ints.begin(), ints.begin() + 1, simd::flag_convert), ref1);
+
+ t.verify_equal(simd::unchecked_load<V>(mem, M(true)), ref);
+ t.verify_equal(simd::unchecked_load<V>(mem, M(false)), V());
+ t.verify_equal(simd::partial_load<V>(mem, M(true)), ref);
+ t.verify_equal(simd::partial_load<V>(mem, M(false)), V());
+ }
+ };
+
+ static constexpr M alternating = M([](int i) { return 1 == (i & 1); });
+ static constexpr V ref_k = select(alternating, ref, T());
+ static constexpr V ref_2 = select(M([](int i) { return i < 2; }), ref, T());
+ static constexpr V ref_k_2 = select(M([](int i) { return i < 2; }), ref_k, T());
+
+ ADD_TEST(masked_loads) {
+ std::tuple {make_iota_array<T>(), alternating, M(true), M(false)},
+ [](auto& t, auto mem, M k, M tr, M fa) {
+ t.verify_equal(simd::unchecked_load<V>(mem, tr), ref);
+ t.verify_equal(simd::unchecked_load<V>(mem, fa), V());
+ t.verify_equal(simd::unchecked_load<V>(mem, k), ref_k);
+
+ t.verify_equal(simd::partial_load<V>(mem, tr), ref);
+ t.verify_equal(simd::partial_load<V>(mem, fa), V());
+ t.verify_equal(simd::partial_load<V>(mem, k), ref_k);
+
+ t.verify_equal(simd::partial_load<V>(mem.begin(), mem.begin() + 2, tr), ref_2);
+ t.verify_equal(simd::partial_load<V>(mem.begin(), mem.begin() + 2, fa), V());
+ t.verify_equal(simd::partial_load<V>(mem.begin(), mem.begin() + 2, k), ref_k_2);
}
};
};
signature.asc
Description: This is a digitally signed message part.
