HappenLee commented on code in PR #12349:
URL: https://github.com/apache/doris/pull/12349#discussion_r965753805


##########
be/src/vec/core/sort_block.h:
##########
@@ -46,9 +48,425 @@ void stable_get_permutation(const Block& block, const 
SortDescription& descripti
   */
 bool is_already_sorted(const Block& block, const SortDescription& description);
 
-using ColumnsWithSortDescriptions = std::vector<std::pair<const IColumn*, 
SortColumnDescription>>;
+using ColumnWithSortDescription = std::pair<const IColumn*, 
SortColumnDescription>;
+
+using ColumnsWithSortDescriptions = std::vector<ColumnWithSortDescription>;
 
 ColumnsWithSortDescriptions get_columns_with_sort_description(const Block& 
block,
                                                               const 
SortDescription& description);
 
+struct EqualRangeIterator {
+    int range_begin;
+    int range_end;
+
+    EqualRangeIterator(const EqualFlags& flags) : EqualRangeIterator(flags, 0, 
flags.size()) {}
+
+    EqualRangeIterator(const EqualFlags& flags, int begin, int end) : 
flags_(flags), end_(end) {
+        range_begin = begin;
+        range_end = end;
+        cur_range_begin_ = begin;
+        cur_range_end_ = end;
+    }
+
+    bool next() {
+        if (cur_range_begin_ >= end_) {
+            return false;
+        }
+
+        // `flags_[i]=1` indicates that the i-th row is equal to the previous 
row, which means we
+        // should continue to sort this row according to current column. Using 
the first non-zero
+        // value and first zero value after first non-zero value as two 
bounds, we can get an equal range here
+        if (!(cur_range_begin_ == 0) || !(flags_[cur_range_begin_] == 1)) {
+            cur_range_begin_ = simd::find_nonzero(flags_, cur_range_begin_ + 
1);
+            if (cur_range_begin_ >= end_) {
+                return false;
+            }
+            cur_range_begin_--;
+        }
+
+        cur_range_end_ = simd::find_zero(flags_, cur_range_begin_ + 1);
+        cur_range_end_ = std::min(cur_range_end_, end_);
+
+        if (cur_range_begin_ >= cur_range_end_) {
+            return false;
+        }
+
+        range_begin = cur_range_begin_;
+        range_end = cur_range_end_;
+        cur_range_begin_ = cur_range_end_;
+        return true;
+    }
+
+private:
+    int cur_range_begin_;
+    int cur_range_end_;
+
+    const EqualFlags& flags_;
+    const int end_;
+};
+
+struct ColumnPartialSortingLess {
+    const ColumnWithSortDescription& column_;
+
+    explicit ColumnPartialSortingLess(const ColumnWithSortDescription& column) 
: column_(column) {}
+
+    bool operator()(size_t a, size_t b) const {
+        int res = column_.second.direction *
+                  column_.first->compare_at(a, b, *column_.first, 
column_.second.nulls_direction);
+        if (res < 0) {
+            return true;
+        } else if (res > 0) {
+            return false;
+        }
+        return false;
+    }
+};
+
+template <typename T>
+struct PermutationWithInlineValue {
+    T inline_value_;
+    uint32_t row_id_;
+
+    PermutationWithInlineValue(T inline_value, uint32_t row_id)
+            : inline_value_(inline_value), row_id_(row_id) {}
+
+    PermutationWithInlineValue() : row_id_(-1) {}
+};
+
+template <typename T>
+using PermutationForColumn = std::vector<PermutationWithInlineValue<T>>;
+
+class ColumnSorter {
+public:
+    explicit ColumnSorter(const ColumnWithSortDescription& column, const int 
limit)
+            : column_(column),
+              limit_(limit),
+              nulls_direction_(column.second.nulls_direction),
+              direction_(column.second.direction) {}
+
+    void operator()(EqualFlags& flags, IColumn::Permutation& perms, 
EqualRange& range,
+                    bool last_column) const {
+        column_.first->sort_column(this, flags, perms, range, last_column);
+    }
+
+    void sort_column(const IColumn& column, EqualFlags& flags, 
IColumn::Permutation& perms,
+                     EqualRange& range, bool last_column) const {
+        int new_limit = limit_;
+        auto comparator = [&](const size_t a, const size_t b) {
+            return column.compare_at(a, b, *column_.first, nulls_direction_);
+        };
+        ColumnPartialSortingLess less(column_);
+        auto do_sort = [&](size_t first_iter, size_t last_iter) {
+            auto begin = perms.begin() + first_iter;
+            auto end = perms.begin() + last_iter;
+
+            if (UNLIKELY(limit_ > 0 && first_iter < limit_ && limit_ <= 
last_iter)) {
+                int n = limit_ - first_iter;
+                std::partial_sort(begin, begin + n, end, less);
+
+                auto nth = perms[limit_ - 1];
+                size_t equal_count = 0;
+                for (auto iter = begin + n; iter < end; iter++) {
+                    if (comparator(*iter, nth) == 0) {
+                        std::iter_swap(iter, begin + n + equal_count);
+                        equal_count++;
+                    }
+                }
+                new_limit = limit_ + equal_count;
+            } else {
+                pdqsort(begin, end, less);
+            }
+        };
+
+        EqualRangeIterator iterator(flags, range.first, range.second);
+        while (iterator.next()) {
+            int range_begin = iterator.range_begin;
+            int range_end = iterator.range_end;
+
+            if (UNLIKELY(limit_ > 0 && range_begin > limit_)) {
+                break;
+            }
+            if (LIKELY(range_end - range_begin > 1)) {
+                do_sort(range_begin, range_end);
+                if (!last_column) {
+                    flags[range_begin] = 0;
+                    for (int i = range_begin + 1; i < range_end; i++) {
+                        flags[i] &= comparator(perms[i - 1], perms[i]) == 0;
+                    }
+                }
+            }
+        }
+        _shrink_to_fit(perms, flags, new_limit);
+    }
+
+    template <typename T>
+    void sort_column(const ColumnVector<T>& column, EqualFlags& flags, 
IColumn::Permutation& perms,
+                     EqualRange& range, bool last_column) const {
+        if (!_should_inline_value(perms)) {
+            _sort_by_default(column, flags, perms, range, last_column);
+        } else {
+            _sort_by_inlined_permutation<T>(column, flags, perms, range, 
last_column);
+        }
+    }
+
+    template <typename T>
+    void sort_column(const ColumnDecimal<T>& column, EqualFlags& flags, 
IColumn::Permutation& perms,
+                     EqualRange& range, bool last_column) const {
+        if (!_should_inline_value(perms)) {
+            _sort_by_default(column, flags, perms, range, last_column);
+        } else {
+            _sort_by_inlined_permutation<T>(column, flags, perms, range, 
last_column);
+        }
+    }
+
+    void sort_column(const ColumnString& column, EqualFlags& flags, 
IColumn::Permutation& perms,
+                     EqualRange& range, bool last_column) const {
+        if (!_should_inline_value(perms)) {
+            _sort_by_default(column, flags, perms, range, last_column);
+        } else {
+            _sort_by_inlined_permutation<StringRef>(column, flags, perms, 
range, last_column);
+        }
+    }
+
+    void sort_column(const ColumnNullable& column, EqualFlags& flags, 
IColumn::Permutation& perms,
+                     EqualRange& range, bool last_column) const {
+        if (!column.has_null()) {
+            column.get_nested_column().sort_column(this, flags, perms, range, 
last_column);
+        } else {
+            const auto& null_map = column.get_null_map_data();
+            EqualRangeIterator iterator(flags, range.first, range.second);
+            while (iterator.next()) {
+                int range_begin = iterator.range_begin;
+                int range_end = iterator.range_end;
+
+                if (UNLIKELY(limit_ > 0 && range_begin > limit_)) {
+                    break;
+                }
+                bool null_first = nulls_direction_ * direction_ < 0;
+                if (LIKELY(range_end - range_begin > 1)) {
+                    int range_split = 0;
+                    if (null_first) {
+                        range_split = std::partition(perms.begin() + 
range_begin,
+                                                     perms.begin() + range_end,
+                                                     [&](size_t row_id) -> 
bool {
+                                                         return 
null_map[row_id] != 0;
+                                                     }) -
+                                      perms.begin();
+                    } else {
+                        range_split = std::partition(perms.begin() + 
range_begin,
+                                                     perms.begin() + range_end,
+                                                     [&](size_t row_id) -> 
bool {
+                                                         return 
null_map[row_id] == 0;
+                                                     }) -
+                                      perms.begin();
+                    }
+                    std::pair<size_t, size_t> is_null_range = {range_begin, 
range_split};
+                    std::pair<size_t, size_t> not_null_range = {range_split, 
range_end};
+                    if (!null_first) {
+                        std::swap(is_null_range, not_null_range);
+                    }
+
+                    if (not_null_range.first < not_null_range.second) {
+                        flags[not_null_range.first] = 0;
+                    }
+                    if (range_begin <= is_null_range.first && 
is_null_range.first < range_end) {
+                        std::fill(flags.begin() + is_null_range.first,
+                                  flags.begin() + is_null_range.second, 1);
+
+                        flags[is_null_range.first] = 0;
+                    }
+                }
+            }
+
+            column.get_nested_column().sort_column(this, flags, perms, range, 
last_column);
+        }
+    }
+
+private:
+    bool _should_inline_value(const IColumn::Permutation& perms) const {
+        return limit_ == 0 || limit_ > (perms.size() / 5);
+    }
+
+    template <typename T>
+    void _shrink_to_fit(PermutationForColumn<T>& permutation_for_column,
+                        IColumn::Permutation& perms, EqualFlags& flags, int 
limit) const {
+        if (limit < perms.size() && limit != 0) {
+            permutation_for_column.resize(limit);
+            perms.resize(limit);
+            flags.resize(limit);
+        }
+    }
+
+    void _shrink_to_fit(IColumn::Permutation& perms, EqualFlags& flags, int 
limit) const {
+        if (limit_ < perms.size() && limit != 0) {
+            perms.resize(limit);
+            flags.resize(limit);
+        }
+    }
+
+    template <typename ColumnType>
+    static constexpr bool always_false_v = false;
+
+    template <typename ColumnType, typename T>
+    void _create_permutation(const ColumnType& column,
+                             PermutationWithInlineValue<T>* __restrict 
permutation_for_column,
+                             const IColumn::Permutation& perms) const {
+        for (size_t i = 0; i < perms.size(); i++) {
+            size_t row_id = perms[i];
+            if constexpr (std::is_same_v<ColumnType, ColumnVector<T>> ||
+                          std::is_same_v<ColumnType, ColumnDecimal<T>>) {
+                permutation_for_column[i].inline_value_ = 
column.get_data()[row_id];
+            } else if constexpr (std::is_same_v<ColumnType, ColumnString>) {
+                permutation_for_column[i].inline_value_ = 
column.get_data_at(row_id);
+            } else {
+                static_assert(always_false_v<ColumnType>);
+            }
+            permutation_for_column[i].row_id_ = row_id;
+        }
+    }
+
+    template <typename ColumnType>
+    void _sort_by_default(const ColumnType& column, EqualFlags& flags, 
IColumn::Permutation& perms,
+                          EqualRange& range, bool last_column) const {
+        int new_limit = limit_;
+        auto comparator = [&](const size_t a, const size_t b) {
+            if constexpr (!std::is_same_v<ColumnType, ColumnString>) {
+                auto value_a = column.get_data()[a];

Review Comment:
   why do not use `compare_at ` may cause problem in decimal value?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to