Author: Wen-Heng (Jack) Chung Date: 2020-06-05T22:18:19-05:00 New Revision: ff39c4c709ac1603d21f7baab75dbfbb13ae6fbc
URL: https://github.com/llvm/llvm-project/commit/ff39c4c709ac1603d21f7baab75dbfbb13ae6fbc DIFF: https://github.com/llvm/llvm-project/commit/ff39c4c709ac1603d21f7baab75dbfbb13ae6fbc.diff LOG: Add parse / print logic to MIOpen ops. Revise test cases along the way. Added: Modified: mlir/include/mlir/Dialect/MIOpenOps/MIOpenOps.td mlir/lib/Dialect/MIOpenOps/MIOpenOps.cpp mlir/test/Dialect/MIOpen/ops.mlir Removed: ################################################################################ diff --git a/mlir/include/mlir/Dialect/MIOpenOps/MIOpenOps.td b/mlir/include/mlir/Dialect/MIOpenOps/MIOpenOps.td index 1304f16f3b30..8ffd66647f3f 100644 --- a/mlir/include/mlir/Dialect/MIOpenOps/MIOpenOps.td +++ b/mlir/include/mlir/Dialect/MIOpenOps/MIOpenOps.td @@ -35,10 +35,36 @@ class MIOpen_Op<string mnemonic, list<OpTrait> traits = []> : let parser = [{ return ::parse$cppClass(parser, result); }]; } -def MIOpen_Conv2DOp : MIOpen_Op<"conv2d">; +def MIOpen_Conv2DOp : + MIOpen_Op<"conv2d">, + Arguments<(ins MemRefRankOf<[F32], [4]>, + MemRefRankOf<[F32], [4]>, + MemRefRankOf<[F32], [4]>)> { + let summary = "2D convolution"; + let description = [{ + The `miopen.conv2d` op computes 2D convolution. + }]; +} -def MIOpen_TransformOp : MIOpen_Op<"transform">; +def MIOpen_TransformOp : + MIOpen_Op<"transform">, + Arguments<(ins AnyMemRef)>, + Results<(outs AnyMemRef)> { + let summary = "Tensor transformation"; + let description = [{ + The `miopen.transform` op transforms tensor coordinates. + }]; +} -def MIOpen_GridwiseGemmOp : MIOpen_Op<"gridwise_gemm">; +def MIOpen_GridwiseGemmOp : + MIOpen_Op<"gridwise_gemm">, + Arguments<(ins MemRefRankOf<[F32], [2]>, + MemRefRankOf<[F32], [2]>, + MemRefRankOf<[F32], [2]>)> { + let summary = "Gridwise GEMM"; + let description = [{ + The `miopen.gridwise_gemm` op computes gridwise GEMM. + }]; +} #endif // MIOPEN_OPS diff --git a/mlir/lib/Dialect/MIOpenOps/MIOpenOps.cpp b/mlir/lib/Dialect/MIOpenOps/MIOpenOps.cpp index 9408e17b2831..b41423435e33 100644 --- a/mlir/lib/Dialect/MIOpenOps/MIOpenOps.cpp +++ b/mlir/lib/Dialect/MIOpenOps/MIOpenOps.cpp @@ -40,8 +40,6 @@ MIOpenOpsDialect::MIOpenOpsDialect(MLIRContext *context) #define GET_OP_LIST #include "mlir/Dialect/MIOpenOps/MIOpenOps.cpp.inc" >(); - - //addInterfaces<LoopSideEffectsInterface>(); } //===----------------------------------------------------------------------===// @@ -49,11 +47,19 @@ MIOpenOpsDialect::MIOpenOpsDialect(MLIRContext *context) //===----------------------------------------------------------------------===// static ParseResult parseConv2DOp(OpAsmParser &parser, OperationState &result) { - return success(); + SmallVector<OpAsmParser::OperandType, 3> ops; + SmallVector<Type, 3> types; + return failure( + parser.parseOperandList(ops, OpAsmParser::Delimiter::Paren) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonTypeList(types) || + parser.resolveOperands(ops, types, parser.getNameLoc(), result.operands)); } static void print(OpAsmPrinter &p, Conv2DOp op) { - p << Conv2DOp::getOperationName(); + p << op.getOperationName() << "(" << op.getOperands() << ")"; + p.printOptionalAttrDict(op.getAttrs()); + p << " : " << op.getOperandTypes(); } static LogicalResult verify(Conv2DOp op) { @@ -65,11 +71,24 @@ static LogicalResult verify(Conv2DOp op) { //===----------------------------------------------------------------------===// static ParseResult parseTransformOp(OpAsmParser &parser, OperationState &result) { + OpAsmParser::OperandType src; + Type srcType, dstType; + return failure( + parser.parseLParen() || + parser.parseOperand(src) || + parser.parseRParen() || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(srcType) || + parser.resolveOperand(src, srcType, result.operands) || + parser.parseKeywordType("to", dstType) || + parser.addTypeToList(dstType, result.types)); return success(); } static void print(OpAsmPrinter &p, TransformOp op) { - p << TransformOp::getOperationName(); + p << op.getOperationName() << "(" << op.getOperand() << ")"; + p.printOptionalAttrDict(op.getAttrs()); + p << " : " << op.getOperand()->getType() << " to " << op.getType(); } static LogicalResult verify(TransformOp op) { @@ -81,11 +100,19 @@ static LogicalResult verify(TransformOp op) { //===----------------------------------------------------------------------===// static ParseResult parseGridwiseGemmOp(OpAsmParser &parser, OperationState &result) { - return success(); + SmallVector<OpAsmParser::OperandType, 3> ops; + SmallVector<Type, 3> types; + return failure( + parser.parseOperandList(ops, OpAsmParser::Delimiter::Paren) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonTypeList(types) || + parser.resolveOperands(ops, types, parser.getNameLoc(), result.operands)); } static void print(OpAsmPrinter &p, GridwiseGemmOp op) { - p << GridwiseGemmOp::getOperationName(); + p << op.getOperationName() << "(" << op.getOperands() << ")"; + p.printOptionalAttrDict(op.getAttrs()); + p << " : " << op.getOperandTypes(); } static LogicalResult verify(GridwiseGemmOp op) { diff --git a/mlir/test/Dialect/MIOpen/ops.mlir b/mlir/test/Dialect/MIOpen/ops.mlir index a37e54110186..9b3b2e3db27e 100644 --- a/mlir/test/Dialect/MIOpen/ops.mlir +++ b/mlir/test/Dialect/MIOpen/ops.mlir @@ -3,22 +3,132 @@ // Run: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s func @miopen_conv2d(%filter : memref<?x?x?x?xf32>, %input : memref<?x?x?x?xf32>, %output : memref<?x?x?x?xf32>) { - miopen.conv2d + miopen.conv2d(%filter, %input, %output) { + filter_layout = ["k", "c", "y", "x"], + input_layout = ["n", "c", "hi", "wi"], + output_layout = ["n", "k", "ho", "wo"], + dilations = [1, 1], + strides = [1, 1], + padding = [0, 0] + } : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32> return } // CHECK-LABEL: func @miopen_conv2d // CHECK-NEXT: miopen.conv2d -func @miopen_transform(%memref : memref<?x?x?x?xf32>) { - miopen.transform +// test 1-1 dimension mappings. +func @miopen_transform_1_to_1(%memref: memref<?x?x?x?xf32>) { + %transformed_memref = miopen.transform(%memref) { + layout = [ + { + dimensions = [0], + names = ["n"], + transformation = "passthorugh", + source_dimensions = [0], + source_names = ["n"] + }, + { + dimensions = [1], + names = ["c"], + transformation = "passthorugh", + source_dimensions = [1], + source_names = ["c"] + }, + { + dimensions = [2], + names = ["hipad"], + transformation = "pad", + parameters = [0, 0], + source_dimensions = [2], + source_names = ["hi"] + }, + { + dimensions = [3], + names = ["wipad"], + transformation = "pad", + parameters = [0, 0], + source_dimensions = [3], + source_names = ["wi"] + } + ] + } : memref<?x?x?x?xf32> to memref<?x?x?x?xf32> + return +} +// CHECK-LABEL: func @miopen_transform_1_to_1 +// CHECK-NEXT: miopen.transform + +// test multiple source dimensions map to 1 target dimension. +func @miopen_transform_n_to_1(%memref : memref<?x?x?x?xf32>) { + %transformed_memref = miopen.transform(%memref) { + layout = [ + { + dimensions = [0], + names = ["gemmK"], + transformation = "merge", + source_dimensions = [1, 2, 3], + source_names = ["c", "y", "x"] + }, + { + dimensions = [1], + names = ["gemmM"], + transformation = "passthrough", + source_dimensions = [0], + source_names = ["n"] + } + ] + } : memref<?x?x?x?xf32> to memref<?x?xf32> + return +} +// CHECK-LABEL: func @miopen_transform_n_to_1 +// CHECK-NEXT: miopen.transform + +// test 1 source dimension map to multiple target dimensions. +func @miopen_transform_1_to_n(%memref : memref<?x?x?x?xf32>) { + %transformed_memref = miopen.transform(%memref) { + layout = [ + { + dimensions = [0], + names = ["n"], + transformation = "passthrough", + source_dimensions = [0], + source_names = ["n"] + }, + { + dimensions = [1], + names = ["c"], + transformation = "passthrough", + source_dimensions = [1], + source_names = ["c"] + }, + { + dimensions = [2, 3], + names = ["y", "ho"], + transformation = "embed", + parameters = [1, 1, 0], + source_dimensions = [2], + source_names = ["hipad"] + }, + { + dimensions = [4, 5], + names = ["x", "wo"], + transformation = "embed", + parameters = [1, 1, 0], + source_dimensions = [3], + source_names = ["wipad"] + } + ] + } : memref<?x?x?x?xf32> to memref<?x?x?x?x?x?xf32> return } -// CHECK-LABEL: func @miopen_transform +// CHECK-LABEL: func @miopen_transform_1_to_n // CHECK-NEXT: miopen.transform func @miopen_gridwise_gemm(%A : memref<?x?xf32>, %B : memref<?x?xf32>, %C : memref<?x?xf32>) { - miopen.gridwise_gemm + miopen.gridwise_gemm(%A, %B, %C) { + parameters = [ + ] + } : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32> return } _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits