This is an automated email from the ASF dual-hosted git repository.
jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new 800220b860 VarianceAggregationFunction NULL support. (#11365)
800220b860 is described below
commit 800220b860a31243b179554c9d87732f9456e372
Author: Shen Yu <[email protected]>
AuthorDate: Tue Aug 22 17:55:46 2023 +0000
VarianceAggregationFunction NULL support. (#11365)
---
.../function/AggregationFunctionFactory.java | 8 +-
.../function/VarianceAggregationFunction.java | 84 +++++++++++---
.../queries/NullHandlingEnabledQueriesTest.java | 122 +++++++++++++++++++++
3 files changed, 196 insertions(+), 18 deletions(-)
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
index a79f01a0f1..0f03ee8723 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
@@ -317,13 +317,13 @@ public class AggregationFunctionFactory {
case BOOLOR:
return new BooleanOrAggregationFunction(firstArgument,
nullHandlingEnabled);
case VARPOP:
- return new VarianceAggregationFunction(firstArgument, false,
false);
+ return new VarianceAggregationFunction(firstArgument, false,
false, nullHandlingEnabled);
case VARSAMP:
- return new VarianceAggregationFunction(firstArgument, true, false);
+ return new VarianceAggregationFunction(firstArgument, true, false,
nullHandlingEnabled);
case STDDEVPOP:
- return new VarianceAggregationFunction(firstArgument, false, true);
+ return new VarianceAggregationFunction(firstArgument, false, true,
nullHandlingEnabled);
case STDDEVSAMP:
- return new VarianceAggregationFunction(firstArgument, true, true);
+ return new VarianceAggregationFunction(firstArgument, true, true,
nullHandlingEnabled);
case SKEWNESS:
return new FourthMomentAggregationFunction(firstArgument,
FourthMomentAggregationFunction.Type.SKEWNESS);
case KURTOSIS:
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/VarianceAggregationFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/VarianceAggregationFunction.java
index c86269b7ce..5498731442 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/VarianceAggregationFunction.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/VarianceAggregationFunction.java
@@ -29,6 +29,7 @@ import
org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder
import
org.apache.pinot.core.query.aggregation.utils.StatisticalAggregationFunctionUtils;
import org.apache.pinot.segment.local.customobject.VarianceTuple;
import org.apache.pinot.segment.spi.AggregationFunctionType;
+import org.roaringbitmap.RoaringBitmap;
/**
@@ -41,13 +42,15 @@ import org.apache.pinot.segment.spi.AggregationFunctionType;
public class VarianceAggregationFunction extends
BaseSingleInputAggregationFunction<VarianceTuple, Double> {
private static final double DEFAULT_FINAL_RESULT = Double.NEGATIVE_INFINITY;
protected final boolean _isSample;
-
protected final boolean _isStdDev;
+ protected final boolean _nullHandlingEnabled;
- public VarianceAggregationFunction(ExpressionContext expression, boolean
isSample, boolean isStdDev) {
+ public VarianceAggregationFunction(ExpressionContext expression, boolean
isSample, boolean isStdDev,
+ boolean nullHandlingEnabled) {
super(expression);
_isSample = isSample;
_isStdDev = isStdDev;
+ _nullHandlingEnabled = nullHandlingEnabled;
}
@Override
@@ -72,18 +75,38 @@ public class VarianceAggregationFunction extends
BaseSingleInputAggregationFunct
public void aggregate(int length, AggregationResultHolder
aggregationResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
double[] values =
StatisticalAggregationFunctionUtils.getValSet(blockValSetMap, _expression);
+ RoaringBitmap nullBitmap = null;
+ if (_nullHandlingEnabled) {
+ nullBitmap = blockValSetMap.get(_expression).getNullBitmap();
+ }
long count = 0;
double sum = 0.0;
double variance = 0.0;
- for (int i = 0; i < length; i++) {
- count++;
- sum += values[i];
- if (count > 1) {
- variance = computeIntermediateVariance(count, sum, variance,
values[i]);
+ if (nullBitmap != null && !nullBitmap.isEmpty()) {
+ for (int i = 0; i < length; i++) {
+ if (!nullBitmap.contains(i)) {
+ count++;
+ sum += values[i];
+ if (count > 1) {
+ variance = computeIntermediateVariance(count, sum, variance,
values[i]);
+ }
+ }
+ }
+ } else {
+ for (int i = 0; i < length; i++) {
+ count++;
+ sum += values[i];
+ if (count > 1) {
+ variance = computeIntermediateVariance(count, sum, variance,
values[i]);
+ }
}
}
- setAggregationResult(aggregationResultHolder, length, sum, variance);
+
+ if (_nullHandlingEnabled && count == 0) {
+ return;
+ }
+ setAggregationResult(aggregationResultHolder, count, sum, variance);
}
private double computeIntermediateVariance(long count, double sum, double
m2, double value) {
@@ -116,8 +139,20 @@ public class VarianceAggregationFunction extends
BaseSingleInputAggregationFunct
public void aggregateGroupBySV(int length, int[] groupKeyArray,
GroupByResultHolder groupByResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
double[] values =
StatisticalAggregationFunctionUtils.getValSet(blockValSetMap, _expression);
- for (int i = 0; i < length; i++) {
- setGroupByResult(groupKeyArray[i], groupByResultHolder, 1L, values[i],
0.0);
+ RoaringBitmap nullBitmap = null;
+ if (_nullHandlingEnabled) {
+ nullBitmap = blockValSetMap.get(_expression).getNullBitmap();
+ }
+ if (nullBitmap != null && !nullBitmap.isEmpty()) {
+ for (int i = 0; i < length; i++) {
+ if (!nullBitmap.contains(i)) {
+ setGroupByResult(groupKeyArray[i], groupByResultHolder, 1L,
values[i], 0.0);
+ }
+ }
+ } else {
+ for (int i = 0; i < length; i++) {
+ setGroupByResult(groupKeyArray[i], groupByResultHolder, 1L, values[i],
0.0);
+ }
}
}
@@ -125,9 +160,23 @@ public class VarianceAggregationFunction extends
BaseSingleInputAggregationFunct
public void aggregateGroupByMV(int length, int[][] groupKeysArray,
GroupByResultHolder groupByResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
double[] values =
StatisticalAggregationFunctionUtils.getValSet(blockValSetMap, _expression);
- for (int i = 0; i < length; i++) {
- for (int groupKey : groupKeysArray[i]) {
- setGroupByResult(groupKey, groupByResultHolder, 1L, values[i], 0.0);
+ RoaringBitmap nullBitmap = null;
+ if (_nullHandlingEnabled) {
+ nullBitmap = blockValSetMap.get(_expression).getNullBitmap();
+ }
+ if (nullBitmap != null && !nullBitmap.isEmpty()) {
+ for (int i = 0; i < length; i++) {
+ if (!nullBitmap.contains(i)) {
+ for (int groupKey : groupKeysArray[i]) {
+ setGroupByResult(groupKey, groupByResultHolder, 1L, values[i],
0.0);
+ }
+ }
+ }
+ } else {
+ for (int i = 0; i < length; i++) {
+ for (int groupKey : groupKeysArray[i]) {
+ setGroupByResult(groupKey, groupByResultHolder, 1L, values[i], 0.0);
+ }
}
}
}
@@ -136,7 +185,7 @@ public class VarianceAggregationFunction extends
BaseSingleInputAggregationFunct
public VarianceTuple extractAggregationResult(AggregationResultHolder
aggregationResultHolder) {
VarianceTuple varianceTuple = aggregationResultHolder.getResult();
if (varianceTuple == null) {
- return new VarianceTuple(0L, 0.0, 0.0);
+ return _nullHandlingEnabled ? null : new VarianceTuple(0L, 0.0, 0.0);
} else {
return varianceTuple;
}
@@ -149,6 +198,13 @@ public class VarianceAggregationFunction extends
BaseSingleInputAggregationFunct
@Override
public VarianceTuple merge(VarianceTuple intermediateResult1, VarianceTuple
intermediateResult2) {
+ if (_nullHandlingEnabled) {
+ if (intermediateResult1 == null) {
+ return intermediateResult2;
+ } else if (intermediateResult2 == null) {
+ return intermediateResult1;
+ }
+ }
intermediateResult1.apply(intermediateResult2);
return intermediateResult1;
}
diff --git
a/pinot-core/src/test/java/org/apache/pinot/queries/NullHandlingEnabledQueriesTest.java
b/pinot-core/src/test/java/org/apache/pinot/queries/NullHandlingEnabledQueriesTest.java
index b53bfa17fa..9830f35f7e 100644
---
a/pinot-core/src/test/java/org/apache/pinot/queries/NullHandlingEnabledQueriesTest.java
+++
b/pinot-core/src/test/java/org/apache/pinot/queries/NullHandlingEnabledQueriesTest.java
@@ -1400,4 +1400,126 @@ public class NullHandlingEnabledQueriesTest extends
BaseQueriesTest {
assertEquals(rows.size(), NUM_OF_SEGMENT_COPIES);
assertArrayEquals(rows.get(0), new Object[]{null});
}
+
+ @Test(dataProvider = "NumberTypes")
+ public void testStddevPop(FieldSpec.DataType dataType)
+ throws Exception {
+ initializeRows();
+ insertRow(null);
+ insertRow(1);
+ insertRow(2);
+ TableConfig tableConfig = new
TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME).build();
+ Schema schema = new
Schema.SchemaBuilder().addSingleValueDimension(COLUMN1, dataType).build();
+ setUpSegments(tableConfig, schema);
+ String query = String.format("SELECT STDDEV_POP(%s) FROM testTable",
COLUMN1);
+
+ BrokerResponseNative brokerResponse = getBrokerResponse(query,
QUERY_OPTIONS);
+
+ ResultTable resultTable = brokerResponse.getResultTable();
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 1);
+ assertEquals(rows.get(0)[0], 0.5);
+ }
+
+ @Test(dataProvider = "NumberTypes")
+ public void testGroupByStddevPop(FieldSpec.DataType dataType)
+ throws Exception {
+ initializeRows();
+ insertRowWithTwoColumns(null, "key");
+ insertRowWithTwoColumns(1, "key");
+ insertRowWithTwoColumns(2, "key");
+ TableConfig tableConfig = new
TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME).build();
+ Schema schema = new
Schema.SchemaBuilder().addSingleValueDimension(COLUMN1, dataType)
+ .addSingleValueDimension(COLUMN2, FieldSpec.DataType.STRING).build();
+ setUpSegments(tableConfig, schema);
+ String query = String.format("SELECT STDDEV_POP(%s), %s FROM testTable
GROUP BY %s", COLUMN1, COLUMN2, COLUMN2);
+
+ BrokerResponseNative brokerResponse = getBrokerResponse(query,
QUERY_OPTIONS);
+
+ ResultTable resultTable = brokerResponse.getResultTable();
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 1);
+ assertArrayEquals(rows.get(0), new Object[]{0.5, "key"});
+ }
+
+ @Test(dataProvider = "NumberTypes")
+ public void testGroupByMvStddevPop(FieldSpec.DataType dataType)
+ throws Exception {
+ initializeRows();
+ insertRowWithTwoColumns(null, new String[]{"key1", "key2"});
+ insertRowWithTwoColumns(1, new String[]{"key1", "key2"});
+ insertRowWithTwoColumns(2, new String[]{"key1"});
+ TableConfig tableConfig = new
TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME).build();
+ Schema schema = new
Schema.SchemaBuilder().addSingleValueDimension(COLUMN1, dataType)
+ .addMultiValueDimension(COLUMN2, FieldSpec.DataType.STRING).build();
+ setUpSegments(tableConfig, schema);
+ String query =
+ String.format("SELECT STDDEV_POP(%s), %s FROM testTable GROUP BY %s
ORDER BY %s", COLUMN1, COLUMN2, COLUMN2,
+ COLUMN2);
+
+ BrokerResponseNative brokerResponse = getBrokerResponse(query,
QUERY_OPTIONS);
+
+ ResultTable resultTable = brokerResponse.getResultTable();
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 2);
+ assertArrayEquals(rows.get(0), new Object[]{0.5, "key1"});
+ assertArrayEquals(rows.get(1), new Object[]{0.0, "key2"});
+ }
+
+ @Test
+ public void testAllNullGroupByStddevPopReturnsNull()
+ throws Exception {
+ initializeRows();
+ insertRowWithTwoColumns(null, "key1");
+ TableConfig tableConfig = new
TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME).build();
+ Schema schema = new
Schema.SchemaBuilder().addSingleValueDimension(COLUMN1, FieldSpec.DataType.INT)
+ .addSingleValueDimension(COLUMN2, FieldSpec.DataType.STRING).build();
+ setUpSegments(tableConfig, schema);
+ String query =
+ String.format("SELECT STDDEV_POP(%s), %s FROM testTable GROUP BY %s
ORDER BY %s", COLUMN1, COLUMN2, COLUMN2,
+ COLUMN2);
+
+ BrokerResponseNative brokerResponse = getBrokerResponse(query,
QUERY_OPTIONS);
+
+ ResultTable resultTable = brokerResponse.getResultTable();
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 1);
+ assertEquals(rows.get(0)[0], null);
+ }
+
+ @Test
+ public void testAllNullStddevPopReturnsNull()
+ throws Exception {
+ initializeRows();
+ insertRow(null);
+ TableConfig tableConfig = new
TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME).build();
+ Schema schema = new
Schema.SchemaBuilder().addSingleValueDimension(COLUMN1,
FieldSpec.DataType.DOUBLE).build();
+ setUpSegments(tableConfig, schema);
+ String query = String.format("SELECT STDDEV_POP(%s) FROM testTable",
COLUMN1);
+
+ BrokerResponseNative brokerResponse = getBrokerResponse(query,
QUERY_OPTIONS);
+
+ ResultTable resultTable = brokerResponse.getResultTable();
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 1);
+ assertEquals(rows.get(0)[0], null);
+ }
+
+ @Test
+ public void testNoMatchingRowNullHandlingDisabledStddevPopReturnsNull()
+ throws Exception {
+ initializeRows();
+ insertRow(1);
+ TableConfig tableConfig = new
TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME).build();
+ Schema schema = new
Schema.SchemaBuilder().addSingleValueDimension(COLUMN1,
FieldSpec.DataType.DOUBLE).build();
+ setUpSegments(tableConfig, schema);
+ String query = String.format("SELECT STDDEV_POP(%s) FROM testTable WHERE
%s != 1", COLUMN1, COLUMN1);
+
+ BrokerResponseNative brokerResponse = getBrokerResponse(query);
+
+ ResultTable resultTable = brokerResponse.getResultTable();
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 1);
+ assertEquals(rows.get(0)[0], Double.NEGATIVE_INFINITY);
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]