Author: ergawy Date: 2020-12-16T08:26:48-05:00 New Revision: 6551c9ac365ca46e83354703d1a63c671a50258a
URL: https://github.com/llvm/llvm-project/commit/6551c9ac365ca46e83354703d1a63c671a50258a DIFF: https://github.com/llvm/llvm-project/commit/6551c9ac365ca46e83354703d1a63c671a50258a.diff LOG: [mlir][spirv] Add parsing and printing support for SpecConstantOperation Adds more support for `SpecConstantOperation` by defining a custom syntax for the op and implementing its parsing and printing. Reviewed By: mravishankar, antiagainst Differential Revision: https://reviews.llvm.org/D92919 Added: Modified: mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td mlir/lib/Dialect/SPIRV/SPIRVOps.cpp mlir/test/Dialect/SPIRV/structure-ops.mlir Removed: ################################################################################ diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td index b8e76c3662ec..1ae7d285cd93 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td @@ -608,9 +608,12 @@ def SPV_SpecConstantCompositeOp : SPV_Op<"specConstantComposite", [InModuleScope let autogenSerialization = 0; } -def SPV_YieldOp : SPV_Op<"mlir.yield", [NoSideEffect, Terminator]> { - let summary = "Yields the result computed in `spv.SpecConstantOperation`'s" - "region back to the parent op."; +def SPV_YieldOp : SPV_Op<"mlir.yield", [ + HasParent<"SpecConstantOperationOp">, NoSideEffect, Terminator]> { + let summary = [{ + Yields the result computed in `spv.SpecConstantOperation`'s + region back to the parent op. + }]; let description = [{ This op is a special terminator whose only purpose is to terminate @@ -639,12 +642,16 @@ def SPV_YieldOp : SPV_Op<"mlir.yield", [NoSideEffect, Terminator]> { let autogenSerialization = 0; let assemblyFormat = "attr-dict $operand `:` type($operand)"; + + let verifier = [{ return success(); }]; } def SPV_SpecConstantOperationOp : SPV_Op<"SpecConstantOperation", [ - InFunctionScope, NoSideEffect, - IsolatedFromAbove]> { - let summary = "Declare a new specialization constant that results from doing an operation."; + NoSideEffect, InFunctionScope, + SingleBlockImplicitTerminator<"YieldOp">]> { + let summary = [{ + Declare a new specialization constant that results from doing an operation. + }]; let description = [{ This op declares a SPIR-V specialization constant that results from @@ -653,12 +660,8 @@ def SPV_SpecConstantOperationOp : SPV_Op<"SpecConstantOperation", [ In the `spv` dialect, this op is modelled as follows: ``` - spv-spec-constant-operation-op ::= `"spv.SpecConstantOperation"` - `(`ssa-id (`, ` ssa-id)`)` - `({` - ssa-id = spirv-op - `spv.mlir.yield` ssa-id - `})` `:` function-type + spv-spec-constant-operation-op ::= `spv.SpecConstantOperation` `wraps` + generic-spirv-op `:` function-type ``` In particular, an `spv.SpecConstantOperation` contains exactly one @@ -712,17 +715,15 @@ def SPV_SpecConstantOperationOp : SPV_Op<"SpecConstantOperation", [ #### Example: ```mlir %0 = spv.constant 1: i32 + %1 = spv.constant 1: i32 - %1 = "spv.SpecConstantOperation"(%0) ({ - %ret = spv.IAdd %0, %0 : i32 - spv.mlir.yield %ret : i32 - }) : (i32) -> i32 + %2 = spv.SpecConstantOperation wraps "spv.IAdd"(%0, %1) : (i32, i32) -> i32 ``` }]; - let arguments = (ins Variadic<AnyType>:$operands); + let arguments = (ins); - let results = (outs AnyType:$results); + let results = (outs AnyType:$result); let regions = (region SizedRegion<1>:$body); diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 03e416e95441..43b3c517a4c6 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -3396,35 +3396,39 @@ static LogicalResult verify(spirv::SpecConstantCompositeOp constOp) { } //===----------------------------------------------------------------------===// -// spv.mlir.yield +// spv.SpecConstantOperation //===----------------------------------------------------------------------===// -static LogicalResult verify(spirv::YieldOp yieldOp) { - Operation *parentOp = yieldOp->getParentOp(); +static ParseResult parseSpecConstantOperationOp(OpAsmParser &parser, + OperationState &state) { + Region *body = state.addRegion(); - if (!parentOp || !isa<spirv::SpecConstantOperationOp>(parentOp)) - return yieldOp.emitOpError( - "expected parent op to be 'spv.SpecConstantOperation'"); + if (parser.parseKeyword("wraps")) + return failure(); - Block &block = parentOp->getRegion(0).getBlocks().front(); - Operation &enclosedOp = block.getOperations().front(); + body->push_back(new Block); + Block &block = body->back(); + Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin()); - if (yieldOp.getOperand().getDefiningOp() != &enclosedOp) - return yieldOp.emitOpError( - "expected operand to be defined by preceeding op"); + if (!wrappedOp) + return failure(); - return success(); -} + OpBuilder builder(parser.getBuilder().getContext()); + builder.setInsertionPointToEnd(&block); + builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0)); + state.location = wrappedOp->getLoc(); -static ParseResult parseSpecConstantOperationOp(OpAsmParser &parser, - OperationState &state) { - // TODO: For now, only generic form is supported. - return failure(); + state.addTypes(wrappedOp->getResult(0).getType()); + + if (parser.parseOptionalAttrDict(state.attributes)) + return failure(); + + return success(); } static void print(spirv::SpecConstantOperationOp op, OpAsmPrinter &printer) { - // TODO - printer.printGenericOp(op); + printer << op.getOperationName() << " wraps "; + printer.printGenericOp(&op.body().front().front()); } static LogicalResult verify(spirv::SpecConstantOperationOp constOp) { @@ -3433,11 +3437,6 @@ static LogicalResult verify(spirv::SpecConstantOperationOp constOp) { if (block.getOperations().size() != 2) return constOp.emitOpError("expected exactly 2 nested ops"); - Operation &yieldOp = block.getOperations().back(); - - if (!isa<spirv::YieldOp>(yieldOp)) - return constOp.emitOpError("expected terminator to be a yield op"); - Operation &enclosedOp = block.getOperations().front(); // TODO Add a `UsableInSpecConstantOp` trait and mark ops from the list below @@ -3457,21 +3456,12 @@ static LogicalResult verify(spirv::SpecConstantOperationOp constOp) { spirv::UGreaterThanEqualOp, spirv::SGreaterThanEqualOp>(enclosedOp)) return constOp.emitOpError("invalid enclosed op"); - if (enclosedOp.getNumOperands() != constOp.getOperands().size()) - return constOp.emitOpError("invalid number of operands; expected ") - << enclosedOp.getNumOperands() << ", actual " - << constOp.getOperands().size(); - - if (enclosedOp.getNumOperands() != constOp.getRegion().getNumArguments()) - return constOp.emitOpError("invalid number of region arguments; expected ") - << enclosedOp.getNumOperands() << ", actual " - << constOp.getRegion().getNumArguments(); - - for (auto operand : constOp.getOperands()) + for (auto operand : enclosedOp.getOperands()) if (!isa<spirv::ConstantOp, spirv::SpecConstantOp, spirv::SpecConstantCompositeOp, spirv::SpecConstantOperationOp>( operand.getDefiningOp())) - return constOp.emitOpError("invalid operand"); + return constOp.emitOpError( + "invalid operand, must be defined by a constant operation"); return success(); } diff --git a/mlir/test/Dialect/SPIRV/structure-ops.mlir b/mlir/test/Dialect/SPIRV/structure-ops.mlir index 89a30e23dec9..c0b495115d6c 100644 --- a/mlir/test/Dialect/SPIRV/structure-ops.mlir +++ b/mlir/test/Dialect/SPIRV/structure-ops.mlir @@ -757,6 +757,7 @@ spv.module Logical GLSL450 { // expected-error @+1 {{unsupported composite type}} spv.specConstantComposite @scc (@sc1) : !spv.coopmatrix<8x16xf32, Device> } + //===----------------------------------------------------------------------===// // spv.SpecConstantOperation //===----------------------------------------------------------------------===// @@ -765,34 +766,15 @@ spv.module Logical GLSL450 { spv.module Logical GLSL450 { spv.func @foo() -> i32 "None" { + // CHECK: [[LHS:%.*]] = spv.constant %0 = spv.constant 1: i32 - %2 = spv.constant 1: i32 - - %1 = "spv.SpecConstantOperation"(%0, %0) ({ - ^bb(%lhs : i32, %rhs : i32): - %ret = spv.IAdd %lhs, %rhs : i32 - spv.mlir.yield %ret : i32 - }) : (i32, i32) -> i32 - - spv.ReturnValue %1 : i32 - } -} - -// ----- - -spv.module Logical GLSL450 { - spv.func @foo() -> i32 "None" { - %0 = spv.constant 1: i32 - %2 = spv.constant 1: i32 + // CHECK: [[RHS:%.*]] = spv.constant + %1 = spv.constant 1: i32 - // expected-error @+1 {{invalid number of operands; expected 2, actual 1}} - %1 = "spv.SpecConstantOperation"(%0) ({ - ^bb(%lhs : i32, %rhs : i32): - %ret = spv.IAdd %lhs, %rhs : i32 - spv.mlir.yield %ret : i32 - }) : (i32) -> i32 + // CHECK: spv.SpecConstantOperation wraps "spv.IAdd"([[LHS]], [[RHS]]) : (i32, i32) -> i32 + %2 = spv.SpecConstantOperation wraps "spv.IAdd"(%0, %1) : (i32, i32) -> i32 - spv.ReturnValue %1 : i32 + spv.ReturnValue %2 : i32 } } @@ -801,93 +783,20 @@ spv.module Logical GLSL450 { spv.module Logical GLSL450 { spv.func @foo() -> i32 "None" { %0 = spv.constant 1: i32 - %2 = spv.constant 1: i32 - - // expected-error @+1 {{invalid number of region arguments; expected 2, actual 1}} - %1 = "spv.SpecConstantOperation"(%0, %0) ({ - ^bb(%lhs : i32): - %ret = spv.IAdd %lhs, %lhs : i32 - spv.mlir.yield %ret : i32 - }) : (i32, i32) -> i32 - - spv.ReturnValue %1 : i32 - } -} - -// ----- - -spv.module Logical GLSL450 { - spv.func @foo() -> i32 "None" { - %0 = spv.constant 1: i32 - // expected-error @+1 {{expected parent op to be 'spv.SpecConstantOperation'}} + // expected-error @+1 {{op expects parent op 'spv.SpecConstantOperation'}} spv.mlir.yield %0 : i32 } } // ----- -spv.module Logical GLSL450 { - spv.func @foo() -> i32 "None" { - %0 = spv.constant 1: i32 - - %1 = "spv.SpecConstantOperation"(%0, %0) ({ - ^bb(%lhs : i32, %rhs : i32): - %ret = spv.ISub %lhs, %rhs : i32 - // expected-error @+1 {{expected operand to be defined by preceeding op}} - spv.mlir.yield %lhs : i32 - }) : (i32, i32) -> i32 - - spv.ReturnValue %1 : i32 - } -} - -// ----- - -spv.module Logical GLSL450 { - spv.func @foo() -> i32 "None" { - %0 = spv.constant 1: i32 - - // expected-error @+1 {{expected exactly 2 nested ops}} - %1 = "spv.SpecConstantOperation"(%0, %0) ({ - ^bb(%lhs : i32, %rhs : i32): - %ret = spv.IAdd %lhs, %rhs : i32 - %ret2 = spv.IAdd %lhs, %rhs : i32 - spv.mlir.yield %ret : i32 - }) : (i32, i32) -> i32 - - spv.ReturnValue %1 : i32 - } -} - -// ----- - -spv.module Logical GLSL450 { - spv.func @foo() -> i32 "None" { - %0 = spv.constant 1: i32 - - // expected-error @+1 {{expected terminator to be a yield op}} - %1 = "spv.SpecConstantOperation"(%0, %0) ({ - ^bb(%lhs : i32, %rhs : i32): - %ret = spv.IAdd %lhs, %rhs : i32 - spv.ReturnValue %ret : i32 - }) : (i32, i32) -> i32 - - spv.ReturnValue %1 : i32 - } -} - -// ----- - spv.module Logical GLSL450 { spv.func @foo() -> () "None" { %0 = spv.Variable : !spv.ptr<i32, Function> // expected-error @+1 {{invalid enclosed op}} - %2 = "spv.SpecConstantOperation"(%0) ({ - ^bb(%arg0 : !spv.ptr<i32, Function>): - %ret = spv.Load "Function" %arg0 : i32 - spv.mlir.yield %ret : i32 - }) : (!spv.ptr<i32, Function>) -> i32 + %1 = spv.SpecConstantOperation wraps "spv.Load"(%0) {memory_access = 0 : i32} : (!spv.ptr<i32, Function>) -> i32 + spv.Return } } @@ -898,11 +807,9 @@ spv.module Logical GLSL450 { %0 = spv.Variable : !spv.ptr<i32, Function> %1 = spv.Load "Function" %0 : i32 - // expected-error @+1 {{invalid operand}} - %2 = "spv.SpecConstantOperation"(%1, %1) ({ - ^bb(%lhs: i32, %rhs: i32): - %ret = spv.IAdd %lhs, %lhs : i32 - spv.mlir.yield %ret : i32 - }) : (i32, i32) -> i32 + // expected-error @+1 {{invalid operand, must be defined by a constant operation}} + %2 = spv.SpecConstantOperation wraps "spv.IAdd"(%1, %1) : (i32, i32) -> i32 + + spv.Return } } _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits