Author: Wen-Heng (Jack) Chung Date: 2020-06-05T22:18:20-05:00 New Revision: 1c3be7ec0838048d0a5a8f2ebf3dfa5e831370cd
URL: https://github.com/llvm/llvm-project/commit/1c3be7ec0838048d0a5a8f2ebf3dfa5e831370cd DIFF: https://github.com/llvm/llvm-project/commit/1c3be7ec0838048d0a5a8f2ebf3dfa5e831370cd.diff LOG: Add Op transform logic. Improve Op translate logic. Revise tests. Added: Modified: mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp mlir/test/Dialect/MIOpen/lowering.mlir Removed: ################################################################################ diff --git a/mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp b/mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp index 2bd64efa77b6..cda706c4112c 100644 --- a/mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp +++ b/mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp @@ -343,12 +343,17 @@ struct GridwiseConvolutionImplicitGemm_v4r4_)"; output << kHeaderPreamblePart2; output << kHeaderPreamblePart3; output << '\n'; - output << R"( - constexpr auto )" << tensorDescs[0] << " = InGlobalDesc{};"; - output << R"( - constexpr auto )" << tensorDescs[1] << " = WeiGlobalDesc{};"; - output << R"( - constexpr auto )" << tensorDescs[2] << " = OutGlobalDesc{};"; + + // TBD: remove these interim checks. + if (tensorDescs.size() > 0) + output << R"( + constexpr auto )" << tensorDescs[0] << " = InGlobalDesc{};"; + if (tensorDescs.size() > 1) + output << R"( + constexpr auto )" << tensorDescs[1] << " = WeiGlobalDesc{};"; + if (tensorDescs.size() > 2) + output << R"( + constexpr auto )" << tensorDescs[2] << " = OutGlobalDesc{};"; output << '\n'; } @@ -358,7 +363,7 @@ void EmitHeaderEpilogue(llvm::raw_ostream &output, llvm::SmallDenseMap<int64_t, // decltype(wei_e_k_global_desc), // decltype(in_e_b_global_desc), // decltype(out_k_b_global_desc), - for (int i = 0; i < 3; ++i) { + for (unsigned i = 0; i < args.size(); ++i) { output << R"( decltype()" << args[i] << "),"; } @@ -396,7 +401,9 @@ void EmitDimensionVariables(llvm::raw_ostream &output, llvm::ArrayRef<mlir::Attr case 'H': case 'W': output << llvm::toUpper(strAttr.getValue()[0]); - output << llvm::toUpper(strAttr.getValue()[1]); + // XXX: fix this. + if (strAttr.getValue().size() > 1) + output << llvm::toUpper(strAttr.getValue()[1]); break; default: output << llvm::toUpper(strAttr.getValue()[0]); diff --git a/mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp b/mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp index 2a00ed675122..27311cb8cfb9 100644 --- a/mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp +++ b/mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp @@ -24,6 +24,7 @@ #include "mlir/Dialect/MIOpenOps/Passes.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/IR/Operation.h" @@ -37,6 +38,8 @@ #include "mlir/Transforms/Passes.h" #include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/SmallVector.h" + using namespace mlir; struct Conv2DOpRewritePattern : public OpRewritePattern<miopen::Conv2DOp> { @@ -44,15 +47,450 @@ struct Conv2DOpRewritePattern : public OpRewritePattern<miopen::Conv2DOp> { PatternMatchResult matchAndRewrite(miopen::Conv2DOp op, PatternRewriter &rewriter) const override { - rewriter.create<miopen::TransformOp>(op.getLoc(), op.filter().getType(), op.filter()); + auto filterLayoutAttr = op.getAttrOfType<ArrayAttr>("filter_layout"); + auto inputLayoutAttr = op.getAttrOfType<ArrayAttr>("input_layout"); + auto outputLayoutAttr = op.getAttrOfType<ArrayAttr>("output_layout"); + + // TBD: handle dilations, strides, padding. + + // Transform filter tensor. + auto filterType = op.filter().getType().dyn_cast<MemRefType>(); + auto filterShape = filterType.getShape(); + auto filterElementType = filterType.getElementType(); + + llvm::SmallVector<int64_t, 2> transformedFilterShape; + transformedFilterShape.set_size(filterShape.size() - 2); + // TBD: compute transformed filter shape dimensions. + std::fill(transformedFilterShape.begin(), transformedFilterShape.end(), -1); + auto transformedFilterMemRefType = MemRefType::get(transformedFilterShape, filterElementType); + + llvm::SmallVector<NamedAttribute, 3> transformedFilterAttrs; + + // TBD: set layout attribute. + // TBD: Merge part. + llvm::SmallVector<NamedAttribute, 5> transformedFilterLayoutPart1Specs; + transformedFilterLayoutPart1Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 0)}, op.getContext()))); + transformedFilterLayoutPart1Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("gemmK", op.getContext())}, op.getContext()))); + transformedFilterLayoutPart1Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("Merge", op.getContext()))); + transformedFilterLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_dimensions", + ArrayAttr::get({ + IntegerAttr::get(IntegerType::get(32, op.getContext()), 1), + IntegerAttr::get(IntegerType::get(32, op.getContext()), 2), + IntegerAttr::get(IntegerType::get(32, op.getContext()), 3), + }, op.getContext()))); + transformedFilterLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_names", + ArrayAttr::get({ + StringAttr::get("c", op.getContext()), + StringAttr::get("y", op.getContext()), + StringAttr::get("x", op.getContext()) + }, op.getContext()))); + + // TBD: Passthrough part. + llvm::SmallVector<NamedAttribute, 5> transformedFilterLayoutPart2Specs; + transformedFilterLayoutPart2Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 1)}, op.getContext()))); + transformedFilterLayoutPart2Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("gemmM", op.getContext())}, op.getContext()))); + transformedFilterLayoutPart2Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("PassThrough", op.getContext()))); + transformedFilterLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_dimensions", + ArrayAttr::get({ + IntegerAttr::get(IntegerType::get(32, op.getContext()), 0), + }, op.getContext()))); + transformedFilterLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_names", + ArrayAttr::get({ + StringAttr::get("k", op.getContext()) + }, op.getContext()))); + + auto transformedFilterLayoutAttr = rewriter.getNamedAttr("layout", + ArrayAttr::get({ + DictionaryAttr::get(transformedFilterLayoutPart1Specs, op.getContext()), + DictionaryAttr::get(transformedFilterLayoutPart2Specs, op.getContext()) + }, op.getContext())); + transformedFilterAttrs.push_back(transformedFilterLayoutAttr); + + // set source_layout attribute. + auto filterSrcLayoutAttr = rewriter.getNamedAttr("source_layout", filterLayoutAttr); + transformedFilterAttrs.push_back(filterSrcLayoutAttr); + // set output_layout attribute. + auto filterOutputLayoutAttr = rewriter.getNamedAttr("output_layout", + ArrayAttr::get({ + StringAttr::get("gemmK", op.getContext()), + StringAttr::get("gemmM", op.getContext()) + }, op.getContext())); + transformedFilterAttrs.push_back(filterOutputLayoutAttr); + // set gridwise_gemm_argument_pos attribute. + auto filterGridwiseGemmArgPosAttr = rewriter.getNamedAttr("gridwise_gemm_argument_position", + IntegerAttr::get(IntegerType::get(32, op.getContext()), 0)); + transformedFilterAttrs.push_back(filterGridwiseGemmArgPosAttr); + auto gemmA = rewriter.create<miopen::TransformOp>(op.getLoc(), transformedFilterMemRefType, op.filter(), transformedFilterAttrs); + + + // Transform input tensor. + // Input tensor step 1: padded input. + auto inputType = op.input().getType().dyn_cast<MemRefType>(); + auto inputShape = inputType.getShape(); + auto inputElementType = inputType.getElementType(); + + // TBD: compute padded input shape dimensions. + + llvm::SmallVector<NamedAttribute, 3> paddedInputAttrs; + + // TBD: set layout attribute. + // TBD: part 1: Passthrough. + llvm::SmallVector<NamedAttribute, 5> paddedInputLayoutPart1Specs; + paddedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 0)}, op.getContext()))); + paddedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("ni", op.getContext())}, op.getContext()))); + paddedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("PassThrough", op.getContext()))); + paddedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_dimensions", + ArrayAttr::get({ + IntegerAttr::get(IntegerType::get(32, op.getContext()), 0), + }, op.getContext()))); + paddedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_names", + ArrayAttr::get({ + StringAttr::get("ni", op.getContext()) + }, op.getContext()))); + + // TBD: part 2: Passthrough. + llvm::SmallVector<NamedAttribute, 5> paddedInputLayoutPart2Specs; + paddedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 1)}, op.getContext()))); + paddedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("ci", op.getContext())}, op.getContext()))); + paddedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("PassThrough", op.getContext()))); + paddedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_dimensions", + ArrayAttr::get({ + IntegerAttr::get(IntegerType::get(32, op.getContext()), 1), + }, op.getContext()))); + paddedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_names", + ArrayAttr::get({ + StringAttr::get("ci", op.getContext()) + }, op.getContext()))); + + // TBD: part 3: Pad. + llvm::SmallVector<NamedAttribute, 5> paddedInputLayoutPart3Specs; + paddedInputLayoutPart3Specs.push_back(rewriter.getNamedAttr("dimensions", + ArrayAttr::get({ + IntegerAttr::get(IntegerType::get(32, op.getContext()), 2), + IntegerAttr::get(IntegerType::get(32, op.getContext()), 3) + }, op.getContext()))); + paddedInputLayoutPart3Specs.push_back(rewriter.getNamedAttr("names", + ArrayAttr::get({ + StringAttr::get("hipad", op.getContext()), + StringAttr::get("wipad", op.getContext()), + }, op.getContext()))); + paddedInputLayoutPart3Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("Pad", op.getContext()))); + // TBD: padding parmeters. + paddedInputLayoutPart3Specs.push_back(rewriter.getNamedAttr("parameters", + ArrayAttr::get({ + IntegerAttr::get(IntegerType::get(32, op.getContext()), 0), + IntegerAttr::get(IntegerType::get(32, op.getContext()), 0) + }, op.getContext()))); + paddedInputLayoutPart3Specs.push_back(rewriter.getNamedAttr("source_dimensions", + ArrayAttr::get({ + IntegerAttr::get(IntegerType::get(32, op.getContext()), 2), + IntegerAttr::get(IntegerType::get(32, op.getContext()), 3) + }, op.getContext()))); + paddedInputLayoutPart3Specs.push_back(rewriter.getNamedAttr("source_names", + ArrayAttr::get({ + StringAttr::get("hi", op.getContext()), + StringAttr::get("wi", op.getContext()) + }, op.getContext()))); + + auto paddedInputLayoutAttr = rewriter.getNamedAttr("layout", + ArrayAttr::get({ + DictionaryAttr::get(paddedInputLayoutPart1Specs, op.getContext()), + DictionaryAttr::get(paddedInputLayoutPart2Specs, op.getContext()), + DictionaryAttr::get(paddedInputLayoutPart3Specs, op.getContext()) + }, op.getContext())); + paddedInputAttrs.push_back(paddedInputLayoutAttr); + + // set source_layout attribute. + auto inputSrcLayoutAttr = rewriter.getNamedAttr("source_layout", inputLayoutAttr); + paddedInputAttrs.push_back(inputSrcLayoutAttr); + // TBD: set output_layout attribute. + auto paddedInputOutputLayoutAttr = rewriter.getNamedAttr("output_layout", + ArrayAttr::get({ + StringAttr::get("ni", op.getContext()), + StringAttr::get("ci", op.getContext()), + StringAttr::get("hi", op.getContext()), + StringAttr::get("wi", op.getContext()) + }, op.getContext())); + paddedInputAttrs.push_back(paddedInputOutputLayoutAttr); + auto paddedInput = rewriter.create<miopen::TransformOp>(op.getLoc(), inputType, op.input(), paddedInputAttrs); + + // Input tensor step 2 : embedded input. + llvm::SmallVector<int64_t, 6> embeddedInputShape; + embeddedInputShape.set_size(inputShape.size() + 2); + // TBD: compute embedded input shape dimensions. + std::fill(embeddedInputShape.begin(), embeddedInputShape.end(), -1); + auto embeddedInputMemRefType = MemRefType::get(embeddedInputShape, inputElementType); + + llvm::SmallVector<NamedAttribute, 3> embeddedInputAttrs; + + // TBD: set layout attribute. + // TBD: part 1: Passthrough. + llvm::SmallVector<NamedAttribute, 5> embeddedInputLayoutPart1Specs; + embeddedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 0)}, op.getContext()))); + embeddedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("ni", op.getContext())}, op.getContext()))); + embeddedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("PassThrough", op.getContext()))); + embeddedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_dimensions", + ArrayAttr::get({ IntegerAttr::get(IntegerType::get(32, op.getContext()), 0), + }, op.getContext()))); + embeddedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_names", + ArrayAttr::get({ + StringAttr::get("ni", op.getContext()) + }, op.getContext()))); + + // TBD: part 2: Passthrough. + llvm::SmallVector<NamedAttribute, 5> embeddedInputLayoutPart2Specs; + embeddedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 1)}, op.getContext()))); + embeddedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("ci", op.getContext())}, op.getContext()))); + embeddedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("PassThrough", op.getContext()))); + embeddedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_dimensions", + ArrayAttr::get({ IntegerAttr::get(IntegerType::get(32, op.getContext()), 1), + }, op.getContext()))); + embeddedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_names", + ArrayAttr::get({ + StringAttr::get("ci", op.getContext()) + }, op.getContext()))); + // TBD: part 3: Embed. + llvm::SmallVector<NamedAttribute, 5> embeddedInputLayoutPart3Specs; + embeddedInputLayoutPart3Specs.push_back(rewriter.getNamedAttr("dimensions", + ArrayAttr::get({ + IntegerAttr::get(IntegerType::get(32, op.getContext()), 2), + IntegerAttr::get(IntegerType::get(32, op.getContext()), 3) + }, op.getContext()))); + embeddedInputLayoutPart3Specs.push_back(rewriter.getNamedAttr("names", + ArrayAttr::get({ + StringAttr::get("y", op.getContext()), + StringAttr::get("ho", op.getContext()), + }, op.getContext()))); + embeddedInputLayoutPart3Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("Embed", op.getContext()))); + // TBD: padding parmeters. + embeddedInputLayoutPart3Specs.push_back(rewriter.getNamedAttr("parameters", + ArrayAttr::get({ + IntegerAttr::get(IntegerType::get(32, op.getContext()), 2), + IntegerAttr::get(IntegerType::get(32, op.getContext()), 1), + IntegerAttr::get(IntegerType::get(32, op.getContext()), 1), + IntegerAttr::get(IntegerType::get(32, op.getContext()), 0) + }, op.getContext()))); + embeddedInputLayoutPart3Specs.push_back(rewriter.getNamedAttr("source_dimensions", + ArrayAttr::get({ + IntegerAttr::get(IntegerType::get(32, op.getContext()), 2) + }, op.getContext()))); + embeddedInputLayoutPart3Specs.push_back(rewriter.getNamedAttr("source_names", + ArrayAttr::get({ + StringAttr::get("hipad", op.getContext()), + }, op.getContext()))); + + // TBD: part 4: Embed. + llvm::SmallVector<NamedAttribute, 5> embeddedInputLayoutPart4Specs; + embeddedInputLayoutPart4Specs.push_back(rewriter.getNamedAttr("dimensions", + ArrayAttr::get({ + IntegerAttr::get(IntegerType::get(32, op.getContext()), 4), + IntegerAttr::get(IntegerType::get(32, op.getContext()), 5) + }, op.getContext()))); + embeddedInputLayoutPart4Specs.push_back(rewriter.getNamedAttr("names", + ArrayAttr::get({ + StringAttr::get("x", op.getContext()), + StringAttr::get("wo", op.getContext()), + }, op.getContext()))); + embeddedInputLayoutPart4Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("Embed", op.getContext()))); + // TBD: embed parmeters. + embeddedInputLayoutPart4Specs.push_back(rewriter.getNamedAttr("parameters", + ArrayAttr::get({ + IntegerAttr::get(IntegerType::get(32, op.getContext()), 2), + IntegerAttr::get(IntegerType::get(32, op.getContext()), 1), + IntegerAttr::get(IntegerType::get(32, op.getContext()), 1), + IntegerAttr::get(IntegerType::get(32, op.getContext()), 0) + }, op.getContext()))); + embeddedInputLayoutPart4Specs.push_back(rewriter.getNamedAttr("source_dimensions", + ArrayAttr::get({ + IntegerAttr::get(IntegerType::get(32, op.getContext()), 3) + }, op.getContext()))); + embeddedInputLayoutPart4Specs.push_back(rewriter.getNamedAttr("source_names", + ArrayAttr::get({ + StringAttr::get("wipad", op.getContext()) + }, op.getContext()))); + + auto embeddedInputLayoutAttr = rewriter.getNamedAttr("layout", + ArrayAttr::get({ + DictionaryAttr::get(embeddedInputLayoutPart1Specs, op.getContext()), + DictionaryAttr::get(embeddedInputLayoutPart2Specs, op.getContext()), + DictionaryAttr::get(embeddedInputLayoutPart3Specs, op.getContext()), + DictionaryAttr::get(embeddedInputLayoutPart4Specs, op.getContext()) + }, op.getContext())); + embeddedInputAttrs.push_back(embeddedInputLayoutAttr); + + + // TBD: set intermediate_layout attribute. + auto embeddedInputImmLayoutAttr = rewriter.getNamedAttr("intermediate_layout", + ArrayAttr::get({ + StringAttr::get("ni", op.getContext()), + StringAttr::get("ci", op.getContext()), + StringAttr::get("hipad", op.getContext()), + StringAttr::get("wipad", op.getContext()) + }, op.getContext())); + embeddedInputAttrs.push_back(embeddedInputImmLayoutAttr); + // TBD: set output_layout attribute. + auto embeddedInputOutputLayoutAttr = rewriter.getNamedAttr("output_layout", + ArrayAttr::get({ + StringAttr::get("ni", op.getContext()), + StringAttr::get("ci", op.getContext()), + StringAttr::get("y", op.getContext()), + StringAttr::get("ho", op.getContext()), + StringAttr::get("x", op.getContext()), + StringAttr::get("wo", op.getContext()) + }, op.getContext())); + embeddedInputAttrs.push_back(embeddedInputOutputLayoutAttr); + auto embeddedInput = rewriter.create<miopen::TransformOp>(op.getLoc(), embeddedInputMemRefType, ArrayRef<Value>(paddedInput), embeddedInputAttrs); + + // Input tensor step 3: transformed input. + llvm::SmallVector<int64_t, 2> transformedInputShape; + transformedInputShape.set_size(inputShape.size() - 2); + // TBD: compute transformed input shape dimensions. + std::fill(transformedInputShape.begin(), transformedInputShape.end(), -1); + auto transformedInputMemRefType = MemRefType::get(transformedInputShape, inputElementType); + + llvm::SmallVector<NamedAttribute, 3> transformedInputAttrs; + + // TBD: set layout attribute. + // TBD: Part 1: Merge. + llvm::SmallVector<NamedAttribute, 5> transformedInputLayoutPart1Specs; + transformedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 0)}, op.getContext()))); + transformedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("gemmK", op.getContext())}, op.getContext()))); + transformedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("Merge", op.getContext()))); + transformedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_dimensions", + ArrayAttr::get({ + IntegerAttr::get(IntegerType::get(32, op.getContext()), 1), + IntegerAttr::get(IntegerType::get(32, op.getContext()), 2), + IntegerAttr::get(IntegerType::get(32, op.getContext()), 4) + }, op.getContext()))); + transformedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_names", + ArrayAttr::get({ + StringAttr::get("ci", op.getContext()), + StringAttr::get("y", op.getContext()), + StringAttr::get("x", op.getContext()) + }, op.getContext()))); + + // TBD: Part 2: Merge. + llvm::SmallVector<NamedAttribute, 5> transformedInputLayoutPart2Specs; + transformedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 1)}, op.getContext()))); + transformedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("gemmN", op.getContext())}, op.getContext()))); + transformedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("Merge", op.getContext()))); + transformedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_dimensions", + ArrayAttr::get({ + IntegerAttr::get(IntegerType::get(32, op.getContext()), 0), + IntegerAttr::get(IntegerType::get(32, op.getContext()), 3), + IntegerAttr::get(IntegerType::get(32, op.getContext()), 5) + }, op.getContext()))); + transformedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_names", + ArrayAttr::get({ + StringAttr::get("ni", op.getContext()), + StringAttr::get("ho", op.getContext()), + StringAttr::get("wo", op.getContext()) + }, op.getContext()))); + + auto transformedInputLayoutAttr = rewriter.getNamedAttr("layout", + ArrayAttr::get({ + DictionaryAttr::get(transformedInputLayoutPart1Specs, op.getContext()), + DictionaryAttr::get(transformedInputLayoutPart2Specs, op.getContext()) + }, op.getContext())); + transformedInputAttrs.push_back(transformedInputLayoutAttr); + + // TBD: set intermediate_layout attribute. + auto transformedInputImmLayoutAttr = rewriter.getNamedAttr("intermediate_layout", + ArrayAttr::get({ + StringAttr::get("ni", op.getContext()), + StringAttr::get("ci", op.getContext()), + StringAttr::get("y", op.getContext()), + StringAttr::get("ho", op.getContext()), + StringAttr::get("x", op.getContext()), + StringAttr::get("wo", op.getContext()) + }, op.getContext())); + transformedInputAttrs.push_back(transformedInputImmLayoutAttr); + // TBD: set output_layout attribute. + auto transformedInputOutputLayoutAttr = rewriter.getNamedAttr("output_layout", + ArrayAttr::get({ + StringAttr::get("gemmK", op.getContext()), + StringAttr::get("gemmN", op.getContext()), + }, op.getContext())); + transformedInputAttrs.push_back(transformedInputOutputLayoutAttr); + + // set gridwise_gemm_argument_pos attribute. + auto inputGridwiseGemmArgPosAttr = rewriter.getNamedAttr("gridwise_gemm_argument_position", + IntegerAttr::get(IntegerType::get(32, op.getContext()), 1)); + transformedInputAttrs.push_back(inputGridwiseGemmArgPosAttr); + auto gemmB = rewriter.create<miopen::TransformOp>(op.getLoc(), transformedInputMemRefType, ArrayRef<Value>(embeddedInput), transformedInputAttrs); + + + // Transform output tensor. + auto outputType = op.output().getType().dyn_cast<MemRefType>(); + auto outputShape = outputType.getShape(); + auto outputElementType = outputType.getElementType(); + + llvm::SmallVector<int64_t, 2> transformedOutputShape; + transformedOutputShape.set_size(outputShape.size() - 2); + // TBD: compute transformed output shape dimensions. + std::fill(transformedOutputShape.begin(), transformedOutputShape.end(), -1); + auto transformedOutputMemRefType = MemRefType::get(transformedOutputShape, outputElementType); + + llvm::SmallVector<NamedAttribute, 3> transformedOutputAttrs; + + // TBD: set layout attribute. + // TBD: Part 1: Passthrough. + llvm::SmallVector<NamedAttribute, 5> transformedOutputLayoutPart1Specs; + transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 0)}, op.getContext()))); + transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("gemmM", op.getContext())}, op.getContext()))); + transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("PassThrough", op.getContext()))); + transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_dimensions", + ArrayAttr::get({ + IntegerAttr::get(IntegerType::get(32, op.getContext()), 1), + }, op.getContext()))); + transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_names", + ArrayAttr::get({ + StringAttr::get("ko", op.getContext()) + }, op.getContext()))); + + // TBD: Part 2: Merge. + llvm::SmallVector<NamedAttribute, 5> transformedOutputLayoutPart2Specs; + transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 1)}, op.getContext()))); + transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("gemmN", op.getContext())}, op.getContext()))); + transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("Merge", op.getContext()))); + transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_dimensions", + ArrayAttr::get({ + IntegerAttr::get(IntegerType::get(32, op.getContext()), 0), + IntegerAttr::get(IntegerType::get(32, op.getContext()), 2), + IntegerAttr::get(IntegerType::get(32, op.getContext()), 3), + }, op.getContext()))); + transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_names", + ArrayAttr::get({ StringAttr::get("no", op.getContext()), + StringAttr::get("ho", op.getContext()), + StringAttr::get("wo", op.getContext()) + }, op.getContext()))); + + auto transformedOutputLayoutAttr = rewriter.getNamedAttr("layout", + ArrayAttr::get({ + DictionaryAttr::get(transformedOutputLayoutPart1Specs, op.getContext()), + DictionaryAttr::get(transformedOutputLayoutPart2Specs, op.getContext()) + }, op.getContext())); + transformedOutputAttrs.push_back(transformedOutputLayoutAttr); - rewriter.create<miopen::TransformOp>(op.getLoc(), op.input().getType(), op.input()); - rewriter.create<miopen::TransformOp>(op.getLoc(), op.input().getType(), op.input()); - rewriter.create<miopen::TransformOp>(op.getLoc(), op.input().getType(), op.input()); + // set source_layout attribute. + auto outputSrcLayoutAttr = rewriter.getNamedAttr("source_layout", outputLayoutAttr); + transformedOutputAttrs.push_back(outputSrcLayoutAttr); + // TBD: set output_layout attribute. + auto transformedOutputOutputLayoutAttr = rewriter.getNamedAttr("output_layout", + ArrayAttr::get({ + StringAttr::get("gemmM", op.getContext()), + StringAttr::get("gemmN", op.getContext()), + }, op.getContext())); + transformedOutputAttrs.push_back(transformedOutputOutputLayoutAttr); - rewriter.create<miopen::TransformOp>(op.getLoc(), op.output().getType(), op.output()); + // TBD: set gridwise_gemm_argument_pos attribute. + auto outputGridwiseGemmArgPosAttr = rewriter.getNamedAttr("gridwise_gemm_argument_position", + IntegerAttr::get(IntegerType::get(32, op.getContext()), 2)); + transformedOutputAttrs.push_back(outputGridwiseGemmArgPosAttr); + auto gemmC = rewriter.create<miopen::TransformOp>(op.getLoc(), transformedOutputMemRefType, op.output(), transformedOutputAttrs); - //rewriter.create<miopen::GridwiseGemmOp>(op.getLoc(), op.filter(), op.input(), op.output()); + // Emit miopen.gridwise_gemm op. + rewriter.create<miopen::GridwiseGemmOp>(op.getLoc(), gemmA, gemmB, gemmC); // Finally, erase the original Conv2D op. op.erase(); diff --git a/mlir/test/Dialect/MIOpen/lowering.mlir b/mlir/test/Dialect/MIOpen/lowering.mlir index e7734cef5a29..5907fbd41ebd 100644 --- a/mlir/test/Dialect/MIOpen/lowering.mlir +++ b/mlir/test/Dialect/MIOpen/lowering.mlir @@ -3,8 +3,8 @@ func @miopen_conv2d(%filter : memref<?x?x?x?xf32>, %input : memref<?x?x?x?xf32>, %output : memref<?x?x?x?xf32>) { miopen.conv2d(%filter, %input, %output) { filter_layout = ["k", "c", "y", "x"], - input_layout = ["n", "c", "hi", "wi"], - output_layout = ["n", "k", "ho", "wo"], + input_layout = ["ni", "ci", "hi", "wi"], + output_layout = ["no", "ko", "ho", "wo"], dilations = [1, 1], strides = [1, 1], padding = [0, 0] @@ -18,4 +18,4 @@ func @miopen_conv2d(%filter : memref<?x?x?x?xf32>, %input : memref<?x?x?x?xf32>, // CHECK-NEXT: miopen.transform // CHECK-NEXT: miopen.transform // CHECK-NEXT: miopen.transform -// TBD-CHECK-NEXT: miopen.gridwise_gemm +// CHECK-NEXT: miopen.gridwise_gemm _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits