This is an automated email from the ASF dual-hosted git repository. yiguolei pushed a commit to branch branch-2.1 in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-2.1 by this push: new 3d667e95d25 [branch-2.1](function) support array_split and array_reverse_split functions (#35619) (#43761) 3d667e95d25 is described below commit 3d667e95d250268a7a3ce9c25cba0573e5f84d45 Author: zclllhhjj <zhaochan...@selectdb.com> AuthorDate: Tue Nov 12 21:27:55 2024 +0800 [branch-2.1](function) support array_split and array_reverse_split functions (#35619) (#43761) pick https://github.com/apache/doris/pull/35619 --- be/src/vec/columns/column_array.cpp | 8 +- be/src/vec/columns/column_array.h | 18 +-- .../vec/functions/array/function_array_popback.cpp | 6 +- .../functions/array/function_array_register.cpp | 2 + .../vec/functions/array/function_array_split.cpp | 157 +++++++++++++++++++++ .../doris/analysis/LambdaFunctionCallExpr.java | 31 +++- .../doris/catalog/BuiltinScalarFunctions.java | 4 + .../functions/scalar/ArrayReverseSort.java | 2 +- ...rrayReverseSort.java => ArrayReverseSplit.java} | 54 +++---- .../{ArrayReverseSort.java => ArraySplit.java} | 54 +++---- .../expressions/visitor/ScalarFunctionVisitor.java | 10 ++ gensrc/script/doris_builtins_functions.py | 37 +++++ .../array_functions/test_array_split.out | 60 ++++++++ .../array_functions/test_array_split.groovy | 81 +++++++++++ 14 files changed, 448 insertions(+), 76 deletions(-) diff --git a/be/src/vec/columns/column_array.cpp b/be/src/vec/columns/column_array.cpp index f8e5b595169..110c7f492b1 100644 --- a/be/src/vec/columns/column_array.cpp +++ b/be/src/vec/columns/column_array.cpp @@ -20,13 +20,9 @@ #include "vec/columns/column_array.h" -#include <assert.h> -#include <string.h> - #include <algorithm> #include <boost/iterator/iterator_facade.hpp> -#include <limits> -#include <memory> +#include <cstring> #include <vector> #include "common/status.h" @@ -90,7 +86,7 @@ INSTANTIATE_INDEX_IMPL(ColumnArray) ColumnArray::ColumnArray(MutableColumnPtr&& nested_column, MutableColumnPtr&& offsets_column) : data(std::move(nested_column)), offsets(std::move(offsets_column)) { - const ColumnOffsets* offsets_concrete = typeid_cast<const ColumnOffsets*>(offsets.get()); + const auto* offsets_concrete = typeid_cast<const ColumnOffsets*>(offsets.get()); if (!offsets_concrete) { throw doris::Exception(ErrorCode::INTERNAL_ERROR, "offsets_column must be a ColumnUInt64"); diff --git a/be/src/vec/columns/column_array.h b/be/src/vec/columns/column_array.h index a476c45f94d..f8769cb1e79 100644 --- a/be/src/vec/columns/column_array.h +++ b/be/src/vec/columns/column_array.h @@ -21,9 +21,9 @@ #pragma once #include <glog/logging.h> -#include <stdint.h> #include <sys/types.h> +#include <cstdint> #include <functional> #include <ostream> #include <string> @@ -31,25 +31,16 @@ #include <utility> #include "common/compiler_util.h" // IWYU pragma: keep -#include "common/status.h" #include "vec/columns/column.h" -#include "vec/columns/column_impl.h" #include "vec/columns/column_vector.h" #include "vec/common/assert_cast.h" #include "vec/common/cow.h" -#include "vec/common/pod_array_fwd.h" #include "vec/common/string_ref.h" #include "vec/core/field.h" #include "vec/core/types.h" class SipHash; -namespace doris { -namespace vectorized { -class Arena; -} // namespace vectorized -} // namespace doris - //TODO: use marcos below to decouple array function calls #define ALL_COLUMNS_NUMBER \ ColumnUInt8, ColumnInt8, ColumnInt16, ColumnInt32, ColumnInt64, ColumnInt128, ColumnFloat32, \ @@ -61,6 +52,8 @@ class Arena; namespace doris::vectorized { +class Arena; + /** Obtaining array as Field can be slow for large arrays and consume vast amount of memory. * Just don't allow to do it. * You can increase the limit if the following query: @@ -71,7 +64,6 @@ static constexpr size_t max_array_size_as_field = 1000000; /** A column of array values. * In memory, it is represented as one column of a nested type, whose size is equal to the sum of the sizes of all arrays, * and as an array of offsets in it, which allows you to get each element. - * NOTE: the ColumnArray won't nest multi-layers. That means the nested type will be concrete data-type. */ class ColumnArray final : public COWHelper<IColumn, ColumnArray> { private: @@ -268,8 +260,8 @@ public: double get_ratio_of_default_rows(double sample_ratio) const override; private: - // [[2,1,5,9,1], [1,2,4]] --> data column [2,1,5,9,1,1,2,4], offset[-1] = 0, offset[0] = 5, offset[1] = 8 - // [[[2,1,5],[9,1]], [[1,2]]] --> data column [3 column array], offset[-1] = 0, offset[0] = 2, offset[1] = 3 + // [2,1,5,9,1]\n[1,2,4] --> data column [2,1,5,9,1,1,2,4], offset[-1] = 0, offset[0] = 5, offset[1] = 8 + // [[2,1,5],[9,1]]\n[[1,2]] --> data column [3 column array], offset[-1] = 0, offset[0] = 2, offset[1] = 3 WrappedPtr data; WrappedPtr offsets; diff --git a/be/src/vec/functions/array/function_array_popback.cpp b/be/src/vec/functions/array/function_array_popback.cpp index 08c0b2f40bc..dc1f7818292 100644 --- a/be/src/vec/functions/array/function_array_popback.cpp +++ b/be/src/vec/functions/array/function_array_popback.cpp @@ -14,13 +14,13 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. + #include <fmt/format.h> #include <glog/logging.h> -#include <stddef.h> +#include <cstddef> #include <memory> #include <ostream> -#include <string> #include <utility> #include "common/status.h" @@ -74,7 +74,7 @@ public: block.get_by_position(arguments[0]).type->get_name())); } // prepare dst array column - bool is_nullable = src.nested_nullmap_data ? true : false; + bool is_nullable = src.nested_nullmap_data != nullptr; ColumnArrayMutableData dst = create_mutable_data(src.nested_col, is_nullable); dst.offsets_ptr->reserve(input_rows_count); // start from 1 diff --git a/be/src/vec/functions/array/function_array_register.cpp b/be/src/vec/functions/array/function_array_register.cpp index 5a7c33c687a..6ddf6b6a5e2 100644 --- a/be/src/vec/functions/array/function_array_register.cpp +++ b/be/src/vec/functions/array/function_array_register.cpp @@ -55,6 +55,7 @@ void register_function_array_first_or_last_index(SimpleFunctionFactory& factory) void register_function_array_cum_sum(SimpleFunctionFactory& factory); void register_function_array_count(SimpleFunctionFactory&); void register_function_array_filter_function(SimpleFunctionFactory&); +void register_function_array_splits(SimpleFunctionFactory&); void register_function_array(SimpleFunctionFactory& factory) { register_function_array_shuffle(factory); @@ -90,6 +91,7 @@ void register_function_array(SimpleFunctionFactory& factory) { register_function_array_cum_sum(factory); register_function_array_count(factory); register_function_array_filter_function(factory); + register_function_array_splits(factory); } } // namespace doris::vectorized diff --git a/be/src/vec/functions/array/function_array_split.cpp b/be/src/vec/functions/array/function_array_split.cpp new file mode 100644 index 00000000000..30e46d18c8f --- /dev/null +++ b/be/src/vec/functions/array/function_array_split.cpp @@ -0,0 +1,157 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// +// This file is copied from +// https://github.com/ClickHouse/ClickHouse/blob/master/src/Functions/array/arraySplit.cpp +// and modified by Doris + +#include <cstddef> +#include <memory> +#include <utility> + +#include "common/status.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/columns/column.h" +#include "vec/columns/column_array.h" +#include "vec/columns/column_const.h" +#include "vec/columns/column_nullable.h" +#include "vec/columns/columns_number.h" +#include "vec/common/assert_cast.h" +#include "vec/core/block.h" +#include "vec/core/column_numbers.h" +#include "vec/core/column_with_type_and_name.h" +#include "vec/core/types.h" +#include "vec/data_types/data_type.h" +#include "vec/data_types/data_type_array.h" +#include "vec/data_types/data_type_nullable.h" +#include "vec/functions/function.h" +#include "vec/functions/simple_function_factory.h" + +namespace doris { +class FunctionContext; +} // namespace doris + +namespace doris::vectorized { + +template <bool reverse> +class FunctionArraySplit : public IFunction { +public: + static constexpr auto name = reverse ? "array_reverse_split" : "array_split"; + static FunctionPtr create() { return std::make_shared<FunctionArraySplit>(); } + String get_name() const override { return name; } + + size_t get_number_of_arguments() const override { return 2; } + + DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { + return std::make_shared<DataTypeArray>(make_nullable(arguments[0])); + }; + + Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, + size_t result, size_t input_rows_count) const override { + // <Nullable>(Array(<Nullable>(Int))) + auto src_column = + block.get_by_position(arguments[0]).column->convert_to_full_column_if_const(); + auto spliter_column = + block.get_by_position(arguments[1]).column->convert_to_full_column_if_const(); + + // only change its split(i.e. offsets) + const auto& src_data = assert_cast<const ColumnArray&>(*src_column).get_data_ptr(); + const auto& src_offsets = assert_cast<const ColumnArray&>(*src_column).get_offsets(); + + auto split_col = assert_cast<const ColumnArray*>(spliter_column.get())->get_data_ptr(); + const auto& split_offsets = assert_cast<const ColumnArray&>(*spliter_column) + .get_offsets(); // for check uneven array + + const NullMap* null_map = nullptr; + if (split_col->is_nullable()) { + if (split_col->has_null()) { + null_map = + &assert_cast<const ColumnNullable*>(split_col.get())->get_null_map_data(); + } + split_col = + assert_cast<const ColumnNullable*>(split_col.get())->get_nested_column_ptr(); + } + + const IColumn::Filter& cut = assert_cast<const ColumnBool*>(split_col.get())->get_data(); + + auto col_offsets_inner = ColumnArray::ColumnOffsets::create(); + auto col_offsets_outer = ColumnArray::ColumnOffsets::create(); + auto& offsets_inner = col_offsets_inner->get_data(); + auto& offsets_outer = col_offsets_outer->get_data(); + offsets_inner.reserve(src_offsets.size()); // assume the actual size to be equal or larger + offsets_outer.reserve(src_offsets.size()); + + if (null_map != nullptr) { + RETURN_IF_ERROR(do_loop<true>(src_offsets, split_offsets, cut, null_map, offsets_inner, + offsets_outer)); + } else { + RETURN_IF_ERROR(do_loop<false>(src_offsets, split_offsets, cut, null_map, offsets_inner, + offsets_outer)); + } + + auto inner_result = ColumnArray::create(src_data, std::move(col_offsets_inner)); + auto outer_result = ColumnArray::create( + ColumnNullable::create(inner_result, ColumnUInt8::create(inner_result->size(), 0)), + std::move(col_offsets_outer)); + block.replace_by_position(result, outer_result); + return Status::OK(); + } + + template <bool CONSIDER_NULL> + static Status do_loop(const IColumn::Offsets64& src_offsets, + const IColumn::Offsets64& split_offsets, const IColumn::Filter& cut, + const NullMap* null_map, PaddedPODArray<IColumn::Offset64>& offsets_inner, + PaddedPODArray<IColumn::Offset64>& offsets_outer) { + size_t pos = 0; + for (auto i = 0; i < src_offsets.size(); i++) { // per cells + auto in_offset = src_offsets[i]; + auto sp_offset = split_offsets[i]; + if (in_offset != sp_offset) [[unlikely]] { + return Status::InvalidArgument("function {} has uneven arguments on row {}", name, + i); + } + + // [1,2,3,4,5] + if (pos < in_offset) { // values in a cell + pos += !reverse; + for (; pos < in_offset - reverse; ++pos) { + if constexpr (CONSIDER_NULL) { + if (cut[pos] && !(*null_map)[pos]) { + offsets_inner.push_back(pos + reverse); // cut a array [1,2,3] + } + } else { + if (cut[pos]) { + offsets_inner.push_back(pos + reverse); // cut a array [1,2,3] + } + } + } + pos += reverse; + // put the tail offset, always last. + offsets_inner.push_back(pos); // put [4,5] + } + + offsets_outer.push_back(offsets_inner.size()); + } + return Status::OK(); + } +}; + +void register_function_array_splits(SimpleFunctionFactory& factory) { + factory.register_function<FunctionArraySplit<true>>(); + factory.register_function<FunctionArraySplit<false>>(); +} +} // namespace doris::vectorized diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/LambdaFunctionCallExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/LambdaFunctionCallExpr.java index ebf3d1307aa..93b9dc4d5c4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/LambdaFunctionCallExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/LambdaFunctionCallExpr.java @@ -38,6 +38,7 @@ public class LambdaFunctionCallExpr extends FunctionCallExpr { public static final ImmutableSet<String> LAMBDA_FUNCTION_SET = new ImmutableSortedSet.Builder( String.CASE_INSENSITIVE_ORDER).add("array_map").add("array_filter").add("array_exists").add("array_sortby") .add("array_first_index").add("array_last_index").add("array_first").add("array_last").add("array_count") + .add("array_split").add("array_reverse_split") .build(); // The functions in this set are all normal array functions when implemented initially. // and then wants add lambda expr as the input param, so we rewrite it to contains an array_map lambda function @@ -45,7 +46,7 @@ public class LambdaFunctionCallExpr extends FunctionCallExpr { public static final ImmutableSet<String> LAMBDA_MAPPED_FUNCTION_SET = new ImmutableSortedSet.Builder( String.CASE_INSENSITIVE_ORDER).add("array_exists").add("array_sortby") .add("array_first_index").add("array_last_index").add("array_first").add("array_last").add("array_count") - .add("element_at") + .add("element_at").add("array_split").add("array_reverse_split") .build(); private static final Logger LOG = LogManager.getLogger(LambdaFunctionCallExpr.class); @@ -202,6 +203,34 @@ public class LambdaFunctionCallExpr extends FunctionCallExpr { } fn = getBuiltinFunction(fnName.getFunction(), argTypes, Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF); + } else if (fnName.getFunction().equalsIgnoreCase("array_split") + || fnName.getFunction().equalsIgnoreCase("array_reverse_split")) { + if (fnParams.exprs() == null || fnParams.exprs().size() < 2) { + throw new AnalysisException("The " + fnName.getFunction() + " function must have at least two params"); + } + /* + * array_split((x,y)->y, [1,-2,3], [0,1,1]) + * ---> array_split([1,-2,3],[0,1,1], (x,y)->y) + * ---> array_split([1,-2,3], array_map((x,y)->y, [1,-2,3], [0,1,1])) + */ + if (getChild(childSize - 1) instanceof LambdaFunctionExpr) { + List<Expr> params = new ArrayList<>(); + for (int i = 0; i <= childSize - 1; ++i) { + params.add(getChild(i)); + } + LambdaFunctionCallExpr arrayMapFunc = new LambdaFunctionCallExpr("array_map", + params); + arrayMapFunc.analyzeImpl(analyzer); + Expr firstExpr = getChild(0); + this.clearChildren(); + this.addChild(firstExpr); + this.addChild(arrayMapFunc); + argTypes = new Type[2]; + argTypes[0] = getChild(0).getType(); + argTypes[1] = getChild(1).getType(); + } + fn = getBuiltinFunction(fnName.getFunction(), argTypes, + Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF); } else if (fnName.getFunction().equalsIgnoreCase("array_last")) { // array_last(lambda,array)--->array_last(array,lambda)--->element_at(array_filter,-1) if (getChild(childSize - 1) instanceof LambdaFunctionExpr) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java index 49b13a0fd8a..6579cffde38 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java @@ -58,10 +58,12 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayRange; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayRemove; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayRepeat; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayReverseSort; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayReverseSplit; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayShuffle; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraySlice; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraySort; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraySortBy; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraySplit; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraySum; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayUnion; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayWithConstant; @@ -514,9 +516,11 @@ public class BuiltinScalarFunctions implements FunctionHelper { scalar(ArrayRemove.class, "array_remove"), scalar(ArrayRepeat.class, "array_repeat"), scalar(ArrayReverseSort.class, "array_reverse_sort"), + scalar(ArrayReverseSplit.class, "array_reverse_split"), scalar(ArraySlice.class, "array_slice"), scalar(ArraySort.class, "array_sort"), scalar(ArraySortBy.class, "array_sortby"), + scalar(ArraySplit.class, "array_split"), scalar(ArrayShuffle.class, "array_shuffle", "shuffle"), scalar(ArraySum.class, "array_sum"), scalar(ArrayUnion.class, "array_union"), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayReverseSort.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayReverseSort.java index dd62fdb7e45..1fb920e0bd1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayReverseSort.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayReverseSort.java @@ -34,7 +34,7 @@ import com.google.common.collect.ImmutableList; import java.util.List; /** - * ScalarFunction 'array_sort'. This class is generated by GenerateFunction. + * ScalarFunction 'array_reverse_sort'. */ public class ArrayReverseSort extends ScalarFunction implements UnaryExpression, ExplicitlyCastableSignature, PropagateNullable { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayReverseSort.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayReverseSplit.java similarity index 53% copy from fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayReverseSort.java copy to fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayReverseSplit.java index dd62fdb7e45..4b7cea0f23d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayReverseSort.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayReverseSplit.java @@ -20,60 +20,62 @@ package org.apache.doris.nereids.trees.expressions.functions.scalar; import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; -import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.ArrayType; -import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.BooleanType; import org.apache.doris.nereids.types.coercion.AnyDataType; +import org.apache.doris.nereids.types.coercion.FollowToArgumentType; -import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import java.util.List; /** - * ScalarFunction 'array_sort'. This class is generated by GenerateFunction. + * ScalarFunction 'array_reverse_split'. */ -public class ArrayReverseSort extends ScalarFunction - implements UnaryExpression, ExplicitlyCastableSignature, PropagateNullable { - +public class ArrayReverseSplit extends ScalarFunction implements PropagateNullable, HighOrderFunction { + // arg0 = Array<T>, return_value = Array<Array<T>> public static final List<FunctionSignature> SIGNATURES = ImmutableList.of( - FunctionSignature.retArgType(0).args(ArrayType.of(AnyDataType.INSTANCE_WITHOUT_INDEX)) - ); + FunctionSignature.ret(ArrayType.of(new FollowToArgumentType(0))).args( + ArrayType.of(AnyDataType.INSTANCE_WITHOUT_INDEX), + ArrayType.of(BooleanType.INSTANCE))); + + private ArrayReverseSplit(List<Expression> expressions) { + super("array_reverse_split", expressions); + } /** - * constructor with 1 argument. + * constructor with arguments. */ - public ArrayReverseSort(Expression arg) { - super("array_reverse_sort", arg); + public ArrayReverseSplit(Expression arg0, Expression arg1) { + super("array_reverse_split", arg0, arg1); } - @Override - public void checkLegalityBeforeTypeCoercion() { - DataType argType = child().getDataType(); - if (((ArrayType) argType).getItemType().isComplexType()) { - throw new AnalysisException("array_reverse_sort does not support complex types: " + toSql()); + /** + * constructor with arguments. + * array_split(lambda, a1, ...) = array_split(a1, array_map(lambda, a1, ...)) + */ + public ArrayReverseSplit(Expression arg) { + super("array_reverse_split", arg.child(1).child(0), new ArrayMap(arg)); + if (!(arg instanceof Lambda)) { + throw new AnalysisException( + String.format("The 1st arg of %s must be lambda but is %s", getName(), arg)); } } - /** - * withChildren. - */ @Override - public ArrayReverseSort withChildren(List<Expression> children) { - Preconditions.checkArgument(children.size() == 1); - return new ArrayReverseSort(children.get(0)); + public ArrayReverseSplit withChildren(List<Expression> children) { + return new ArrayReverseSplit(children.get(0), children.get(1)); } @Override public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) { - return visitor.visitArrayReverseSort(this, context); + return visitor.visitArrayReverseSplit(this, context); } @Override - public List<FunctionSignature> getSignatures() { + public List<FunctionSignature> getImplSignature() { return SIGNATURES; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayReverseSort.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArraySplit.java similarity index 54% copy from fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayReverseSort.java copy to fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArraySplit.java index dd62fdb7e45..07ca4aafe65 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayReverseSort.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArraySplit.java @@ -20,60 +20,62 @@ package org.apache.doris.nereids.trees.expressions.functions.scalar; import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; -import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.ArrayType; -import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.BooleanType; import org.apache.doris.nereids.types.coercion.AnyDataType; +import org.apache.doris.nereids.types.coercion.FollowToArgumentType; -import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import java.util.List; /** - * ScalarFunction 'array_sort'. This class is generated by GenerateFunction. + * ScalarFunction 'array_split'. */ -public class ArrayReverseSort extends ScalarFunction - implements UnaryExpression, ExplicitlyCastableSignature, PropagateNullable { - +public class ArraySplit extends ScalarFunction implements PropagateNullable, HighOrderFunction { + // arg0 = Array<T>, return_value = Array<Array<T>> public static final List<FunctionSignature> SIGNATURES = ImmutableList.of( - FunctionSignature.retArgType(0).args(ArrayType.of(AnyDataType.INSTANCE_WITHOUT_INDEX)) - ); + FunctionSignature.ret(ArrayType.of(new FollowToArgumentType(0))).args( + ArrayType.of(AnyDataType.INSTANCE_WITHOUT_INDEX), + ArrayType.of(BooleanType.INSTANCE))); + + private ArraySplit(List<Expression> expressions) { + super("array_split", expressions); + } /** - * constructor with 1 argument. + * constructor with arguments. */ - public ArrayReverseSort(Expression arg) { - super("array_reverse_sort", arg); + public ArraySplit(Expression arg0, Expression arg1) { + super("array_split", arg0, arg1); } - @Override - public void checkLegalityBeforeTypeCoercion() { - DataType argType = child().getDataType(); - if (((ArrayType) argType).getItemType().isComplexType()) { - throw new AnalysisException("array_reverse_sort does not support complex types: " + toSql()); + /** + * constructor with arguments. + * array_split(lambda, a1, ...) = array_split(a1, array_map(lambda, a1, ...)) + */ + public ArraySplit(Expression arg) { + super("array_split", arg.child(1).child(0), new ArrayMap(arg)); + if (!(arg instanceof Lambda)) { + throw new AnalysisException( + String.format("The 1st arg of %s must be lambda but is %s", getName(), arg)); } } - /** - * withChildren. - */ @Override - public ArrayReverseSort withChildren(List<Expression> children) { - Preconditions.checkArgument(children.size() == 1); - return new ArrayReverseSort(children.get(0)); + public ArraySplit withChildren(List<Expression> children) { + return new ArraySplit(children.get(0), children.get(1)); } @Override public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) { - return visitor.visitArrayReverseSort(this, context); + return visitor.visitArraySplit(this, context); } @Override - public List<FunctionSignature> getSignatures() { + public List<FunctionSignature> getImplSignature() { return SIGNATURES; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java index 55d58e320b2..40c1ada9fd3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java @@ -65,10 +65,12 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayRangeYea import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayRemove; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayRepeat; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayReverseSort; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayReverseSplit; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayShuffle; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraySlice; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraySort; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraySortBy; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraySplit; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraySum; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayUnion; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayWithConstant; @@ -643,6 +645,14 @@ public interface ScalarFunctionVisitor<R, C> { return visitScalarFunction(arraySortBy, context); } + default R visitArraySplit(ArraySplit arraySplit, C context) { + return visitScalarFunction(arraySplit, context); + } + + default R visitArrayReverseSplit(ArrayReverseSplit arrayReverseSplit, C context) { + return visitScalarFunction(arrayReverseSplit, context); + } + default R visitArrayShuffle(ArrayShuffle arrayShuffle, C context) { return visitScalarFunction(arrayShuffle, context); } diff --git a/gensrc/script/doris_builtins_functions.py b/gensrc/script/doris_builtins_functions.py index 0da5f697100..74f7b42f7a8 100644 --- a/gensrc/script/doris_builtins_functions.py +++ b/gensrc/script/doris_builtins_functions.py @@ -840,6 +840,43 @@ visible_functions = { [['array_zip'], 'ARRAY', ['ARRAY<T>', '...'], '', ['T']], + [['array_split'], 'ARRAY_BOOLEAN',['ARRAY_BOOLEAN', 'ARRAY_BOOLEAN'], ''], + [['array_split'], 'ARRAY_TINYINT',['ARRAY_TINYINT', 'ARRAY_BOOLEAN'], ''], + [['array_split'], 'ARRAY_SMALLINT',['ARRAY_SMALLINT', 'ARRAY_BOOLEAN'], ''], + [['array_split'], 'ARRAY_INT',['ARRAY_INT', 'ARRAY_BOOLEAN'], ''], + [['array_split'], 'ARRAY_BIGINT',['ARRAY_BIGINT', 'ARRAY_BOOLEAN'], ''], + [['array_split'], 'ARRAY_LARGEINT',['ARRAY_LARGEINT', 'ARRAY_BOOLEAN'], ''], + [['array_split'], 'ARRAY_FLOAT',['ARRAY_FLOAT', 'ARRAY_BOOLEAN'], ''], + [['array_split'], 'ARRAY_DOUBLE',['ARRAY_DOUBLE', 'ARRAY_BOOLEAN'], ''], + [['array_split'], 'ARRAY_VARCHAR',['ARRAY_VARCHAR', 'ARRAY_BOOLEAN'], ''], + [['array_split'], 'ARRAY_STRING',['ARRAY_STRING', 'ARRAY_BOOLEAN'], ''], + [['array_split'], 'ARRAY_DECIMALV2',['ARRAY_DECIMALV2', 'ARRAY_BOOLEAN'], ''], + [['array_split'], 'ARRAY_DECIMAL32',['ARRAY_DECIMAL32', 'ARRAY_BOOLEAN'], ''], + [['array_split'], 'ARRAY_DECIMAL64',['ARRAY_DECIMAL64', 'ARRAY_BOOLEAN'], ''], + [['array_split'], 'ARRAY_DECIMAL128',['ARRAY_DECIMAL128', 'ARRAY_BOOLEAN'], ''], + [['array_split'], 'ARRAY_DATETIME',['ARRAY_DATETIME', 'ARRAY_BOOLEAN'], ''], + [['array_split'], 'ARRAY_DATE',['ARRAY_DATE', 'ARRAY_BOOLEAN'], ''], + [['array_split'], 'ARRAY_DATETIMEV2',['ARRAY_DATETIMEV2', 'ARRAY_BOOLEAN'], ''], + [['array_split'], 'ARRAY_DATEV2',['ARRAY_DATEV2', 'ARRAY_BOOLEAN'], ''], + [['array_reverse_split'], 'ARRAY_BOOLEAN',['ARRAY_BOOLEAN', 'ARRAY_BOOLEAN'], ''], + [['array_reverse_split'], 'ARRAY_TINYINT',['ARRAY_TINYINT', 'ARRAY_BOOLEAN'], ''], + [['array_reverse_split'], 'ARRAY_SMALLINT',['ARRAY_SMALLINT', 'ARRAY_BOOLEAN'], ''], + [['array_reverse_split'], 'ARRAY_INT',['ARRAY_INT', 'ARRAY_BOOLEAN'], ''], + [['array_reverse_split'], 'ARRAY_BIGINT',['ARRAY_BIGINT', 'ARRAY_BOOLEAN'], ''], + [['array_reverse_split'], 'ARRAY_LARGEINT',['ARRAY_LARGEINT', 'ARRAY_BOOLEAN'], ''], + [['array_reverse_split'], 'ARRAY_FLOAT',['ARRAY_FLOAT', 'ARRAY_BOOLEAN'], ''], + [['array_reverse_split'], 'ARRAY_DOUBLE',['ARRAY_DOUBLE', 'ARRAY_BOOLEAN'], ''], + [['array_reverse_split'], 'ARRAY_VARCHAR',['ARRAY_VARCHAR', 'ARRAY_BOOLEAN'], ''], + [['array_reverse_split'], 'ARRAY_STRING',['ARRAY_STRING', 'ARRAY_BOOLEAN'], ''], + [['array_reverse_split'], 'ARRAY_DECIMALV2',['ARRAY_DECIMALV2', 'ARRAY_BOOLEAN'], ''], + [['array_reverse_split'], 'ARRAY_DECIMAL32',['ARRAY_DECIMAL32', 'ARRAY_BOOLEAN'], ''], + [['array_reverse_split'], 'ARRAY_DECIMAL64',['ARRAY_DECIMAL64', 'ARRAY_BOOLEAN'], ''], + [['array_reverse_split'], 'ARRAY_DECIMAL128',['ARRAY_DECIMAL128', 'ARRAY_BOOLEAN'], ''], + [['array_reverse_split'], 'ARRAY_DATETIME',['ARRAY_DATETIME', 'ARRAY_BOOLEAN'], ''], + [['array_reverse_split'], 'ARRAY_DATE',['ARRAY_DATE', 'ARRAY_BOOLEAN'], ''], + [['array_reverse_split'], 'ARRAY_DATETIMEV2',['ARRAY_DATETIMEV2', 'ARRAY_BOOLEAN'], ''], + [['array_reverse_split'], 'ARRAY_DATEV2',['ARRAY_DATEV2', 'ARRAY_BOOLEAN'], ''], + # reverse function for string builtin [['reverse'], 'VARCHAR', ['VARCHAR'], ''], diff --git a/regression-test/data/query_p0/sql_functions/array_functions/test_array_split.out b/regression-test/data/query_p0/sql_functions/array_functions/test_array_split.out new file mode 100644 index 00000000000..f06e4309c86 --- /dev/null +++ b/regression-test/data/query_p0/sql_functions/array_functions/test_array_split.out @@ -0,0 +1,60 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !sql1 -- +[[1, 2, 3, 4, 5]] + +-- !sql2 -- +[[1, 2, 3, 4, 5]] + +-- !sql3 -- +[[1, 2, 3, 4, 5]] + +-- !sql4 -- +[[1, 2, 3, 4, 5]] + +-- !sql5 -- +\N + +-- !sql6 -- +\N + +-- !sql7 -- +\N + +-- !sql8 -- +[[1, 2, 3, 4], [5]] + +-- !lambda1 -- +[[1, 2, 3, 4, 5]] + +-- !lambda2 -- +[["a", "b", "c"], ["d"]] + +-- !lambda3 -- +[["a", "b"], ["c"], ["d"]] + +-- !null1 -- +[[1, 2], [3, 4, 5]] + +-- !null2 -- +[[1, null, null], [4, 5]] + +-- !table1 -- +1 [[1], [2], [3], [4], [5]] [[1], [2], [3], [4], [5]] +2 [[2, 3], [4]] [[2, 3], [4]] +3 \N [[1], [2], [3], [4], [5]] +4 \N [[1], [2], [3], [4], [5]] + +-- !table2 -- +1 [[1], [2], [3], [4], [5]] [[1], [2], [3], [4], [5]] +2 [[2], [3, 4]] [[2], [3, 4]] +3 \N [[1], [2], [3], [4], [5]] +4 \N [[1], [2], [3], [4], [5]] + +-- !dt1 -- +1 [["2020-12-12 00:00:00.000000", "2013-12-12 00:00:00.000000"], ["2015-12-12 00:00:00.000000", null]] [["2020-12-12 00:00:00.000000"], ["2013-12-12 00:00:00.000000", "2015-12-12 00:00:00.000000"], [null]] +2 [["2020-12-12 00:00:00.000000", "2013-12-12 00:00:00.000000"], ["2015-12-12 00:00:00.000000", null], ["2200-12-12 12:12:12.123456"]] [["2020-12-12 00:00:00.000000"], ["2013-12-12 00:00:00.000000", "2015-12-12 00:00:00.000000"], [null, "2200-12-12 12:12:12.123456"]] + +-- !dt_null -- +1 [["2020-12-12 00:00:00.000000", "2013-12-12 00:00:00.000000", "2015-12-12 00:00:00.000000", null]] +2 [["2020-12-12 00:00:00.000000", "2013-12-12 00:00:00.000000", "2015-12-12 00:00:00.000000", null], ["2200-12-12 12:12:12.123456"]] + diff --git a/regression-test/suites/query_p0/sql_functions/array_functions/test_array_split.groovy b/regression-test/suites/query_p0/sql_functions/array_functions/test_array_split.groovy new file mode 100644 index 00000000000..f74760f4cef --- /dev/null +++ b/regression-test/suites/query_p0/sql_functions/array_functions/test_array_split.groovy @@ -0,0 +1,81 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_array_split") { + sql " set parallel_pipeline_task_num = 1; " + qt_sql1 " select array_split([1,2,3,4,5], [1,0,0,0,0]); " + qt_sql2 " select array_split(nullable([1,2,3,4,5]), [1,0,0,0,0]); " + qt_sql3 " select array_split([1,2,3,4,5], nullable([1,0,0,0,0])); " + qt_sql4 " select array_split(nullable([1,2,3,4,5]), nullable([1,0,0,0,0])); " + qt_sql5 " select array_split(cast(null as array<int>), [1,0,0]); " + qt_sql6 " select array_split([1,2,3], cast(null as array<tinyint>)); " + qt_sql7 " select array_split(cast(null as array<int>), cast(null as array<tinyint>)); " + qt_sql8 " select array_reverse_split([1,2,3,4,5], [0,0,0,1,0]); " + qt_lambda1 " select array_split((x,y)->y, [1,2,3,4,5], [1,0,0,0,0]); " + qt_lambda2 " select array_reverse_split((x,y)->(y+1), ['a', 'b', 'c', 'd'], [-1, -1, 0, -1]); " + qt_lambda3 " select array_reverse_split(x->(x>'a'), ['a', 'b', 'c', 'd']); " + qt_null1 " select array_split([1,2,3,4,5], [null,null,1,0,0]); " + qt_null2 " select array_reverse_split([1,null,null,4,5], [null,null,1,0,0]); " + + sql " drop table if exists arr_int; " + sql """ + create table arr_int( + x int, + a0 array<int> NULL, + s0 array<int> NULL, + a1 array<int> NOT NULL, + s1 array<int> NOT NULL + ) + DISTRIBUTED BY HASH(`x`) BUCKETS auto + properties("replication_num" = "1"); + """ + sql """ + insert into arr_int values + (1, [1,2,3,4,5], [1,1,1,1,1], [1,2,3,4,5], [1,1,1,1,1]), + (2, [2,3,4], [1,0,1], [2,3,4], [1,0,1]), + (3, NULL, [1,1,1,1,1], [1,2,3,4,5], [1,1,1,1,1]), + (4, [1,2,3,4,5], NULL, [1,2,3,4,5], [1,1,1,1,1]); + """ + qt_table1 " select x, array_split(a0, s0), array_split(a1, s1) from arr_int order by x; " + qt_table2 " select x, array_reverse_split(a0, s0), array_reverse_split(a1, s1) from arr_int order by x; " + + sql " drop table if exists dt; " + sql """ + create table dt( + x int, + k0 array<datetime(6)> + ) + DISTRIBUTED BY HASH(`x`) BUCKETS auto + properties("replication_num" = "1"); + """ + sql """ insert into dt values + (1, ["2020-12-12", "2013-12-12", "2015-12-12", null]), + (2, ["2020-12-12", "2013-12-12", "2015-12-12", null, "2200-12-12 12:12:12.123456"]); """ + + qt_dt1 """ select x, array_split(x->(year(x)>2013), k0), array_reverse_split(x->(year(x)>2013), k0) + from dt order by x; """ + qt_dt_null """ select x, array_reverse_split(x->(null_or_empty(x)), k0) from dt order by x; """ + + test { + sql " select array_split([1,2,3,4,5], [1,1,1]); " + exception "function array_split has uneven arguments on row 0" + } + test { + sql " select array_reverse_split((x,y)->(y), [1,2,3,4,5], [1,1,1]); " + exception "in array map function, the input column size are not equal completely" + } +} \ No newline at end of file --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org