Hi

In my work on fancy pointer support I've decided to always cache the hash code.

Doing so I spotted a bug in the management of this cache when hash functor is stateful.

    libstdc++: [_Hashtable] Fix hash code cache usage when hash functor is stateful

    It is wrong to reuse a cached hash code when this code depends then on the state
    of the Hash functor.

    Add checks that Hash functor is stateless before reusing the cached hash code.

    libstdc++-v3/ChangeLog:

            * include/bits/hashtable_policy.h (_Hash_code_base::_M_copy_code): Remove.
            * include/bits/hashtable.h (_M_copy_code): New.
            (_M_assign): Use latter.
            (_M_bucket_index_ex): New.
            (_M_equals): Use latter.
            * testsuite/23_containers/unordered_map/modifiers/merge.cc (test10): New
            test case.

Tested under Linux x64, ok to commit ?

François
diff --git a/libstdc++-v3/include/bits/hashtable.h 
b/libstdc++-v3/include/bits/hashtable.h
index d6d76a743bb..b3c1d7aac24 100644
--- a/libstdc++-v3/include/bits/hashtable.h
+++ b/libstdc++-v3/include/bits/hashtable.h
@@ -808,6 +808,36 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
       _M_bucket_index(__hash_code __c) const
       { return __hash_code_base::_M_bucket_index(__c, _M_bucket_count); }
 
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wc++17-extensions" // if constexpr
+      // Like _M_bucket_index but when the node is coming from another
+      // container instance.
+      size_type
+      _M_bucket_index_ex(const __node_value_type& __n) const
+      {
+       if constexpr (__hash_cached::value)
+         if constexpr (std::is_empty<_Hash>::value)
+           return _RangeHash{}(__n._M_hash_code, _M_bucket_count);
+
+       return _RangeHash{}
+         (this->_M_hash_code(_ExtractKey{}(__n._M_v())), _M_bucket_count);
+      }
+
+      void
+      _M_copy_code(__node_value_type& __to,
+                  const __node_value_type& __from) const
+      {
+       if constexpr (__hash_cached::value)
+         {
+           if constexpr (std::is_empty<_Hash>::value)
+             __to._M_hash_code = __from._M_hash_code;
+           else
+             __to._M_hash_code =
+               this->_M_hash_code(_ExtractKey{}(__from._M_v()));
+         }
+      }
+#pragma GCC diagnostic pop
+
       // Find and insert helper functions and types
 
       // Find the node before the one matching the criteria.
@@ -1587,7 +1617,7 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
        __node_ptr __ht_n = __ht._M_begin();
        __node_ptr __this_n
          = __node_gen(static_cast<_FromVal>(__ht_n->_M_v()));
-       this->_M_copy_code(*__this_n, *__ht_n);
+       _M_copy_code(*__this_n, *__ht_n);
        _M_update_bbegin(__this_n);
 
        // Then deal with other nodes.
@@ -1596,7 +1626,7 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
          {
            __this_n = __node_gen(static_cast<_FromVal>(__ht_n->_M_v()));
            __prev_n->_M_nxt = __this_n;
-           this->_M_copy_code(*__this_n, *__ht_n);
+           _M_copy_code(*__this_n, *__ht_n);
            size_type __bkt = _M_bucket_index(*__this_n);
            if (!_M_buckets[__bkt])
              _M_buckets[__bkt] = __prev_n;
@@ -2851,7 +2881,7 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
       if constexpr (__unique_keys::value)
        for (auto __x_n = _M_begin(); __x_n; __x_n = __x_n->_M_next())
          {
-           std::size_t __ybkt = __other._M_bucket_index(*__x_n);
+           std::size_t __ybkt = __other._M_bucket_index_ex(*__x_n);
            auto __prev_n = __other._M_buckets[__ybkt];
            if (!__prev_n)
              return false;
@@ -2878,7 +2908,7 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
                 __x_n_end = __x_n_end->_M_next())
              ++__x_count;
 
-           std::size_t __ybkt = __other._M_bucket_index(*__x_n);
+           std::size_t __ybkt = __other._M_bucket_index_ex(*__x_n);
            auto __y_prev_n = __other._M_buckets[__ybkt];
            if (!__y_prev_n)
              return false;
diff --git a/libstdc++-v3/include/bits/hashtable_policy.h 
b/libstdc++-v3/include/bits/hashtable_policy.h
index 1fa8c01d5e8..61c57651c4a 100644
--- a/libstdc++-v3/include/bits/hashtable_policy.h
+++ b/libstdc++-v3/include/bits/hashtable_policy.h
@@ -1108,19 +1108,9 @@ namespace __detail
       _M_store_code(_Hash_node_code_cache<false>&, __hash_code) const
       { }
 
-      void
-      _M_copy_code(_Hash_node_code_cache<false>&,
-                  const _Hash_node_code_cache<false>&) const
-      { }
-
       void
       _M_store_code(_Hash_node_code_cache<true>& __n, __hash_code __c) const
       { __n._M_hash_code = __c; }
-
-      void
-      _M_copy_code(_Hash_node_code_cache<true>& __to,
-                  const _Hash_node_code_cache<true>& __from) const
-      { __to._M_hash_code = __from._M_hash_code; }
     };
 
   /// Partial specialization used when nodes contain a cached hash code.
diff --git 
a/libstdc++-v3/testsuite/23_containers/unordered_map/modifiers/merge.cc 
b/libstdc++-v3/testsuite/23_containers/unordered_map/modifiers/merge.cc
index 10b61464243..010181f7038 100644
--- a/libstdc++-v3/testsuite/23_containers/unordered_map/modifiers/merge.cc
+++ b/libstdc++-v3/testsuite/23_containers/unordered_map/modifiers/merge.cc
@@ -417,6 +417,51 @@ test09()
   VERIFY( c2.size() == 3 );
 }
 
+struct slow_stateful_hash
+{
+  size_t seed = 0;
+
+  auto operator()(const int& i) const noexcept
+  { return std::hash<int>()(i) + seed; }
+};
+
+namespace std
+{
+  template<>
+    struct __is_fast_hash<slow_stateful_hash> : public std::false_type
+    { };
+}
+
+void
+test10()
+{
+  using map_type = std::unordered_map<int, int, slow_stateful_hash>;
+  map_type c1({ {1, 1}, {3, 3}, {5, 5} }, 0, slow_stateful_hash{1});
+  map_type c2({ {2, 2}, {4, 4}, {6, 6} }, 0, slow_stateful_hash{2});
+  const auto c3 = c2;
+
+  c1.merge(c2);
+  VERIFY( c1.size() == 6 );
+  VERIFY( c2.empty() );
+
+  c2 = c3;
+  c1.clear();
+  c1.merge(std::move(c2));
+  VERIFY( c1 == c3 );
+  VERIFY( c2.empty() );
+
+  c2.merge(std::move(c1));
+  VERIFY( c1.empty() );
+  VERIFY( c2 == c3 );
+
+  c2.merge(c1);
+  VERIFY( c1.empty() );
+  VERIFY( c2 == c3 );
+
+  c2.merge(c2);
+  VERIFY( c2 == c3 );
+}
+
 int
 main()
 {
@@ -429,4 +474,5 @@ main()
   test07();
   test08();
   test09();
+  test10();
 }

Reply via email to