This is an automated email from the ASF dual-hosted git repository. panxiaolei pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new fe8acdb268 [feature-wip](array-type) add agg function collect_list and collect_set (#10606) fe8acdb268 is described below commit fe8acdb268942aeda7a12b6d1078830b0fa7eb3d Author: camby <104178...@qq.com> AuthorDate: Fri Jul 8 12:48:46 2022 +0800 [feature-wip](array-type) add agg function collect_list and collect_set (#10606) add codes for collect_list and collect_set and update regression output, before output format for ARRAY(string) already changed. Co-authored-by: cambyzju <zhuxiaol...@baidu.com> --- be/src/vec/CMakeLists.txt | 1 + .../aggregate_function_collect.cpp | 86 +++++++ .../aggregate_function_collect.h | 257 +++++++++++++++++++++ .../aggregate_function_simple_factory.cpp | 3 + be/test/CMakeLists.txt | 1 + .../vec/aggregate_functions/agg_collect_test.cpp | 161 +++++++++++++ .../aggregate-functions/collect_list.md | 70 ++++++ .../aggregate-functions/collect_set.md | 70 ++++++ .../aggregate-functions/collect_list.md | 71 ++++++ .../aggregate-functions/collect_set.md | 70 ++++++ .../java/org/apache/doris/catalog/FunctionSet.java | 12 + .../aggregate_functions/test_aggregate_collect.out | 9 + .../test_aggregate_collect.groovy | 41 ++++ 13 files changed, 852 insertions(+) diff --git a/be/src/vec/CMakeLists.txt b/be/src/vec/CMakeLists.txt index 549e931672..3eb9ae6c5a 100644 --- a/be/src/vec/CMakeLists.txt +++ b/be/src/vec/CMakeLists.txt @@ -22,6 +22,7 @@ set(EXECUTABLE_OUTPUT_PATH "${BUILD_DIR}/src/vec") set(VEC_FILES aggregate_functions/aggregate_function_window_funnel.cpp aggregate_functions/aggregate_function_avg.cpp + aggregate_functions/aggregate_function_collect.cpp aggregate_functions/aggregate_function_count.cpp aggregate_functions/aggregate_function_distinct.cpp aggregate_functions/aggregate_function_sum.cpp diff --git a/be/src/vec/aggregate_functions/aggregate_function_collect.cpp b/be/src/vec/aggregate_functions/aggregate_function_collect.cpp new file mode 100644 index 0000000000..9f58b8ee17 --- /dev/null +++ b/be/src/vec/aggregate_functions/aggregate_function_collect.cpp @@ -0,0 +1,86 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "vec/aggregate_functions/aggregate_function_collect.h" + +#include "vec/aggregate_functions/aggregate_function_simple_factory.h" + +namespace doris::vectorized { + +template <typename T> +AggregateFunctionPtr create_agg_function_collect(bool distinct, const DataTypes& argument_types) { + if (distinct) { + return AggregateFunctionPtr( + new AggregateFunctionCollect<AggregateFunctionCollectSetData<T>>(argument_types)); + } else { + return AggregateFunctionPtr( + new AggregateFunctionCollect<AggregateFunctionCollectListData<T>>(argument_types)); + } +} + +AggregateFunctionPtr create_aggregate_function_collect(const std::string& name, + const DataTypes& argument_types, + const Array& parameters, + const bool result_is_nullable) { + if (argument_types.size() != 1) { + LOG(WARNING) << fmt::format("Illegal number {} of argument for aggregate function {}", + argument_types.size(), name); + return nullptr; + } + + bool distinct = false; + if (name == "collect_set") { + distinct = true; + } + + WhichDataType type(argument_types[0]); + if (type.is_uint8()) { + return create_agg_function_collect<UInt8>(distinct, argument_types); + } else if (type.is_int8()) { + return create_agg_function_collect<Int8>(distinct, argument_types); + } else if (type.is_int16()) { + return create_agg_function_collect<Int16>(distinct, argument_types); + } else if (type.is_int32()) { + return create_agg_function_collect<Int32>(distinct, argument_types); + } else if (type.is_int64()) { + return create_agg_function_collect<Int64>(distinct, argument_types); + } else if (type.is_int128()) { + return create_agg_function_collect<Int128>(distinct, argument_types); + } else if (type.is_float32()) { + return create_agg_function_collect<Float32>(distinct, argument_types); + } else if (type.is_float64()) { + return create_agg_function_collect<Float64>(distinct, argument_types); + } else if (type.is_decimal128()) { + return create_agg_function_collect<Decimal128>(distinct, argument_types); + } else if (type.is_date()) { + return create_agg_function_collect<Int64>(distinct, argument_types); + } else if (type.is_date_time()) { + return create_agg_function_collect<Int64>(distinct, argument_types); + } else if (type.is_string()) { + return create_agg_function_collect<StringRef>(distinct, argument_types); + } + + LOG(WARNING) << fmt::format("unsupported input type {} for aggregate function {}", + argument_types[0]->get_name(), name); + return nullptr; +} + +void register_aggregate_function_collect_list(AggregateFunctionSimpleFactory& factory) { + factory.register_function("collect_list", create_aggregate_function_collect); + factory.register_function("collect_set", create_aggregate_function_collect); +} +} // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_collect.h b/be/src/vec/aggregate_functions/aggregate_function_collect.h new file mode 100644 index 0000000000..5df33ab6f2 --- /dev/null +++ b/be/src/vec/aggregate_functions/aggregate_function_collect.h @@ -0,0 +1,257 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "common/status.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/aggregate_functions/key_holder_helpers.h" +#include "vec/common/aggregation_common.h" +#include "vec/common/hash_table/hash_set.h" +#include "vec/common/pod_array_fwd.h" +#include "vec/common/string_ref.h" +#include "vec/data_types/data_type_array.h" +#include "vec/data_types/data_type_string.h" +#include "vec/io/io_helper.h" + +namespace doris::vectorized { + +template <typename T> +struct AggregateFunctionCollectSetData { + using ElementType = T; + using ColVecType = ColumnVectorOrDecimal<ElementType>; + using ElementNativeType = typename NativeType<T>::Type; + using Set = HashSetWithStackMemory<ElementNativeType, DefaultHash<ElementNativeType>, 4>; + Set set; + + void add(const IColumn& column, size_t row_num) { + const auto& vec = assert_cast<const ColVecType&>(column).get_data(); + set.insert(vec[row_num]); + } + void merge(const AggregateFunctionCollectSetData& rhs) { set.merge(rhs.set); } + void write(BufferWritable& buf) const { set.write(buf); } + void read(BufferReadable& buf) { set.read(buf); } + void reset() { set.clear(); } + void insert_result_into(IColumn& to) const { + auto& vec = assert_cast<ColVecType&>(to).get_data(); + vec.reserve(set.size()); + for (auto item : set) { + vec.push_back(item.key); + } + } +}; + +template <> +struct AggregateFunctionCollectSetData<StringRef> { + using ElementType = StringRef; + using ColVecType = ColumnString; + using Set = HashSetWithSavedHashWithStackMemory<ElementType, DefaultHash<ElementType>, 4>; + Set set; + + void add(const IColumn& column, size_t row_num, Arena* arena) { + Set::LookupResult it; + bool inserted; + auto key_holder = get_key_holder<true>(column, row_num, *arena); + set.emplace(key_holder, it, inserted); + } + + void merge(const AggregateFunctionCollectSetData& rhs, Arena* arena) { + Set::LookupResult it; + bool inserted; + for (const auto& elem : rhs.set) { + set.emplace(ArenaKeyHolder {elem.get_value(), *arena}, it, inserted); + } + } + void write(BufferWritable& buf) const { + write_var_uint(set.size(), buf); + for (const auto& elem : set) { + write_string_binary(elem.get_value(), buf); + } + } + void read(BufferReadable& buf) { + size_t rows; + read_var_uint(rows, buf); + + StringRef ref; + for (size_t i = 0; i < rows; ++i) { + read_string_binary(ref, buf); + set.insert(ref); + } + } + void reset() { set.clear(); } + void insert_result_into(IColumn& to) const { + auto& vec = assert_cast<ColVecType&>(to); + vec.reserve(set.size()); + for (const auto& item : set) { + vec.insert_data(item.key.data, item.key.size); + } + } +}; + +template <typename T> +struct AggregateFunctionCollectListData { + using ElementType = T; + using ColVecType = ColumnVectorOrDecimal<ElementType>; + PaddedPODArray<ElementType> data; + + void add(const IColumn& column, size_t row_num) { + const auto& vec = assert_cast<const ColVecType&>(column).get_data(); + data.push_back(vec[row_num]); + } + void merge(const AggregateFunctionCollectListData& rhs) { + data.insert(rhs.data.begin(), rhs.data.end()); + } + void write(BufferWritable& buf) const { + write_var_uint(data.size(), buf); + buf.write(data.raw_data(), data.size() * sizeof(ElementType)); + } + void read(BufferReadable& buf) { + size_t rows = 0; + read_var_uint(rows, buf); + data.resize(rows); + buf.read(reinterpret_cast<char*>(data.data()), rows * sizeof(ElementType)); + } + void reset() { data.clear(); } + void insert_result_into(IColumn& to) const { + auto& vec = assert_cast<ColVecType&>(to).get_data(); + size_t old_size = vec.size(); + vec.resize(old_size + data.size()); + memcpy(vec.data() + old_size, data.data(), data.size() * sizeof(ElementType)); + } +}; + +template <> +struct AggregateFunctionCollectListData<StringRef> { + using ElementType = StringRef; + using ColVecType = ColumnString; + MutableColumnPtr data; + + AggregateFunctionCollectListData<ElementType>() { data = ColVecType::create(); } + + void add(const IColumn& column, size_t row_num) { data->insert_from(column, row_num); } + + void merge(const AggregateFunctionCollectListData& rhs) { + data->insert_range_from(*rhs.data, 0, rhs.data->size()); + } + + void write(BufferWritable& buf) const { + auto& col = assert_cast<ColVecType&>(*data); + + write_var_uint(col.size(), buf); + buf.write(col.get_offsets().raw_data(), col.size() * sizeof(IColumn::Offset)); + + write_var_uint(col.get_chars().size(), buf); + buf.write(col.get_chars().raw_data(), col.get_chars().size()); + } + + void read(BufferReadable& buf) { + auto& col = assert_cast<ColVecType&>(*data); + size_t offs_size = 0; + read_var_uint(offs_size, buf); + col.get_offsets().resize(offs_size); + buf.read(reinterpret_cast<char*>(col.get_offsets().data()), + offs_size * sizeof(IColumn::Offset)); + + size_t chars_size = 0; + read_var_uint(chars_size, buf); + col.get_chars().resize(chars_size); + buf.read(reinterpret_cast<char*>(col.get_chars().data()), chars_size); + } + + void reset() { data->clear(); } + + void insert_result_into(IColumn& to) const { + auto& to_str = assert_cast<ColVecType&>(to); + to_str.insert_range_from(*data, 0, data->size()); + } +}; + +template <typename Data> +class AggregateFunctionCollect final + : public IAggregateFunctionDataHelper<Data, AggregateFunctionCollect<Data>> { +public: + static constexpr bool alloc_memory_in_arena = + std::is_same_v<Data, AggregateFunctionCollectSetData<StringRef>>; + + AggregateFunctionCollect(const DataTypes& argument_types_) + : IAggregateFunctionDataHelper<Data, AggregateFunctionCollect<Data>>(argument_types_, + {}), + _argument_type(argument_types_[0]) {} + + std::string get_name() const override { + if constexpr (std::is_same_v<AggregateFunctionCollectListData<typename Data::ElementType>, + Data>) { + return "collect_list"; + } else { + return "collect_set"; + } + } + + DataTypePtr get_return_type() const override { + return std::make_shared<DataTypeArray>(make_nullable(_argument_type)); + } + + void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + Arena* arena) const override { + assert(!columns[0]->is_null_at(row_num)); + if constexpr (alloc_memory_in_arena) { + this->data(place).add(*columns[0], row_num, arena); + } else { + this->data(place).add(*columns[0], row_num); + } + } + + void reset(AggregateDataPtr place) const override { this->data(place).reset(); } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, + Arena* arena) const override { + if constexpr (alloc_memory_in_arena) { + this->data(place).merge(this->data(rhs), arena); + } else { + this->data(place).merge(this->data(rhs)); + } + } + + void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { + this->data(place).write(buf); + } + + void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, + Arena*) const override { + this->data(place).read(buf); + } + + void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { + auto& to_arr = assert_cast<ColumnArray&>(to); + auto& to_nested_col = to_arr.get_data(); + if (to_nested_col.is_nullable()) { + auto col_null = reinterpret_cast<ColumnNullable*>(&to_nested_col); + this->data(place).insert_result_into(col_null->get_nested_column()); + col_null->get_null_map_data().resize_fill(col_null->get_nested_column().size(), 0); + } else { + this->data(place).insert_result_into(to_nested_col); + } + to_arr.get_offsets().push_back(to_nested_col.size()); + } + + bool allocates_memory_in_arena() const override { return alloc_memory_in_arena; } + +private: + DataTypePtr _argument_type; +}; + +} // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp index badf756f8b..58dc0a4c9b 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp @@ -47,6 +47,8 @@ void register_aggregate_function_percentile(AggregateFunctionSimpleFactory& fact void register_aggregate_function_window_funnel(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_percentile_approx(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_orthogonal_bitmap(AggregateFunctionSimpleFactory& factory); +void register_aggregate_function_collect_list(AggregateFunctionSimpleFactory& factory); + AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() { static std::once_flag oc; static AggregateFunctionSimpleFactory instance; @@ -70,6 +72,7 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() { register_aggregate_function_percentile_approx(instance); register_aggregate_function_window_funnel(instance); register_aggregate_function_orthogonal_bitmap(instance); + register_aggregate_function_collect_list(instance); // if you only register function with no nullable, and wants to add nullable automatically, you should place function above this line register_aggregate_function_combinator_null(instance); diff --git a/be/test/CMakeLists.txt b/be/test/CMakeLists.txt index ed6693b992..1e2ffc8ade 100644 --- a/be/test/CMakeLists.txt +++ b/be/test/CMakeLists.txt @@ -317,6 +317,7 @@ set(UTIL_TEST_FILES util/interval_tree_test.cpp ) set(VEC_TEST_FILES + vec/aggregate_functions/agg_collect_test.cpp vec/aggregate_functions/agg_test.cpp vec/aggregate_functions/agg_min_max_test.cpp vec/aggregate_functions/vec_window_funnel_test.cpp diff --git a/be/test/vec/aggregate_functions/agg_collect_test.cpp b/be/test/vec/aggregate_functions/agg_collect_test.cpp new file mode 100644 index 0000000000..28e31ca58c --- /dev/null +++ b/be/test/vec/aggregate_functions/agg_collect_test.cpp @@ -0,0 +1,161 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <gtest/gtest.h> + +#include "common/logging.h" +#include "gtest/gtest.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/aggregate_functions/aggregate_function_collect.h" +#include "vec/aggregate_functions/aggregate_function_simple_factory.h" +#include "vec/columns/column_vector.h" +#include "vec/data_types/data_type.h" +#include "vec/data_types/data_type_date.h" +#include "vec/data_types/data_type_date_time.h" +#include "vec/data_types/data_type_decimal.h" +#include "vec/data_types/data_type_number.h" +#include "vec/data_types/data_type_string.h" + +namespace doris::vectorized { + +void register_aggregate_function_collect_list(AggregateFunctionSimpleFactory& factory); + +class VAggCollectTest : public testing::Test { +public: + void SetUp() { + AggregateFunctionSimpleFactory factory = AggregateFunctionSimpleFactory::instance(); + register_aggregate_function_collect_list(factory); + } + + void TearDown() {} + + bool is_distinct(const std::string& fn_name) { return fn_name == "collect_set"; } + + template <typename DataType> + void agg_collect_add_elements(AggregateFunctionPtr agg_function, AggregateDataPtr place, + size_t input_nums) { + using FieldType = typename DataType::FieldType; + auto type = std::make_shared<DataType>(); + auto input_col = type->create_column(); + for (size_t i = 0; i < input_nums; ++i) { + for (size_t j = 0; j < _repeated_times; ++j) { + if constexpr (std::is_same_v<DataType, DataTypeString>) { + auto item = std::string("item") + std::to_string(i); + input_col->insert_data(item.c_str(), item.size()); + } else { + auto item = FieldType(i); + input_col->insert_data(reinterpret_cast<const char*>(&item), 0); + } + } + } + EXPECT_EQ(input_col->size(), input_nums * _repeated_times); + + const IColumn* column[1] = {input_col.get()}; + for (int i = 0; i < input_col->size(); i++) { + agg_function->add(place, column, i, &_agg_arena_pool); + } + } + + template <typename DataType> + void test_agg_collect(const std::string& fn_name, size_t input_nums = 0) { + DataTypes data_types = {(DataTypePtr)std::make_shared<DataType>()}; + LOG(INFO) << "test_agg_collect for " << fn_name << "(" << data_types[0]->get_name() << ")"; + Array array; + AggregateFunctionSimpleFactory factory = AggregateFunctionSimpleFactory::instance(); + auto agg_function = factory.get(fn_name, data_types, array); + EXPECT_NE(agg_function, nullptr); + + std::unique_ptr<char[]> memory(new char[agg_function->size_of_data()]); + AggregateDataPtr place = memory.get(); + agg_function->create(place); + + agg_collect_add_elements<DataType>(agg_function, place, input_nums); + + ColumnString buf; + VectorBufferWriter buf_writer(buf); + agg_function->serialize(place, buf_writer); + buf_writer.commit(); + VectorBufferReader buf_reader(buf.get_data_at(0)); + agg_function->deserialize(place, buf_reader, &_agg_arena_pool); + + std::unique_ptr<char[]> memory2(new char[agg_function->size_of_data()]); + AggregateDataPtr place2 = memory2.get(); + agg_function->create(place2); + + agg_collect_add_elements<DataType>(agg_function, place2, input_nums); + + agg_function->merge(place, place2, &_agg_arena_pool); + auto column_result = ColumnArray::create(data_types[0]->create_column()); + agg_function->insert_result_into(place, *column_result); + EXPECT_EQ(column_result->size(), 1); + EXPECT_EQ(column_result->get_offsets()[0], + is_distinct(fn_name) ? input_nums : 2 * input_nums * _repeated_times); + + auto column_result2 = ColumnArray::create(data_types[0]->create_column()); + agg_function->insert_result_into(place2, *column_result2); + EXPECT_EQ(column_result2->size(), 1); + EXPECT_EQ(column_result2->get_offsets()[0], + is_distinct(fn_name) ? input_nums : input_nums * _repeated_times); + + agg_function->destroy(place); + agg_function->destroy(place2); + } + +private: + const size_t _repeated_times = 2; + Arena _agg_arena_pool; +}; + +TEST_F(VAggCollectTest, test_empty) { + test_agg_collect<DataTypeInt8>("collect_list"); + test_agg_collect<DataTypeInt8>("collect_set"); + test_agg_collect<DataTypeInt16>("collect_list"); + test_agg_collect<DataTypeInt16>("collect_set"); + test_agg_collect<DataTypeInt32>("collect_list"); + test_agg_collect<DataTypeInt32>("collect_set"); + test_agg_collect<DataTypeInt64>("collect_list"); + test_agg_collect<DataTypeInt64>("collect_set"); + test_agg_collect<DataTypeInt128>("collect_list"); + test_agg_collect<DataTypeInt128>("collect_set"); + + test_agg_collect<DataTypeDecimal<Decimal128>>("collect_list"); + test_agg_collect<DataTypeDecimal<Decimal128>>("collect_set"); + + test_agg_collect<DataTypeDate>("collect_list"); + test_agg_collect<DataTypeDate>("collect_set"); + + test_agg_collect<DataTypeString>("collect_list"); + test_agg_collect<DataTypeString>("collect_set"); +} + +TEST_F(VAggCollectTest, test_with_data) { + test_agg_collect<DataTypeInt32>("collect_list", 7); + test_agg_collect<DataTypeInt32>("collect_set", 9); + test_agg_collect<DataTypeInt128>("collect_list", 20); + test_agg_collect<DataTypeInt128>("collect_set", 30); + + test_agg_collect<DataTypeDecimal<Decimal128>>("collect_list", 10); + test_agg_collect<DataTypeDecimal<Decimal128>>("collect_set", 11); + + test_agg_collect<DataTypeDateTime>("collect_list", 5); + test_agg_collect<DataTypeDateTime>("collect_set", 6); + + test_agg_collect<DataTypeString>("collect_list", 10); + test_agg_collect<DataTypeString>("collect_set", 5); +} + +} // namespace doris::vectorized diff --git a/docs/en/docs/sql-manual/sql-functions/aggregate-functions/collect_list.md b/docs/en/docs/sql-manual/sql-functions/aggregate-functions/collect_list.md new file mode 100644 index 0000000000..a27b540c78 --- /dev/null +++ b/docs/en/docs/sql-manual/sql-functions/aggregate-functions/collect_list.md @@ -0,0 +1,70 @@ +--- +{ + "title": "COLLECT_LIST", + "language": "en" +} +--- + +<!-- +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +--> + +## COLLECT_LIST +### description +#### Syntax + +`ARRAY<T> collect_list(expr)` + +Returns an array consisting of all values in expr within the group. +The order of elements in the array is non-deterministic. NULL values are excluded. + +### notice + +``` +Only supported in vectorized engine +``` + +### example + +``` +mysql> set enable_vectorized_engine=true; +mysql> set enable_array_type = true; + +mysql> select k1,k2,k3 from collect_test order by k1; ++------+------------+-------+ +| k1 | k2 | k3 | ++------+------------+-------+ +| 1 | 2022-07-05 | hello | +| 2 | 2022-07-04 | NULL | +| 2 | 2022-07-04 | hello | +| 3 | NULL | world | +| 3 | NULL | world | ++------+------------+-------+ + +mysql> select k1,collect_list(k2),collect_list(k3) from collect_test group by k1 order by k1; ++------+--------------------------+--------------------+ +| k1 | collect_list(`k2`) | collect_list(`k3`) | ++------+--------------------------+--------------------+ +| 1 | [2022-07-05] | [hello] | +| 2 | [2022-07-04, 2022-07-04] | [hello] | +| 3 | NULL | [world, world] | ++------+--------------------------+--------------------+ +``` + +### keywords +COLLECT_LIST,COLLECT_SET,ARRAY diff --git a/docs/en/docs/sql-manual/sql-functions/aggregate-functions/collect_set.md b/docs/en/docs/sql-manual/sql-functions/aggregate-functions/collect_set.md new file mode 100644 index 0000000000..8c3d617b3e --- /dev/null +++ b/docs/en/docs/sql-manual/sql-functions/aggregate-functions/collect_set.md @@ -0,0 +1,70 @@ +--- +{ + "title": "COLLECT_SET", + "language": "en" +} +--- + +<!-- +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +--> + +## COLLECT_SET +### description +#### Syntax + +`ARRAY<T> collect_set(expr)` + +Returns an array consisting of all unique values in expr within the group. +The order of elements in the array is non-deterministic. NULL values are excluded. + +### notice + +``` +Only supported in vectorized engine +``` + +### example + +``` +mysql> set enable_vectorized_engine=true; +mysql> set enable_array_type = true; + +mysql> select k1,k2,k3 from collect_test order by k1; ++------+------------+-------+ +| k1 | k2 | k3 | ++------+------------+-------+ +| 1 | 2022-07-05 | hello | +| 2 | 2022-07-04 | NULL | +| 2 | 2022-07-04 | hello | +| 3 | NULL | world | +| 3 | NULL | world | ++------+------------+-------+ + +mysql> select k1,collect_set(k2),collect_set(k3) from collect_test group by k1 order by k1; ++------+-------------------+-------------------+ +| k1 | collect_set(`k2`) | collect_set(`k3`) | ++------+-------------------+-------------------+ +| 1 | [2022-07-05] | [hello] | +| 2 | [2022-07-04] | [hello] | +| 3 | NULL | [world] | ++------+-------------------+-------------------+ +``` + +### keywords +COLLECT_SET,COLLECT_LIST,ARRAY diff --git a/docs/zh-CN/docs/sql-manual/sql-functions/aggregate-functions/collect_list.md b/docs/zh-CN/docs/sql-manual/sql-functions/aggregate-functions/collect_list.md new file mode 100644 index 0000000000..eaf56bf959 --- /dev/null +++ b/docs/zh-CN/docs/sql-manual/sql-functions/aggregate-functions/collect_list.md @@ -0,0 +1,71 @@ +--- +{ + "title": "COLLECT_LIST", + "language": "zh-CN" +} +--- + +<!-- +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +--> + +## COLLECT_LIST +### description +#### Syntax + +`ARRAY<T> collect_list(expr)` + +返回一个包含 expr 中所有元素(不包括NULL)的数组,数组中元素顺序是不确定的。 + + +### notice + +``` +仅支持向量化引擎中使用 +``` + +### example + +``` +mysql> set enable_vectorized_engine=true; +mysql> set enable_array_type = true; + +mysql> select k1,k2,k3 from collect_test order by k1; ++------+------------+-------+ +| k1 | k2 | k3 | ++------+------------+-------+ +| 1 | 2022-07-05 | hello | +| 2 | 2022-07-04 | NULL | +| 2 | 2022-07-04 | hello | +| 3 | NULL | world | +| 3 | NULL | world | ++------+------------+-------+ + +mysql> select k1,collect_list(k2),collect_list(k3) from collect_test group by k1 order by k1; ++------+--------------------------+--------------------+ +| k1 | collect_list(`k2`) | collect_list(`k3`) | ++------+--------------------------+--------------------+ +| 1 | [2022-07-05] | [hello] | +| 2 | [2022-07-04, 2022-07-04] | [hello] | +| 3 | NULL | [world, world] | ++------+--------------------------+--------------------+ + +``` + +### keywords +COLLECT_LIST,COLLECT_SET,ARRAY diff --git a/docs/zh-CN/docs/sql-manual/sql-functions/aggregate-functions/collect_set.md b/docs/zh-CN/docs/sql-manual/sql-functions/aggregate-functions/collect_set.md new file mode 100644 index 0000000000..ccc734cea8 --- /dev/null +++ b/docs/zh-CN/docs/sql-manual/sql-functions/aggregate-functions/collect_set.md @@ -0,0 +1,70 @@ +--- +{ + "title": "COLLECT_SET", + "language": "zh-CN" +} +--- + +<!-- +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +--> + +## COLLECT_SET +### description +#### Syntax + +`ARRAY<T> collect_set(expr)` + +返回一个包含 expr 中所有去重后元素(不包括NULL)的数组,数组中元素顺序是不确定的。 + +### notice + +``` +仅支持向量化引擎中使用 +``` + +### example + +``` +mysql> set enable_vectorized_engine=true; +mysql> set enable_array_type = true; + +mysql> select k1,k2,k3 from collect_test order by k1; ++------+------------+-------+ +| k1 | k2 | k3 | ++------+------------+-------+ +| 1 | 2022-07-05 | hello | +| 2 | 2022-07-04 | NULL | +| 2 | 2022-07-04 | hello | +| 3 | NULL | world | +| 3 | NULL | world | ++------+------------+-------+ + +mysql> select k1,collect_set(k2),collect_set(k3) from collect_test group by k1 order by k1; ++------+-------------------+-------------------+ +| k1 | collect_set(`k2`) | collect_set(`k3`) | ++------+-------------------+-------------------+ +| 1 | [2022-07-05] | [hello] | +| 2 | [2022-07-04] | [hello] | +| 3 | NULL | [world] | ++------+-------------------+-------------------+ + +``` + +### keywords +COLLECT_SET,COLLECT_LIST,ARRAY diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java index 884a8fa452..c03722b0fa 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java @@ -855,6 +855,8 @@ public class FunctionSet<T> { //TODO(weixiang): is quantile_percent can be replaced by approx_percentile? public static final String QUANTILE_PERCENT = "quantile_percent"; public static final String TO_QUANTILE_STATE = "to_quantile_state"; + public static final String COLLECT_LIST = "collect_list"; + public static final String COLLECT_SET = "collect_set"; private static final Map<Type, String> ORTHOGONAL_BITMAP_INTERSECT_INIT_SYMBOL = ImmutableMap.<Type, String>builder() @@ -2215,6 +2217,16 @@ public class FunctionSet<T> { prefix + "26percentile_approx_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE", false, true, false, true)); + // collect_list + Type[] arraySubTypes = {Type.BOOLEAN, Type.SMALLINT, Type.TINYINT, Type.INT, Type.BIGINT, Type.LARGEINT, + Type.FLOAT, Type.DOUBLE, Type.DATE, Type.DATETIME, Type.DECIMALV2, Type.VARCHAR, Type.STRING}; + for (Type t : arraySubTypes) { + addBuiltin(AggregateFunction.createBuiltin(COLLECT_LIST, Lists.newArrayList(t), new ArrayType(t), t, + "", "", "", "", "", true, false, true, true)); + addBuiltin(AggregateFunction.createBuiltin(COLLECT_SET, Lists.newArrayList(t), new ArrayType(t), t, + "", "", "", "", "", true, false, true, true)); + } + // Avg // TODO: switch to CHAR(sizeof(AvgIntermediateType) when that becomes available addBuiltin(AggregateFunction.createBuiltin("avg", diff --git a/regression-test/data/query/sql_functions/aggregate_functions/test_aggregate_collect.out b/regression-test/data/query/sql_functions/aggregate_functions/test_aggregate_collect.out new file mode 100644 index 0000000000..cec8df6151 --- /dev/null +++ b/regression-test/data/query/sql_functions/aggregate_functions/test_aggregate_collect.out @@ -0,0 +1,9 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !select -- +1 ['hello', 'hello'] [2022-07-04, 2022-07-04] [1.23, 1.23] +2 \N \N \N + +-- !select -- +1 ['hello'] [2022-07-04] [1.23] +2 \N \N \N + diff --git a/regression-test/suites/query/sql_functions/aggregate_functions/test_aggregate_collect.groovy b/regression-test/suites/query/sql_functions/aggregate_functions/test_aggregate_collect.groovy new file mode 100644 index 0000000000..a1c1ff260e --- /dev/null +++ b/regression-test/suites/query/sql_functions/aggregate_functions/test_aggregate_collect.groovy @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_aggregate_collect", "query") { + sql "set enable_vectorized_engine = true" + sql "set enable_array_type = true;" + + def tableName = "collect_test" + sql "DROP TABLE IF EXISTS ${tableName}" + sql """ + CREATE TABLE IF NOT EXISTS ${tableName} ( + c_int INT, + c_string VARCHAR(10), + c_date Date, + c_decimal DECIMAL(10, 2) + ) + DISTRIBUTED BY HASH(c_int) BUCKETS 1 + PROPERTIES ( + "replication_num" = "1" + ) + """ + sql "INSERT INTO ${tableName} values(1,'hello','2022-07-04',1.23), (2,NULL,NULL,NULL)" + sql "INSERT INTO ${tableName} values(1,'hello','2022-07-04',1.23), (2,NULL,NULL,NULL)" + + qt_select "select c_int,collect_list(c_string),collect_list(c_date),collect_list(c_decimal) from ${tableName} group by c_int order by c_int" + qt_select "select c_int,collect_set(c_string),collect_set(c_date),collect_set(c_decimal) from ${tableName} group by c_int order by c_int" +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org