srishti-pm updated this revision to Diff 448682.
srishti-pm marked 3 inline comments as done.
srishti-pm added a comment.
Addressed the final NITs.
Repository:
rG LLVM Github Monorepo
CHANGES SINCE LAST ACTION
https://reviews.llvm.org/D124750/new/
https://reviews.llvm.org/D124750
Files:
mlir/include/mlir/Transforms/CommutativityUtils.h
mlir/lib/Transforms/Utils/CMakeLists.txt
mlir/lib/Transforms/Utils/CommutativityUtils.cpp
mlir/test/Transforms/test-commutativity-utils.mlir
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/lib/Transforms/CMakeLists.txt
mlir/test/lib/Transforms/TestCommutativityUtils.cpp
mlir/tools/mlir-opt/mlir-opt.cpp
Index: mlir/tools/mlir-opt/mlir-opt.cpp
===================================================================
--- mlir/tools/mlir-opt/mlir-opt.cpp
+++ mlir/tools/mlir-opt/mlir-opt.cpp
@@ -57,6 +57,7 @@
void registerVectorizerTestPass();
namespace test {
+void registerCommutativityUtils();
void registerConvertCallOpPass();
void registerInliner();
void registerMemRefBoundCheck();
@@ -149,6 +150,7 @@
registerVectorizerTestPass();
registerTosaTestQuantUtilAPIPass();
+ mlir::test::registerCommutativityUtils();
mlir::test::registerConvertCallOpPass();
mlir::test::registerInliner();
mlir::test::registerMemRefBoundCheck();
Index: mlir/test/lib/Transforms/TestCommutativityUtils.cpp
===================================================================
--- /dev/null
+++ mlir/test/lib/Transforms/TestCommutativityUtils.cpp
@@ -0,0 +1,48 @@
+//===- TestCommutativityUtils.cpp - Pass to test the commutativity utility-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass tests the functionality of the commutativity utility pattern.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/CommutativityUtils.h"
+
+#include "TestDialect.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+
+struct CommutativityUtils
+ : public PassWrapper<CommutativityUtils, OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CommutativityUtils)
+
+ StringRef getArgument() const final { return "test-commutativity-utils"; }
+ StringRef getDescription() const final {
+ return "Test the functionality of the commutativity utility";
+ }
+
+ void runOnOperation() override {
+ auto func = getOperation();
+ auto *context = &getContext();
+
+ RewritePatternSet patterns(context);
+ populateCommutativityUtilsPatterns(patterns);
+
+ (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
+ }
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerCommutativityUtils() { PassRegistration<CommutativityUtils>(); }
+} // namespace test
+} // namespace mlir
Index: mlir/test/lib/Transforms/CMakeLists.txt
===================================================================
--- mlir/test/lib/Transforms/CMakeLists.txt
+++ mlir/test/lib/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRTestTransforms
+ TestCommutativityUtils.cpp
TestConstantFold.cpp
TestControlFlowSink.cpp
TestInlining.cpp
Index: mlir/test/lib/Dialect/Test/TestOps.td
===================================================================
--- mlir/test/lib/Dialect/Test/TestOps.td
+++ mlir/test/lib/Dialect/Test/TestOps.td
@@ -1186,11 +1186,21 @@
let hasFolder = 1;
}
+def TestAddIOp : TEST_Op<"addi"> {
+ let arguments = (ins I32:$op1, I32:$op2);
+ let results = (outs I32);
+}
+
def TestCommutativeOp : TEST_Op<"op_commutative", [Commutative]> {
let arguments = (ins I32:$op1, I32:$op2, I32:$op3, I32:$op4);
let results = (outs I32);
}
+def TestLargeCommutativeOp : TEST_Op<"op_large_commutative", [Commutative]> {
+ let arguments = (ins I32:$op1, I32:$op2, I32:$op3, I32:$op4, I32:$op5, I32:$op6, I32:$op7);
+ let results = (outs I32);
+}
+
def TestCommutative2Op : TEST_Op<"op_commutative2", [Commutative]> {
let arguments = (ins I32:$op1, I32:$op2);
let results = (outs I32);
Index: mlir/test/Transforms/test-commutativity-utils.mlir
===================================================================
--- /dev/null
+++ mlir/test/Transforms/test-commutativity-utils.mlir
@@ -0,0 +1,116 @@
+// RUN: mlir-opt %s -test-commutativity-utils | FileCheck %s
+
+// CHECK-LABEL: @test_small_pattern_1
+func.func @test_small_pattern_1(%arg0 : i32) -> i32 {
+ // CHECK-NEXT: %[[ARITH_CONST:.*]] = arith.constant
+ %0 = arith.constant 45 : i32
+
+ // CHECK-NEXT: %[[TEST_ADD:.*]] = "test.addi"
+ %1 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32
+
+ // CHECK-NEXT: %[[ARITH_ADD:.*]] = arith.addi
+ %2 = arith.addi %arg0, %arg0 : i32
+
+ // CHECK-NEXT: %[[ARITH_MUL:.*]] = arith.muli
+ %3 = arith.muli %arg0, %arg0 : i32
+
+ // CHECK-NEXT: %[[RESULT:.*]] = "test.op_commutative"(%[[ARITH_ADD]], %[[ARITH_MUL]], %[[TEST_ADD]], %[[ARITH_CONST]])
+ %result = "test.op_commutative"(%0, %1, %2, %3): (i32, i32, i32, i32) -> i32
+
+ // CHECK-NEXT: return %[[RESULT]]
+ return %result : i32
+}
+
+// CHECK-LABEL: @test_small_pattern_2
+// CHECK-SAME: (%[[ARG0:.*]]: i32
+func.func @test_small_pattern_2(%arg0 : i32) -> i32 {
+ // CHECK-NEXT: %[[TEST_CONST:.*]] = "test.constant"
+ %0 = "test.constant"() {value = 0 : i32} : () -> i32
+
+ // CHECK-NEXT: %[[ARITH_CONST:.*]] = arith.constant
+ %1 = arith.constant 0 : i32
+
+ // CHECK-NEXT: %[[ARITH_ADD:.*]] = arith.addi
+ %2 = arith.addi %arg0, %arg0 : i32
+
+ // CHECK-NEXT: %[[RESULT:.*]] = "test.op_commutative"(%[[ARG0]], %[[ARITH_ADD]], %[[ARITH_CONST]], %[[TEST_CONST]])
+ %result = "test.op_commutative"(%0, %1, %2, %arg0): (i32, i32, i32, i32) -> i32
+
+ // CHECK-NEXT: return %[[RESULT]]
+ return %result : i32
+}
+
+// CHECK-LABEL: @test_large_pattern
+func.func @test_large_pattern(%arg0 : i32, %arg1 : i32) -> i32 {
+ // CHECK-NEXT: arith.divsi
+ %0 = arith.divsi %arg0, %arg1 : i32
+
+ // CHECK-NEXT: arith.divsi
+ %1 = arith.divsi %0, %arg0 : i32
+
+ // CHECK-NEXT: arith.divsi
+ %2 = arith.divsi %1, %arg1 : i32
+
+ // CHECK-NEXT: arith.addi
+ %3 = arith.addi %1, %arg1 : i32
+
+ // CHECK-NEXT: arith.subi
+ %4 = arith.subi %2, %3 : i32
+
+ // CHECK-NEXT: "test.addi"
+ %5 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32
+
+ // CHECK-NEXT: %[[VAL6:.*]] = arith.divsi
+ %6 = arith.divsi %4, %5 : i32
+
+ // CHECK-NEXT: arith.divsi
+ %7 = arith.divsi %1, %arg1 : i32
+
+ // CHECK-NEXT: %[[VAL8:.*]] = arith.muli
+ %8 = arith.muli %1, %arg1 : i32
+
+ // CHECK-NEXT: %[[VAL9:.*]] = arith.subi
+ %9 = arith.subi %7, %8 : i32
+
+ // CHECK-NEXT: "test.addi"
+ %10 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32
+
+ // CHECK-NEXT: %[[VAL11:.*]] = arith.divsi
+ %11 = arith.divsi %9, %10 : i32
+
+ // CHECK-NEXT: %[[VAL12:.*]] = arith.divsi
+ %12 = arith.divsi %6, %arg1 : i32
+
+ // CHECK-NEXT: arith.subi
+ %13 = arith.subi %arg1, %arg0 : i32
+
+ // CHECK-NEXT: "test.op_commutative"(%[[VAL12]], %[[VAL12]], %[[VAL8]], %[[VAL9]])
+ %14 = "test.op_commutative"(%12, %9, %12, %8): (i32, i32, i32, i32) -> i32
+
+ // CHECK-NEXT: %[[VAL15:.*]] = arith.divsi
+ %15 = arith.divsi %13, %14 : i32
+
+ // CHECK-NEXT: %[[VAL16:.*]] = arith.addi
+ %16 = arith.addi %2, %15 : i32
+
+ // CHECK-NEXT: arith.subi
+ %17 = arith.subi %16, %arg1 : i32
+
+ // CHECK-NEXT: "test.addi"
+ %18 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32
+
+ // CHECK-NEXT: %[[VAL19:.*]] = arith.divsi
+ %19 = arith.divsi %17, %18 : i32
+
+ // CHECK-NEXT: "test.addi"
+ %20 = "test.addi"(%arg0, %16): (i32, i32) -> i32
+
+ // CHECK-NEXT: %[[VAL21:.*]] = arith.divsi
+ %21 = arith.divsi %17, %20 : i32
+
+ // CHECK-NEXT: %[[RESULT:.*]] = "test.op_large_commutative"(%[[VAL16]], %[[VAL19]], %[[VAL19]], %[[VAL21]], %[[VAL6]], %[[VAL11]], %[[VAL15]])
+ %result = "test.op_large_commutative"(%16, %6, %11, %15, %19, %21, %19): (i32, i32, i32, i32, i32, i32, i32) -> i32
+
+ // CHECK-NEXT: return %[[RESULT]]
+ return %result : i32
+}
Index: mlir/lib/Transforms/Utils/CommutativityUtils.cpp
===================================================================
--- /dev/null
+++ mlir/lib/Transforms/Utils/CommutativityUtils.cpp
@@ -0,0 +1,317 @@
+//===- CommutativityUtils.cpp - Commutativity utilities ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a commutativity utility pattern and a function to
+// populate this pattern. The function is intended to be used inside passes to
+// simplify the matching of commutative operations by fixing the order of their
+// operands.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/CommutativityUtils.h"
+
+#include <queue>
+
+using namespace mlir;
+
+/// The possible "types" of ancestors. Here, an ancestor is an op or a block
+/// argument present in the backward slice of a value.
+enum AncestorType {
+ /// Pertains to a block argument.
+ BLOCK_ARGUMENT,
+
+ /// Pertains to a non-constant-like op.
+ NON_CONSTANT_OP,
+
+ /// Pertains to a constant-like op.
+ CONSTANT_OP
+};
+
+/// Stores the "key" associated with an ancestor.
+struct AncestorKey {
+ /// Holds `BLOCK_ARGUMENT`, `NON_CONSTANT_OP`, or `CONSTANT_OP`, depending on
+ /// the ancestor.
+ AncestorType type;
+
+ /// Holds the op name of the ancestor if its `type` is `NON_CONSTANT_OP` or
+ /// `CONSTANT_OP`. Else, holds "".
+ StringRef opName;
+
+ /// Constructor for `AncestorKey`.
+ AncestorKey(Operation *op) {
+ if (!op) {
+ type = BLOCK_ARGUMENT;
+ } else {
+ type =
+ op->hasTrait<OpTrait::ConstantLike>() ? CONSTANT_OP : NON_CONSTANT_OP;
+ opName = op->getName().getStringRef();
+ }
+ }
+
+ /// Overloaded operator `<` for `AncestorKey`.
+ ///
+ /// AncestorKeys of type `BLOCK_ARGUMENT` are considered the smallest, those
+ /// of type `CONSTANT_OP`, the largest, and `NON_CONSTANT_OP` types come in
+ /// between. Within the types `NON_CONSTANT_OP` and `CONSTANT_OP`, the smaller
+ /// ones are the ones with smaller op names (lexicographically).
+ ///
+ /// TODO: Include other information like attributes, value type, etc., to
+ /// enhance this comparison. For example, currently this comparison doesn't
+ /// differentiate between `cmpi sle` and `cmpi sgt` or `addi (in i32)` and
+ /// `addi (in i64)`. Such an enhancement should only be done if the need
+ /// arises.
+ bool operator<(const AncestorKey &key) const {
+ return std::tie(type, opName) < std::tie(key.type, key.opName);
+ }
+};
+
+/// Stores a commutative operand along with its BFS traversal information.
+struct CommutativeOperand {
+ /// Stores the operand.
+ Value operand;
+
+ /// Stores the queue of ancestors of the operand's BFS traversal at a
+ /// particular point in time.
+ std::queue<Operation *> ancestorQueue;
+
+ /// Stores the list of ancestors that have been visited by the BFS traversal
+ /// at a particular point in time.
+ DenseSet<Operation *> visitedAncestors;
+
+ /// Stores the operand's "key". This "key" is defined as a list of the
+ /// "AncestorKeys" associated with the ancestors of this operand, in a
+ /// breadth-first order.
+ ///
+ /// So, if an operand, say `A`, was produced as follows:
+ ///
+ /// `<block argument>` `<block argument>`
+ /// \ /
+ /// \ /
+ /// `arith.subi` `arith.constant`
+ /// \ /
+ /// `arith.addi`
+ /// |
+ /// returns `A`
+ ///
+ /// Then, the ancestors of `A`, in the breadth-first order are:
+ /// `arith.addi`, `arith.subi`, `arith.constant`, `<block argument>`, and
+ /// `<block argument>`.
+ ///
+ /// Thus, the "key" associated with operand `A` is:
+ /// {
+ /// {type: `NON_CONSTANT_OP`, opName: "arith.addi"},
+ /// {type: `NON_CONSTANT_OP`, opName: "arith.subi"},
+ /// {type: `CONSTANT_OP`, opName: "arith.constant"},
+ /// {type: `BLOCK_ARGUMENT`, opName: ""},
+ /// {type: `BLOCK_ARGUMENT`, opName: ""}
+ /// }
+ SmallVector<AncestorKey, 4> key;
+
+ /// Push an ancestor into the operand's BFS information structure. This
+ /// entails it being pushed into the queue (always) and inserted into the
+ /// "visited ancestors" list (iff it is an op rather than a block argument).
+ void pushAncestor(Operation *op) {
+ ancestorQueue.push(op);
+ if (op)
+ visitedAncestors.insert(op);
+ return;
+ }
+
+ /// Refresh the key.
+ ///
+ /// Refreshing a key entails making it up-to-date with the operand's BFS
+ /// traversal that has happened till that point in time, i.e, appending the
+ /// existing key with the front ancestor's "AncestorKey". Note that a key
+ /// directly reflects the BFS and thus needs to be refreshed during the
+ /// progression of the traversal.
+ void refreshKey() {
+ if (ancestorQueue.empty())
+ return;
+
+ Operation *frontAncestor = ancestorQueue.front();
+ AncestorKey frontAncestorKey(frontAncestor);
+ key.push_back(frontAncestorKey);
+ return;
+ }
+
+ /// Pop the front ancestor, if any, from the queue and then push its adjacent
+ /// unvisited ancestors, if any, to the queue (this is the main body of the
+ /// BFS algorithm).
+ void popFrontAndPushAdjacentUnvisitedAncestors() {
+ if (ancestorQueue.empty())
+ return;
+ Operation *frontAncestor = ancestorQueue.front();
+ ancestorQueue.pop();
+ if (!frontAncestor)
+ return;
+ for (Value operand : frontAncestor->getOperands()) {
+ Operation *operandDefOp = operand.getDefiningOp();
+ if (!operandDefOp || !visitedAncestors.contains(operandDefOp))
+ pushAncestor(operandDefOp);
+ }
+ return;
+ }
+};
+
+/// Sorts the operands of `op` in ascending order of the "key" associated with
+/// each operand iff `op` is commutative. This is a stable sort.
+///
+/// After the application of this pattern, since the commutative operands now
+/// have a deterministic order in which they occur in an op, the matching of
+/// large DAGs becomes much simpler, i.e., requires much less number of checks
+/// to be written by a user in her/his pattern matching function.
+///
+/// Some examples of such a sorting:
+///
+/// Assume that the sorting is being applied to `foo.commutative`, which is a
+/// commutative op.
+///
+/// Example 1:
+///
+/// %1 = foo.const 0
+/// %2 = foo.mul <block argument>, <block argument>
+/// %3 = foo.commutative %1, %2
+///
+/// Here,
+/// 1. The key associated with %1 is:
+/// `{
+/// {CONSTANT_OP, "foo.const"}
+/// }`
+/// 2. The key associated with %2 is:
+/// `{
+/// {NON_CONSTANT_OP, "foo.mul"},
+/// {BLOCK_ARGUMENT, ""},
+/// {BLOCK_ARGUMENT, ""}
+/// }`
+///
+/// The key of %2 < the key of %1
+/// Thus, the sorted `foo.commutative` is:
+/// %3 = foo.commutative %2, %1
+///
+/// Example 2:
+///
+/// %1 = foo.const 0
+/// %2 = foo.mul <block argument>, <block argument>
+/// %3 = foo.mul %2, %1
+/// %4 = foo.add %2, %1
+/// %5 = foo.commutative %1, %2, %3, %4
+///
+/// Here,
+/// 1. The key associated with %1 is:
+/// `{
+/// {CONSTANT_OP, "foo.const"}
+/// }`
+/// 2. The key associated with %2 is:
+/// `{
+/// {NON_CONSTANT_OP, "foo.mul"},
+/// {BLOCK_ARGUMENT, ""}
+/// }`
+/// 3. The key associated with %3 is:
+/// `{
+/// {NON_CONSTANT_OP, "foo.mul"},
+/// {NON_CONSTANT_OP, "foo.mul"},
+/// {CONSTANT_OP, "foo.const"},
+/// {BLOCK_ARGUMENT, ""},
+/// {BLOCK_ARGUMENT, ""}
+/// }`
+/// 4. The key associated with %4 is:
+/// `{
+/// {NON_CONSTANT_OP, "foo.add"},
+/// {NON_CONSTANT_OP, "foo.mul"},
+/// {CONSTANT_OP, "foo.const"},
+/// {BLOCK_ARGUMENT, ""},
+/// {BLOCK_ARGUMENT, ""}
+/// }`
+///
+/// Thus, the sorted `foo.commutative` is:
+/// %5 = foo.commutative %4, %3, %2, %1
+class SortCommutativeOperands : public RewritePattern {
+public:
+ SortCommutativeOperands(MLIRContext *context)
+ : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/5, context) {}
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ // Custom comparator for two commutative operands, which returns true iff
+ // the "key" of `constCommOperandA` < the "key" of `constCommOperandB`,
+ // i.e.,
+ // 1. In the first unequal pair of corresponding AncestorKeys, the
+ // AncestorKey in `constCommOperandA` is smaller, or,
+ // 2. Both the AncestorKeys in every pair are the same and the size of
+ // `constCommOperandA`'s "key" is smaller.
+ auto commutativeOperandComparator =
+ [](const std::unique_ptr<CommutativeOperand> &constCommOperandA,
+ const std::unique_ptr<CommutativeOperand> &constCommOperandB) {
+ if (constCommOperandA->operand == constCommOperandB->operand)
+ return false;
+
+ auto &commOperandA =
+ const_cast<std::unique_ptr<CommutativeOperand> &>(
+ constCommOperandA);
+ auto &commOperandB =
+ const_cast<std::unique_ptr<CommutativeOperand> &>(
+ constCommOperandB);
+
+ // Iteratively perform the BFS's of both operands until an order among
+ // them can be determined.
+ unsigned keyIndex = 0;
+ while (true) {
+ if (commOperandA->key.size() <= keyIndex) {
+ if (commOperandA->ancestorQueue.empty())
+ return true;
+ commOperandA->popFrontAndPushAdjacentUnvisitedAncestors();
+ commOperandA->refreshKey();
+ }
+ if (commOperandB->key.size() <= keyIndex) {
+ if (commOperandB->ancestorQueue.empty())
+ return false;
+ commOperandB->popFrontAndPushAdjacentUnvisitedAncestors();
+ commOperandB->refreshKey();
+ }
+ if (commOperandA->ancestorQueue.empty() ||
+ commOperandB->ancestorQueue.empty())
+ return commOperandA->key.size() < commOperandB->key.size();
+ if (commOperandA->key[keyIndex] < commOperandB->key[keyIndex])
+ return true;
+ if (commOperandB->key[keyIndex] < commOperandA->key[keyIndex])
+ return false;
+ keyIndex++;
+ }
+ };
+
+ // If `op` is not commutative, do nothing.
+ if (!op->hasTrait<OpTrait::IsCommutative>())
+ return failure();
+
+ // Populate the list of commutative operands.
+ SmallVector<Value, 2> operands = op->getOperands();
+ SmallVector<std::unique_ptr<CommutativeOperand>, 2> commOperands;
+ for (Value operand : operands) {
+ std::unique_ptr<CommutativeOperand> commOperand =
+ std::make_unique<CommutativeOperand>();
+ commOperand->operand = operand;
+ commOperand->pushAncestor(operand.getDefiningOp());
+ commOperand->refreshKey();
+ commOperands.push_back(std::move(commOperand));
+ }
+
+ // Sort the operands.
+ std::stable_sort(commOperands.begin(), commOperands.end(),
+ commutativeOperandComparator);
+ SmallVector<Value, 2> sortedOperands;
+ for (const std::unique_ptr<CommutativeOperand> &commOperand : commOperands)
+ sortedOperands.push_back(commOperand->operand);
+ if (sortedOperands == operands)
+ return failure();
+ rewriter.updateRootInPlace(op, [&] { op->setOperands(sortedOperands); });
+ return success();
+ }
+};
+
+void mlir::populateCommutativityUtilsPatterns(RewritePatternSet &patterns) {
+ patterns.add<SortCommutativeOperands>(patterns.getContext());
+}
Index: mlir/lib/Transforms/Utils/CMakeLists.txt
===================================================================
--- mlir/lib/Transforms/Utils/CMakeLists.txt
+++ mlir/lib/Transforms/Utils/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_library(MLIRTransformUtils
+ CommutativityUtils.cpp
ControlFlowSinkUtils.cpp
DialectConversion.cpp
FoldUtils.cpp
Index: mlir/include/mlir/Transforms/CommutativityUtils.h
===================================================================
--- /dev/null
+++ mlir/include/mlir/Transforms/CommutativityUtils.h
@@ -0,0 +1,27 @@
+//===- CommutativityUtils.h - Commutativity utilities -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file declares a function to populate the commutativity utility
+// pattern. This function is intended to be used inside passes to simplify the
+// matching of commutative operations by fixing the order of their operands.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_COMMUTATIVITYUTILS_H
+#define MLIR_TRANSFORMS_COMMUTATIVITYUTILS_H
+
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+
+/// Populates the commutativity utility patterns.
+void populateCommutativityUtilsPatterns(RewritePatternSet &patterns);
+
+} // namespace mlir
+
+#endif // MLIR_TRANSFORMS_COMMUTATIVITYUTILS_H
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits