On Wed, 5 May 2021, Jonathan Wakely wrote:

> On 04/05/21 21:42 -0400, Patrick Palka via Libstdc++ wrote:
> > This rewrites ranges::minmax and ranges::minmax_element so that it
> > performs at most 3*N/2 many comparisons, as required by the standard.
> > In passing, this also fixes PR100387 by avoiding a premature std::move
> > in ranges::minmax and in std::shift_right.
> > 
> > Tested on x86_64-pc-linux-gnu, does this look OK for trunk and perhaps
> > 10/11?
> > 
> > libstdc++-v3/ChangeLog:
> > 
> >     PR libstdc++/100387
> >     * include/bits/ranges_algo.h (__minmax_fn::operator()): Rewrite
> >     to limit comparison complexity to 3*N/2.  Avoid premature std::move.
> >     (__minmax_element_fn::operator()): Likewise.
> >     (shift_right): Avoid premature std::move of __result.
> >     * testsuite/25_algorithms/minmax/constrained.cc (test04, test05):
> >     New tests.
> >     * testsuite/25_algorithms/minmax_element/constrained.cc (test02):
> >     Likewise.
> > ---
> > libstdc++-v3/include/bits/ranges_algo.h       | 87 ++++++++++++++-----
> > .../25_algorithms/minmax/constrained.cc       | 31 +++++++
> > .../minmax_element/constrained.cc             | 19 ++++
> > 3 files changed, 113 insertions(+), 24 deletions(-)
> > 
> > diff --git a/libstdc++-v3/include/bits/ranges_algo.h
> > b/libstdc++-v3/include/bits/ranges_algo.h
> > index cda3042c11f..bbd29127e89 100644
> > --- a/libstdc++-v3/include/bits/ranges_algo.h
> > +++ b/libstdc++-v3/include/bits/ranges_algo.h
> > @@ -3291,18 +3291,39 @@ namespace ranges
> >     auto __first = ranges::begin(__r);
> >     auto __last = ranges::end(__r);
> >     __glibcxx_assert(__first != __last);
> > +   auto __comp_proj = __detail::__make_comp_proj(__comp, __proj);
> >     minmax_result<range_value_t<_Range>> __result = {*__first, *__first};
> >     while (++__first != __last)
> >       {
> > -       auto __tmp = *__first;
> > -       if (std::__invoke(__comp,
> > -                         std::__invoke(__proj, __tmp),
> > -                         std::__invoke(__proj, __result.min)))
> > -         __result.min = std::move(__tmp);
> > -       if (!(bool)std::__invoke(__comp,
> > -                                std::__invoke(__proj, __tmp),
> > -                                std::__invoke(__proj, __result.max)))
> > -         __result.max = std::move(__tmp);
> > +       // Process two elements at a time so that we perform at most
> > +       // 3*N/2 many comparisons in total (each of the N/2 iterations
> 
> Is "many" a typo here?

Just a bad habit of mine to usually write "<count> many" instead of just
"<count>" :) Consider the "many" removed.

> 
> > +       // of this loop performs three comparisions).
> > +       auto __val1 = *__first;
> 
> Can we avoid making this copy if the range satisfies forward_range, by
> keeping copies of the min/max iterators, or just forwarding to
> ranges::minmax_element?

Maybe we can make __val1 and __val2 universal references?  Ah, but then
__val1 would potentially be invalidated after incrementing __first.  I
think it should be safe to make __val2 a universal reference though.
I've done this in v2 below.

Forwarding to ranges::minmax_element seems like it would be profitable
in some situations, e.g if the value type isn't trivially copyable.  I
can do this in a followup patch for ranges::max/max_element and
ranges::min/min_element too, they should all use the same heuristic.

