libstdc++-v3/ChangeLog:

        * include/bits/atomic_timed_wait.h: Use __wait_result_type.
        * include/bits/atomic_wait.h (__wait_result_type): New struct.
        (__wait_args::_M_prep_for_wait_on): Rename to _M_setup_wait, use
        __wait_result_type.
        (__atomic_wait_address): Adjust to call _M_setup_wait.
        * src/c++20/atomic.cc (__spin_impl): Use __wait_result_type.
        (__wait_impl): Likewise.
        (__spin_until_impl): Likewise.
        (__wait_until_impl): Likewise.
---
 libstdc++-v3/include/bits/atomic_timed_wait.h | 25 ++++-----
 libstdc++-v3/include/bits/atomic_wait.h       | 48 +++++++++++-----
 libstdc++-v3/src/c++20/atomic.cc              | 55 ++++++++++++-------
 3 files changed, 80 insertions(+), 48 deletions(-)

diff --git a/libstdc++-v3/include/bits/atomic_timed_wait.h 
b/libstdc++-v3/include/bits/atomic_timed_wait.h
index 3e25607b7d4c..230afbc96e7d 100644
--- a/libstdc++-v3/include/bits/atomic_timed_wait.h
+++ b/libstdc++-v3/include/bits/atomic_timed_wait.h
@@ -98,18 +98,17 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
                                                 __at.time_since_epoch());
 
        if constexpr (!is_same_v<__wait_clock_t, _Clock>)
-         if (!__res.first)
+         if (__res._M_timeout)
            {
              // We got a timeout when measured against __clock_t but
              // we need to check against the caller-supplied clock
              // to tell whether we should return a timeout.
              if (_Clock::now() < __atime)
-               __res.first = true;
+               __res._M_timeout = false;
            }
        return __res;
       }
 
-    // Returns {true, val} if wait ended before a timeout.
     template<typename _Rep, typename _Period>
       __wait_result_type
       __wait_for(const void* __addr, __wait_args_base& __args,
@@ -139,14 +138,13 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
                                bool __bare_wait = false) noexcept
     {
       __detail::__wait_args __args{ __addr, __bare_wait };
-      _Tp __val = __args._M_prep_for_wait_on(__addr, __vfn);
+      _Tp __val = __args._M_setup_wait(__addr, __vfn);
       while (!__pred(__val))
        {
          auto __res = __detail::__wait_until(__addr, __args, __atime);
-         if (!__res.first)
-           // timed out
-           return __res.first; // C++26 will also return last observed __val
-         __val = __args._M_prep_for_wait_on(__addr, __vfn);
+         if (__res._M_timeout)
+           return false; // C++26 will also return last observed __val
+         __val = __args._M_setup_wait(__addr, __vfn, __res);
        }
       return true; // C++26 will also return last observed __val
     }
@@ -189,14 +187,13 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
                              bool __bare_wait = false) noexcept
     {
       __detail::__wait_args __args{ __addr, __bare_wait };
-      _Tp __val = __args._M_prep_for_wait_on(__addr, __vfn);
+      _Tp __val = __args._M_setup_wait(__addr, __vfn);
       while (!__pred(__val))
        {
          auto __res = __detail::__wait_for(__addr, __args, __rtime);
-         if (!__res.first)
-           // timed out
-           return __res.first; // C++26 will also return last observed __val
-         __val = __args._M_prep_for_wait_on(__addr, __vfn);
+         if (__res._M_timeout)
+           return false; // C++26 will also return last observed __val
+         __val = __args._M_setup_wait(__addr, __vfn);
        }
       return true; // C++26 will also return last observed __val
     }
@@ -211,7 +208,7 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
     {
       __detail::__wait_args __args{ __addr, __old, __order, __bare_wait };
       auto __res = __detail::__wait_for(__addr, __args, __rtime);
-      return __res.first; // C++26 will also return last observed __Val
+      return !__res._M_timeout; // C++26 will also return last observed __val
     }
 
   template<typename _Tp, typename _ValFn,
diff --git a/libstdc++-v3/include/bits/atomic_wait.h 
b/libstdc++-v3/include/bits/atomic_wait.h
index 33e8d3202566..815726c16ccb 100644
--- a/libstdc++-v3/include/bits/atomic_wait.h
+++ b/libstdc++-v3/include/bits/atomic_wait.h
@@ -105,6 +105,15 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
        return __builtin_memcmp(&__a, &__b, sizeof(_Tp)) == 0;
       }
 
