srishti-pm updated this revision to Diff 440570.
srishti-pm marked 6 inline comments as done.
srishti-pm added a comment.

Addressed the final comments.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D124750/new/

https://reviews.llvm.org/D124750

Files:
  mlir/include/mlir/Transforms/CommutativityUtils.h
  mlir/lib/Transforms/Utils/CMakeLists.txt
  mlir/lib/Transforms/Utils/CommutativityUtils.cpp
  mlir/test/Transforms/test-commutativity-utils.mlir
  mlir/test/lib/Dialect/Test/TestOps.td
  mlir/test/lib/Transforms/CMakeLists.txt
  mlir/test/lib/Transforms/TestCommutativityUtils.cpp
  mlir/tools/mlir-opt/mlir-opt.cpp

Index: mlir/tools/mlir-opt/mlir-opt.cpp
===================================================================
--- mlir/tools/mlir-opt/mlir-opt.cpp
+++ mlir/tools/mlir-opt/mlir-opt.cpp
@@ -57,6 +57,7 @@
 void registerVectorizerTestPass();
 
 namespace test {
+void registerCommutativityUtils();
 void registerConvertCallOpPass();
 void registerInliner();
 void registerMemRefBoundCheck();
@@ -152,6 +153,7 @@
   registerVectorizerTestPass();
   registerTosaTestQuantUtilAPIPass();
 
+  mlir::test::registerCommutativityUtils();
   mlir::test::registerConvertCallOpPass();
   mlir::test::registerInliner();
   mlir::test::registerMemRefBoundCheck();
Index: mlir/test/lib/Transforms/TestCommutativityUtils.cpp
===================================================================
--- /dev/null
+++ mlir/test/lib/Transforms/TestCommutativityUtils.cpp
@@ -0,0 +1,48 @@
+//===- TestCommutativityUtils.cpp - Pass to test the commutativity utility-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass tests the functionality of the commutativity utility pattern.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/CommutativityUtils.h"
+
+#include "TestDialect.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+
+struct CommutativityUtils
+    : public PassWrapper<CommutativityUtils, OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CommutativityUtils)
+
+  StringRef getArgument() const final { return "test-commutativity-utils"; }
+  StringRef getDescription() const final {
+    return "Test the functionality of the commutativity utility";
+  }
+
+  void runOnOperation() override {
+    auto func = getOperation();
+    auto *context = &getContext();
+
+    RewritePatternSet patterns(context);
+    populateCommutativityUtilsPatterns(patterns);
+
+    (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
+  }
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerCommutativityUtils() { PassRegistration<CommutativityUtils>(); }
+} // namespace test
+} // namespace mlir
Index: mlir/test/lib/Transforms/CMakeLists.txt
===================================================================
--- mlir/test/lib/Transforms/CMakeLists.txt
+++ mlir/test/lib/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRTestTransforms
+  TestCommutativityUtils.cpp
   TestConstantFold.cpp
   TestControlFlowSink.cpp
   TestInlining.cpp
Index: mlir/test/lib/Dialect/Test/TestOps.td
===================================================================
--- mlir/test/lib/Dialect/Test/TestOps.td
+++ mlir/test/lib/Dialect/Test/TestOps.td
@@ -1162,11 +1162,21 @@
   let hasFolder = 1;
 }
 
+def TestAddIOp : TEST_Op<"addi"> {
+  let arguments = (ins I32:$op1, I32:$op2);
+  let results = (outs I32);
+}
+
 def TestCommutativeOp : TEST_Op<"op_commutative", [Commutative]> {
   let arguments = (ins I32:$op1, I32:$op2, I32:$op3, I32:$op4);
   let results = (outs I32);
 }
 
+def TestLargeCommutativeOp : TEST_Op<"op_large_commutative", [Commutative]> {
+  let arguments = (ins I32:$op1, I32:$op2, I32:$op3, I32:$op4, I32:$op5, I32:$op6, I32:$op7);
+  let results = (outs I32);
+}
+
 def TestCommutative2Op : TEST_Op<"op_commutative2", [Commutative]> {
   let arguments = (ins I32:$op1, I32:$op2);
   let results = (outs I32);
Index: mlir/test/Transforms/test-commutativity-utils.mlir
===================================================================
--- /dev/null
+++ mlir/test/Transforms/test-commutativity-utils.mlir
@@ -0,0 +1,116 @@
+// RUN: mlir-opt %s -test-commutativity-utils | FileCheck %s
+
+// CHECK-LABEL: @test_small_pattern_1
+func.func @test_small_pattern_1(%arg0 : i32) -> i32 {
+  // CHECK-NEXT: %[[ARITH_CONST:.*]] = arith.constant
+  %0 = arith.constant 45 : i32
+
+  // CHECK-NEXT: %[[TEST_ADD:.*]] = "test.addi"
+  %1 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32
+
+  // CHECK-NEXT: %[[ARITH_ADD:.*]] = arith.addi
+  %2 = arith.addi %arg0, %arg0 : i32
+
+  // CHECK-NEXT: %[[ARITH_MUL:.*]] = arith.muli
+  %3 = arith.muli %arg0, %arg0 : i32
+
+  // CHECK-NEXT: %[[RESULT:.*]] = "test.op_commutative"(%[[ARITH_ADD]], %[[ARITH_MUL]], %[[TEST_ADD]], %[[ARITH_CONST]])
+  %result = "test.op_commutative"(%0, %1, %2, %3): (i32, i32, i32, i32) -> i32
+
+  // CHECK-NEXT: return %[[RESULT]]
+  return %result : i32
+}
+
+// CHECK-LABEL: @test_small_pattern_2
+// CHECK-SAME: (%[[ARG0:.*]]: i32
+func.func @test_small_pattern_2(%arg0 : i32) -> i32 {
+  // CHECK-NEXT: %[[TEST_CONST:.*]] = "test.constant"
+  %0 = "test.constant"() {value = 0 : i32} : () -> i32
+
+  // CHECK-NEXT: %[[ARITH_CONST:.*]] = arith.constant
+  %1 = arith.constant 0 : i32
+
+  // CHECK-NEXT: %[[ARITH_ADD:.*]] = arith.addi
+  %2 = arith.addi %arg0, %arg0 : i32
+
+  // CHECK-NEXT: %[[RESULT:.*]] = "test.op_commutative"(%[[ARG0]], %[[ARITH_ADD]], %[[TEST_CONST]], %[[ARITH_CONST]])
+  %result = "test.op_commutative"(%0, %1, %2, %arg0): (i32, i32, i32, i32) -> i32
+
+  // CHECK-NEXT: return %[[RESULT]]
+  return %result : i32
+}
+
+// CHECK-LABEL: @test_large_pattern
+func.func @test_large_pattern(%arg0 : i32, %arg1 : i32) -> i32 {
+  // CHECK-NEXT: arith.divsi
+  %0 = arith.divsi %arg0, %arg1 : i32
+
+  // CHECK-NEXT: arith.divsi
+  %1 = arith.divsi %0, %arg0 : i32
+
+  // CHECK-NEXT: arith.divsi
+  %2 = arith.divsi %1, %arg1 : i32
+
+  // CHECK-NEXT: arith.addi
+  %3 = arith.addi %1, %arg1 : i32
+
+  // CHECK-NEXT: arith.subi
+  %4 = arith.subi %2, %3 : i32
+
+  // CHECK-NEXT: "test.addi"
+  %5 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32
+
+  // CHECK-NEXT: %[[VAL6:.*]] = arith.divsi
+  %6 = arith.divsi %4, %5 : i32
+
+  // CHECK-NEXT: arith.divsi
+  %7 = arith.divsi %1, %arg1 : i32
+
+  // CHECK-NEXT: %[[VAL8:.*]] = arith.muli
+  %8 = arith.muli %1, %arg1 : i32
+
+  // CHECK-NEXT: %[[VAL9:.*]] = arith.subi
+  %9 = arith.subi %7, %8 : i32
+
+  // CHECK-NEXT: "test.addi"
+  %10 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32
+
+  // CHECK-NEXT: %[[VAL11:.*]] = arith.divsi
+  %11 = arith.divsi %9, %10 : i32
+
+  // CHECK-NEXT: %[[VAL12:.*]] = arith.divsi
+  %12 = arith.divsi %6, %arg1 : i32
+
+  // CHECK-NEXT: arith.subi
+  %13 = arith.subi %arg1, %arg0 : i32
+
+  // CHECK-NEXT: "test.op_commutative"(%[[VAL12]], %[[VAL12]], %[[VAL8]], %[[VAL9]])
+  %14 = "test.op_commutative"(%12, %9, %12, %8): (i32, i32, i32, i32) -> i32
+
+  // CHECK-NEXT: %[[VAL15:.*]] = arith.divsi
+  %15 = arith.divsi %13, %14 : i32
+
+  // CHECK-NEXT: %[[VAL16:.*]] = arith.addi
+  %16 = arith.addi %2, %15 : i32
+
+  // CHECK-NEXT: arith.subi
+  %17 = arith.subi %16, %arg1 : i32
+
+  // CHECK-NEXT: "test.addi"
+  %18 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32
+
+  // CHECK-NEXT: %[[VAL19:.*]] = arith.divsi
+  %19 = arith.divsi %17, %18 : i32
+
+  // CHECK-NEXT: "test.addi"
+  %20 = "test.addi"(%arg0, %16): (i32, i32) -> i32
+
+  // CHECK-NEXT: %[[VAL21:.*]] = arith.divsi
+  %21 = arith.divsi %17, %20 : i32
+
+  // CHECK-NEXT: %[[RESULT:.*]] = "test.op_large_commutative"(%[[VAL16]], %[[VAL19]], %[[VAL19]], %[[VAL21]], %[[VAL6]], %[[VAL11]], %[[VAL15]])
+  %result = "test.op_large_commutative"(%16, %6, %11, %15, %19, %21, %19): (i32, i32, i32, i32, i32, i32, i32) -> i32
+
+  // CHECK-NEXT: return %[[RESULT]]
+  return %result : i32
+}
Index: mlir/lib/Transforms/Utils/CommutativityUtils.cpp
===================================================================
--- /dev/null
+++ mlir/lib/Transforms/Utils/CommutativityUtils.cpp
@@ -0,0 +1,558 @@
+//===- CommutativityUtils.cpp - Commutativity utilities ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a commutativity utility pattern and a function to
+// populate this pattern. The function is intended to be used inside passes to
+// simplify the matching of commutative operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/CommutativityUtils.h"
+
+#include <queue>
+
+using namespace mlir;
+
+/// Stores the "ancestor" of an operand of some op. The operand of any op is
+/// produced by a set of ops and block arguments. Each of these ops and block
+/// arguments is called an "ancestor" of this operand.
+struct Ancestor {
+  /// Stores true when the "ancestor" is an op and false when the "ancestor" is
+  /// a block argument.
+  bool isOp;
+
+  /// Stores the op when the "ancestor" is an op and nullptr when the "ancestor"
+  /// is a block argument.
+  Operation *op;
+
+  /// Defines the constructor for `Ancestor`.
+  Ancestor(Operation *opForAncestor) {
+    isOp = opForAncestor ? true : false;
+    op = opForAncestor;
+  }
+};
+
+/// Declares various "types" of ancestors.
+enum AncestorType {
+  /// Pertains to a block argument.
+  BLOCK_ARGUMENT,
+
+  /// Pertains to a non-constant-like op.
+  NON_CONSTANT_OP,
+
+  /// Pertains to a constant-like op.
+  CONSTANT_OP
+};
+
+/// Stores the "key" associated with an ancestor.
+struct AncestorKey {
+  /// Holds `BLOCK_ARGUMENT`, `NON_CONSTANT_OP`, or `CONSTANT_OP`, depending on
+  /// the ancestor.
+  AncestorType type;
+
+  /// Holds the full op name of the ancestor, for example, "arith.addi", iff
+  /// `type` is `NON_CONSTANT_OP`.
+  StringRef opName;
+
+  /// Defines the constructor for `AncestorKey`.
+  AncestorKey(Ancestor ancestor) {
+    if (!ancestor.isOp) {
+      // When `ancestor` is a block argument, we assign `type` as
+      // `BLOCK_ARGUMENT` and `opName` remains "".
+      type = BLOCK_ARGUMENT;
+    } else if (!ancestor.op->hasTrait<OpTrait::ConstantLike>()) {
+      // When `ancestor` is a non-constant-like op, we assign `type` as
+      // `NON_CONSTANT_OP` and `opName` as the full op name of `ancestor`.
+      type = NON_CONSTANT_OP;
+      opName = ancestor.op->getName().getStringRef();
+    } else {
+      // When `ancestor` is a constant-like op, we assign `type` as
+      // `CONSTANT_OP` and `opName` remains "".
+      type = CONSTANT_OP;
+    }
+  }
+
+  /// Declares the overloaded operator `<`.
+  /// `AncestorKey1` is considered < `AncestorKey2` iff:
+  ///   1. The `type` of `AncestorKey1` is `BLOCK_ARGUMENT` and that of
+  ///   `AncestorKey2` isn't,
+  ///   2. The `type` of `AncestorKey1` is `NON_CONSTANT_OP` and that of
+  ///   `AncestorKey2` is `CONSTANT_OP`, or
+  ///   3. Both have the same `type` and the `opName` of `AncestorKey1` is
+  ///   alphabetically smaller than that of `AncestorKey2`.
+  bool operator<(const AncestorKey &key) const {
+    if ((type == BLOCK_ARGUMENT && key.type != BLOCK_ARGUMENT) ||
+        (type == NON_CONSTANT_OP && key.type == CONSTANT_OP))
+      return true;
+    if ((key.type == BLOCK_ARGUMENT && type != BLOCK_ARGUMENT) ||
+        (key.type == NON_CONSTANT_OP && type == CONSTANT_OP))
+      return false;
+    return opName < key.opName;
+  }
+};
+
+/// Stores the BFS traversal information of an operand.
+struct OperandBFS {
+  /// Stores the queue of ancestors of the BFS traversal of an operand at a
+  /// particular point in time.
+  std::queue<Ancestor> ancestorQueue;
+
+  /// Stores the list of visited "op" ancestors of the BFS traversal of an
+  /// operand at a particular point in time.
+  DenseSet<Operation *> visitedAncestors;
+
+  /// Stores the "key" associated with an operand. This "key" is defined as the
+  /// list of the "AncestorKeys" associated with the ancestors of this operand,
+  /// in a breadth-first order.
+  ///
+  /// So, if an operand, say `A`, was produced as follows:
+  ///
+  /// `<block argument>`  `<block argument>`
+  ///             \          /
+  ///              \        /
+  ///             `arith.subi`           `arith.constant`
+  ///                       \            /
+  ///                        `arith.addi`
+  ///                              |
+  ///                         returns `A`
+  ///
+  /// Then, the ancestors of `A`, in the breadth-first order are:
+  /// `arith.addi`, `arith.subi`, `arith.constant`, `<block argument>`, and
+  /// `<block argument>`.
+  ///
+  /// Now, as already mentioned, the "AncestorKey" associated with:
+  /// 1. A block argument is {type: `BLOCK_ARGUMENT`, opName: ""}.
+  /// 2. A non-constant-like op, for example, `arith.addi`, is {type:
+  /// `NON_CONSTANT_OP`, opName: "arith.addi"}.
+  /// 3. A constant-like op, for example, `arith.constant`, is {type:
+  /// `CONSTANT_OP`, opName: ""}.
+  ///
+  /// Thus, the "key" associated with operand `A` is:
+  /// {
+  ///  {type: `NON_CONSTANT_OP`, opName: "arith.addi"},
+  ///  {type: `NON_CONSTANT_OP`, opName: "arith.subi"},
+  ///  {type: `CONSTANT_OP`, opName: ""},
+  ///  {type: `BLOCK_ARGUMENT`, opName: ""},
+  ///  {type: `BLOCK_ARGUMENT`, opName: ""}
+  /// }
+  SmallVector<AncestorKey, 4> key;
+
+  /// Stores true iff the operand has been assigned a sorted position yet.
+  bool isAssignedSortedPosition = false;
+
+  /// Push an ancestor into the operand's BFS information structure. This
+  /// entails it being pushed into the queue (always) and inserted into the
+  /// "visited ancestors" list (iff it is an op rather than a block argument).
+  void pushAncestor(Operation *op) {
+    Ancestor ancestor(op);
+    ancestorQueue.push(ancestor);
+    if (ancestor.isOp)
+      visitedAncestors.insert(ancestor.op);
+    return;
+  }
+
+  /// Pop the ancestor from the front of the queue.
+  void popAncestor() {
+    assert(!ancestorQueue.empty() &&
+           "to pop the ancestor from the front of the queue, the ancestor "
+           "queue should be non-empty");
+    ancestorQueue.pop();
+    return;
+  }
+
+  /// Return the ancestor at the front of the queue.
+  Ancestor frontAncestor() {
+    assert(!ancestorQueue.empty() &&
+           "to access the ancestor at the front of the queue, the ancestor "
+           "queue should be non-empty");
+    return ancestorQueue.front();
+  }
+};
+
+/// Returns:
+/// -1 if `keyA` < `keyB`,
+///  0 if `keyA` == `keyB`, and
+///  1 if `keyA` > `keyB`.
+///
+/// Note that:
+///
+/// (A) `keyA` == `keyB` iff:
+///   Both these keys, each of which is a list, have the same size and both
+///   the elements in each pair of corresponding elements among them is the
+///   same.
+///
+/// (B) `keyA` < `keyB` iff:
+///   1. In the first unequal pair of corresponding elements among them,
+///   `keyA`'s element is smaller, or
+///   2. Both the elements in every pair of corresponding elements are the same
+///   in both keys and the size of `keyA` is smaller.
+///
+/// (C) `keyB` < `keyA` condition is defined likewise.
+static int compareKeys(ArrayRef<AncestorKey> keyA, ArrayRef<AncestorKey> keyB) {
+  unsigned keyASize = keyA.size();
+  unsigned keyBSize = keyB.size();
+  unsigned smallestSize = keyASize;
+  if (keyBSize < smallestSize)
+    smallestSize = keyBSize;
+
+  for (unsigned i = 0; i < smallestSize; i++) {
+    if (keyA[i] < keyB[i])
+      return -1;
+    if (keyB[i] < keyA[i])
+      return 1;
+  }
+
+  if (keyASize == keyBSize)
+    return 0;
+  if (keyASize < keyBSize)
+    return -1;
+  return 1;
+}
+
+/// Goes through all the unassigned operands of `bfsOfOperands` and:
+/// 1. Stores the indices of the ones with the smallest key in
+/// `smallestKeyIndices`,
+/// 2. Stores the indices of the ones with the largest key in
+/// `largestKeyIndices`,
+/// 3. Sets `hasASingleOperandWithSmallestKey` as true if exactly one of them
+/// has the smallest key (and as false otherwise), AND,
+/// 4. Sets `hasASingleOperandWithLargestKey` as true if exactly one of them has
+/// the largest key (and as false otherwise).
+static void getIndicesOfUnassignedOperandsWithSmallestAndLargestKeys(
+    ArrayRef<OperandBFS *> bfsOfOperands,
+    DenseSet<unsigned> &smallestKeyIndices,
+    DenseSet<unsigned> &largestKeyIndices,
+    bool &hasASingleOperandWithSmallestKey,
+    bool &hasASingleOperandWithLargestKey) {
+  bool foundAnUnassignedOperand = false;
+
+  // Compute the smallest and largest keys present among the unassigned operands
+  // of `bfsOfOperands`.
+  ArrayRef<AncestorKey> smallestKey, largestKey;
+  for (OperandBFS *bfsOfOperand : bfsOfOperands) {
+    if (bfsOfOperand->isAssignedSortedPosition)
+      continue;
+
+    ArrayRef<AncestorKey> currentKey = bfsOfOperand->key;
+    if (!foundAnUnassignedOperand) {
+      foundAnUnassignedOperand = true;
+      smallestKey = currentKey;
+      largestKey = currentKey;
+      continue;
+    }
+    if (compareKeys(smallestKey, currentKey) == 1)
+      smallestKey = currentKey;
+    if (compareKeys(largestKey, currentKey) == -1)
+      largestKey = currentKey;
+  }
+
+  // If there is no unassigned operand, assign the necessary values to the input
+  // arguments and return.
+  if (!foundAnUnassignedOperand) {
+    hasASingleOperandWithSmallestKey = false;
+    hasASingleOperandWithLargestKey = false;
+    return;
+  }
+
+  // Populate `smallestKeyIndices` and `largestKeyIndices` and set
+  // `hasASingleOperandWithSmallestKey` and `hasASingleOperandWithLargestKey`
+  // accordingly.
+  bool smallestKeyFound = false;
+  bool largestKeyFound = false;
+  hasASingleOperandWithSmallestKey = true;
+  hasASingleOperandWithLargestKey = true;
+  for (auto &indexedBfsOfOperand : llvm::enumerate(bfsOfOperands)) {
+    OperandBFS *bfsOfOperand = indexedBfsOfOperand.value();
+    if (bfsOfOperand->isAssignedSortedPosition)
+      continue;
+
+    unsigned index = indexedBfsOfOperand.index();
+    ArrayRef<AncestorKey> currentKey = bfsOfOperand->key;
+
+    if (compareKeys(smallestKey, currentKey) == 0) {
+      smallestKeyIndices.insert(index);
+      if (smallestKeyFound)
+        hasASingleOperandWithSmallestKey = false;
+      smallestKeyFound = true;
+    }
+
+    if (compareKeys(largestKey, currentKey) == 0) {
+      largestKeyIndices.insert(index);
+      if (largestKeyFound)
+        hasASingleOperandWithLargestKey = false;
+      largestKeyFound = true;
+    }
+  }
+  return;
+}
+
+/// Update the key associated with each unassigned operand in `bfsOfOperands`.
+/// Updating a key entails making it up-to-date with its associated operand's
+/// BFS traversal that has happened till that point in time, i.e, appending the
+/// existing key with the front ancestor's "AncestorKey". Note that a
+/// key directly reflects the BFS and thus needs to be updated during the
+/// progression of the traversal.
+static void updateKeys(ArrayRef<OperandBFS *> bfsOfOperands) {
+  for (OperandBFS *bfsOfOperand : bfsOfOperands) {
+    if (bfsOfOperand->isAssignedSortedPosition ||
+        bfsOfOperand->ancestorQueue.empty())
+      continue;
+
+    // Append the key of `bfsOfOperand` with the its front ancestor's
+    // "AncestorKey".
+    Ancestor frontAncestor = bfsOfOperand->frontAncestor();
+    AncestorKey frontAncestorKey(frontAncestor);
+    bfsOfOperand->key.push_back(frontAncestorKey);
+  }
+  return;
+}
+
+/// If `keyIndices` contains `indexOfOperand` and either `isTheOnlyKey` is true
+/// or the ancestor queue of `bfsOfOperand` is empty, assign the sorted position
+/// `positionToAssign` to the operand of `op` at index `indexOfOperand`, and
+/// return true. Else, return false.
+static bool assignSortedPositionTo(OperandBFS *bfsOfOperand,
+                                   unsigned indexOfOperand,
+                                   DenseSet<unsigned> keyIndices,
+                                   bool isTheOnlyKey,
+                                   SmallVectorImpl<Value> &sortedOperands,
+                                   unsigned positionToAssign, Operation *op) {
+  if (keyIndices.contains(indexOfOperand) &&
+      (isTheOnlyKey || bfsOfOperand->ancestorQueue.empty())) {
+    bfsOfOperand->isAssignedSortedPosition = true;
+    sortedOperands[positionToAssign] = op->getOperand(indexOfOperand);
+    return true;
+  }
+  return false;
+}
+
+/// In each of the operands of `bfsOfOperands`, pop the front ancestor from the
+/// queue, if any, and then push its adjacent unvisited ancestors, if any, to
+/// the queue (this is the main body of the BFS algorithm).
+static void popFrontAndPushAdjacentUnvisitedAncestors(
+    ArrayRef<OperandBFS *> bfsOfOperands) {
+  for (OperandBFS *bfsOfOperand : bfsOfOperands) {
+    if (bfsOfOperand->isAssignedSortedPosition ||
+        bfsOfOperand->ancestorQueue.empty())
+      continue;
+    Ancestor frontAncestor = bfsOfOperand->frontAncestor();
+    bfsOfOperand->popAncestor();
+    if (!frontAncestor.isOp)
+      continue;
+    for (Value operand : frontAncestor.op->getOperands()) {
+      Operation *operandDefOp = operand.getDefiningOp();
+      if (!operandDefOp ||
+          !bfsOfOperand->visitedAncestors.contains(operandDefOp))
+        bfsOfOperand->pushAncestor(operandDefOp);
+    }
+  }
+  return;
+}
+
+/// Sorts the operands of `op` in ascending order of the "key" associated with
+/// each operand iff `op` is commutative.
+///
+/// After the application of this pattern, since the commutative operands now
+/// have a deterministic order in which they occur in an op, the matching of
+/// large DAGs becomes much simpler, i.e., requires much less number of checks
+/// to be written by a user in her/his pattern matching function.
+///
+/// Some examples of such a sorting:
+///
+/// Assume that the sorting is being applied to `foo.commutative`, which is a
+/// commutative op.
+///
+/// Example 1:
+///
+/// %1 = foo.const 0
+/// %2 = foo.mul <block argument>, <block argument>
+/// %3 = foo.commutative %1, %2
+///
+/// Here,
+/// 1. The key associated with %1 is `{{CONSTANT_OP, ""}}`, and,
+/// 2. The key associated with %2 is `{{NON_CONSTANT_OP, "foo.mul"},
+/// {BLOCK_ARGUMENT, ""}, {BLOCK_ARGUMENT, ""}}`.
+///
+/// Thus, the key of %2 < key of %1, and so, the sorted `foo.commutative` will
+/// look like:
+///
+/// %3 = foo.commutative %2, %1
+///
+/// Note that in this example, it wasn't necessary to get the entire set of
+/// ancestors of operand %2 to decide that it has the smaller key. Just the
+/// comparision of `{{CONSTANT_OP, ""}}` and `{{NON_CONSTANT_OP, "foo.mul"}}` is
+/// enough to determine this. Thus, this sorting utility does not store the
+/// entire set of ancestors of an operand at once. It seeks the ancestors in a
+/// breadth-first fashion, one at a time, to create an operand "key" but stops
+/// when the relative ordering of that operand can be determined.
+///
+/// So, here,
+/// 1. The key computed for %1 is `{{CONSTANT_OP, ""}}`, and,
+/// 2. The key computed for %2 is `{{NON_CONSTANT_OP, "foo.mul"}}` (instead of
+/// `{{NON_CONSTANT_OP, "foo.mul"}, {BLOCK_ARGUMENT, ""}, {BLOCK_ARGUMENT,
+/// ""}}`.
+///
+/// Example 2:
+///
+/// %1 = foo.const 0
+/// %2 = foo.mul <block argument>, <block argument>
+/// %3 = foo.mul %2, %1
+/// %4 = foo.add %2, %1
+/// %5 = foo.commutative %1, %2, %3, %4
+///
+/// Here,
+/// 1. The key associated with %1 is `{{CONSTANT_OP, ""}}`,
+/// 2. The key associated with %2 is `{{NON_CONSTANT_OP, "foo.mul"},
+/// {BLOCK_ARGUMENT, ""}}`,
+/// 3. The key associated with %3 is `{{NON_CONSTANT_OP, "foo.mul"},
+/// {NON_CONSTANT_OP, "foo.mul"}, {CONSTANT_OP, ""}, {BLOCK_ARGUMENT, ""},
+/// {BLOCK_ARGUMENT, ""}}`, and,
+/// 4. The key associated with %4 is `{{NON_CONSTANT_OP, "foo.add"},
+/// {NON_CONSTANT_OP, "foo.mul"}, {CONSTANT_OP, ""}, {BLOCK_ARGUMENT, ""},
+/// {BLOCK_ARGUMENT, ""}}`.
+///
+/// And,
+/// 1. The key computed for %1 is `{{CONSTANT_OP, ""}}`,
+/// 2. The key computed for %2 is `{{NON_CONSTANT_OP, "foo.mul"},
+/// {BLOCK_ARGUMENT, ""}}`,
+/// 3. The key computed for %3 is `{{NON_CONSTANT_OP, "foo.mul"},
+/// {NON_CONSTANT_OP, "foo.mul"}}`, and,
+/// 4. The key computed for %4 is `{{NON_CONSTANT_OP, "foo.add"}}`.
+///
+/// Note that the BFS's for operands %1 and %4 stop after one iteration while
+/// those for operands %2 and %3 stop after the second iteration.
+///
+/// And, the sorted `foo.commutative` will look like:
+///
+/// %5 = foo.commutative %4, %3, %2, %1
+class SortCommutativeOperands : public RewritePattern {
+public:
+  SortCommutativeOperands(MLIRContext *context)
+      : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/5, context) {}
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    // If `op` is not commutative, do nothing.
+    if (!op->hasTrait<OpTrait::IsCommutative>())
+      return failure();
+
+    // `bfsOfOperands` stores the BFS traversal information of each operand of
+    // `op`. For each operand, this information comprises a queue of ancestors
+    // being visited during the BFS (at a particular point in time), a list of
+    // visited ancestors (at a particular point in time), its associated key (at
+    // a particular point in time), and whether or not the operand has been
+    // assigned a sorted position yet.
+    SmallVector<OperandBFS *, 2> bfsOfOperands;
+
+    // Initially, each operand's ancestor queue contains the op defining it
+    // (which is considered its first ancestor). Thus, it acts as the starting
+    // point for that operand's BFS traversal.
+    for (Value operand : op->getOperands()) {
+      OperandBFS *bfsOfOperand = new OperandBFS();
+      bfsOfOperand->pushAncestor(operand.getDefiningOp());
+      bfsOfOperands.push_back(bfsOfOperand);
+    }
+
+    // Since none of the operands have been assigned a sorted position yet, the
+    // smallest unassigned position is set as zero and the largest one is set as
+    // the number of operands in `op` minus one (N - 1). This is because each
+    // operand will be assigned a sorted position between 0 and (N - 1), both
+    // inclusive.
+    unsigned numOperands = op->getNumOperands();
+    unsigned smallestUnassignedPosition = 0;
+    unsigned largestUnassignedPosition = numOperands - 1;
+
+    // `sortedOperands` will store the list of `op`'s operands in sorted order.
+    // At first, all elements in it are initialized as null.
+    SmallVector<Value, 2> sortedOperands(numOperands, nullptr);
+
+    // We perform the BFS traversals of all operands parallelly until each of
+    // them is assigned a sorted position. During the traversals, we try to
+    // assign a sorted position to an operand as soon as it is possible (based
+    // on a comparision of its traversal with the other traversals at that
+    // particular point in time).
+    while (llvm::any_of(bfsOfOperands, [](OperandBFS *bfsOfOperand) {
+      return !bfsOfOperand->isAssignedSortedPosition;
+    })) {
+      // Update the keys corresponding to all unassigned operands.
+      updateKeys(bfsOfOperands);
+
+      // Stores the indices of the unassigned operands whose key is the
+      // smallest.
+      DenseSet<unsigned> smallestKeyIndices;
+      // Stores the indices of the unassigned operands whose key is the largest.
+      DenseSet<unsigned> largestKeyIndices;
+
+      // Stores true iff there is a single unassigned operand that has the
+      // smallest key.
+      bool hasASingleOperandWithSmallestKey;
+      // Stores true iff there is a single unassigned operand that has the
+      // largest key.
+      bool hasASingleOperandWithLargestKey;
+
+      getIndicesOfUnassignedOperandsWithSmallestAndLargestKeys(
+          bfsOfOperands, smallestKeyIndices, largestKeyIndices,
+          hasASingleOperandWithSmallestKey, hasASingleOperandWithLargestKey);
+
+      // Go through each of the unassigned operands with the smallest key and
+      // try to assign it a sorted position if possible (ensuring stable
+      // sorting).
+      for (auto &indexedBfsOfOperand : llvm::enumerate(bfsOfOperands)) {
+        OperandBFS *bfsOfOperand = indexedBfsOfOperand.value();
+        if (bfsOfOperand->isAssignedSortedPosition)
+          continue;
+
+        // If an unassigned operand has the smallest key and:
+        // 1. It is the only operand with the smallest key, OR,
+        // 2. Its BFS is complete,
+        // then,
+        // this operand is assigned the `smallestUnassignedPosition` (which will
+        // be its new position in the rearranged `op`).
+        if (assignSortedPositionTo(
+                bfsOfOperand, /*indexOfOperand=*/indexedBfsOfOperand.index(),
+                /*keyIndices=*/smallestKeyIndices,
+                /*isTheOnlyKey=*/hasASingleOperandWithSmallestKey,
+                /*sortedOperands=*/sortedOperands,
+                /*positionToAssign=*/smallestUnassignedPosition, /*op=*/op))
+          smallestUnassignedPosition++;
+      }
+      // Go through each of the unassigned operands with the largest key and try
+      // to assign it a sorted position if possible (ensuring stable sorting).
+      for (auto indexedBfsOfOperand :
+           llvm::enumerate(llvm::reverse(bfsOfOperands))) {
+        OperandBFS *bfsOfOperand = indexedBfsOfOperand.value();
+        if (bfsOfOperand->isAssignedSortedPosition)
+          continue;
+
+        //  If an unassigned operand has the largest key and:
+        //  1. It is the only operand with the largest key, OR,
+        //  2. Its BFS is complete,
+        //  then,
+        //  this operand is assigned the `largestUnassignedPosition` (which will
+        //  be its new position in the rearranged `op`).
+        if (assignSortedPositionTo(
+                bfsOfOperand, /*indexOfOperand=*/numOperands -
+                                  indexedBfsOfOperand.index() - 1,
+                /*keyIndices=*/largestKeyIndices,
+                /*isTheOnlyKey=*/hasASingleOperandWithLargestKey,
+                /*sortedOperands=*/sortedOperands,
+                /*positionToAssign=*/largestUnassignedPosition, /*op=*/op))
+          largestUnassignedPosition--;
+      }
+
+      // For each operand in `bfsOfOperands`, pop the front ancestor from the
+      // queue and push its adjacent unvisited ancestors into the queue.
+      popFrontAndPushAdjacentUnvisitedAncestors(bfsOfOperands);
+    }
+    rewriter.updateRootInPlace(op, [&] { op->setOperands(sortedOperands); });
+    return success();
+  }
+};
+
+void mlir::populateCommutativityUtilsPatterns(RewritePatternSet &patterns) {
+  patterns.add<SortCommutativeOperands>(patterns.getContext());
+}
Index: mlir/lib/Transforms/Utils/CMakeLists.txt
===================================================================
--- mlir/lib/Transforms/Utils/CMakeLists.txt
+++ mlir/lib/Transforms/Utils/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_library(MLIRTransformUtils
+  CommutativityUtils.cpp
   ControlFlowSinkUtils.cpp
   DialectConversion.cpp
   FoldUtils.cpp
Index: mlir/include/mlir/Transforms/CommutativityUtils.h
===================================================================
--- /dev/null
+++ mlir/include/mlir/Transforms/CommutativityUtils.h
@@ -0,0 +1,27 @@
+//===- CommutativityUtils.h - Commutativity utilities -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file declares a function to populate the commutativity utility
+// pattern. This function is intended to be used inside passes to simplify the
+// matching of commutative operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_COMMUTATIVITYUTILS_H
+#define MLIR_TRANSFORMS_COMMUTATIVITYUTILS_H
+
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+
+/// Populates the commutativity utility patterns.
+void populateCommutativityUtilsPatterns(RewritePatternSet &patterns);
+
+} // namespace mlir
+
+#endif // MLIR_TRANSFORMS_COMMUTATIVITYUTILS_H
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to