This is an automated email from the ASF dual-hosted git repository.

kxiao pushed a commit to branch branch-2.0
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-2.0 by this push:
     new 62e2a74538 [feature](agg) Make 'map_agg' support array type as value 
(#22945) (#22991)
62e2a74538 is described below

commit 62e2a745387b70606f982b0734c1a2b6fb0f40da
Author: Jerry Hu <mrh...@gmail.com>
AuthorDate: Thu Aug 17 15:59:17 2023 +0800

    [feature](agg) Make 'map_agg' support array type as value (#22945) (#22991)
---
 .../aggregate_functions/aggregate_function_map.h   | 243 +++++++++++----------
 be/src/vec/columns/column_map.cpp                  |   6 +-
 be/src/vec/exec/vaggregation_node.cpp              |  12 +-
 .../java/org/apache/doris/catalog/FunctionSet.java |  13 +-
 .../data/query_p0/aggregate/map_agg.out            |   3 +
 .../suites/query_p0/aggregate/map_agg.groovy       |   4 +
 6 files changed, 150 insertions(+), 131 deletions(-)

diff --git a/be/src/vec/aggregate_functions/aggregate_function_map.h 
b/be/src/vec/aggregate_functions/aggregate_function_map.h
index 5901c6eb66..d04f85973b 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_map.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_map.h
@@ -55,7 +55,7 @@ struct AggregateFunctionMapAggData {
         _value_column->clear();
     }
 
-    void add(const StringRef& key, const StringRef& value) {
+    void add(const StringRef& key, const Field& value) {
         DCHECK(key.data != nullptr);
         if (UNLIKELY(_map.find(key) != _map.end())) {
             return;
@@ -68,7 +68,41 @@ struct AggregateFunctionMapAggData {
 
         _map.emplace(key_holder.key, _key_column->size());
         _key_column->insert_data(key_holder.key.data, key_holder.key.size);
-        _value_column->insert_data(value.data, value.size);
+        _value_column->insert(value);
+    }
+
+    void add(const Field& key_, const Field& value) {
+        DCHECK(!key_.is_null());
+        auto key_array = vectorized::get<Array>(key_);
+        auto value_array = vectorized::get<Array>(value);
+
+        const auto count = key_array.size();
+        DCHECK_EQ(count, value_array.size());
+
+        for (size_t i = 0; i != count; ++i) {
+            StringRef key;
+            if constexpr (std::is_same_v<K, String>) {
+                auto string = key_array[i].get<K>();
+                key = string;
+            } else {
+                auto& k = key_array[i].get<KeyType>();
+                key.data = reinterpret_cast<const char*>(&k);
+                key.size = sizeof(k);
+            }
+
+            if (UNLIKELY(_map.find(key) != _map.end())) {
+                return;
+            }
+
+            ArenaKeyHolder key_holder {key, _arena};
+            if (key.size > 0) {
+                key_holder_persist_key(key_holder);
+            }
+
+            _map.emplace(key_holder.key, _key_column->size());
+            _key_column->insert_data(key_holder.key.data, key_holder.key.size);
+            _value_column->insert(value_array[i]);
+        }
     }
 
     void merge(const AggregateFunctionMapAggData& other) {
@@ -98,65 +132,6 @@ struct AggregateFunctionMapAggData {
         }
     }
 
-    static void serialize(BufferWritable& buf, const IColumn& key_column,
-                          const IColumn& value_column, const DataTypePtr& 
key_type,
-                          const DataTypePtr& value_type) {
-        size_t element_number = key_column.size();
-        write_binary(element_number, buf);
-
-        DCHECK(!key_column.is_nullable());
-        DCHECK(!key_type->is_nullable());
-        DCHECK(value_column.is_nullable());
-        DCHECK(value_type->is_nullable());
-
-        if (element_number > 0) {
-            size_t serialized_size = 
key_type->get_uncompressed_serialized_bytes(key_column, 0);
-            serialized_size += 
value_type->get_uncompressed_serialized_bytes(value_column, 0);
-
-            std::string serialized_buffer;
-            serialized_buffer.resize(serialized_size);
-            auto* serialized_data = serialized_buffer.data();
-
-            serialized_data = key_type->serialize(key_column, serialized_data, 
0);
-            value_type->serialize(value_column, serialized_data, 0);
-
-            write_binary(serialized_size, buf);
-            buf.write(serialized_buffer.data(), serialized_buffer.size());
-        }
-    }
-
-    void write(BufferWritable& buf) const {
-        serialize(buf, *_key_column, *_value_column, _key_type, _value_type);
-    }
-
-    void read(BufferReadable& buf) {
-        size_t element_number = 0;
-        read_binary(element_number, buf);
-
-        if (element_number > 0) {
-            _map.reserve(element_number);
-
-            size_t serialized_size;
-            read_binary(serialized_size, buf);
-            std::string serialized_buffer;
-            serialized_buffer.resize(serialized_size);
-
-            buf.read(serialized_buffer.data(), serialized_size);
-            const auto* serialized_data = serialized_buffer.data();
-            serialized_data = _key_type->deserialize(serialized_data, 
_key_column.get(), 0);
-            _value_type->deserialize(serialized_data, _value_column.get(), 0);
-
-            DCHECK_EQ(element_number, _key_column->size());
-            DCHECK_EQ(element_number, _value_column->size());
-
-            for (size_t i = 0; i != element_number; ++i) {
-                auto key = 
static_cast<KeyColumnType&>(*_key_column).get_data_at(i);
-                DCHECK(_map.find(key) == _map.cend());
-                _map.emplace(key, i);
-            }
-        }
-    }
-
     void insert_result_into(IColumn& to) const {
         auto& dst = assert_cast<ColumnMap&>(to);
         size_t num_rows = _key_column->size();
@@ -211,14 +186,17 @@ public:
             if (nullable_map[row_num]) {
                 return;
             }
+            Field value;
+            columns[1]->get(row_num, value);
             this->data(place).add(
                     assert_cast<const 
KeyColumnType&>(nullable_col.get_nested_column())
                             .get_data_at(row_num),
-                    columns[1]->get_data_at(row_num));
+                    value);
         } else {
+            Field value;
+            columns[1]->get(row_num, value);
             this->data(place).add(
-                    assert_cast<const 
KeyColumnType&>(*columns[0]).get_data_at(row_num),
-                    columns[1]->get_data_at(row_num));
+                    assert_cast<const 
KeyColumnType&>(*columns[0]).get_data_at(row_num), value);
         }
     }
 
@@ -233,80 +211,107 @@ public:
         this->data(place).merge(this->data(rhs));
     }
 
-    void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& 
buf) const override {
-        this->data(place).write(buf);
+    void serialize(ConstAggregateDataPtr /* __restrict place */,
+                   BufferWritable& /* buf */) const override {
+        __builtin_unreachable();
+    }
+
+    void deserialize(AggregateDataPtr /* __restrict place */, BufferReadable& 
/* buf */,
+                     Arena*) const override {
+        __builtin_unreachable();
     }
 
-    template <bool key_nullable, bool value_nullable>
-    void streaming_agg_serialize_to_column_impl(const size_t num_rows, const 
IColumn& key_column,
-                                                const IColumn& value_column,
-                                                const NullMap& null_map,
-                                                BufferWritable& writer) const {
-        auto& key_col = assert_cast<const KeyColumnType&>(key_column);
-        auto key_to_serialize = key_col.clone_empty();
-        auto val_to_serialize = value_column.clone_empty();
-        auto key_type = remove_nullable(argument_types[0]);
-        auto val_type = make_nullable(argument_types[1]);
+    void streaming_agg_serialize_to_column(const IColumn** columns, 
MutableColumnPtr& dst,
+                                           const size_t num_rows, Arena* 
arena) const override {
+        auto& col = assert_cast<ColumnMap&>(*dst);
         for (size_t i = 0; i != num_rows; ++i) {
-            key_to_serialize->clear();
-            val_to_serialize->clear();
-            if constexpr (key_nullable) {
-                if (!null_map[i]) {
-                    key_to_serialize->insert_range_from(key_col, i, 1);
-                    val_to_serialize->insert_range_from(value_column, i, 1);
-                }
-            } else {
-                key_to_serialize->insert_range_from(key_col, i, 1);
-                val_to_serialize->insert_range_from(value_column, i, 1);
+            Map map(2);
+            columns[0]->get(i, map[0]);
+            if (map[0].is_null()) {
+                continue;
             }
+            columns[1]->get(i, map[1]);
+            col.insert(map);
+        }
+    }
 
-            if constexpr (value_nullable) {
-                Data::serialize(writer, *key_to_serialize, *val_to_serialize, 
key_type, val_type);
-            } else {
-                auto nullable_value_col = 
make_nullable(val_to_serialize->assume_mutable(), false);
-                Data::serialize(writer, *key_to_serialize, 
*nullable_value_col, key_type, val_type);
-                val_to_serialize = value_column.clone_empty();
-            }
-            writer.commit();
+    void deserialize_from_column(AggregateDataPtr places, const IColumn& 
column, Arena* arena,
+                                 size_t num_rows) const override {
+        auto& col = assert_cast<const ColumnMap&>(column);
+        auto* data = &(this->data(places));
+        for (size_t i = 0; i != num_rows; ++i) {
+            auto map = doris::vectorized::get<Map>(col[i]);
+            data->add(map[0], map[1]);
         }
     }
 
-    void streaming_agg_serialize_to_column(const IColumn** columns, 
MutableColumnPtr& dst,
-                                           const size_t num_rows, Arena* 
arena) const override {
-        auto& col = assert_cast<ColumnString&>(*dst);
-        col.reserve(num_rows);
-        VectorBufferWriter writer(col);
+    void serialize_to_column(const std::vector<AggregateDataPtr>& places, 
size_t offset,
+                             MutableColumnPtr& dst, const size_t num_rows) 
const override {
+        for (size_t i = 0; i != num_rows; ++i) {
+            Data& data_ = this->data(places[i] + offset);
+            data_.insert_result_into(*dst);
+        }
+    }
 
-        if (columns[0]->is_nullable()) {
-            auto& nullable_col = assert_cast<const 
ColumnNullable&>(*columns[0]);
-            auto& null_map = nullable_col.get_null_map_data();
-            if (columns[0]->is_nullable()) {
-                this->streaming_agg_serialize_to_column_impl<true, true>(
-                        num_rows, nullable_col.get_nested_column(), 
*columns[1], null_map, writer);
-            } else {
-                this->streaming_agg_serialize_to_column_impl<true, false>(
-                        num_rows, nullable_col.get_nested_column(), 
*columns[1], null_map, writer);
-            }
-        } else {
-            if (columns[0]->is_nullable()) {
-                this->streaming_agg_serialize_to_column_impl<false, 
true>(num_rows, *columns[0],
-                                                                          
*columns[1], {}, writer);
-            } else {
-                this->streaming_agg_serialize_to_column_impl<false, 
false>(num_rows, *columns[0],
-                                                                           
*columns[1], {}, writer);
+    void deserialize_and_merge_from_column(AggregateDataPtr __restrict place, 
const IColumn& column,
+                                           Arena* arena) const override {
+        auto& col = assert_cast<const ColumnMap&>(column);
+        const size_t num_rows = column.size();
+        for (size_t i = 0; i != num_rows; ++i) {
+            auto map = doris::vectorized::get<Map>(col[i]);
+            this->data(place).add(map[0], map[1]);
+        }
+    }
+
+    void deserialize_and_merge_from_column_range(AggregateDataPtr __restrict 
place,
+                                                 const IColumn& column, size_t 
begin, size_t end,
+                                                 Arena* arena) const override {
+        DCHECK(end <= column.size() && begin <= end)
+                << ", begin:" << begin << ", end:" << end << ", 
column.size():" << column.size();
+        auto& col = assert_cast<const ColumnMap&>(column);
+        for (size_t i = begin; i <= end; ++i) {
+            auto map = doris::vectorized::get<Map>(col[i]);
+            this->data(place).add(map[0], map[1]);
+        }
+    }
+
+    void deserialize_and_merge_vec(const AggregateDataPtr* places, size_t 
offset,
+                                   AggregateDataPtr rhs, const ColumnString* 
column, Arena* arena,
+                                   const size_t num_rows) const override {
+        auto& col = assert_cast<const ColumnMap&>(*assert_cast<const 
IColumn*>(column));
+        for (size_t i = 0; i != num_rows; ++i) {
+            auto map = doris::vectorized::get<Map>(col[i]);
+            this->data(places[i]).add(map[0], map[1]);
+        }
+    }
+
+    void deserialize_and_merge_vec_selected(const AggregateDataPtr* places, 
size_t offset,
+                                            AggregateDataPtr rhs, const 
ColumnString* column,
+                                            Arena* arena, const size_t 
num_rows) const override {
+        auto& col = assert_cast<const ColumnMap&>(*assert_cast<const 
IColumn*>(column));
+        for (size_t i = 0; i != num_rows; ++i) {
+            if (places[i]) {
+                auto map = doris::vectorized::get<Map>(col[i]);
+                this->data(places[i]).add(map[0], map[1]);
             }
         }
     }
 
-    void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
-                     Arena*) const override {
-        this->data(place).read(buf);
+    void serialize_without_key_to_column(ConstAggregateDataPtr __restrict 
place,
+                                         IColumn& to) const override {
+        this->data(place).insert_result_into(to);
     }
 
     void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& 
to) const override {
         this->data(place).insert_result_into(to);
     }
 
+    [[nodiscard]] MutableColumnPtr create_serialize_column() const override {
+        return get_return_type()->create_column();
+    }
+
+    [[nodiscard]] DataTypePtr get_serialized_type() const override { return 
get_return_type(); }
+
 protected:
     using IAggregateFunction::argument_types;
 };
diff --git a/be/src/vec/columns/column_map.cpp 
b/be/src/vec/columns/column_map.cpp
index ac7c5da1a9..8126b3e4e6 100644
--- a/be/src/vec/columns/column_map.cpp
+++ b/be/src/vec/columns/column_map.cpp
@@ -98,8 +98,6 @@ MutableColumnPtr ColumnMap::clone_resized(size_t to_size) 
const {
 
 // to support field functions
 Field ColumnMap::operator[](size_t n) const {
-    // Map is FieldVector, now we keep key value in seperate  , see in field.h
-    Map m(2);
     size_t start_offset = offset_at(n);
     size_t element_size = size_at(n);
 
@@ -116,9 +114,7 @@ Field ColumnMap::operator[](size_t n) const {
         v[i] = get_values()[start_offset + i];
     }
 
-    m.push_back(k);
-    m.push_back(v);
-    return m;
+    return Map {k, v};
 }
 
 // here to compare to below
diff --git a/be/src/vec/exec/vaggregation_node.cpp 
b/be/src/vec/exec/vaggregation_node.cpp
index 93cc3d97e9..c483e02ffa 100644
--- a/be/src/vec/exec/vaggregation_node.cpp
+++ b/be/src/vec/exec/vaggregation_node.cpp
@@ -685,11 +685,13 @@ Status 
AggregationNode::_get_without_key_result(RuntimeState* state, Block* bloc
                 }
             }
 
-            ColumnPtr ptr = std::move(columns[i]);
-            // unless `count`, other aggregate function dispose empty set 
should be null
-            // so here check the children row return
-            ptr = make_nullable(ptr, _children[0]->rows_returned() == 0);
-            columns[i] = std::move(*ptr).mutate();
+            if (column_type->is_nullable() && !data_types[i]->is_nullable()) {
+                ColumnPtr ptr = std::move(columns[i]);
+                // unless `count`, other aggregate function dispose empty set 
should be null
+                // so here check the children row return
+                ptr = make_nullable(ptr, _children[0]->rows_returned() == 0);
+                columns[i] = ptr->assume_mutable();
+            }
         }
     }
 
diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java 
b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java
index 2391e1ec84..41a9de9b66 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java
@@ -1040,12 +1040,21 @@ public class FunctionSet<T> {
             }
 
             if (!Type.JSONB.equals(t)) {
-                for (Type valueType : Type.getTrivialTypes()) {
-                    addBuiltin(AggregateFunction.createBuiltin(MAP_AGG, 
Lists.newArrayList(t, valueType), new MapType(t, valueType),
+                for (Type valueType : Type.getMapSubTypes()) {
+                    addBuiltin(AggregateFunction.createBuiltin(MAP_AGG, 
Lists.newArrayList(t, valueType),
+                            new MapType(t, valueType),
                             Type.VARCHAR,
                             "", "", "", "", "", null, "",
                             true, true, false, true));
                 }
+
+                for (Type v : Type.getArraySubTypes()) {
+                    addBuiltin(AggregateFunction.createBuiltin(MAP_AGG, 
Lists.newArrayList(t, new ArrayType(v)),
+                            new MapType(t, new ArrayType(v)),
+                            new MapType(t, new ArrayType(v)),
+                            "", "", "", "", "", null, "",
+                            true, true, false, true));
+                }
             }
 
             if (STDDEV_UPDATE_SYMBOL.containsKey(t)) {
diff --git a/regression-test/data/query_p0/aggregate/map_agg.out 
b/regression-test/data/query_p0/aggregate/map_agg.out
index 0b8d5f3be0..62c8ecc101 100644
--- a/regression-test/data/query_p0/aggregate/map_agg.out
+++ b/regression-test/data/query_p0/aggregate/map_agg.out
@@ -20,3 +20,6 @@
 4      V4_1    V4_2    V4_3
 5      V5_1    V5_2    V5_3
 
+-- !sql3 --
+{"key":["ab", "efg", NULL]}
+
diff --git a/regression-test/suites/query_p0/aggregate/map_agg.groovy 
b/regression-test/suites/query_p0/aggregate/map_agg.groovy
index e779e6061e..2337f2fcea 100644
--- a/regression-test/suites/query_p0/aggregate/map_agg.groovy
+++ b/regression-test/suites/query_p0/aggregate/map_agg.groovy
@@ -168,6 +168,10 @@ suite("map_agg") {
         ORDER BY `id`;
     """
 
+    qt_sql3 """
+        select map_agg(k, v) from (select 'key' as k, array('ab', 'efg', null) 
v) a;
+    """
+
      sql "DROP TABLE `test_map_agg`"
      sql "DROP TABLE `test_map_agg_nullable`"
      sql "DROP TABLE `test_map_agg_numeric_key`"


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to