https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/68941
>From 877111a139b2f01037fdbe7c0cb120a2f4e64562 Mon Sep 17 00:00:00 2001 From: hanhanW <hanhan0...@gmail.com> Date: Thu, 12 Oct 2023 17:14:29 -0700 Subject: [PATCH 1/2] Reland "[mlir][arith] Canonicalization patterns for `arith.select` (#67809)" This cherry-picks the changes in llvm-project/5bf701a6687a46fd898621f5077959ff202d716b with limiting types to i1. --- .../Dialect/Arith/IR/ArithCanonicalization.td | 46 +++++++++++ mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 4 +- mlir/test/Dialect/Arith/canonicalize.mlir | 76 +++++++++++++++++++ 3 files changed, 125 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td index f3d84d0b261e8dd..817de0e06c661b9 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td +++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td @@ -233,6 +233,52 @@ def CmpIExtUI : CPred<"$0.getValue() == arith::CmpIPredicate::eq || " "$0.getValue() == arith::CmpIPredicate::ne">> $pred)]>; +//===----------------------------------------------------------------------===// +// SelectOp +//===----------------------------------------------------------------------===// + +// select(not(pred), a, b) => select(pred, b, a) +def SelectNotCond : + Pat<(SelectOp (Arith_XOrIOp $pred, (ConstantLikeMatcher APIntAttr:$ones)), $a, $b), + (SelectOp $pred, $b, $a), + [(IsScalarOrSplatNegativeOne $ones)]>; + +// select(pred, select(pred, a, b), c) => select(pred, a, c) +def RedundantSelectTrue : + Pat<(SelectOp $pred, (SelectOp $pred, $a, $b), $c), + (SelectOp $pred, $a, $c)>; + +// select(pred, a, select(pred, b, c)) => select(pred, a, c) +def RedundantSelectFalse : + Pat<(SelectOp $pred, $a, (SelectOp $pred, $b, $c)), + (SelectOp $pred, $a, $c)>; + +// select(predA, select(predB, x, y), y) => select(and(predA, predB), x, y) +def SelectAndCond : + Pat<(SelectOp $predA, (SelectOp $predB, $x, $y), $y), + (SelectOp (Arith_AndIOp $predA, $predB), $x, $y)>; + +// select(predA, select(predB, y, x), y) => select(and(predA, not(predB)), x, y) +def SelectAndNotCond : + Pat<(SelectOp $predA, (SelectOp $predB, $y, $x), $y), + (SelectOp (Arith_AndIOp $predA, + (Arith_XOrIOp $predB, (Arith_ConstantOp ConstantAttr<I1Attr, "1">))), + $x, $y), + [(Constraint<CPred<"$0.getType() == $_builder.getI1Type()">> $predB)]>; + +// select(predA, x, select(predB, x, y)) => select(or(predA, predB), x, y) +def SelectOrCond : + Pat<(SelectOp $predA, $x, (SelectOp $predB, $x, $y)), + (SelectOp (Arith_OrIOp $predA, $predB), $x, $y)>; + +// select(predA, x, select(predB, y, x)) => select(or(predA, not(predB)), x, y) +def SelectOrNotCond : + Pat<(SelectOp $predA, $x, (SelectOp $predB, $y, $x)), + (SelectOp (Arith_OrIOp $predA, + (Arith_XOrIOp $predB, (Arith_ConstantOp ConstantAttr<I1Attr, "1">))), + $x, $y), + [(Constraint<CPred<"$0.getType() == $_builder.getI1Type()">> $predB)]>; + //===----------------------------------------------------------------------===// // IndexCastOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index ae8a6ef350ce191..0ecc288f3b07701 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -2212,7 +2212,9 @@ struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> { void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<SelectI1Simplify, SelectToExtUI>(context); + results.add<RedundantSelectFalse, RedundantSelectTrue, SelectI1Simplify, + SelectAndCond, SelectAndNotCond, SelectOrCond, SelectOrNotCond, + SelectNotCond, SelectToExtUI>(context); } OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) { diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index f697f3d01458eee..1b0547c9e8f804a 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -128,6 +128,82 @@ func.func @selToArith(%arg0: i1, %arg1 : i1, %arg2 : i1) -> i1 { return %res : i1 } +// CHECK-LABEL: @redundantSelectTrue +// CHECK-NEXT: %[[res:.+]] = arith.select %arg0, %arg1, %arg3 +// CHECK-NEXT: return %[[res]] +func.func @redundantSelectTrue(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 { + %0 = arith.select %arg0, %arg1, %arg2 : i32 + %res = arith.select %arg0, %0, %arg3 : i32 + return %res : i32 +} + +// CHECK-LABEL: @redundantSelectFalse +// CHECK-NEXT: %[[res:.+]] = arith.select %arg0, %arg3, %arg2 +// CHECK-NEXT: return %[[res]] +func.func @redundantSelectFalse(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 { + %0 = arith.select %arg0, %arg1, %arg2 : i32 + %res = arith.select %arg0, %arg3, %0 : i32 + return %res : i32 +} + +// CHECK-LABEL: @selNotCond +// CHECK-NEXT: %[[res1:.+]] = arith.select %arg0, %arg2, %arg1 +// CHECK-NEXT: %[[res2:.+]] = arith.select %arg0, %arg4, %arg3 +// CHECK-NEXT: return %[[res1]], %[[res2]] +func.func @selNotCond(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 : i32) -> (i32, i32) { + %one = arith.constant 1 : i1 + %cond1 = arith.xori %arg0, %one : i1 + %cond2 = arith.xori %one, %arg0 : i1 + + %res1 = arith.select %cond1, %arg1, %arg2 : i32 + %res2 = arith.select %cond2, %arg3, %arg4 : i32 + return %res1, %res2 : i32, i32 +} + +// CHECK-LABEL: @selAndCond +// CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %arg0 +// CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg2, %arg3 +// CHECK-NEXT: return %[[res]] +func.func @selAndCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 { + %sel = arith.select %arg0, %arg2, %arg3 : i32 + %res = arith.select %arg1, %sel, %arg3 : i32 + return %res : i32 +} + +// CHECK-LABEL: @selAndNotCond +// CHECK-NEXT: %[[one:.+]] = arith.constant true +// CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]] +// CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %[[not]] +// CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg3, %arg2 +// CHECK-NEXT: return %[[res]] +func.func @selAndNotCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 { + %sel = arith.select %arg0, %arg2, %arg3 : i32 + %res = arith.select %arg1, %sel, %arg2 : i32 + return %res : i32 +} + +// CHECK-LABEL: @selOrCond +// CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %arg0 +// CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg2, %arg3 +// CHECK-NEXT: return %[[res]] +func.func @selOrCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 { + %sel = arith.select %arg0, %arg2, %arg3 : i32 + %res = arith.select %arg1, %arg2, %sel : i32 + return %res : i32 +} + +// CHECK-LABEL: @selOrNotCond +// CHECK-NEXT: %[[one:.+]] = arith.constant true +// CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]] +// CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %[[not]] +// CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg3, %arg2 +// CHECK-NEXT: return %[[res]] +func.func @selOrNotCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 { + %sel = arith.select %arg0, %arg2, %arg3 : i32 + %res = arith.select %arg1, %arg3, %sel : i32 + return %res : i32 +} + // Test case: Folding of comparisons with equal operands. // CHECK-LABEL: @cmpi_equal_operands // CHECK-DAG: %[[T:.*]] = arith.constant true >From 506e0c83d65845c62737bc915878ae47008bbc28 Mon Sep 17 00:00:00 2001 From: hanhanW <hanhan0...@gmail.com> Date: Fri, 13 Oct 2023 10:45:11 -0700 Subject: [PATCH 2/2] extend patterns to handle vector types --- .../Dialect/Arith/IR/ArithCanonicalization.td | 15 +++++++----- mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 18 +++++++------- mlir/test/Dialect/Arith/canonicalize.mlir | 24 +++++++++++++++++++ 3 files changed, 42 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td index 817de0e06c661b9..ef951647ccd1464 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td +++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td @@ -237,6 +237,9 @@ def CmpIExtUI : // SelectOp //===----------------------------------------------------------------------===// +def GetScalarOrVectorTrueAttribute : + NativeCodeCall<"cast<TypedAttr>(getBoolAttribute($0.getType(), true))">; + // select(not(pred), a, b) => select(pred, b, a) def SelectNotCond : Pat<(SelectOp (Arith_XOrIOp $pred, (ConstantLikeMatcher APIntAttr:$ones)), $a, $b), @@ -262,9 +265,9 @@ def SelectAndCond : def SelectAndNotCond : Pat<(SelectOp $predA, (SelectOp $predB, $y, $x), $y), (SelectOp (Arith_AndIOp $predA, - (Arith_XOrIOp $predB, (Arith_ConstantOp ConstantAttr<I1Attr, "1">))), - $x, $y), - [(Constraint<CPred<"$0.getType() == $_builder.getI1Type()">> $predB)]>; + (Arith_XOrIOp $predB, + (Arith_ConstantOp (GetScalarOrVectorTrueAttribute $predB)))), + $x, $y)>; // select(predA, x, select(predB, x, y)) => select(or(predA, predB), x, y) def SelectOrCond : @@ -275,9 +278,9 @@ def SelectOrCond : def SelectOrNotCond : Pat<(SelectOp $predA, $x, (SelectOp $predB, $y, $x)), (SelectOp (Arith_OrIOp $predA, - (Arith_XOrIOp $predB, (Arith_ConstantOp ConstantAttr<I1Attr, "1">))), - $x, $y), - [(Constraint<CPred<"$0.getType() == $_builder.getI1Type()">> $predB)]>; + (Arith_XOrIOp $predB, + (Arith_ConstantOp (GetScalarOrVectorTrueAttribute $predB)))), + $x, $y)>; //===----------------------------------------------------------------------===// // IndexCastOp diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 0ecc288f3b07701..02bab31971dcbe4 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -113,6 +113,14 @@ static FailureOr<APInt> getIntOrSplatIntValue(Attribute attr) { return failure(); } +static Attribute getBoolAttribute(Type type, bool value) { + auto boolAttr = BoolAttr::get(type.getContext(), value); + ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type); + if (!shapedType) + return boolAttr; + return DenseElementsAttr::get(shapedType, boolAttr); +} + //===----------------------------------------------------------------------===// // TableGen'd canonicalization patterns //===----------------------------------------------------------------------===// @@ -1696,14 +1704,6 @@ static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) { llvm_unreachable("unknown cmpi predicate kind"); } -static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) { - auto boolAttr = BoolAttr::get(ctx, value); - ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type); - if (!shapedType) - return boolAttr; - return DenseElementsAttr::get(shapedType, boolAttr); -} - static std::optional<int64_t> getIntegerWidth(Type t) { if (auto intType = llvm::dyn_cast<IntegerType>(t)) { return intType.getWidth(); @@ -1718,7 +1718,7 @@ OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) { // cmpi(pred, x, x) if (getLhs() == getRhs()) { auto val = applyCmpPredicateToEqualOperands(getPredicate()); - return getBoolAttribute(getType(), getContext(), val); + return getBoolAttribute(getType(), val); } if (matchPattern(adaptor.getRhs(), m_Zero())) { diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index 1b0547c9e8f804a..abe9737b25443e8 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -182,6 +182,18 @@ func.func @selAndNotCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 return %res : i32 } +// CHECK-LABEL: @selAndNotCondVec +// CHECK-NEXT: %[[one:.+]] = arith.constant dense<true> : vector<4xi1> +// CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]] +// CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %[[not]] +// CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg3, %arg2 +// CHECK-NEXT: return %[[res]] +func.func @selAndNotCondVec(%arg0: vector<4xi1>, %arg1: vector<4xi1>, %arg2 : vector<4xi32>, %arg3 : vector<4xi32>) -> vector<4xi32> { + %sel = arith.select %arg0, %arg2, %arg3 : vector<4xi1>, vector<4xi32> + %res = arith.select %arg1, %sel, %arg2 : vector<4xi1>, vector<4xi32> + return %res : vector<4xi32> +} + // CHECK-LABEL: @selOrCond // CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %arg0 // CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg2, %arg3 @@ -204,6 +216,18 @@ func.func @selOrNotCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 { return %res : i32 } +// CHECK-LABEL: @selOrNotCondVec +// CHECK-NEXT: %[[one:.+]] = arith.constant dense<true> : vector<4xi1> +// CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]] +// CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %[[not]] +// CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg3, %arg2 +// CHECK-NEXT: return %[[res]] +func.func @selOrNotCondVec(%arg0: vector<4xi1>, %arg1: vector<4xi1>, %arg2 : vector<4xi32>, %arg3 : vector<4xi32>) -> vector<4xi32> { + %sel = arith.select %arg0, %arg2, %arg3 : vector<4xi1>, vector<4xi32> + %res = arith.select %arg1, %arg3, %sel : vector<4xi1>, vector<4xi32> + return %res : vector<4xi32> +} + // Test case: Folding of comparisons with equal operands. // CHECK-LABEL: @cmpi_equal_operands // CHECK-DAG: %[[T:.*]] = arith.constant true _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits