This is an automated email from the ASF dual-hosted git repository. yiguolei pushed a commit to branch branch-4.1 in repository https://gitbox.apache.org/repos/asf/doris.git
commit f980f13fa940b2e98ca83b711cc85b73289be4b9 Author: minghong <[email protected]> AuthorDate: Fri Mar 13 19:54:15 2026 +0800 branch-4.1 [feature](nereids) add RewriteSimpleAggToConstantRule to rewrite simple agg to constant (#61306) ### What problem does this PR solve? pick#61183 Issue Number: close #xxx Related PR: #xxx Problem Summary: ### Release note None ### Check List (For Author) - Test <!-- At least one of them must be included. --> - [ ] Regression test - [ ] Unit Test - [ ] Manual test (add detailed scripts or steps below) - [ ] No need to test or manual test. Explain why: - [ ] This is a refactor/code format and no logic has been changed. - [ ] Previous test can cover this change. - [ ] No code files have been changed. - [ ] Other reason <!-- Add your reason? --> - Behavior changed: - [ ] No. - [ ] Yes. <!-- Explain the behavior change --> - Does this need documentation? - [ ] No. - [ ] Yes. <!-- Add document PR link here. eg: https://github.com/apache/doris-website/pull/1214 --> ### Check List (For Reviewer who merge this PR) - [ ] Confirm the release note - [ ] Confirm test cases - [ ] Confirm document - [ ] Add branch pick label <!-- Add branch pick label that this PR should merge into --> --- .../java/org/apache/doris/catalog/OlapTable.java | 5 + .../org/apache/doris/catalog/TableAttributes.java | 5 + .../apache/doris/datasource/InternalCatalog.java | 4 + .../doris/nereids/jobs/executor/Rewriter.java | 3 + .../org/apache/doris/nereids/rules/RuleType.java | 1 + .../rewrite/RewriteSimpleAggToConstantRule.java | 267 ++++++++++ .../doris/nereids/stats/SimpleAggCacheMgr.java | 546 +++++++++++++++++++++ .../RewriteSimpleAggToConstantRuleTest.java | 294 +++++++++++ .../rewrite_simple_agg_to_constant.out | 37 ++ .../truncate_version_reset.out | 7 + .../agg_use_key_direct/agg_use_key_direct.groovy | 2 +- .../suites/nereids_p0/hint/test_hint.groovy | 2 +- .../rewrite_simple_agg_to_constant.groovy | 317 ++++++++++++ .../truncate_version_reset.groovy | 112 +++++ 14 files changed, 1600 insertions(+), 2 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/OlapTable.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/OlapTable.java index 74bffbbd392..9d4904370c8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/OlapTable.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/OlapTable.java @@ -3302,6 +3302,11 @@ public class OlapTable extends Table implements MTMVRelatedTableIf, GsonPostProc tableAttributes.updateVisibleVersionAndTime(visibleVersion, visibleVersionTime); } + public void resetVisibleVersion() { + LOG.info("resetVisibleVersion, tableName: {}", name); + tableAttributes.resetVisibleVersion(); + } + // During `getNextVersion` and `updateVisibleVersionAndTime` period, // the write lock on the table should be held continuously public long getNextVersion() { diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/TableAttributes.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/TableAttributes.java index 5edcc8704ff..b581ffd4713 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/TableAttributes.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/TableAttributes.java @@ -66,4 +66,9 @@ public class TableAttributes { public long getNextVersion() { return visibleVersion + 1; } + + public void resetVisibleVersion() { + this.visibleVersion = TABLE_INIT_VERSION; + this.visibleVersionTime = System.currentTimeMillis(); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/datasource/InternalCatalog.java b/fe/fe-core/src/main/java/org/apache/doris/datasource/InternalCatalog.java index 02f9ab3c672..9a719c8bef1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/datasource/InternalCatalog.java +++ b/fe/fe-core/src/main/java/org/apache/doris/datasource/InternalCatalog.java @@ -3782,6 +3782,10 @@ public class InternalCatalog implements CatalogIf<Database> { olapTable.dropPartitionForTruncate(olapTable.getDatabase().getId(), isforceDrop, pair.getValue()); } + // Reset table-level visibleVersion to TABLE_INIT_VERSION so it stays consistent + // with the newly created partitions (which also start at PARTITION_INIT_VERSION). + olapTable.resetVisibleVersion(); + return oldPartitions; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index 80d55db9a7e..342df5d8a05 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -156,6 +156,7 @@ import org.apache.doris.nereids.rules.rewrite.ReduceAggregateChildOutputRows; import org.apache.doris.nereids.rules.rewrite.ReorderJoin; import org.apache.doris.nereids.rules.rewrite.RewriteCteChildren; import org.apache.doris.nereids.rules.rewrite.RewriteSearchToSlots; +import org.apache.doris.nereids.rules.rewrite.RewriteSimpleAggToConstantRule; import org.apache.doris.nereids.rules.rewrite.SaltJoin; import org.apache.doris.nereids.rules.rewrite.SetPreAggStatus; import org.apache.doris.nereids.rules.rewrite.SimplifyEncodeDecode; @@ -283,6 +284,7 @@ public class Rewriter extends AbstractBatchJobExecutor { topDown( new NormalizeAggregate(), new CountLiteralRewrite(), + new RewriteSimpleAggToConstantRule(), new NormalizeSort() ), @@ -519,6 +521,7 @@ public class Rewriter extends AbstractBatchJobExecutor { topDown( new NormalizeAggregate(), new CountLiteralRewrite(), + new RewriteSimpleAggToConstantRule(), new NormalizeSort() ), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 9ce73a3df4d..0f20c8164fc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -74,6 +74,7 @@ public enum RuleType { COUNT_LITERAL_REWRITE(RuleTypeClass.REWRITE), SUM_LITERAL_REWRITE(RuleTypeClass.REWRITE), + REWRITE_SIMPLE_AGG_TO_CONSTANT(RuleTypeClass.REWRITE), REPLACE_SORT_EXPRESSION_BY_CHILD_OUTPUT(RuleTypeClass.REWRITE), FILL_UP_HAVING_AGGREGATE(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/RewriteSimpleAggToConstantRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/RewriteSimpleAggToConstantRule.java new file mode 100644 index 00000000000..9347b7edb72 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/RewriteSimpleAggToConstantRule.java @@ -0,0 +1,267 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.analysis.LiteralExpr; +import org.apache.doris.catalog.Column; +import org.apache.doris.catalog.KeysType; +import org.apache.doris.catalog.OlapTable; +import org.apache.doris.nereids.StatementContext; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.stats.SimpleAggCacheMgr; +import org.apache.doris.nereids.stats.SimpleAggCacheMgr.ColumnMinMax; +import org.apache.doris.nereids.stats.SimpleAggCacheMgr.ColumnMinMaxKey; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.agg.Max; +import org.apache.doris.nereids.trees.expressions.functions.agg.Min; +import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.rpc.RpcException; +import org.apache.doris.statistics.util.StatisticsUtil; + +import com.google.common.collect.ImmutableList; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.OptionalLong; +import java.util.Set; + +/** + * For simple aggregation queries like + * 'select count(*), count(not-null column), min(col), max(col) from olap_table', + * rewrite them to return constants directly from FE metadata, completely bypassing BE. + * + * <p>COUNT uses table.getRowCount(). + * MIN/MAX uses an async FE-side cache ({@link SimpleAggCacheMgr}) that stores exact values + * obtained via internal SQL queries, NOT sampled ColumnStatistic. + * + * <p>Conditions: + * 1. DUP_KEYS table only (AGG_KEYS: rowCount inflated; UNIQUE_KEYS: min/max may be inaccurate + * in MoW model before compaction merges delete-marked rows) + * 2. No GROUP BY + * 3. Only COUNT / MIN / MAX aggregate functions + * 4. COUNT(col): col must be NOT NULL (so count(col) == rowCount) + * 5. MIN/MAX(col): col must be numeric or date type, and not aggregated + */ +public class RewriteSimpleAggToConstantRule implements RewriteRuleFactory { + + @Override + public List<Rule> buildRules() { + return ImmutableList.of( + // pattern: agg -> scan + logicalAggregate(logicalOlapScan()) + .thenApply(ctx -> { + LogicalAggregate<LogicalOlapScan> agg = ctx.root; + LogicalOlapScan olapScan = agg.child(); + return tryRewrite(agg, olapScan, ctx.statementContext); + }) + .toRule(RuleType.REWRITE_SIMPLE_AGG_TO_CONSTANT), + // pattern: agg -> project -> scan + logicalAggregate(logicalProject(logicalOlapScan())) + .thenApply(ctx -> { + LogicalAggregate<?> agg = ctx.root; + LogicalOlapScan olapScan = (LogicalOlapScan) ctx.root.child().child(); + return tryRewrite(agg, olapScan, ctx.statementContext); + }) + .toRule(RuleType.REWRITE_SIMPLE_AGG_TO_CONSTANT) + ); + } + + private Plan tryRewrite(LogicalAggregate<?> agg, LogicalOlapScan olapScan, + StatementContext statementContext) { + if (olapScan.isIndexSelected() + || !olapScan.getManuallySpecifiedPartitions().isEmpty() + || !olapScan.getManuallySpecifiedTabletIds().isEmpty() + || olapScan.getTableSample().isPresent()) { + return null; + } + OlapTable table = olapScan.getTable(); + + // Condition 1: DUP_KEYS only. + // - DUP_KEYS: FE rowCount equals actual count(*); min/max are accurate. + // - AGG_KEYS: rowCount may be inflated before full compaction. + // - UNIQUE_KEYS: in MoW model, min/max may include values from delete-marked rows + // not yet compacted, so the result could be inaccurate. + if (table.getKeysType() != KeysType.DUP_KEYS) { + return null; + } + + // Condition 2: No GROUP BY + if (!agg.getGroupByExpressions().isEmpty()) { + return null; + } + + // Condition 3: Only COUNT / MIN / MAX aggregate functions. + Set<AggregateFunction> funcs = agg.getAggregateFunctions(); + if (funcs.isEmpty()) { + return null; + } + for (AggregateFunction func : funcs) { + if (!(func instanceof Count) && !(func instanceof Min) && !(func instanceof Max)) { + return null; + } + } + + // Try to compute a constant for each output expression. + // If ANY one cannot be replaced, we give up the entire rewrite. + List<NamedExpression> newOutputExprs = new ArrayList<>(); + for (NamedExpression outputExpr : agg.getOutputExpressions()) { + if (!(outputExpr instanceof Alias)) { + // Unexpected: for no-group-by aggregates, outputs should all be aliases over agg funcs + return null; + } + Alias alias = (Alias) outputExpr; + Expression child = alias.child(); + if (!(child instanceof AggregateFunction)) { + return null; + } + AggregateFunction func = (AggregateFunction) child; + Optional<Literal> constant = tryGetConstant(func, table); + if (!constant.isPresent()) { + // Cannot replace this agg function — give up + return null; + } + newOutputExprs.add(new Alias(alias.getExprId(), constant.get(), alias.getName())); + } + + if (newOutputExprs.isEmpty()) { + return null; + } + + // Build: LogicalProject(constants) -> LogicalOneRowRelation(dummy) + // The OneRowRelation provides a single-row source; all real values come from the project. + LogicalOneRowRelation oneRowRelation = new LogicalOneRowRelation( + statementContext.getNextRelationId(), + ImmutableList.of(new Alias(new NullLiteral(), "__dummy__"))); + return new LogicalProject<>(newOutputExprs, oneRowRelation); + } + + /** + * Try to compute a compile-time constant value for the given aggregate function, + * using FE-side cached row-counts (for COUNT) or exact min/max cache (for MIN/MAX). + * All values are obtained via internal SQL queries (SELECT count/min/max), + * NOT from BE tablet stats reporting, to avoid delayed-reporting and version issues. + */ + private Optional<Literal> tryGetConstant(AggregateFunction func, OlapTable table) { + if (func.isDistinct()) { + return Optional.empty(); + } + + // Use versionTime (always a cheap local read, never RPC) as the primary cache key. + // getVisibleVersion() may involve an RPC in cloud mode, so it is only fetched lazily + // as a fallback when versionTime alone is insufficient to confirm cache freshness. + long versionTime = table.getVisibleVersionTime(); + // --- COUNT --- + if (func instanceof Count) { + // Look up exact row count from the async cache. + // The count is obtained by executing "SELECT count(*) FROM table" internally, + // so it is accurate and versioned, unlike BE tablet stats reporting which + // has delayed-reporting and version-mismatch issues. + OptionalLong cachedCount = SimpleAggCacheMgr.internalInstance() + .getRowCount(table.getId(), versionTime, () -> getVisibleVersionOrUnknown(table)); + if (!cachedCount.isPresent()) { + return Optional.empty(); + } + long rowCount = cachedCount.getAsLong(); + if (func.getArguments().isEmpty()) { + // count(*) or count() + return Optional.of(new BigIntLiteral(rowCount)); + } + if (func.getArguments().size() == 1) { + Expression arg = func.getArguments().get(0); + if (arg instanceof SlotReference) { + Optional<Column> colOpt = ((SlotReference) arg).getOriginalColumn(); + // count(not-null col) == rowCount + if (colOpt.isPresent() && !colOpt.get().isAllowNull()) { + return Optional.of(new BigIntLiteral(rowCount)); + } + } + } + return Optional.empty(); + } + + // --- MIN / MAX --- + if (func instanceof Min || func instanceof Max) { + if (func.getArguments().size() != 1) { + return Optional.empty(); + } + Expression arg = func.getArguments().get(0); + if (!(arg instanceof SlotReference)) { + return Optional.empty(); + } + SlotReference slot = (SlotReference) arg; + Optional<Column> colOpt = slot.getOriginalColumn(); + if (!colOpt.isPresent()) { + return Optional.empty(); + } + Column column = colOpt.get(); + // Only numeric and date/datetime columns are supported + if (!column.getType().isNumericType() && !column.getType().isDateType()) { + return Optional.empty(); + } + // Aggregated columns cannot give correct min/max + if (column.isAggregated()) { + return Optional.empty(); + } + + // Look up exact min/max from the async cache + ColumnMinMaxKey cacheKey = new ColumnMinMaxKey(table.getId(), column.getName()); + Optional<ColumnMinMax> minMax = SimpleAggCacheMgr.internalInstance() + .getStats(cacheKey, versionTime, () -> getVisibleVersionOrUnknown(table)); + if (!minMax.isPresent()) { + return Optional.empty(); + } + + // Convert the string value to a Nereids Literal + try { + String value = (func instanceof Min) ? minMax.get().minValue() : minMax.get().maxValue(); + LiteralExpr legacyLiteral = StatisticsUtil.readableValue(column.getType(), value); + return Optional.of(Literal.fromLegacyLiteral(legacyLiteral, column.getType())); + } catch (Exception e) { + return Optional.empty(); + } + } + + return Optional.empty(); + } + + /** + * Returns the table's visibleVersion, or -1 if unavailable (e.g., RPC failure in cloud mode). + * This is only used as a fallback when versionTime comparison is insufficient. + */ + private long getVisibleVersionOrUnknown(OlapTable table) { + try { + return table.getVisibleVersion(); + } catch (RpcException e) { + return -1L; + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/SimpleAggCacheMgr.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/SimpleAggCacheMgr.java new file mode 100644 index 00000000000..cc9cf81d275 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/SimpleAggCacheMgr.java @@ -0,0 +1,546 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.stats; + +import org.apache.doris.catalog.Column; +import org.apache.doris.catalog.Env; +import org.apache.doris.catalog.OlapTable; +import org.apache.doris.catalog.TableIf; +import org.apache.doris.common.Config; +import org.apache.doris.common.ThreadPoolManager; +import org.apache.doris.qe.AutoCloseConnectContext; +import org.apache.doris.qe.StmtExecutor; +import org.apache.doris.statistics.ResultRow; +import org.apache.doris.statistics.util.StatisticsUtil; + +import com.github.benmanes.caffeine.cache.AsyncCacheLoader; +import com.github.benmanes.caffeine.cache.AsyncLoadingCache; +import com.github.benmanes.caffeine.cache.Caffeine; +import com.google.common.annotations.VisibleForTesting; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.checkerframework.checker.nullness.qual.NonNull; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.OptionalLong; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.function.LongSupplier; + +/** + * Async cache that stores exact MIN/MAX/COUNT values for OlapTable, + * used by {@code RewriteSimpleAggToConstantRule} to replace simple + * aggregations with constant values. + * + * <p>MIN/MAX values are obtained by executing + * {@code SELECT min(col), max(col) FROM table}, and COUNT values by + * {@code SELECT count(*) FROM table}, both as internal SQL queries + * inside FE. Results are cached with a version stamp derived from + * {@code OlapTable.getVisibleVersionTime()}. + * When a caller provides a versionTime newer than the cached versionTime, + * the stale entry is evicted and a background reload is triggered. + * + * <p>Cache validation uses a two-level check driven by {@code versionTime}: + * <ol> + * <li>If times differ: the table has definitely changed — evict immediately, + * without calling {@code getVisibleVersion()} (which may involve an RPC in cloud mode).</li> + * <li>If times match: call {@code getVisibleVersion()} once to compare versions, + * guarding against the rare case of two writes completing within the same millisecond.</li> + * </ol> + * + * <p>Only numeric and date/datetime columns are cached for MIN/MAX; + * aggregated columns are skipped. + */ +public class SimpleAggCacheMgr { + + // ======================== Public inner types ======================== + + /** + * Holds exact min and max values for a column as strings. + */ + public static class ColumnMinMax { + private final String minValue; + private final String maxValue; + + public ColumnMinMax(String minValue, String maxValue) { + this.minValue = minValue; + this.maxValue = maxValue; + } + + public String minValue() { + return minValue; + } + + public String maxValue() { + return maxValue; + } + } + + /** + * Cache key identifying a column by its table ID and column name. + */ + public static final class ColumnMinMaxKey { + private final long tableId; + private final String columnName; + + public ColumnMinMaxKey(long tableId, String columnName) { + this.tableId = tableId; + this.columnName = columnName; + } + + public long getTableId() { + return tableId; + } + + public String getColumnName() { + return columnName; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof ColumnMinMaxKey)) { + return false; + } + ColumnMinMaxKey that = (ColumnMinMaxKey) o; + return tableId == that.tableId && columnName.equalsIgnoreCase(that.columnName); + } + + @Override + public int hashCode() { + return Objects.hash(tableId, columnName.toLowerCase()); + } + + @Override + public String toString() { + return "ColumnMinMaxKey{tableId=" + tableId + ", column=" + columnName + "}"; + } + } + + private static class CacheValue { + private final ColumnMinMax minMax; + private final long version; + private final long versionTime; + + CacheValue(ColumnMinMax minMax, long version, long versionTime) { + this.minMax = minMax; + this.version = version; + this.versionTime = versionTime; + } + + ColumnMinMax minMax() { + return minMax; + } + + long version() { + return version; + } + + long versionTime() { + return versionTime; + } + } + + /** + * Cached row count with version stamp. + */ + private static class RowCountValue { + private final long rowCount; + private final long version; + private final long versionTime; + + RowCountValue(long rowCount, long version, long versionTime) { + this.rowCount = rowCount; + this.version = version; + this.versionTime = versionTime; + } + + long rowCount() { + return rowCount; + } + + long version() { + return version; + } + + long versionTime() { + return versionTime; + } + } + + private static final Logger LOG = LogManager.getLogger(SimpleAggCacheMgr.class); + + private static volatile SimpleAggCacheMgr INSTANCE; + private static volatile SimpleAggCacheMgr TEST_INSTANCE; + + private final AsyncLoadingCache<ColumnMinMaxKey, Optional<CacheValue>> cache; + private final AsyncLoadingCache<Long, Optional<RowCountValue>> rowCountCache; + + /** + * Protected no-arg constructor for test subclassing. + * Subclasses override {@link #getStats}, {@link #getRowCount}, etc. + */ + protected SimpleAggCacheMgr() { + this.cache = null; + this.rowCountCache = null; + } + + private SimpleAggCacheMgr(ExecutorService executor) { + this.cache = Caffeine.newBuilder() + .maximumSize(Config.stats_cache_size) + .executor(executor) + .buildAsync(new CacheLoader()); + this.rowCountCache = Caffeine.newBuilder() + .maximumSize(Config.stats_cache_size) + .executor(executor) + .buildAsync(new RowCountLoader()); + } + + private static SimpleAggCacheMgr getInstance() { + if (INSTANCE == null) { + synchronized (SimpleAggCacheMgr.class) { + if (INSTANCE == null) { + ExecutorService executor = ThreadPoolManager.newDaemonCacheThreadPool( + 4, "simple-agg-cache-pool", true); + INSTANCE = new SimpleAggCacheMgr(executor); + } + } + } + return INSTANCE; + } + + /** + * Returns the singleton instance backed by async-loading cache, + * or the test override if one has been set. + */ + public static SimpleAggCacheMgr internalInstance() { + SimpleAggCacheMgr test = TEST_INSTANCE; + if (test != null) { + return test; + } + return getInstance(); + } + + /** + * Override used only in unit tests to inject a mock implementation. + */ + @VisibleForTesting + public static void setTestInstance(SimpleAggCacheMgr instance) { + TEST_INSTANCE = instance; + } + + /** + * Reset the test override so that subsequent calls go back to the real cache. + */ + @VisibleForTesting + public static void clearTestInstance() { + TEST_INSTANCE = null; + } + + /** + * Get the cached min/max for a column. + * + * <p>Cache validation uses a two-level check driven by {@code versionTime}: + * <ol> + * <li>If {@code cachedVersionTime != callerVersionTime}: times differ, the table has + * definitely been modified — evict immediately. {@code versionSupplier} is + * <em>never</em> called, saving a potentially expensive RPC in cloud mode.</li> + * <li>If {@code cachedVersionTime == callerVersionTime}: times match, but two writes + * in the same millisecond could produce the same time with different versions. + * {@code versionSupplier} is called exactly once to confirm {@code version} equality.</li> + * </ol> + * + * @param key cache key (tableId + columnName) + * @param callerVersionTime the caller's current {@code table.getVisibleVersionTime()} (cheap, always local) + * @param versionSupplier lazy supplier for {@code table.getVisibleVersion()}; called only when + * {@code versionTime} values are equal. Returns -1 on RPC failure, + * in which case the method returns empty conservatively. + */ + public Optional<ColumnMinMax> getStats(ColumnMinMaxKey key, long callerVersionTime, + LongSupplier versionSupplier) { + CompletableFuture<Optional<CacheValue>> future = cache.get(key); + if (future.isDone()) { + try { + Optional<CacheValue> cacheValue = future.get(); + if (cacheValue.isPresent()) { + CacheValue value = cacheValue.get(); + if (value.versionTime() == callerVersionTime) { + // Times match — verify version to guard against two writes in the same ms. + // versionSupplier may be an RPC in cloud mode; it is only invoked here. + long callerVersion = versionSupplier.getAsLong(); + if (callerVersion < 0) { + // RPC failed: cannot verify version — return empty conservatively. + // Do not invalidate: keep the existing entry so it can be reused once the RPC recovers. + return Optional.empty(); + } + if (value.version() == callerVersion) { + return Optional.of(value.minMax()); + } + // Same time but different version: two writes in the same ms — stale. + } + // Times differ → definitely stale; skip the version RPC entirely. + } + // Either empty (load failed) or stale — evict so next call triggers a fresh reload. + cache.synchronous().invalidate(key); + } catch (Exception e) { + LOG.warn("Failed to get MinMax for column: {}, versionTime: {}", key, callerVersionTime, e); + cache.synchronous().invalidate(key); + } + } + return Optional.empty(); + } + + /** + * Evict the cached stats for a column, if present. Used when we know the data has changed + */ + public void removeStats(ColumnMinMaxKey key) { + cache.synchronous().invalidate(key); + } + + /** + * Get the cached row count for a table. + * + * <p>Cache validation uses a two-level check driven by {@code versionTime}: + * <ol> + * <li>If {@code cachedVersionTime != callerVersionTime}: times differ, the table has + * definitely been modified — evict immediately. {@code versionSupplier} is + * <em>never</em> called, saving a potentially expensive RPC in cloud mode.</li> + * <li>If {@code cachedVersionTime == callerVersionTime}: times match, but two writes + * in the same millisecond could produce the same time with different versions. + * {@code versionSupplier} is called exactly once to confirm {@code version} equality.</li> + * </ol> + * + * @param tableId the table id + * @param callerVersionTime the caller's current {@code table.getVisibleVersionTime()} (cheap, always local) + * @param versionSupplier lazy supplier for {@code table.getVisibleVersion()}; called only when + * {@code versionTime} values are equal. Returns -1 on RPC failure, + * in which case the method returns empty conservatively. + */ + public OptionalLong getRowCount(long tableId, long callerVersionTime, LongSupplier versionSupplier) { + CompletableFuture<Optional<RowCountValue>> future = rowCountCache.get(tableId); + if (future.isDone()) { + try { + Optional<RowCountValue> cached = future.get(); + if (cached.isPresent()) { + RowCountValue value = cached.get(); + if (value.versionTime() == callerVersionTime) { + // Times match — verify version to guard against two writes in the same ms. + long callerVersion = versionSupplier.getAsLong(); + if (callerVersion < 0) { + // RPC failed: cannot verify version — return empty conservatively. + // Do not invalidate: keep the existing entry so it can be reused once the RPC recovers. + return OptionalLong.empty(); + } + if (value.version() == callerVersion) { + return OptionalLong.of(value.rowCount()); + } + // Same time but different version: two writes in the same ms — stale. + } + // Times differ → definitely stale; skip the version RPC entirely. + } + // Either empty (load failed) or stale — evict so next call triggers a fresh reload. + rowCountCache.synchronous().invalidate(tableId); + } catch (Exception e) { + LOG.warn("Failed to get row count for table: {}, versionTime: {}", tableId, callerVersionTime, e); + rowCountCache.synchronous().invalidate(tableId); + } + } + return OptionalLong.empty(); + } + + /** + * Generate the internal SQL for fetching exact min/max values. + */ + @VisibleForTesting + public static String genMinMaxSql(List<String> qualifiers, String columnName) { + // qualifiers: [catalogName, dbName, tableName] + String quotedCol = "`" + StatisticsUtil.escapeColumnName(columnName) + "`"; + String fullTable = "`" + qualifiers.get(0) + "`.`" + + qualifiers.get(1) + "`.`" + + qualifiers.get(2) + "`"; + return "SELECT min(" + quotedCol + "), max(" + quotedCol + ") FROM " + fullTable; + } + + /** + * Generate the internal SQL for fetching exact row count. + */ + @VisibleForTesting + public static String genCountSql(List<String> qualifiers) { + String fullTable = "`" + qualifiers.get(0) + "`.`" + + qualifiers.get(1) + "`.`" + + qualifiers.get(2) + "`"; + return "SELECT count(*) FROM " + fullTable; + } + + /** + * Async cache loader that issues internal SQL queries to compute exact min/max. + */ + protected static final class CacheLoader + implements AsyncCacheLoader<ColumnMinMaxKey, Optional<CacheValue>> { + + @Override + public @NonNull CompletableFuture<Optional<CacheValue>> asyncLoad( + @NonNull ColumnMinMaxKey key, @NonNull Executor executor) { + return CompletableFuture.supplyAsync(() -> { + try { + return doLoad(key); + } catch (Exception e) { + LOG.warn("Failed to load MinMax for column: {}", key, e); + return Optional.empty(); + } + }, executor); + } + + private Optional<CacheValue> doLoad(ColumnMinMaxKey key) throws Exception { + // Look up the table by its ID + TableIf tableIf = Env.getCurrentInternalCatalog().getTableByTableId(key.getTableId()); + if (!(tableIf instanceof OlapTable)) { + return Optional.empty(); + } + OlapTable olapTable = (OlapTable) tableIf; + + // Validate column exists and is eligible + Column column = olapTable.getColumn(key.getColumnName()); + if (column == null) { + return Optional.empty(); + } + if (!column.getType().isNumericType() && !column.getType().isDateType()) { + return Optional.empty(); + } + if (column.isAggregated()) { + return Optional.empty(); + } + + // Capture version and versionTime before the query. + // Both are needed: versionTime (cheap) catches writes that cross a millisecond boundary; + // version (may be RPC in cloud mode) is the ground truth for same-millisecond writes. + long versionBefore = olapTable.getVisibleVersion(); + long versionTimeBefore = olapTable.getVisibleVersionTime(); + + // Build and execute internal SQL + List<String> qualifiers = olapTable.getFullQualifiers(); + String sql = genMinMaxSql(qualifiers, column.getName()); + + List<ResultRow> rows; + try (AutoCloseConnectContext r = StatisticsUtil.buildConnectContext(false)) { + r.connectContext.getSessionVariable().setPipelineTaskNum("1"); + // Disable our own rule to prevent infinite recursion: + // this internal SQL goes through Nereids and would otherwise trigger + // RewriteSimpleAggToConstantRule again. + r.connectContext.getSessionVariable().setDisableNereidsRules( + "REWRITE_SIMPLE_AGG_TO_CONSTANT"); + StmtExecutor stmtExecutor = new StmtExecutor(r.connectContext, sql); + rows = stmtExecutor.executeInternalQuery(); + } + if (rows == null || rows.isEmpty()) { + return Optional.empty(); + } + ResultRow row = rows.get(0); + String minVal = row.get(0); + String maxVal = row.get(1); + if (minVal == null || maxVal == null) { + return Optional.empty(); + } + // Fast check: if versionTime changed, a write definitely occurred during the query. + long versionTimeAfter = olapTable.getVisibleVersionTime(); + if (versionTimeAfter != versionTimeBefore) { + return Optional.empty(); + } + // Definitive check: compare version before and after the query. + // A same-millisecond write would not be caught by versionTime alone but shows up here. + long versionAfter = olapTable.getVisibleVersion(); + if (versionAfter != versionBefore) { + return Optional.empty(); + } + return Optional.of(new CacheValue(new ColumnMinMax(minVal, maxVal), versionAfter, versionTimeBefore)); + } + } + + /** + * Async cache loader that issues {@code SELECT count(*) FROM table} + * to compute exact row counts. + */ + protected static final class RowCountLoader + implements AsyncCacheLoader<Long, Optional<RowCountValue>> { + + @Override + public @NonNull CompletableFuture<Optional<RowCountValue>> asyncLoad( + @NonNull Long tableId, @NonNull Executor executor) { + return CompletableFuture.supplyAsync(() -> { + try { + return doLoad(tableId); + } catch (Exception e) { + LOG.warn("Failed to load row count for table: {}", tableId, e); + return Optional.empty(); + } + }, executor); + } + + private Optional<RowCountValue> doLoad(Long tableId) throws Exception { + TableIf tableIf = Env.getCurrentInternalCatalog().getTableByTableId(tableId); + if (!(tableIf instanceof OlapTable)) { + return Optional.empty(); + } + OlapTable olapTable = (OlapTable) tableIf; + + // Capture version and versionTime before the query. + long versionBefore = olapTable.getVisibleVersion(); + long versionTimeBefore = olapTable.getVisibleVersionTime(); + + List<String> qualifiers = olapTable.getFullQualifiers(); + String sql = genCountSql(qualifiers); + + List<ResultRow> rows; + try (AutoCloseConnectContext r = StatisticsUtil.buildConnectContext(false)) { + r.connectContext.getSessionVariable().setPipelineTaskNum("1"); + // Disable our own rule to prevent infinite recursion: + // this internal SQL goes through Nereids and would otherwise trigger + // RewriteSimpleAggToConstantRule again. + r.connectContext.getSessionVariable().setDisableNereidsRules( + "REWRITE_SIMPLE_AGG_TO_CONSTANT"); + StmtExecutor stmtExecutor = new StmtExecutor(r.connectContext, sql); + rows = stmtExecutor.executeInternalQuery(); + } + if (rows == null || rows.isEmpty()) { + return Optional.empty(); + } + String countStr = rows.get(0).get(0); + if (countStr == null) { + return Optional.empty(); + } + long count = Long.parseLong(countStr); + // Fast check: if versionTime changed, a write definitely occurred during the query. + long versionTimeAfter = olapTable.getVisibleVersionTime(); + if (versionTimeAfter != versionTimeBefore) { + return Optional.empty(); + } + // Definitive check: compare version before and after the query. + // A same-millisecond write would not be caught by versionTime alone but shows up here. + long versionAfter = olapTable.getVisibleVersion(); + if (versionAfter != versionBefore) { + return Optional.empty(); + } + return Optional.of(new RowCountValue(count, versionAfter, versionTimeBefore)); + } + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/RewriteSimpleAggToConstantRuleTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/RewriteSimpleAggToConstantRuleTest.java new file mode 100644 index 00000000000..9cf179238d2 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/RewriteSimpleAggToConstantRuleTest.java @@ -0,0 +1,294 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.catalog.Database; +import org.apache.doris.catalog.Env; +import org.apache.doris.catalog.OlapTable; +import org.apache.doris.nereids.stats.SimpleAggCacheMgr; +import org.apache.doris.nereids.stats.SimpleAggCacheMgr.ColumnMinMax; +import org.apache.doris.nereids.stats.SimpleAggCacheMgr.ColumnMinMaxKey; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.utframe.TestWithFeService; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Test; + +import java.util.Optional; +import java.util.OptionalLong; +import java.util.function.LongSupplier; + +/** + * Tests for {@link RewriteSimpleAggToConstantRule}. + * + * <p>This rule rewrites simple aggregation queries (count/min/max on DUP_KEYS tables + * without GROUP BY) into constant values from FE metadata, producing + * LogicalProject -> LogicalOneRowRelation plans. + */ +class RewriteSimpleAggToConstantRuleTest extends TestWithFeService implements MemoPatternMatchSupported { + + @Override + protected void runBeforeAll() throws Exception { + createDatabase("test"); + + // DUP_KEYS table with NOT NULL columns + createTable("CREATE TABLE test.dup_tbl (\n" + + " k1 INT NOT NULL,\n" + + " v1 INT NOT NULL,\n" + + " v2 BIGINT NOT NULL,\n" + + " v3 DATE NOT NULL,\n" + + " v4 VARCHAR(128)\n" + + ") DUPLICATE KEY(k1)\n" + + "DISTRIBUTED BY HASH(k1) BUCKETS 1\n" + + "PROPERTIES('replication_num' = '1');"); + + // UNIQUE_KEYS table (should NOT be rewritten) + createTable("CREATE TABLE test.uniq_tbl (\n" + + " k1 INT NOT NULL,\n" + + " v1 INT NOT NULL\n" + + ") UNIQUE KEY(k1)\n" + + "DISTRIBUTED BY HASH(k1) BUCKETS 1\n" + + "PROPERTIES('replication_num' = '1');"); + + // AGG_KEYS table (should NOT be rewritten) + createTable("CREATE TABLE test.agg_tbl (\n" + + " k1 INT NOT NULL,\n" + + " v1 INT SUM NOT NULL\n" + + ") AGGREGATE KEY(k1)\n" + + "DISTRIBUTED BY HASH(k1) BUCKETS 1\n" + + "PROPERTIES('replication_num' = '1');"); + + connectContext.setDatabase("test"); + connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION"); + + // Install a mock SimpleAggCacheMgr that returns known min/max and row count for dup_tbl + SimpleAggCacheMgr.setTestInstance(new MockMinMaxStatsMgr()); + } + + @AfterAll + public static void tearDown() { + SimpleAggCacheMgr.clearTestInstance(); + } + + private OlapTable getOlapTable(String tableName) throws Exception { + Database db = Env.getCurrentInternalCatalog().getDbOrMetaException("test"); + return (OlapTable) db.getTableOrMetaException(tableName); + } + + // ======================== Positive tests: should rewrite ======================== + + @Test + void testCountStarRewrite() { + // count(*) on DUP_KEYS with reported row counts → rewrite to constant + PlanChecker.from(connectContext) + .analyze("SELECT count(*) FROM dup_tbl") + .rewrite() + .matches(logicalResultSink(logicalOneRowRelation())) + .printlnTree(); + } + + @Test + void testCountNotNullColumnRewrite() { + // count(not-null column) on DUP_KEYS → rewrite to constant (equals rowCount) + PlanChecker.from(connectContext) + .analyze("SELECT count(k1) FROM dup_tbl") + .rewrite() + .matches(logicalResultSink(logicalOneRowRelation())) + .printlnTree(); + } + + @Test + void testMinRewrite() { + // min(int col) on DUP_KEYS with cache hit → rewrite to constant + PlanChecker.from(connectContext) + .analyze("SELECT min(v1) FROM dup_tbl") + .rewrite() + .matches(logicalResultSink(logicalOneRowRelation())) + .printlnTree(); + } + + @Test + void testMaxRewrite() { + // max(bigint col) on DUP_KEYS with cache hit → rewrite to constant + PlanChecker.from(connectContext) + .analyze("SELECT max(v2) FROM dup_tbl") + .rewrite() + .matches(logicalResultSink(logicalOneRowRelation())) + .printlnTree(); + } + + @Test + void testMinMaxDateRewrite() { + // min/max(date col) on DUP_KEYS with cache hit → rewrite to constant + PlanChecker.from(connectContext) + .analyze("SELECT min(v3), max(v3) FROM dup_tbl") + .rewrite() + .matches(logicalResultSink(logicalOneRowRelation())) + .printlnTree(); + } + + @Test + void testMixedCountMinMax() { + // count(*), min(v1), max(v2) on DUP_KEYS → rewrite to constant + PlanChecker.from(connectContext) + .analyze("SELECT count(*), min(v1), max(v2) FROM dup_tbl") + .rewrite() + .matches(logicalResultSink(logicalOneRowRelation())) + .printlnTree(); + } + + // ======================== Negative tests: should NOT rewrite ======================== + + @Test + void testUniqueKeysNotRewrite() { + // UNIQUE_KEYS table → rule should NOT trigger + PlanChecker.from(connectContext) + .analyze("SELECT count(*) FROM uniq_tbl") + .rewrite() + .nonMatch(logicalResultSink(logicalOneRowRelation())); + } + + @Test + void testAggKeysNotRewrite() { + // AGG_KEYS table → rule should NOT trigger + PlanChecker.from(connectContext) + .analyze("SELECT count(*) FROM agg_tbl") + .rewrite() + .nonMatch(logicalResultSink(logicalOneRowRelation())); + } + + @Test + void testGroupByNotRewrite() { + // GROUP BY present → rule should NOT trigger + PlanChecker.from(connectContext) + .analyze("SELECT count(*) FROM dup_tbl GROUP BY k1") + .rewrite() + .nonMatch(logicalResultSink(logicalOneRowRelation())); + } + + @Test + void testUnsupportedAggFuncNotRewrite() { + // SUM is not supported → rule should NOT trigger + PlanChecker.from(connectContext) + .analyze("SELECT sum(v1) FROM dup_tbl") + .rewrite() + .nonMatch(logicalResultSink(logicalOneRowRelation())); + } + + @Test + void testAvgNotRewrite() { + // AVG is not supported → rule should NOT trigger + PlanChecker.from(connectContext) + .analyze("SELECT avg(v1) FROM dup_tbl") + .rewrite() + .nonMatch(logicalResultSink(logicalOneRowRelation())); + } + + @Test + void testDistinctCountNotRewrite() { + // count(distinct col) → rule should NOT trigger + PlanChecker.from(connectContext) + .analyze("SELECT count(distinct k1) FROM dup_tbl") + .rewrite() + .nonMatch(logicalResultSink(logicalOneRowRelation())); + } + + @Test + void testCountNullableColumnNotRewrite() { + // count(nullable column v4) → cannot guarantee count(v4) == rowCount, rule should NOT trigger + PlanChecker.from(connectContext) + .analyze("SELECT count(v4) FROM dup_tbl") + .rewrite() + .nonMatch(logicalResultSink(logicalOneRowRelation())); + } + + @Test + void testMinMaxStringColumnNotRewrite() { + // min(varchar col) → not supported for string types + PlanChecker.from(connectContext) + .analyze("SELECT min(v4) FROM dup_tbl") + .rewrite() + .nonMatch(logicalResultSink(logicalOneRowRelation())); + } + + @Test + void testMixedSupportedAndUnsupportedNotRewrite() { + // If ANY agg function cannot be replaced, the entire rewrite is skipped. + // count(*) can be replaced, but sum(v1) cannot → entire query NOT rewritten + PlanChecker.from(connectContext) + .analyze("SELECT count(*), sum(v1) FROM dup_tbl") + .rewrite() + .nonMatch(logicalResultSink(logicalOneRowRelation())); + } + + // ======================== Mock SimpleAggCacheMgr ======================== + + /** + * A simple mock that returns known min/max values and row count for dup_tbl. + * It accepts any version (returns the values regardless of version). + */ + private class MockMinMaxStatsMgr extends SimpleAggCacheMgr { + + @Override + public Optional<ColumnMinMax> getStats(ColumnMinMaxKey key, long callerVersionTime, + LongSupplier versionSupplier) { + try { + OlapTable table = getOlapTable("dup_tbl"); + if (key.getTableId() != table.getId()) { + return Optional.empty(); + } + } catch (Exception e) { + return Optional.empty(); + } + + String colName = key.getColumnName().toLowerCase(); + switch (colName) { + case "k1": + return Optional.of(new ColumnMinMax("1", "100")); + case "v1": + return Optional.of(new ColumnMinMax("10", "999")); + case "v2": + return Optional.of(new ColumnMinMax("100", "99999")); + case "v3": + return Optional.of(new ColumnMinMax("2024-01-01", "2025-12-31")); + default: + // v4 (varchar) and unknown columns → no cache + return Optional.empty(); + } + } + + @Override + public OptionalLong getRowCount(long tableId, long callerVersionTime, LongSupplier versionSupplier) { + try { + OlapTable table = getOlapTable("dup_tbl"); + if (tableId == table.getId()) { + return OptionalLong.of(100L); + } + } catch (Exception e) { + // fall through + } + return OptionalLong.empty(); + } + + @Override + public void removeStats(ColumnMinMaxKey key) { + // no-op for mock + } + } +} diff --git a/regression-test/data/nereids_rules_p0/rewrite_simple_agg_to_constant/rewrite_simple_agg_to_constant.out b/regression-test/data/nereids_rules_p0/rewrite_simple_agg_to_constant/rewrite_simple_agg_to_constant.out new file mode 100644 index 00000000000..a982c987c22 --- /dev/null +++ b/regression-test/data/nereids_rules_p0/rewrite_simple_agg_to_constant/rewrite_simple_agg_to_constant.out @@ -0,0 +1,37 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !count_star -- +5 + +-- !count_notnull -- +5 + +-- !min_int -- +10 + +-- !max_int -- +50 + +-- !min_bigint -- +100 + +-- !max_bigint -- +500 + +-- !min_date -- +2024-01-01 + +-- !max_date -- +2025-06-30 + +-- !mixed -- +5 10 500 + +-- !count_nullable -- +4 + +-- !uniq_count -- +3 + +-- !agg_count -- +3 + diff --git a/regression-test/data/nereids_rules_p0/rewrite_simple_agg_to_constant/truncate_version_reset.out b/regression-test/data/nereids_rules_p0/rewrite_simple_agg_to_constant/truncate_version_reset.out new file mode 100644 index 00000000000..44aab81c0b6 --- /dev/null +++ b/regression-test/data/nereids_rules_p0/rewrite_simple_agg_to_constant/truncate_version_reset.out @@ -0,0 +1,7 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !count_before_truncate -- +5 + +-- !count_after_truncate -- +0 + diff --git a/regression-test/suites/mv_p0/agg_use_key_direct/agg_use_key_direct.groovy b/regression-test/suites/mv_p0/agg_use_key_direct/agg_use_key_direct.groovy index 4a866c20e3c..938d015b7ff 100644 --- a/regression-test/suites/mv_p0/agg_use_key_direct/agg_use_key_direct.groovy +++ b/regression-test/suites/mv_p0/agg_use_key_direct/agg_use_key_direct.groovy @@ -23,7 +23,7 @@ suite ("agg_use_key_direct") { // this mv rewrite would not be rewritten in RBO phase, so set TRY_IN_RBO explicitly to make case stable sql "set pre_materialized_view_rewrite_strategy = TRY_IN_RBO" - + sql "set disable_nereids_rules='REWRITE_SIMPLE_AGG_TO_CONSTANT'"; sql "drop table if exists ${tblName} force;" sql """ create table ${tblName} ( diff --git a/regression-test/suites/nereids_p0/hint/test_hint.groovy b/regression-test/suites/nereids_p0/hint/test_hint.groovy index 2e89cabafe2..a696a22c36c 100644 --- a/regression-test/suites/nereids_p0/hint/test_hint.groovy +++ b/regression-test/suites/nereids_p0/hint/test_hint.groovy @@ -27,7 +27,7 @@ suite("test_hint") { sql 'set exec_mem_limit=21G' sql 'set be_number_for_test=1' sql 'set parallel_pipeline_task_num=1' - sql "set disable_nereids_rules=PRUNE_EMPTY_PARTITION" + sql "set disable_nereids_rules='PRUNE_EMPTY_PARTITION, REWRITE_SIMPLE_AGG_TO_CONSTANT'" sql 'set enable_nereids_planner=true' sql 'set enable_nereids_distribute_planner=false' sql "set ignore_shape_nodes='PhysicalProject'" diff --git a/regression-test/suites/nereids_rules_p0/rewrite_simple_agg_to_constant/rewrite_simple_agg_to_constant.groovy b/regression-test/suites/nereids_rules_p0/rewrite_simple_agg_to_constant/rewrite_simple_agg_to_constant.groovy new file mode 100644 index 00000000000..2aed9f841ad --- /dev/null +++ b/regression-test/suites/nereids_rules_p0/rewrite_simple_agg_to_constant/rewrite_simple_agg_to_constant.groovy @@ -0,0 +1,317 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("rewrite_simple_agg_to_constant") { + sql "SET enable_nereids_planner=true" + sql "SET enable_fallback_to_original_planner=false" + + sql "DROP DATABASE IF EXISTS test_rewrite_simple_agg_constant" + sql "CREATE DATABASE test_rewrite_simple_agg_constant" + sql "USE test_rewrite_simple_agg_constant" + + // ========== Create test tables ========== + + // DUP_KEYS table with NOT NULL columns + sql """ + CREATE TABLE dup_tbl ( + k1 INT NOT NULL, + v1 INT NOT NULL, + v2 BIGINT NOT NULL, + v3 DATE NOT NULL, + v4 VARCHAR(128) + ) DUPLICATE KEY(k1) + DISTRIBUTED BY HASH(k1) BUCKETS 1 + PROPERTIES('replication_num' = '1'); + """ + + // UNIQUE_KEYS table + sql """ + CREATE TABLE uniq_tbl ( + k1 INT NOT NULL, + v1 INT NOT NULL + ) UNIQUE KEY(k1) + DISTRIBUTED BY HASH(k1) BUCKETS 1 + PROPERTIES('replication_num' = '1'); + """ + + // AGG_KEYS table + sql """ + CREATE TABLE agg_tbl ( + k1 INT NOT NULL, + v1 INT SUM NOT NULL + ) AGGREGATE KEY(k1) + DISTRIBUTED BY HASH(k1) BUCKETS 1 + PROPERTIES('replication_num' = '1'); + """ + + // ========== Insert test data ========== + sql """ + INSERT INTO dup_tbl VALUES + (1, 10, 100, '2024-01-01', 'aaa'), + (2, 20, 200, '2024-06-15', 'bbb'), + (3, 30, 300, '2024-12-31', null), + (4, 40, 400, '2025-03-01', 'ddd'), + (5, 50, 500, '2025-06-30', 'eee'); + """ + + sql """ + INSERT INTO uniq_tbl VALUES (1, 10), (2, 20), (3, 30); + """ + + sql """ + INSERT INTO agg_tbl VALUES (1, 10), (2, 20), (3, 30); + """ + + // Wait a bit for tablet stats to be reported to FE + sleep(3000) + + // =================================================================== + // Warm up the SimpleAggCacheMgr async cache. + // + // The first call to getStats()/getRowCount() triggers an async load; + // the result is not available until the internal SQL finishes. + // We poll until explain shows "constant exprs", which proves the cache + // entry is loaded and the rule can fire. + // =================================================================== + // Trigger cache loads for all columns we'll test + sql "SELECT count(*) FROM dup_tbl" + sql "SELECT min(v1), max(v1) FROM dup_tbl" + sql "SELECT min(v2), max(v2) FROM dup_tbl" + sql "SELECT min(v3), max(v3) FROM dup_tbl" + + // Poll until the rule fires (cache is warm) + def warmUpSql = "SELECT count(*), min(v1), max(v2), min(v3) FROM dup_tbl" + def cacheReady = false + for (int i = 0; i < 30; i++) { + def explainResult = sql "EXPLAIN ${warmUpSql}" + if (explainResult.toString().contains("constant exprs")) { + cacheReady = true + break + } + sleep(1000) + } + if (!cacheReady) { + if (isCloudMode()) { + logger.info("SimpleAggCacheMgr cache did not warm up within 30s in cloud mode, skip remaining tests") + return + } + assertTrue(false, "SimpleAggCacheMgr cache did not warm up within 30 seconds") + } + + // =================================================================== + // Positive tests: verify the rule IS applied. + // The cache is confirmed warm from the poll above. + // =================================================================== + + // count(*) + explain { + sql("SELECT count(*) FROM dup_tbl") + contains "constant exprs" + } + order_qt_count_star """SELECT count(*) FROM dup_tbl;""" + + // count(not-null column) + explain { + sql("SELECT count(k1) FROM dup_tbl") + contains "constant exprs" + } + order_qt_count_notnull """SELECT count(k1) FROM dup_tbl;""" + + // min(int) + explain { + sql("SELECT min(v1) FROM dup_tbl") + contains "constant exprs" + } + order_qt_min_int """SELECT min(v1) FROM dup_tbl;""" + + // max(int) + explain { + sql("SELECT max(v1) FROM dup_tbl") + contains "constant exprs" + } + order_qt_max_int """SELECT max(v1) FROM dup_tbl;""" + + // min(bigint) + explain { + sql("SELECT min(v2) FROM dup_tbl") + contains "constant exprs" + } + order_qt_min_bigint """SELECT min(v2) FROM dup_tbl;""" + + // max(bigint) + explain { + sql("SELECT max(v2) FROM dup_tbl") + contains "constant exprs" + } + order_qt_max_bigint """SELECT max(v2) FROM dup_tbl;""" + + // min(date) + explain { + sql("SELECT min(v3) FROM dup_tbl") + contains "constant exprs" + } + order_qt_min_date """SELECT min(v3) FROM dup_tbl;""" + + // max(date) + explain { + sql("SELECT max(v3) FROM dup_tbl") + contains "constant exprs" + } + order_qt_max_date """SELECT max(v3) FROM dup_tbl;""" + + // Mixed: count(*), min, max together + explain { + sql("SELECT count(*), min(v1), max(v2) FROM dup_tbl") + contains "constant exprs" + } + order_qt_mixed """SELECT count(*), min(v1), max(v2) FROM dup_tbl;""" + + // =================================================================== + // Negative tests: these queries should NEVER be rewritten. + // The cache is confirmed warm (the poll + positive tests above proved + // count/min/max for dup_tbl all hit cache). So if these plans do NOT + // contain "constant exprs", the rule actively rejected them. + // =================================================================== + + // Non-DUP_KEYS: UNIQUE_KEYS table should not be rewritten + explain { + sql("SELECT count(*) FROM uniq_tbl") + notContains "constant exprs" + } + + // Non-DUP_KEYS: AGG_KEYS table should not be rewritten + explain { + sql("SELECT count(*) FROM agg_tbl") + notContains "constant exprs" + } + + // GROUP BY present → should not be rewritten + explain { + sql("SELECT count(*) FROM dup_tbl GROUP BY k1") + notContains "constant exprs" + } + + // Unsupported aggregate function: SUM (cache for v1 is already warm from min/max above) + explain { + sql("SELECT sum(v1) FROM dup_tbl") + notContains "constant exprs" + } + + // Unsupported aggregate function: AVG + explain { + sql("SELECT avg(v1) FROM dup_tbl") + notContains "constant exprs" + } + + // DISTINCT count + explain { + sql("SELECT count(distinct k1) FROM dup_tbl") + notContains "constant exprs" + } + + // count(nullable column) → cannot guarantee count(v4) equals row count + explain { + sql("SELECT count(v4) FROM dup_tbl") + notContains "constant exprs" + } + + // min/max on string column → not supported (row count cache is warm) + explain { + sql("SELECT min(v4) FROM dup_tbl") + notContains "constant exprs" + } + + // Mixed supported (count) and unsupported (sum) → entire query NOT rewritten + explain { + sql("SELECT count(*), sum(v1) FROM dup_tbl") + notContains "constant exprs" + } + + // Manually specified partition → should not be rewritten + explain { + sql("SELECT count(*) FROM dup_tbl PARTITION(dup_tbl)") + notContains "constant exprs" + } + explain { + sql("SELECT min(v1), max(v2) FROM dup_tbl PARTITION(dup_tbl)") + notContains "constant exprs" + } + + // Manually specified tablet → should not be rewritten + def tabletResult = sql "SHOW TABLETS FROM dup_tbl" + def tabletId = tabletResult[0][0] + explain { + sql("SELECT count(*) FROM dup_tbl TABLET(${tabletId})") + notContains "constant exprs" + } + explain { + sql("SELECT min(v1), max(v2) FROM dup_tbl TABLET(${tabletId})") + notContains "constant exprs" + } + + // TABLESAMPLE → should not be rewritten + explain { + sql("SELECT count(*) FROM dup_tbl TABLESAMPLE(10 PERCENT)") + notContains "constant exprs" + } + explain { + sql("SELECT min(v1), max(v2) FROM dup_tbl TABLESAMPLE(3 ROWS)") + notContains "constant exprs" + } + + // Sync materialized view (indexSelected = true) → should not be rewritten + createMV("""CREATE MATERIALIZED VIEW mv_dup_sum AS SELECT v1 as m1, sum(v2) as m2 FROM dup_tbl GROUP BY v1;""") + explain { + sql("SELECT count(*) FROM dup_tbl INDEX mv_dup_sum") + notContains "constant exprs" + } + + // =================================================================== + // Verify disabling the rule works. + // When the rule is disabled, even simple count(*) should NOT produce constant exprs. + // =================================================================== + explain { + sql("SELECT /*+ SET_VAR(disable_nereids_rules=REWRITE_SIMPLE_AGG_TO_CONSTANT) */ count(*) FROM dup_tbl") + notContains "constant exprs" + } + + // =================================================================== + // Correctness-only tests for queries that should NOT be rewritten. + // The result must still be correct even though the rule does not fire. + // =================================================================== + + // count(nullable column) — not rewritten, but result must be correct + order_qt_count_nullable """SELECT count(v4) FROM dup_tbl;""" + + // Count on non-DUP_KEYS tables — result correctness + order_qt_uniq_count """SELECT count(*) FROM uniq_tbl;""" + order_qt_agg_count """SELECT count(*) FROM agg_tbl;""" + + // =================================================================== + // Cache invalidation test: inserting new data should invalidate the + // cached min/max stats so the rule no longer fires until cache refreshes. + // =================================================================== + sql "INSERT INTO dup_tbl VALUES (6, 60, 600, '2025-12-01', 'fff');" + + // Right after INSERT the cached stats are stale; explain should NOT + // show "constant exprs" because the cache entry has been invalidated. + explain { + sql("SELECT min(v2), max(v2) FROM dup_tbl") + notContains "constant exprs" + } + +} diff --git a/regression-test/suites/nereids_rules_p0/rewrite_simple_agg_to_constant/truncate_version_reset.groovy b/regression-test/suites/nereids_rules_p0/rewrite_simple_agg_to_constant/truncate_version_reset.groovy new file mode 100644 index 00000000000..caef303ff5b --- /dev/null +++ b/regression-test/suites/nereids_rules_p0/rewrite_simple_agg_to_constant/truncate_version_reset.groovy @@ -0,0 +1,112 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/** + * Regression test for: TRUNCATE TABLE must reset TableAttributes.visibleVersion. + * + * Bug: before the fix, truncateTableInternal() replaced partition data but did + * not call olapTable.resetVisibleVersion(). As a result: + * - Partition.visibleVersion was reset to PARTITION_INIT_VERSION (1). + * - TableAttributes.visibleVersion kept its old, higher value. + * - TableAttributes.visibleVersionTime was never updated. + * + * Consequence for RewriteSimpleAggToConstantRule / SimpleAggCacheMgr: + * The cache entry was keyed by versionTime. Because versionTime did not + * change at truncate time, the *caller* saw the same versionTime as the + * stale cached entry → cache HIT → the rule returned the pre-truncate + * count/min/max instead of the correct post-truncate values. + * + * The fix adds olapTable.resetVisibleVersion() inside truncateTableInternal(), + * which bumps both visibleVersion (back to TABLE_INIT_VERSION = 1) and + * visibleVersionTime (to System.currentTimeMillis()). The new versionTime + * differs from the cached entry's versionTime → cache MISS → the rule + * correctly falls back to BE execution and returns the right result. + */ +suite("truncate_version_reset") { + sql "SET enable_nereids_planner=true" + sql "SET enable_fallback_to_original_planner=false" + + sql "DROP DATABASE IF EXISTS test_truncate_version_reset" + sql "CREATE DATABASE test_truncate_version_reset" + sql "USE test_truncate_version_reset" + + sql """ + CREATE TABLE tbl ( + k1 INT NOT NULL, + v1 INT NOT NULL + ) DUPLICATE KEY(k1) + DISTRIBUTED BY HASH(k1) BUCKETS 1 + PROPERTIES('replication_num' = '1'); + """ + + sql "INSERT INTO tbl VALUES (1, 10), (2, 20), (3, 30), (4, 40), (5, 50);" + + // ----------------------------------------------------------------------- + // Warm up SimpleAggCacheMgr for count(*). + // Poll until the rule fires (plan contains "constant exprs"). + // ----------------------------------------------------------------------- + sql "SELECT count(*) FROM tbl" + + def cacheReady = false + for (int i = 0; i < 30; i++) { + def explainResult = sql "EXPLAIN SELECT count(*) FROM tbl" + if (explainResult.toString().contains("constant exprs")) { + cacheReady = true + break + } + sleep(1000) + } + if (!cacheReady) { + if (isCloudMode()) { + logger.info("SimpleAggCacheMgr did not warm up in cloud mode, skip") + return + } + assertTrue(false, "SimpleAggCacheMgr cache did not warm up within 30 seconds") + } + + // Confirm the cache is hot and the rule fires for count(*). + explain { + sql "SELECT count(*) FROM tbl" + contains "constant exprs" + } + // Confirm the cached count is correct before truncate. + order_qt_count_before_truncate "SELECT count(*) FROM tbl;" + + // ----------------------------------------------------------------------- + // TRUNCATE the table. + // After the fix, resetVisibleVersion() is called inside + // truncateTableInternal(), which updates visibleVersionTime. + // The cache entry's versionTime no longer matches → cache is invalidated. + // ----------------------------------------------------------------------- + sql "TRUNCATE TABLE tbl;" + + // count(*) must return 0. + // Without the fix, the stale cache entry (count = 5) would be returned. + order_qt_count_after_truncate "SELECT count(*) FROM tbl;" + + // ----------------------------------------------------------------------- + // Insert new rows after truncate, then verify count(*) reflects them. + // This also validates that the version counter is correctly reset so + // subsequent transactions start from the right next-version. + // ----------------------------------------------------------------------- + sql "INSERT INTO tbl VALUES (10, 100), (20, 200);" + + // After insert the count must be 2. + def count = sql "SELECT count(*) FROM tbl" + assertEquals(2L, count[0][0] as long, + "count(*) after truncate + insert should be 2, got ${count[0][0]}") +} --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
