srishti-pm updated this revision to Diff 441253. srishti-pm added a comment.
Addressed of all Jeff's comments. Repository: rG LLVM Github Monorepo CHANGES SINCE LAST ACTION https://reviews.llvm.org/D124750/new/ https://reviews.llvm.org/D124750 Files: mlir/include/mlir/Transforms/CommutativityUtils.h mlir/lib/Transforms/Utils/CMakeLists.txt mlir/lib/Transforms/Utils/CommutativityUtils.cpp mlir/test/Transforms/test-commutativity-utils.mlir mlir/test/lib/Dialect/Test/TestOps.td mlir/test/lib/Transforms/CMakeLists.txt mlir/test/lib/Transforms/TestCommutativityUtils.cpp mlir/tools/mlir-opt/mlir-opt.cpp
Index: mlir/tools/mlir-opt/mlir-opt.cpp =================================================================== --- mlir/tools/mlir-opt/mlir-opt.cpp +++ mlir/tools/mlir-opt/mlir-opt.cpp @@ -57,6 +57,7 @@ void registerVectorizerTestPass(); namespace test { +void registerCommutativityUtils(); void registerConvertCallOpPass(); void registerInliner(); void registerMemRefBoundCheck(); @@ -152,6 +153,7 @@ registerVectorizerTestPass(); registerTosaTestQuantUtilAPIPass(); + mlir::test::registerCommutativityUtils(); mlir::test::registerConvertCallOpPass(); mlir::test::registerInliner(); mlir::test::registerMemRefBoundCheck(); Index: mlir/test/lib/Transforms/TestCommutativityUtils.cpp =================================================================== --- /dev/null +++ mlir/test/lib/Transforms/TestCommutativityUtils.cpp @@ -0,0 +1,48 @@ +//===- TestCommutativityUtils.cpp - Pass to test the commutativity utility-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass tests the functionality of the commutativity utility pattern. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/CommutativityUtils.h" + +#include "TestDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { + +struct CommutativityUtils + : public PassWrapper<CommutativityUtils, OperationPass<func::FuncOp>> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CommutativityUtils) + + StringRef getArgument() const final { return "test-commutativity-utils"; } + StringRef getDescription() const final { + return "Test the functionality of the commutativity utility"; + } + + void runOnOperation() override { + auto func = getOperation(); + auto *context = &getContext(); + + RewritePatternSet patterns(context); + populateCommutativityUtilsPatterns(patterns); + + (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + } +}; +} // namespace + +namespace mlir { +namespace test { +void registerCommutativityUtils() { PassRegistration<CommutativityUtils>(); } +} // namespace test +} // namespace mlir Index: mlir/test/lib/Transforms/CMakeLists.txt =================================================================== --- mlir/test/lib/Transforms/CMakeLists.txt +++ mlir/test/lib/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRTestTransforms + TestCommutativityUtils.cpp TestConstantFold.cpp TestControlFlowSink.cpp TestInlining.cpp Index: mlir/test/lib/Dialect/Test/TestOps.td =================================================================== --- mlir/test/lib/Dialect/Test/TestOps.td +++ mlir/test/lib/Dialect/Test/TestOps.td @@ -1162,11 +1162,21 @@ let hasFolder = 1; } +def TestAddIOp : TEST_Op<"addi"> { + let arguments = (ins I32:$op1, I32:$op2); + let results = (outs I32); +} + def TestCommutativeOp : TEST_Op<"op_commutative", [Commutative]> { let arguments = (ins I32:$op1, I32:$op2, I32:$op3, I32:$op4); let results = (outs I32); } +def TestLargeCommutativeOp : TEST_Op<"op_large_commutative", [Commutative]> { + let arguments = (ins I32:$op1, I32:$op2, I32:$op3, I32:$op4, I32:$op5, I32:$op6, I32:$op7); + let results = (outs I32); +} + def TestCommutative2Op : TEST_Op<"op_commutative2", [Commutative]> { let arguments = (ins I32:$op1, I32:$op2); let results = (outs I32); Index: mlir/test/Transforms/test-commutativity-utils.mlir =================================================================== --- /dev/null +++ mlir/test/Transforms/test-commutativity-utils.mlir @@ -0,0 +1,116 @@ +// RUN: mlir-opt %s -test-commutativity-utils | FileCheck %s + +// CHECK-LABEL: @test_small_pattern_1 +func.func @test_small_pattern_1(%arg0 : i32) -> i32 { + // CHECK-NEXT: %[[ARITH_CONST:.*]] = arith.constant + %0 = arith.constant 45 : i32 + + // CHECK-NEXT: %[[TEST_ADD:.*]] = "test.addi" + %1 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32 + + // CHECK-NEXT: %[[ARITH_ADD:.*]] = arith.addi + %2 = arith.addi %arg0, %arg0 : i32 + + // CHECK-NEXT: %[[ARITH_MUL:.*]] = arith.muli + %3 = arith.muli %arg0, %arg0 : i32 + + // CHECK-NEXT: %[[RESULT:.*]] = "test.op_commutative"(%[[ARITH_ADD]], %[[ARITH_MUL]], %[[TEST_ADD]], %[[ARITH_CONST]]) + %result = "test.op_commutative"(%0, %1, %2, %3): (i32, i32, i32, i32) -> i32 + + // CHECK-NEXT: return %[[RESULT]] + return %result : i32 +} + +// CHECK-LABEL: @test_small_pattern_2 +// CHECK-SAME: (%[[ARG0:.*]]: i32 +func.func @test_small_pattern_2(%arg0 : i32) -> i32 { + // CHECK-NEXT: %[[TEST_CONST:.*]] = "test.constant" + %0 = "test.constant"() {value = 0 : i32} : () -> i32 + + // CHECK-NEXT: %[[ARITH_CONST:.*]] = arith.constant + %1 = arith.constant 0 : i32 + + // CHECK-NEXT: %[[ARITH_ADD:.*]] = arith.addi + %2 = arith.addi %arg0, %arg0 : i32 + + // CHECK-NEXT: %[[RESULT:.*]] = "test.op_commutative"(%[[ARG0]], %[[ARITH_ADD]], %[[TEST_CONST]], %[[ARITH_CONST]]) + %result = "test.op_commutative"(%0, %1, %2, %arg0): (i32, i32, i32, i32) -> i32 + + // CHECK-NEXT: return %[[RESULT]] + return %result : i32 +} + +// CHECK-LABEL: @test_large_pattern +func.func @test_large_pattern(%arg0 : i32, %arg1 : i32) -> i32 { + // CHECK-NEXT: arith.divsi + %0 = arith.divsi %arg0, %arg1 : i32 + + // CHECK-NEXT: arith.divsi + %1 = arith.divsi %0, %arg0 : i32 + + // CHECK-NEXT: arith.divsi + %2 = arith.divsi %1, %arg1 : i32 + + // CHECK-NEXT: arith.addi + %3 = arith.addi %1, %arg1 : i32 + + // CHECK-NEXT: arith.subi + %4 = arith.subi %2, %3 : i32 + + // CHECK-NEXT: "test.addi" + %5 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32 + + // CHECK-NEXT: %[[VAL6:.*]] = arith.divsi + %6 = arith.divsi %4, %5 : i32 + + // CHECK-NEXT: arith.divsi + %7 = arith.divsi %1, %arg1 : i32 + + // CHECK-NEXT: %[[VAL8:.*]] = arith.muli + %8 = arith.muli %1, %arg1 : i32 + + // CHECK-NEXT: %[[VAL9:.*]] = arith.subi + %9 = arith.subi %7, %8 : i32 + + // CHECK-NEXT: "test.addi" + %10 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32 + + // CHECK-NEXT: %[[VAL11:.*]] = arith.divsi + %11 = arith.divsi %9, %10 : i32 + + // CHECK-NEXT: %[[VAL12:.*]] = arith.divsi + %12 = arith.divsi %6, %arg1 : i32 + + // CHECK-NEXT: arith.subi + %13 = arith.subi %arg1, %arg0 : i32 + + // CHECK-NEXT: "test.op_commutative"(%[[VAL12]], %[[VAL12]], %[[VAL8]], %[[VAL9]]) + %14 = "test.op_commutative"(%12, %9, %12, %8): (i32, i32, i32, i32) -> i32 + + // CHECK-NEXT: %[[VAL15:.*]] = arith.divsi + %15 = arith.divsi %13, %14 : i32 + + // CHECK-NEXT: %[[VAL16:.*]] = arith.addi + %16 = arith.addi %2, %15 : i32 + + // CHECK-NEXT: arith.subi + %17 = arith.subi %16, %arg1 : i32 + + // CHECK-NEXT: "test.addi" + %18 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32 + + // CHECK-NEXT: %[[VAL19:.*]] = arith.divsi + %19 = arith.divsi %17, %18 : i32 + + // CHECK-NEXT: "test.addi" + %20 = "test.addi"(%arg0, %16): (i32, i32) -> i32 + + // CHECK-NEXT: %[[VAL21:.*]] = arith.divsi + %21 = arith.divsi %17, %20 : i32 + + // CHECK-NEXT: %[[RESULT:.*]] = "test.op_large_commutative"(%[[VAL16]], %[[VAL19]], %[[VAL19]], %[[VAL21]], %[[VAL6]], %[[VAL11]], %[[VAL15]]) + %result = "test.op_large_commutative"(%16, %6, %11, %15, %19, %21, %19): (i32, i32, i32, i32, i32, i32, i32) -> i32 + + // CHECK-NEXT: return %[[RESULT]] + return %result : i32 +} Index: mlir/lib/Transforms/Utils/CommutativityUtils.cpp =================================================================== --- /dev/null +++ mlir/lib/Transforms/Utils/CommutativityUtils.cpp @@ -0,0 +1,549 @@ +//===- CommutativityUtils.cpp - Commutativity utilities ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a commutativity utility pattern and a function to +// populate this pattern. The function is intended to be used inside passes to +// simplify the matching of commutative operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/CommutativityUtils.h" + +#include <queue> + +using namespace mlir; + +/// Stores the "ancestor" of an operand of some op. The operand of any op is +/// produced by a set of ops and block arguments. Each of these ops and block +/// arguments is called an "ancestor" of this operand. +struct Ancestor { + /// Stores true when the "ancestor" is an op and false when the "ancestor" is + /// a block argument. + bool isOp; + + /// Stores the op when the "ancestor" is an op and nullptr when the "ancestor" + /// is a block argument. + Operation *op; + + /// Defines the constructor for `Ancestor`. + Ancestor(Operation *opForAncestor) { + isOp = opForAncestor ? true : false; + op = opForAncestor; + } +}; + +/// Declares various "types" of ancestors. +enum AncestorType { + /// Pertains to a block argument. + BLOCK_ARGUMENT, + + /// Pertains to a non-constant-like op. + NON_CONSTANT_OP, + + /// Pertains to a constant-like op. + CONSTANT_OP +}; + +/// Stores the "key" associated with an ancestor. +struct AncestorKey { + /// Holds `BLOCK_ARGUMENT`, `NON_CONSTANT_OP`, or `CONSTANT_OP`, depending on + /// the ancestor. + AncestorType type; + + /// Holds the full op name of the ancestor, for example, "arith.addi", iff + /// `type` is `NON_CONSTANT_OP`. + StringRef opName; + + /// Defines the constructor for `AncestorKey`. + AncestorKey(Ancestor ancestor) { + if (!ancestor.isOp) { + // When `ancestor` is a block argument, we assign `type` as + // `BLOCK_ARGUMENT` and `opName` remains "". + type = BLOCK_ARGUMENT; + } else if (!ancestor.op->hasTrait<OpTrait::ConstantLike>()) { + // When `ancestor` is a non-constant-like op, we assign `type` as + // `NON_CONSTANT_OP` and `opName` as the full op name of `ancestor`. + type = NON_CONSTANT_OP; + opName = ancestor.op->getName().getStringRef(); + } else { + // When `ancestor` is a constant-like op, we assign `type` as + // `CONSTANT_OP` and `opName` remains "". + type = CONSTANT_OP; + } + } + + /// Declares the overloaded operator `<`. + /// `AncestorKey1` is considered < `AncestorKey2` iff: + /// 1. The `type` of `AncestorKey1` is `BLOCK_ARGUMENT` and that of + /// `AncestorKey2` isn't, + /// 2. The `type` of `AncestorKey1` is `NON_CONSTANT_OP` and that of + /// `AncestorKey2` is `CONSTANT_OP`, or + /// 3. Both have the same `type` and the `opName` of `AncestorKey1` is + /// alphabetically smaller than that of `AncestorKey2`. + bool operator<(const AncestorKey &key) const { + if ((type == BLOCK_ARGUMENT && key.type != BLOCK_ARGUMENT) || + (type == NON_CONSTANT_OP && key.type == CONSTANT_OP)) + return true; + if ((key.type == BLOCK_ARGUMENT && type != BLOCK_ARGUMENT) || + (key.type == NON_CONSTANT_OP && type == CONSTANT_OP)) + return false; + return opName < key.opName; + } +}; + +/// Stores the BFS traversal information of an operand. +struct OperandBFS { + /// Stores the original position of the operand in its unsorted op. + unsigned originalPosition; + + /// Stores the queue of ancestors of the BFS traversal of an operand at a + /// particular point in time. + std::queue<Ancestor> ancestorQueue; + + /// Stores the list of visited "op" ancestors of the BFS traversal of an + /// operand at a particular point in time. + DenseSet<Operation *> visitedAncestors; + + /// Stores the "key" associated with an operand. This "key" is defined as the + /// list of the "AncestorKeys" associated with the ancestors of this operand, + /// in a breadth-first order. + /// + /// So, if an operand, say `A`, was produced as follows: + /// + /// `<block argument>` `<block argument>` + /// \ / + /// \ / + /// `arith.subi` `arith.constant` + /// \ / + /// `arith.addi` + /// | + /// returns `A` + /// + /// Then, the ancestors of `A`, in the breadth-first order are: + /// `arith.addi`, `arith.subi`, `arith.constant`, `<block argument>`, and + /// `<block argument>`. + /// + /// Now, as already mentioned, the "AncestorKey" associated with: + /// 1. A block argument is {type: `BLOCK_ARGUMENT`, opName: ""}. + /// 2. A non-constant-like op, for example, `arith.addi`, is {type: + /// `NON_CONSTANT_OP`, opName: "arith.addi"}. + /// 3. A constant-like op, for example, `arith.constant`, is {type: + /// `CONSTANT_OP`, opName: ""}. + /// + /// Thus, the "key" associated with operand `A` is: + /// { + /// {type: `NON_CONSTANT_OP`, opName: "arith.addi"}, + /// {type: `NON_CONSTANT_OP`, opName: "arith.subi"}, + /// {type: `CONSTANT_OP`, opName: ""}, + /// {type: `BLOCK_ARGUMENT`, opName: ""}, + /// {type: `BLOCK_ARGUMENT`, opName: ""} + /// } + SmallVector<AncestorKey, 4> key; + + /// Stores true iff the operand has been assigned a sorted position yet. + bool isSorted = false; + + /// Push an ancestor into the operand's BFS information structure. This + /// entails it being pushed into the queue (always) and inserted into the + /// "visited ancestors" list (iff it is an op rather than a block argument). + void pushAncestor(Operation *op) { + Ancestor ancestor(op); + ancestorQueue.push(ancestor); + if (ancestor.isOp) + visitedAncestors.insert(ancestor.op); + return; + } + + /// Pop the ancestor from the front of the queue. + void popAncestor() { + assert(!ancestorQueue.empty() && + "to pop the ancestor from the front of the queue, the ancestor " + "queue should be non-empty"); + ancestorQueue.pop(); + return; + } + + /// Return the ancestor at the front of the queue. + Ancestor frontAncestor() { + assert(!ancestorQueue.empty() && + "to access the ancestor at the front of the queue, the ancestor " + "queue should be non-empty"); + return ancestorQueue.front(); + } +}; + +/// Returns: +/// -1 if `keyA` < `keyB`, +/// 0 if `keyA` == `keyB`, and +/// 1 if `keyA` > `keyB`. +/// +/// Note that: +/// +/// (A) `keyA` == `keyB` iff: +/// Both these keys, each of which is a list, have the same size and both +/// the elements in each pair of corresponding elements among them is the +/// same. +/// +/// (B) `keyA` < `keyB` iff: +/// 1. In the first unequal pair of corresponding elements among them, +/// `keyA`'s element is smaller, or +/// 2. Both the elements in every pair of corresponding elements are the same +/// in both keys and the size of `keyA` is smaller. +/// +/// (C) `keyB` < `keyA` condition is defined likewise. +static int compareKeys(ArrayRef<AncestorKey> keyA, ArrayRef<AncestorKey> keyB) { + unsigned keyASize = keyA.size(); + unsigned keyBSize = keyB.size(); + unsigned smallestSize = keyASize; + if (keyBSize < smallestSize) + smallestSize = keyBSize; + + for (unsigned i = 0; i < smallestSize; i++) { + if (keyA[i] < keyB[i]) + return -1; + if (keyB[i] < keyA[i]) + return 1; + } + + if (keyASize == keyBSize) + return 0; + if (keyASize < keyBSize) + return -1; + return 1; +} + +/// Refresh the key associated with each unsorted operand in `bfsOfOperands`. +/// Refreshing a key entails making it up-to-date with its associated operand's +/// BFS traversal that has happened till that point in time, i.e, appending the +/// existing key with the front ancestor's "AncestorKey". Note that a +/// key directly reflects the BFS and thus needs to be refreshed during the +/// progression of the traversal. +static void refreshKeys(ArrayRef<std::unique_ptr<OperandBFS>> bfsOfOperands) { + for (const std::unique_ptr<OperandBFS> &bfsOfOperand : bfsOfOperands) { + if (bfsOfOperand->isSorted || bfsOfOperand->ancestorQueue.empty()) + continue; + + // Append the key of `bfsOfOperand` with the its front ancestor's + // "AncestorKey". + Ancestor frontAncestor = bfsOfOperand->frontAncestor(); + AncestorKey frontAncestorKey(frontAncestor); + bfsOfOperand->key.push_back(frontAncestorKey); + } + return; +} + +// Compute the smallest key present among the unsorted operands in +// `bfsOfOperands`. +static ArrayRef<AncestorKey> computeTheSmallestUnsortedKey( + ArrayRef<std::unique_ptr<OperandBFS>> bfsOfOperands) { + ArrayRef<AncestorKey> smallestKey; + bool foundUnsortedOperand = false; + for (const std::unique_ptr<OperandBFS> &bfsOfOperand : bfsOfOperands) { + if (bfsOfOperand->isSorted) + continue; + + ArrayRef<AncestorKey> currentKey = bfsOfOperand->key; + if (!foundUnsortedOperand) { + foundUnsortedOperand = true; + smallestKey = currentKey; + continue; + } + if (compareKeys(smallestKey, currentKey) == 1) + smallestKey = currentKey; + } + return smallestKey; +} + +/// In `operandsWithKey`, store all the operands (of `bfsOfOperands`) whose key +/// = `key`. And, store true in `hasOneOperandWithKey` iff there is exactly one +/// such operand. +static void getBFSOfOperandsWithKey( + ArrayRef<AncestorKey> key, + SmallVectorImpl<std::unique_ptr<OperandBFS>> &bfsOfOperandsWithKey, + ArrayRef<std::unique_ptr<OperandBFS>> bfsOfOperands, + bool &hasOneOperandWithKey) { + bool keyFound = false; + hasOneOperandWithKey = true; + for (const std::unique_ptr<OperandBFS> &bfsOfOperand : bfsOfOperands) { + if (bfsOfOperand->isSorted) + continue; + + ArrayRef<AncestorKey> currentKey = bfsOfOperand->key; + if (compareKeys(key, currentKey) == 0) { + bfsOfOperandsWithKey.push_back( + std::make_unique<OperandBFS>(*bfsOfOperand)); + if (keyFound) + hasOneOperandWithKey = false; + keyFound = true; + } + } +} + +/// Returns true and stores: +/// 1. the BFS traversal information of all the unsorted operands (of +/// `bfsOfOperands`) which have the smallest key in +/// `bfsOfSmallestUnsortedOperands`, and, +/// 2. true in `hasOneSmallestOperand` when there exists exactly one such +/// operand, +/// iff there exists at least one such operand. +static bool getBFSOfSmallestUnsortedOperands( + SmallVectorImpl<std::unique_ptr<OperandBFS>> &bfsOfSmallestUnsortedOperands, + ArrayRef<std::unique_ptr<OperandBFS>> bfsOfOperands, + bool &hasOneSmallestOperand) { + + // If there exists no unsorted operand, return false. + if (llvm::all_of(bfsOfOperands, + [](const std::unique_ptr<OperandBFS> &bfsOfOperand) { + return bfsOfOperand->isSorted; + })) + return false; + + // Get the smallest key present among the unsorted operands. + ArrayRef<AncestorKey> smallestKey = + computeTheSmallestUnsortedKey(bfsOfOperands); + + // Set `bfsOfSmallestUnsortedOperands` and `hasOneSmallestOperand`. + getBFSOfOperandsWithKey( + /*key=*/smallestKey, + /*bfsOfOperandsWithKey=*/bfsOfSmallestUnsortedOperands, bfsOfOperands, + /*hasOneOperandWithKey=*/hasOneSmallestOperand); + + return true; +} + +/// Shift the BFS traversal information of `bfsOfOperandToShift` to +/// `frontPosition` in `bfsOfOperands` iff either `isUnique` is true or the BFS +/// traversal of `bfsOfOperandToShift` is complete. +static bool shiftBFSOfOperandToFront( + std::unique_ptr<OperandBFS> &bfsOfOperandToShift, unsigned frontPosition, + bool isUnique, + SmallVectorImpl<std::unique_ptr<OperandBFS>> &bfsOfOperands) { + if (!isUnique && !bfsOfOperandToShift->ancestorQueue.empty()) + return false; + + assert(frontPosition >= 0 && frontPosition < bfsOfOperands.size() && + "`frontPosition` should be valid"); + unsigned positionOfOperandToShift; + bool foundOperandToShift = false; + for (auto &indexedBfsOfOperand : llvm::enumerate(bfsOfOperands)) { + std::unique_ptr<OperandBFS> &bfsOfOperand = indexedBfsOfOperand.value(); + if (bfsOfOperand->isSorted) + continue; + if (bfsOfOperandToShift->originalPosition == + bfsOfOperand->originalPosition) { + positionOfOperandToShift = indexedBfsOfOperand.index(); + foundOperandToShift = true; + break; + } + } + assert(foundOperandToShift && + "`operandToShift` should be present in `bfsOfOperands`"); + assert(positionOfOperandToShift >= frontPosition && + "`operandToShift` should be positioned after `frontPosition`"); + + for (int p = int(positionOfOperandToShift) - 1; p >= int(frontPosition); p--) + bfsOfOperands[p + 1] = std::move(bfsOfOperands[p]); + bfsOfOperands[frontPosition] = std::move(bfsOfOperandToShift); + bfsOfOperands[frontPosition]->isSorted = true; + return true; +} + +/// Among the unsorted operands in `bfsOfOperands`, shift the ones with the +/// smallest key to the smallest unsorted positions. Then, update +/// `smallestUnsortedPosition` to store the smallest unsorted position (i.e., +/// the smallest position containing an unsorted operand). +static void shiftTheSmallestUnsortedOperandsToTheSmallestUnsortedPositions( + SmallVectorImpl<std::unique_ptr<OperandBFS>> &bfsOfOperands, + unsigned &smallestUnsortedPosition) { + // Stores true iff there is exactly one unsorted operand that has the + // smallest key. + bool hasOneSmallestOperand; + + // Store the BFS traversal information of the unsorted operands with the + // smallest key in `bfsOfSmallestUnsortedOperands`. + SmallVector<std::unique_ptr<OperandBFS>, 2> bfsOfSmallestUnsortedOperands; + getBFSOfSmallestUnsortedOperands(bfsOfSmallestUnsortedOperands, bfsOfOperands, + hasOneSmallestOperand); + + // Shift an operand of `bfsOfSmallestUnsortedOperands` iff it is either unique + // or has completed its BFS traversal. + for (std::unique_ptr<OperandBFS> &bfsOfSmallestUnsortedOperand : + bfsOfSmallestUnsortedOperands) { + if (shiftBFSOfOperandToFront( + /*bfsOfOperandToShift=*/bfsOfSmallestUnsortedOperand, + /*frontPosition=*/smallestUnsortedPosition, + /*isUnique=*/hasOneSmallestOperand, bfsOfOperands)) + smallestUnsortedPosition++; + } +} + +/// In each of the sorted operands of `bfsOfOperands`, pop the front ancestor +/// from the queue, if any, and then push its adjacent unvisited ancestors, if +/// any, to the queue (this is the main body of the BFS algorithm). +static void popFrontAndPushAdjacentUnvisitedAncestors( + ArrayRef<std::unique_ptr<OperandBFS>> bfsOfOperands) { + for (const std::unique_ptr<OperandBFS> &bfsOfOperand : bfsOfOperands) { + if (bfsOfOperand->isSorted || bfsOfOperand->ancestorQueue.empty()) + continue; + Ancestor frontAncestor = bfsOfOperand->frontAncestor(); + bfsOfOperand->popAncestor(); + if (!frontAncestor.isOp) + continue; + for (Value operand : frontAncestor.op->getOperands()) { + Operation *operandDefOp = operand.getDefiningOp(); + if (!operandDefOp || + !bfsOfOperand->visitedAncestors.contains(operandDefOp)) + bfsOfOperand->pushAncestor(operandDefOp); + } + } + return; +} + +/// Sorts the operands of `op` in ascending order of the "key" associated with +/// each operand iff `op` is commutative. +/// +/// After the application of this pattern, since the commutative operands now +/// have a deterministic order in which they occur in an op, the matching of +/// large DAGs becomes much simpler, i.e., requires much less number of checks +/// to be written by a user in her/his pattern matching function. +/// +/// Some examples of such a sorting: +/// +/// Assume that the sorting is being applied to `foo.commutative`, which is a +/// commutative op. +/// +/// Example 1: +/// +/// %1 = foo.const 0 +/// %2 = foo.mul <block argument>, <block argument> +/// %3 = foo.commutative %1, %2 +/// +/// Here, +/// 1. The key associated with %1 is `{{CONSTANT_OP, ""}}`, and, +/// 2. The key associated with %2 is `{{NON_CONSTANT_OP, "foo.mul"}, +/// {BLOCK_ARGUMENT, ""}, {BLOCK_ARGUMENT, ""}}`. +/// +/// Thus, the key of %2 < key of %1, and so, the sorted `foo.commutative` will +/// look like: +/// +/// %3 = foo.commutative %2, %1 +/// +/// Note that in this example, it wasn't necessary to get the entire set of +/// ancestors of operand %2 to decide that it has the smaller key. Just the +/// comparision of `{{CONSTANT_OP, ""}}` and `{{NON_CONSTANT_OP, "foo.mul"}}` is +/// enough to determine this. Thus, this sorting utility does not store the +/// entire set of ancestors of an operand at once. It seeks the ancestors in a +/// breadth-first fashion, one at a time, to create an operand "key" but stops +/// when the relative ordering of that operand can be determined. +/// +/// So, here, +/// 1. The key computed for %1 is `{{CONSTANT_OP, ""}}`, and, +/// 2. The key computed for %2 is `{{NON_CONSTANT_OP, "foo.mul"}}` (instead of +/// `{{NON_CONSTANT_OP, "foo.mul"}, {BLOCK_ARGUMENT, ""}, {BLOCK_ARGUMENT, +/// ""}}`. +/// +/// Example 2: +/// +/// %1 = foo.const 0 +/// %2 = foo.mul <block argument>, <block argument> +/// %3 = foo.mul %2, %1 +/// %4 = foo.add %2, %1 +/// %5 = foo.commutative %1, %2, %3, %4 +/// +/// Here, +/// 1. The key associated with %1 is `{{CONSTANT_OP, ""}}`, +/// 2. The key associated with %2 is `{{NON_CONSTANT_OP, "foo.mul"}, +/// {BLOCK_ARGUMENT, ""}}`, +/// 3. The key associated with %3 is `{{NON_CONSTANT_OP, "foo.mul"}, +/// {NON_CONSTANT_OP, "foo.mul"}, {CONSTANT_OP, ""}, {BLOCK_ARGUMENT, ""}, +/// {BLOCK_ARGUMENT, ""}}`, and, +/// 4. The key associated with %4 is `{{NON_CONSTANT_OP, "foo.add"}, +/// {NON_CONSTANT_OP, "foo.mul"}, {CONSTANT_OP, ""}, {BLOCK_ARGUMENT, ""}, +/// {BLOCK_ARGUMENT, ""}}`. +/// +/// And, +/// 1. The key computed for %1 is `{{CONSTANT_OP, ""}}`, +/// 2. The key computed for %2 is `{{NON_CONSTANT_OP, "foo.mul"}, +/// {BLOCK_ARGUMENT, ""}}`, +/// 3. The key computed for %3 is `{{NON_CONSTANT_OP, "foo.mul"}, +/// {NON_CONSTANT_OP, "foo.mul"}}`, and, +/// 4. The key computed for %4 is `{{NON_CONSTANT_OP, "foo.add"}}`. +/// +/// Note that the BFS's for operands %1 and %4 stop after one iteration while +/// those for operands %2 and %3 stop after the second iteration. +/// +/// And, the sorted `foo.commutative` will look like: +/// +/// %5 = foo.commutative %4, %3, %2, %1 +class SortCommutativeOperands : public RewritePattern { +public: + SortCommutativeOperands(MLIRContext *context) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/5, context) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + // If `op` is not commutative, do nothing. + if (!op->hasTrait<OpTrait::IsCommutative>()) + return failure(); + + // Store the BFS traversal information of each operand in `bfsOfOperands`. + SmallVector<std::unique_ptr<OperandBFS>, 2> bfsOfOperands; + + // Each operand's BFS starts with its first ancestor (i.e., the op defining + // it or the block argument). + SmallVector<Value, 2> unsortedOperands = op->getOperands(); + for (auto &indexedOperand : llvm::enumerate(unsortedOperands)) { + Value operand = indexedOperand.value(); + bfsOfOperands.push_back(std::make_unique<OperandBFS>()); + bfsOfOperands.back()->originalPosition = indexedOperand.index(); + bfsOfOperands.back()->pushAncestor(operand.getDefiningOp()); + } + + // Since none of the operands have been assigned a sorted position yet, the + // smallest unsorted position is zero. + unsigned numOperands = op->getNumOperands(); + unsigned smallestUnsortedPosition = 0; + + // We perform the BFS traversals of all operands parallelly until each of + // them is assigned a sorted position. During each iteration, the BFS's move + // ahead and we shift the smallest unsorted operands to the smallest + // unsorted positions. + while (smallestUnsortedPosition < numOperands - 1) { + // Refresh the keys of all unsorted operands. + refreshKeys(bfsOfOperands); + + // Among the unsorted operands, shift the ones with the smallest key to + // the smallest unsorted positions. Then, update + // `smallestUnsortedPosition` + shiftTheSmallestUnsortedOperandsToTheSmallestUnsortedPositions( + bfsOfOperands, smallestUnsortedPosition); + + // For each unsorted operand, pop the front ancestor from the BFS queue + // and push its adjacent unvisited ancestors into the queue (this is the + // main body of the BFS algorithm). + popFrontAndPushAdjacentUnvisitedAncestors(bfsOfOperands); + } + + // Store the list of `op`'s sorted operands in `sortedOperands`. + SmallVector<Value, 2> sortedOperands; + for (std::unique_ptr<OperandBFS> &bfsOfOperand : bfsOfOperands) + sortedOperands.push_back(op->getOperand(bfsOfOperand->originalPosition)); + + // If the operands were already sorted, return failure. + if (unsortedOperands == sortedOperands) + return failure(); + + // Else, replace the existing operands with the sorted ones and return + // success. + rewriter.updateRootInPlace(op, [&] { op->setOperands(sortedOperands); }); + return success(); + } +}; + +void mlir::populateCommutativityUtilsPatterns(RewritePatternSet &patterns) { + patterns.add<SortCommutativeOperands>(patterns.getContext()); +} Index: mlir/lib/Transforms/Utils/CMakeLists.txt =================================================================== --- mlir/lib/Transforms/Utils/CMakeLists.txt +++ mlir/lib/Transforms/Utils/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_library(MLIRTransformUtils + CommutativityUtils.cpp ControlFlowSinkUtils.cpp DialectConversion.cpp FoldUtils.cpp Index: mlir/include/mlir/Transforms/CommutativityUtils.h =================================================================== --- /dev/null +++ mlir/include/mlir/Transforms/CommutativityUtils.h @@ -0,0 +1,27 @@ +//===- CommutativityUtils.h - Commutativity utilities -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file declares a function to populate the commutativity utility +// pattern. This function is intended to be used inside passes to simplify the +// matching of commutative operations. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TRANSFORMS_COMMUTATIVITYUTILS_H +#define MLIR_TRANSFORMS_COMMUTATIVITYUTILS_H + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { + +/// Populates the commutativity utility patterns. +void populateCommutativityUtilsPatterns(RewritePatternSet &patterns); + +} // namespace mlir + +#endif // MLIR_TRANSFORMS_COMMUTATIVITYUTILS_H
_______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits