This patch implements P3383R3: mdspan.at().
The mdspan::at cast only non-integral types to the index_type, before
performing the checks. This allows to detect negative value of arguments,
even if the index type is unsigned, and other values that would overflow
index_type.
libstdc++-v3/ChangeLog:
* include/std/mdspan (mdspan::at, mdspan::__index_int_t):
Define.
* testsuite/23_containers/mdspan/at.cc: New test.
Signed-off-by: Tomasz Kamiński <[email protected]>
---
Tested on x86_64-linux. OK for trunk?
libstdc++-v3/include/std/mdspan | 63 ++++++++++
.../testsuite/23_containers/mdspan/at.cc | 113 ++++++++++++++++++
2 files changed, 176 insertions(+)
create mode 100644 libstdc++-v3/testsuite/23_containers/mdspan/at.cc
diff --git a/libstdc++-v3/include/std/mdspan b/libstdc++-v3/include/std/mdspan
index 0c89f8e7155..3476d42a129 100644
--- a/libstdc++-v3/include/std/mdspan
+++ b/libstdc++-v3/include/std/mdspan
@@ -51,6 +51,9 @@
#include <tuple>
#endif
+#if __cplusplus > 202302L
+#include <bits/stdexcept_throw.h>
+#endif
#ifdef __glibcxx_mdspan
@@ -3083,6 +3086,62 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
operator[](const array<_OIndexType, rank()>& __indices) const
{ return (*this)[span<const _OIndexType, rank()>(__indices)]; }
+#if __cplusplus > 202302L
+ template<__mdspan::__valid_index_type<index_type>... _OIndexTypes>
+ requires (sizeof...(_OIndexTypes) == rank())
+ constexpr reference
+ at(_OIndexTypes... __indices) const
+ {
+ if constexpr (rank() == 0)
+ return _M_accessor.access(_M_handle, _M_mapping());
+ else if constexpr (!(is_integral_v<_OIndexTypes> && ...))
+ return at(__index_int_t<_OIndexTypes>(std::move(__indices))...);
+ else
+ {
+ auto __check_bound = [&]<typename _OIntType>(size_t __dim,
_OIntType __index)
+ {
+ if constexpr (is_signed_v<_OIntType>)
+ if (__index < 0)
+ std::__throw_out_of_range_fmt(
+ __N("mdspan::at: %zuth index is negative"), __dim);
+
+ const auto __ext = extents().extent(__dim);
+ if (std::cmp_greater_equal(__index, __ext))
+ std::__throw_out_of_range_fmt(
+ __N("mdspan::at: %zuth index (which is %zu)"
+ " >= extent(%zu) (which is %zu)"),
+ __dim, size_t(__index), __dim, size_t(__ext));
+ };
+ auto __check_bounds = [&]<size_t...
_Counts>(index_sequence<_Counts...>)
+ { (__check_bound(_Counts, __indices), ...); };
+
+ __check_bounds(make_index_sequence<rank()>());
+ auto __index = _M_mapping(static_cast<index_type>(__indices)...);
+ return _M_accessor.access(_M_handle, __index);
+ }
+ }
+
+ template<typename _OIndexType>
+ requires __mdspan::__valid_index_type<const _OIndexType&, index_type>
+ constexpr reference
+ at(span<_OIndexType, rank()> __indices) const
+ {
+ auto __call = [&]<size_t... _Counts>(index_sequence<_Counts...>)
+ -> reference
+ {
+ return at(
+ __index_int_t<_OIndexType>(as_const(__indices[_Counts]))...);
+ };
+ return __call(make_index_sequence<rank()>());
+ }
+
+ template<typename _OIndexType>
+ requires __mdspan::__valid_index_type<const _OIndexType&, index_type>
+ constexpr reference
+ at(const array<_OIndexType, rank()>& __indices) const
+ { return at(span<const _OIndexType, rank()>(__indices)); }
+#endif // C++26
+
constexpr size_type
size() const noexcept
{
@@ -3150,6 +3209,10 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
stride(rank_type __r) const { return _M_mapping.stride(__r); }
private:
+ template<typename _OIndexType>
+ using __index_int_t = std::conditional_t<
+ is_integral_v<_OIndexType>, _OIndexType, index_type>;
+
[[no_unique_address]] accessor_type _M_accessor = accessor_type();
[[no_unique_address]] mapping_type _M_mapping = mapping_type();
[[no_unique_address]] data_handle_type _M_handle = data_handle_type();
diff --git a/libstdc++-v3/testsuite/23_containers/mdspan/at.cc
b/libstdc++-v3/testsuite/23_containers/mdspan/at.cc
new file mode 100644
index 00000000000..4e659f57275
--- /dev/null
+++ b/libstdc++-v3/testsuite/23_containers/mdspan/at.cc
@@ -0,0 +1,113 @@
+// { dg-do run { target c++26 } }
+#include <mdspan>
+
+#include <testsuite_hooks.h>
+#include "int_like.h"
+#include <stdexcept>
+
+template<typename MDSpan, typename... Args>
+concept valid_at = requires (MDSpan md, Args... args)
+{
+ { md.at(args...) } -> std::same_as<typename MDSpan::reference>;
+};
+
+template<typename Int, bool ValidForPacks, bool ValidForArrays>
+ constexpr bool
+ test_at()
+ {
+ using Extents = std::extents<int, 3, 5, 7>;
+ auto exts = Extents{};
+
+ auto mapping = std::layout_left::mapping(exts);
+ constexpr size_t n = mapping.required_span_size();
+ std::array<double, n> storage{};
+
+ auto md = std::mdspan(storage.data(), mapping);
+ using MDSpan = decltype(md);
+
+ for(int i = 0; i < exts.extent(0); ++i)
+ for(int j = 0; j < exts.extent(1); ++j)
+ for(int k = 0; k < exts.extent(2); ++k)
+ {
+ storage[mapping(i, j, k)] = 1.0;
+ if constexpr (ValidForPacks)
+ VERIFY(md.at(Int(i), Int(j), Int(k)) == 1.0);
+
+ if constexpr (ValidForArrays)
+ {
+ std::array<Int, 3> ijk{Int(i), Int(j), Int(k)};
+ VERIFY(md.at(ijk) == 1.0);
+ VERIFY(md.at(std::span(ijk)) == 1.0);
+ }
+ storage[mapping(i, j, k)] = 0.0;
+ }
+
+ if constexpr (!ValidForPacks)
+ static_assert(!valid_at<MDSpan, Int, int, Int>);
+
+ if constexpr (!ValidForArrays)
+ {
+ static_assert(!valid_at<MDSpan, std::array<Int, 3>>);
+ static_assert(!valid_at<MDSpan, std::span<Int, 3>>);
+ }
+
+ auto verify_throw = [&md](int i, int j, int k)
+ {
+ if constexpr (ValidForPacks)
+ try
+ {
+ md.at(Int(i), Int(j), Int(k));
+ VERIFY(false);
+ }
+ catch (std::out_of_range&)
+ {
+ VERIFY(true);
+ }
+
+ if constexpr (ValidForArrays)
+ {
+ std::array<Int, 3> ijk{Int(i), Int(j), Int(k)};
+ try
+ {
+ md.at(ijk);
+ VERIFY(false);
+ }
+ catch (std::out_of_range&)
+ {
+ VERIFY(true);
+ }
+
+ try
+ {
+ md.at(std::span(ijk));
+ VERIFY(false);
+ }
+ catch (std::out_of_range&)
+ {
+ VERIFY(true);
+ }
+ }
+ };
+
+ verify_throw(-1, 0, 0);
+ verify_throw(0, -3, 0);
+ verify_throw(0, 0, -5);
+
+ verify_throw(11, 0, 0);
+ verify_throw(0, 13, 0);
+ verify_throw(0, 0, 15);
+
+ return true;
+ }
+
+int
+main()
+{
+ test_at<int, true, true>();
+ static_assert(test_at<int, true, true>());
+ test_at<short, true, true>();
+ test_at<IntLike, true, true>();
+ test_at<ThrowingInt, false, false>();
+ test_at<MutatingInt, true, false>();
+ test_at<RValueInt, true, false>();
+}
--
2.53.0