Author: Wen-Heng (Jack) Chung Date: 2020-06-05T22:18:20-05:00 New Revision: 7054cfc71ac450ab5ac9ee505d1096bb1252f9c6
URL: https://github.com/llvm/llvm-project/commit/7054cfc71ac450ab5ac9ee505d1096bb1252f9c6 DIFF: https://github.com/llvm/llvm-project/commit/7054cfc71ac450ab5ac9ee505d1096bb1252f9c6.diff LOG: Fix translation logic. Added: Modified: mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp Removed: ################################################################################ diff --git a/mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp b/mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp index 4eb8d7de7181..2bd64efa77b6 100644 --- a/mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp +++ b/mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp @@ -183,20 +183,16 @@ static constexpr StringLiteral kCppEpiloguePart2 =R"( void EmitCppPreamble(llvm::raw_ostream &output, llvm::StringRef layoutStr) { output << kCppPreamblePart1; - // Between Preamble Part 1 and Part 2: // #include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" output << R"(#include "gridwise_convolution_implicit_gemm_v4r4_)"; - output << layoutStr << ".hpp"; - + output << layoutStr << R"(.hpp")"; output << kCppPreamblePart2; - // Between Preamble Part 2 and Par 3: // __launch_bounds__(CK_PARAM_TUNABLE_BLOCK_SIZE, 2) void gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw( output << R"( __launch_bounds__(CK_PARAM_TUNABLE_BLOCK_SIZE, 2) void gridwise_convolution_implicit_gemm_v4r4_)"; output << layoutStr; - output << kCppPreamblePart3; } @@ -210,9 +206,7 @@ void EmitCppEpilogue(llvm::raw_ostream &output, llvm::StringRef layoutStr, llvm: output << R"( constexpr auto gridwise_conv = GridwiseConvolutionImplicitGemm_v4r4_)"; output << layoutStr; - output << kCppEpiloguePart1; - // Between Part1 and Part2: // decltype(in_nchw_desc), // decltype(wei_kcyx_desc), @@ -220,7 +214,6 @@ void EmitCppEpilogue(llvm::raw_ostream &output, llvm::StringRef layoutStr, llvm: for (auto desc : tensorDescs) { output << " decltype(" << desc << "),\n"; } - output << kCppEpiloguePart2; } @@ -344,28 +337,23 @@ static constexpr StringLiteral kHeaderEpiloguePart2 = R"( void EmitHeaderPreamble(llvm::raw_ostream &output, llvm::StringRef layoutStr, llvm::SmallVector<std::string, 3> &tensorDescs) { output << kHeaderPreamblePart1; - output << R"( struct GridwiseConvolutionImplicitGemm_v4r4_)"; output << layoutStr; - output << kHeaderPreamblePart2; - output << kHeaderPreamblePart3; - output << '\n'; - output << R"( constexpr auto )" << tensorDescs[0] << " = InGlobalDesc{};"; output << R"( constexpr auto )" << tensorDescs[1] << " = WeiGlobalDesc{};"; output << R"( constexpr auto )" << tensorDescs[2] << " = OutGlobalDesc{};"; + output << '\n'; } void EmitHeaderEpilogue(llvm::raw_ostream &output, llvm::SmallDenseMap<int64_t, std::string> &args) { output << kHeaderEpiloguePart1; - // Between Part1 and Part2 emit: // decltype(wei_e_k_global_desc), // decltype(in_e_b_global_desc), @@ -374,7 +362,6 @@ void EmitHeaderEpilogue(llvm::raw_ostream &output, llvm::SmallDenseMap<int64_t, output << R"( decltype()" << args[i] << "),"; } - output << kHeaderEpiloguePart2; } @@ -437,17 +424,19 @@ void EmitStrideVariables(llvm::raw_ostream &output, llvm::ArrayRef<mlir::Attribu } } -void EmitInterleaveArrayAttrOfStringAttrWithSeparator(llvm::raw_ostream &os, mlir::ArrayAttr &arrayAttr, const StringRef &separator) { +template<typename T> +void EmitInterleaveArrayAttrWithSeparator(llvm::raw_ostream &os, mlir::ArrayAttr &arrayAttr, const StringRef &separator) { if (arrayAttr) { interleave(arrayAttr, os, [&](Attribute attr) { - if (auto strAttr = attr.dyn_cast<StringAttr>()) - os << strAttr.getValue(); + if (auto typedAttr = attr.dyn_cast<T>()) + os << typedAttr.getValue(); }, separator); } } -void EmitInterleaveCommaArrayAttrOfStringAttr(llvm::raw_ostream &os, mlir::ArrayAttr &arrayAttr) { - EmitInterleaveArrayAttrOfStringAttrWithSeparator(os, arrayAttr, ", "); +template<typename T> +void EmitInterleaveCommaArrayAttr(llvm::raw_ostream &os, mlir::ArrayAttr &arrayAttr) { + EmitInterleaveArrayAttrWithSeparator<T>(os, arrayAttr, ", "); } void ObtainModuleInfo(ModuleOp &m, std::string &layoutStr, llvm::SmallVector<std::string, 3> &tensorDescs) { @@ -511,7 +500,8 @@ std::unique_ptr<llvm::StringRef> mlir::translateModuleToMIOpenHeader(ModuleOp m) // Start emitting. EmitHeaderPreamble(output, layoutStr, tensorDescs); - f.walk([&output, &srcLayoutAttrCtr, &tensorDescs, &gridwiseGemmArguments](miopen::TransformOp op) { + // First iteration. Output source dimensions. + f.walk([&output, &srcLayoutAttrCtr, &tensorDescs](miopen::TransformOp op) { // get source_layout attribute. auto srcLayoutAttr = op.getAttrOfType<ArrayAttr>("source_layout"); if (srcLayoutAttr) { @@ -520,10 +510,17 @@ std::unique_ptr<llvm::StringRef> mlir::translateModuleToMIOpenHeader(ModuleOp m) EmitLayoutString(output, srcLayout, "", "", ", "); output << '\n'; - EmitHeaderDimensionLengths(output, srcLayout, tensorDescs[srcLayoutAttrCtr]); + EmitHeaderDimensionLengths(output, srcLayout, tensorDescs[srcLayoutAttrCtr++]); } - output << '\n'; + }); + output << '\n'; + srcLayoutAttrCtr = 0; + // Second iteration. Output the rest. + f.walk([&output, &srcLayoutAttrCtr, &tensorDescs, &gridwiseGemmArguments](miopen::TransformOp op) { + // get source_layout attribute. + auto srcLayoutAttr = op.getAttrOfType<ArrayAttr>("source_layout"); + // get layout attribute. auto layoutAttr = op.getAttrOfType<ArrayAttr>("layout"); std::string inputTensorName; @@ -549,22 +546,20 @@ std::unique_ptr<llvm::StringRef> mlir::translateModuleToMIOpenHeader(ModuleOp m) // get intermediate_layout attribute. if (immLayoutAttr) { ins << kVarName[srcLayoutAttrCtr - 1] << "_"; - EmitInterleaveArrayAttrOfStringAttrWithSeparator(ins, immLayoutAttr, "_"); + EmitInterleaveArrayAttrWithSeparator<StringAttr>(ins, immLayoutAttr, "_"); ins << "_desc"; ins.flush(); outs << kVarName[srcLayoutAttrCtr - 1] << "_"; } } - EmitInterleaveArrayAttrOfStringAttrWithSeparator(outs, outputLayoutAttr, "_"); + EmitInterleaveArrayAttrWithSeparator<StringAttr>(outs, outputLayoutAttr, "_"); outs << "_desc"; outs.flush(); // determine gridwise GEMM arguments. auto gridwiseGemmArgPosAttr = op.getAttrOfType<IntegerAttr>("gridwise_gemm_argument_position"); if (gridwiseGemmArgPosAttr) { - llvm::errs() << "gridwise gemm argument pos: " << gridwiseGemmArgPosAttr.getValue() << "\n"; - llvm::errs() << "tensor: " << outputTensorName << "\n"; gridwiseGemmArguments[gridwiseGemmArgPosAttr.getInt()] = outputTensorName; } @@ -572,30 +567,48 @@ std::unique_ptr<llvm::StringRef> mlir::translateModuleToMIOpenHeader(ModuleOp m) srcs << " make_tuple("; dsts << " make_tuple("; + // XXX see if we can get better than this. + int convDilationCtr = 0; + for (auto layoutSpec = layoutAttr.begin(); layoutSpec != layoutAttr.end(); ) { if (auto layoutSpecDict = layoutSpec->dyn_cast<DictionaryAttr>()) { auto srcNames = layoutSpecDict.get("source_names").dyn_cast<ArrayAttr>(); auto dstNames = layoutSpecDict.get("names").dyn_cast<ArrayAttr>(); + auto srcDims = layoutSpecDict.get("source_dimensions").dyn_cast<ArrayAttr>(); + auto dstDims = layoutSpecDict.get("dimensions").dyn_cast<ArrayAttr>(); if (auto transform = layoutSpecDict.get("transformation").dyn_cast<StringAttr>()) { - if (transform.getValue() == "PassThrough" || - transform.getValue() == "Merge") { + if (transform.getValue() == "PassThrough") { ops << transform.getValue() << "<"; - EmitInterleaveCommaArrayAttrOfStringAttr(ops, srcNames); + EmitInterleaveCommaArrayAttr<StringAttr>(ops, srcNames); ops << ">{}"; + } else if (transform.getValue() == "Merge") { + ops << transform.getValue() << "<" + << "Sequence<"; + EmitInterleaveCommaArrayAttr<StringAttr>(ops, srcNames); + ops << ">" << ">{}"; } else if (transform.getValue() == "Pad") { ops << transform.getValue() << "<" << "Sequence<"; - EmitInterleaveCommaArrayAttrOfStringAttr(ops, srcNames); + EmitInterleaveCommaArrayAttr<StringAttr>(ops, srcNames); ops << ">, InLeftPads, InRightPads" << ">{}"; } else if (transform.getValue() == "Embed") { ops << transform.getValue() << "<" << "Sequence<"; - EmitInterleaveCommaArrayAttrOfStringAttr(ops, dstNames); - ops << ">, Sequence<ConvDilationTBD, ConvDilationTBD, 0>>{}"; + EmitInterleaveCommaArrayAttr<StringAttr>(ops, dstNames); + if (convDilationCtr == 0) { + ops << ">, Sequence<ConvDilationH, ConvDilationH, 0>>{}"; + convDilationCtr++; + } else { + ops << ">, Sequence<ConvDilationW, ConvDilationW, 0>>{}"; + } } - srcs << "Sequence<" << layoutSpecDict.get("source_dimensions") << ">{}"; - dsts << "Sequence<" << layoutSpecDict.get("dimensions") << ">{}"; + srcs << "Sequence<"; + EmitInterleaveCommaArrayAttr<IntegerAttr>(srcs, srcDims); + srcs << ">{}"; + dsts << "Sequence<"; + EmitInterleaveCommaArrayAttr<IntegerAttr>(dsts, dstDims); + dsts << ">{}"; } } @@ -616,7 +629,7 @@ std::unique_ptr<llvm::StringRef> mlir::translateModuleToMIOpenHeader(ModuleOp m) output << " constexpr auto " << outputTensorName << " = transform_tensor_descriptor(\n"; output << " " << inputTensorName << ",\n"; output << operationSpec << srcDimSpec << dstDimSpec; - output << ");\n"; + output << ");\n\n"; }); // TBD get tuning parameters. _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits