Author: Wen-Heng (Jack) Chung Date: 2020-06-05T22:18:20-05:00 New Revision: e587a8a23c9812610dd1e79a3e1211e1f4d8aba5
URL: https://github.com/llvm/llvm-project/commit/e587a8a23c9812610dd1e79a3e1211e1f4d8aba5 DIFF: https://github.com/llvm/llvm-project/commit/e587a8a23c9812610dd1e79a3e1211e1f4d8aba5.diff LOG: Inital commit to add MIOpen Conv2D to Transform and GridwiseGemm transform pass. Added: mlir/include/mlir/Dialect/MIOpenOps/Passes.h mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp mlir/test/Dialect/MIOpen/lowering.mlir Modified: mlir/include/mlir/Dialect/MIOpenOps/MIOpenOps.td Removed: ################################################################################ diff --git a/mlir/include/mlir/Dialect/MIOpenOps/MIOpenOps.td b/mlir/include/mlir/Dialect/MIOpenOps/MIOpenOps.td index 8ffd66647f3f..1f531d9176ab 100644 --- a/mlir/include/mlir/Dialect/MIOpenOps/MIOpenOps.td +++ b/mlir/include/mlir/Dialect/MIOpenOps/MIOpenOps.td @@ -37,9 +37,9 @@ class MIOpen_Op<string mnemonic, list<OpTrait> traits = []> : def MIOpen_Conv2DOp : MIOpen_Op<"conv2d">, - Arguments<(ins MemRefRankOf<[F32], [4]>, - MemRefRankOf<[F32], [4]>, - MemRefRankOf<[F32], [4]>)> { + Arguments<(ins MemRefRankOf<[F32], [4]>:$filter, + MemRefRankOf<[F32], [4]>:$input, + MemRefRankOf<[F32], [4]>:$output)> { let summary = "2D convolution"; let description = [{ The `miopen.conv2d` op computes 2D convolution. diff --git a/mlir/include/mlir/Dialect/MIOpenOps/Passes.h b/mlir/include/mlir/Dialect/MIOpenOps/Passes.h new file mode 100644 index 000000000000..6752b71c5598 --- /dev/null +++ b/mlir/include/mlir/Dialect/MIOpenOps/Passes.h @@ -0,0 +1,33 @@ +//===- Passes.h - Linalg pass entry points ----------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines prototypes that expose pass constructors. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_MIOPEN_PASSES_H_ +#define MLIR_DIALECT_MIOPEN_PASSES_H_ + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/ArrayRef.h" + +namespace mlir { +class FuncOp; +class ModuleOp; +template <typename T> class OpPassBase; + +namespace miopen { + +/// Create a pass to convert MIOpen conv2d operations to transform and +/// gridwise_gemm operations. +std::unique_ptr<OpPassBase<ModuleOp>> createLowerMIOpenOpsPass(); + +} // namespace miopen +} // namespace mlir + +#endif // MLIR_DIALECT_MIOPEN_PASSES_H_ diff --git a/mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp b/mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp new file mode 100644 index 000000000000..2a00ed675122 --- /dev/null +++ b/mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp @@ -0,0 +1,82 @@ +//===- LowerMIOpenOps.cpp - MLIR MIOpen ops lowering passes ---------------===// +// +// Copyright 2020 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This pass converts miopen.conv2d into miopen.transform and +// miopen.gridwise_gemm. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MIOpenOps/MIOpenOps.h" +#include "mlir/Dialect/MIOpenOps/Passes.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Support/LogicalResult.h" + +using namespace mlir; + +struct Conv2DOpRewritePattern : public OpRewritePattern<miopen::Conv2DOp> { + using OpRewritePattern<miopen::Conv2DOp>::OpRewritePattern; + + PatternMatchResult + matchAndRewrite(miopen::Conv2DOp op, PatternRewriter &rewriter) const override { + rewriter.create<miopen::TransformOp>(op.getLoc(), op.filter().getType(), op.filter()); + + rewriter.create<miopen::TransformOp>(op.getLoc(), op.input().getType(), op.input()); + rewriter.create<miopen::TransformOp>(op.getLoc(), op.input().getType(), op.input()); + rewriter.create<miopen::TransformOp>(op.getLoc(), op.input().getType(), op.input()); + + rewriter.create<miopen::TransformOp>(op.getLoc(), op.output().getType(), op.output()); + + //rewriter.create<miopen::GridwiseGemmOp>(op.getLoc(), op.filter(), op.input(), op.output()); + + // Finally, erase the original Conv2D op. + op.erase(); + + return matchSuccess(); + } +}; + +namespace { +struct LowerMIOpenOpsPass : public ModulePass<LowerMIOpenOpsPass> { + void runOnModule() override; +}; +} // end anonymous namespace + +void LowerMIOpenOpsPass::runOnModule() { + OwningRewritePatternList patterns; + patterns.insert<Conv2DOpRewritePattern>(&getContext()); + applyPatternsGreedily(getModule(), patterns); +} + +std::unique_ptr<OpPassBase<ModuleOp>> mlir::miopen::createLowerMIOpenOpsPass() { + return std::make_unique<LowerMIOpenOpsPass>(); +} + +static PassRegistration<LowerMIOpenOpsPass> + lowerMIOpenOpsPass("miopen-lowering", + "Lower MIOpen conv2d into transform and gridwise_gemm."); diff --git a/mlir/test/Dialect/MIOpen/lowering.mlir b/mlir/test/Dialect/MIOpen/lowering.mlir new file mode 100644 index 000000000000..e7734cef5a29 --- /dev/null +++ b/mlir/test/Dialect/MIOpen/lowering.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-opt -miopen-lowering %s | FileCheck %s + +func @miopen_conv2d(%filter : memref<?x?x?x?xf32>, %input : memref<?x?x?x?xf32>, %output : memref<?x?x?x?xf32>) { + miopen.conv2d(%filter, %input, %output) { + filter_layout = ["k", "c", "y", "x"], + input_layout = ["n", "c", "hi", "wi"], + output_layout = ["n", "k", "ho", "wo"], + dilations = [1, 1], + strides = [1, 1], + padding = [0, 0] + } : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32> + return +} +// CHECK-LABEL: func @miopen_conv2d +// CHECK-NOT: miopen.conv2d +// CHECK-NEXT: miopen.transform +// CHECK-NEXT: miopen.transform +// CHECK-NEXT: miopen.transform +// CHECK-NEXT: miopen.transform +// CHECK-NEXT: miopen.transform +// TBD-CHECK-NEXT: miopen.gridwise_gemm _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits