This is an automated email from the ASF dual-hosted git repository. kxiao pushed a commit to branch branch-2.0 in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-2.0 by this push: new 62e2a74538 [feature](agg) Make 'map_agg' support array type as value (#22945) (#22991) 62e2a74538 is described below commit 62e2a745387b70606f982b0734c1a2b6fb0f40da Author: Jerry Hu <mrh...@gmail.com> AuthorDate: Thu Aug 17 15:59:17 2023 +0800 [feature](agg) Make 'map_agg' support array type as value (#22945) (#22991) --- .../aggregate_functions/aggregate_function_map.h | 243 +++++++++++---------- be/src/vec/columns/column_map.cpp | 6 +- be/src/vec/exec/vaggregation_node.cpp | 12 +- .../java/org/apache/doris/catalog/FunctionSet.java | 13 +- .../data/query_p0/aggregate/map_agg.out | 3 + .../suites/query_p0/aggregate/map_agg.groovy | 4 + 6 files changed, 150 insertions(+), 131 deletions(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function_map.h b/be/src/vec/aggregate_functions/aggregate_function_map.h index 5901c6eb66..d04f85973b 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_map.h +++ b/be/src/vec/aggregate_functions/aggregate_function_map.h @@ -55,7 +55,7 @@ struct AggregateFunctionMapAggData { _value_column->clear(); } - void add(const StringRef& key, const StringRef& value) { + void add(const StringRef& key, const Field& value) { DCHECK(key.data != nullptr); if (UNLIKELY(_map.find(key) != _map.end())) { return; @@ -68,7 +68,41 @@ struct AggregateFunctionMapAggData { _map.emplace(key_holder.key, _key_column->size()); _key_column->insert_data(key_holder.key.data, key_holder.key.size); - _value_column->insert_data(value.data, value.size); + _value_column->insert(value); + } + + void add(const Field& key_, const Field& value) { + DCHECK(!key_.is_null()); + auto key_array = vectorized::get<Array>(key_); + auto value_array = vectorized::get<Array>(value); + + const auto count = key_array.size(); + DCHECK_EQ(count, value_array.size()); + + for (size_t i = 0; i != count; ++i) { + StringRef key; + if constexpr (std::is_same_v<K, String>) { + auto string = key_array[i].get<K>(); + key = string; + } else { + auto& k = key_array[i].get<KeyType>(); + key.data = reinterpret_cast<const char*>(&k); + key.size = sizeof(k); + } + + if (UNLIKELY(_map.find(key) != _map.end())) { + return; + } + + ArenaKeyHolder key_holder {key, _arena}; + if (key.size > 0) { + key_holder_persist_key(key_holder); + } + + _map.emplace(key_holder.key, _key_column->size()); + _key_column->insert_data(key_holder.key.data, key_holder.key.size); + _value_column->insert(value_array[i]); + } } void merge(const AggregateFunctionMapAggData& other) { @@ -98,65 +132,6 @@ struct AggregateFunctionMapAggData { } } - static void serialize(BufferWritable& buf, const IColumn& key_column, - const IColumn& value_column, const DataTypePtr& key_type, - const DataTypePtr& value_type) { - size_t element_number = key_column.size(); - write_binary(element_number, buf); - - DCHECK(!key_column.is_nullable()); - DCHECK(!key_type->is_nullable()); - DCHECK(value_column.is_nullable()); - DCHECK(value_type->is_nullable()); - - if (element_number > 0) { - size_t serialized_size = key_type->get_uncompressed_serialized_bytes(key_column, 0); - serialized_size += value_type->get_uncompressed_serialized_bytes(value_column, 0); - - std::string serialized_buffer; - serialized_buffer.resize(serialized_size); - auto* serialized_data = serialized_buffer.data(); - - serialized_data = key_type->serialize(key_column, serialized_data, 0); - value_type->serialize(value_column, serialized_data, 0); - - write_binary(serialized_size, buf); - buf.write(serialized_buffer.data(), serialized_buffer.size()); - } - } - - void write(BufferWritable& buf) const { - serialize(buf, *_key_column, *_value_column, _key_type, _value_type); - } - - void read(BufferReadable& buf) { - size_t element_number = 0; - read_binary(element_number, buf); - - if (element_number > 0) { - _map.reserve(element_number); - - size_t serialized_size; - read_binary(serialized_size, buf); - std::string serialized_buffer; - serialized_buffer.resize(serialized_size); - - buf.read(serialized_buffer.data(), serialized_size); - const auto* serialized_data = serialized_buffer.data(); - serialized_data = _key_type->deserialize(serialized_data, _key_column.get(), 0); - _value_type->deserialize(serialized_data, _value_column.get(), 0); - - DCHECK_EQ(element_number, _key_column->size()); - DCHECK_EQ(element_number, _value_column->size()); - - for (size_t i = 0; i != element_number; ++i) { - auto key = static_cast<KeyColumnType&>(*_key_column).get_data_at(i); - DCHECK(_map.find(key) == _map.cend()); - _map.emplace(key, i); - } - } - } - void insert_result_into(IColumn& to) const { auto& dst = assert_cast<ColumnMap&>(to); size_t num_rows = _key_column->size(); @@ -211,14 +186,17 @@ public: if (nullable_map[row_num]) { return; } + Field value; + columns[1]->get(row_num, value); this->data(place).add( assert_cast<const KeyColumnType&>(nullable_col.get_nested_column()) .get_data_at(row_num), - columns[1]->get_data_at(row_num)); + value); } else { + Field value; + columns[1]->get(row_num, value); this->data(place).add( - assert_cast<const KeyColumnType&>(*columns[0]).get_data_at(row_num), - columns[1]->get_data_at(row_num)); + assert_cast<const KeyColumnType&>(*columns[0]).get_data_at(row_num), value); } } @@ -233,80 +211,107 @@ public: this->data(place).merge(this->data(rhs)); } - void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { - this->data(place).write(buf); + void serialize(ConstAggregateDataPtr /* __restrict place */, + BufferWritable& /* buf */) const override { + __builtin_unreachable(); + } + + void deserialize(AggregateDataPtr /* __restrict place */, BufferReadable& /* buf */, + Arena*) const override { + __builtin_unreachable(); } - template <bool key_nullable, bool value_nullable> - void streaming_agg_serialize_to_column_impl(const size_t num_rows, const IColumn& key_column, - const IColumn& value_column, - const NullMap& null_map, - BufferWritable& writer) const { - auto& key_col = assert_cast<const KeyColumnType&>(key_column); - auto key_to_serialize = key_col.clone_empty(); - auto val_to_serialize = value_column.clone_empty(); - auto key_type = remove_nullable(argument_types[0]); - auto val_type = make_nullable(argument_types[1]); + void streaming_agg_serialize_to_column(const IColumn** columns, MutableColumnPtr& dst, + const size_t num_rows, Arena* arena) const override { + auto& col = assert_cast<ColumnMap&>(*dst); for (size_t i = 0; i != num_rows; ++i) { - key_to_serialize->clear(); - val_to_serialize->clear(); - if constexpr (key_nullable) { - if (!null_map[i]) { - key_to_serialize->insert_range_from(key_col, i, 1); - val_to_serialize->insert_range_from(value_column, i, 1); - } - } else { - key_to_serialize->insert_range_from(key_col, i, 1); - val_to_serialize->insert_range_from(value_column, i, 1); + Map map(2); + columns[0]->get(i, map[0]); + if (map[0].is_null()) { + continue; } + columns[1]->get(i, map[1]); + col.insert(map); + } + } - if constexpr (value_nullable) { - Data::serialize(writer, *key_to_serialize, *val_to_serialize, key_type, val_type); - } else { - auto nullable_value_col = make_nullable(val_to_serialize->assume_mutable(), false); - Data::serialize(writer, *key_to_serialize, *nullable_value_col, key_type, val_type); - val_to_serialize = value_column.clone_empty(); - } - writer.commit(); + void deserialize_from_column(AggregateDataPtr places, const IColumn& column, Arena* arena, + size_t num_rows) const override { + auto& col = assert_cast<const ColumnMap&>(column); + auto* data = &(this->data(places)); + for (size_t i = 0; i != num_rows; ++i) { + auto map = doris::vectorized::get<Map>(col[i]); + data->add(map[0], map[1]); } } - void streaming_agg_serialize_to_column(const IColumn** columns, MutableColumnPtr& dst, - const size_t num_rows, Arena* arena) const override { - auto& col = assert_cast<ColumnString&>(*dst); - col.reserve(num_rows); - VectorBufferWriter writer(col); + void serialize_to_column(const std::vector<AggregateDataPtr>& places, size_t offset, + MutableColumnPtr& dst, const size_t num_rows) const override { + for (size_t i = 0; i != num_rows; ++i) { + Data& data_ = this->data(places[i] + offset); + data_.insert_result_into(*dst); + } + } - if (columns[0]->is_nullable()) { - auto& nullable_col = assert_cast<const ColumnNullable&>(*columns[0]); - auto& null_map = nullable_col.get_null_map_data(); - if (columns[0]->is_nullable()) { - this->streaming_agg_serialize_to_column_impl<true, true>( - num_rows, nullable_col.get_nested_column(), *columns[1], null_map, writer); - } else { - this->streaming_agg_serialize_to_column_impl<true, false>( - num_rows, nullable_col.get_nested_column(), *columns[1], null_map, writer); - } - } else { - if (columns[0]->is_nullable()) { - this->streaming_agg_serialize_to_column_impl<false, true>(num_rows, *columns[0], - *columns[1], {}, writer); - } else { - this->streaming_agg_serialize_to_column_impl<false, false>(num_rows, *columns[0], - *columns[1], {}, writer); + void deserialize_and_merge_from_column(AggregateDataPtr __restrict place, const IColumn& column, + Arena* arena) const override { + auto& col = assert_cast<const ColumnMap&>(column); + const size_t num_rows = column.size(); + for (size_t i = 0; i != num_rows; ++i) { + auto map = doris::vectorized::get<Map>(col[i]); + this->data(place).add(map[0], map[1]); + } + } + + void deserialize_and_merge_from_column_range(AggregateDataPtr __restrict place, + const IColumn& column, size_t begin, size_t end, + Arena* arena) const override { + DCHECK(end <= column.size() && begin <= end) + << ", begin:" << begin << ", end:" << end << ", column.size():" << column.size(); + auto& col = assert_cast<const ColumnMap&>(column); + for (size_t i = begin; i <= end; ++i) { + auto map = doris::vectorized::get<Map>(col[i]); + this->data(place).add(map[0], map[1]); + } + } + + void deserialize_and_merge_vec(const AggregateDataPtr* places, size_t offset, + AggregateDataPtr rhs, const ColumnString* column, Arena* arena, + const size_t num_rows) const override { + auto& col = assert_cast<const ColumnMap&>(*assert_cast<const IColumn*>(column)); + for (size_t i = 0; i != num_rows; ++i) { + auto map = doris::vectorized::get<Map>(col[i]); + this->data(places[i]).add(map[0], map[1]); + } + } + + void deserialize_and_merge_vec_selected(const AggregateDataPtr* places, size_t offset, + AggregateDataPtr rhs, const ColumnString* column, + Arena* arena, const size_t num_rows) const override { + auto& col = assert_cast<const ColumnMap&>(*assert_cast<const IColumn*>(column)); + for (size_t i = 0; i != num_rows; ++i) { + if (places[i]) { + auto map = doris::vectorized::get<Map>(col[i]); + this->data(places[i]).add(map[0], map[1]); } } } - void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, - Arena*) const override { - this->data(place).read(buf); + void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place, + IColumn& to) const override { + this->data(place).insert_result_into(to); } void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { this->data(place).insert_result_into(to); } + [[nodiscard]] MutableColumnPtr create_serialize_column() const override { + return get_return_type()->create_column(); + } + + [[nodiscard]] DataTypePtr get_serialized_type() const override { return get_return_type(); } + protected: using IAggregateFunction::argument_types; }; diff --git a/be/src/vec/columns/column_map.cpp b/be/src/vec/columns/column_map.cpp index ac7c5da1a9..8126b3e4e6 100644 --- a/be/src/vec/columns/column_map.cpp +++ b/be/src/vec/columns/column_map.cpp @@ -98,8 +98,6 @@ MutableColumnPtr ColumnMap::clone_resized(size_t to_size) const { // to support field functions Field ColumnMap::operator[](size_t n) const { - // Map is FieldVector, now we keep key value in seperate , see in field.h - Map m(2); size_t start_offset = offset_at(n); size_t element_size = size_at(n); @@ -116,9 +114,7 @@ Field ColumnMap::operator[](size_t n) const { v[i] = get_values()[start_offset + i]; } - m.push_back(k); - m.push_back(v); - return m; + return Map {k, v}; } // here to compare to below diff --git a/be/src/vec/exec/vaggregation_node.cpp b/be/src/vec/exec/vaggregation_node.cpp index 93cc3d97e9..c483e02ffa 100644 --- a/be/src/vec/exec/vaggregation_node.cpp +++ b/be/src/vec/exec/vaggregation_node.cpp @@ -685,11 +685,13 @@ Status AggregationNode::_get_without_key_result(RuntimeState* state, Block* bloc } } - ColumnPtr ptr = std::move(columns[i]); - // unless `count`, other aggregate function dispose empty set should be null - // so here check the children row return - ptr = make_nullable(ptr, _children[0]->rows_returned() == 0); - columns[i] = std::move(*ptr).mutate(); + if (column_type->is_nullable() && !data_types[i]->is_nullable()) { + ColumnPtr ptr = std::move(columns[i]); + // unless `count`, other aggregate function dispose empty set should be null + // so here check the children row return + ptr = make_nullable(ptr, _children[0]->rows_returned() == 0); + columns[i] = ptr->assume_mutable(); + } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java index 2391e1ec84..41a9de9b66 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java @@ -1040,12 +1040,21 @@ public class FunctionSet<T> { } if (!Type.JSONB.equals(t)) { - for (Type valueType : Type.getTrivialTypes()) { - addBuiltin(AggregateFunction.createBuiltin(MAP_AGG, Lists.newArrayList(t, valueType), new MapType(t, valueType), + for (Type valueType : Type.getMapSubTypes()) { + addBuiltin(AggregateFunction.createBuiltin(MAP_AGG, Lists.newArrayList(t, valueType), + new MapType(t, valueType), Type.VARCHAR, "", "", "", "", "", null, "", true, true, false, true)); } + + for (Type v : Type.getArraySubTypes()) { + addBuiltin(AggregateFunction.createBuiltin(MAP_AGG, Lists.newArrayList(t, new ArrayType(v)), + new MapType(t, new ArrayType(v)), + new MapType(t, new ArrayType(v)), + "", "", "", "", "", null, "", + true, true, false, true)); + } } if (STDDEV_UPDATE_SYMBOL.containsKey(t)) { diff --git a/regression-test/data/query_p0/aggregate/map_agg.out b/regression-test/data/query_p0/aggregate/map_agg.out index 0b8d5f3be0..62c8ecc101 100644 --- a/regression-test/data/query_p0/aggregate/map_agg.out +++ b/regression-test/data/query_p0/aggregate/map_agg.out @@ -20,3 +20,6 @@ 4 V4_1 V4_2 V4_3 5 V5_1 V5_2 V5_3 +-- !sql3 -- +{"key":["ab", "efg", NULL]} + diff --git a/regression-test/suites/query_p0/aggregate/map_agg.groovy b/regression-test/suites/query_p0/aggregate/map_agg.groovy index e779e6061e..2337f2fcea 100644 --- a/regression-test/suites/query_p0/aggregate/map_agg.groovy +++ b/regression-test/suites/query_p0/aggregate/map_agg.groovy @@ -168,6 +168,10 @@ suite("map_agg") { ORDER BY `id`; """ + qt_sql3 """ + select map_agg(k, v) from (select 'key' as k, array('ab', 'efg', null) v) a; + """ + sql "DROP TABLE `test_map_agg`" sql "DROP TABLE `test_map_agg_nullable`" sql "DROP TABLE `test_map_agg_numeric_key`" --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org