llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-mlir-memref @llvm/pr-subscribers-mlir-linalg Author: Matthias Springer (matthias-springer) <details> <summary>Changes</summary> This commit generalizes and cleans up the `ValueBoundsConstraintSet` API. The API used to provide function overloads for comparing/computing bounds of: - index-typed SSA value - dimension of shaped value - affine map + operands This commit removes all overloads. There is now a single entry point for each `compare` variant and each `computeBound` variant. These functions now take a `Variable`, which is internally represented as an affine map and map operands. This commit also adds support for computing bounds for an affine map + operands. There was previously no public API for that. WIP until I added a test case for `computeBounds(AffineMap)`. --- Patch is 47.77 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/87980.diff 15 Files Affected: - (modified) mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h (+61-56) - (modified) mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp (+2-4) - (modified) mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp (+1-1) - (modified) mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp (+69) - (modified) mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp (+1-1) - (modified) mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp (+3-1) - (modified) mlir/lib/Dialect/Linalg/Transforms/Padding.cpp (+4-2) - (modified) mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp (+2-4) - (modified) mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp (+2-3) - (modified) mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp (+10-7) - (modified) mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp (+1-2) - (modified) mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp (+2-1) - (modified) mlir/lib/Dialect/Tensor/Utils/Utils.cpp (+2-2) - (modified) mlir/lib/Interfaces/ValueBoundsOpInterface.cpp (+149-188) - (modified) mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp (+3-3) ``````````diff diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h index 1d7bc6ea961cc3a..3e1502b4f5c357a 100644 --- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h +++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h @@ -15,6 +15,7 @@ #include "mlir/IR/Value.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Support/ExtensibleRTTI.h" #include <queue> @@ -111,6 +112,39 @@ class ValueBoundsConstraintSet public: static char ID; + /// A variable that can be added to the constraint set as a "column". The + /// value bounds infrastructure can compute bounds for variables and compare + /// two variables. + /// + /// Internally, a variable is represented as an affine map and operands. + class Variable { + public: + /// Construct a variable for an index-typed attribute or SSA value. + Variable(OpFoldResult ofr); + + /// Construct a variable for an index-typed SSA value. + Variable(Value indexValue); + + /// Construct a variable for a dimension of a shaped value. + Variable(Value shapedValue, int64_t dim); + + /// Construct a variable for an index-typed attribute/SSA value or for a + /// dimension of a shaped value. A non-null dimension must be provided if + /// and only if `ofr` is a shaped value. + Variable(OpFoldResult ofr, std::optional<int64_t> dim); + + /// Construct a variable for a map and its operands. + Variable(AffineMap map, ArrayRef<Variable> mapOperands); + Variable(AffineMap map, ArrayRef<Value> mapOperands); + + MLIRContext *getContext() const { return map.getContext(); } + + private: + friend class ValueBoundsConstraintSet; + AffineMap map; + ValueDimList mapOperands; + }; + /// The stop condition when traversing the backward slice of a shaped value/ /// index-type value. The traversal continues until the stop condition /// evaluates to "true" for a value. @@ -121,35 +155,31 @@ class ValueBoundsConstraintSet using StopConditionFn = std::function<bool( Value, std::optional<int64_t> /*dim*/, ValueBoundsConstraintSet &cstr)>; - /// Compute a bound for the given index-typed value or shape dimension size. - /// The computed bound is stored in `resultMap`. The operands of the bound are - /// stored in `mapOperands`. An operand is either an index-type SSA value - /// or a shaped value and a dimension. + /// Compute a bound for the given variable. The computed bound is stored in + /// `resultMap`. The operands of the bound are stored in `mapOperands`. An + /// operand is either an index-type SSA value or a shaped value and a + /// dimension. /// - /// `dim` must be `nullopt` if and only if `value` is index-typed. The bound - /// is computed in terms of values/dimensions for which `stopCondition` - /// evaluates to "true". To that end, the backward slice (reverse use-def - /// chain) of the given value is visited in a worklist-driven manner and the - /// constraint set is populated according to `ValueBoundsOpInterface` for each - /// visited value. + /// The bound is computed in terms of values/dimensions for which + /// `stopCondition` evaluates to "true". To that end, the backward slice + /// (reverse use-def chain) of the given value is visited in a worklist-driven + /// manner and the constraint set is populated according to + /// `ValueBoundsOpInterface` for each visited value. /// /// By default, lower/equal bounds are closed and upper bounds are open. If /// `closedUB` is set to "true", upper bounds are also closed. - static LogicalResult computeBound(AffineMap &resultMap, - ValueDimList &mapOperands, - presburger::BoundType type, Value value, - std::optional<int64_t> dim, - StopConditionFn stopCondition, - bool closedUB = false); + static LogicalResult + computeBound(AffineMap &resultMap, ValueDimList &mapOperands, + presburger::BoundType type, const Variable &var, + StopConditionFn stopCondition, bool closedUB = false); /// Compute a bound in terms of the values/dimensions in `dependencies`. The /// computed bound consists of only constant terms and dependent values (or /// dimension sizes thereof). static LogicalResult computeDependentBound(AffineMap &resultMap, ValueDimList &mapOperands, - presburger::BoundType type, Value value, - std::optional<int64_t> dim, ValueDimList dependencies, - bool closedUB = false); + presburger::BoundType type, const Variable &var, + ValueDimList dependencies, bool closedUB = false); /// Compute a bound in that is independent of all values in `independencies`. /// @@ -161,13 +191,10 @@ class ValueBoundsConstraintSet /// appear in the computed bound. static LogicalResult computeIndependentBound(AffineMap &resultMap, ValueDimList &mapOperands, - presburger::BoundType type, Value value, - std::optional<int64_t> dim, ValueRange independencies, - bool closedUB = false); + presburger::BoundType type, const Variable &var, + ValueRange independencies, bool closedUB = false); - /// Compute a constant bound for the given affine map, where dims and symbols - /// are bound to the given operands. The affine map must have exactly one - /// result. + /// Compute a constant bound for the given variable. /// /// This function traverses the backward slice of the given operands in a /// worklist-driven manner until `stopCondition` evaluates to "true". The @@ -182,16 +209,9 @@ class ValueBoundsConstraintSet /// By default, lower/equal bounds are closed and upper bounds are open. If /// `closedUB` is set to "true", upper bounds are also closed. static FailureOr<int64_t> - computeConstantBound(presburger::BoundType type, Value value, - std::optional<int64_t> dim = std::nullopt, + computeConstantBound(presburger::BoundType type, const Variable &var, StopConditionFn stopCondition = nullptr, bool closedUB = false); - static FailureOr<int64_t> computeConstantBound( - presburger::BoundType type, AffineMap map, ValueDimList mapOperands, - StopConditionFn stopCondition = nullptr, bool closedUB = false); - static FailureOr<int64_t> computeConstantBound( - presburger::BoundType type, AffineMap map, ArrayRef<Value> mapOperands, - StopConditionFn stopCondition = nullptr, bool closedUB = false); /// Compute a constant delta between the given two values. Return "failure" /// if a constant delta could not be determined. @@ -221,9 +241,7 @@ class ValueBoundsConstraintSet /// proven. This could be because the specified relation does in fact not hold /// or because there is not enough information in the constraint set. In other /// words, if we do not know for sure, this function returns "false". - bool populateAndCompare(OpFoldResult lhs, std::optional<int64_t> lhsDim, - ComparisonOperator cmp, OpFoldResult rhs, - std::optional<int64_t> rhsDim); + bool populateAndCompare(Variable lhs, ComparisonOperator cmp, Variable rhs); /// Return "true" if "lhs cmp rhs" was proven to hold. Return "false" if the /// specified relation could not be proven. This could be because the @@ -233,24 +251,11 @@ class ValueBoundsConstraintSet /// /// This function keeps traversing the backward slice of lhs/rhs until could /// prove the relation or until it ran out of IR. - static bool compare(OpFoldResult lhs, std::optional<int64_t> lhsDim, - ComparisonOperator cmp, OpFoldResult rhs, - std::optional<int64_t> rhsDim); - static bool compare(AffineMap lhs, ValueDimList lhsOperands, - ComparisonOperator cmp, AffineMap rhs, - ValueDimList rhsOperands); - static bool compare(AffineMap lhs, ArrayRef<Value> lhsOperands, - ComparisonOperator cmp, AffineMap rhs, - ArrayRef<Value> rhsOperands); - - /// Compute whether the given values/dimensions are equal. Return "failure" if + static bool compare(Variable lhs, ComparisonOperator cmp, Variable rhs); + + /// Compute whether the given variables are equal. Return "failure" if /// equality could not be determined. - /// - /// `dim1`/`dim2` must be `nullopt` if and only if `value1`/`value2` are - /// index-typed. - static FailureOr<bool> areEqual(OpFoldResult value1, OpFoldResult value2, - std::optional<int64_t> dim1 = std::nullopt, - std::optional<int64_t> dim2 = std::nullopt); + static FailureOr<bool> areEqual(Variable var1, Variable var2); /// Return "true" if the given slices are guaranteed to be overlapping. /// Return "false" if the given slices are guaranteed to be non-overlapping. @@ -317,9 +322,6 @@ class ValueBoundsConstraintSet /// /// This function does not analyze any IR and does not populate any additional /// constraints. - bool compareValueDims(OpFoldResult lhs, std::optional<int64_t> lhsDim, - ComparisonOperator cmp, OpFoldResult rhs, - std::optional<int64_t> rhsDim); bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos); /// Given an affine map with a single result (and map operands), add a new @@ -374,6 +376,7 @@ class ValueBoundsConstraintSet /// constraint system. Return the position of the new column. Any operands /// that were not analyzed yet are put on the worklist. int64_t insert(AffineMap map, ValueDimList operands, bool isSymbol = true); + int64_t insert(const Variable &var, bool isSymbol = true); /// Project out the given column in the constraint set. void projectOut(int64_t pos); @@ -381,6 +384,8 @@ class ValueBoundsConstraintSet /// Project out all columns for which the condition holds. void projectOut(function_ref<bool(ValueDim)> condition); + void projectOutAnonymous(std::optional<int64_t> except = std::nullopt); + /// Mapping of columns to values/shape dimensions. SmallVector<std::optional<ValueDim>> positionToValueDim; /// Reverse mapping of values/shape dimensions to columns. diff --git a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp index e0c3abe7a0f71d1..82a9fb0d490882f 100644 --- a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp @@ -120,9 +120,7 @@ mlir::affine::fullyComposeAndComputeConstantDelta(Value value1, Value value2) { mapOperands.push_back(value1); mapOperands.push_back(value2); affine::fullyComposeAffineMapAndOperands(&map, &mapOperands); - ValueDimList valueDims; - for (Value v : mapOperands) - valueDims.push_back({v, std::nullopt}); return ValueBoundsConstraintSet::computeConstantBound( - presburger::BoundType::EQ, map, valueDims); + presburger::BoundType::EQ, + ValueBoundsConstraintSet::Variable(map, mapOperands)); } diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp index 117ee8e8701ad7c..6c59df91e8af781 100644 --- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp @@ -25,7 +25,7 @@ reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type, AffineMap boundMap; ValueDimList mapOperands; if (failed(ValueBoundsConstraintSet::computeBound( - boundMap, mapOperands, type, value, dim, stopCondition, closedUB))) + boundMap, mapOperands, type, {value, dim}, stopCondition, closedUB))) return failure(); // Reify bound. diff --git a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp index 90895e381c74b5a..411fc117a4d9f5d 100644 --- a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp @@ -75,6 +75,75 @@ struct MulIOpInterface } }; +struct SelectOpInterface + : public ValueBoundsOpInterface::ExternalModel<SelectOpInterface, + SelectOp> { + + static void populateBounds(SelectOp selectOp, std::optional<int64_t> dim, + ValueBoundsConstraintSet &cstr) { + Value value = selectOp.getResult(); + Value condition = selectOp.getCondition(); + Value trueValue = selectOp.getTrueValue(); + Value falseValue = selectOp.getFalseValue(); + + if (isa<ShapedType>(condition.getType())) { + // If the condition is a shaped type, the condition is applied + // element-wise. All three operands must have the same shape. + cstr.bound(value)[*dim] == cstr.getExpr(trueValue, dim); + cstr.bound(value)[*dim] == cstr.getExpr(falseValue, dim); + cstr.bound(value)[*dim] == cstr.getExpr(condition, dim); + return; + } + + // Populate constraints for the true/false values (and all values on the + // backward slice, as long as the current stop condition is not satisfied). + cstr.populateConstraints(trueValue, dim); + cstr.populateConstraints(falseValue, dim); + auto boundsBuilder = cstr.bound(value); + if (dim) + boundsBuilder[*dim]; + + // Compare yielded values. + // If trueValue <= falseValue: + // * result <= falseValue + // * result >= trueValue + if (cstr.compare(/*lhs=*/{trueValue, dim}, + ValueBoundsConstraintSet::ComparisonOperator::LE, + /*rhs=*/{falseValue, dim})) { + if (dim) { + cstr.bound(value)[*dim] >= cstr.getExpr(trueValue, dim); + cstr.bound(value)[*dim] <= cstr.getExpr(falseValue, dim); + } else { + cstr.bound(value) >= trueValue; + cstr.bound(value) <= falseValue; + } + } + // If falseValue <= trueValue: + // * result <= trueValue + // * result >= falseValue + if (cstr.compare(/*lhs=*/{falseValue, dim}, + ValueBoundsConstraintSet::ComparisonOperator::LE, + /*rhs=*/{trueValue, dim})) { + if (dim) { + cstr.bound(value)[*dim] >= cstr.getExpr(falseValue, dim); + cstr.bound(value)[*dim] <= cstr.getExpr(trueValue, dim); + } else { + cstr.bound(value) >= falseValue; + cstr.bound(value) <= trueValue; + } + } + } + + void populateBoundsForIndexValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { + populateBounds(cast<SelectOp>(op), /*dim=*/std::nullopt, cstr); + } + + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + populateBounds(cast<SelectOp>(op), dim, cstr); + } +}; } // namespace } // namespace arith } // namespace mlir diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp index 79fabd6ed2e99a2..f87f3d6350c0221 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp @@ -449,7 +449,7 @@ struct IndexCastPattern final : NarrowingPattern<CastOp> { return failure(); FailureOr<int64_t> ub = ValueBoundsConstraintSet::computeConstantBound( - presburger::BoundType::UB, in, /*dim=*/std::nullopt, + presburger::BoundType::UB, in, /*stopCondition=*/nullptr, /*closedUB=*/true); if (failed(ub)) return failure(); diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp index fad221288f190ed..5bb7d83bf1e3f86 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp @@ -70,7 +70,9 @@ reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type, AffineMap boundMap; ValueDimList mapOperands; if (failed(ValueBoundsConstraintSet::computeBound( - boundMap, mapOperands, type, value, dim, stopCondition, closedUB))) + boundMap, mapOperands, type, + ValueBoundsConstraintSet::Variable(value, dim), stopCondition, + closedUB))) return failure(); // Materialize tensor.dim/memref.dim ops. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp index 8c4b70db2489897..518d2e138c02a97 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp @@ -72,8 +72,10 @@ static LogicalResult computePaddedShape(linalg::LinalgOp opToPad, // Otherwise, try to compute a constant upper bound for the size value. FailureOr<int64_t> upperBound = ValueBoundsConstraintSet::computeConstantBound( - presburger::BoundType::UB, opOperand->get(), - /*dim=*/i, /*stopCondition=*/nullptr, /*closedUB=*/true); + presburger::BoundType::UB, + {opOperand->get(), + /*dim=*/i}, + /*stopCondition=*/nullptr, /*closedUB=*/true); if (failed(upperBound)) { LLVM_DEBUG(DBGS() << "----could not compute a bounding box for padding"); return failure(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp index ac896d6c30d049d..71eb59d40836c1f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -257,14 +257,12 @@ FailureOr<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer( if (auto attr = llvm::dyn_cast_if_present<Attribute>(rangeValue.size)) { size = getValueOrCreateConstantIndexOp(b, loc, rangeValue.size); } else { - Value materializedSize = - getValueOrCreateConstantIndexOp(b, loc, rangeValue.size); FailureOr<int64_t> upperBound = ValueBoundsConstraintSet::computeConstantBound( - presburger::BoundType::UB, materializedSize, /*dim=*/std::nullopt, + presburger::BoundType::UB, rangeValue.size, /*stopCondition=*/nullptr, /*closedUB=*/true); size = failed(upperBound) - ? materializedSize + ? getValueOrCreateConstantIndexOp(b, loc, rangeValue.size) : b.create<arith::ConstantIndexOp>(loc, *upperBound); } LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n"); diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp index 10ba508265e7b9f..1f06318cbd60e04 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp @@ -23,12 +23,11 @@ static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc, ValueRange independencies) { if (ofr.is<Attribute>()) return ofr; - Value value = ofr.get<Value>(); AffineMap boundMap; ValueDimList mapOperands; if (failed(ValueBoundsConstraintSet::computeIndependentBound( - boundMap, mapOperands, presburger::BoundType::UB, value, - /*dim=*/std::nullopt, independencies, /*closedUB=*/true))) + boundMap, mapOperands, presburger::BoundType::UB, ofr, independencies, + /*closedUB=*/true))) return failure(); return affine::materializeComputedBound(b, loc, boundMap, mapOperands); } diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp index 087ffc438a830a3..17a1c016ea16d5a 100644 --- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp @@ -61,12 +61,13 @@ struct ForOpInterface // An EQ constraint can be added if the y... [truncated] `````````` </details> https://github.com/llvm/llvm-project/pull/87980 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits