gemini-code-assist[bot] commented on code in PR #18886: URL: https://github.com/apache/tvm/pull/18886#discussion_r2898560604
########## src/s_tir/schedule/primitive/cache_index_helpers.cc: ########## @@ -0,0 +1,492 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file cache_index_helpers.cc + * \brief Implementation of analysis tools and utility functions used by the cache_index + * primitive, extracted from common_subexpr_elim_tools. + */ + +#include "cache_index_helpers.h" + +#include <tvm/arith/analyzer.h> // For the arith::Analyzer::Simplify() method simplifying terms +#include <tvm/tir/analysis.h> // For the ExprDeepEqual analysis +#include <tvm/tir/expr.h> +#include <tvm/tir/expr_functor.h> +#include <tvm/tir/stmt.h> +#include <tvm/tir/stmt_functor.h> + +#include <algorithm> // For std::find_if +#include <unordered_map> // For the hashtable datatype +#include <utility> +#include <vector> + +namespace tvm { +namespace tir { + +// cache_ is a static variable of the class ComputationsDoneBy, and C++ requires to define here +// such static attribute, otherwise it causes a linking error. +ComputationCache ComputationsDoneBy::cache_; Review Comment:   The `ComputationsDoneBy` class uses a static `ComputationCache` member (`cache_`) to store results of IR analysis. This cache is shared across all threads and instances of the class. Since it uses `std::unordered_map` without any synchronization, concurrent access from multiple threads (common in TVM's parallel auto-tuning environments like MetaSchedule) will lead to race conditions, potentially causing crashes or undefined behavior. Furthermore, the cache is never cleared and holds `ObjectRef`s (smart pointers) to IR nodes, which prevents these nodes from being garbage collected, leading to a persistent memory leak in long-running processes (e.g., a TVM RPC server). ########## src/tir/transform/common_subexpr_elim.cc: ########## @@ -19,807 +19,754 @@ /*! * \file common_subexpr_elim.cc - * \brief Implementation of the Common Subexpressions Elimination (CSE) pass - which rewrites statements and expressions in order to eliminate - redundant computations. In order to achieve that, common (sub-) - expressions are introduced into variables with let-in bindings, - and the places where the expression was used are replaced with - the freshly introduced variable. + * \brief Two-phase Common Subexpression Elimination (CSE) for TIR. + * + * Architecture overview + * --------------------- + * The pass is structured as two cooperating phases (single plan, single rewrite): + * + * Phase 1 — **CSEPlanner** (analysis, no mutation) + * Walks the TIR tree bottom-up and builds: + * - A *scope tree* that mirrors the nesting structure of For/If/While/AttrStmt. + * - An *expression DAG* mapping each structurally-unique eligible expression + * to its occurrence count, LCA scope, first-use location, and direct + * children (which shallower expressions it contains). + * From this it produces a *plan* in a single pass (shallower expressions + * first): two tables describing what to insert where (InsertBeforeTable) + * and what to replace (ExprRemapTable). Shallower-first processing with + * repr propagation resolves all CSE opportunities without a cascade loop. + * + * Phase 2 — **CSERewriter** (mechanical mutation) + * Consumes the plan and performs two kinds of edits: + * - Inserts `Bind(cse_var, expr)` statements at the planned insertion points. + * - Replaces every occurrence of a CSE'd expression with its variable. + * Insertions are handled by overriding VisitStmt and wrapping in SeqStmt; + * SeqStmt flattening handles correct nesting. + * + * Eligibility rules + * ----------------- + * An expression is eligible for CSE if: + * - It is not a leaf (Var, IntImm, FloatImm, StringImm). + * - It does not contain Call or BufferLoad (side-effects / memory dependence). + * - It is not Ramp or Broadcast (hardware-specific vector ops). + * + * Scope tree + * ---------- + * Each For, IfThenElse (each branch), While, and AttrStmt body creates a new + * scope. The scope tree enables computing the Lowest Common Ancestor (LCA) of + * all scopes where an expression occurs, which determines the correct insertion + * point — the narrowest scope that dominates all uses. */ -#include "common_subexpr_elim.h" - #include <tvm/ffi/container/array.h> +#include <tvm/ffi/container/map.h> +#include <tvm/ffi/extra/structural_hash.h> #include <tvm/ffi/reflection/registry.h> #include <tvm/ffi/string.h> -#include <tvm/ir/transform.h> // For the class Pass and the class PassContext -#include <tvm/tir/analysis.h> // For the analysis which gives the size of an expr +#include <tvm/ir/transform.h> +#include <tvm/tir/analysis.h> #include <tvm/tir/expr.h> #include <tvm/tir/expr_functor.h> -#include <tvm/tir/function.h> // For the class PrimFunc +#include <tvm/tir/function.h> #include <tvm/tir/stmt.h> #include <tvm/tir/stmt_functor.h> -#include <tvm/tir/transform.h> // For the decl of the function returning the pass +#include <tvm/tir/transform.h> -#include <algorithm> // For the algorithm std::find -#include <iostream> +#include <algorithm> #include <string> +#include <unordered_map> #include <utility> #include <vector> -#include "../analysis/check_contains.h" // For the visitor CheckContains -#include "common_subexpr_elim_tools.h" // For the auxiliary analysis (visitors) and tools -#include "replace_selected_expr.h" // For the mutator ReplaceSelectedExpr +#include "../analysis/check_contains.h" namespace tvm { namespace tir { -/*! - * \brief Check whether a computation is forbidden for being treated by the CSE pass. - The important thing about forbidden computations is that not only we won't want - to collect them for the CSE pass, but we also won't even want to collect computations - that contain them. - The reason is that reusing such computations would change the semantics of the program, - and therefore before doing any introduction of var or any reuse of already introduced - variables, we will make sure that the computation being considered is not forbidden, and - that it does not even contain a forbidden computation. - * \param expr The expression to check - * \return Whether `expr` is a forbidden computation or not - */ -bool CommonSubexpressionEliminator::ForbiddenComputation(const PrimExpr& expr) { - // Function calls, loads and buffer loads are absolutely forbidden as introducing them into - // variables would change the semantics of the program. - return (expr.as<CallNode>() != nullptr || expr.as<BufferLoadNode>() != nullptr); -} +// ============================================================================ +// Plan interface types (internal, C++ only) +// ============================================================================ /*! - * \brief Predicate used for verifying that a computation is eligible for being treated by - the CSE pass, i.e. for being introduced into a variable / for being replaced by a - variable. - Being eligible is a conjunction of a few conditions, like not being an atom (constant - or variable), not being a forbidden node, not containing a forbidden node, etc. - * \param expr The expression to check - * \return Whether `expr` is an eligible computation or not + * \brief Map from expression to CSE variable, keyed by structural equality. + * + * Used by CSERewriter to look up whether a visited expression should be + * replaced by a previously-introduced CSE variable. */ -bool CommonSubexpressionEliminator::IsEligibleComputation(const PrimExpr& expr) { - return ( - // In order to be eligible, the given expression should not be a constant - (expr.as<IntImmNode>() == nullptr) && (expr.as<FloatImmNode>() == nullptr) && - (expr.as<StringImmNode>() == nullptr) - // and it should not be a variable - && (expr.as<VarNode>() == nullptr) - // and it should not be a forbidden computation (function calls and loads) - && (!ForbiddenComputation(expr)) - // and it should not even contain a forbidden computation (function calls and loads) - // the reason is that we don't want to register expressions like (x + f(y)) or - // (x + Mem[i]) as introducing them into variables could change the semantics - && (!CheckContains::ExprContains(expr, ForbiddenComputation)) - // and it should not be a ramp node or a broadcast node due to some internals TVM - // constraints (which check for these node explicitely without performing any - // evaluation first, so if they have been put into variables it fails) - && (expr.as<RampNode>() == nullptr) && (expr.as<BroadcastNode>() == nullptr)); -} +using ExprRemapTable = std::unordered_map<PrimExpr, Var, ffi::StructuralHash, ExprDeepEqual>; /*! - * \brief Predicate used (when considering eligible computations) for only diving into - expressions that are allowed to contain eligible computations. Customize this predicate - if you want to make it forbidden to rewrite inside a specific node, like inside - a Load node for instance. - * \param expr The expression to check - * \return Whether `expr` can contain some eligible computations or not, and therefore - if recursing inside `expr` is necessary. + * \brief Map from statement (by pointer identity) to a list of Bind + * statements that should be inserted immediately before it. + * + * Pointer identity (ObjectPtrHash/Equal) is used because the insertion + * point is a specific child of a SeqStmt, not a structurally-equivalent + * statement elsewhere in the tree. */ -bool CommonSubexpressionEliminator::CanContainEligibleComputations(const PrimExpr& expr) { - // Uncomment the next line to prevent the collection and the replacement of eligible computations - // inside the index of Load nodes. We initially thought that this would be needed in order to - // not harm the indexing mode of the CPU, but as we are still far from ASM code, we - // finally want to perform such simplifications, which tend to happen fairly frequently. - - // return (expr.as<BufferLoadNode>() == nullptr) - return true; -} +using InsertBeforeTable = + std::unordered_map<Stmt, std::vector<Stmt>, ObjectPtrHash, ObjectPtrEqual>; -/*! - * \brief Implements an order on pairs (expression,frequency). First attempts to compare them - using the size of the expression. If it is the same, decides something else still - deterministic. - * \param a The first pair - * \param b The second pair - * \return A boolean telling if the first pair `a` comes before the second pair `b` - * \note We need this order to be deterministic in order to have a fully deterministic pass, - * as we will deal with elements that are coming from a hashtable, but the order in which - * they appeared in the hashtable was based on some runtime addresses, so it can potentially - * change with every execution. - */ -bool CommonSubexpressionEliminator::OrderOnExprAndFrequency(const std::pair<PrimExpr, size_t>& a, - const std::pair<PrimExpr, size_t>& b) { - size_t a_size = CalculateExprComplexity(a.first); - size_t b_size = CalculateExprComplexity(b.first); - return a_size > b_size; -} +// ============================================================================ +// CSEPlanner: Phase 1 — scan tree, build scope tree + expression table +// ============================================================================ /*! - * \brief Generates a new fresh variable, whose name will be cse_vi. - * \param type_annotation The type of the new variable to generate - * \return A new variable of type `type_annotation` called cse_vi where i is the first available - integer. + * \brief Phase 1 of the two-phase CSE pass. + * + * CSEPlanner is a read-only visitor that scans the TIR tree bottom-up and builds: + * 1. A **scope tree** (vector of ScopeEntry) reflecting For/If/While/AttrStmt nesting. + * 2. An **expression DAG** (ExprTable) where each node is an eligible expression + * with occurrence count, expr_depth, LCA scope, first-use location, and + * direct children (other table entries reachable without passing through + * another table entry). Children and expr_depth are computed incrementally + * during the bottom-up scan — no separate traversal needed. + * + * After scanning, ComputePlan() converts the internal state into two output tables: + * - InsertBeforeTable: where to insert `Bind(cse_var, expr)` statements. + * - ExprRemapTable: which expressions to replace with their CSE variable. + * + * Usage: + * \code + * auto [insert_before, expr_remap] = CSEPlanner::Plan(body, params); + * \endcode */ -Var CommonSubexpressionEliminator::GenerateNewVar(DataType type_annotation) { - // Increase `num_last_try_` for this new attempt - num_last_try_++; - // Builds the variable name, which is cse_vi where i will go up from 1 - std::string prefix = "cse_v"; - std::string name = prefix.append(std::to_string(num_last_try_)); - // Builds a ffi::String using the std::string - ffi::String string_name(name); - - // Check that the name that we want to use for the new variable isn't already being used - // (names don't really have to be unique as they are just hints, and having the same name - // doesn't means that it's the same variable, but it's clearer for dumps) - if (UsesVarName::StmtUsesVarName(initial_body_, string_name)) { - // If the name is already used, call ourselves recursively for trying with the next one - return GenerateNewVar(type_annotation); +class CSEPlanner : public StmtExprVisitor { + public: + /*! + * \brief Run the planner on a function body (static entry point). + * + * Creates a planner instance, initializes the root scope, scans the body, + * and returns the computed plan. + * + * \param body The TIR function body to analyze. + * \param params The function's formal parameters. Currently unused but + * reserved for future parameter-aware optimizations. + * \return A pair of (InsertBeforeTable, ExprRemapTable) describing the + * planned CSE transformations. + */ + static std::pair<InsertBeforeTable, ExprRemapTable> Plan(const Stmt& body, + const ffi::Array<Var>& params) { + CSEPlanner planner; + // Root scope (no parent, depth 0, no creator statement) + planner.scopes_.push_back({-1, 0, Stmt()}); + planner.current_scope_ = 0; + // Initialize current_seq_child_ to the body itself so that when the + // body is not a SeqStmt, first_use_stmt still points to a valid stmt. + planner.current_seq_child_ = body; + // Scan the tree + planner.VisitStmt(body); + // Convert scan results into the plan + return planner.ComputePlan(); } - // Increase `nb_var_` for this new generation of variable that we have just done - nb_var_++; - - // Return a new Variable using the name built and the given type_annotation - return (Var(string_name, type_annotation)); -} - -/*! - * \brief Gives the number of variables generated by the CSE on the current function - (i.e., getter for `nb_var_`). - * \return A copy of `nb_var_` - */ -int CommonSubexpressionEliminator::GetNbVarGenerated() { return nb_var_; } + private: + /*! + * \brief One node in the scope tree. + * + * The scope tree mirrors the nesting structure of the TIR program. + * Each scope-creating statement (For, IfThenElse branch, While, AttrStmt) + * gets its own ScopeEntry. The root scope (depth 0) represents the function + * body itself. + */ + struct ScopeEntry { + /*! \brief Parent scope ID (-1 for root). */ + int parent; + /*! \brief Distance from root (root = 0). */ + int depth; + /*! + * \brief The statement that created this scope (e.g. ForNode). + * + * Null for the root scope. Used as the insertion point when a CSE + * binding must be placed before the scope. + */ + Stmt creator_stmt; + }; -/*! - * \brief Toplevel (static) method that performs Common Subexpression Elimination on - a given statement (which should be the body of a PrimFunc). This method should be - called for each PrimFunc definition. - * \param stmt The statement of the function being analyzed, on which we want to perform CSE - * \param context_init The initial context, which should contain the formal parameters - of the function being analyzed - * \return A new statement where CSE has been performed - */ -Stmt CommonSubexpressionEliminator::PerformCSE(const Stmt& stmt, const Context& context_init, - bool identify_equiv_terms) { - // As this function is being called for each PrimFunc definition, we create a new instance - // for the one we are having now. - CommonSubexpressionEliminator common_subexpression_eliminator(stmt, context_init, - identify_equiv_terms); - return common_subexpression_eliminator.VisitStmt(stmt); -} + /*! + * \brief Node in the expression DAG built during the bottom-up scan. + * + * The planner maintains one ExprEntry per structurally-unique eligible + * expression (keyed by ExprDeepEqual). Since expressions are recorded + * bottom-up (children before parents), the DAG children are naturally + * discovered when a node is first added. Fields like expr_depth are + * computed incrementally from children — no separate traversal needed. + */ + struct ExprEntry { + /*! \brief Total number of occurrences across all scopes. */ + int count{0}; + /*! + * \brief Nesting depth of eligible sub-expressions (leaf eligible = 1). + * + * Computed from children: `1 + max(child.expr_depth)`, or 1 if no children. + * Used to sort entries so that shallower expressions are processed first. + */ + int expr_depth{0}; + /*! \brief The expression itself (first occurrence). */ + PrimExpr repr; + /*! + * \brief Scope ID of the Lowest Common Ancestor of all scopes containing an occurrence. + * + * Determines the outermost valid insertion point. + */ + int lca_scope{-1}; + /*! + * \brief Scope ID where the first occurrence was found. + * + * When lca_scope == first_use_scope, the binding is inserted before first_use_stmt. + */ + int first_use_scope{-1}; + /*! + * \brief The SeqStmt child (or body statement) containing the first occurrence. + * + * Used as the insertion point when the LCA equals the first-use scope. + */ + Stmt first_use_stmt; + /*! + * \brief Direct children in the expression DAG: (child_expr, multiplicity). + * + * A "direct child" is an eligible table entry reachable from this expression + * without passing through another table entry. Multiplicity counts how many + * times the child appears (e.g., 2 for `(x+y) * (x+y)` with child `x+y`). + * Populated during RecordExpr (bottom-up: children already in table). + */ + std::vector<std::pair<PrimExpr, int>> children; + /*! + * \brief Number of occurrences consumed by parent expressions' CSE bindings. + * + * Computed after the DAG is fully built, before plan generation. + * Independent count = count - consumed; only entries with independent >= 2 + * are CSE candidates. + */ + int consumed{0}; + }; -/*! - * \brief Protected constructor of CommonSubexpressionEliminator. - * \param context_init The context at the beginning of the CSE pass. It should contain the - formal parameters of the function that will be analyzed - */ -CommonSubexpressionEliminator::CommonSubexpressionEliminator(const Stmt& stmt, - const Context& context_init, - bool identify_equiv_terms) - : initial_body_(stmt), context_(context_init), identify_equiv_terms_(identify_equiv_terms) { - // The initial scope level (from ScopeStack's constructor) does not need - // EnterContextScope() because it should never be popped -- it persists - // for the lifetime of the CSE pass and holds the function parameters. -} + /*! \brief Expression table keyed by structural equality (ExprDeepEqual). */ + using ExprTable = std::unordered_map<PrimExpr, ExprEntry, ffi::StructuralHash, ExprDeepEqual>; + + // ------------------------------------------------------------------ + // Eligibility predicates + // ------------------------------------------------------------------ + + /*! + * \brief Check if an expression node type is forbidden for CSE. + * + * Call nodes may have side effects. BufferLoad nodes depend on memory + * state and cannot be safely hoisted or deduplicated. + * + * \param expr The expression to check. + * \return true if the expression is a Call or BufferLoad. + */ + static bool IsForbiddenNode(const PrimExpr& expr) { + return (expr.as<CallNode>() != nullptr || expr.as<BufferLoadNode>() != nullptr); + } -/*! - * \brief The method which overrides the generic dispatcher of StmtExprMutator. - Entry point to the common subexpression elimination mutator for expressions. - * \param expr The expression to mutate - */ -PrimExpr CommonSubexpressionEliminator::VisitExpr(const PrimExpr& expr) { - bool variables_created = false; // Will be needed for knowing if the CSE has created new vars - PrimExpr result = expr; - - // Obtain the (syntactic) eligible computations done by the input expression, and keep it as - // a ComputationTable, which is a mapping from PrimExpr to size_t, where the size_t is the - // number of time this exact syntactic computation is being computed. - ComputationTable table_syntactic_comp_done_by_expr = ComputationsDoneBy::GetComputationsDoneBy( - expr, IsEligibleComputation, CanContainEligibleComputations); - - // Transform the hashtable of *syntactic* eligible computations into a vector of pairs - // containing *semantic* entities, i.e. where equivalent computations are merged. - std::vector<std::pair<PrimExpr, size_t>> semantic_comp_done_by_expr = - SyntacticToSemanticComputations(table_syntactic_comp_done_by_expr, identify_equiv_terms_); - - // Sort the vector of semantic entities by decreasing size - std::stable_sort(semantic_comp_done_by_expr.begin(), semantic_comp_done_by_expr.end(), - OrderOnExprAndFrequency); - - // For each computation done (considering them from biggest to smallest) - for (size_t i = 0; i < semantic_comp_done_by_expr.size(); i++) { - std::pair<PrimExpr, size_t>& computation_and_nb = semantic_comp_done_by_expr[i]; - - bool ident_equiv_terms = identify_equiv_terms_; // To avoid the capture of "this" - - // The predicate later used (when doing replacements) to select expressions that are - // equivalent to the current computation (`computation_and_nb.first`) - std::function<bool(const PrimExpr&)> predicate_selector = - [computation_and_nb, ident_equiv_terms](const PrimExpr& current_expr) { - // `current_expr` should be equivalent to `computation_and_nb.first`, but we also check - // that `current_expr` is an eligible computation even if we know that - // `computation_and_nb.first` is eligible by construction, in case that one day the - // equivalence relation would not preserve the eligibility any more (even though that - // would probably be a very weird equivalence). - return (EquivalentTerms(current_expr, computation_and_nb.first, ident_equiv_terms) && - IsEligibleComputation(current_expr)); - }; - - // See if there is a pair (`var`, `value`) in the context where `value` is semantically - // equivalent to `computation_and_nb.first` - auto it_on_var = std::find_if( - context_.begin(), context_.end(), - [computation_and_nb, ident_equiv_terms](const std::pair<Var, MaybeValue>& var_and_value) { - // Note : safe to call value() as we check has_value() just before - return (var_and_value.second.has_value() && - EquivalentTerms(var_and_value.second.value(), computation_and_nb.first, - ident_equiv_terms)); - }); + /*! + * \brief Check if an expression is eligible for common subexpression elimination. + * + * An expression is eligible if it represents a non-trivial pure computation: + * - Not a leaf (Var, IntImm, FloatImm, StringImm — no computation to save). + * - Not a Call or BufferLoad (side effects / memory dependence). + * - Not Ramp or Broadcast (hardware-specific vector construction). + * - Does not transitively contain any forbidden node. + * + * \param expr The expression to check. + * \return true if the expression can participate in CSE. + */ + static bool IsEligible(const PrimExpr& expr) { + if (expr.as<IntImmNode>() || expr.as<FloatImmNode>() || expr.as<StringImmNode>() || + expr.as<VarNode>()) { + return false; + } + if (IsForbiddenNode(expr)) return false; + if (expr.as<RampNode>() || expr.as<BroadcastNode>()) return false; + if (CheckContains::ExprContains(expr, IsForbiddenNode)) return false; + return true; + } - // Case where we have a perfectly equivalent computation already available in a variable - // introduced (i.e, present in context_). - // Note that this case is needed when the user has written something like - // [let x = A in ....A...A...] : we need to be able to replace all the occurrences of A by - // an already existing variable holding A, when such a variable happens to exist. - if (it_on_var != context_.end()) { - // Replace in the current `result` everything that is selected by the selector with - // the existing variable, without diving into expressions in which we don't have the - // right to dive. - result = ReplaceSelectedExpr::ReplaceSelectedExprInExpr( - result, predicate_selector, it_on_var->first, CanContainEligibleComputations); - } else { - // The current computation is not equivalent to a computation already done. We will - // need to see if we want to introduce it. - - // --- Chunk needed for reusing the UndefinedVars() analysis --- - // 1 - Wraps the computation into a statement - Stmt computation_wrapped_in_stmt = Evaluate(computation_and_nb.first); - // 2.1 - Transform the context into a vector of variables instead of pairs - std::function<Var(const std::pair<Var, MaybeValue>&)> forget_value = - [](const std::pair<Var, MaybeValue>& pair) { return pair.first; }; - std::vector<Var> vector_vars_known = VectorMap(context_, forget_value); - // 2.2 - Transform the std::vector into an Array - ffi::Array<Var> array_vars_known = ffi::Array<Var>(vector_vars_known); - // --- End of chunk needed for reusing the UndefinedVars() analysis --- - - // We use the UndefinedVars() analysis to get the undefined vars of the computation - ffi::Array<Var> vars_undefined = UndefinedVars(computation_wrapped_in_stmt, array_vars_known); - - // Check if we can introduce it : if it contains no undefined variables and if we want - // to introduce it according to the predicate - if (vars_undefined.empty() && - PredicateIntroVarForComputation(computation_and_nb.first, computation_and_nb.second)) { - // Create a new variable for this computation - Var new_var = GenerateNewVar(computation_and_nb.first.dtype()); - // Replace in the current `result` everything that is selected by the selector with - // the new variable, without diving into expressions in which we don't have the - // right to dive. - result = ReplaceSelectedExpr::ReplaceSelectedExprInExpr(result, predicate_selector, new_var, - CanContainEligibleComputations); - // Build a let-in that introduces the new variable in the current `result` - result = Let(new_var, computation_and_nb.first, result); - // We don't add the variable to the context because the invariant is that the - // context is the context in which 'result' makes sense, and we've just updated it. - } else { - // Here it's not doable to introduce (via a let-in) the computation at this level - // as it contains variables that are not yet declared, and/or because the predicate - // did not select it. - // Either way, we will simply add to the vector of computations the direct subexprs - // of the current computation, as these ones might be good candidates - // for being introduced into variables. - // Note that we don't need to add all of its subexpressions, but only its *direct* - // subexpressions as we consider them from biggest to smallest, and if they were - // all added at once, then there could be dependencies between them, as commoning - // one of them could remove some other possibilities. - - // Computing the direct subexpressions will return a small number of direct - // subexpressions (typically 0 to 3) - std::vector<PrimExpr> direct_subexprs = DirectSubexpr::GetDirectSubexpressions( - computation_and_nb.first, IsEligibleComputation, CanContainEligibleComputations); - // The following insertion will maintain `semantic_comp_done_by_expr` sorted (by - // decreasing size/complexity), and it will only insert at locations > i as the - // direct subexprs are necessarily smaller than the current computation. - InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_expr, direct_subexprs, - identify_equiv_terms_); + // ------------------------------------------------------------------ + // Expression substitution + // ------------------------------------------------------------------ + + /*! + * \brief Replace all occurrences of `target` in `body` with `replacement`. + * + * Uses structural equality (ExprDeepEqual) to find matches. Stops recursing + * into a sub-tree once a match is found (the replacement is a leaf Var). + * + * \param body The expression to transform. + * \param target The sub-expression to find. + * \param replacement The expression to substitute in (typically a CSE Var). + * \return The transformed expression. + */ + static PrimExpr SubstituteSubexpr(const PrimExpr& body, const PrimExpr& target, + const PrimExpr& replacement) { + struct Replacer : public ExprMutator { + ExprDeepEqual eq; + PrimExpr target, replacement; + PrimExpr VisitExpr(const PrimExpr& e) final { + if (eq(e, target)) return replacement; + return ExprMutator::VisitExpr(e); } - } - // Note : we do not remove the current element, as we never look back in the local vector - } // End of for loop - - // If the CSE pass has created some variables, then we run it again as more commoning could - // potentially happen using the new variables introduced - if (variables_created) { - result = VisitExpr(result); - } else { - // But if no changes were performed, we recurse inside the children by calling the dispatcher. - // Calling the dispatcher to the specific treatments, which will update the context - // appropriately before doing the recursive calls on the children nodes - result = StmtExprMutator::VisitExpr(result); + }; + Replacer r; + r.target = target; + r.replacement = replacement; + return r.VisitExpr(body); } - return result; -} + // ------------------------------------------------------------------ + // Scope tree operations + // ------------------------------------------------------------------ + + /*! + * \brief Allocate a new child scope in the scope tree. + * + * \param parent The parent scope ID. + * \param creator_stmt The statement that creates this scope (e.g. ForNode). + * Stored for later use as an insertion point. + * \return The ID of the newly allocated scope. + */ + int AllocScope(int parent, Stmt creator_stmt) { + int id = static_cast<int>(scopes_.size()); + scopes_.push_back({parent, scopes_[parent].depth + 1, std::move(creator_stmt)}); + return id; + } -/*! - * \brief The method which overrides the specific treatment for a LetNode. - * - * The let-in expression introduces a new variable binding that is only visible - * within the body. We use context_scope_.WithNewScope to automatically clean up - * the binding when the body has been visited, replacing the old manual - * save/restore of context_. - */ -PrimExpr CommonSubexpressionEliminator::VisitExpr_(const LetNode* op) { - // At this point, we have already done the generic treatment of introducing (via let-in) what - // was doable at the toplevel of the given let-in. - - // Recurse on the `value` field for potentially rewriting it - PrimExpr value_new = VisitExpr(op->value); - - // Visit the body in a new scope. The let-in variable binding is added to the - // context inside the scope and automatically removed when the scope exits. - PrimExpr body_new = context_scope_.WithNewScope([&]() -> PrimExpr { - EnterContextScope(); - // Augment the context with the association (`var`, `value`) for the body - context_.push_back({op->var, MaybeValue(op->value)}); - // Recurse on the `body` (with this extended context) - // The recursive call will have potentially done new simplifications, because in this recursive - // call `var` will be a part of the context. - return VisitExpr(op->body); - }); - - // Rebuild the let-in with a new `value_new` and `body_new` where new simplifications might - // have been done. - - // If the `value` and the `body` of the let-in have been rewritten to the same thing - if (value_new.same_as(op->value) && body_new.same_as(op->body)) { - // then return a reference to the same node - return ffi::GetRef<PrimExpr>(op); - } else { - // Otherwise return a let-in built with the new `value_new` and the new `body_new` that - // have just been obtained - return Let(op->var, value_new, body_new, op->span); + /*! + * \brief Compute the Lowest Common Ancestor of two scope IDs. + * + * Walks both scopes upward to the same depth, then walks both upward + * in lockstep until they meet. This is the standard LCA algorithm for + * trees with parent pointers. + * + * \param a First scope ID. + * \param b Second scope ID. + * \return The scope ID of the LCA. + */ + int LCA(int a, int b) const { + while (scopes_[a].depth > scopes_[b].depth) a = scopes_[a].parent; + while (scopes_[b].depth > scopes_[a].depth) b = scopes_[b].parent; + while (a != b) { + a = scopes_[a].parent; + b = scopes_[b].parent; + } + return a; } -} -/*! - * \brief The method which overrides the generic dispatcher of StmtExprMutator. - Entry point to the common subexpression elimination mutator for statements. - * \param stmt The statement to mutate. - */ -Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) { - bool variables_created = false; // Will be needed for knowing if the CSE has created new vars - Stmt result = stmt; - - // Obtain the (syntactic) eligible computations done by the input statement, and keep it as - // a ComputationTable, which is a mapping from PrimExpr to size_t, where the size_t is the - // number of time this exact syntactic computation is being computed. - ComputationTable table_syntactic_comp_done_by_stmt = ComputationsDoneBy::GetComputationsDoneBy( - stmt, IsEligibleComputation, CanContainEligibleComputations); - - // Transform the hashtable of *syntactic* eligible computations into a vector of pairs - // containing *semantic* entities, i.e. where equivalent computations are merged. - std::vector<std::pair<PrimExpr, size_t>> semantic_comp_done_by_stmt = - SyntacticToSemanticComputations(table_syntactic_comp_done_by_stmt, identify_equiv_terms_); - - // Sort the vector of semantic entities by decreasing size - std::stable_sort(semantic_comp_done_by_stmt.begin(), semantic_comp_done_by_stmt.end(), - OrderOnExprAndFrequency); - - // For each computation done (considering them from biggest to smallest) - for (size_t i = 0; i < semantic_comp_done_by_stmt.size(); i++) { - std::pair<PrimExpr, size_t>& computation_and_nb = semantic_comp_done_by_stmt[i]; - - bool ident_equiv_terms = identify_equiv_terms_; // To avoid the capture of "this" - - // The predicate later used (when doing replacements) to select expressions that are - // equivalent to the current computation (`computation_and_nb.first`) - std::function<bool(const PrimExpr&)> predicate_selector = - [computation_and_nb, ident_equiv_terms](const PrimExpr& current_expr) { - // `current_expr` should be equivalent to `computation_and_nb.first`, but we also check - // that `current_expr` is an eligible computation even if we know that - // `computation_and_nb.first` is eligible by construction, in case that one day the - // equivalence relation would not preserve the eligibility any more (even though that - // would probably be a very weird equivalence). - return (EquivalentTerms(current_expr, computation_and_nb.first, ident_equiv_terms) && - IsEligibleComputation(current_expr)); - }; - - // See if there is a pair (`var`, `value`) in the context where `value` is semantically - // equivalent to `computation_and_nb.first` - auto it_on_var = std::find_if( - context_.begin(), context_.end(), - [computation_and_nb, ident_equiv_terms](const std::pair<Var, MaybeValue>& var_and_value) { - // Note : safe to call value() as we check has_value() just before - return (var_and_value.second.has_value() && - EquivalentTerms(var_and_value.second.value(), computation_and_nb.first, - ident_equiv_terms)); - }); + /*! + * \brief Find the statement to insert a CSE binding before. + * + * Two cases: + * - LCA == first-use scope: insert before the first_use_stmt directly. + * - LCA is an ancestor: walk from first_use_scope upward to find the + * scope-creating statement that is a direct child of the LCA scope, + * and insert before that statement. + * + * \param entry The expression entry containing scope and first-use metadata. + * \return The statement before which the CSE Bind should be inserted. + */ + Stmt FindInsertionStmt(const ExprEntry& entry) const { + if (entry.first_use_scope == entry.lca_scope) { + return entry.first_use_stmt; + } + int s = entry.first_use_scope; + while (scopes_[s].parent != entry.lca_scope) s = scopes_[s].parent; + return scopes_[s].creator_stmt; + } - // Case where we have a perfectly equivalent computation already available in a variable - // introduced (i.e, present in context_). - // Note that this case is needed when the user has written something like - // [let x = A in ....A...A...] : we need to be able to replace all the occurrences of A by - // an already existing variable holding A, when such a variable happens to exist. - if (it_on_var != context_.end()) { - // Replace in the current `result` everything that is selected by the selector with - // the existing variable, without diving into expressions in which we don't have the - // right to dive. - result = ReplaceSelectedExpr::ReplaceSelectedExprInStmt( - result, predicate_selector, it_on_var->first, CanContainEligibleComputations); + // ------------------------------------------------------------------ + // Expression recording + // ------------------------------------------------------------------ + + /*! + * \brief Record an occurrence of an expression in the expression table. + * + * On first occurrence: initializes the entry, records direct children + * (AST children that are in the table), and computes expr_depth from + * children. On subsequent occurrences: updates the LCA scope. + * + * \param e The expression to record. + * \param ast_children The direct AST children of e (passed by the caller + * who knows the node structure: op->a, op->b, etc.). + */ + void RecordExpr(const PrimExpr& e, std::initializer_list<PrimExpr> ast_children) { + if (!IsEligible(e)) return; + ExprEntry& entry = table_[e]; + bool is_first_occurrence = (entry.count == 0); + if (is_first_occurrence) { + entry.lca_scope = current_scope_; + entry.first_use_scope = current_scope_; + entry.first_use_stmt = current_seq_child_; + entry.repr = e; + // Build DAG edges: check which AST children are eligible table entries. + // Since we visit bottom-up, children are already in the table. + CollectChildren(entry, ast_children); } else { - // The current computation is not equivalent to a computation already done. We will - // need to see if we want to introduce it. - - // --- Chunk needed for reusing the UndefinedVars() analysis --- - // 1 - Wraps the computation into a statement - Stmt computation_wrapped_in_stmt = Evaluate(computation_and_nb.first); - // 2.1 - Transform the context into a vector of variables instead of pairs - std::function<Var(const std::pair<Var, MaybeValue>&)> forget_value = - [](const std::pair<Var, MaybeValue>& pair) { return pair.first; }; - std::vector<Var> vector_vars_known = VectorMap(context_, forget_value); - // 2.2 - Transform the std::vector into an Array - ffi::Array<Var> array_vars_known = ffi::Array<Var>(vector_vars_known); - // --- End of chunk needed for reusing the UndefinedVars() analysis --- - - // We use the UndefinedVars() analysis to get the undefined vars of the computation - ffi::Array<Var> vars_undefined = UndefinedVars(computation_wrapped_in_stmt, array_vars_known); - - // Check if we can introduce it : if it contains no undefined variables and if we want - // to introduce it according to the predicate - if (vars_undefined.empty() && - PredicateIntroVarForComputation(computation_and_nb.first, computation_and_nb.second)) { - // Create a new variable for this computation - Var new_var = GenerateNewVar(computation_and_nb.first.dtype()); - variables_created = true; - // Replace in the current `result` everything that is selected by the selector with - // the new variable, without diving into expressions in which we don't have the - // right to dive. - result = ReplaceSelectedExpr::ReplaceSelectedExprInStmt(result, predicate_selector, new_var, - CanContainEligibleComputations); - // Build a bind that introduces the new variable before the current `result` - result = SeqStmt({Bind(new_var, computation_and_nb.first), result}); - // We don't add the variable to the context because the invariant is that the - // context is the context in which 'result' makes sense, and we've just updated it. - } else { - // Here it's not doable to introduce (via a let-in) the computation at this level - // as it contains variables that are not yet declared, and/or because the predicate - // did not select it. - // Either way, we will simply add to the vector of computations the direct subexprs - // of the current computation, as these ones might be good candidates - // for being introduced into variables. - // Note that we don't need to add all of its subexpressions, but only its *direct* - // subexpressions as we consider them from biggest to smallest, and if they were - // all added at once, then there could be dependencies between them, as commoning - // one of them could remove some other possibilities. - - // Computing the direct subexpressions will return a small number of direct - // subexpressions (typically 0 to 3) - std::vector<PrimExpr> direct_subexprs = DirectSubexpr::GetDirectSubexpressions( - computation_and_nb.first, IsEligibleComputation, CanContainEligibleComputations); - // The following insertion will maintain `semantic_comp_done_by_stmt` sorted (by - // decreasing size/complexity), and it will only insert at locations > i as the - // direct subexprs are necessarily smaller than the current computation. - InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_stmt, direct_subexprs, - identify_equiv_terms_); - } + // Widen the insertion scope to cover all occurrences. + entry.lca_scope = LCA(entry.lca_scope, current_scope_); } - // Note : we do not remove the current element, as we never look back in the local vector - } // End of for loop - - // If the CSE pass has created some variables, then we run it again as more commoning could - // potentially happen using the new variables introduced - if (variables_created) { - result = VisitStmt(result); - } else { - // But if no changes were performed, we recurse inside the children by calling the dispatcher. - // Calling the dispatcher to the specific treatments, which will update the context - // appropriately before doing the recursive calls on the children nodes - result = StmtExprMutator::VisitStmt(result); + entry.count += 1; } - return result; -} - -/*! - * \brief The method which overrides the specific treatment for a BindNode. - * - * BindNode adds a (var, value) entry to the flat context_ vector. This entry - * persists across subsequent SeqStmt siblings in the same scope, enabling CSE - * to find common subexpressions that reference bind-defined variables. - * Cleanup happens automatically when the enclosing body-carrying statement's - * scope exits (via ContextScopeLevel's destructor), so no manual save/restore - * is needed here. - */ -Stmt CommonSubexpressionEliminator::VisitStmt_(const BindNode* op) { - // Recurse on the `value` field for potentially rewriting it - PrimExpr value_new = VisitExpr(op->value); - - // Augment the context with the association (`var`, `value`). - // This persists across SeqStmt siblings and is cleaned up by the - // enclosing scope's ContextScopeLevel destructor. - context_.push_back({op->var, MaybeValue(op->value)}); - - // Rebuild the Bind if value changed - if (value_new.same_as(op->value)) { - return ffi::GetRef<Stmt>(op); - } else { - return Bind(op->var, value_new, op->span); + /*! + * \brief Populate children and expr_depth for a newly created entry. + * + * Each AST child (e.g. op->a, op->b) that exists in the table becomes + * a DAG child. Multiplicity tracks duplicates (e.g. `(x+y)*(x+y)` has + * child `x+y` with multiplicity 2). expr_depth is 1 + max child depth. + */ + void CollectChildren(ExprEntry& entry, std::initializer_list<PrimExpr> ast_children) { + ExprDeepEqual eq; + int max_child_depth = 0; + for (const PrimExpr& child : ast_children) { + auto it = table_.find(child); + if (it == table_.end()) continue; + max_child_depth = std::max(max_child_depth, it->second.expr_depth); + // Check if this child was already seen (handles multiplicity). + bool already_recorded = false; + for (auto& [existing_child, multiplicity] : entry.children) { + if (eq(existing_child, child)) { + multiplicity++; + already_recorded = true; + break; + } + } + if (!already_recorded) entry.children.push_back({child, 1}); + } + entry.expr_depth = 1 + max_child_depth; } -} -/*! - * \brief Whether a Bind value is trivial (constant or variable), meaning it cannot - * contribute eligible computations for CSE and can be safely batched. - */ -static bool IsTrivialBindValue(const PrimExpr& value) { - return value.as<IntImmNode>() != nullptr || value.as<FloatImmNode>() != nullptr || - value.as<StringImmNode>() != nullptr || value.as<VarNode>() != nullptr; -} + // ------------------------------------------------------------------ + // Visitor overrides — expressions + // ------------------------------------------------------------------ + // Each arithmetic/comparison/logical/cast/select node visitor calls the + // base class to recurse into children first, then records the full + // expression. This bottom-up order ensures that sub-expressions are + // recorded before their parents. + // ------------------------------------------------------------------ + + using StmtExprVisitor::VisitExpr_; + + // Binary arithmetic operators (op->a, op->b) +#define CSE_VISIT_BINARY(NodeType) \ + void VisitExpr_(const NodeType* op) override { \ + StmtExprVisitor::VisitExpr_(op); \ + RecordExpr(ffi::GetRef<PrimExpr>(op), {op->a, op->b}); \ + } + CSE_VISIT_BINARY(AddNode) + CSE_VISIT_BINARY(SubNode) + CSE_VISIT_BINARY(MulNode) + CSE_VISIT_BINARY(DivNode) + CSE_VISIT_BINARY(ModNode) + CSE_VISIT_BINARY(FloorDivNode) + CSE_VISIT_BINARY(FloorModNode) + CSE_VISIT_BINARY(MinNode) + CSE_VISIT_BINARY(MaxNode) + CSE_VISIT_BINARY(EQNode) + CSE_VISIT_BINARY(NENode) + CSE_VISIT_BINARY(LTNode) + CSE_VISIT_BINARY(LENode) + CSE_VISIT_BINARY(GTNode) + CSE_VISIT_BINARY(GENode) + CSE_VISIT_BINARY(AndNode) + CSE_VISIT_BINARY(OrNode) +#undef CSE_VISIT_BINARY + + void VisitExpr_(const NotNode* op) override { + StmtExprVisitor::VisitExpr_(op); + RecordExpr(ffi::GetRef<PrimExpr>(op), {op->a}); + } + void VisitExpr_(const CastNode* op) override { + StmtExprVisitor::VisitExpr_(op); + RecordExpr(ffi::GetRef<PrimExpr>(op), {op->value}); + } + void VisitExpr_(const SelectNode* op) override { + StmtExprVisitor::VisitExpr_(op); + RecordExpr(ffi::GetRef<PrimExpr>(op), {op->condition, op->true_value, op->false_value}); + } -/*! - * \brief The method which overrides the specific treatment for a SeqStmtNode. - * - * Processes the flat sequence using a hybrid strategy that avoids the O(n^2) - * complexity of wrapping remaining siblings after every single Bind node: - * - * - Trivial Bind nodes (constant/variable values) are batched: their values - * are visited via VisitExpr, context_ is augmented, but the expensive - * cross-sibling CSE is deferred until the batch ends. - * - Non-trivial Bind nodes (whose values may contain eligible computations) - * use the wrap-remaining-siblings pattern to enable cross-sibling CSE. - * - After any Bind (trivial batch end or non-trivial), remaining siblings are - * wrapped into a body and VisitStmt is called once for cross-sibling CSE. - * - Non-Bind children are visited individually via VisitStmt. - * - * This reduces the common case of many consecutive trivial Binds (e.g., variable - * definitions with constant values) from O(n^2) to O(n), while preserving full - * CSE effectiveness for non-trivial Bind values. - * - * Context cleanup is handled automatically by ScopeStack. - */ -Stmt CommonSubexpressionEliminator::VisitStmt_(const SeqStmtNode* op) { - ffi::Array<Stmt> new_seq; - size_t i = 0; - - while (i < op->seq.size()) { - if (auto* bind = op->seq[i].as<BindNode>()) { - // Batch consecutive trivial Bind nodes (constant/variable values). - // These can't contribute common subexpressions, so it's safe to defer - // the cross-sibling CSE until the entire batch is processed. - if (IsTrivialBindValue(bind->value)) { - while (i < op->seq.size()) { - auto* b = op->seq[i].as<BindNode>(); - if (!b || !IsTrivialBindValue(b->value)) break; - PrimExpr value_new = VisitExpr(b->value); - context_.push_back({b->var, MaybeValue(b->value)}); - Stmt bind_new = - value_new.same_as(b->value) ? ffi::GetRef<Stmt>(b) : Bind(b->var, value_new, b->span); - new_seq.push_back(bind_new); - ++i; - } - } else { - // Non-trivial Bind: visit value, augment context, then wrap remaining - // siblings and call VisitStmt for cross-sibling CSE. - PrimExpr value_new = VisitExpr(bind->value); - context_.push_back({bind->var, MaybeValue(bind->value)}); - Stmt bind_new = value_new.same_as(bind->value) ? ffi::GetRef<Stmt>(bind) - : Bind(bind->var, value_new, bind->span); - new_seq.push_back(bind_new); - ++i; - } - // After the Bind (batch or single), wrap remaining siblings [i..end) and - // call VisitStmt once for cross-sibling CSE with the updated context. - if (i < op->seq.size()) { - Stmt body; - if (i + 1 == op->seq.size()) { - body = op->seq[i]; - } else { - ffi::Array<Stmt> rest; - for (size_t j = i; j < op->seq.size(); ++j) rest.push_back(op->seq[j]); - body = SeqStmt(rest); - } - Stmt body_new = VisitStmt(body); - // Flatten the result. - if (auto* inner = body_new.as<SeqStmtNode>()) { - for (const auto& s : inner->seq) new_seq.push_back(s); - } else { - new_seq.push_back(body_new); - } - return SeqStmt::Flatten(new_seq); - } - } else { - // Non-Bind child: visit individually via VisitStmt. - Stmt child_new = VisitStmt(op->seq[i]); - if (auto* inner = child_new.as<SeqStmtNode>()) { - for (const auto& s : inner->seq) new_seq.push_back(s); - } else { - new_seq.push_back(child_new); - } - ++i; + // ------------------------------------------------------------------ + // Visitor overrides — statements + // ------------------------------------------------------------------ + + /*! + * \brief Visit a SeqStmt, tracking which child is currently being visited. + * + * Updates current_seq_child_ before visiting each child so that RecordExpr + * can associate first occurrences with the correct SeqStmt child for later + * insertion-point determination. + */ + void VisitStmt_(const SeqStmtNode* op) override { + for (const auto& child : op->seq) { + current_seq_child_ = child; + VisitStmt(child); } } - return SeqStmt::Flatten(new_seq); -} + /*! \brief For loops create a new scope for their body. */ + void VisitStmt_(const ForNode* op) override { + int saved = current_scope_; + current_scope_ = AllocScope(saved, ffi::GetRef<Stmt>(op)); + StmtExprVisitor::VisitStmt_(op); + current_scope_ = saved; + } Review Comment:  The loop bounds (`min` and `extent`) of the `ForNode` are being visited within the new scope created for the loop body. This can lead to suboptimal CSE. For example, if an expression is common to both the loop extent and the loop body, the CSE `Bind` statement will be placed *inside* the loop, causing the expression to be re-evaluated on every iteration. The loop bounds should be processed in the parent scope, before the new scope for the body is created. This would be consistent with how other nodes like `IfThenElseNode` and `WhileNode` are handled. ```c void VisitStmt_(const ForNode* op) override { VisitExpr(op->min); VisitExpr(op->extent); int saved = current_scope_; current_scope_ = AllocScope(saved, ffi::GetRef<Stmt>(op)); VisitStmt(op->body); current_scope_ = saved; } ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