+    // lightweight std::optional<__platform_wait_t>
+    struct __wait_result_type
+    {
+      __platform_wait_t _M_val;
+      unsigned char _M_has_val : 1; // _M_val value was loaded before return.
+      unsigned char _M_timeout : 1; // Waiting function ended with timeout.
+      unsigned char _M_unused : 6;  // padding
+    };
+
     enum class __wait_flags : __UINT_LEAST32_TYPE__
     {
        __abi_version = 0,
@@ -166,21 +175,32 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
       template<typename _ValFn,
               typename _Tp = decay_t<decltype(std::declval<_ValFn&>()())>>
        _Tp
-       _M_prep_for_wait_on(const void* __addr, _ValFn __vfn)
+       _M_setup_wait(const void* __addr, _ValFn __vfn,
+                     __wait_result_type __res = {})
        {
          if constexpr (__platform_wait_uses_type<_Tp>)
            {
-             _Tp __val = __vfn();
-             // If the wait is not proxied, set the value that we're waiting
-             // to change.
-             _M_old = __builtin_bit_cast(__platform_wait_t, __val);
-             return __val;
+             // If the wait is not proxied, the value we check when waiting
+             // is the value of the atomic variable itself.
+
+             if (__res._M_has_val) // The previous wait loaded a recent value.
+               {
+                 _M_old = __res._M_val;
+                 return __builtin_bit_cast(_Tp, __res._M_val);
+               }
+             else // Load the value from __vfn
+               {
+                 _Tp __val = __vfn();
+                 _M_old = __builtin_bit_cast(__platform_wait_t, __val);
+                 return __val;
+               }
            }
-         else
+         else // It's a proxy wait and the proxy's _M_ver is used.
            {
-             // Otherwise, it's a proxy wait and the proxy's _M_ver is used.
-             // This load must happen before the one done by __vfn().
-             _M_load_proxy_wait_val(__addr);
+             if (__res._M_has_val) // The previous wait loaded a recent value.
+               _M_old = __res._M_val;
+             else // Load _M_ver from the proxy (must happen before __vfn()).
+               _M_load_proxy_wait_val(__addr);
              return __vfn();
            }
        }
@@ -204,8 +224,6 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
        }
     };
 
-    using __wait_result_type = pair<bool, __platform_wait_t>;
-
     __wait_result_type
     __wait_impl(const void* __addr, __wait_args_base&);
 
@@ -222,11 +240,11 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
                          bool __bare_wait = false) noexcept
     {
       __detail::__wait_args __args{ __addr, __bare_wait };
-      _Tp __val = __args._M_prep_for_wait_on(__addr, __vfn);
+      _Tp __val = __args._M_setup_wait(__addr, __vfn);
       while (!__pred(__val))
        {
-         __detail::__wait_impl(__addr, __args);
-         __val = __args._M_prep_for_wait_on(__addr, __vfn);
+         auto __res = __detail::__wait_impl(__addr, __args);
+         __val = __args._M_setup_wait(__addr, __vfn, __res);
        }
       // C++26 will return __val
     }
diff --git a/libstdc++-v3/src/c++20/atomic.cc b/libstdc++-v3/src/c++20/atomic.cc
index b9ad66b1ec30..a3ec92a10d56 100644
--- a/libstdc++-v3/src/c++20/atomic.cc
+++ b/libstdc++-v3/src/c++20/atomic.cc
@@ -48,6 +48,8 @@
 # endif
 #endif
 
+#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
+
 namespace std
 {
 _GLIBCXX_BEGIN_NAMESPACE_VERSION
@@ -208,21 +210,23 @@ namespace
   constexpr auto __atomic_spin_count_relax = 12;
   constexpr auto __atomic_spin_count = 16;
 
+  // This function always returns _M_has_val == true and _M_val == *__addr.
+  // _M_timeout == (*__addr == __args._M_old).
   __wait_result_type
   __spin_impl(const __platform_wait_t* __addr, const __wait_args_base& __args)
   {
-    __platform_wait_t __val;
+    __platform_wait_t __val{};
     for (auto __i = 0; __i < __atomic_spin_count; ++__i)
       {
        __atomic_load(__addr, &__val, __args._M_order);
        if (__val != __args._M_old)
-         return { true, __val };
+         return { ._M_val = __val, ._M_has_val = true, ._M_timeout = false };
        if (__i < __atomic_spin_count_relax)
          __thread_relax();
        else
          __thread_yield();
       }
-    return { false, __val };
+    return { ._M_val = __val, ._M_has_val = true, ._M_timeout = true };
   }
 
   inline __waitable_state*
@@ -263,7 +267,7 @@ __wait_impl(const void* __addr, __wait_args_base& __args)
   if (__args & __wait_flags::__do_spin)
     {
       auto __res = __detail::__spin_impl(__wait_addr, __args);
-      if (__res.first)
+      if (!__res._M_timeout)
        return __res;
       if (__args & __wait_flags::__spin_only)
        return __res;
@@ -271,17 +275,21 @@ __wait_impl(const void* __addr, __wait_args_base& __args)
 
 #ifdef _GLIBCXX_HAVE_PLATFORM_WAIT
   if (__args & __wait_flags::__track_contention)
-    set_wait_state(__addr, __args);
+    set_wait_state(__addr, __args); // scoped_wait needs a __waitable_state
   scoped_wait s(__args);
   __platform_wait(__wait_addr, __args._M_old);
-  return { false, __args._M_old };
+  // We haven't loaded a new value so return false as first member:
+  return { ._M_val = __args._M_old, ._M_has_val = false, ._M_timeout = false };
 #else
   waiter_lock l(__args);
   __platform_wait_t __val;
   __atomic_load(__wait_addr, &__val, __args._M_order);
   if (__val == __args._M_old)
-    __state->_M_cv.wait(__state->_M_mtx);
-  return { false, __val };
+    {
+      __state->_M_cv.wait(__state->_M_mtx);
+      return { ._M_val = __val, ._M_has_val = false, ._M_timeout = false };
+    }
+  return { ._M_val = __val, ._M_has_val = true, ._M_timeout = false };
 #endif
 }
 
@@ -389,6 +397,7 @@ __cond_wait_until(__condvar& __cv, mutex& __mx,
 }
 #endif // ! HAVE_PLATFORM_TIMED_WAIT
 
+// Like __spin_impl, always returns _M_has_val == true.
 __wait_result_type
 __spin_until_impl(const __platform_wait_t* __addr,
                  const __wait_args_base& __args,
@@ -411,14 +420,18 @@ __spin_until_impl(const __platform_wait_t* __addr,
 #endif
       if (__elapsed > 4us)
        __thread_yield();
-      else if (auto __res = __detail::__spin_impl(__addr, __args); __res.first)
-       return __res;
+      else
+       {
+         auto __res = __detail::__spin_impl(__addr, __args);
+         if (!__res._M_timeout)
+           return __res;
+       }
 
       __atomic_load(__addr, &__val, __args._M_order);
       if (__val != __args._M_old)
-       return { true, __val };
+       return { ._M_val = __val, ._M_has_val = true, ._M_timeout = false };
     }
-  return { false, __val };
+  return { ._M_val = __val, ._M_has_val = true, ._M_timeout = true };
 }
 } // namespace
 
@@ -437,7 +450,7 @@ __wait_until_impl(const void* __addr, __wait_args_base& 
__args,
   if (__args & __wait_flags::__do_spin)
     {
       auto __res = __detail::__spin_until_impl(__wait_addr, __args, __atime);
-      if (__res.first)
+      if (!__res._M_timeout)
        return __res;
       if (__args & __wait_flags::__spin_only)
        return __res;
@@ -448,17 +461,21 @@ __wait_until_impl(const void* __addr, __wait_args_base& 
__args,
     set_wait_state(__addr, __args);
   scoped_wait s(__args);
   if (__platform_wait_until(__wait_addr, __args._M_old, __atime))
-    return { true, __args._M_old };
+    return { ._M_val = __args._M_old, ._M_has_val = false, ._M_timeout = false 
};
   else
-    return { false, __args._M_old };
+    return { ._M_val = __args._M_old, ._M_has_val = false, ._M_timeout = true 
};
 #else
   waiter_lock l(__args);
   __platform_wait_t __val;
   __atomic_load(__wait_addr, &__val, __args._M_order);
-  if (__val == __args._M_old
-       && __cond_wait_until(__state->_M_cv, __state->_M_mtx, __atime))
-    return { true, __val };
-  return { false, __val };
+  if (__val == __args._M_old)
+    {
+      if (__cond_wait_until(__state->_M_cv, __state->_M_mtx, __atime))
+       return { ._M_val = __val, ._M_has_val = false, ._M_timeout = false };
+      else
+       return { ._M_val = __val, ._M_has_val = false, ._M_timeout = true };
+    }
+  return { ._M_val = __val, ._M_has_val = true, ._M_timeout = false };
 #endif
 }
 
-- 
2.49.0

Reply via email to