> 
> 
> > +       if (++__first == __last)
> > +         {
> > +           // N is odd; in this final iteration, we perform a just one
> 
> s/perform a just one/perform just one/

Fixed.

> 
> > +           // comparison, for a total of 3*(N-1)/2 + 1 < 3*N/2
> > comparisons.
> 
> I find this a bit hard to parse with the inequality there.

Removed.

> 
> > +           if (__comp_proj(__val1, __result.min))
> > +             __result.min = std::move(__val1);
> > +           else if (!__comp_proj(__val1, __result.max))
> > +             __result.max = std::move(__val1);
> 
> This can be two comparisons, can't it? Would this be better...

Whoops, yeah...

> 
>   // N is odd; in this final iteration, we perform at most two
>   // comparisons, for a total of 3*(N-1)/2 + 2 comparisons,
>   // which is not more than 3*N/2, as required.
> 
> ?

Ah, but then the total is more than 3*N/2 :(  And I think we reach this
case really when N is even, not odd (sorry, I really botched this
patch).

And when N=2 in particular, we perform up two comparisons instead of
three, but actually a single comparison should suffice in this case.  I
think all this is fixed in v2 below by handling the second element in
the range specially.

> 
> > +           break;
> > +         }
> > +       auto __val2 = *__first;
> > +       if (!__comp_proj(__val2, __val1))
> > +         {
> > +           if (__comp_proj(__val1, __result.min))
> > +             __result.min = std::move(__val1);
> > +           if (!__comp_proj(__val2, __result.max))
> > +             __result.max = std::move(__val2);
> > +         }
> > +       else
> > +         {
> > +           if (__comp_proj(__val2, __result.min))
> > +             __result.min = std::move(__val2);
> > +           if (!__comp_proj(__val1, __result.max))
> > +             __result.max = std::move(__val1);
> > +         }
> 
> I thought we might be able to simplify this to something like:
> 
>     auto __val2 = *__first;
>     auto&& [__min, __max] = (*this)(__val1, __val2, __comp, __proj);
>     if (__comp_proj(__min, __result.min))
>       __result.min = __min;
>     if (__comp_proj(__result.max, __max))
>       __result.max = __max;
> 
> But it doesn't work because we need to move from __min and __max, but
> the (*this)(...) returns minmax_result<const T&> and can't be moved
> from.
> 
> We could get around that but it's not much of a simplification:
> 
>     range_value_t<Range> __val2 = *__first;
>     auto [__min, __max] = (*this)(std::addressof(__val1),
>                                   std::addressof(__val2),
>                                   __comp,
>                                   [](auto __p) -> const auto& {
>                                     return *__p;
>                                   });
>     if (__comp_proj(*__min, __result.min))
>       __result.min = std::move(*__min);
>     if (__comp_proj(__result.max, *__max))
>       __result.max = std::move(*__max);

Hmm, now that __val2 was made a universal reference, this simplifcation
might not work.

> 
> >       }
> >     return __result;
> >       }
> > @@ -3408,21 +3429,40 @@ namespace ranges
> >       operator()(_Iter __first, _Sent __last,
> >              _Comp __comp = {}, _Proj __proj = {}) const
> >       {
> > -   if (__first == __last)
> > -     return {__first, __first};
> > -
> > +   auto __comp_proj = __detail::__make_comp_proj(__comp, __proj);
> >     minmax_element_result<_Iter> __result = {__first, __first};
> > -   auto __i = __first;
> > -   while (++__i != __last)
> > +   if (__first == __last)
> > +     return __result;
> > +   while (++__first != __last)
> >       {
> > -       if (std::__invoke(__comp,
> > -                         std::__invoke(__proj, *__i),
> > -                         std::__invoke(__proj, *__result.min)))
> > -         __result.min = __i;
> > -       if (!(bool)std::__invoke(__comp,
> > -                                std::__invoke(__proj, *__i),
> > -                                std::__invoke(__proj, *__result.max)))
> > -         __result.max = __i;
> > +       // Process two elements at a time so that we perform at most
> > +       // 3*N/2 many comparisons in total (each of the N/2 iterations
> > +       // of this loop performs three comparisions).
> > +       auto __prev = __first;
> > +       if (++__first == __last)
> > +         {
> > +           // N is odd; in this final iteration, we perform a just one
> > +           // comparison, for a total of 3*(N-1)/2 + 1 < 3*N/2
> > comparisons.
> 
> Same comments on the comments as above.

Fixed.

> 
> > +           if (__comp_proj(*__prev, *__result.min))
> > +             __result.min = __prev;
> > +           else if (!__comp_proj(*__prev, *__result.max))
> > +             __result.max = __prev;
> > +           break;
> > +         }
> > +       if (!__comp_proj(*__first, *__prev))
> > +         {
> > +           if (__comp_proj(*__prev, *__result.min))
> > +             __result.min = __prev;
> > +           if (!__comp_proj(*__first, *__result.max))
> > +             __result.max = __first;
> > +         }
> > +       else
> > +         {
> > +           if (__comp_proj(*__first, *__result.min))
> > +             __result.min = __first;
> > +           if (!__comp_proj(*__prev, *__result.max))
> > +             __result.max = __prev;
> > +         }
> 
> We don't need to move anything here, so this could be written using
> ranges::minmax. I'm not sure it is an improvement though (except for
> being slightly fewer lines of code):
> 
>     auto __mm = minmax(__prev, __first, __comp,
>                        [](auto&& __it) -> auto&& { return *__it; });
> 
>     if (__comp_proj(*__mm.min, *__result.min))
>       __result.min = __mm.min;
>     if (__comp_proj(*__result.max, *__mm.max))
>       __result.max = __mm.max;

Hmm, it's shorter, but on the other hand it'd make the implementations
of minmax and minmax_element no longer mirror each other as closely.  So
I didn't make this change for now.

Here's v2, which makes the following additional changes:

  * Fixes comment typos, and some nearby indentation in the signature
    of minmax
  * Makes both minmax and minmax_element handle the second element
    specially in order to optimally perform just a single comparison
    when N=2 and to always stay below 3*N/2 comparisons
  * Avoids using a deduced return type for __val1
  * Makes __val2 a universal reference to avoid a move
  * Adds comparison complexity tests for the N=2,3 cases

-- >8 --

Subject: [PATCH] libstdc++: Reduce ranges::minmax/minmax_element comparison
 complexity

This rewrites ranges::minmax and ranges::minmax_element so that it
performs at most 3*N/2 many comparisons, as required by the standard.
In passing, this also fixes PR100387 by avoiding a premature std::move
in ranges::minmax and in std::shift_right.

Tested on x86_64-pc-linux-gnu, does this look OK for trunk and perhaps
10/11?

libstdc++-v3/ChangeLog:

        PR libstdc++/100387
        * include/bits/ranges_algo.h (__minmax_fn::operator()): Rewrite
        to limit comparison complexity to 3*N/2.
        (__minmax_element_fn::operator()): Likewise.
        (shift_right): Avoid premature std::move of __result.
        * testsuite/25_algorithms/minmax/constrained.cc (test04, test05):
        New tests.
        * testsuite/25_algorithms/minmax_element/constrained.cc (test02):
        Likewise.
---
 libstdc++-v3/include/bits/ranges_algo.h       | 113 ++++++++++++++----
 .../25_algorithms/minmax/constrained.cc       |  42 +++++++
 .../minmax_element/constrained.cc             |  27 +++++
 3 files changed, 156 insertions(+), 26 deletions(-)

diff --git a/libstdc++-v3/include/bits/ranges_algo.h 
b/libstdc++-v3/include/bits/ranges_algo.h
index cda3042c11f..2091cbf5b4e 100644
--- a/libstdc++-v3/include/bits/ranges_algo.h
+++ b/libstdc++-v3/include/bits/ranges_algo.h
@@ -3283,26 +3283,59 @@ namespace ranges
     template<input_range _Range, typename _Proj = identity,
             indirect_strict_weak_order<projected<iterator_t<_Range>, _Proj>>
               _Comp = ranges::less>
-      requires indirectly_copyable_storable<iterator_t<_Range>,
-      range_value_t<_Range>*>
+      requires indirectly_copyable_storable<iterator_t<_Range>, 
range_value_t<_Range>*>
       constexpr minmax_result<range_value_t<_Range>>
       operator()(_Range&& __r, _Comp __comp = {}, _Proj __proj = {}) const
       {
        auto __first = ranges::begin(__r);
        auto __last = ranges::end(__r);
        __glibcxx_assert(__first != __last);
+       auto __comp_proj = __detail::__make_comp_proj(__comp, __proj);
        minmax_result<range_value_t<_Range>> __result = {*__first, *__first};
+       if (++__first == __last)
+         return __result;
+       else
+         {
+           // At this point __result.min == __result.max, so a single
+           // comparison with the next element suffices.
+           auto&& __val = *__first;
+           if (__comp_proj(__val, __result.min))
+             __result.min = std::forward<decltype(__val)>(__val);
+           else
+             __result.max = std::forward<decltype(__val)>(__val);
+         }
        while (++__first != __last)
          {
-           auto __tmp = *__first;
-           if (std::__invoke(__comp,
-                             std::__invoke(__proj, __tmp),
-                             std::__invoke(__proj, __result.min)))
-             __result.min = std::move(__tmp);
-           if (!(bool)std::__invoke(__comp,
-                                    std::__invoke(__proj, __tmp),
-                                    std::__invoke(__proj, __result.max)))
-             __result.max = std::move(__tmp);
+           // Now process two elements at a time so that we perform at most
+           // 3*(N-2)/2 comparisons in total (each of the (N-2)/2 iterations
+           // of this loop performs three comparisons).
+           range_value_t<_Range> __val1 = *__first;
+           if (++__first == __last)
+             {
+               // N is odd; in this final iteration, we perform at most two
+               // comparisons, for a total of 1 + 3*(N-3)/2 + 2 comparisons,
+               // which is not more than 3*N/2, as required.
+               if (__comp_proj(__val1, __result.min))
+                 __result.min = std::move(__val1);
+               else if (!__comp_proj(__val1, __result.max))
+                 __result.max = std::move(__val1);
+               break;
+             }
+           auto&& __val2 = *__first;
+           if (!__comp_proj(__val2, __val1))
+             {
+               if (__comp_proj(__val1, __result.min))
+                 __result.min = std::move(__val1);
+               if (!__comp_proj(__val2, __result.max))
+                 __result.max = std::forward<decltype(__val2)>(__val2);
+             }
+           else
+             {
+               if (__comp_proj(__val2, __result.min))
+                 __result.min = std::forward<decltype(__val2)>(__val2);
+               if (!__comp_proj(__val1, __result.max))
+                 __result.max = std::move(__val1);
+             }
          }
        return __result;
       }
@@ -3408,21 +3441,50 @@ namespace ranges
       operator()(_Iter __first, _Sent __last,
                 _Comp __comp = {}, _Proj __proj = {}) const
       {
-       if (__first == __last)
-         return {__first, __first};
-
+       auto __comp_proj = __detail::__make_comp_proj(__comp, __proj);
        minmax_element_result<_Iter> __result = {__first, __first};
-       auto __i = __first;
-       while (++__i != __last)
+       if (__first == __last || ++__first == __last)
+         return __result;
+       else
          {
-           if (std::__invoke(__comp,
-                             std::__invoke(__proj, *__i),
-                             std::__invoke(__proj, *__result.min)))
-             __result.min = __i;
-           if (!(bool)std::__invoke(__comp,
-                                    std::__invoke(__proj, *__i),
-                                    std::__invoke(__proj, *__result.max)))
-             __result.max = __i;
+           // At this point __result.min == __result.max, so a single
+           // comparison with the next element suffices.
+           if (__comp_proj(*__first, *__result.min))
+             __result.min = __first;
+           else
+             __result.max = __first;
+         }
+       while (++__first != __last)
+         {
+           // Now process two elements at a time so that we perform at most
+           // 3*(N-2)/2 comparisons in total (each of the (N-2)/2 iterations
+           // of this loop performs three comparisons).
+           auto __prev = __first;
+           if (++__first == __last)
+             {
+               // N is odd; in this final iteration, we perform at most two
+               // comparisons, for a total of 1 + 3*(N-3)/2 + 2 comparisons,
+               // which is not more than 3*N/2, as required.
+               if (__comp_proj(*__prev, *__result.min))
+                 __result.min = __prev;
+               else if (!__comp_proj(*__prev, *__result.max))
+                 __result.max = __prev;
+               break;
+             }
+           if (!__comp_proj(*__first, *__prev))
+             {
+               if (__comp_proj(*__prev, *__result.min))
+                 __result.min = __prev;
+               if (!__comp_proj(*__first, *__result.max))
+                 __result.max = __first;
+             }
+           else
+             {
+               if (__comp_proj(*__first, *__result.min))
+                 __result.min = __first;
+               if (!__comp_proj(*__prev, *__result.max))
+                 __result.max = __prev;
+             }
          }
        return __result;
       }
@@ -3749,8 +3811,7 @@ namespace ranges
                  // i.e. we are shifting out at least half of the range.  In
                  // this case we can safely perform the shift with a single
                  // move.
-                 std::move(std::move(__first), std::move(__dest_head),
-                           std::move(__result));
+                 std::move(std::move(__first), std::move(__dest_head), 
__result);
                  return __result;
                }
              ++__dest_head;
diff --git a/libstdc++-v3/testsuite/25_algorithms/minmax/constrained.cc 
b/libstdc++-v3/testsuite/25_algorithms/minmax/constrained.cc
index 786922414b5..c365152bf2b 100644
--- a/libstdc++-v3/testsuite/25_algorithms/minmax/constrained.cc
+++ b/libstdc++-v3/testsuite/25_algorithms/minmax/constrained.cc
@@ -19,6 +19,8 @@
 // { dg-do run { target c++2a } }
 
 #include <algorithm>
+#include <string>
+#include <vector>
 #include <testsuite_hooks.h>
 #include <testsuite_iterators.h>
 
@@ -89,10 +91,50 @@ test03()
          == res_t(1,4) );
 }
 
+void
+test04()
+{
+  // Verify we perform at most 3*N/2 applications of the comparison predicate.
+  static int counter;
+  struct counted_less
+  { bool operator()(int a, int b) { ++counter; return a < b; } };
+
+  ranges::minmax({1,2}, counted_less{});
+  VERIFY( counter == 1 );
+
+  counter = 0;
+  ranges::minmax({1,2,3}, counted_less{});
+  VERIFY( counter == 3 );
+
+  counter = 0;
+  ranges::minmax({1,2,3,4,5,6,7,8,9,10}, counted_less{});
+  VERIFY( counter <= 15 );
+
+  counter = 0;
+  ranges::minmax({10,9,8,7,6,5,4,3,2,1}, counted_less{});
+  VERIFY( counter <= 15 );
+}
+
+void
+test05()
+{
+  // PR libstdc++/100387
+  using namespace std::literals::string_literals;
+  auto comp = [](const auto& a, const auto& b) {
+    return a.size() == b.size() ? a.front() < b.front() : a.size() > b.size();
+  };
+  auto result = ranges::minmax({"b"s, "a"s}, comp);
+  VERIFY( result.min == "a"s && result.max == "b"s );
+  result = ranges::minmax({"c"s, "b"s, "a"s}, comp);
+  VERIFY( result.min == "a"s && result.max == "c"s );
+}
+
 int
 main()
 {
   test01();
   test02();
   test03();
+  test04();
+  test05();
 }
