Author: Adrian Kuegel Date: 2020-12-11T13:20:53+01:00 New Revision: 91220705632ed20dd06d1c0dc21b888302ee324e
URL: https://github.com/llvm/llvm-project/commit/91220705632ed20dd06d1c0dc21b888302ee324e DIFF: https://github.com/llvm/llvm-project/commit/91220705632ed20dd06d1c0dc21b888302ee324e.diff LOG: [mlir] Expose target configuration for lowering to ROCDL. Differential Revision: https://reviews.llvm.org/D93028 Added: Modified: mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp Removed: ################################################################################ diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h index 8be0a7cad017..233b947bcfed 100644 --- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h +++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h @@ -22,7 +22,7 @@ namespace gpu { class GPUModuleOp; } -/// Configure target to convert from to convert from the GPU dialect to NVVM. +/// Configure target to convert from the GPU dialect to NVVM. void configureGpuToNVVMConversionLegality(ConversionTarget &target); /// Collect a set of patterns to convert from the GPU dialect to NVVM. diff --git a/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h b/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h index 677782b2dc67..5fa798bf2834 100644 --- a/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h +++ b/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h @@ -14,6 +14,7 @@ namespace mlir { class LLVMTypeConverter; class OwningRewritePatternList; +class ConversionTarget; template <typename OpT> class OperationPass; @@ -26,6 +27,9 @@ class GPUModuleOp; void populateGpuToROCDLConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns); +/// Configure target to convert from the GPU dialect to ROCDL. +void configureGpuToROCDLConversionLegality(ConversionTarget &target); + /// Creates a pass that lowers GPU dialect operations to ROCDL counterparts. The /// index bitwidth used for the lowering of the device side index computations /// is configurable. diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp index e8c8a1fc3eb9..4ed1f0761c92 100644 --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -69,14 +69,7 @@ struct LowerGpuOpsToROCDLOpsPass populateStdToLLVMConversionPatterns(converter, llvmPatterns); populateGpuToROCDLConversionPatterns(converter, llvmPatterns); LLVMConversionTarget target(getContext()); - target.addIllegalDialect<gpu::GPUDialect>(); - target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::FAbsOp, LLVM::FCeilOp, - LLVM::FFloorOp, LLVM::LogOp, LLVM::Log10Op, - LLVM::Log2Op, LLVM::SinOp, LLVM::SqrtOp>(); - target.addIllegalOp<FuncOp>(); - target.addLegalDialect<ROCDL::ROCDLDialect>(); - // TODO: Remove once we support replacing non-root ops. - target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>(); + configureGpuToROCDLConversionLegality(target); if (failed(applyPartialConversion(m, target, std::move(llvmPatterns)))) signalPassFailure(); } @@ -84,6 +77,19 @@ struct LowerGpuOpsToROCDLOpsPass } // anonymous namespace +void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) { + target.addIllegalOp<FuncOp>(); + target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); + target.addLegalDialect<ROCDL::ROCDLDialect>(); + target.addIllegalDialect<gpu::GPUDialect>(); + target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::FAbsOp, LLVM::FCeilOp, + LLVM::FFloorOp, LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op, + LLVM::SinOp, LLVM::SqrtOp>(); + + // TODO: Remove once we support replacing non-root ops. + target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>(); +} + void mlir::populateGpuToROCDLConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { populateWithGenerated(converter.getDialect()->getContext(), patterns); _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits