srishti-pm updated this revision to Diff 428240. srishti-pm added a comment.
Within constant-like ops, removed the requirement for them being sorted alphabetically. Basically, all constants will be treated as equals by the sorting algorithm and it will not distinguish between, say, `arith.constant` and `tf.Const`. This is because multiple canonicalizations exist in various dialects that push the constants to the right but do not make any distinction among constants. So, since we want this utility to not clash with those canonicalizations, this is being done. Repository: rG LLVM Github Monorepo CHANGES SINCE LAST ACTION https://reviews.llvm.org/D124750/new/ https://reviews.llvm.org/D124750 Files: clang/docs/tools/clang-formatted-files.txt mlir/include/mlir/Transforms/CommutativityUtils.h mlir/lib/Transforms/Utils/CMakeLists.txt mlir/lib/Transforms/Utils/CommutativityUtils.cpp mlir/test/Transforms/test-commutativity-utils.mlir mlir/test/lib/Dialect/Test/TestOps.td mlir/test/lib/Transforms/CMakeLists.txt mlir/test/lib/Transforms/TestCommutativityUtils.cpp mlir/tools/mlir-opt/mlir-opt.cpp
Index: mlir/tools/mlir-opt/mlir-opt.cpp =================================================================== --- mlir/tools/mlir-opt/mlir-opt.cpp +++ mlir/tools/mlir-opt/mlir-opt.cpp @@ -56,6 +56,7 @@ void registerVectorizerTestPass(); namespace test { +void registerCommutativityUtils(); void registerConvertCallOpPass(); void registerInliner(); void registerMemRefBoundCheck(); @@ -146,6 +147,7 @@ registerVectorizerTestPass(); registerTosaTestQuantUtilAPIPass(); + mlir::test::registerCommutativityUtils(); mlir::test::registerConvertCallOpPass(); mlir::test::registerInliner(); mlir::test::registerMemRefBoundCheck(); Index: mlir/test/lib/Transforms/TestCommutativityUtils.cpp =================================================================== --- /dev/null +++ mlir/test/lib/Transforms/TestCommutativityUtils.cpp @@ -0,0 +1,67 @@ +//===- TestCommutativityUtils.cpp - Pass to test the commutativity utility-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass tests the functionality of the commutativity utility. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/CommutativityUtils.h" + +#include "TestDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace test; + +namespace { + +struct SmallPattern : public OpRewritePattern<TestCommutativeOp> { + using OpRewritePattern<TestCommutativeOp>::OpRewritePattern; + LogicalResult matchAndRewrite(TestCommutativeOp testCommOp, + PatternRewriter &rewriter) const override { + sortCommutativeOperands(testCommOp.getOperation(), rewriter); + return success(); + } +}; + +struct LargePattern : public OpRewritePattern<TestLargeCommutativeOp> { + using OpRewritePattern<TestLargeCommutativeOp>::OpRewritePattern; + LogicalResult matchAndRewrite(TestLargeCommutativeOp testLargeCommOp, + PatternRewriter &rewriter) const override { + sortCommutativeOperands(testLargeCommOp.getOperation(), rewriter); + return success(); + } +}; + +struct CommutativityUtils + : public PassWrapper<CommutativityUtils, OperationPass<FuncOp>> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CommutativityUtils) + + StringRef getArgument() const final { return "test-commutativity-utils"; } + StringRef getDescription() const final { + return "Test the functionality of the commutativity utility"; + } + + void runOnOperation() override { + auto func = getOperation(); + auto *context = &getContext(); + + RewritePatternSet patterns(context); + patterns.add<LargePattern, SmallPattern>(context); + + (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + } +}; +} // namespace + +namespace mlir { +namespace test { +void registerCommutativityUtils() { PassRegistration<CommutativityUtils>(); } +} // namespace test +} // namespace mlir Index: mlir/test/lib/Transforms/CMakeLists.txt =================================================================== --- mlir/test/lib/Transforms/CMakeLists.txt +++ mlir/test/lib/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRTestTransforms + TestCommutativityUtils.cpp TestConstantFold.cpp TestControlFlowSink.cpp TestInlining.cpp Index: mlir/test/lib/Dialect/Test/TestOps.td =================================================================== --- mlir/test/lib/Dialect/Test/TestOps.td +++ mlir/test/lib/Dialect/Test/TestOps.td @@ -1101,11 +1101,21 @@ let hasFolder = 1; } +def TestAddIOp : TEST_Op<"addi"> { + let arguments = (ins I32:$op1, I32:$op2); + let results = (outs I32); +} + def TestCommutativeOp : TEST_Op<"op_commutative", [Commutative]> { let arguments = (ins I32:$op1, I32:$op2, I32:$op3, I32:$op4); let results = (outs I32); } +def TestLargeCommutativeOp : TEST_Op<"op_large_commutative", [Commutative]> { + let arguments = (ins I32:$op1, I32:$op2, I32:$op3, I32:$op4, I32:$op5, I32:$op6, I32:$op7); + let results = (outs I32); +} + def TestCommutative2Op : TEST_Op<"op_commutative2", [Commutative]> { let arguments = (ins I32:$op1, I32:$op2); let results = (outs I32); Index: mlir/test/Transforms/test-commutativity-utils.mlir =================================================================== --- /dev/null +++ mlir/test/Transforms/test-commutativity-utils.mlir @@ -0,0 +1,116 @@ +// RUN: mlir-opt %s -test-commutativity-utils | FileCheck %s + +// CHECK-LABEL: @test_small_pattern_1 +func @test_small_pattern_1(%arg0 : i32) -> i32 { + // CHECK-NEXT: %[[ARITH_CONST:.*]] = arith.constant + %0 = arith.constant 45 : i32 + + // CHECK-NEXT: %[[TEST_ADD:.*]] = "test.addi" + %1 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32 + + // CHECK-NEXT: %[[ARITH_ADD:.*]] = arith.addi + %2 = arith.addi %arg0, %arg0 : i32 + + // CHECK-NEXT: %[[ARITH_MUL:.*]] = arith.muli + %3 = arith.muli %arg0, %arg0 : i32 + + // CHECK-NEXT: %[[RESULT:.*]] = "test.op_commutative"(%[[ARITH_ADD]], %[[ARITH_MUL]], %[[TEST_ADD]], %[[ARITH_CONST]]) + %result = "test.op_commutative"(%0, %1, %2, %3): (i32, i32, i32, i32) -> i32 + + // CHECK-NEXT: return %[[RESULT]] + return %result : i32 +} + +// CHECK-LABEL: @test_small_pattern_2 +// CHECK-SAME: (%[[ARG0:.*]]: i32 +func @test_small_pattern_2(%arg0 : i32) -> i32 { + // CHECK-NEXT: %[[TEST_CONST:.*]] = "test.constant" + %0 = "test.constant"() {value = 0 : i32} : () -> i32 + + // CHECK-NEXT: %[[ARITH_CONST:.*]] = arith.constant + %1 = arith.constant 0 : i32 + + // CHECK-NEXT: %[[ARITH_ADD:.*]] = arith.addi + %2 = arith.addi %arg0, %arg0 : i32 + + // CHECK-NEXT: %[[RESULT:.*]] = "test.op_commutative"(%[[ARITH_ADD]], %[[ARG0]], %[[TEST_CONST]], %[[ARITH_CONST]]) + %result = "test.op_commutative"(%0, %1, %2, %arg0): (i32, i32, i32, i32) -> i32 + + // CHECK-NEXT: return %[[RESULT]] + return %result : i32 +} + +// CHECK-LABEL: @test_large_pattern +func @test_large_pattern(%arg0 : i32, %arg1 : i32) -> i32 { + // CHECK-NEXT: arith.divsi + %0 = arith.divsi %arg0, %arg1 : i32 + + // CHECK-NEXT: arith.divsi + %1 = arith.divsi %0, %arg0 : i32 + + // CHECK-NEXT: arith.divsi + %2 = arith.divsi %1, %arg1 : i32 + + // CHECK-NEXT: arith.addi + %3 = arith.addi %1, %arg1 : i32 + + // CHECK-NEXT: arith.subi + %4 = arith.subi %2, %3 : i32 + + // CHECK-NEXT: "test.addi" + %5 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32 + + // CHECK-NEXT: %[[VAL6:.*]] = arith.divsi + %6 = arith.divsi %4, %5 : i32 + + // CHECK-NEXT: arith.divsi + %7 = arith.divsi %1, %arg1 : i32 + + // CHECK-NEXT: %[[VAL8:.*]] = arith.muli + %8 = arith.muli %1, %arg1 : i32 + + // CHECK-NEXT: %[[VAL9:.*]] = arith.subi + %9 = arith.subi %7, %8 : i32 + + // CHECK-NEXT: "test.addi" + %10 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32 + + // CHECK-NEXT: %[[VAL11:.*]] = arith.divsi + %11 = arith.divsi %9, %10 : i32 + + // CHECK-NEXT: %[[VAL12:.*]] = arith.divsi + %12 = arith.divsi %6, %arg1 : i32 + + // CHECK-NEXT: arith.subi + %13 = arith.subi %arg1, %arg0 : i32 + + // CHECK-NEXT: "test.op_commutative"(%[[VAL12]], %[[VAL12]], %[[VAL8]], %[[VAL9]]) + %14 = "test.op_commutative"(%12, %9, %12, %8): (i32, i32, i32, i32) -> i32 + + // CHECK-NEXT: %[[VAL15:.*]] = arith.divsi + %15 = arith.divsi %13, %14 : i32 + + // CHECK-NEXT: %[[VAL16:.*]] = arith.addi + %16 = arith.addi %2, %15 : i32 + + // CHECK-NEXT: arith.subi + %17 = arith.subi %16, %arg1 : i32 + + // CHECK-NEXT: "test.addi" + %18 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32 + + // CHECK-NEXT: %[[VAL19:.*]] = arith.divsi + %19 = arith.divsi %17, %18 : i32 + + // CHECK-NEXT: "test.addi" + %20 = "test.addi"(%arg0, %16): (i32, i32) -> i32 + + // CHECK-NEXT: %[[VAL21:.*]] = arith.divsi + %21 = arith.divsi %17, %20 : i32 + + // CHECK-NEXT: %[[RESULT:.*]] = "test.op_large_commutative"(%[[VAL16]], %[[VAL21]], %[[VAL19]], %[[VAL19]], %[[VAL6]], %[[VAL11]], %[[VAL15]]) + %result = "test.op_large_commutative"(%16, %6, %11, %15, %19, %19, %21): (i32, i32, i32, i32, i32, i32, i32) -> i32 + + // CHECK-NEXT: return %[[RESULT]] + return %result : i32 +} Index: mlir/lib/Transforms/Utils/CommutativityUtils.cpp =================================================================== --- /dev/null +++ mlir/lib/Transforms/Utils/CommutativityUtils.cpp @@ -0,0 +1,396 @@ +//===- CommutativityUtils.cpp - Commutativity utilities ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a utility that is intended to be used inside a pass or +// an individual pattern to simplify the matching of commutative operations. +// Note that this utility can also be used inside PDL patterns in conjunction +// with the `pdl.apply_native_rewrite` op. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/CommutativityUtils.h" + +#include "mlir/IR/PatternMatch.h" +#include <queue> + +#define DEBUG_TYPE "commutativity-utils" + +using namespace mlir; + +/// Stores the BFS traversal information of an operand. +struct OperandBFS { + /// Stores the queue of ancestors of the BFS traversal of an operand at a + /// particular point in time. + std::queue<Operation *> ancestorQueue; + + /// Stores the list of visited ancestors of the BFS traversal of an operand at + /// a particular point in time. + DenseSet<Operation *> visitedAncestors; + + /// Stores the key corresponding to the BFS traversal of an operand at a + /// particular point in time. + /// Some examples: + /// 1. If the BFS has seen `arith.addi`, + /// then, + /// the key will store the string: + /// "1arith.addi". + /// 2. If the BFS has seen `arg5`, + /// then, + /// the key will store the string: + /// "2". + /// 3. If the BFS has seen `arith.constant`, + /// then, + /// the key will store the string: + /// "3". + /// 4. If the BFS has seen `arith.addi`, `test.constant`, `scf.if`, `tf.Add`, + /// `arith.constant`, and `arg5` (in BFS order), + /// then, + /// the key will store the string: + /// "1arith.addi31scf.if1tf.Add32". + /// + /// Such a definition of "key" will allow the ascending order of keys of + /// different operands to be such the (1) ones defined by non-constant-like + /// ops come first, followed by (2) block arguments, which are finally + /// followed by the (3) ones defined by constant-like ops. In addition to + /// this, within the category (1), the order of operands is alphabetical + /// w.r.t. the dialect name and op name. + /// + /// Further, as an example to demonstrate the comparision of keys, note that + /// if we have the following commutative op (foo.op): + /// e = foo.div f, g + /// c = foo.constant + /// b = foo.add e, d + /// a = foo.add c, d + /// s = foo.op a, b, + /// then, + /// the key associated with operand `a` will be "1foo.add3", and, + /// the key associated with operand `b` will be "1foo.add1foo.div", + /// and thus, + /// key of `a` > key of `b`, + /// + /// which means that a "sorted" foo.op would look like: + /// s = foo.op b, a (instead of a, b). + std::string key = ""; + + /// Stores true iff the operand has been assigned a sorted position yet. + bool isAssignedSortedPosition = false; + + /// Push an ancestor into the operand's BFS information structure. This + /// entails it being pushed into the queue (always) and inserted into the + /// "visited ancestors" list (iff it is not null, i.e., corresponds to an op + /// rather than a block argument). + void pushAncestor(Operation *ancestor) { + ancestorQueue.push(ancestor); + if (ancestor) + visitedAncestors.insert(ancestor); + return; + } + + /// Pop the ancestor from the front of the queue. + void popAncestor() { + assert(!ancestorQueue.empty() && + "to pop the ancestor from the front of the queue, the ancestor " + "queue should be non-empty"); + ancestorQueue.pop(); + return; + } + + /// Return the ancestor at the front of the queue. + Operation *frontAncestor() { + assert(!ancestorQueue.empty() && + "to access the ancestor at the front of the queue, the ancestor " + "queue should be non-empty"); + return ancestorQueue.front(); + } +}; + +/// Returns true iff at least one unassigned operand exists. An unassigned +/// operand refers to one which has not been assigned a sorted position yet. +static bool +hasAtLeastOneUnassignedOperand(SmallVector<OperandBFS *, 2> bfsOfOperands) { + for (OperandBFS *bfsOfOperand : bfsOfOperands) { + if (!bfsOfOperand->isAssignedSortedPosition) + return true; + } + return false; +} + +/// Goes through all the unassigned operands of `bfsOfOperands` and: +/// 1. Stores the indices of the ones with the smallest key in +/// `smallestKeyIndices`, +/// 2. Stores the indices of the ones with the largest key in +/// `largestKeyIndices`, +/// 3. Sets `hasASingleOperandWithSmallestKey` as true if exactly one of them +/// has the smallest key (and as false otherwise), AND, +/// 4. Sets `hasASingleOperandWithLargestKey` as true if exactly one of them has +/// the largest key (and as false otherwise). +static void getIndicesOfUnassignedOperandsWithSmallestAndLargestKeys( + SmallVector<OperandBFS *, 2> bfsOfOperands, + DenseSet<unsigned> &smallestKeyIndices, + DenseSet<unsigned> &largestKeyIndices, + bool &hasASingleOperandWithSmallestKey, + bool &hasASingleOperandWithLargestKey) { + bool foundAnUnassignedOperand = false; + + // Compute the smallest and largest keys present among the unassigned operands + // of `bfsOfOperands`. + std::string smallestKey, largestKey; + for (OperandBFS *bfsOfOperand : bfsOfOperands) { + if (bfsOfOperand->isAssignedSortedPosition) + continue; + + std::string currentKey = bfsOfOperand->key; + if (!foundAnUnassignedOperand) { + foundAnUnassignedOperand = true; + smallestKey = currentKey; + largestKey = currentKey; + continue; + } + if (smallestKey > currentKey) + smallestKey = currentKey; + if (largestKey < currentKey) + largestKey = currentKey; + } + + // If there is no unassigned operand, assign the necessary values to the input + // arguments and return. + if (!foundAnUnassignedOperand) { + hasASingleOperandWithSmallestKey = false; + hasASingleOperandWithLargestKey = false; + return; + } + + // Populate `smallestKeyIndices` and `largestKeyIndices` and set + // `hasASingleOperandWithSmallestKey` and `hasASingleOperandWithLargestKey` + // accordingly. + bool smallestKeyFound = false; + bool largestKeyFound = false; + hasASingleOperandWithSmallestKey = true; + hasASingleOperandWithLargestKey = true; + for (auto indexedBfsOfOperand : llvm::enumerate(bfsOfOperands)) { + OperandBFS *bfsOfOperand = indexedBfsOfOperand.value(); + if (bfsOfOperand->isAssignedSortedPosition) + continue; + + unsigned index = indexedBfsOfOperand.index(); + std::string currentKey = bfsOfOperand->key; + + if (smallestKey == currentKey) { + smallestKeyIndices.insert(index); + if (smallestKeyFound) + hasASingleOperandWithSmallestKey = false; + smallestKeyFound = true; + } + + if (largestKey == currentKey) { + largestKeyIndices.insert(index); + if (largestKeyFound) + hasASingleOperandWithLargestKey = false; + largestKeyFound = true; + } + } + return; +} + +/// Update the key associated with each unassigned operand in `bfsOfOperands`. +/// Updating a key entails making it up-to-date with its associated operand's +/// BFS traversal that has happened till that point in time. Note that a key +/// directly reflects the BFS and thus needs to be updated after every change in +/// the BFS queue, as the traversal happens. +static void updateKeys(SmallVector<OperandBFS *, 2> bfsOfOperands) { + for (OperandBFS *bfsOfOperand : bfsOfOperands) { + if (bfsOfOperand->isAssignedSortedPosition || + bfsOfOperand->ancestorQueue.empty()) + continue; + + Operation *frontAncestor = bfsOfOperand->frontAncestor(); + if (!frontAncestor) { + // When the front ancestor is a block argument, we concatenate the old key + // with such a value that allows its corresponding operand to be + // positioned between operands defined by non-constant-like and + // constant-like operations. + bfsOfOperand->key = (Twine(bfsOfOperand->key) + Twine("2")).str(); + } else if (frontAncestor->hasTrait<OpTrait::ConstantLike>()) { + // When the front ancestor is a constant-like operation, we concatenate + // the old key with such a value that allows its corresponding operand to + // be positioned after operands defined by non-constant-like operations or + // block arguments. + bfsOfOperand->key = (Twine(bfsOfOperand->key) + Twine("3")).str(); + } else { + // When the front ancestor is a non-constant-like operation, we + // concatenate the old key with such a value that allows its corresponding + // operand to be positioned before block arguments or operands defined by + // constant-like operations (while maintaining that among + // non-constant-like operations, the corresponding operands are positioned + // alphabetically). + bfsOfOperand->key = (Twine(bfsOfOperand->key) + Twine("1") + + std::string(frontAncestor->getName().getStringRef())) + .str(); + } + } + return; +} + +/// Rewrite `op`, i.e., rearrange its operands in a "sorted" order. +/// The operands of an op are considered to be "sorted" iff: +/// 1. The op is not commutative, OR, +/// 2. It is commutative and its operands are in ascending order of the "keys" +/// associated with them. +/// +/// Note that `operandDefOps` stores the list of ops defining its operands (in +/// the order in which they appear in `op`). If an operand is a block argument, +/// the op defining it stores null. +static void +rewriteCommutativeOperands(Operation *op, + SmallVector<Operation *, 2> operandDefOps, + PatternRewriter &rewriter) { + // If `op` is not commutative, do nothing. + if (!op->hasTrait<OpTrait::IsCommutative>()) + return; + + // `bfsOfOperands` stores the BFS traversal information of each operand of + // `op`. For each operand, this information comprises a queue of ancestors + // being visited during the BFS (at a particular point in time), a list of + // visited ancestors (at a particular point in time), its associated key (at a + // particular point in time), and whether or not the operand has been assigned + // a sorted position yet. + SmallVector<OperandBFS *, 2> bfsOfOperands; + + // Initially, each operand's ancestor queue contains the op defining it (which + // is considered its first ancestor). Thus, it acts as the starting point for + // that operand's BFS traversal. + for (Operation *operandDefOp : operandDefOps) { + OperandBFS *bfsOfOperand = new OperandBFS(); + bfsOfOperand->pushAncestor(operandDefOp); + bfsOfOperands.push_back(bfsOfOperand); + } + + // Since none of the operands have been assigned a sorted position yet, the + // smallest unassigned position is set as zero and the largest one is set as + // the number of operands in `op` minus one (N - 1). This is because each + // operand will be assigned a sorted position between 0 and (N - 1), both + // inclusive. + unsigned numOperands = op->getNumOperands(); + unsigned smallestUnassignedPosition = 0; + unsigned largestUnassignedPosition = numOperands - 1; + + // `sortedOperands` will store the list of `op`'s operands in sorted order. + // At first, all elements in it are initialized as null. + SmallVector<Value, 2> sortedOperands; + while (numOperands) { + sortedOperands.push_back(nullptr); + numOperands--; + } + + // We perform the BFS traversals of all operands parallelly until each of them + // is assigned a sorted position. During the traversals, we try to assign a + // sorted position to an operand as soon as it is possible (based on a + // comparision of its traversal with the other traversals at that particular + // point in time). + while (hasAtLeastOneUnassignedOperand(bfsOfOperands)) { + // Update the keys corresponding to all unassigned operands. + updateKeys(bfsOfOperands); + + // Stores the indices of the unassigned operands whose key is the smallest. + DenseSet<unsigned> smallestKeyIndices; + // Stores the indices of the unassigned operands whose key is the largest. + DenseSet<unsigned> largestKeyIndices; + + // Stores true iff there is a single unassigned operand that has the + // smallest key. + bool hasASingleOperandWithSmallestKey; + // Stores true iff there is a single unassigned operand that has the largest + // key. + bool hasASingleOperandWithLargestKey; + + getIndicesOfUnassignedOperandsWithSmallestAndLargestKeys( + bfsOfOperands, smallestKeyIndices, largestKeyIndices, + hasASingleOperandWithSmallestKey, hasASingleOperandWithLargestKey); + + // Go through each of the unassigned operands and try to assign it a sorted + // position if possible. + for (auto indexedBfsOfOperand : llvm::enumerate(bfsOfOperands)) { + OperandBFS *bfsOfOperand = indexedBfsOfOperand.value(); + if (bfsOfOperand->isAssignedSortedPosition) + continue; + + unsigned index = indexedBfsOfOperand.index(); + + // If an unassigned operand has the smallest key and: + // 1. It is the only operand with the smallest key, OR, + // 2. Its BFS is complete, + // then, + // this operand is assigned the `smallestUnassignedPosition` (which will + // be its new position in the rearranged `op`). + // + // Likewise, + // + // If an unassigned operand has the largest key and: + // 1. It is the only operand with the largest key, OR, + // 2. Its BFS is complete, + // then, + // this operand is assigned the `largestUnassignedPosition` (which will be + // its new position in the rearranged `op`). + if (smallestKeyIndices.contains(index) && + (hasASingleOperandWithSmallestKey || + bfsOfOperand->ancestorQueue.empty())) { + bfsOfOperand->isAssignedSortedPosition = true; + sortedOperands[smallestUnassignedPosition] = op->getOperand(index); + smallestUnassignedPosition++; + } else if (largestKeyIndices.contains(index) && + (hasASingleOperandWithLargestKey || + bfsOfOperand->ancestorQueue.empty())) { + bfsOfOperand->isAssignedSortedPosition = true; + sortedOperands[largestUnassignedPosition] = op->getOperand(index); + largestUnassignedPosition--; + } + + // Pop the front ancestor from the queue, if any, and then push its + // adjacent unvisited ancestors, if any, to the queue (the main body of + // the BFS algorithm). + if (bfsOfOperand->ancestorQueue.empty()) + continue; + Operation *frontAncestor = bfsOfOperand->frontAncestor(); + bfsOfOperand->popAncestor(); + if (!frontAncestor) + continue; + for (Value operand : frontAncestor->getOperands()) { + Operation *thisOperandDefOp = operand.getDefiningOp(); + if (!thisOperandDefOp || + !bfsOfOperand->visitedAncestors.contains(thisOperandDefOp)) + bfsOfOperand->pushAncestor(thisOperandDefOp); + } + } + } + rewriter.updateRootInPlace(op, [&] { op->setOperands(sortedOperands); }); +} + +/// Sorts `op`. +/// "Sorting" `op` means to "sort" the ops defining each of its operands +/// followed by rearranging its operands in the "sorted" order. Before the +/// rearrangement, it is important to sort the ops defining its operands so that +/// the rearrangement is deterministic. In other words, if these ops were not +/// sorted, the rearrangement would be non-deterministic and would thus make +/// this utility useless. +void mlir::sortCommutativeOperands(Operation *op, PatternRewriter &rewriter) { + assert(op && "the input argument `op` must not be null"); + + // Before the operands of `op` are rearranged, the operations defining the + // operands of `op` are sorted. + SmallVector<Operation *, 2> operandDefOps; + for (Value operand : op->getOperands()) { + Operation *operandDefOp = operand.getDefiningOp(); + operandDefOps.push_back(operandDefOp); + if (operandDefOp) + sortCommutativeOperands(operandDefOp, rewriter); + } + + // Now, rewrite `op`, i.e, rearrange its operands in a "sorted" order. + rewriteCommutativeOperands(op, operandDefOps, rewriter); + return; +} Index: mlir/lib/Transforms/Utils/CMakeLists.txt =================================================================== --- mlir/lib/Transforms/Utils/CMakeLists.txt +++ mlir/lib/Transforms/Utils/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_library(MLIRTransformUtils + CommutativityUtils.cpp ControlFlowSinkUtils.cpp DialectConversion.cpp FoldUtils.cpp Index: mlir/include/mlir/Transforms/CommutativityUtils.h =================================================================== --- /dev/null +++ mlir/include/mlir/Transforms/CommutativityUtils.h @@ -0,0 +1,28 @@ +//===- CommutativityUtils.h - Commutativity utilities -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file declares a utility that is intended to be used inside a pass +// or an individual pattern to simplify the matching of commutative operations. +// Note that this utility can also be used inside PDL patterns in conjunction +// with the `pdl.apply_native_rewrite` op. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TRANSFORMS_COMMUTATIVITYUTILS_H +#define MLIR_TRANSFORMS_COMMUTATIVITYUTILS_H + +namespace mlir { + +class Operation; +class PatternRewriter; + +void sortCommutativeOperands(Operation *op, PatternRewriter &rewriter); + +} // namespace mlir + +#endif // MLIR_TRANSFORMS_COMMUTATIVITYUTILS_H Index: clang/docs/tools/clang-formatted-files.txt =================================================================== --- clang/docs/tools/clang-formatted-files.txt +++ clang/docs/tools/clang-formatted-files.txt @@ -7888,6 +7888,7 @@ mlir/include/mlir/Tools/PDLL/ODS/Dialect.h mlir/include/mlir/Tools/PDLL/ODS/Operation.h mlir/include/mlir/Tools/PDLL/Parser/Parser.h +mlir/include/mlir/Transforms/CommutativityUtils.h mlir/include/mlir/Transforms/ControlFlowSinkUtils.h mlir/include/mlir/Transforms/DialectConversion.h mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -8448,6 +8449,7 @@ mlir/lib/Transforms/StripDebugInfo.cpp mlir/lib/Transforms/SymbolDCE.cpp mlir/lib/Transforms/SymbolPrivatize.cpp +mlir/lib/Transforms/Utils/CommutativityUtils.cpp mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp mlir/lib/Transforms/Utils/DialectConversion.cpp mlir/lib/Transforms/Utils/FoldUtils.cpp
_______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits