srishti-pm updated this revision to Diff 440570. srishti-pm marked 6 inline comments as done. srishti-pm added a comment.
Addressed the final comments. Repository: rG LLVM Github Monorepo CHANGES SINCE LAST ACTION https://reviews.llvm.org/D124750/new/ https://reviews.llvm.org/D124750 Files: mlir/include/mlir/Transforms/CommutativityUtils.h mlir/lib/Transforms/Utils/CMakeLists.txt mlir/lib/Transforms/Utils/CommutativityUtils.cpp mlir/test/Transforms/test-commutativity-utils.mlir mlir/test/lib/Dialect/Test/TestOps.td mlir/test/lib/Transforms/CMakeLists.txt mlir/test/lib/Transforms/TestCommutativityUtils.cpp mlir/tools/mlir-opt/mlir-opt.cpp
Index: mlir/tools/mlir-opt/mlir-opt.cpp =================================================================== --- mlir/tools/mlir-opt/mlir-opt.cpp +++ mlir/tools/mlir-opt/mlir-opt.cpp @@ -57,6 +57,7 @@ void registerVectorizerTestPass(); namespace test { +void registerCommutativityUtils(); void registerConvertCallOpPass(); void registerInliner(); void registerMemRefBoundCheck(); @@ -152,6 +153,7 @@ registerVectorizerTestPass(); registerTosaTestQuantUtilAPIPass(); + mlir::test::registerCommutativityUtils(); mlir::test::registerConvertCallOpPass(); mlir::test::registerInliner(); mlir::test::registerMemRefBoundCheck(); Index: mlir/test/lib/Transforms/TestCommutativityUtils.cpp =================================================================== --- /dev/null +++ mlir/test/lib/Transforms/TestCommutativityUtils.cpp @@ -0,0 +1,48 @@ +//===- TestCommutativityUtils.cpp - Pass to test the commutativity utility-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass tests the functionality of the commutativity utility pattern. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/CommutativityUtils.h" + +#include "TestDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { + +struct CommutativityUtils + : public PassWrapper<CommutativityUtils, OperationPass<func::FuncOp>> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CommutativityUtils) + + StringRef getArgument() const final { return "test-commutativity-utils"; } + StringRef getDescription() const final { + return "Test the functionality of the commutativity utility"; + } + + void runOnOperation() override { + auto func = getOperation(); + auto *context = &getContext(); + + RewritePatternSet patterns(context); + populateCommutativityUtilsPatterns(patterns); + + (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + } +}; +} // namespace + +namespace mlir { +namespace test { +void registerCommutativityUtils() { PassRegistration<CommutativityUtils>(); } +} // namespace test +} // namespace mlir Index: mlir/test/lib/Transforms/CMakeLists.txt =================================================================== --- mlir/test/lib/Transforms/CMakeLists.txt +++ mlir/test/lib/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRTestTransforms + TestCommutativityUtils.cpp TestConstantFold.cpp TestControlFlowSink.cpp TestInlining.cpp Index: mlir/test/lib/Dialect/Test/TestOps.td =================================================================== --- mlir/test/lib/Dialect/Test/TestOps.td +++ mlir/test/lib/Dialect/Test/TestOps.td @@ -1162,11 +1162,21 @@ let hasFolder = 1; } +def TestAddIOp : TEST_Op<"addi"> { + let arguments = (ins I32:$op1, I32:$op2); + let results = (outs I32); +} + def TestCommutativeOp : TEST_Op<"op_commutative", [Commutative]> { let arguments = (ins I32:$op1, I32:$op2, I32:$op3, I32:$op4); let results = (outs I32); } +def TestLargeCommutativeOp : TEST_Op<"op_large_commutative", [Commutative]> { + let arguments = (ins I32:$op1, I32:$op2, I32:$op3, I32:$op4, I32:$op5, I32:$op6, I32:$op7); + let results = (outs I32); +} + def TestCommutative2Op : TEST_Op<"op_commutative2", [Commutative]> { let arguments = (ins I32:$op1, I32:$op2); let results = (outs I32); Index: mlir/test/Transforms/test-commutativity-utils.mlir =================================================================== --- /dev/null +++ mlir/test/Transforms/test-commutativity-utils.mlir @@ -0,0 +1,116 @@ +// RUN: mlir-opt %s -test-commutativity-utils | FileCheck %s + +// CHECK-LABEL: @test_small_pattern_1 +func.func @test_small_pattern_1(%arg0 : i32) -> i32 { + // CHECK-NEXT: %[[ARITH_CONST:.*]] = arith.constant + %0 = arith.constant 45 : i32 + + // CHECK-NEXT: %[[TEST_ADD:.*]] = "test.addi" + %1 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32 + + // CHECK-NEXT: %[[ARITH_ADD:.*]] = arith.addi + %2 = arith.addi %arg0, %arg0 : i32 + + // CHECK-NEXT: %[[ARITH_MUL:.*]] = arith.muli + %3 = arith.muli %arg0, %arg0 : i32 + + // CHECK-NEXT: %[[RESULT:.*]] = "test.op_commutative"(%[[ARITH_ADD]], %[[ARITH_MUL]], %[[TEST_ADD]], %[[ARITH_CONST]]) + %result = "test.op_commutative"(%0, %1, %2, %3): (i32, i32, i32, i32) -> i32 + + // CHECK-NEXT: return %[[RESULT]] + return %result : i32 +} + +// CHECK-LABEL: @test_small_pattern_2 +// CHECK-SAME: (%[[ARG0:.*]]: i32 +func.func @test_small_pattern_2(%arg0 : i32) -> i32 { + // CHECK-NEXT: %[[TEST_CONST:.*]] = "test.constant" + %0 = "test.constant"() {value = 0 : i32} : () -> i32 + + // CHECK-NEXT: %[[ARITH_CONST:.*]] = arith.constant + %1 = arith.constant 0 : i32 + + // CHECK-NEXT: %[[ARITH_ADD:.*]] = arith.addi + %2 = arith.addi %arg0, %arg0 : i32 + + // CHECK-NEXT: %[[RESULT:.*]] = "test.op_commutative"(%[[ARG0]], %[[ARITH_ADD]], %[[TEST_CONST]], %[[ARITH_CONST]]) + %result = "test.op_commutative"(%0, %1, %2, %arg0): (i32, i32, i32, i32) -> i32 + + // CHECK-NEXT: return %[[RESULT]] + return %result : i32 +} + +// CHECK-LABEL: @test_large_pattern +func.func @test_large_pattern(%arg0 : i32, %arg1 : i32) -> i32 { + // CHECK-NEXT: arith.divsi + %0 = arith.divsi %arg0, %arg1 : i32 + + // CHECK-NEXT: arith.divsi + %1 = arith.divsi %0, %arg0 : i32 + + // CHECK-NEXT: arith.divsi + %2 = arith.divsi %1, %arg1 : i32 + + // CHECK-NEXT: arith.addi + %3 = arith.addi %1, %arg1 : i32 + + // CHECK-NEXT: arith.subi + %4 = arith.subi %2, %3 : i32 + + // CHECK-NEXT: "test.addi" + %5 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32 + + // CHECK-NEXT: %[[VAL6:.*]] = arith.divsi + %6 = arith.divsi %4, %5 : i32 + + // CHECK-NEXT: arith.divsi + %7 = arith.divsi %1, %arg1 : i32 + + // CHECK-NEXT: %[[VAL8:.*]] = arith.muli + %8 = arith.muli %1, %arg1 : i32 + + // CHECK-NEXT: %[[VAL9:.*]] = arith.subi + %9 = arith.subi %7, %8 : i32 + + // CHECK-NEXT: "test.addi" + %10 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32 + + // CHECK-NEXT: %[[VAL11:.*]] = arith.divsi + %11 = arith.divsi %9, %10 : i32 + + // CHECK-NEXT: %[[VAL12:.*]] = arith.divsi + %12 = arith.divsi %6, %arg1 : i32 + + // CHECK-NEXT: arith.subi + %13 = arith.subi %arg1, %arg0 : i32 + + // CHECK-NEXT: "test.op_commutative"(%[[VAL12]], %[[VAL12]], %[[VAL8]], %[[VAL9]]) + %14 = "test.op_commutative"(%12, %9, %12, %8): (i32, i32, i32, i32) -> i32 + + // CHECK-NEXT: %[[VAL15:.*]] = arith.divsi + %15 = arith.divsi %13, %14 : i32 + + // CHECK-NEXT: %[[VAL16:.*]] = arith.addi + %16 = arith.addi %2, %15 : i32 + + // CHECK-NEXT: arith.subi + %17 = arith.subi %16, %arg1 : i32 + + // CHECK-NEXT: "test.addi" + %18 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32 + + // CHECK-NEXT: %[[VAL19:.*]] = arith.divsi + %19 = arith.divsi %17, %18 : i32 + + // CHECK-NEXT: "test.addi" + %20 = "test.addi"(%arg0, %16): (i32, i32) -> i32 + + // CHECK-NEXT: %[[VAL21:.*]] = arith.divsi + %21 = arith.divsi %17, %20 : i32 + + // CHECK-NEXT: %[[RESULT:.*]] = "test.op_large_commutative"(%[[VAL16]], %[[VAL19]], %[[VAL19]], %[[VAL21]], %[[VAL6]], %[[VAL11]], %[[VAL15]]) + %result = "test.op_large_commutative"(%16, %6, %11, %15, %19, %21, %19): (i32, i32, i32, i32, i32, i32, i32) -> i32 + + // CHECK-NEXT: return %[[RESULT]] + return %result : i32 +} Index: mlir/lib/Transforms/Utils/CommutativityUtils.cpp =================================================================== --- /dev/null +++ mlir/lib/Transforms/Utils/CommutativityUtils.cpp @@ -0,0 +1,558 @@ +//===- CommutativityUtils.cpp - Commutativity utilities ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a commutativity utility pattern and a function to +// populate this pattern. The function is intended to be used inside passes to +// simplify the matching of commutative operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/CommutativityUtils.h" + +#include <queue> + +using namespace mlir; + +/// Stores the "ancestor" of an operand of some op. The operand of any op is +/// produced by a set of ops and block arguments. Each of these ops and block +/// arguments is called an "ancestor" of this operand. +struct Ancestor { + /// Stores true when the "ancestor" is an op and false when the "ancestor" is + /// a block argument. + bool isOp; + + /// Stores the op when the "ancestor" is an op and nullptr when the "ancestor" + /// is a block argument. + Operation *op; + + /// Defines the constructor for `Ancestor`. + Ancestor(Operation *opForAncestor) { + isOp = opForAncestor ? true : false; + op = opForAncestor; + } +}; + +/// Declares various "types" of ancestors. +enum AncestorType { + /// Pertains to a block argument. + BLOCK_ARGUMENT, + + /// Pertains to a non-constant-like op. + NON_CONSTANT_OP, + + /// Pertains to a constant-like op. + CONSTANT_OP +}; + +/// Stores the "key" associated with an ancestor. +struct AncestorKey { + /// Holds `BLOCK_ARGUMENT`, `NON_CONSTANT_OP`, or `CONSTANT_OP`, depending on + /// the ancestor. + AncestorType type; + + /// Holds the full op name of the ancestor, for example, "arith.addi", iff + /// `type` is `NON_CONSTANT_OP`. + StringRef opName; + + /// Defines the constructor for `AncestorKey`. + AncestorKey(Ancestor ancestor) { + if (!ancestor.isOp) { + // When `ancestor` is a block argument, we assign `type` as + // `BLOCK_ARGUMENT` and `opName` remains "". + type = BLOCK_ARGUMENT; + } else if (!ancestor.op->hasTrait<OpTrait::ConstantLike>()) { + // When `ancestor` is a non-constant-like op, we assign `type` as + // `NON_CONSTANT_OP` and `opName` as the full op name of `ancestor`. + type = NON_CONSTANT_OP; + opName = ancestor.op->getName().getStringRef(); + } else { + // When `ancestor` is a constant-like op, we assign `type` as + // `CONSTANT_OP` and `opName` remains "". + type = CONSTANT_OP; + } + } + + /// Declares the overloaded operator `<`. + /// `AncestorKey1` is considered < `AncestorKey2` iff: + /// 1. The `type` of `AncestorKey1` is `BLOCK_ARGUMENT` and that of + /// `AncestorKey2` isn't, + /// 2. The `type` of `AncestorKey1` is `NON_CONSTANT_OP` and that of + /// `AncestorKey2` is `CONSTANT_OP`, or + /// 3. Both have the same `type` and the `opName` of `AncestorKey1` is + /// alphabetically smaller than that of `AncestorKey2`. + bool operator<(const AncestorKey &key) const { + if ((type == BLOCK_ARGUMENT && key.type != BLOCK_ARGUMENT) || + (type == NON_CONSTANT_OP && key.type == CONSTANT_OP)) + return true; + if ((key.type == BLOCK_ARGUMENT && type != BLOCK_ARGUMENT) || + (key.type == NON_CONSTANT_OP && type == CONSTANT_OP)) + return false; + return opName < key.opName; + } +}; + +/// Stores the BFS traversal information of an operand. +struct OperandBFS { + /// Stores the queue of ancestors of the BFS traversal of an operand at a + /// particular point in time. + std::queue<Ancestor> ancestorQueue; + + /// Stores the list of visited "op" ancestors of the BFS traversal of an + /// operand at a particular point in time. + DenseSet<Operation *> visitedAncestors; + + /// Stores the "key" associated with an operand. This "key" is defined as the + /// list of the "AncestorKeys" associated with the ancestors of this operand, + /// in a breadth-first order. + /// + /// So, if an operand, say `A`, was produced as follows: + /// + /// `<block argument>` `<block argument>` + /// \ / + /// \ / + /// `arith.subi` `arith.constant` + /// \ / + /// `arith.addi` + /// | + /// returns `A` + /// + /// Then, the ancestors of `A`, in the breadth-first order are: + /// `arith.addi`, `arith.subi`, `arith.constant`, `<block argument>`, and + /// `<block argument>`. + /// + /// Now, as already mentioned, the "AncestorKey" associated with: + /// 1. A block argument is {type: `BLOCK_ARGUMENT`, opName: ""}. + /// 2. A non-constant-like op, for example, `arith.addi`, is {type: + /// `NON_CONSTANT_OP`, opName: "arith.addi"}. + /// 3. A constant-like op, for example, `arith.constant`, is {type: + /// `CONSTANT_OP`, opName: ""}. + /// + /// Thus, the "key" associated with operand `A` is: + /// { + /// {type: `NON_CONSTANT_OP`, opName: "arith.addi"}, + /// {type: `NON_CONSTANT_OP`, opName: "arith.subi"}, + /// {type: `CONSTANT_OP`, opName: ""}, + /// {type: `BLOCK_ARGUMENT`, opName: ""}, + /// {type: `BLOCK_ARGUMENT`, opName: ""} + /// } + SmallVector<AncestorKey, 4> key; + + /// Stores true iff the operand has been assigned a sorted position yet. + bool isAssignedSortedPosition = false; + + /// Push an ancestor into the operand's BFS information structure. This + /// entails it being pushed into the queue (always) and inserted into the + /// "visited ancestors" list (iff it is an op rather than a block argument). + void pushAncestor(Operation *op) { + Ancestor ancestor(op); + ancestorQueue.push(ancestor); + if (ancestor.isOp) + visitedAncestors.insert(ancestor.op); + return; + } + + /// Pop the ancestor from the front of the queue. + void popAncestor() { + assert(!ancestorQueue.empty() && + "to pop the ancestor from the front of the queue, the ancestor " + "queue should be non-empty"); + ancestorQueue.pop(); + return; + } + + /// Return the ancestor at the front of the queue. + Ancestor frontAncestor() { + assert(!ancestorQueue.empty() && + "to access the ancestor at the front of the queue, the ancestor " + "queue should be non-empty"); + return ancestorQueue.front(); + } +}; + +/// Returns: +/// -1 if `keyA` < `keyB`, +/// 0 if `keyA` == `keyB`, and +/// 1 if `keyA` > `keyB`. +/// +/// Note that: +/// +/// (A) `keyA` == `keyB` iff: +/// Both these keys, each of which is a list, have the same size and both +/// the elements in each pair of corresponding elements among them is the +/// same. +/// +/// (B) `keyA` < `keyB` iff: +/// 1. In the first unequal pair of corresponding elements among them, +/// `keyA`'s element is smaller, or +/// 2. Both the elements in every pair of corresponding elements are the same +/// in both keys and the size of `keyA` is smaller. +/// +/// (C) `keyB` < `keyA` condition is defined likewise. +static int compareKeys(ArrayRef<AncestorKey> keyA, ArrayRef<AncestorKey> keyB) { + unsigned keyASize = keyA.size(); + unsigned keyBSize = keyB.size(); + unsigned smallestSize = keyASize; + if (keyBSize < smallestSize) + smallestSize = keyBSize; + + for (unsigned i = 0; i < smallestSize; i++) { + if (keyA[i] < keyB[i]) + return -1; + if (keyB[i] < keyA[i]) + return 1; + } + + if (keyASize == keyBSize) + return 0; + if (keyASize < keyBSize) + return -1; + return 1; +} + +/// Goes through all the unassigned operands of `bfsOfOperands` and: +/// 1. Stores the indices of the ones with the smallest key in +/// `smallestKeyIndices`, +/// 2. Stores the indices of the ones with the largest key in +/// `largestKeyIndices`, +/// 3. Sets `hasASingleOperandWithSmallestKey` as true if exactly one of them +/// has the smallest key (and as false otherwise), AND, +/// 4. Sets `hasASingleOperandWithLargestKey` as true if exactly one of them has +/// the largest key (and as false otherwise). +static void getIndicesOfUnassignedOperandsWithSmallestAndLargestKeys( + ArrayRef<OperandBFS *> bfsOfOperands, + DenseSet<unsigned> &smallestKeyIndices, + DenseSet<unsigned> &largestKeyIndices, + bool &hasASingleOperandWithSmallestKey, + bool &hasASingleOperandWithLargestKey) { + bool foundAnUnassignedOperand = false; + + // Compute the smallest and largest keys present among the unassigned operands + // of `bfsOfOperands`. + ArrayRef<AncestorKey> smallestKey, largestKey; + for (OperandBFS *bfsOfOperand : bfsOfOperands) { + if (bfsOfOperand->isAssignedSortedPosition) + continue; + + ArrayRef<AncestorKey> currentKey = bfsOfOperand->key; + if (!foundAnUnassignedOperand) { + foundAnUnassignedOperand = true; + smallestKey = currentKey; + largestKey = currentKey; + continue; + } + if (compareKeys(smallestKey, currentKey) == 1) + smallestKey = currentKey; + if (compareKeys(largestKey, currentKey) == -1) + largestKey = currentKey; + } + + // If there is no unassigned operand, assign the necessary values to the input + // arguments and return. + if (!foundAnUnassignedOperand) { + hasASingleOperandWithSmallestKey = false; + hasASingleOperandWithLargestKey = false; + return; + } + + // Populate `smallestKeyIndices` and `largestKeyIndices` and set + // `hasASingleOperandWithSmallestKey` and `hasASingleOperandWithLargestKey` + // accordingly. + bool smallestKeyFound = false; + bool largestKeyFound = false; + hasASingleOperandWithSmallestKey = true; + hasASingleOperandWithLargestKey = true; + for (auto &indexedBfsOfOperand : llvm::enumerate(bfsOfOperands)) { + OperandBFS *bfsOfOperand = indexedBfsOfOperand.value(); + if (bfsOfOperand->isAssignedSortedPosition) + continue; + + unsigned index = indexedBfsOfOperand.index(); + ArrayRef<AncestorKey> currentKey = bfsOfOperand->key; + + if (compareKeys(smallestKey, currentKey) == 0) { + smallestKeyIndices.insert(index); + if (smallestKeyFound) + hasASingleOperandWithSmallestKey = false; + smallestKeyFound = true; + } + + if (compareKeys(largestKey, currentKey) == 0) { + largestKeyIndices.insert(index); + if (largestKeyFound) + hasASingleOperandWithLargestKey = false; + largestKeyFound = true; + } + } + return; +} + +/// Update the key associated with each unassigned operand in `bfsOfOperands`. +/// Updating a key entails making it up-to-date with its associated operand's +/// BFS traversal that has happened till that point in time, i.e, appending the +/// existing key with the front ancestor's "AncestorKey". Note that a +/// key directly reflects the BFS and thus needs to be updated during the +/// progression of the traversal. +static void updateKeys(ArrayRef<OperandBFS *> bfsOfOperands) { + for (OperandBFS *bfsOfOperand : bfsOfOperands) { + if (bfsOfOperand->isAssignedSortedPosition || + bfsOfOperand->ancestorQueue.empty()) + continue; + + // Append the key of `bfsOfOperand` with the its front ancestor's + // "AncestorKey". + Ancestor frontAncestor = bfsOfOperand->frontAncestor(); + AncestorKey frontAncestorKey(frontAncestor); + bfsOfOperand->key.push_back(frontAncestorKey); + } + return; +} + +/// If `keyIndices` contains `indexOfOperand` and either `isTheOnlyKey` is true +/// or the ancestor queue of `bfsOfOperand` is empty, assign the sorted position +/// `positionToAssign` to the operand of `op` at index `indexOfOperand`, and +/// return true. Else, return false. +static bool assignSortedPositionTo(OperandBFS *bfsOfOperand, + unsigned indexOfOperand, + DenseSet<unsigned> keyIndices, + bool isTheOnlyKey, + SmallVectorImpl<Value> &sortedOperands, + unsigned positionToAssign, Operation *op) { + if (keyIndices.contains(indexOfOperand) && + (isTheOnlyKey || bfsOfOperand->ancestorQueue.empty())) { + bfsOfOperand->isAssignedSortedPosition = true; + sortedOperands[positionToAssign] = op->getOperand(indexOfOperand); + return true; + } + return false; +} + +/// In each of the operands of `bfsOfOperands`, pop the front ancestor from the +/// queue, if any, and then push its adjacent unvisited ancestors, if any, to +/// the queue (this is the main body of the BFS algorithm). +static void popFrontAndPushAdjacentUnvisitedAncestors( + ArrayRef<OperandBFS *> bfsOfOperands) { + for (OperandBFS *bfsOfOperand : bfsOfOperands) { + if (bfsOfOperand->isAssignedSortedPosition || + bfsOfOperand->ancestorQueue.empty()) + continue; + Ancestor frontAncestor = bfsOfOperand->frontAncestor(); + bfsOfOperand->popAncestor(); + if (!frontAncestor.isOp) + continue; + for (Value operand : frontAncestor.op->getOperands()) { + Operation *operandDefOp = operand.getDefiningOp(); + if (!operandDefOp || + !bfsOfOperand->visitedAncestors.contains(operandDefOp)) + bfsOfOperand->pushAncestor(operandDefOp); + } + } + return; +} + +/// Sorts the operands of `op` in ascending order of the "key" associated with +/// each operand iff `op` is commutative. +/// +/// After the application of this pattern, since the commutative operands now +/// have a deterministic order in which they occur in an op, the matching of +/// large DAGs becomes much simpler, i.e., requires much less number of checks +/// to be written by a user in her/his pattern matching function. +/// +/// Some examples of such a sorting: +/// +/// Assume that the sorting is being applied to `foo.commutative`, which is a +/// commutative op. +/// +/// Example 1: +/// +/// %1 = foo.const 0 +/// %2 = foo.mul <block argument>, <block argument> +/// %3 = foo.commutative %1, %2 +/// +/// Here, +/// 1. The key associated with %1 is `{{CONSTANT_OP, ""}}`, and, +/// 2. The key associated with %2 is `{{NON_CONSTANT_OP, "foo.mul"}, +/// {BLOCK_ARGUMENT, ""}, {BLOCK_ARGUMENT, ""}}`. +/// +/// Thus, the key of %2 < key of %1, and so, the sorted `foo.commutative` will +/// look like: +/// +/// %3 = foo.commutative %2, %1 +/// +/// Note that in this example, it wasn't necessary to get the entire set of +/// ancestors of operand %2 to decide that it has the smaller key. Just the +/// comparision of `{{CONSTANT_OP, ""}}` and `{{NON_CONSTANT_OP, "foo.mul"}}` is +/// enough to determine this. Thus, this sorting utility does not store the +/// entire set of ancestors of an operand at once. It seeks the ancestors in a +/// breadth-first fashion, one at a time, to create an operand "key" but stops +/// when the relative ordering of that operand can be determined. +/// +/// So, here, +/// 1. The key computed for %1 is `{{CONSTANT_OP, ""}}`, and, +/// 2. The key computed for %2 is `{{NON_CONSTANT_OP, "foo.mul"}}` (instead of +/// `{{NON_CONSTANT_OP, "foo.mul"}, {BLOCK_ARGUMENT, ""}, {BLOCK_ARGUMENT, +/// ""}}`. +/// +/// Example 2: +/// +/// %1 = foo.const 0 +/// %2 = foo.mul <block argument>, <block argument> +/// %3 = foo.mul %2, %1 +/// %4 = foo.add %2, %1 +/// %5 = foo.commutative %1, %2, %3, %4 +/// +/// Here, +/// 1. The key associated with %1 is `{{CONSTANT_OP, ""}}`, +/// 2. The key associated with %2 is `{{NON_CONSTANT_OP, "foo.mul"}, +/// {BLOCK_ARGUMENT, ""}}`, +/// 3. The key associated with %3 is `{{NON_CONSTANT_OP, "foo.mul"}, +/// {NON_CONSTANT_OP, "foo.mul"}, {CONSTANT_OP, ""}, {BLOCK_ARGUMENT, ""}, +/// {BLOCK_ARGUMENT, ""}}`, and, +/// 4. The key associated with %4 is `{{NON_CONSTANT_OP, "foo.add"}, +/// {NON_CONSTANT_OP, "foo.mul"}, {CONSTANT_OP, ""}, {BLOCK_ARGUMENT, ""}, +/// {BLOCK_ARGUMENT, ""}}`. +/// +/// And, +/// 1. The key computed for %1 is `{{CONSTANT_OP, ""}}`, +/// 2. The key computed for %2 is `{{NON_CONSTANT_OP, "foo.mul"}, +/// {BLOCK_ARGUMENT, ""}}`, +/// 3. The key computed for %3 is `{{NON_CONSTANT_OP, "foo.mul"}, +/// {NON_CONSTANT_OP, "foo.mul"}}`, and, +/// 4. The key computed for %4 is `{{NON_CONSTANT_OP, "foo.add"}}`. +/// +/// Note that the BFS's for operands %1 and %4 stop after one iteration while +/// those for operands %2 and %3 stop after the second iteration. +/// +/// And, the sorted `foo.commutative` will look like: +/// +/// %5 = foo.commutative %4, %3, %2, %1 +class SortCommutativeOperands : public RewritePattern { +public: + SortCommutativeOperands(MLIRContext *context) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/5, context) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + // If `op` is not commutative, do nothing. + if (!op->hasTrait<OpTrait::IsCommutative>()) + return failure(); + + // `bfsOfOperands` stores the BFS traversal information of each operand of + // `op`. For each operand, this information comprises a queue of ancestors + // being visited during the BFS (at a particular point in time), a list of + // visited ancestors (at a particular point in time), its associated key (at + // a particular point in time), and whether or not the operand has been + // assigned a sorted position yet. + SmallVector<OperandBFS *, 2> bfsOfOperands; + + // Initially, each operand's ancestor queue contains the op defining it + // (which is considered its first ancestor). Thus, it acts as the starting + // point for that operand's BFS traversal. + for (Value operand : op->getOperands()) { + OperandBFS *bfsOfOperand = new OperandBFS(); + bfsOfOperand->pushAncestor(operand.getDefiningOp()); + bfsOfOperands.push_back(bfsOfOperand); + } + + // Since none of the operands have been assigned a sorted position yet, the + // smallest unassigned position is set as zero and the largest one is set as + // the number of operands in `op` minus one (N - 1). This is because each + // operand will be assigned a sorted position between 0 and (N - 1), both + // inclusive. + unsigned numOperands = op->getNumOperands(); + unsigned smallestUnassignedPosition = 0; + unsigned largestUnassignedPosition = numOperands - 1; + + // `sortedOperands` will store the list of `op`'s operands in sorted order. + // At first, all elements in it are initialized as null. + SmallVector<Value, 2> sortedOperands(numOperands, nullptr); + + // We perform the BFS traversals of all operands parallelly until each of + // them is assigned a sorted position. During the traversals, we try to + // assign a sorted position to an operand as soon as it is possible (based + // on a comparision of its traversal with the other traversals at that + // particular point in time). + while (llvm::any_of(bfsOfOperands, [](OperandBFS *bfsOfOperand) { + return !bfsOfOperand->isAssignedSortedPosition; + })) { + // Update the keys corresponding to all unassigned operands. + updateKeys(bfsOfOperands); + + // Stores the indices of the unassigned operands whose key is the + // smallest. + DenseSet<unsigned> smallestKeyIndices; + // Stores the indices of the unassigned operands whose key is the largest. + DenseSet<unsigned> largestKeyIndices; + + // Stores true iff there is a single unassigned operand that has the + // smallest key. + bool hasASingleOperandWithSmallestKey; + // Stores true iff there is a single unassigned operand that has the + // largest key. + bool hasASingleOperandWithLargestKey; + + getIndicesOfUnassignedOperandsWithSmallestAndLargestKeys( + bfsOfOperands, smallestKeyIndices, largestKeyIndices, + hasASingleOperandWithSmallestKey, hasASingleOperandWithLargestKey); + + // Go through each of the unassigned operands with the smallest key and + // try to assign it a sorted position if possible (ensuring stable + // sorting). + for (auto &indexedBfsOfOperand : llvm::enumerate(bfsOfOperands)) { + OperandBFS *bfsOfOperand = indexedBfsOfOperand.value(); + if (bfsOfOperand->isAssignedSortedPosition) + continue; + + // If an unassigned operand has the smallest key and: + // 1. It is the only operand with the smallest key, OR, + // 2. Its BFS is complete, + // then, + // this operand is assigned the `smallestUnassignedPosition` (which will + // be its new position in the rearranged `op`). + if (assignSortedPositionTo( + bfsOfOperand, /*indexOfOperand=*/indexedBfsOfOperand.index(), + /*keyIndices=*/smallestKeyIndices, + /*isTheOnlyKey=*/hasASingleOperandWithSmallestKey, + /*sortedOperands=*/sortedOperands, + /*positionToAssign=*/smallestUnassignedPosition, /*op=*/op)) + smallestUnassignedPosition++; + } + // Go through each of the unassigned operands with the largest key and try + // to assign it a sorted position if possible (ensuring stable sorting). + for (auto indexedBfsOfOperand : + llvm::enumerate(llvm::reverse(bfsOfOperands))) { + OperandBFS *bfsOfOperand = indexedBfsOfOperand.value(); + if (bfsOfOperand->isAssignedSortedPosition) + continue; + + // If an unassigned operand has the largest key and: + // 1. It is the only operand with the largest key, OR, + // 2. Its BFS is complete, + // then, + // this operand is assigned the `largestUnassignedPosition` (which will + // be its new position in the rearranged `op`). + if (assignSortedPositionTo( + bfsOfOperand, /*indexOfOperand=*/numOperands - + indexedBfsOfOperand.index() - 1, + /*keyIndices=*/largestKeyIndices, + /*isTheOnlyKey=*/hasASingleOperandWithLargestKey, + /*sortedOperands=*/sortedOperands, + /*positionToAssign=*/largestUnassignedPosition, /*op=*/op)) + largestUnassignedPosition--; + } + + // For each operand in `bfsOfOperands`, pop the front ancestor from the + // queue and push its adjacent unvisited ancestors into the queue. + popFrontAndPushAdjacentUnvisitedAncestors(bfsOfOperands); + } + rewriter.updateRootInPlace(op, [&] { op->setOperands(sortedOperands); }); + return success(); + } +}; + +void mlir::populateCommutativityUtilsPatterns(RewritePatternSet &patterns) { + patterns.add<SortCommutativeOperands>(patterns.getContext()); +} Index: mlir/lib/Transforms/Utils/CMakeLists.txt =================================================================== --- mlir/lib/Transforms/Utils/CMakeLists.txt +++ mlir/lib/Transforms/Utils/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_library(MLIRTransformUtils + CommutativityUtils.cpp ControlFlowSinkUtils.cpp DialectConversion.cpp FoldUtils.cpp Index: mlir/include/mlir/Transforms/CommutativityUtils.h =================================================================== --- /dev/null +++ mlir/include/mlir/Transforms/CommutativityUtils.h @@ -0,0 +1,27 @@ +//===- CommutativityUtils.h - Commutativity utilities -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file declares a function to populate the commutativity utility +// pattern. This function is intended to be used inside passes to simplify the +// matching of commutative operations. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TRANSFORMS_COMMUTATIVITYUTILS_H +#define MLIR_TRANSFORMS_COMMUTATIVITYUTILS_H + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { + +/// Populates the commutativity utility patterns. +void populateCommutativityUtilsPatterns(RewritePatternSet &patterns); + +} // namespace mlir + +#endif // MLIR_TRANSFORMS_COMMUTATIVITYUTILS_H
_______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits