This is an automated email from the ASF dual-hosted git repository.
panxiaolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 527eb5b059 [Enchancement](function) nullable inline refactor of
min_max_by/bitmap && add register_functio… (#17228)
527eb5b059 is described below
commit 527eb5b059e15c22411ba44f2434b47aace0dc9f
Author: Pxl <[email protected]>
AuthorDate: Thu Mar 2 00:00:01 2023 +0800
[Enchancement](function) nullable inline refactor of min_max_by/bitmap &&
add register_functio… (#17228)
1. nullable inline refactor of min_max_by/bitmap/group_concat/histogram/topn
2. add register_function_both method
3. add datetimev2 type creator of min_max_by
4. remove uint16/32/64 in FOR_INTEGER_TYPES
---
.../aggregate_functions/aggregate_function_avg.cpp | 3 +-
.../aggregate_functions/aggregate_function_bit.cpp | 21 +---
.../aggregate_function_bitmap.cpp | 58 +++++-----
.../aggregate_function_group_concat.cpp | 14 ++-
.../aggregate_function_histogram.cpp | 46 +++-----
.../aggregate_function_min_max.cpp | 11 +-
.../aggregate_function_min_max_by.cpp | 124 ++++++++++-----------
.../aggregate_function_min_max_by.h | 16 +--
.../aggregate_function_orthogonal_bitmap.cpp | 28 ++---
.../aggregate_function_simple_factory.h | 5 +
.../aggregate_functions/aggregate_function_sum.cpp | 5 +-
.../aggregate_function_topn.cpp | 81 ++++++++------
.../aggregate_function_uniq.cpp | 3 +-
.../aggregate_function_window.cpp | 12 +-
be/src/vec/aggregate_functions/helpers.h | 119 +++++---------------
be/src/vec/core/types.h | 9 ++
16 files changed, 229 insertions(+), 326 deletions(-)
diff --git a/be/src/vec/aggregate_functions/aggregate_function_avg.cpp
b/be/src/vec/aggregate_functions/aggregate_function_avg.cpp
index 8bda389f4a..4f493c9529 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_avg.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_avg.cpp
@@ -52,7 +52,6 @@ AggregateFunctionPtr create_aggregate_function_avg(const
std::string& name,
}
void register_aggregate_function_avg(AggregateFunctionSimpleFactory& factory) {
- factory.register_function("avg", create_aggregate_function_avg);
- factory.register_function("avg", create_aggregate_function_avg, true);
+ factory.register_function_both("avg", create_aggregate_function_avg);
}
} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_bit.cpp
b/be/src/vec/aggregate_functions/aggregate_function_bit.cpp
index 379df49559..6b9be5c92c 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_bit.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_bit.cpp
@@ -47,21 +47,12 @@ AggregateFunctionPtr createAggregateFunctionBitwise(const
std::string& name,
}
void register_aggregate_function_bit(AggregateFunctionSimpleFactory& factory) {
- factory.register_function("group_bit_or",
-
createAggregateFunctionBitwise<AggregateFunctionGroupBitOrData>);
- factory.register_function("group_bit_and",
-
createAggregateFunctionBitwise<AggregateFunctionGroupBitAndData>);
- factory.register_function("group_bit_xor",
-
createAggregateFunctionBitwise<AggregateFunctionGroupBitXorData>);
-
- factory.register_function(
- "group_bit_or",
createAggregateFunctionBitwise<AggregateFunctionGroupBitOrData>, true);
- factory.register_function("group_bit_and",
-
createAggregateFunctionBitwise<AggregateFunctionGroupBitAndData>,
- true);
- factory.register_function("group_bit_xor",
-
createAggregateFunctionBitwise<AggregateFunctionGroupBitXorData>,
- true);
+ factory.register_function_both("group_bit_or",
+
createAggregateFunctionBitwise<AggregateFunctionGroupBitOrData>);
+ factory.register_function_both(
+ "group_bit_and",
createAggregateFunctionBitwise<AggregateFunctionGroupBitAndData>);
+ factory.register_function_both(
+ "group_bit_xor",
createAggregateFunctionBitwise<AggregateFunctionGroupBitXorData>);
}
} // namespace doris::vectorized
\ No newline at end of file
diff --git a/be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp
b/be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp
index eb9a8fb35c..e2dd7e309d 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp
@@ -18,50 +18,45 @@
#include "vec/aggregate_functions/aggregate_function_bitmap.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
+#include "vec/aggregate_functions/helpers.h"
namespace doris::vectorized {
template <bool nullable, template <bool, typename> class
AggregateFunctionTemplate>
-static IAggregateFunction* createWithIntDataType(const DataTypes&
argument_type) {
- auto type = argument_type[0].get();
- if (type->is_nullable()) {
- type = assert_cast<const
DataTypeNullable*>(type)->get_nested_type().get();
- }
+static IAggregateFunction* create_with_int_data_type(const DataTypes&
argument_type) {
+ auto type = remove_nullable(argument_type[0]);
WhichDataType which(type);
- if (which.idx == TypeIndex::Int8) {
- return new AggregateFunctionTemplate<nullable,
ColumnVector<Int8>>(argument_type);
- }
- if (which.idx == TypeIndex::Int16) {
- return new AggregateFunctionTemplate<nullable,
ColumnVector<Int16>>(argument_type);
- }
- if (which.idx == TypeIndex::Int32) {
- return new AggregateFunctionTemplate<nullable,
ColumnVector<Int32>>(argument_type);
- }
- if (which.idx == TypeIndex::Int64) {
- return new AggregateFunctionTemplate<nullable,
ColumnVector<Int64>>(argument_type);
+#define DISPATCH(TYPE)
\
+ if (which.idx == TypeIndex::TYPE) {
\
+ return new AggregateFunctionTemplate<nullable,
ColumnVector<TYPE>>(argument_type); \
}
+ FOR_INTEGER_TYPES(DISPATCH)
+#undef DISPATCH
return nullptr;
}
AggregateFunctionPtr create_aggregate_function_bitmap_union(const std::string&
name,
const DataTypes&
argument_types,
const bool
result_is_nullable) {
- return
std::make_shared<AggregateFunctionBitmapOp<AggregateFunctionBitmapUnionOp>>(
- argument_types);
+ return AggregateFunctionPtr(
+
creator_without_type::create<AggregateFunctionBitmapOp<AggregateFunctionBitmapUnionOp>>(
+ result_is_nullable, argument_types));
}
AggregateFunctionPtr create_aggregate_function_bitmap_intersect(const
std::string& name,
const
DataTypes& argument_types,
const bool
result_is_nullable) {
- return
std::make_shared<AggregateFunctionBitmapOp<AggregateFunctionBitmapIntersectOp>>(
- argument_types);
+ return AggregateFunctionPtr(creator_without_type::create<
+
AggregateFunctionBitmapOp<AggregateFunctionBitmapIntersectOp>>(
+ result_is_nullable, argument_types));
}
AggregateFunctionPtr create_aggregate_function_group_bitmap_xor(const
std::string& name,
const
DataTypes& argument_types,
const bool
result_is_nullable) {
- return
std::make_shared<AggregateFunctionBitmapOp<AggregateFunctionGroupBitmapXorOp>>(
- argument_types);
+ return AggregateFunctionPtr(creator_without_type::create<
+
AggregateFunctionBitmapOp<AggregateFunctionGroupBitmapXorOp>>(
+ result_is_nullable, argument_types));
}
AggregateFunctionPtr create_aggregate_function_bitmap_union_count(const
std::string& name,
@@ -81,22 +76,19 @@ AggregateFunctionPtr
create_aggregate_function_bitmap_union_int(const std::strin
const bool arg_is_nullable = argument_types[0]->is_nullable();
if (arg_is_nullable) {
return std::shared_ptr<IAggregateFunction>(
- createWithIntDataType<true,
AggregateFunctionBitmapCount>(argument_types));
+ create_with_int_data_type<true,
AggregateFunctionBitmapCount>(argument_types));
} else {
return std::shared_ptr<IAggregateFunction>(
- createWithIntDataType<false,
AggregateFunctionBitmapCount>(argument_types));
+ create_with_int_data_type<false,
AggregateFunctionBitmapCount>(argument_types));
}
}
void register_aggregate_function_bitmap(AggregateFunctionSimpleFactory&
factory) {
- factory.register_function("bitmap_union",
create_aggregate_function_bitmap_union);
- factory.register_function("bitmap_intersect",
create_aggregate_function_bitmap_intersect);
- factory.register_function("group_bitmap_xor",
create_aggregate_function_group_bitmap_xor);
- factory.register_function("bitmap_union_count",
create_aggregate_function_bitmap_union_count);
- factory.register_function("bitmap_union_count",
create_aggregate_function_bitmap_union_count,
- true);
-
- factory.register_function("bitmap_union_int",
create_aggregate_function_bitmap_union_int);
- factory.register_function("bitmap_union_int",
create_aggregate_function_bitmap_union_int, true);
+ factory.register_function_both("bitmap_union",
create_aggregate_function_bitmap_union);
+ factory.register_function_both("bitmap_intersect",
create_aggregate_function_bitmap_intersect);
+ factory.register_function_both("group_bitmap_xor",
create_aggregate_function_group_bitmap_xor);
+ factory.register_function_both("bitmap_union_count",
+
create_aggregate_function_bitmap_union_count);
+ factory.register_function_both("bitmap_union_int",
create_aggregate_function_bitmap_union_int);
}
} // namespace doris::vectorized
\ No newline at end of file
diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_concat.cpp
b/be/src/vec/aggregate_functions/aggregate_function_group_concat.cpp
index bcd7becc5e..5bd070ada3 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_group_concat.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_group_concat.cpp
@@ -17,6 +17,8 @@
#include "vec/aggregate_functions/aggregate_function_group_concat.h"
+#include "vec/aggregate_functions/helpers.h"
+
namespace doris::vectorized {
const std::string AggregateFunctionGroupConcatImplStr::separator = ", ";
@@ -26,12 +28,14 @@ AggregateFunctionPtr
create_aggregate_function_group_concat(const std::string& n
const bool
result_is_nullable) {
if (argument_types.size() == 1) {
return AggregateFunctionPtr(
- new
AggregateFunctionGroupConcat<AggregateFunctionGroupConcatImplStr>(
- argument_types));
+ creator_without_type::create<
+
AggregateFunctionGroupConcat<AggregateFunctionGroupConcatImplStr>>(
+ result_is_nullable, argument_types));
} else if (argument_types.size() == 2) {
return AggregateFunctionPtr(
- new
AggregateFunctionGroupConcat<AggregateFunctionGroupConcatImplStrStr>(
- argument_types));
+ creator_without_type::create<
+
AggregateFunctionGroupConcat<AggregateFunctionGroupConcatImplStrStr>>(
+ result_is_nullable, argument_types));
}
LOG(WARNING) << fmt::format("Illegal number {} of argument for aggregate
function {}",
@@ -40,6 +44,6 @@ AggregateFunctionPtr
create_aggregate_function_group_concat(const std::string& n
}
void register_aggregate_function_group_concat(AggregateFunctionSimpleFactory&
factory) {
- factory.register_function("group_concat",
create_aggregate_function_group_concat);
+ factory.register_function_both("group_concat",
create_aggregate_function_group_concat);
}
} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_histogram.cpp
b/be/src/vec/aggregate_functions/aggregate_function_histogram.cpp
index 81dece0c95..77e67ab29a 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_histogram.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_histogram.cpp
@@ -23,56 +23,46 @@
namespace doris::vectorized {
template <typename T>
-AggregateFunctionPtr create_agg_function_histogram(const DataTypes&
argument_types) {
+AggregateFunctionPtr create_agg_function_histogram(const DataTypes&
argument_types,
+ const bool
result_is_nullable) {
bool has_input_param = (argument_types.size() == 3);
if (has_input_param) {
return AggregateFunctionPtr(
- new
AggregateFunctionHistogram<AggregateFunctionHistogramData<T>, T, true>(
- argument_types));
+ creator_without_type::create<
+
AggregateFunctionHistogram<AggregateFunctionHistogramData<T>, T, true>>(
+ result_is_nullable, argument_types));
} else {
return AggregateFunctionPtr(
- new
AggregateFunctionHistogram<AggregateFunctionHistogramData<T>, T, false>(
- argument_types));
+ creator_without_type::create<
+
AggregateFunctionHistogram<AggregateFunctionHistogramData<T>, T, false>>(
+ result_is_nullable, argument_types));
}
}
AggregateFunctionPtr create_aggregate_function_histogram(const std::string&
name,
const DataTypes&
argument_types,
const bool
result_is_nullable) {
- WhichDataType type(argument_types[0]);
+ WhichDataType type(remove_nullable(argument_types[0]));
- LOG(INFO) << fmt::format("supported input type {} for aggregate function
{}",
- argument_types[0]->get_name(), name);
-
-#define DISPATCH(TYPE) \
- if (type.idx == TypeIndex::TYPE) return
create_agg_function_histogram<TYPE>(argument_types);
+#define DISPATCH(TYPE) \
+ if (type.idx == TypeIndex::TYPE) \
+ return create_agg_function_histogram<TYPE>(argument_types,
result_is_nullable);
FOR_NUMERIC_TYPES(DISPATCH)
+ FOR_DECIMAL_TYPES(DISPATCH)
#undef DISPATCH
if (type.idx == TypeIndex::String) {
- return create_agg_function_histogram<String>(argument_types);
+ return create_agg_function_histogram<String>(argument_types,
result_is_nullable);
}
if (type.idx == TypeIndex::DateTime || type.idx == TypeIndex::Date) {
- return create_agg_function_histogram<Int64>(argument_types);
+ return create_agg_function_histogram<Int64>(argument_types,
result_is_nullable);
}
if (type.idx == TypeIndex::DateV2) {
- return create_agg_function_histogram<UInt32>(argument_types);
+ return create_agg_function_histogram<UInt32>(argument_types,
result_is_nullable);
}
if (type.idx == TypeIndex::DateTimeV2) {
- return create_agg_function_histogram<UInt64>(argument_types);
- }
- if (type.idx == TypeIndex::Decimal32) {
- return create_agg_function_histogram<Decimal32>(argument_types);
- }
- if (type.idx == TypeIndex::Decimal64) {
- return create_agg_function_histogram<Decimal64>(argument_types);
- }
- if (type.idx == TypeIndex::Decimal128) {
- return create_agg_function_histogram<Decimal128>(argument_types);
- }
- if (type.idx == TypeIndex::Decimal128I) {
- return create_agg_function_histogram<Decimal128I>(argument_types);
+ return create_agg_function_histogram<UInt64>(argument_types,
result_is_nullable);
}
LOG(WARNING) << fmt::format("unsupported input type {} for aggregate
function {}",
@@ -81,7 +71,7 @@ AggregateFunctionPtr
create_aggregate_function_histogram(const std::string& name
}
void register_aggregate_function_histogram(AggregateFunctionSimpleFactory&
factory) {
- factory.register_function("histogram",
create_aggregate_function_histogram);
+ factory.register_function_both("histogram",
create_aggregate_function_histogram);
factory.register_alias("histogram", "hist");
}
diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp
b/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp
index 882b532c7e..46606142b2 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp
@@ -97,14 +97,9 @@ AggregateFunctionPtr create_aggregate_function_any(const
std::string& name,
}
void register_aggregate_function_minmax(AggregateFunctionSimpleFactory&
factory) {
- factory.register_function("max", create_aggregate_function_max);
- factory.register_function("min", create_aggregate_function_min);
- factory.register_function("any", create_aggregate_function_any);
-
- factory.register_function("max", create_aggregate_function_max, true);
- factory.register_function("min", create_aggregate_function_min, true);
- factory.register_function("any", create_aggregate_function_any, true);
-
+ factory.register_function_both("max", create_aggregate_function_max);
+ factory.register_function_both("min", create_aggregate_function_min);
+ factory.register_function_both("any", create_aggregate_function_any);
factory.register_alias("any", "any_value");
}
diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max_by.cpp
b/be/src/vec/aggregate_functions/aggregate_function_min_max_by.cpp
index 8a4ad945f9..2252da7721 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_min_max_by.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_min_max_by.cpp
@@ -26,101 +26,95 @@
namespace doris::vectorized {
/// min_by, max_by
-template <template <typename, bool> class AggregateFunctionTemplate,
+template <template <typename> class AggregateFunctionTemplate,
template <typename, typename> class Data, typename VT>
static IAggregateFunction* create_aggregate_function_min_max_by_impl(
- const DataTypes& argument_types) {
- const DataTypePtr& value_arg_type = argument_types[0];
- const DataTypePtr& key_arg_type = argument_types[1];
+ const DataTypes& argument_types, const bool result_is_nullable) {
+ WhichDataType which(remove_nullable(argument_types[1]));
- WhichDataType which(key_arg_type);
-#define DISPATCH(TYPE)
\
- if (which.idx == TypeIndex::TYPE)
\
- return new AggregateFunctionTemplate<Data<VT,
SingleValueDataFixed<TYPE>>, false>( \
- value_arg_type, key_arg_type);
+#define DISPATCH(TYPE)
\
+ if (which.idx == TypeIndex::TYPE)
\
+ return creator_without_type::create<
\
+ AggregateFunctionTemplate<Data<VT,
SingleValueDataFixed<TYPE>>>>( \
+ result_is_nullable, argument_types);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
+
+#define DISPATCH(TYPE)
\
+ if (which.idx == TypeIndex::TYPE)
\
+ return creator_without_type::create<
\
+ AggregateFunctionTemplate<Data<VT,
SingleValueDataDecimal<TYPE>>>>( \
+ result_is_nullable, argument_types);
+ FOR_DECIMAL_TYPES(DISPATCH)
+#undef DISPATCH
+
if (which.idx == TypeIndex::String) {
- return new AggregateFunctionTemplate<Data<VT, SingleValueDataString>,
false>(value_arg_type,
-
key_arg_type);
+ return creator_without_type::create<
+ AggregateFunctionTemplate<Data<VT,
SingleValueDataString>>>(result_is_nullable,
+
argument_types);
}
if (which.idx == TypeIndex::DateTime || which.idx == TypeIndex::Date) {
- return new AggregateFunctionTemplate<Data<VT,
SingleValueDataFixed<Int64>>, false>(
- value_arg_type, key_arg_type);
+ return creator_without_type::create<
+ AggregateFunctionTemplate<Data<VT,
SingleValueDataFixed<Int64>>>>(
+ result_is_nullable, argument_types);
}
if (which.idx == TypeIndex::DateV2) {
- return new AggregateFunctionTemplate<Data<VT,
SingleValueDataFixed<UInt32>>, false>(
- value_arg_type, key_arg_type);
- }
- if (which.idx == TypeIndex::Decimal32) {
- return new AggregateFunctionTemplate<Data<VT,
SingleValueDataDecimal<Decimal32>>, false>(
- value_arg_type, key_arg_type);
- }
- if (which.idx == TypeIndex::Decimal64) {
- return new AggregateFunctionTemplate<Data<VT,
SingleValueDataDecimal<Decimal64>>, false>(
- value_arg_type, key_arg_type);
+ return creator_without_type::create<
+ AggregateFunctionTemplate<Data<VT,
SingleValueDataFixed<UInt32>>>>(
+ result_is_nullable, argument_types);
}
- if (which.idx == TypeIndex::Decimal128) {
- return new AggregateFunctionTemplate<Data<VT,
SingleValueDataDecimal<Decimal128>>, false>(
- value_arg_type, key_arg_type);
- }
- if (which.idx == TypeIndex::Decimal128I) {
- return new AggregateFunctionTemplate<Data<VT,
SingleValueDataDecimal<Decimal128I>>, false>(
- value_arg_type, key_arg_type);
+ if (which.idx == TypeIndex::DateTimeV2) {
+ return creator_without_type::create<
+ AggregateFunctionTemplate<Data<VT,
SingleValueDataFixed<UInt64>>>>(
+ result_is_nullable, argument_types);
}
return nullptr;
}
/// min_by, max_by
-template <template <typename, bool> class AggregateFunctionTemplate,
+template <template <typename> class AggregateFunctionTemplate,
template <typename, typename> class Data>
static IAggregateFunction* create_aggregate_function_min_max_by(const String&
name,
- const
DataTypes& argument_types) {
+ const
DataTypes& argument_types,
+ const bool
result_is_nullable) {
assert_binary(name, argument_types);
- const DataTypePtr& value_arg_type = argument_types[0];
-
- WhichDataType which(value_arg_type);
+ WhichDataType which(remove_nullable(argument_types[0]));
#define DISPATCH(TYPE)
\
if (which.idx == TypeIndex::TYPE)
\
return
create_aggregate_function_min_max_by_impl<AggregateFunctionTemplate, Data, \
SingleValueDataFixed<TYPE>>( \
- argument_types);
+ argument_types, result_is_nullable);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
+
+#define DISPATCH(TYPE)
\
+ if (which.idx == TypeIndex::TYPE)
\
+ return
create_aggregate_function_min_max_by_impl<AggregateFunctionTemplate, Data, \
+
SingleValueDataDecimal<TYPE>>( \
+ argument_types, result_is_nullable);
+ FOR_DECIMAL_TYPES(DISPATCH)
+#undef DISPATCH
+
if (which.idx == TypeIndex::String) {
return
create_aggregate_function_min_max_by_impl<AggregateFunctionTemplate, Data,
-
SingleValueDataString>(argument_types);
+
SingleValueDataString>(argument_types,
+
result_is_nullable);
}
if (which.idx == TypeIndex::DateTime || which.idx == TypeIndex::Date) {
return
create_aggregate_function_min_max_by_impl<AggregateFunctionTemplate, Data,
SingleValueDataFixed<Int64>>(
- argument_types);
+ argument_types, result_is_nullable);
}
if (which.idx == TypeIndex::DateV2) {
return
create_aggregate_function_min_max_by_impl<AggregateFunctionTemplate, Data,
SingleValueDataFixed<UInt32>>(
- argument_types);
- }
- if (which.idx == TypeIndex::Decimal128) {
- return
create_aggregate_function_min_max_by_impl<AggregateFunctionTemplate, Data,
-
SingleValueDataDecimal<Decimal128>>(
- argument_types);
- }
- if (which.idx == TypeIndex::Decimal32) {
- return
create_aggregate_function_min_max_by_impl<AggregateFunctionTemplate, Data,
-
SingleValueDataDecimal<Decimal32>>(
- argument_types);
- }
- if (which.idx == TypeIndex::Decimal64) {
- return
create_aggregate_function_min_max_by_impl<AggregateFunctionTemplate, Data,
-
SingleValueDataDecimal<Decimal64>>(
- argument_types);
+ argument_types, result_is_nullable);
}
- if (which.idx == TypeIndex::Decimal128I) {
+ if (which.idx == TypeIndex::DateTimeV2) {
return
create_aggregate_function_min_max_by_impl<AggregateFunctionTemplate, Data,
-
SingleValueDataDecimal<Decimal128I>>(
- argument_types);
+
SingleValueDataFixed<UInt64>>(
+ argument_types, result_is_nullable);
}
return nullptr;
}
@@ -128,22 +122,22 @@ static IAggregateFunction*
create_aggregate_function_min_max_by(const String& na
AggregateFunctionPtr create_aggregate_function_max_by(const std::string& name,
const DataTypes&
argument_types,
const bool
result_is_nullable) {
- return AggregateFunctionPtr(
- create_aggregate_function_min_max_by<AggregateFunctionsMinMaxBy,
-
AggregateFunctionMaxByData>(name, argument_types));
+ return
AggregateFunctionPtr(create_aggregate_function_min_max_by<AggregateFunctionsMinMaxBy,
+
AggregateFunctionMaxByData>(
+ name, argument_types, result_is_nullable));
}
AggregateFunctionPtr create_aggregate_function_min_by(const std::string& name,
const DataTypes&
argument_types,
const bool
result_is_nullable) {
- return AggregateFunctionPtr(
- create_aggregate_function_min_max_by<AggregateFunctionsMinMaxBy,
-
AggregateFunctionMinByData>(name, argument_types));
+ return
AggregateFunctionPtr(create_aggregate_function_min_max_by<AggregateFunctionsMinMaxBy,
+
AggregateFunctionMinByData>(
+ name, argument_types, result_is_nullable));
}
void register_aggregate_function_min_max_by(AggregateFunctionSimpleFactory&
factory) {
- factory.register_function("max_by", create_aggregate_function_max_by);
- factory.register_function("min_by", create_aggregate_function_min_by);
+ factory.register_function_both("max_by", create_aggregate_function_max_by);
+ factory.register_function_both("min_by", create_aggregate_function_min_by);
}
} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max_by.h
b/be/src/vec/aggregate_functions/aggregate_function_min_max_by.h
index b25e771862..28133dbb5d 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_min_max_by.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_min_max_by.h
@@ -95,24 +95,22 @@ struct AggregateFunctionMinByData : public
AggregateFunctionMinMaxByBaseData<VT,
static const char* name() { return "min_by"; }
};
-template <typename Data, bool AllocatesMemoryInArena>
+template <typename Data>
class AggregateFunctionsMinMaxBy final
- : public IAggregateFunctionDataHelper<
- Data, AggregateFunctionsMinMaxBy<Data,
AllocatesMemoryInArena>> {
+ : public IAggregateFunctionDataHelper<Data,
AggregateFunctionsMinMaxBy<Data>> {
private:
DataTypePtr& value_type;
DataTypePtr& key_type;
public:
- AggregateFunctionsMinMaxBy(const DataTypePtr& value_type_, const
DataTypePtr& key_type_)
- : IAggregateFunctionDataHelper<
- Data, AggregateFunctionsMinMaxBy<Data,
AllocatesMemoryInArena>>(
- {value_type_, key_type_}),
+ AggregateFunctionsMinMaxBy(const DataTypes& arguments)
+ : IAggregateFunctionDataHelper<Data,
AggregateFunctionsMinMaxBy<Data>>(
+ {arguments[0], arguments[1]}),
value_type(this->argument_types[0]),
key_type(this->argument_types[1]) {
if (StringRef(Data::name()) == StringRef("min_by") ||
StringRef(Data::name()) == StringRef("max_by")) {
- CHECK(key_type_->is_comparable());
+ CHECK(key_type->is_comparable());
}
}
@@ -141,8 +139,6 @@ public:
this->data(place).read(buf);
}
- bool allocates_memory_in_arena() const override { return
AllocatesMemoryInArena; }
-
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn&
to) const override {
this->data(place).insert_result_into(to);
}
diff --git
a/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.cpp
b/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.cpp
index 579fe930cf..d894a5be06 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.cpp
@@ -85,26 +85,12 @@ AggregateFunctionPtr
create_aggregate_function_orthogonal_bitmap_union_count(
}
void
register_aggregate_function_orthogonal_bitmap(AggregateFunctionSimpleFactory&
factory) {
- factory.register_function("orthogonal_bitmap_intersect",
-
create_aggregate_function_orthogonal_bitmap_intersect);
-
- factory.register_function("orthogonal_bitmap_intersect_count",
-
create_aggregate_function_orthogonal_bitmap_intersect_count);
-
- factory.register_function("orthogonal_bitmap_union_count",
-
create_aggregate_function_orthogonal_bitmap_union_count);
-
- factory.register_function("intersect_count",
create_aggregate_function_intersect_count);
-
- factory.register_function("orthogonal_bitmap_intersect",
-
create_aggregate_function_orthogonal_bitmap_intersect, true);
-
- factory.register_function("orthogonal_bitmap_intersect_count",
-
create_aggregate_function_orthogonal_bitmap_intersect_count, true);
-
- factory.register_function("orthogonal_bitmap_union_count",
-
create_aggregate_function_orthogonal_bitmap_union_count, true);
-
- factory.register_function("intersect_count",
create_aggregate_function_intersect_count, true);
+ factory.register_function_both("orthogonal_bitmap_intersect",
+
create_aggregate_function_orthogonal_bitmap_intersect);
+ factory.register_function_both("orthogonal_bitmap_intersect_count",
+
create_aggregate_function_orthogonal_bitmap_intersect_count);
+ factory.register_function_both("orthogonal_bitmap_union_count",
+
create_aggregate_function_orthogonal_bitmap_union_count);
+ factory.register_function_both("intersect_count",
create_aggregate_function_intersect_count);
}
} // namespace doris::vectorized
\ No newline at end of file
diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h
b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h
index 12298d8aa8..4ebc804d5d 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h
@@ -107,6 +107,11 @@ public:
}
}
+ void register_function_both(const std::string& name, const Creator&
creator) {
+ register_function(name, creator, false);
+ register_function(name, creator, true);
+ }
+
void register_alias(const std::string& name, const std::string& alias) {
function_alias[alias] = name;
}
diff --git a/be/src/vec/aggregate_functions/aggregate_function_sum.cpp
b/be/src/vec/aggregate_functions/aggregate_function_sum.cpp
index 7a08be0c1c..0f7b47193a 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_sum.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_sum.cpp
@@ -73,9 +73,8 @@ AggregateFunctionPtr
create_aggregate_function_sum_reader(const std::string& nam
}
void register_aggregate_function_sum(AggregateFunctionSimpleFactory& factory) {
- factory.register_function("sum",
create_aggregate_function_sum<AggregateFunctionSumSimple>);
- factory.register_function("sum",
create_aggregate_function_sum<AggregateFunctionSumSimple>,
- true);
+ factory.register_function_both("sum",
+
create_aggregate_function_sum<AggregateFunctionSumSimple>);
}
} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_topn.cpp
b/be/src/vec/aggregate_functions/aggregate_function_topn.cpp
index cb4224a322..c57ec934e5 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_topn.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_topn.cpp
@@ -26,10 +26,12 @@ AggregateFunctionPtr create_aggregate_function_topn(const
std::string& name,
const bool
result_is_nullable) {
if (argument_types.size() == 2) {
return AggregateFunctionPtr(
- new
AggregateFunctionTopN<AggregateFunctionTopNImplInt>(argument_types));
+
creator_without_type::create<AggregateFunctionTopN<AggregateFunctionTopNImplInt>>(
+ result_is_nullable, argument_types));
} else if (argument_types.size() == 3) {
- return AggregateFunctionPtr(
- new
AggregateFunctionTopN<AggregateFunctionTopNImplIntInt>(argument_types));
+ return AggregateFunctionPtr(creator_without_type::create<
+
AggregateFunctionTopN<AggregateFunctionTopNImplIntInt>>(
+ result_is_nullable, argument_types));
}
LOG(WARNING) << fmt::format("Illegal number {} of argument for aggregate
function {}",
@@ -39,44 +41,47 @@ AggregateFunctionPtr create_aggregate_function_topn(const
std::string& name,
template <template <typename, bool> class AggregateFunctionTemplate, bool
has_default_param,
bool is_weighted>
-AggregateFunctionPtr create_topn_array(const DataTypes& argument_types) {
- auto type = argument_types[0].get();
- if (type->is_nullable()) {
- type = assert_cast<const
DataTypeNullable*>(type)->get_nested_type().get();
- }
+AggregateFunctionPtr create_topn_array(const DataTypes& argument_types,
+ const bool result_is_nullable) {
+ WhichDataType which(remove_nullable(argument_types[0]));
- WhichDataType which(*type);
-
-#define DISPATCH(TYPE)
\
- if (which.idx == TypeIndex::TYPE)
\
- return AggregateFunctionPtr(
\
- new AggregateFunctionTopNArray<AggregateFunctionTemplate<TYPE,
has_default_param>, \
- TYPE,
is_weighted>(argument_types));
+#define DISPATCH(TYPE)
\
+ if (which.idx == TypeIndex::TYPE)
\
+ return AggregateFunctionPtr(
\
+ creator_without_type::create<AggregateFunctionTopNArray<
\
+ AggregateFunctionTemplate<TYPE, has_default_param>,
TYPE, is_weighted>>( \
+ result_is_nullable, argument_types));
FOR_NUMERIC_TYPES(DISPATCH)
+ FOR_DECIMAL_TYPES(DISPATCH)
#undef DISPATCH
+
if (which.is_string_or_fixed_string()) {
- return AggregateFunctionPtr(new AggregateFunctionTopNArray<
- AggregateFunctionTemplate<std::string,
has_default_param>,
- std::string, is_weighted>(argument_types));
- }
- if (which.is_decimal()) {
- return AggregateFunctionPtr(new AggregateFunctionTopNArray<
- AggregateFunctionTemplate<Decimal128,
has_default_param>,
- Decimal128, is_weighted>(argument_types));
+ return AggregateFunctionPtr(
+ creator_without_type::create<AggregateFunctionTopNArray<
+ AggregateFunctionTemplate<std::string,
has_default_param>, std::string,
+ is_weighted>>(result_is_nullable, argument_types));
}
- if (which.is_date_or_datetime() || which.is_date_time_v2()) {
+ if (which.is_date_or_datetime()) {
return AggregateFunctionPtr(
- new
AggregateFunctionTopNArray<AggregateFunctionTemplate<Int64, has_default_param>,
- Int64,
is_weighted>(argument_types));
+ creator_without_type::create<AggregateFunctionTopNArray<
+ AggregateFunctionTemplate<Int64, has_default_param>,
Int64, is_weighted>>(
+ result_is_nullable, argument_types));
}
if (which.is_date_v2()) {
return AggregateFunctionPtr(
- new
AggregateFunctionTopNArray<AggregateFunctionTemplate<UInt32, has_default_param>,
- UInt32,
is_weighted>(argument_types));
+ creator_without_type::create<AggregateFunctionTopNArray<
+ AggregateFunctionTemplate<UInt32, has_default_param>,
UInt32, is_weighted>>(
+ result_is_nullable, argument_types));
+ }
+ if (which.is_date_time_v2()) {
+ return AggregateFunctionPtr(
+ creator_without_type::create<AggregateFunctionTopNArray<
+ AggregateFunctionTemplate<UInt64, has_default_param>,
UInt64, is_weighted>>(
+ result_is_nullable, argument_types));
}
LOG(WARNING) << fmt::format("Illegal argument type for aggregate function
topn_array is: {}",
- type->get_name());
+
remove_nullable(argument_types[0])->get_name());
return nullptr;
}
@@ -85,9 +90,11 @@ AggregateFunctionPtr
create_aggregate_function_topn_array(const std::string& nam
const bool
result_is_nullable) {
bool has_default_param = (argument_types.size() == 3);
if (has_default_param) {
- return create_topn_array<AggregateFunctionTopNImplArray, true,
false>(argument_types);
+ return create_topn_array<AggregateFunctionTopNImplArray, true,
false>(argument_types,
+
result_is_nullable);
} else {
- return create_topn_array<AggregateFunctionTopNImplArray, false,
false>(argument_types);
+ return create_topn_array<AggregateFunctionTopNImplArray, false,
false>(argument_types,
+
result_is_nullable);
}
}
@@ -96,16 +103,18 @@ AggregateFunctionPtr
create_aggregate_function_topn_weighted(const std::string&
const bool
result_is_nullable) {
bool has_default_param = (argument_types.size() == 4);
if (has_default_param) {
- return create_topn_array<AggregateFunctionTopNImplWeight, true,
true>(argument_types);
+ return create_topn_array<AggregateFunctionTopNImplWeight, true,
true>(argument_types,
+
result_is_nullable);
} else {
- return create_topn_array<AggregateFunctionTopNImplWeight, false,
true>(argument_types);
+ return create_topn_array<AggregateFunctionTopNImplWeight, false,
true>(argument_types,
+
result_is_nullable);
}
}
void register_aggregate_function_topn(AggregateFunctionSimpleFactory& factory)
{
- factory.register_function("topn", create_aggregate_function_topn);
- factory.register_function("topn_array",
create_aggregate_function_topn_array);
- factory.register_function("topn_weighted",
create_aggregate_function_topn_weighted);
+ factory.register_function_both("topn", create_aggregate_function_topn);
+ factory.register_function_both("topn_array",
create_aggregate_function_topn_array);
+ factory.register_function_both("topn_weighted",
create_aggregate_function_topn_weighted);
}
} // namespace doris::vectorized
\ No newline at end of file
diff --git a/be/src/vec/aggregate_functions/aggregate_function_uniq.cpp
b/be/src/vec/aggregate_functions/aggregate_function_uniq.cpp
index 399fdb6317..18bd119a21 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_uniq.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_uniq.cpp
@@ -70,8 +70,7 @@ AggregateFunctionPtr create_aggregate_function_uniq(const
std::string& name,
void register_aggregate_function_uniq(AggregateFunctionSimpleFactory& factory)
{
AggregateFunctionCreator creator =
create_aggregate_function_uniq<AggregateFunctionUniqExactData>;
- factory.register_function("multi_distinct_count", creator);
- factory.register_function("multi_distinct_count", creator, true);
+ factory.register_function_both("multi_distinct_count", creator);
}
} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_window.cpp
b/be/src/vec/aggregate_functions/aggregate_function_window.cpp
index 7bb9e94524..a36b9601c2 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_window.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_window.cpp
@@ -114,14 +114,10 @@ void
register_aggregate_function_window_rank(AggregateFunctionSimpleFactory& fac
void register_aggregate_function_window_lead_lag_first_last(
AggregateFunctionSimpleFactory& factory) {
- factory.register_function("lead", create_aggregate_function_window_lead);
- factory.register_function("lead", create_aggregate_function_window_lead,
true);
- factory.register_function("lag", create_aggregate_function_window_lag);
- factory.register_function("lag", create_aggregate_function_window_lag,
true);
- factory.register_function("first_value",
create_aggregate_function_window_first);
- factory.register_function("first_value",
create_aggregate_function_window_first, true);
- factory.register_function("last_value",
create_aggregate_function_window_last);
- factory.register_function("last_value",
create_aggregate_function_window_last, true);
+ factory.register_function_both("lead",
create_aggregate_function_window_lead);
+ factory.register_function_both("lag",
create_aggregate_function_window_lag);
+ factory.register_function_both("first_value",
create_aggregate_function_window_first);
+ factory.register_function_both("last_value",
create_aggregate_function_window_last);
}
} // namespace doris::vectorized
\ No newline at end of file
diff --git a/be/src/vec/aggregate_functions/helpers.h
b/be/src/vec/aggregate_functions/helpers.h
index e67e0976d7..7a811871ed 100644
--- a/be/src/vec/aggregate_functions/helpers.h
+++ b/be/src/vec/aggregate_functions/helpers.h
@@ -25,12 +25,8 @@
#include "vec/data_types/data_type.h"
#include "vec/utils/template_helpers.hpp"
-// TODO: Should we support decimal in numeric types?
#define FOR_INTEGER_TYPES(M) \
M(UInt8) \
- M(UInt16) \
- M(UInt32) \
- M(UInt64) \
M(Int8) \
M(Int16) \
M(Int32) \
@@ -53,49 +49,50 @@
namespace doris::vectorized {
-/** Create an aggregate function with a numeric type in the template
parameter, depending on the type of the argument.
- */
-template <template <typename> class AggregateFunctionTemplate, typename Type>
-struct BuilderDirect {
- using T = AggregateFunctionTemplate<Type>;
-};
-template <template <typename> class AggregateFunctionTemplate, template
<typename> class Data,
- typename Type>
-struct BuilderData {
- using T = AggregateFunctionTemplate<Data<Type>>;
-};
-template <template <typename> class AggregateFunctionTemplate, template
<typename> class Data,
- template <typename> class Impl, typename Type>
-struct BuilderDataImpl {
- using T = AggregateFunctionTemplate<Data<Impl<Type>>>;
-};
-template <template <typename, typename> class AggregateFunctionTemplate,
- template <typename> class Data, typename Type>
-struct BuilderDirectAndData {
- using T = AggregateFunctionTemplate<Type, Data<Type>>;
+struct creator_without_type {
+ template <bool multi_arguments, bool f, typename T>
+ using NullableT = std::conditional_t<multi_arguments,
AggregateFunctionNullVariadicInline<T, f>,
+ AggregateFunctionNullUnaryInline<T,
f>>;
+
+ template <typename AggregateFunctionTemplate, typename... TArgs>
+ static IAggregateFunction* create(const bool result_is_nullable,
+ const DataTypes& argument_types,
TArgs&&... args) {
+ IAggregateFunction* result(new
AggregateFunctionTemplate(std::forward<TArgs>(args)...,
+
remove_nullable(argument_types)));
+ if (have_nullable(argument_types)) {
+ std::visit(
+ [&](auto multi_arguments, auto result_is_nullable) {
+ result = new NullableT<multi_arguments,
result_is_nullable,
+
AggregateFunctionTemplate>(result, argument_types);
+ },
+ make_bool_variant(argument_types.size() > 1),
+ make_bool_variant(result_is_nullable));
+ }
+ return result;
+ }
};
template <template <typename> class AggregateFunctionTemplate>
struct CurryDirect {
template <typename Type>
- using Builder = BuilderDirect<AggregateFunctionTemplate, Type>;
+ using T = AggregateFunctionTemplate<Type>;
};
template <template <typename> class AggregateFunctionTemplate, template
<typename> class Data>
struct CurryData {
template <typename Type>
- using Builder = BuilderData<AggregateFunctionTemplate, Data, Type>;
+ using T = AggregateFunctionTemplate<Data<Type>>;
};
template <template <typename> class AggregateFunctionTemplate, template
<typename> class Data,
template <typename> class Impl>
struct CurryDataImpl {
template <typename Type>
- using Builder = BuilderDataImpl<AggregateFunctionTemplate, Data, Impl,
Type>;
+ using T = AggregateFunctionTemplate<Data<Impl<Type>>>;
};
template <template <typename, typename> class AggregateFunctionTemplate,
template <typename> class Data>
struct CurryDirectAndData {
template <typename Type>
- using Builder = BuilderDirectAndData<AggregateFunctionTemplate, Data,
Type>;
+ using T = AggregateFunctionTemplate<Type, Data<Type>>;
};
template <bool allow_integer, bool allow_float, bool allow_decimal, int
define_index = 0>
@@ -104,35 +101,10 @@ struct creator_with_type_base {
static IAggregateFunction* create_base(const bool result_is_nullable,
const DataTypes& argument_types,
TArgs&&... args) {
WhichDataType which(remove_nullable(argument_types[define_index]));
-#define DISPATCH(TYPE)
\
- if (which.idx == TypeIndex::TYPE) {
\
- using T = typename Class::template Builder<TYPE>::T;
\
- if (have_nullable(argument_types)) {
\
- IAggregateFunction* result = nullptr;
\
- if (argument_types.size() > 1) {
\
- std::visit(
\
- [&](auto result_is_nullable) {
\
- result = new
AggregateFunctionNullVariadicInline<T, \
-
result_is_nullable>( \
- new T(std::forward<TArgs>(args)...,
\
- remove_nullable(argument_types)),
\
- argument_types);
\
- },
\
- make_bool_variant(result_is_nullable));
\
- } else {
\
- std::visit(
\
- [&](auto result_is_nullable) {
\
- result = new AggregateFunctionNullUnaryInline<T,
result_is_nullable>( \
- new T(std::forward<TArgs>(args)...,
\
- remove_nullable(argument_types)),
\
- argument_types);
\
- },
\
- make_bool_variant(result_is_nullable));
\
- }
\
- return result;
\
- } else {
\
- return new T(std::forward<TArgs>(args)..., argument_types);
\
- }
\
+#define DISPATCH(TYPE)
\
+ if (which.idx == TypeIndex::TYPE) {
\
+ return creator_without_type::create<typename Class::template T<TYPE>>(
\
+ result_is_nullable, argument_types,
std::forward<TArgs>(args)...); \
}
if constexpr (allow_integer) {
@@ -180,37 +152,4 @@ using creator_with_numeric_type =
creator_with_type_base<true, true, false>;
using creator_with_decimal_type = creator_with_type_base<false, false, true>;
using creator_with_type = creator_with_type_base<true, true, true>;
-struct creator_without_type {
- template <typename AggregateFunctionTemplate, typename... TArgs>
- static IAggregateFunction* create(const bool result_is_nullable,
- const DataTypes& argument_types,
TArgs&&... args) {
- if (have_nullable(argument_types)) {
- IAggregateFunction* result = nullptr;
- if (argument_types.size() > 1) {
- std::visit(
- [&](auto result_is_nullable) {
- result = new AggregateFunctionNullVariadicInline<
- AggregateFunctionTemplate,
result_is_nullable>(
- new
AggregateFunctionTemplate(std::forward<TArgs>(args)...,
-
remove_nullable(argument_types)),
- argument_types);
- },
- make_bool_variant(result_is_nullable));
- } else {
- std::visit(
- [&](auto result_is_nullable) {
- result = new
AggregateFunctionNullUnaryInline<AggregateFunctionTemplate,
-
result_is_nullable>(
- new
AggregateFunctionTemplate(std::forward<TArgs>(args)...,
-
remove_nullable(argument_types)),
- argument_types);
- },
- make_bool_variant(result_is_nullable));
- }
- return result;
- } else {
- return new AggregateFunctionTemplate(std::forward<TArgs>(args)...,
argument_types);
- }
- }
-};
} // namespace doris::vectorized
diff --git a/be/src/vec/core/types.h b/be/src/vec/core/types.h
index f8f0203208..177d2166ed 100644
--- a/be/src/vec/core/types.h
+++ b/be/src/vec/core/types.h
@@ -626,6 +626,15 @@ struct std::hash<doris::vectorized::Decimal128> {
}
};
+template <>
+struct std::hash<doris::vectorized::Decimal128I> {
+ size_t operator()(const
doris::vectorized::Decimal<doris::vectorized::Int128I>& x) const {
+ return std::hash<doris::vectorized::Int64>()(x.value >> 64) ^
+ std::hash<doris::vectorized::Int64>()(
+ x.value &
std::numeric_limits<doris::vectorized::UInt64>::max());
+ }
+};
+
constexpr bool typeindex_is_int(doris::vectorized::TypeIndex index) {
using TypeIndex = doris::vectorized::TypeIndex;
switch (index) {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]