diff --git a/libstdc++-v3/testsuite/25_algorithms/minmax_element/constrained.cc 
b/libstdc++-v3/testsuite/25_algorithms/minmax_element/constrained.cc
index 3b11c0dd96c..0919f7dda8f 100644
--- a/libstdc++-v3/testsuite/25_algorithms/minmax_element/constrained.cc
+++ b/libstdc++-v3/testsuite/25_algorithms/minmax_element/constrained.cc
@@ -61,8 +61,35 @@ test01()
   static_assert(ranges::minmax_element(y, y+3, {}, &X::i).max->j == 3);
 }
 
+void
+test02()
+{
+  // Verify we perform at most 3*N/2 applications of the comparison predicate.
+  static int counter;
+  struct counted_less
+  { bool operator()(int a, int b) { ++counter; return a < b; } };
+
+  int x[] = {1,2,3,4,5,6,7,8,9,10};
+  ranges::minmax_element(x, x+2, counted_less{});
+  VERIFY( counter == 1 );
+
+  counter = 0;
+  ranges::minmax_element(x, x+3, counted_less{});
+  VERIFY( counter == 3 );
+
+  counter = 0;
+  ranges::minmax_element(x, counted_less{});
+  VERIFY( counter <= 15 );
+
+  ranges::reverse(x);
+  counter = 0;
+  ranges::minmax_element(x, counted_less{});
+  VERIFY( counter <= 15 );
+}
+
 int
 main()
 {
   test01();
+  test02();
 }
-- 
2.31.1.442.g7e39198978

Reply via email to