https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/87980
This commit generalizes and cleans up the `ValueBoundsConstraintSet` API. The API used to provide function overloads for comparing/computing bounds of: - index-typed SSA value - dimension of shaped value - affine map + operands This commit removes all overloads. There is now a single entry point for each `compare` variant and each `computeBound` variant. These functions now take a `Variable`, which is internally represented as an affine map and map operands. This commit also adds support for computing bounds for an affine map + operands. There was previously no public API for that. WIP until I added a test case for `computeBounds(AffineMap)`. >From ed12ff5144bc1fe5013698ee19ffcab9f831d7eb Mon Sep 17 00:00:00 2001 From: Matthias Springer <spring...@google.com> Date: Mon, 8 Apr 2024 11:25:29 +0000 Subject: [PATCH] [mlir][Interfaces][WIP] `ValueBoundsOpInterface`: `Variable` --- .../mlir/Interfaces/ValueBoundsOpInterface.h | 117 +++--- .../Affine/IR/ValueBoundsOpInterfaceImpl.cpp | 6 +- .../Affine/Transforms/ReifyValueBounds.cpp | 2 +- .../Arith/IR/ValueBoundsOpInterfaceImpl.cpp | 69 ++++ .../Dialect/Arith/Transforms/IntNarrowing.cpp | 2 +- .../Arith/Transforms/ReifyValueBounds.cpp | 4 +- .../lib/Dialect/Linalg/Transforms/Padding.cpp | 6 +- .../Dialect/Linalg/Transforms/Promotion.cpp | 6 +- .../Transforms/IndependenceTransforms.cpp | 5 +- .../SCF/IR/ValueBoundsOpInterfaceImpl.cpp | 17 +- .../Tensor/IR/TensorTilingInterfaceImpl.cpp | 3 +- .../Transforms/IndependenceTransforms.cpp | 3 +- mlir/lib/Dialect/Tensor/Utils/Utils.cpp | 4 +- .../lib/Interfaces/ValueBoundsOpInterface.cpp | 337 ++++++++---------- .../Dialect/Affine/TestReifyValueBounds.cpp | 6 +- 15 files changed, 312 insertions(+), 275 deletions(-) diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h index 1d7bc6ea961cc3a..3e1502b4f5c357a 100644 --- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h +++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h @@ -15,6 +15,7 @@ #include "mlir/IR/Value.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Support/ExtensibleRTTI.h" #include <queue> @@ -111,6 +112,39 @@ class ValueBoundsConstraintSet public: static char ID; + /// A variable that can be added to the constraint set as a "column". The + /// value bounds infrastructure can compute bounds for variables and compare + /// two variables. + /// + /// Internally, a variable is represented as an affine map and operands. + class Variable { + public: + /// Construct a variable for an index-typed attribute or SSA value. + Variable(OpFoldResult ofr); + + /// Construct a variable for an index-typed SSA value. + Variable(Value indexValue); + + /// Construct a variable for a dimension of a shaped value. + Variable(Value shapedValue, int64_t dim); + + /// Construct a variable for an index-typed attribute/SSA value or for a + /// dimension of a shaped value. A non-null dimension must be provided if + /// and only if `ofr` is a shaped value. + Variable(OpFoldResult ofr, std::optional<int64_t> dim); + + /// Construct a variable for a map and its operands. + Variable(AffineMap map, ArrayRef<Variable> mapOperands); + Variable(AffineMap map, ArrayRef<Value> mapOperands); + + MLIRContext *getContext() const { return map.getContext(); } + + private: + friend class ValueBoundsConstraintSet; + AffineMap map; + ValueDimList mapOperands; + }; + /// The stop condition when traversing the backward slice of a shaped value/ /// index-type value. The traversal continues until the stop condition /// evaluates to "true" for a value. @@ -121,35 +155,31 @@ class ValueBoundsConstraintSet using StopConditionFn = std::function<bool( Value, std::optional<int64_t> /*dim*/, ValueBoundsConstraintSet &cstr)>; - /// Compute a bound for the given index-typed value or shape dimension size. - /// The computed bound is stored in `resultMap`. The operands of the bound are - /// stored in `mapOperands`. An operand is either an index-type SSA value - /// or a shaped value and a dimension. + /// Compute a bound for the given variable. The computed bound is stored in + /// `resultMap`. The operands of the bound are stored in `mapOperands`. An + /// operand is either an index-type SSA value or a shaped value and a + /// dimension. /// - /// `dim` must be `nullopt` if and only if `value` is index-typed. The bound - /// is computed in terms of values/dimensions for which `stopCondition` - /// evaluates to "true". To that end, the backward slice (reverse use-def - /// chain) of the given value is visited in a worklist-driven manner and the - /// constraint set is populated according to `ValueBoundsOpInterface` for each - /// visited value. + /// The bound is computed in terms of values/dimensions for which + /// `stopCondition` evaluates to "true". To that end, the backward slice + /// (reverse use-def chain) of the given value is visited in a worklist-driven + /// manner and the constraint set is populated according to + /// `ValueBoundsOpInterface` for each visited value. /// /// By default, lower/equal bounds are closed and upper bounds are open. If /// `closedUB` is set to "true", upper bounds are also closed. - static LogicalResult computeBound(AffineMap &resultMap, - ValueDimList &mapOperands, - presburger::BoundType type, Value value, - std::optional<int64_t> dim, - StopConditionFn stopCondition, - bool closedUB = false); + static LogicalResult + computeBound(AffineMap &resultMap, ValueDimList &mapOperands, + presburger::BoundType type, const Variable &var, + StopConditionFn stopCondition, bool closedUB = false); /// Compute a bound in terms of the values/dimensions in `dependencies`. The /// computed bound consists of only constant terms and dependent values (or /// dimension sizes thereof). static LogicalResult computeDependentBound(AffineMap &resultMap, ValueDimList &mapOperands, - presburger::BoundType type, Value value, - std::optional<int64_t> dim, ValueDimList dependencies, - bool closedUB = false); + presburger::BoundType type, const Variable &var, + ValueDimList dependencies, bool closedUB = false); /// Compute a bound in that is independent of all values in `independencies`. /// @@ -161,13 +191,10 @@ class ValueBoundsConstraintSet /// appear in the computed bound. static LogicalResult computeIndependentBound(AffineMap &resultMap, ValueDimList &mapOperands, - presburger::BoundType type, Value value, - std::optional<int64_t> dim, ValueRange independencies, - bool closedUB = false); + presburger::BoundType type, const Variable &var, + ValueRange independencies, bool closedUB = false); - /// Compute a constant bound for the given affine map, where dims and symbols - /// are bound to the given operands. The affine map must have exactly one - /// result. + /// Compute a constant bound for the given variable. /// /// This function traverses the backward slice of the given operands in a /// worklist-driven manner until `stopCondition` evaluates to "true". The @@ -182,16 +209,9 @@ class ValueBoundsConstraintSet /// By default, lower/equal bounds are closed and upper bounds are open. If /// `closedUB` is set to "true", upper bounds are also closed. static FailureOr<int64_t> - computeConstantBound(presburger::BoundType type, Value value, - std::optional<int64_t> dim = std::nullopt, + computeConstantBound(presburger::BoundType type, const Variable &var, StopConditionFn stopCondition = nullptr, bool closedUB = false); - static FailureOr<int64_t> computeConstantBound( - presburger::BoundType type, AffineMap map, ValueDimList mapOperands, - StopConditionFn stopCondition = nullptr, bool closedUB = false); - static FailureOr<int64_t> computeConstantBound( - presburger::BoundType type, AffineMap map, ArrayRef<Value> mapOperands, - StopConditionFn stopCondition = nullptr, bool closedUB = false); /// Compute a constant delta between the given two values. Return "failure" /// if a constant delta could not be determined. @@ -221,9 +241,7 @@ class ValueBoundsConstraintSet /// proven. This could be because the specified relation does in fact not hold /// or because there is not enough information in the constraint set. In other /// words, if we do not know for sure, this function returns "false". - bool populateAndCompare(OpFoldResult lhs, std::optional<int64_t> lhsDim, - ComparisonOperator cmp, OpFoldResult rhs, - std::optional<int64_t> rhsDim); + bool populateAndCompare(Variable lhs, ComparisonOperator cmp, Variable rhs); /// Return "true" if "lhs cmp rhs" was proven to hold. Return "false" if the /// specified relation could not be proven. This could be because the @@ -233,24 +251,11 @@ class ValueBoundsConstraintSet /// /// This function keeps traversing the backward slice of lhs/rhs until could /// prove the relation or until it ran out of IR. - static bool compare(OpFoldResult lhs, std::optional<int64_t> lhsDim, - ComparisonOperator cmp, OpFoldResult rhs, - std::optional<int64_t> rhsDim); - static bool compare(AffineMap lhs, ValueDimList lhsOperands, - ComparisonOperator cmp, AffineMap rhs, - ValueDimList rhsOperands); - static bool compare(AffineMap lhs, ArrayRef<Value> lhsOperands, - ComparisonOperator cmp, AffineMap rhs, - ArrayRef<Value> rhsOperands); - - /// Compute whether the given values/dimensions are equal. Return "failure" if + static bool compare(Variable lhs, ComparisonOperator cmp, Variable rhs); + + /// Compute whether the given variables are equal. Return "failure" if /// equality could not be determined. - /// - /// `dim1`/`dim2` must be `nullopt` if and only if `value1`/`value2` are - /// index-typed. - static FailureOr<bool> areEqual(OpFoldResult value1, OpFoldResult value2, - std::optional<int64_t> dim1 = std::nullopt, - std::optional<int64_t> dim2 = std::nullopt); + static FailureOr<bool> areEqual(Variable var1, Variable var2); /// Return "true" if the given slices are guaranteed to be overlapping. /// Return "false" if the given slices are guaranteed to be non-overlapping. @@ -317,9 +322,6 @@ class ValueBoundsConstraintSet /// /// This function does not analyze any IR and does not populate any additional /// constraints. - bool compareValueDims(OpFoldResult lhs, std::optional<int64_t> lhsDim, - ComparisonOperator cmp, OpFoldResult rhs, - std::optional<int64_t> rhsDim); bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos); /// Given an affine map with a single result (and map operands), add a new @@ -374,6 +376,7 @@ class ValueBoundsConstraintSet /// constraint system. Return the position of the new column. Any operands /// that were not analyzed yet are put on the worklist. int64_t insert(AffineMap map, ValueDimList operands, bool isSymbol = true); + int64_t insert(const Variable &var, bool isSymbol = true); /// Project out the given column in the constraint set. void projectOut(int64_t pos); @@ -381,6 +384,8 @@ class ValueBoundsConstraintSet /// Project out all columns for which the condition holds. void projectOut(function_ref<bool(ValueDim)> condition); + void projectOutAnonymous(std::optional<int64_t> except = std::nullopt); + /// Mapping of columns to values/shape dimensions. SmallVector<std::optional<ValueDim>> positionToValueDim; /// Reverse mapping of values/shape dimensions to columns. diff --git a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp index e0c3abe7a0f71d1..82a9fb0d490882f 100644 --- a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp @@ -120,9 +120,7 @@ mlir::affine::fullyComposeAndComputeConstantDelta(Value value1, Value value2) { mapOperands.push_back(value1); mapOperands.push_back(value2); affine::fullyComposeAffineMapAndOperands(&map, &mapOperands); - ValueDimList valueDims; - for (Value v : mapOperands) - valueDims.push_back({v, std::nullopt}); return ValueBoundsConstraintSet::computeConstantBound( - presburger::BoundType::EQ, map, valueDims); + presburger::BoundType::EQ, + ValueBoundsConstraintSet::Variable(map, mapOperands)); } diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp index 117ee8e8701ad7c..6c59df91e8af781 100644 --- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp @@ -25,7 +25,7 @@ reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type, AffineMap boundMap; ValueDimList mapOperands; if (failed(ValueBoundsConstraintSet::computeBound( - boundMap, mapOperands, type, value, dim, stopCondition, closedUB))) + boundMap, mapOperands, type, {value, dim}, stopCondition, closedUB))) return failure(); // Reify bound. diff --git a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp index 90895e381c74b5a..411fc117a4d9f5d 100644 --- a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp @@ -75,6 +75,75 @@ struct MulIOpInterface } }; +struct SelectOpInterface + : public ValueBoundsOpInterface::ExternalModel<SelectOpInterface, + SelectOp> { + + static void populateBounds(SelectOp selectOp, std::optional<int64_t> dim, + ValueBoundsConstraintSet &cstr) { + Value value = selectOp.getResult(); + Value condition = selectOp.getCondition(); + Value trueValue = selectOp.getTrueValue(); + Value falseValue = selectOp.getFalseValue(); + + if (isa<ShapedType>(condition.getType())) { + // If the condition is a shaped type, the condition is applied + // element-wise. All three operands must have the same shape. + cstr.bound(value)[*dim] == cstr.getExpr(trueValue, dim); + cstr.bound(value)[*dim] == cstr.getExpr(falseValue, dim); + cstr.bound(value)[*dim] == cstr.getExpr(condition, dim); + return; + } + + // Populate constraints for the true/false values (and all values on the + // backward slice, as long as the current stop condition is not satisfied). + cstr.populateConstraints(trueValue, dim); + cstr.populateConstraints(falseValue, dim); + auto boundsBuilder = cstr.bound(value); + if (dim) + boundsBuilder[*dim]; + + // Compare yielded values. + // If trueValue <= falseValue: + // * result <= falseValue + // * result >= trueValue + if (cstr.compare(/*lhs=*/{trueValue, dim}, + ValueBoundsConstraintSet::ComparisonOperator::LE, + /*rhs=*/{falseValue, dim})) { + if (dim) { + cstr.bound(value)[*dim] >= cstr.getExpr(trueValue, dim); + cstr.bound(value)[*dim] <= cstr.getExpr(falseValue, dim); + } else { + cstr.bound(value) >= trueValue; + cstr.bound(value) <= falseValue; + } + } + // If falseValue <= trueValue: + // * result <= trueValue + // * result >= falseValue + if (cstr.compare(/*lhs=*/{falseValue, dim}, + ValueBoundsConstraintSet::ComparisonOperator::LE, + /*rhs=*/{trueValue, dim})) { + if (dim) { + cstr.bound(value)[*dim] >= cstr.getExpr(falseValue, dim); + cstr.bound(value)[*dim] <= cstr.getExpr(trueValue, dim); + } else { + cstr.bound(value) >= falseValue; + cstr.bound(value) <= trueValue; + } + } + } + + void populateBoundsForIndexValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { + populateBounds(cast<SelectOp>(op), /*dim=*/std::nullopt, cstr); + } + + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + populateBounds(cast<SelectOp>(op), dim, cstr); + } +}; } // namespace } // namespace arith } // namespace mlir diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp index 79fabd6ed2e99a2..f87f3d6350c0221 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp @@ -449,7 +449,7 @@ struct IndexCastPattern final : NarrowingPattern<CastOp> { return failure(); FailureOr<int64_t> ub = ValueBoundsConstraintSet::computeConstantBound( - presburger::BoundType::UB, in, /*dim=*/std::nullopt, + presburger::BoundType::UB, in, /*stopCondition=*/nullptr, /*closedUB=*/true); if (failed(ub)) return failure(); diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp index fad221288f190ed..5bb7d83bf1e3f86 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp @@ -70,7 +70,9 @@ reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type, AffineMap boundMap; ValueDimList mapOperands; if (failed(ValueBoundsConstraintSet::computeBound( - boundMap, mapOperands, type, value, dim, stopCondition, closedUB))) + boundMap, mapOperands, type, + ValueBoundsConstraintSet::Variable(value, dim), stopCondition, + closedUB))) return failure(); // Materialize tensor.dim/memref.dim ops. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp index 8c4b70db2489897..518d2e138c02a97 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp @@ -72,8 +72,10 @@ static LogicalResult computePaddedShape(linalg::LinalgOp opToPad, // Otherwise, try to compute a constant upper bound for the size value. FailureOr<int64_t> upperBound = ValueBoundsConstraintSet::computeConstantBound( - presburger::BoundType::UB, opOperand->get(), - /*dim=*/i, /*stopCondition=*/nullptr, /*closedUB=*/true); + presburger::BoundType::UB, + {opOperand->get(), + /*dim=*/i}, + /*stopCondition=*/nullptr, /*closedUB=*/true); if (failed(upperBound)) { LLVM_DEBUG(DBGS() << "----could not compute a bounding box for padding"); return failure(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp index ac896d6c30d049d..71eb59d40836c1f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -257,14 +257,12 @@ FailureOr<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer( if (auto attr = llvm::dyn_cast_if_present<Attribute>(rangeValue.size)) { size = getValueOrCreateConstantIndexOp(b, loc, rangeValue.size); } else { - Value materializedSize = - getValueOrCreateConstantIndexOp(b, loc, rangeValue.size); FailureOr<int64_t> upperBound = ValueBoundsConstraintSet::computeConstantBound( - presburger::BoundType::UB, materializedSize, /*dim=*/std::nullopt, + presburger::BoundType::UB, rangeValue.size, /*stopCondition=*/nullptr, /*closedUB=*/true); size = failed(upperBound) - ? materializedSize + ? getValueOrCreateConstantIndexOp(b, loc, rangeValue.size) : b.create<arith::ConstantIndexOp>(loc, *upperBound); } LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n"); diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp index 10ba508265e7b9f..1f06318cbd60e04 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp @@ -23,12 +23,11 @@ static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc, ValueRange independencies) { if (ofr.is<Attribute>()) return ofr; - Value value = ofr.get<Value>(); AffineMap boundMap; ValueDimList mapOperands; if (failed(ValueBoundsConstraintSet::computeIndependentBound( - boundMap, mapOperands, presburger::BoundType::UB, value, - /*dim=*/std::nullopt, independencies, /*closedUB=*/true))) + boundMap, mapOperands, presburger::BoundType::UB, ofr, independencies, + /*closedUB=*/true))) return failure(); return affine::materializeComputedBound(b, loc, boundMap, mapOperands); } diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp index 087ffc438a830a3..17a1c016ea16d5a 100644 --- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp @@ -61,12 +61,13 @@ struct ForOpInterface // An EQ constraint can be added if the yielded value (dimension size) // equals the corresponding block argument (dimension size). if (cstr.populateAndCompare( - yieldedValue, dim, ValueBoundsConstraintSet::ComparisonOperator::EQ, - iterArg, dim)) { + /*lhs=*/{yieldedValue, dim}, + ValueBoundsConstraintSet::ComparisonOperator::EQ, + /*rhs=*/{iterArg, dim})) { if (dim.has_value()) { cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim); } else { - cstr.bound(value) == initArg; + cstr.bound(value) == cstr.getExpr(initArg); } } } @@ -113,8 +114,9 @@ struct IfOpInterface // * result <= elseValue // * result >= thenValue if (cstr.populateAndCompare( - thenValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE, - elseValue, dim)) { + /*lhs=*/{thenValue, dim}, + ValueBoundsConstraintSet::ComparisonOperator::LE, + /*rhs=*/{elseValue, dim})) { if (dim) { cstr.bound(value)[*dim] >= cstr.getExpr(thenValue, dim); cstr.bound(value)[*dim] <= cstr.getExpr(elseValue, dim); @@ -127,8 +129,9 @@ struct IfOpInterface // * result <= thenValue // * result >= elseValue if (cstr.populateAndCompare( - elseValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE, - thenValue, dim)) { + /*lhs=*/{elseValue, dim}, + ValueBoundsConstraintSet::ComparisonOperator::LE, + /*rhs=*/{thenValue, dim})) { if (dim) { cstr.bound(value)[*dim] >= cstr.getExpr(elseValue, dim); cstr.bound(value)[*dim] <= cstr.getExpr(thenValue, dim); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp index 67080d8e301c135..d25efcf50ec566f 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -289,8 +289,7 @@ static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp, info.isAlignedToInnerTileSize = false; FailureOr<int64_t> cstSize = ValueBoundsConstraintSet::computeConstantBound( - presburger::BoundType::UB, - getValueOrCreateConstantIndexOp(b, loc, tileSize), /*dim=*/std::nullopt, + presburger::BoundType::UB, tileSize, /*stopCondition=*/nullptr, /*closedUB=*/true); std::optional<int64_t> cstInnerSize = getConstantIntValue(innerTileSize); if (!failed(cstSize) && cstInnerSize) { diff --git a/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp index 721730862d49b37..a89ce20048dff3d 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp @@ -28,7 +28,8 @@ static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc, ValueDimList mapOperands; if (failed(ValueBoundsConstraintSet::computeIndependentBound( boundMap, mapOperands, presburger::BoundType::UB, value, - /*dim=*/std::nullopt, independencies, /*closedUB=*/true))) + independencies, + /*closedUB=*/true))) return failure(); return mlir::affine::materializeComputedBound(b, loc, boundMap, mapOperands); } diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp index 2dd91e2f7a17003..15381ec520e2119 100644 --- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp @@ -154,7 +154,7 @@ bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) { continue; } FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual( - op.getSource(), op.getResult(), srcDim, resultDim); + {op.getSource(), srcDim}, {op.getResult(), resultDim}); if (failed(equalDimSize) || !*equalDimSize) return false; ++srcDim; @@ -178,7 +178,7 @@ bool mlir::tensor::isCastLikeExtractSliceOp(ExtractSliceOp op) { continue; } FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual( - op.getSource(), op.getResult(), dim, resultDim); + {op.getSource(), dim}, {op.getResult(), resultDim}); if (failed(equalDimSize) || !*equalDimSize) return false; ++resultDim; diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp index fa66da4a0def937..9f220f5f6ceb729 100644 --- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp +++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp @@ -25,6 +25,12 @@ namespace mlir { #include "mlir/Interfaces/ValueBoundsOpInterface.cpp.inc" } // namespace mlir +static Operation *getOwnerOfValue(Value value) { + if (auto bbArg = dyn_cast<BlockArgument>(value)) + return bbArg.getOwner()->getParentOp(); + return value.getDefiningOp(); +} + HyperrectangularSlice::HyperrectangularSlice(ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) @@ -67,6 +73,83 @@ static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) { return std::nullopt; } +ValueBoundsConstraintSet::Variable::Variable(OpFoldResult ofr) + : Variable(ofr, std::nullopt) {} + +ValueBoundsConstraintSet::Variable::Variable(Value indexValue) + : Variable(static_cast<OpFoldResult>(indexValue)) {} + +ValueBoundsConstraintSet::Variable::Variable(Value shapedValue, int64_t dim) + : Variable(static_cast<OpFoldResult>(shapedValue), std::optional(dim)) {} + +ValueBoundsConstraintSet::Variable::Variable(OpFoldResult ofr, + std::optional<int64_t> dim) { + Builder b(ofr.getContext()); + if (auto constInt = ::getConstantIntValue(ofr)) { + assert(!dim && "expected no dim for index-typed values"); + map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0, + b.getAffineConstantExpr(*constInt)); + return; + } + Value value = cast<Value>(ofr); +#ifndef NDEBUG + if (dim) { + assert(isa<ShapedType>(value.getType()) && "expected shaped type"); + } else { + assert(value.getType().isIndex() && "expected index type"); + } +#endif // NDEBUG + map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1, + b.getAffineSymbolExpr(0)); + mapOperands.emplace_back(value, dim); +} + +ValueBoundsConstraintSet::Variable::Variable(AffineMap map, + ArrayRef<Variable> mapOperands) { + assert(map.getNumResults() == 1 && "expected single result"); + + // Turn all dims into symbols. + Builder b(map.getContext()); + SmallVector<AffineExpr> dimReplacements, symReplacements; + for (int64_t i = 0; i < map.getNumDims(); ++i) + dimReplacements.push_back(b.getAffineSymbolExpr(i)); + for (int64_t i = 0; i < map.getNumSymbols(); ++i) + symReplacements.push_back(b.getAffineSymbolExpr(i + map.getNumDims())); + AffineMap tmpMap = map.replaceDimsAndSymbols( + dimReplacements, symReplacements, /*numResultDims=*/0, + /*numResultSyms=*/map.getNumSymbols() + map.getNumDims()); + + // Inline operands. + DenseMap<AffineExpr, AffineExpr> replacements; + for (auto [index, var] : llvm::enumerate(mapOperands)) { + assert(var.map.getNumResults() == 1 && "expected single result"); + assert(var.map.getNumDims() == 0 && "expected only symbols"); + SmallVector<AffineExpr> symReplacements; + for (auto valueDim : var.mapOperands) { + auto it = llvm::find(this->mapOperands, valueDim); + if (it != this->mapOperands.end()) { + // There is already a symbol for this operand. + symReplacements.push_back(b.getAffineSymbolExpr( + std::distance(this->mapOperands.begin(), it))); + } else { + // This is a new operand: add a new symbol. + symReplacements.push_back( + b.getAffineSymbolExpr(this->mapOperands.size())); + this->mapOperands.push_back(valueDim); + } + } + replacements[b.getAffineSymbolExpr(index)] = + var.map.getResult(0).replaceSymbols(symReplacements); + } + this->map = tmpMap.replace(replacements, /*numResultDims=*/0, + /*numResultSyms=*/this->mapOperands.size()); +} + +ValueBoundsConstraintSet::Variable::Variable(AffineMap map, + ArrayRef<Value> mapOperands) + : Variable(map, llvm::map_to_vector(mapOperands, + [](Value v) { return Variable(v); })) {} + ValueBoundsConstraintSet::ValueBoundsConstraintSet( MLIRContext *ctx, StopConditionFn stopCondition) : builder(ctx), stopCondition(stopCondition) { @@ -176,6 +259,11 @@ int64_t ValueBoundsConstraintSet::insert(Value value, assert(!valueDimToPosition.contains(valueDim) && "already mapped"); int64_t pos = isSymbol ? cstr.appendVar(VarKind::Symbol) : cstr.appendVar(VarKind::SetDim); + LLVM_DEBUG(llvm::dbgs() << "Inserting constraint set column " << pos + << " for: " << value + << " (dim: " << dim.value_or(kIndexValue) + << ", owner: " << getOwnerOfValue(value)->getName() + << ")\n"); positionToValueDim.insert(positionToValueDim.begin() + pos, valueDim); // Update reverse mapping. for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i) @@ -194,6 +282,8 @@ int64_t ValueBoundsConstraintSet::insert(Value value, int64_t ValueBoundsConstraintSet::insert(bool isSymbol) { int64_t pos = isSymbol ? cstr.appendVar(VarKind::Symbol) : cstr.appendVar(VarKind::SetDim); + LLVM_DEBUG(llvm::dbgs() << "Inserting anonymous constraint set column " << pos + << "\n"); positionToValueDim.insert(positionToValueDim.begin() + pos, std::nullopt); // Update reverse mapping. for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i) @@ -224,6 +314,10 @@ int64_t ValueBoundsConstraintSet::insert(AffineMap map, ValueDimList operands, return pos; } +int64_t ValueBoundsConstraintSet::insert(const Variable &var, bool isSymbol) { + return insert(var.map, var.mapOperands, isSymbol); +} + int64_t ValueBoundsConstraintSet::getPos(Value value, std::optional<int64_t> dim) const { #ifndef NDEBUG @@ -232,7 +326,10 @@ int64_t ValueBoundsConstraintSet::getPos(Value value, cast<BlockArgument>(value).getOwner()->isEntryBlock()) && "unstructured control flow is not supported"); #endif // NDEBUG - + LLVM_DEBUG(llvm::dbgs() << "Getting pos for: " << value + << " (dim: " << dim.value_or(kIndexValue) + << ", owner: " << getOwnerOfValue(value)->getName() + << ")\n"); auto it = valueDimToPosition.find(std::make_pair(value, dim.value_or(kIndexValue))); assert(it != valueDimToPosition.end() && "expected mapped entry"); @@ -253,12 +350,6 @@ bool ValueBoundsConstraintSet::isMapped(Value value, return it != valueDimToPosition.end(); } -static Operation *getOwnerOfValue(Value value) { - if (auto bbArg = dyn_cast<BlockArgument>(value)) - return bbArg.getOwner()->getParentOp(); - return value.getDefiningOp(); -} - void ValueBoundsConstraintSet::processWorklist() { LLVM_DEBUG(llvm::dbgs() << "Processing value bounds worklist...\n"); while (!worklist.empty()) { @@ -346,41 +437,47 @@ void ValueBoundsConstraintSet::projectOut( } } +void ValueBoundsConstraintSet::projectOutAnonymous( + std::optional<int64_t> except) { + int64_t nextPos = 0; + while (nextPos < static_cast<int64_t>(positionToValueDim.size())) { + if (positionToValueDim[nextPos].has_value() || except == nextPos) { + ++nextPos; + } else { + projectOut(nextPos); + // The column was projected out so another column is now at that position. + // Do not increase the counter. + } + } +} + LogicalResult ValueBoundsConstraintSet::computeBound( AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type, - Value value, std::optional<int64_t> dim, StopConditionFn stopCondition, - bool closedUB) { -#ifndef NDEBUG - assertValidValueDim(value, dim); -#endif // NDEBUG - + const Variable &var, StopConditionFn stopCondition, bool closedUB) { + MLIRContext *ctx = var.getContext(); int64_t ubAdjustment = closedUB ? 0 : 1; - Builder b(value.getContext()); + Builder b(ctx); mapOperands.clear(); // Process the backward slice of `value` (i.e., reverse use-def chain) until // `stopCondition` is met. - ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue)); - ValueBoundsConstraintSet cstr(value.getContext(), stopCondition); - assert(!stopCondition(value, dim, cstr) && - "stop condition should not be satisfied for starting point"); - int64_t pos = cstr.insert(value, dim, /*isSymbol=*/false); + ValueBoundsConstraintSet cstr(ctx, stopCondition); + int64_t pos = cstr.insert(var, /*isSymbol=*/false); + assert(pos == 0 && "expected first column"); cstr.processWorklist(); // Project out all variables (apart from `valueDim`) that do not match the // stop condition. cstr.projectOut([&](ValueDim p) { - // Do not project out `valueDim`. - if (valueDim == p) - return false; auto maybeDim = p.second == kIndexValue ? std::nullopt : std::make_optional(p.second); return !stopCondition(p.first, maybeDim, cstr); }); + cstr.projectOutAnonymous(/*except=*/pos); // Compute lower and upper bounds for `valueDim`. SmallVector<AffineMap> lb(1), ub(1); - cstr.cstr.getSliceBounds(pos, 1, value.getContext(), &lb, &ub, + cstr.cstr.getSliceBounds(pos, 1, ctx, &lb, &ub, /*closedUB=*/true); // Note: There are TODOs in the implementation of `getSliceBounds`. In such a @@ -477,10 +574,9 @@ LogicalResult ValueBoundsConstraintSet::computeBound( LogicalResult ValueBoundsConstraintSet::computeDependentBound( AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type, - Value value, std::optional<int64_t> dim, ValueDimList dependencies, - bool closedUB) { + const Variable &var, ValueDimList dependencies, bool closedUB) { return computeBound( - resultMap, mapOperands, type, value, dim, + resultMap, mapOperands, type, var, [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) { return llvm::is_contained(dependencies, std::make_pair(v, d)); }, @@ -489,8 +585,7 @@ LogicalResult ValueBoundsConstraintSet::computeDependentBound( LogicalResult ValueBoundsConstraintSet::computeIndependentBound( AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type, - Value value, std::optional<int64_t> dim, ValueRange independencies, - bool closedUB) { + const Variable &var, ValueRange independencies, bool closedUB) { // Return "true" if the given value is independent of all values in // `independencies`. I.e., neither the value itself nor any value in the // backward slice (reverse use-def chain) is contained in `independencies`. @@ -516,7 +611,7 @@ LogicalResult ValueBoundsConstraintSet::computeIndependentBound( // Reify bounds in terms of any independent values. return computeBound( - resultMap, mapOperands, type, value, dim, + resultMap, mapOperands, type, var, [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) { return isIndependent(v); }, @@ -524,35 +619,8 @@ LogicalResult ValueBoundsConstraintSet::computeIndependentBound( } FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound( - presburger::BoundType type, Value value, std::optional<int64_t> dim, - StopConditionFn stopCondition, bool closedUB) { -#ifndef NDEBUG - assertValidValueDim(value, dim); -#endif // NDEBUG - - AffineMap map = - AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, - Builder(value.getContext()).getAffineDimExpr(0)); - return computeConstantBound(type, map, {{value, dim}}, stopCondition, - closedUB); -} - -FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound( - presburger::BoundType type, AffineMap map, ArrayRef<Value> operands, + presburger::BoundType type, const Variable &var, StopConditionFn stopCondition, bool closedUB) { - ValueDimList valueDims; - for (Value v : operands) { - assert(v.getType().isIndex() && "expected index type"); - valueDims.emplace_back(v, std::nullopt); - } - return computeConstantBound(type, map, valueDims, stopCondition, closedUB); -} - -FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound( - presburger::BoundType type, AffineMap map, ValueDimList operands, - StopConditionFn stopCondition, bool closedUB) { - assert(map.getNumResults() == 1 && "expected affine map with one result"); - // Default stop condition if none was specified: Keep adding constraints until // a bound could be computed. int64_t pos = 0; @@ -562,8 +630,8 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound( }; ValueBoundsConstraintSet cstr( - map.getContext(), stopCondition ? stopCondition : defaultStopCondition); - pos = cstr.populateConstraints(map, operands); + var.getContext(), stopCondition ? stopCondition : defaultStopCondition); + pos = cstr.populateConstraints(var.map, var.mapOperands); assert(pos == 0 && "expected `map` is the first column"); // Compute constant bound for `valueDim`. @@ -608,22 +676,13 @@ ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2, Builder b(value1.getContext()); AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, b.getAffineDimExpr(0) - b.getAffineDimExpr(1)); - return computeConstantBound(presburger::BoundType::EQ, map, - {{value1, dim1}, {value2, dim2}}); + return computeConstantBound(presburger::BoundType::EQ, + Variable(map, {{value1, dim1}, {value2, dim2}})); } -bool ValueBoundsConstraintSet::compareValueDims(OpFoldResult lhs, - std::optional<int64_t> lhsDim, - ComparisonOperator cmp, - OpFoldResult rhs, - std::optional<int64_t> rhsDim) { -#ifndef NDEBUG - if (auto lhsVal = dyn_cast<Value>(lhs)) - assertValidValueDim(lhsVal, lhsDim); - if (auto rhsVal = dyn_cast<Value>(rhs)) - assertValidValueDim(rhsVal, rhsDim); -#endif // NDEBUG - +bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos, + ComparisonOperator cmp, + int64_t rhsPos) { // This function returns "true" if "lhs CMP rhs" is proven to hold. // // Example for ComparisonOperator::LE and index-typed values: We would like to @@ -640,50 +699,6 @@ bool ValueBoundsConstraintSet::compareValueDims(OpFoldResult lhs, return false; } - // EQ can be expressed as LE and GE. - if (cmp == EQ) - return compareValueDims(lhs, lhsDim, ComparisonOperator::LE, rhs, rhsDim) && - compareValueDims(lhs, lhsDim, ComparisonOperator::GE, rhs, rhsDim); - - // Construct inequality. For the above example: lhs > rhs. - // `IntegerRelation` inequalities are expressed in the "flattened" form and - // with ">= 0". I.e., lhs - rhs - 1 >= 0. - SmallVector<int64_t> eq(cstr.getNumCols(), 0); - auto addToEq = [&](OpFoldResult ofr, std::optional<int64_t> dim, - int64_t factor) { - if (auto constVal = ::getConstantIntValue(ofr)) { - eq[cstr.getNumCols() - 1] += *constVal * factor; - } else { - eq[getPos(cast<Value>(ofr), dim)] += factor; - } - }; - if (cmp == LT || cmp == LE) { - addToEq(lhs, lhsDim, 1); - addToEq(rhs, rhsDim, -1); - } else if (cmp == GT || cmp == GE) { - addToEq(lhs, lhsDim, -1); - addToEq(rhs, rhsDim, 1); - } else { - llvm_unreachable("unsupported comparison operator"); - } - if (cmp == LE || cmp == GE) - eq[cstr.getNumCols() - 1] -= 1; - - // Add inequality to the constraint set and check if it made the constraint - // set empty. - int64_t ineqPos = cstr.getNumInequalities(); - cstr.addInequality(eq); - bool isEmpty = cstr.isEmpty(); - cstr.removeInequality(ineqPos); - return isEmpty; -} - -bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos, - ComparisonOperator cmp, - int64_t rhsPos) { - // This function returns "true" if "lhs CMP rhs" is proven to hold. For - // detailed documentation, see `compareValueDims`. - // EQ can be expressed as LE and GE. if (cmp == EQ) return comparePos(lhsPos, ComparisonOperator::LE, rhsPos) && @@ -712,48 +727,16 @@ bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos, return isEmpty; } -bool ValueBoundsConstraintSet::populateAndCompare( - OpFoldResult lhs, std::optional<int64_t> lhsDim, ComparisonOperator cmp, - OpFoldResult rhs, std::optional<int64_t> rhsDim) { -#ifndef NDEBUG - if (auto lhsVal = dyn_cast<Value>(lhs)) - assertValidValueDim(lhsVal, lhsDim); - if (auto rhsVal = dyn_cast<Value>(rhs)) - assertValidValueDim(rhsVal, rhsDim); -#endif // NDEBUG - - if (auto lhsVal = dyn_cast<Value>(lhs)) - populateConstraints(lhsVal, lhsDim); - if (auto rhsVal = dyn_cast<Value>(rhs)) - populateConstraints(rhsVal, rhsDim); - - return compareValueDims(lhs, lhsDim, cmp, rhs, rhsDim); +bool ValueBoundsConstraintSet::populateAndCompare(Variable lhs, + ComparisonOperator cmp, + Variable rhs) { + int64_t lhsPos = populateConstraints(lhs.map, lhs.mapOperands); + int64_t rhsPos = populateConstraints(rhs.map, rhs.mapOperands); + return comparePos(lhsPos, cmp, rhsPos); } -bool ValueBoundsConstraintSet::compare(OpFoldResult lhs, - std::optional<int64_t> lhsDim, - ComparisonOperator cmp, OpFoldResult rhs, - std::optional<int64_t> rhsDim) { - auto stopCondition = [&](Value v, std::optional<int64_t> dim, - ValueBoundsConstraintSet &cstr) { - // Keep processing as long as lhs/rhs are not mapped. - if (auto lhsVal = dyn_cast<Value>(lhs)) - if (!cstr.isMapped(lhsVal, dim)) - return false; - if (auto rhsVal = dyn_cast<Value>(rhs)) - if (!cstr.isMapped(rhsVal, dim)) - return false; - // Keep processing as long as the relation cannot be proven. - return cstr.compareValueDims(lhs, lhsDim, cmp, rhs, rhsDim); - }; - - ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition); - return cstr.populateAndCompare(lhs, lhsDim, cmp, rhs, rhsDim); -} - -bool ValueBoundsConstraintSet::compare(AffineMap lhs, ValueDimList lhsOperands, - ComparisonOperator cmp, AffineMap rhs, - ValueDimList rhsOperands) { +bool ValueBoundsConstraintSet::compare(Variable lhs, ComparisonOperator cmp, + Variable rhs) { int64_t lhsPos = -1, rhsPos = -1; auto stopCondition = [&](Value v, std::optional<int64_t> dim, ValueBoundsConstraintSet &cstr) { @@ -765,39 +748,17 @@ bool ValueBoundsConstraintSet::compare(AffineMap lhs, ValueDimList lhsOperands, return cstr.comparePos(lhsPos, cmp, rhsPos); }; ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition); - lhsPos = cstr.insert(lhs, lhsOperands); - rhsPos = cstr.insert(rhs, rhsOperands); - cstr.processWorklist(); + lhsPos = cstr.populateConstraints(lhs.map, lhs.mapOperands); + rhsPos = cstr.populateConstraints(rhs.map, rhs.mapOperands); return cstr.comparePos(lhsPos, cmp, rhsPos); } -bool ValueBoundsConstraintSet::compare(AffineMap lhs, - ArrayRef<Value> lhsOperands, - ComparisonOperator cmp, AffineMap rhs, - ArrayRef<Value> rhsOperands) { - ValueDimList lhsValueDimOperands = - llvm::map_to_vector(lhsOperands, [](Value v) { - return std::make_pair(v, std::optional<int64_t>()); - }); - ValueDimList rhsValueDimOperands = - llvm::map_to_vector(rhsOperands, [](Value v) { - return std::make_pair(v, std::optional<int64_t>()); - }); - return ValueBoundsConstraintSet::compare(lhs, lhsValueDimOperands, cmp, rhs, - rhsValueDimOperands); -} - -FailureOr<bool> -ValueBoundsConstraintSet::areEqual(OpFoldResult value1, OpFoldResult value2, - std::optional<int64_t> dim1, - std::optional<int64_t> dim2) { - if (ValueBoundsConstraintSet::compare(value1, dim1, ComparisonOperator::EQ, - value2, dim2)) +FailureOr<bool> ValueBoundsConstraintSet::areEqual(Variable var1, + Variable var2) { + if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::EQ, var2)) return true; - if (ValueBoundsConstraintSet::compare(value1, dim1, ComparisonOperator::LT, - value2, dim2) || - ValueBoundsConstraintSet::compare(value1, dim1, ComparisonOperator::GT, - value2, dim2)) + if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::LT, var2) || + ValueBoundsConstraintSet::compare(var1, ComparisonOperator::GT, var2)) return false; return failure(); } @@ -833,7 +794,7 @@ ValueBoundsConstraintSet::areOverlappingSlices(MLIRContext *ctx, AffineMap foldedMap = foldAttributesIntoMap(b, map, ofrOperands, valueOperands); FailureOr<int64_t> constBound = computeConstantBound( - presburger::BoundType::EQ, foldedMap, valueOperands); + presburger::BoundType::EQ, Variable(foldedMap, valueOperands)); foundUnknownBound |= failed(constBound); if (succeeded(constBound) && *constBound <= 0) return false; @@ -850,7 +811,7 @@ ValueBoundsConstraintSet::areOverlappingSlices(MLIRContext *ctx, AffineMap foldedMap = foldAttributesIntoMap(b, map, ofrOperands, valueOperands); FailureOr<int64_t> constBound = computeConstantBound( - presburger::BoundType::EQ, foldedMap, valueOperands); + presburger::BoundType::EQ, Variable(foldedMap, valueOperands)); foundUnknownBound |= failed(constBound); if (succeeded(constBound) && *constBound <= 0) return false; diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp index f38631054fb3c14..af4ba7de3df1f6c 100644 --- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp +++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp @@ -169,7 +169,7 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp, FailureOr<OpFoldResult> reified = failure(); if (constant) { auto reifiedConst = ValueBoundsConstraintSet::computeConstantBound( - *boundType, value, dim, /*stopCondition=*/nullptr); + *boundType, {value, dim}, /*stopCondition=*/nullptr); if (succeeded(reifiedConst)) reified = FailureOr<OpFoldResult>(rewriter.getIndexAttr(*reifiedConst)); @@ -285,8 +285,8 @@ static LogicalResult testEquality(func::FuncOp funcOp) { auto compare = [&](ValueBoundsConstraintSet::ComparisonOperator cmp) { return ValueBoundsConstraintSet::compare( - /*lhs=*/op->getOperand(0), /*lhsDim=*/std::nullopt, cmp, - /*rhs=*/op->getOperand(1), /*rhsDim=*/std::nullopt); + /*lhs=*/op->getOperand(0), cmp, + /*rhs=*/op->getOperand(1)); }; if (compare(*cmpType)) { op->emitRemark("true"); _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits