This is an automated email from the ASF dual-hosted git repository.
michaelsmith pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/impala.git
The following commit(s) were added to refs/heads/master by this push:
new 4a645105f IMPALA-13658: Enable tuple caching aggregates
4a645105f is described below
commit 4a645105f920979759e75e136ffc6156fd268266
Author: Michael Smith <[email protected]>
AuthorDate: Wed Jan 8 15:46:01 2025 -0800
IMPALA-13658: Enable tuple caching aggregates
Enables tuple caching on aggregates directly above scan nodes. Caching
aggregates requires that their children are also eligible for caching,
so this excludes aggregates above an exchange, union, or hash join.
Testing:
- Adds Planner tests for different aggregate cases to confirm they have
stable tuple cache keys and are valid for caching.
- Adds custom cluster tests that cached aggregates are used, and can be
re-used in slightly different statements.
Change-Id: I9bd13c2813c90d23eb3a70f98068fdcdab97a885
Reviewed-on: http://gerrit.cloudera.org:8080/22322
Reviewed-by: Impala Public Jenkins <[email protected]>
Tested-by: Impala Public Jenkins <[email protected]>
---
.../apache/impala/analysis/FunctionCallExpr.java | 2 +
.../org/apache/impala/planner/AggregationNode.java | 34 ++++++++++--
.../org/apache/impala/planner/TupleCacheTest.java | 64 +++++++++++++++++++++-
tests/custom_cluster/test_tuple_cache.py | 46 ++++++++++++++++
4 files changed, 139 insertions(+), 7 deletions(-)
diff --git a/fe/src/main/java/org/apache/impala/analysis/FunctionCallExpr.java
b/fe/src/main/java/org/apache/impala/analysis/FunctionCallExpr.java
index f4f59df18..ee9e44121 100644
--- a/fe/src/main/java/org/apache/impala/analysis/FunctionCallExpr.java
+++ b/fe/src/main/java/org/apache/impala/analysis/FunctionCallExpr.java
@@ -412,6 +412,8 @@ public class FunctionCallExpr extends Expr {
// Session / system information
"coordinator", "current_database", "current_session",
"current_user",
"effective_user", "logged_in_user", "pid", "user", "version",
+ // Sampling aggregate functions
+ "appx_median",
// AI Functions
"ai_generate_text", "ai_generate_text_default");
return functionNameInBuiltinSet(fnName_, knownNondeterministicFns);
diff --git a/fe/src/main/java/org/apache/impala/planner/AggregationNode.java
b/fe/src/main/java/org/apache/impala/planner/AggregationNode.java
index f9e04c962..ceb2ffe08 100644
--- a/fe/src/main/java/org/apache/impala/planner/AggregationNode.java
+++ b/fe/src/main/java/org/apache/impala/planner/AggregationNode.java
@@ -41,9 +41,11 @@ import org.apache.impala.analysis.TupleDescriptor;
import org.apache.impala.analysis.TupleId;
import org.apache.impala.analysis.ValidTupleIdExpr;
import org.apache.impala.common.InternalException;
+import org.apache.impala.common.ThriftSerializationCtx;
import org.apache.impala.thrift.QueryConstants;
import org.apache.impala.thrift.TAggregationNode;
import org.apache.impala.thrift.TAggregator;
+import org.apache.impala.thrift.TBackendResourceProfile;
import org.apache.impala.thrift.TExplainLevel;
import org.apache.impala.thrift.TExpr;
import org.apache.impala.thrift.TPlanNode;
@@ -645,7 +647,7 @@ public class AggregationNode extends PlanNode {
private AggregationNode getPrevAggNode(AggregationNode aggNode) {
Preconditions.checkArgument(aggNode.getAggPhase() != AggPhase.FIRST);
PlanNode child = aggNode.getChild(0);
- if (child instanceof ExchangeNode) {
+ while (child instanceof ExchangeNode || child instanceof TupleCacheNode) {
child = child.getChild(0);
}
Preconditions.checkState(child instanceof AggregationNode);
@@ -782,24 +784,41 @@ public class AggregationNode extends PlanNode {
@Override
protected void toThrift(TPlanNode msg) {
+ Preconditions.checkState(false, "Unexpected use of old toThrift()
signature.");
+ }
+
+ @Override
+ protected void toThrift(TPlanNode msg, ThriftSerializationCtx serialCtx) {
msg.agg_node = new TAggregationNode();
msg.node_type = TPlanNodeType.AGGREGATION_NODE;
boolean replicateInput = aggPhase_ == AggPhase.FIRST && aggInfos_.size() >
1;
msg.agg_node.setReplicate_input(replicateInput);
- msg.agg_node.setEstimated_input_cardinality(getChild(0).getCardinality());
+ // Normalize input cardinality estimate for caching in case stats change.
+ // Cache key of scan ensures we detect changes to actual input data.
+ msg.agg_node.setEstimated_input_cardinality(serialCtx.isTupleCache() ?
+ 1 : getChild(0).getCardinality());
msg.agg_node.setFast_limit_check(canCompleteEarly());
for (int i = 0; i < aggInfos_.size(); ++i) {
AggregateInfo aggInfo = aggInfos_.get(i);
List<TExpr> aggregateFunctions = new ArrayList<>();
for (FunctionCallExpr e : aggInfo.getMaterializedAggregateExprs()) {
- aggregateFunctions.add(e.treeToThrift());
+ aggregateFunctions.add(e.treeToThrift(serialCtx));
}
+ // At the point when TupleCachePlanner runs, the resource profile has
not been
+ // calculated yet. They should not be in the cache key anyway, so mask
them out.
+ TBackendResourceProfile resourceProfile = serialCtx.isTupleCache()
+ ? ResourceProfile.noReservation(0).toThrift()
+ : resourceProfiles_.get(i).toThrift();
+ // Ensure both tuple IDs are registered. Only one is added to tupleIds_.
+ serialCtx.registerTuple(aggInfo.getIntermediateTupleId());
+ serialCtx.registerTuple(aggInfo.getOutputTupleId());
TAggregator taggregator = new TAggregator(aggregateFunctions,
- aggInfo.getIntermediateTupleId().asInt(),
aggInfo.getOutputTupleId().asInt(),
- needsFinalize_, useStreamingPreagg_,
resourceProfiles_.get(i).toThrift());
+ serialCtx.translateTupleId(aggInfo.getIntermediateTupleId()).asInt(),
+ serialCtx.translateTupleId(aggInfo.getOutputTupleId()).asInt(),
+ needsFinalize_, useStreamingPreagg_, resourceProfile);
List<Expr> groupingExprs = aggInfo.getGroupingExprs();
if (!groupingExprs.isEmpty()) {
- taggregator.setGrouping_exprs(Expr.treesToThrift(groupingExprs));
+ taggregator.setGrouping_exprs(Expr.treesToThrift(groupingExprs,
serialCtx));
}
msg.agg_node.addToAggregators(taggregator);
}
@@ -1107,4 +1126,7 @@ public class AggregationNode extends PlanNode {
return isSingleClassAgg() && hasLimit() && hasGrouping()
&& !multiAggInfo_.hasAggregateExprs() && getConjuncts().isEmpty();
}
+
+ @Override
+ public boolean isTupleCachingImplemented() { return true; }
}
diff --git a/fe/src/test/java/org/apache/impala/planner/TupleCacheTest.java
b/fe/src/test/java/org/apache/impala/planner/TupleCacheTest.java
index b5ccc05ac..35963501d 100644
--- a/fe/src/test/java/org/apache/impala/planner/TupleCacheTest.java
+++ b/fe/src/test/java/org/apache/impala/planner/TupleCacheTest.java
@@ -109,6 +109,64 @@ public class TupleCacheTest extends PlannerTestBase {
String.format(basicJoinTmpl, "probe.id = build.id and build.id <
100"));
}
+ @Test
+ public void testAggregateCacheKeys() {
+ // Scan and aggregate above scan are cached, aggregate above exchange is
not.
+ String basicAgg = "select count(*), count(tinyint_col), min(tinyint_col),
" +
+ "max(tinyint_col), sum(tinyint_col), avg(tinyint_col) " +
+ "from functional.alltypesagg";
+ verifyNIdenticalCacheKeys(basicAgg, basicAgg, 2);
+ // Scan and aggregate above scan are cached, aggregate above exchange is
not.
+ String groupingAgg = "select tinyint_col, bigint_col, count(*),
min(tinyint_col), " +
+ "max(tinyint_col), sum(tinyint_col), avg(tinyint_col) " +
+ "from functional.alltypesagg group by 2, 1";
+ verifyNIdenticalCacheKeys(groupingAgg, groupingAgg, 2);
+ // Scan and aggregate above scan are cached, later aggregates are not.
+ String distinctAgg = "select avg(l_quantity), ndv(l_discount), " +
+ "count(distinct l_partkey) from tpch_parquet.lineitem";
+ verifyNIdenticalCacheKeys(distinctAgg, distinctAgg, 2);
+ // Scan and aggregate above scan are cached, later aggregates are not.
+ String groupDistinctAgg = "select group_concat(distinct string_col) " +
+ "from functional.alltypesagg";
+ verifyNIdenticalCacheKeys(groupDistinctAgg, groupDistinctAgg, 2);
+ // Scan and the only aggregate are cached.
+ String havingAgg = "select 1 from functional.alltypestiny having count(*)
> 0";
+ verifyNIdenticalCacheKeys(havingAgg, havingAgg, 2);
+ // Scan and aggregate above scan are cached. All later aggregates are
above the
+ // exchange and thus not cached.
+ String twoPhaseAgg = "select bigint_col bc, count(smallint_col) c1, " +
+ "count(distinct int_col) c2 from functional.alltypessmall " +
+ "group by bigint_col order by bc";
+ verifyNIdenticalCacheKeys(twoPhaseAgg, twoPhaseAgg, 2);
+ // Build scan and aggregate above scan are cached. Probe scan caching is
invalid
+ // because runtime filter contains agg with exchange. Other aggregates are
above
+ // an exchange or hash join and thus not cached.
+ String rightJoinAgg = "with v1 as (select c_nationkey, c_custkey, count(*)
" +
+ "from tpch.customer group by c_nationkey, c_custkey), " +
+ "v2 as (select c_nationkey, c_custkey, count(*) from v1, tpch.orders "
+
+ "where c_custkey = o_custkey group by c_nationkey, c_custkey) " +
+ "select c_nationkey, count(*) from v2 group by c_nationkey";
+ verifyNIdenticalCacheKeys(rightJoinAgg, rightJoinAgg, 2);
+ // Both scans are cached, but aggregates above hash join and exchange are
not.
+ String innerJoinAgg = "select count(*) from functional.alltypes t1 inner
join " +
+ "functional.alltypestiny t2 on t1.smallint_col = t2.smallint_col group
by " +
+ "t1.tinyint_col, t2.smallint_col having count(t2.int_col) =
count(t1.bigint_col)";
+ verifyNIdenticalCacheKeys(innerJoinAgg, innerJoinAgg, 2);
+ // Both scans are cached, but aggregate is not because it's above a union.
Limit only
+ // applies to aggregate above exchange, which is obviously not cached.
+ String unionAgg = "select count(*) from (select * from functional.alltypes
" +
+ "union all select * from functional.alltypessmall) t limit 10";
+ verifyNIdenticalCacheKeys(unionAgg, unionAgg, 2);
+ // Only scan is cached, as aggregates are above an exchange and TOP-N.
+ String groupConcatGroupAgg = "select day, group_concat(distinct
string_col) " +
+ "from (select * from functional.alltypesagg where id % 100 = day order
by id " +
+ "limit 99999) a group by day";
+ verifyNIdenticalCacheKeys(groupConcatGroupAgg, groupConcatGroupAgg, 1);
+ // Only scan is cached, appx_median disables caching on aggregate.
+ String appxMedianAgg = "select appx_median(tinyint_col) from
functional.alltypesagg";
+ verifyNIdenticalCacheKeys(appxMedianAgg, appxMedianAgg, 1);
+ }
+
/**
* Test cases that rely on masking out unnecessary data to have cache hits.
*/
@@ -342,9 +400,13 @@ public class TupleCacheTest extends PlannerTestBase {
}
protected void verifyIdenticalCacheKeys(String query1, String query2) {
+ verifyNIdenticalCacheKeys(query1, query2, 1);
+ }
+
+ protected void verifyNIdenticalCacheKeys(String query1, String query2, int
n) {
List<PlanNode> cacheEligibleNodes1 = getCacheEligibleNodes(query1);
List<PlanNode> cacheEligibleNodes2 = getCacheEligibleNodes(query2);
- assertTrue(cacheEligibleNodes1.size() > 0);
+ assertTrue(cacheEligibleNodes1.size() >= n);
List<String> cacheKeys1 = getCacheKeys(cacheEligibleNodes1);
List<String> cacheKeys2 = getCacheKeys(cacheEligibleNodes2);
List<String> cacheHashTraces1 = getCacheHashTraces(cacheEligibleNodes1);
diff --git a/tests/custom_cluster/test_tuple_cache.py
b/tests/custom_cluster/test_tuple_cache.py
index 76ad6794a..2f2523496 100644
--- a/tests/custom_cluster/test_tuple_cache.py
+++ b/tests/custom_cluster/test_tuple_cache.py
@@ -58,6 +58,10 @@ def getCounterValues(profile, key):
return [int(v) for v in counter_str_list]
+def assertCounterOrder(profile, key, vals):
+ values = getCounterValues(profile, key)
+ assert values == vals, values
+
def assertCounter(profile, key, val, num_matches):
if not isinstance(num_matches, list):
num_matches = [num_matches]
@@ -355,6 +359,48 @@ class TestTupleCacheSingle(TestTupleCacheBase):
assertCounters(exempt2.runtime_profile, num_hits=1, num_halted=0,
num_skipped=0)
assertCounters(exempt3.runtime_profile, num_hits=1, num_halted=0,
num_skipped=0)
+ def test_aggregate(self, vector, unique_database):
+ """Simple aggregation can be cached"""
+ self.client.set_configuration(vector.get_value('exec_option'))
+ fq_table = "{0}.agg".format(unique_database)
+ self.create_table(fq_table)
+
+ result1 = self.execute_query("SELECT sum(age) FROM {0}".format(fq_table))
+ result2 = self.execute_query("SELECT sum(age) FROM {0}".format(fq_table))
+
+ assert result1.success
+ assert result2.success
+ assert result1.data == result2.data
+ assertCounters(result1.runtime_profile, 0, 0, 0, num_matches=2)
+ # Aggregate should hit, and scan node below it will miss.
+ assertCounterOrder(result2.runtime_profile, NUM_HITS, [1, 0])
+ assertCounter(result2.runtime_profile, NUM_HALTED, 0, num_matches=2)
+ assertCounter(result2.runtime_profile, NUM_SKIPPED, 0, num_matches=2)
+ # Verify that the bytes written by the first profile are the same as the
bytes
+ # read by the second profile.
+ bytes_written = getCounterValues(result1.runtime_profile,
"TupleCacheBytesWritten")
+ bytes_read = getCounterValues(result2.runtime_profile,
"TupleCacheBytesRead")
+ assert len(bytes_written) == 2
+ assert len(bytes_read) == 1
+ assert bytes_written[0] == bytes_read[0]
+
+ def test_aggregate_reuse(self, vector):
+ """Cached aggregation can be re-used"""
+ self.client.set_configuration(vector.get_value('exec_option'))
+
+ result = self.execute_query("SELECT sum(int_col) FROM functional.alltypes")
+ assert result.success
+ assertCounters(result.runtime_profile, 0, 0, 0, num_matches=2)
+
+ result_scan = self.execute_query("SELECT avg(int_col) FROM
functional.alltypes")
+ assert result_scan.success
+ assertCounterOrder(result_scan.runtime_profile, NUM_HITS, [0, 1])
+
+ result_agg = self.execute_query(
+ "SELECT avg(a) FROM (SELECT sum(int_col) as a FROM
functional.alltypes) b")
+ assert result_agg.success
+ assertCounterOrder(result_agg.runtime_profile, NUM_HITS, [1, 0])
+
@CustomClusterTestSuite.with_args(start_args=CACHE_START_ARGS)
class TestTupleCacheCluster(TestTupleCacheBase):