Signed-off-by: Matthias Kretz <[email protected]>

libstdc++-v3/ChangeLog:

        * include/bits/simd_loadstore.h (unchecked_load): Scalar partial
        masked loads are either a simple load or nothing.
        * include/bits/simd_mask.h (_S_partial_mask_of_n): Conversion
        from bool or unsigned integer to basic_mask needs explicit
        constructor call. Scalar masks don't need a recursion to
        _S_partial_mask_of_n.
        * include/bits/simd_vec.h (_S_masked_load): Add case for size 1.
        This is needed for recursion from _ScalarAbi<N> with N > 1.
        * testsuite/std/simd/loads.cc: Add tests for masked loads.
---
 libstdc++-v3/include/bits/simd_loadstore.h | 11 ++--
 libstdc++-v3/include/bits/simd_mask.h      | 15 ++++--
 libstdc++-v3/include/bits/simd_vec.h       |  6 ++-
 libstdc++-v3/testsuite/std/simd/loads.cc   | 63 ++++++++++++++++------
 4 files changed, 70 insertions(+), 25 deletions(-)


--
──────────────────────────────────────────────────────────────────────────
 Dr. Matthias Kretz                           https://mattkretz.github.io
 GSI Helmholtz Center for Heavy Ion Research               https://gsi.de
 std::simd
──────────────────────────────────────────────────────────────────────────
diff --git a/libstdc++-v3/include/bits/simd_loadstore.h b/libstdc++-v3/include/bits/simd_loadstore.h
index 7154c8d2e208..7ea9de0f6f98 100644
--- a/libstdc++-v3/include/bits/simd_loadstore.h
+++ b/libstdc++-v3/include/bits/simd_loadstore.h
@@ -141,9 +141,14 @@ __glibcxx_simd_precondition(
 	}
       else
 	{
-	  if constexpr (!__allow_out_of_bounds
-			  || (__static_size != dynamic_extent
-				&& __static_size >= size_t(_RV::size.value)))
+	  constexpr bool __no_size_check
+	    = !__allow_out_of_bounds
+		|| (__static_size != dynamic_extent
+		      && __static_size >= size_t(_RV::size.value));
+	  if constexpr (_RV::size() == 1)
+	    return __mask[0] && (__no_size_check || __rg_size > 0) ? _RV(_LoadCtorTag(), __ptr)
+								   : _RV();
+	  else if constexpr (__no_size_check)
 	    return _RV::_S_masked_load(__ptr, __mask);
 	  else if (__rg_size >= size_t(_RV::size()))
 	    return _RV::_S_masked_load(__ptr, __mask);
diff --git a/libstdc++-v3/include/bits/simd_mask.h b/libstdc++-v3/include/bits/simd_mask.h
index 44676ff150d1..03bf68ec3d09 100644
--- a/libstdc++-v3/include/bits/simd_mask.h
+++ b/libstdc++-v3/include/bits/simd_mask.h
@@ -614,7 +614,7 @@ _S_partial_mask_of_n(int __n)
 					  "positive __n that does not overflow.");
 	      constexpr _DataType __0123
 		= __builtin_bit_cast(_DataType, _IotaArray<_Ip(_S_full_size)>);
-	      return __0123 < _Ip(__n);
+	      return basic_mask(__0123 < _Ip(__n));
 	    }
 	  else
 	    {
@@ -1584,13 +1584,18 @@ _S_partial_mask_of_n(int __n)
 	{
 #if __has_builtin(__builtin_ia32_bzhi_di)
 	  if constexpr (_S_use_bitmask && _S_size <= 64 && _Traits._M_have_bmi2())
-	    return __builtin_ia32_bzhi_di(~0ull >> (64 - _S_size), unsigned(__n));
+	    return basic_mask(__builtin_ia32_bzhi_di(~0ull >> (64 - _S_size), unsigned(__n)));
 #endif
-	  if (__n < _N0)
+	  if constexpr (_N0 == 1)
+	    {
+	      static_assert(_S_size == 2); // => __n == 1
+	      return _S_init(_Mask0(true), _Mask1(false));
+	    }
+	  else if (__n < _N0)
 	    return _S_init(_Mask0::_S_partial_mask_of_n(__n), _Mask1(false));
-	  else if (__n == _N0)
+	  else if (__n == _N0 || _N1 == 1)
 	    return _S_init(_Mask0(true), _Mask1(false));
-	  else
+	  else if constexpr (_N1 != 1)
 	    return _S_init(_Mask0(true), _Mask1::_S_partial_mask_of_n(__n - _N0));
 	}
 
diff --git a/libstdc++-v3/include/bits/simd_vec.h b/libstdc++-v3/include/bits/simd_vec.h
index c76f052d0532..e4276e8ea9c4 100644
--- a/libstdc++-v3/include/bits/simd_vec.h
+++ b/libstdc++-v3/include/bits/simd_vec.h
@@ -1103,8 +1103,10 @@ _S_partial_load(const _Up* __mem, size_t __n)
 	static inline basic_vec
 	_S_masked_load(const _Up* __mem, mask_type __k)
 	{
+	  if constexpr (_S_size == 1)
+	    return __k[0] ? static_cast<value_type>(__mem[0]) : value_type();
 #if _GLIBCXX_X86
-	  if constexpr (_Traits._M_have_avx512f())
+	  else if constexpr (_Traits._M_have_avx512f())
 	    return __x86_masked_load<_DataType>(__mem, __k._M_data);
 	  else if constexpr (_Traits._M_have_avx() && (sizeof(_Up) == 4 || sizeof(_Up) == 8))
 	    {
@@ -1117,7 +1119,7 @@ _S_masked_load(const _Up* __mem, mask_type __k)
 		}
 	    }
 #endif
-	  if (__k._M_none_of()) [[unlikely]]
+	  else if (__k._M_none_of()) [[unlikely]]
 	    return basic_vec();
 	  else if constexpr (_S_is_scalar)
 	    return basic_vec(static_cast<value_type>(*__mem));
diff --git a/libstdc++-v3/testsuite/std/simd/loads.cc b/libstdc++-v3/testsuite/std/simd/loads.cc
index 869346aa4d5b..e4e67ab98791 100644
--- a/libstdc++-v3/testsuite/std/simd/loads.cc
+++ b/libstdc++-v3/testsuite/std/simd/loads.cc
@@ -18,7 +18,7 @@ struct Tests
 
     static_assert(simd::alignment_v<V> <= 256);
 
-    ADD_TEST(loads) {
+    ADD_TEST(load_zeros) {
       std::tuple {aligned_array<T, V::size * 2, 256> {}, aligned_array<int, V::size * 2, 256> {}},
       [](auto& t, auto mem, auto ints) {
 	t.verify_equal(simd::unchecked_load<V>(mem), V());
@@ -37,24 +37,30 @@ struct Tests
 
 	t.verify_equal(simd::unchecked_load<V>(ints, simd::flag_convert), V());
 	t.verify_equal(simd::partial_load<V>(ints, simd::flag_convert), V());
+
+	t.verify_equal(simd::unchecked_load<V>(mem, M(true)), V());
+	t.verify_equal(simd::unchecked_load<V>(mem, M(false)), V());
+	t.verify_equal(simd::partial_load<V>(mem, M(true)), V());
+	t.verify_equal(simd::partial_load<V>(mem, M(false)), V());
       }
     };
 
-    ADD_TEST(loads_iota, requires {T() + T(1);}) {
-      std::tuple {[] {
-	aligned_array<T, V::size * 2, simd::alignment_v<V>> arr = {};
-	T init = 0;
-	for (auto& x : arr) x = (init += T(1));
-	return arr;
-      }(), [] {
-	aligned_array<int, V::size * 2, simd::alignment_v<V, int>> arr = {};
-	std::iota(arr.begin(), arr.end(), 1);
-	return arr;
-      }()},
-      [](auto& t, auto mem, auto ints) {
-	constexpr V ref = test_iota<V, 1, 0>;
-	constexpr V ref1 = V([](int i) { return i == 0 ? T(1): T(); });
+    static constexpr V ref = test_iota<V, 1, 0>;
+    static constexpr V ref1 = V([](int i) { return i == 0 ? T(1): T(); });
 
+    template <typename U>
+    static constexpr auto
+    make_iota_array()
+    {
+      aligned_array<U, V::size * 2, simd::alignment_v<V, U>> arr = {};
+      U init = 0;
+      for (auto& x : arr) x = (init += U(1));
+      return arr;
+    }
+
+    ADD_TEST(load_iotas, requires {T() + T(1);}) {
+      std::tuple {make_iota_array<T>(), make_iota_array<int>()},
+      [](auto& t, auto mem, auto ints) {
 	t.verify_equal(simd::unchecked_load<V>(mem), ref);
 	t.verify_equal(simd::partial_load<V>(mem), ref);
 
@@ -71,6 +77,33 @@ struct Tests
 			 ints.begin(), ints.begin(), simd::flag_convert), V());
 	t.verify_equal(simd::partial_load<V>(
 			 ints.begin(), ints.begin() + 1, simd::flag_convert), ref1);
+
+	t.verify_equal(simd::unchecked_load<V>(mem, M(true)), ref);
+	t.verify_equal(simd::unchecked_load<V>(mem, M(false)), V());
+	t.verify_equal(simd::partial_load<V>(mem, M(true)), ref);
+	t.verify_equal(simd::partial_load<V>(mem, M(false)), V());
+      }
+    };
+
+    static constexpr M alternating = M([](int i) { return 1 == (i & 1); });
+    static constexpr V ref_k = select(alternating, ref, T());
+    static constexpr V ref_2 = select(M([](int i) { return i < 2; }), ref, T());
+    static constexpr V ref_k_2 = select(M([](int i) { return i < 2; }), ref_k, T());
+
+    ADD_TEST(masked_loads) {
+      std::tuple {make_iota_array<T>(), alternating, M(true), M(false)},
+      [](auto& t, auto mem, M k, M tr, M fa) {
+	t.verify_equal(simd::unchecked_load<V>(mem, tr), ref);
+	t.verify_equal(simd::unchecked_load<V>(mem, fa), V());
+	t.verify_equal(simd::unchecked_load<V>(mem, k), ref_k);
+
+	t.verify_equal(simd::partial_load<V>(mem, tr), ref);
+	t.verify_equal(simd::partial_load<V>(mem, fa), V());
+	t.verify_equal(simd::partial_load<V>(mem, k), ref_k);
+
+	t.verify_equal(simd::partial_load<V>(mem.begin(), mem.begin() + 2, tr), ref_2);
+	t.verify_equal(simd::partial_load<V>(mem.begin(), mem.begin() + 2, fa), V());
+	t.verify_equal(simd::partial_load<V>(mem.begin(), mem.begin() + 2, k), ref_k_2);
       }
     };
   };

Attachment: signature.asc
Description: This is a digitally signed message part.

Reply via email to