https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/123326
Remove `type.isFloat4E2M1FN()` etc. Use `isa<Float4E2M1FNType>(type)` instead. For details, see: https://discourse.llvm.org/t/rethink-on-approach-to-low-precision-fp-types/82361/28 Depends on #123321. >From 55825a999595222141f79a812c72c57cebd598d8 Mon Sep 17 00:00:00 2001 From: Matthias Springer <msprin...@nvidia.com> Date: Fri, 17 Jan 2025 12:31:38 +0100 Subject: [PATCH] [mlir][IR] Remove `isF...()` type API for low-precision FP types --- mlir/include/mlir/IR/CommonTypeConstraints.td | 26 ++++++------ mlir/include/mlir/IR/Types.h | 11 ----- mlir/lib/CAPI/IR/BuiltinTypes.cpp | 40 +++++++++++-------- .../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 38 +++++++++--------- .../ArithToAMDGPU/ArithToAMDGPU.cpp | 4 +- .../Conversion/LLVMCommon/TypeConverter.cpp | 9 ++--- .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 8 ++-- mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 4 +- mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp | 6 +-- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 3 +- mlir/lib/IR/Types.cpp | 19 --------- 11 files changed, 73 insertions(+), 95 deletions(-) diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td index 6f52195c1d7c92..e752cdfb47fbb1 100644 --- a/mlir/include/mlir/IR/CommonTypeConstraints.td +++ b/mlir/include/mlir/IR/CommonTypeConstraints.td @@ -329,31 +329,31 @@ def F64 : F<64>; def F80 : F<80>; def F128 : F<128>; -def BF16 : Type<CPred<"$_self.isBF16()">, "bfloat16 type">, +def BF16 : Type<CPred<"::llvm::isa<BFloat16Type>($_self)">, "bfloat16 type">, BuildableType<"$_builder.getType<BFloat16Type>()">; -def TF32 : Type<CPred<"$_self.isTF32()">, "tf32 type">, +def TF32 : Type<CPred<"::llvm::isa<FloatTF32Type>($_self)">, "tf32 type">, BuildableType<"$_builder.getType<FloatTF32Type>()">; -def F8E4M3FN : Type<CPred<"$_self.isFloat8E4M3FN()">, "f8E4M3FN type">, +def F8E4M3FN : Type<CPred<"::llvm::isa<Float8E4M3FNType>($_self)">, "f8E4M3FN type">, BuildableType<"$_builder.getType<Float8E4M3FNType>()">; -def F8E5M2 : Type<CPred<"$_self.isFloat8E5M2()">, "f8E5M2 type">, +def F8E5M2 : Type<CPred<"::llvm::isa<Float8E5M2Type>($_self)">, "f8E5M2 type">, BuildableType<"$_builder.getType<Float8E5M2Type>()">; -def F8E4M3 : Type<CPred<"$_self.isFloat8E4M3()">, "f8E4M3 type">, +def F8E4M3 : Type<CPred<"::llvm::isa<Float8E4M3Type>($_self)">, "f8E4M3 type">, BuildableType<"$_builder.getType<Float8E4M3Type>()">; -def F8E4M3FNUZ : Type<CPred<"$_self.isFloat8E4M3FNUZ()">, "f8E4M3FNUZ type">, +def F8E4M3FNUZ : Type<CPred<"::llvm::isa<Float8E4M3FNUZType>($_self)">, "f8E4M3FNUZ type">, BuildableType<"$_builder.getType<Float8E4M3FNUZType>()">; -def F8E4M3B11FNUZ : Type<CPred<"$_self.isFloat8E4M3B11FNUZ()">, "f8E4M3B11FNUZ type">, +def F8E4M3B11FNUZ : Type<CPred<"::llvm::isa<Float8E4M3B11FNUZType>($_self)">, "f8E4M3B11FNUZ type">, BuildableType<"$_builder.getType<Float8E4M3B11FNUZType>()">; -def F8E5M2FNUZ : Type<CPred<"$_self.isFloat8E5M2FNUZ()">, "f8E5M2FNUZ type">, +def F8E5M2FNUZ : Type<CPred<"::llvm::isa<Float8E5M2FNUZType>($_self)">, "f8E5M2FNUZ type">, BuildableType<"$_builder.getType<Float8E5M2FNUZType>()">; -def F8E3M4 : Type<CPred<"$_self.isFloat8E3M4()">, "f8E3M4 type">, +def F8E3M4 : Type<CPred<"::llvm::isa<Float8E3M4Type>($_self)">, "f8E3M4 type">, BuildableType<"$_builder.getType<Float8E3M4Type>()">; -def F4E2M1FN : Type<CPred<"$_self.isFloat4E2M1FN()">, "f4E2M1FN type">, +def F4E2M1FN : Type<CPred<"::llvm::isa<Float4E2M1FNType>($_self)">, "f4E2M1FN type">, BuildableType<"$_builder.getType<Float4E2M1FNType>()">; -def F6E2M3FN : Type<CPred<"$_self.isFloat6E2M3FN()">, "f6E2M3FN type">, +def F6E2M3FN : Type<CPred<"::llvm::isa<Float6E2M3FNType>($_self)">, "f6E2M3FN type">, BuildableType<"$_builder.getType<Float6E2M3FNType>()">; -def F6E3M2FN : Type<CPred<"$_self.isFloat6E3M2FN()">, "f6E3M2FN type">, +def F6E3M2FN : Type<CPred<"::llvm::isa<Float6E3M2FNType($_self)">, "f6E3M2FN type">, BuildableType<"$_builder.getType<Float6E3M2FNType>()">; -def F8E8M0FNU : Type<CPred<"$_self.isFloat8E8M0FNU()">, "f8E8M0FNU type">, +def F8E8M0FNU : Type<CPred<"::llvm::isa<Float8E8M0FNUType>($_self)">, "f8E8M0FNU type">, BuildableType<"$_builder.getType<Float8E8M0FNUType>()">; def AnyComplex : Type<CPred<"::llvm::isa<::mlir::ComplexType>($_self)">, diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h index acd0f894abbbe6..0e82ad2be907ab 100644 --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -125,17 +125,6 @@ class Type { // Convenience predicates. This is only for floating point types, // derived types should use isa/dyn_cast. bool isIndex() const; - bool isFloat4E2M1FN() const; - bool isFloat6E2M3FN() const; - bool isFloat6E3M2FN() const; - bool isFloat8E5M2() const; - bool isFloat8E4M3() const; - bool isFloat8E4M3FN() const; - bool isFloat8E5M2FNUZ() const; - bool isFloat8E4M3FNUZ() const; - bool isFloat8E4M3B11FNUZ() const; - bool isFloat8E3M4() const; - bool isFloat8E8M0FNU() const; bool isBF16() const; bool isF16() const; bool isTF32() const; diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index 250e4a6bbf8dfd..313d6830b41b2a 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -90,7 +90,7 @@ MlirTypeID mlirFloat4E2M1FNTypeGetTypeID() { } bool mlirTypeIsAFloat4E2M1FN(MlirType type) { - return unwrap(type).isFloat4E2M1FN(); + return llvm::isa<Float4E2M1FNType>(unwrap(type)); } MlirType mlirFloat4E2M1FNTypeGet(MlirContext ctx) { @@ -102,7 +102,7 @@ MlirTypeID mlirFloat6E2M3FNTypeGetTypeID() { } bool mlirTypeIsAFloat6E2M3FN(MlirType type) { - return unwrap(type).isFloat6E2M3FN(); + return llvm::isa<Float6E2M3FNType>(unwrap(type)); } MlirType mlirFloat6E2M3FNTypeGet(MlirContext ctx) { @@ -114,7 +114,7 @@ MlirTypeID mlirFloat6E3M2FNTypeGetTypeID() { } bool mlirTypeIsAFloat6E3M2FN(MlirType type) { - return unwrap(type).isFloat6E3M2FN(); + return llvm::isa<Float6E3M2FNType>(unwrap(type)); } MlirType mlirFloat6E3M2FNTypeGet(MlirContext ctx) { @@ -126,7 +126,7 @@ MlirTypeID mlirFloat8E5M2TypeGetTypeID() { } bool mlirTypeIsAFloat8E5M2(MlirType type) { - return unwrap(type).isFloat8E5M2(); + return llvm::isa<Float8E5M2Type>(unwrap(type)); } MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) { @@ -138,7 +138,7 @@ MlirTypeID mlirFloat8E4M3TypeGetTypeID() { } bool mlirTypeIsAFloat8E4M3(MlirType type) { - return unwrap(type).isFloat8E4M3(); + return llvm::isa<Float8E4M3Type>(unwrap(type)); } MlirType mlirFloat8E4M3TypeGet(MlirContext ctx) { @@ -150,7 +150,7 @@ MlirTypeID mlirFloat8E4M3FNTypeGetTypeID() { } bool mlirTypeIsAFloat8E4M3FN(MlirType type) { - return unwrap(type).isFloat8E4M3FN(); + return llvm::isa<Float8E4M3FNType>(unwrap(type)); } MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx) { @@ -162,7 +162,7 @@ MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID() { } bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type) { - return unwrap(type).isFloat8E5M2FNUZ(); + return llvm::isa<Float8E5M2FNUZType>(unwrap(type)); } MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx) { @@ -174,7 +174,7 @@ MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID() { } bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type) { - return unwrap(type).isFloat8E4M3FNUZ(); + return llvm::isa<Float8E4M3FNUZType>(unwrap(type)); } MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx) { @@ -186,7 +186,7 @@ MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID() { } bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type) { - return unwrap(type).isFloat8E4M3B11FNUZ(); + return llvm::isa<Float8E4M3B11FNUZType>(unwrap(type)); } MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx) { @@ -198,7 +198,7 @@ MlirTypeID mlirFloat8E3M4TypeGetTypeID() { } bool mlirTypeIsAFloat8E3M4(MlirType type) { - return unwrap(type).isFloat8E3M4(); + return llvm::isa<Float8E3M4Type>(unwrap(type)); } MlirType mlirFloat8E3M4TypeGet(MlirContext ctx) { @@ -210,7 +210,7 @@ MlirTypeID mlirFloat8E8M0FNUTypeGetTypeID() { } bool mlirTypeIsAFloat8E8M0FNU(MlirType type) { - return unwrap(type).isFloat8E8M0FNU(); + return llvm::isa<Float8E8M0FNUType>(unwrap(type)); } MlirType mlirFloat8E8M0FNUTypeGet(MlirContext ctx) { @@ -221,7 +221,9 @@ MlirTypeID mlirBFloat16TypeGetTypeID() { return wrap(BFloat16Type::getTypeID()); } -bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); } +bool mlirTypeIsABF16(MlirType type) { + return llvm::isa<BFloat16Type>(unwrap(type)); +} MlirType mlirBF16TypeGet(MlirContext ctx) { return wrap(BFloat16Type::get(unwrap(ctx))); @@ -229,7 +231,9 @@ MlirType mlirBF16TypeGet(MlirContext ctx) { MlirTypeID mlirFloat16TypeGetTypeID() { return wrap(Float16Type::getTypeID()); } -bool mlirTypeIsAF16(MlirType type) { return unwrap(type).isF16(); } +bool mlirTypeIsAF16(MlirType type) { + return llvm::isa<Float16Type>(unwrap(type)); +} MlirType mlirF16TypeGet(MlirContext ctx) { return wrap(Float16Type::get(unwrap(ctx))); @@ -239,7 +243,7 @@ MlirTypeID mlirFloatTF32TypeGetTypeID() { return wrap(FloatTF32Type::getTypeID()); } -bool mlirTypeIsATF32(MlirType type) { return unwrap(type).isTF32(); } +bool mlirTypeIsATF32(MlirType type) { return llvm::isa<FloatTF32Type>(type); } MlirType mlirTF32TypeGet(MlirContext ctx) { return wrap(FloatTF32Type::get(unwrap(ctx))); @@ -247,7 +251,9 @@ MlirType mlirTF32TypeGet(MlirContext ctx) { MlirTypeID mlirFloat32TypeGetTypeID() { return wrap(Float32Type::getTypeID()); } -bool mlirTypeIsAF32(MlirType type) { return unwrap(type).isF32(); } +bool mlirTypeIsAF32(MlirType type) { + return llvm::isa<Float32Type>(unwrap(type)); +} MlirType mlirF32TypeGet(MlirContext ctx) { return wrap(Float32Type::get(unwrap(ctx))); @@ -255,7 +261,9 @@ MlirType mlirF32TypeGet(MlirContext ctx) { MlirTypeID mlirFloat64TypeGetTypeID() { return wrap(Float64Type::getTypeID()); } -bool mlirTypeIsAF64(MlirType type) { return unwrap(type).isF64(); } +bool mlirTypeIsAF64(MlirType type) { + return llvm::isa<Float64Type>(unwrap(type)); +} MlirType mlirF64TypeGet(MlirContext ctx) { return wrap(Float64Type::get(unwrap(ctx))); diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 1564e417a7a48e..5d09d6f1d69523 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -564,38 +564,40 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma, return ROCDL::mfma_f64_4x4x4f64::getOperationName(); } - if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() && chipset >= kGfx940) { + if (isa<Float8E5M2FNUZType>(sourceElem) && destElem.isF32() && + chipset >= kGfx940) { // Known to be correct because there are no scalar f8 instructions and // because a length mismatch will have been caught by the verifier. Type sourceBElem = cast<VectorType>(mfma.getSourceB().getType()).getElementType(); if (m == 16 && n == 16 && k == 32 && b == 1) { - if (sourceBElem.isFloat8E5M2FNUZ()) + if (isa<Float8E5M2FNUZType>(sourceBElem)) return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName(); - if (sourceBElem.isFloat8E4M3FNUZ()) + if (isa<Float8E4M3FNUZType>(sourceBElem)) return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName(); } if (m == 32 && n == 32 && k == 16 && b == 1) { - if (sourceBElem.isFloat8E5M2FNUZ()) + if (isa<Float8E5M2FNUZType>(sourceBElem)) return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName(); - if (sourceBElem.isFloat8E4M3FNUZ()) + if (isa<Float8E4M3FNUZType>(sourceBElem)) return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName(); } } - if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() && chipset >= kGfx940) { + if (isa<Float8E4M3FNUZType>(sourceElem) && destElem.isF32() && + chipset >= kGfx940) { Type sourceBElem = cast<VectorType>(mfma.getSourceB().getType()).getElementType(); if (m == 16 && n == 16 && k == 32 && b == 1) { - if (sourceBElem.isFloat8E5M2FNUZ()) + if (isa<Float8E5M2FNUZType>(sourceBElem)) return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName(); - if (sourceBElem.isFloat8E4M3FNUZ()) + if (isa<Float8E4M3FNUZType>(sourceBElem)) return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName(); } if (m == 32 && n == 32 && k == 16 && b == 1) { - if (sourceBElem.isFloat8E5M2FNUZ()) + if (isa<Float8E5M2FNUZType>(sourceBElem)) return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName(); - if (sourceBElem.isFloat8E4M3FNUZ()) + if (isa<Float8E4M3FNUZType>(sourceBElem)) return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName(); } } @@ -623,9 +625,9 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma, return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName(); if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) return ROCDL::wmma_i32_16x16x16_iu8::getOperationName(); - if (elemSourceType.isFloat8E4M3FN() && elemDestType.isF32()) + if (isa<Float8E4M3FNType>(elemSourceType) && elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_fp8::getOperationName(); - if (elemSourceType.isFloat8E5M2() && elemDestType.isF32()) + if (isa<Float8E5M2Type>(elemSourceType) && elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_bf8::getOperationName(); return std::nullopt; } @@ -803,10 +805,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( } Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source); Value wordSel = createI32Constant(rewriter, loc, op.getIndex()); - if (sourceElemType.isFloat8E5M2FNUZ()) { + if (isa<Float8E5M2FNUZType>(sourceElemType)) { rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source, wordSel); - } else if (sourceElemType.isFloat8E4M3FNUZ()) { + } else if (isa<Float8E4M3FNUZType>(sourceElemType)) { rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source, wordSel); } @@ -838,10 +840,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite( Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex()); Value result; - if (resultElemType.isFloat8E5M2FNUZ()) + if (isa<Float8E5M2FNUZType>(resultElemType)) result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB, existing, wordSel); - else if (resultElemType.isFloat8E4M3FNUZ()) + else if (isa<Float8E4M3FNUZType>(resultElemType)) result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB, existing, wordSel); @@ -873,10 +875,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite( Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex()); Value result; - if (resultElemType.isFloat8E5M2FNUZ()) + if (isa<Float8E5M2FNUZType>(resultElemType)) result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch, existing, byteSel); - else if (resultElemType.isFloat8E4M3FNUZ()) + else if (isa<Float8E4M3FNUZType>(resultElemType)) result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch, existing, byteSel); diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp index a8283023afc53d..33370566996eee 100644 --- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp +++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp @@ -86,7 +86,7 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const { return failure(); inType = inVecType.getElementType(); } - return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ()); + return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(inType)); } void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op, @@ -216,7 +216,7 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const { if (inType && inType.getWidth() <= 8 && saturateFP8) // Conversion between 8-bit floats is not supported with truncation enabled. return failure(); - return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ()); + return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(outType)); } void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op, diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp index 64bdb248dff430..247a8ab28a44be 100644 --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -299,11 +299,10 @@ Type LLVMTypeConverter::convertFloatType(FloatType type) const { return type; // F4, F6, F8 types are converted to integer types with the same bit width. - if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() || - type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() || - type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4() || - type.isFloat4E2M1FN() || type.isFloat6E2M3FN() || type.isFloat6E3M2FN() || - type.isFloat8E8M0FNU()) + if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType, + Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type, + Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType, + Float8E8M0FNUType>(type)) return IntegerType::get(&getContext(), type.getWidth()); // Other floating-point types: A custom type conversion rule must be diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 34a6b1d506540d..7e97fb84434f89 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -1254,8 +1254,8 @@ struct NVGPUWarpgroupMmaOpLowering wgmmaK = 8; } else if (inputElemType.isF16() || inputElemType.isBF16()) { wgmmaK = 16; - } else if (inputElemType.isFloat8E4M3FN() || - inputElemType.isFloat8E5M2() || inputElemType.isInteger(16)) { + } else if (isa<Float8E4M3FNType, Float8E5M2Type>(inputElemType) || + inputElemType.isInteger(16)) { wgmmaK = 32; } else if (inputElemType.isInteger(1)) { wgmmaK = 256; @@ -1276,9 +1276,9 @@ struct NVGPUWarpgroupMmaOpLowering return NVVM::WGMMATypes::f16; if (elemType.isBF16()) return NVVM::WGMMATypes::bf16; - if (elemType.isFloat8E4M3FN()) + if (isa<Float8E4M3FNType>(elemType)) return NVVM::WGMMATypes::e4m3; - if (elemType.isFloat8E5M2()) + if (isa<Float8E5M2Type>(elemType)) return NVVM::WGMMATypes::e5m2; if (elemType.isInteger(1)) return NVVM::WGMMATypes::b1; diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 492e4781f57810..5af0cb0c7ba1cc 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -272,14 +272,14 @@ LogicalResult MFMAOp::verify() { } Type sourceBType = getSourceB().getType(); - if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ()) { + if (isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(sourceElem)) { int64_t sourceBLen = 1; Type sourceBElem = sourceBType; if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) { sourceBLen = sourceBVector.getNumElements(); sourceBElem = sourceBVector.getElementType(); } - if (!sourceBElem.isFloat8E5M2FNUZ() && !sourceBElem.isFloat8E4M3FNUZ()) + if (!isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(sourceBElem)) return emitOpError("expected both source operands to have f8 elements"); if (sourceLen != sourceBLen) return emitOpError( diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp index a027350e8a5f70..47d1b8492e06ec 100644 --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -525,8 +525,8 @@ LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB) { return success(); // F16 += f8 + f8 // F32 += f8 + f8 - if ((typeA.isFloat8E5M2() || typeA.isFloat8E4M3FN()) && - (typeB.isFloat8E5M2() || typeB.isFloat8E4M3FN()) && + if (isa<Float8E5M2Type, Float8E4M3FNType>(typeA) && + isa<Float8E5M2Type, Float8E4M3FNType>(typeB) && (typeD.isF32() || typeD.isF16())) return success(); @@ -548,7 +548,7 @@ LogicalResult isAllowedSizeN(int sizeN, Type typeA) { 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256}; if (typeA.isBF16() || typeA.isF16() || typeA.isF32() || typeA.isTF32() || - typeA.isFloat8E4M3FN() || typeA.isFloat8E5M2()) + isa<Float8E5M2Type, Float8E4M3FNType>(typeA)) if (llvm::is_contained(allowedN, sizeN)) return success(); diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 83cf4a9415fe68..9590fa6fa5a8b2 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -306,8 +306,7 @@ static LogicalResult verifyConvOpModes(T op) { if (inputEType.isInteger(16) && !accType.isInteger(48)) return op.emitOpError("accumulator type for i16 tensor is not i48"); - if ((inputEType.isFloat8E5M2() || inputEType.isFloat8E4M3()) && - !accType.isF16()) + if (isa<Float8E5M2Type, Float8E4M3FNType>(inputEType) && !accType.isF16()) return op.emitOpError("accumulator type for f8 tensor is not f16"); if (inputEType.isF16() && !(accType.isF16() || accType.isF32())) diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp index e190902b2e4898..bca90de6f4a8a5 100644 --- a/mlir/lib/IR/Types.cpp +++ b/mlir/lib/IR/Types.cpp @@ -34,25 +34,6 @@ Type AbstractType::replaceImmediateSubElements(Type type, MLIRContext *Type::getContext() const { return getDialect().getContext(); } -bool Type::isFloat4E2M1FN() const { return llvm::isa<Float4E2M1FNType>(*this); } -bool Type::isFloat6E2M3FN() const { return llvm::isa<Float6E2M3FNType>(*this); } -bool Type::isFloat6E3M2FN() const { return llvm::isa<Float6E3M2FNType>(*this); } -bool Type::isFloat8E5M2() const { return llvm::isa<Float8E5M2Type>(*this); } -bool Type::isFloat8E4M3() const { return llvm::isa<Float8E4M3Type>(*this); } -bool Type::isFloat8E4M3FN() const { return llvm::isa<Float8E4M3FNType>(*this); } -bool Type::isFloat8E5M2FNUZ() const { - return llvm::isa<Float8E5M2FNUZType>(*this); -} -bool Type::isFloat8E4M3FNUZ() const { - return llvm::isa<Float8E4M3FNUZType>(*this); -} -bool Type::isFloat8E4M3B11FNUZ() const { - return llvm::isa<Float8E4M3B11FNUZType>(*this); -} -bool Type::isFloat8E8M0FNU() const { - return llvm::isa<Float8E8M0FNUType>(*this); -} -bool Type::isFloat8E3M4() const { return llvm::isa<Float8E3M4Type>(*this); } bool Type::isBF16() const { return llvm::isa<BFloat16Type>(*this); } bool Type::isF16() const { return llvm::isa<Float16Type>(*this); } bool Type::isTF32() const { return llvm::isa<FloatTF32Type>(*this); } _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits