https://github.com/matthias-springer created 
https://github.com/llvm/llvm-project/pull/123326

Remove `type.isFloat4E2M1FN()` etc. Use `isa<Float4E2M1FNType>(type)` instead.

For details, see: 
https://discourse.llvm.org/t/rethink-on-approach-to-low-precision-fp-types/82361/28

Depends on #123321.


>From 55825a999595222141f79a812c72c57cebd598d8 Mon Sep 17 00:00:00 2001
From: Matthias Springer <msprin...@nvidia.com>
Date: Fri, 17 Jan 2025 12:31:38 +0100
Subject: [PATCH] [mlir][IR] Remove `isF...()` type API for low-precision FP
 types

---
 mlir/include/mlir/IR/CommonTypeConstraints.td | 26 ++++++------
 mlir/include/mlir/IR/Types.h                  | 11 -----
 mlir/lib/CAPI/IR/BuiltinTypes.cpp             | 40 +++++++++++--------
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 38 +++++++++---------
 .../ArithToAMDGPU/ArithToAMDGPU.cpp           |  4 +-
 .../Conversion/LLVMCommon/TypeConverter.cpp   |  9 ++---
 .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp    |  8 ++--
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  |  4 +-
 mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp    |  6 +--
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          |  3 +-
 mlir/lib/IR/Types.cpp                         | 19 ---------
 11 files changed, 73 insertions(+), 95 deletions(-)

diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td 
b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 6f52195c1d7c92..e752cdfb47fbb1 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -329,31 +329,31 @@ def F64 : F<64>;
 def F80 : F<80>;
 def F128 : F<128>;
 
-def BF16 : Type<CPred<"$_self.isBF16()">, "bfloat16 type">,
+def BF16 : Type<CPred<"::llvm::isa<BFloat16Type>($_self)">, "bfloat16 type">,
            BuildableType<"$_builder.getType<BFloat16Type>()">;
-def TF32 : Type<CPred<"$_self.isTF32()">, "tf32 type">,
+def TF32 : Type<CPred<"::llvm::isa<FloatTF32Type>($_self)">, "tf32 type">,
            BuildableType<"$_builder.getType<FloatTF32Type>()">;
-def F8E4M3FN : Type<CPred<"$_self.isFloat8E4M3FN()">, "f8E4M3FN type">,
+def F8E4M3FN : Type<CPred<"::llvm::isa<Float8E4M3FNType>($_self)">, "f8E4M3FN 
type">,
                BuildableType<"$_builder.getType<Float8E4M3FNType>()">;
-def F8E5M2 : Type<CPred<"$_self.isFloat8E5M2()">, "f8E5M2 type">,
+def F8E5M2 : Type<CPred<"::llvm::isa<Float8E5M2Type>($_self)">, "f8E5M2 type">,
              BuildableType<"$_builder.getType<Float8E5M2Type>()">;
-def F8E4M3 : Type<CPred<"$_self.isFloat8E4M3()">, "f8E4M3 type">,
+def F8E4M3 : Type<CPred<"::llvm::isa<Float8E4M3Type>($_self)">, "f8E4M3 type">,
              BuildableType<"$_builder.getType<Float8E4M3Type>()">;
-def F8E4M3FNUZ : Type<CPred<"$_self.isFloat8E4M3FNUZ()">, "f8E4M3FNUZ type">,
+def F8E4M3FNUZ : Type<CPred<"::llvm::isa<Float8E4M3FNUZType>($_self)">, 
"f8E4M3FNUZ type">,
                  BuildableType<"$_builder.getType<Float8E4M3FNUZType>()">;
-def F8E4M3B11FNUZ : Type<CPred<"$_self.isFloat8E4M3B11FNUZ()">, "f8E4M3B11FNUZ 
type">,
+def F8E4M3B11FNUZ : Type<CPred<"::llvm::isa<Float8E4M3B11FNUZType>($_self)">, 
"f8E4M3B11FNUZ type">,
                  BuildableType<"$_builder.getType<Float8E4M3B11FNUZType>()">;
-def F8E5M2FNUZ : Type<CPred<"$_self.isFloat8E5M2FNUZ()">, "f8E5M2FNUZ type">,
+def F8E5M2FNUZ : Type<CPred<"::llvm::isa<Float8E5M2FNUZType>($_self)">, 
"f8E5M2FNUZ type">,
                  BuildableType<"$_builder.getType<Float8E5M2FNUZType>()">;
-def F8E3M4 : Type<CPred<"$_self.isFloat8E3M4()">, "f8E3M4 type">,
+def F8E3M4 : Type<CPred<"::llvm::isa<Float8E3M4Type>($_self)">, "f8E3M4 type">,
              BuildableType<"$_builder.getType<Float8E3M4Type>()">;
-def F4E2M1FN : Type<CPred<"$_self.isFloat4E2M1FN()">, "f4E2M1FN type">,
+def F4E2M1FN : Type<CPred<"::llvm::isa<Float4E2M1FNType>($_self)">, "f4E2M1FN 
type">,
                BuildableType<"$_builder.getType<Float4E2M1FNType>()">;
-def F6E2M3FN : Type<CPred<"$_self.isFloat6E2M3FN()">, "f6E2M3FN type">,
+def F6E2M3FN : Type<CPred<"::llvm::isa<Float6E2M3FNType>($_self)">, "f6E2M3FN 
type">,
                BuildableType<"$_builder.getType<Float6E2M3FNType>()">;
-def F6E3M2FN : Type<CPred<"$_self.isFloat6E3M2FN()">, "f6E3M2FN type">,
+def F6E3M2FN : Type<CPred<"::llvm::isa<Float6E3M2FNType($_self)">, "f6E3M2FN 
type">,
                BuildableType<"$_builder.getType<Float6E3M2FNType>()">;
-def F8E8M0FNU : Type<CPred<"$_self.isFloat8E8M0FNU()">, "f8E8M0FNU type">,
+def F8E8M0FNU : Type<CPred<"::llvm::isa<Float8E8M0FNUType>($_self)">, 
"f8E8M0FNU type">,
                 BuildableType<"$_builder.getType<Float8E8M0FNUType>()">;
 
 def AnyComplex : Type<CPred<"::llvm::isa<::mlir::ComplexType>($_self)">,
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index acd0f894abbbe6..0e82ad2be907ab 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -125,17 +125,6 @@ class Type {
   // Convenience predicates.  This is only for floating point types,
   // derived types should use isa/dyn_cast.
   bool isIndex() const;
-  bool isFloat4E2M1FN() const;
-  bool isFloat6E2M3FN() const;
-  bool isFloat6E3M2FN() const;
-  bool isFloat8E5M2() const;
-  bool isFloat8E4M3() const;
-  bool isFloat8E4M3FN() const;
-  bool isFloat8E5M2FNUZ() const;
-  bool isFloat8E4M3FNUZ() const;
-  bool isFloat8E4M3B11FNUZ() const;
-  bool isFloat8E3M4() const;
-  bool isFloat8E8M0FNU() const;
   bool isBF16() const;
   bool isF16() const;
   bool isTF32() const;
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp 
b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 250e4a6bbf8dfd..313d6830b41b2a 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -90,7 +90,7 @@ MlirTypeID mlirFloat4E2M1FNTypeGetTypeID() {
 }
 
 bool mlirTypeIsAFloat4E2M1FN(MlirType type) {
-  return unwrap(type).isFloat4E2M1FN();
+  return llvm::isa<Float4E2M1FNType>(unwrap(type));
 }
 
 MlirType mlirFloat4E2M1FNTypeGet(MlirContext ctx) {
@@ -102,7 +102,7 @@ MlirTypeID mlirFloat6E2M3FNTypeGetTypeID() {
 }
 
 bool mlirTypeIsAFloat6E2M3FN(MlirType type) {
-  return unwrap(type).isFloat6E2M3FN();
+  return llvm::isa<Float6E2M3FNType>(unwrap(type));
 }
 
 MlirType mlirFloat6E2M3FNTypeGet(MlirContext ctx) {
@@ -114,7 +114,7 @@ MlirTypeID mlirFloat6E3M2FNTypeGetTypeID() {
 }
 
 bool mlirTypeIsAFloat6E3M2FN(MlirType type) {
-  return unwrap(type).isFloat6E3M2FN();
+  return llvm::isa<Float6E3M2FNType>(unwrap(type));
 }
 
 MlirType mlirFloat6E3M2FNTypeGet(MlirContext ctx) {
@@ -126,7 +126,7 @@ MlirTypeID mlirFloat8E5M2TypeGetTypeID() {
 }
 
 bool mlirTypeIsAFloat8E5M2(MlirType type) {
-  return unwrap(type).isFloat8E5M2();
+  return llvm::isa<Float8E5M2Type>(unwrap(type));
 }
 
 MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) {
@@ -138,7 +138,7 @@ MlirTypeID mlirFloat8E4M3TypeGetTypeID() {
 }
 
 bool mlirTypeIsAFloat8E4M3(MlirType type) {
-  return unwrap(type).isFloat8E4M3();
+  return llvm::isa<Float8E4M3Type>(unwrap(type));
 }
 
 MlirType mlirFloat8E4M3TypeGet(MlirContext ctx) {
@@ -150,7 +150,7 @@ MlirTypeID mlirFloat8E4M3FNTypeGetTypeID() {
 }
 
 bool mlirTypeIsAFloat8E4M3FN(MlirType type) {
-  return unwrap(type).isFloat8E4M3FN();
+  return llvm::isa<Float8E4M3FNType>(unwrap(type));
 }
 
 MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx) {
@@ -162,7 +162,7 @@ MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID() {
 }
 
 bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type) {
-  return unwrap(type).isFloat8E5M2FNUZ();
+  return llvm::isa<Float8E5M2FNUZType>(unwrap(type));
 }
 
 MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx) {
@@ -174,7 +174,7 @@ MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID() {
 }
 
 bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type) {
-  return unwrap(type).isFloat8E4M3FNUZ();
+  return llvm::isa<Float8E4M3FNUZType>(unwrap(type));
 }
 
 MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx) {
@@ -186,7 +186,7 @@ MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID() {
 }
 
 bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type) {
-  return unwrap(type).isFloat8E4M3B11FNUZ();
+  return llvm::isa<Float8E4M3B11FNUZType>(unwrap(type));
 }
 
 MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx) {
@@ -198,7 +198,7 @@ MlirTypeID mlirFloat8E3M4TypeGetTypeID() {
 }
 
 bool mlirTypeIsAFloat8E3M4(MlirType type) {
-  return unwrap(type).isFloat8E3M4();
+  return llvm::isa<Float8E3M4Type>(unwrap(type));
 }
 
 MlirType mlirFloat8E3M4TypeGet(MlirContext ctx) {
@@ -210,7 +210,7 @@ MlirTypeID mlirFloat8E8M0FNUTypeGetTypeID() {
 }
 
 bool mlirTypeIsAFloat8E8M0FNU(MlirType type) {
-  return unwrap(type).isFloat8E8M0FNU();
+  return llvm::isa<Float8E8M0FNUType>(unwrap(type));
 }
 
 MlirType mlirFloat8E8M0FNUTypeGet(MlirContext ctx) {
@@ -221,7 +221,9 @@ MlirTypeID mlirBFloat16TypeGetTypeID() {
   return wrap(BFloat16Type::getTypeID());
 }
 
-bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); }
+bool mlirTypeIsABF16(MlirType type) {
+  return llvm::isa<BFloat16Type>(unwrap(type));
+}
 
 MlirType mlirBF16TypeGet(MlirContext ctx) {
   return wrap(BFloat16Type::get(unwrap(ctx)));
@@ -229,7 +231,9 @@ MlirType mlirBF16TypeGet(MlirContext ctx) {
 
 MlirTypeID mlirFloat16TypeGetTypeID() { return wrap(Float16Type::getTypeID()); 
}
 
-bool mlirTypeIsAF16(MlirType type) { return unwrap(type).isF16(); }
+bool mlirTypeIsAF16(MlirType type) {
+  return llvm::isa<Float16Type>(unwrap(type));
+}
 
 MlirType mlirF16TypeGet(MlirContext ctx) {
   return wrap(Float16Type::get(unwrap(ctx)));
@@ -239,7 +243,7 @@ MlirTypeID mlirFloatTF32TypeGetTypeID() {
   return wrap(FloatTF32Type::getTypeID());
 }
 
-bool mlirTypeIsATF32(MlirType type) { return unwrap(type).isTF32(); }
+bool mlirTypeIsATF32(MlirType type) { return llvm::isa<FloatTF32Type>(type); }
 
 MlirType mlirTF32TypeGet(MlirContext ctx) {
   return wrap(FloatTF32Type::get(unwrap(ctx)));
@@ -247,7 +251,9 @@ MlirType mlirTF32TypeGet(MlirContext ctx) {
 
 MlirTypeID mlirFloat32TypeGetTypeID() { return wrap(Float32Type::getTypeID()); 
}
 
-bool mlirTypeIsAF32(MlirType type) { return unwrap(type).isF32(); }
+bool mlirTypeIsAF32(MlirType type) {
+  return llvm::isa<Float32Type>(unwrap(type));
+}
 
 MlirType mlirF32TypeGet(MlirContext ctx) {
   return wrap(Float32Type::get(unwrap(ctx)));
@@ -255,7 +261,9 @@ MlirType mlirF32TypeGet(MlirContext ctx) {
 
 MlirTypeID mlirFloat64TypeGetTypeID() { return wrap(Float64Type::getTypeID()); 
}
 
-bool mlirTypeIsAF64(MlirType type) { return unwrap(type).isF64(); }
+bool mlirTypeIsAF64(MlirType type) {
+  return llvm::isa<Float64Type>(unwrap(type));
+}
 
 MlirType mlirF64TypeGet(MlirContext ctx) {
   return wrap(Float64Type::get(unwrap(ctx)));
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp 
b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 1564e417a7a48e..5d09d6f1d69523 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -564,38 +564,40 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp 
mfma,
       return ROCDL::mfma_f64_4x4x4f64::getOperationName();
   }
 
-  if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() && chipset >= kGfx940) 
{
+  if (isa<Float8E5M2FNUZType>(sourceElem) && destElem.isF32() &&
+      chipset >= kGfx940) {
     // Known to be correct because there are no scalar f8 instructions and
     // because a length mismatch will have been caught by the verifier.
     Type sourceBElem =
         cast<VectorType>(mfma.getSourceB().getType()).getElementType();
     if (m == 16 && n == 16 && k == 32 && b == 1) {
-      if (sourceBElem.isFloat8E5M2FNUZ())
+      if (isa<Float8E5M2FNUZType>(sourceBElem))
         return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
-      if (sourceBElem.isFloat8E4M3FNUZ())
+      if (isa<Float8E4M3FNUZType>(sourceBElem))
         return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
     }
     if (m == 32 && n == 32 && k == 16 && b == 1) {
-      if (sourceBElem.isFloat8E5M2FNUZ())
+      if (isa<Float8E5M2FNUZType>(sourceBElem))
         return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
-      if (sourceBElem.isFloat8E4M3FNUZ())
+      if (isa<Float8E4M3FNUZType>(sourceBElem))
         return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
     }
   }
 
-  if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() && chipset >= kGfx940) 
{
+  if (isa<Float8E4M3FNUZType>(sourceElem) && destElem.isF32() &&
+      chipset >= kGfx940) {
     Type sourceBElem =
         cast<VectorType>(mfma.getSourceB().getType()).getElementType();
     if (m == 16 && n == 16 && k == 32 && b == 1) {
-      if (sourceBElem.isFloat8E5M2FNUZ())
+      if (isa<Float8E5M2FNUZType>(sourceBElem))
         return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
-      if (sourceBElem.isFloat8E4M3FNUZ())
+      if (isa<Float8E4M3FNUZType>(sourceBElem))
         return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
     }
     if (m == 32 && n == 32 && k == 16 && b == 1) {
-      if (sourceBElem.isFloat8E5M2FNUZ())
+      if (isa<Float8E5M2FNUZType>(sourceBElem))
         return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
-      if (sourceBElem.isFloat8E4M3FNUZ())
+      if (isa<Float8E4M3FNUZType>(sourceBElem))
         return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
     }
   }
@@ -623,9 +625,9 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp 
wmma,
     return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
   if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
     return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
-  if (elemSourceType.isFloat8E4M3FN() && elemDestType.isF32())
+  if (isa<Float8E4M3FNType>(elemSourceType) && elemDestType.isF32())
     return ROCDL::wmma_f32_16x16x16_fp8::getOperationName();
-  if (elemSourceType.isFloat8E5M2() && elemDestType.isF32())
+  if (isa<Float8E5M2Type>(elemSourceType) && elemDestType.isF32())
     return ROCDL::wmma_f32_16x16x16_bf8::getOperationName();
   return std::nullopt;
 }
@@ -803,10 +805,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
   }
   Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
   Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
-  if (sourceElemType.isFloat8E5M2FNUZ()) {
+  if (isa<Float8E5M2FNUZType>(sourceElemType)) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
                                                     wordSel);
-  } else if (sourceElemType.isFloat8E4M3FNUZ()) {
+  } else if (isa<Float8E4M3FNUZType>(sourceElemType)) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
                                                     wordSel);
   }
@@ -838,10 +840,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
   Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
 
   Value result;
-  if (resultElemType.isFloat8E5M2FNUZ())
+  if (isa<Float8E5M2FNUZType>(resultElemType))
     result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
                                                    existing, wordSel);
-  else if (resultElemType.isFloat8E4M3FNUZ())
+  else if (isa<Float8E4M3FNUZType>(resultElemType))
     result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
                                                    existing, wordSel);
 
@@ -873,10 +875,10 @@ LogicalResult 
PackedStochRoundFp8OpLowering::matchAndRewrite(
   Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
 
   Value result;
-  if (resultElemType.isFloat8E5M2FNUZ())
+  if (isa<Float8E5M2FNUZType>(resultElemType))
     result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
                                                    existing, byteSel);
-  else if (resultElemType.isFloat8E4M3FNUZ())
+  else if (isa<Float8E4M3FNUZType>(resultElemType))
     result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
                                                    existing, byteSel);
 
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp 
b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index a8283023afc53d..33370566996eee 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -86,7 +86,7 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp 
op) const {
       return failure();
     inType = inVecType.getElementType();
   }
-  return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
+  return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(inType));
 }
 
 void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
@@ -216,7 +216,7 @@ LogicalResult 
TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
   if (inType && inType.getWidth() <= 8 && saturateFP8)
     // Conversion between 8-bit floats is not supported with truncation 
enabled.
     return failure();
-  return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ());
+  return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(outType));
 }
 
 void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp 
b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 64bdb248dff430..247a8ab28a44be 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -299,11 +299,10 @@ Type LLVMTypeConverter::convertFloatType(FloatType type) 
const {
     return type;
 
   // F4, F6, F8 types are converted to integer types with the same bit width.
-  if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
-      type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() ||
-      type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4() ||
-      type.isFloat4E2M1FN() || type.isFloat6E2M3FN() || type.isFloat6E3M2FN() 
||
-      type.isFloat8E8M0FNU())
+  if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
+          Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
+          Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
+          Float8E8M0FNUType>(type))
     return IntegerType::get(&getContext(), type.getWidth());
 
   // Other floating-point types: A custom type conversion rule must be
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp 
b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 34a6b1d506540d..7e97fb84434f89 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1254,8 +1254,8 @@ struct NVGPUWarpgroupMmaOpLowering
         wgmmaK = 8;
       } else if (inputElemType.isF16() || inputElemType.isBF16()) {
         wgmmaK = 16;
-      } else if (inputElemType.isFloat8E4M3FN() ||
-                 inputElemType.isFloat8E5M2() || inputElemType.isInteger(16)) {
+      } else if (isa<Float8E4M3FNType, Float8E5M2Type>(inputElemType) ||
+                 inputElemType.isInteger(16)) {
         wgmmaK = 32;
       } else if (inputElemType.isInteger(1)) {
         wgmmaK = 256;
@@ -1276,9 +1276,9 @@ struct NVGPUWarpgroupMmaOpLowering
           return NVVM::WGMMATypes::f16;
         if (elemType.isBF16())
           return NVVM::WGMMATypes::bf16;
-        if (elemType.isFloat8E4M3FN())
+        if (isa<Float8E4M3FNType>(elemType))
           return NVVM::WGMMATypes::e4m3;
-        if (elemType.isFloat8E5M2())
+        if (isa<Float8E5M2Type>(elemType))
           return NVVM::WGMMATypes::e5m2;
         if (elemType.isInteger(1))
           return NVVM::WGMMATypes::b1;
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp 
b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 492e4781f57810..5af0cb0c7ba1cc 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -272,14 +272,14 @@ LogicalResult MFMAOp::verify() {
   }
 
   Type sourceBType = getSourceB().getType();
-  if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ()) {
+  if (isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(sourceElem)) {
     int64_t sourceBLen = 1;
     Type sourceBElem = sourceBType;
     if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
       sourceBLen = sourceBVector.getNumElements();
       sourceBElem = sourceBVector.getElementType();
     }
-    if (!sourceBElem.isFloat8E5M2FNUZ() && !sourceBElem.isFloat8E4M3FNUZ())
+    if (!isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(sourceBElem))
       return emitOpError("expected both source operands to have f8 elements");
     if (sourceLen != sourceBLen)
       return emitOpError(
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp 
b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index a027350e8a5f70..47d1b8492e06ec 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -525,8 +525,8 @@ LogicalResult isAllowedWGMMADataType(Type typeD, Type 
typeA, Type typeB) {
     return success();
   // F16 += f8 + f8
   // F32 += f8 + f8
-  if ((typeA.isFloat8E5M2() || typeA.isFloat8E4M3FN()) &&
-      (typeB.isFloat8E5M2() || typeB.isFloat8E4M3FN()) &&
+  if (isa<Float8E5M2Type, Float8E4M3FNType>(typeA) &&
+      isa<Float8E5M2Type, Float8E4M3FNType>(typeB) &&
       (typeD.isF32() || typeD.isF16()))
     return success();
 
@@ -548,7 +548,7 @@ LogicalResult isAllowedSizeN(int sizeN, Type typeA) {
                                     80,  96,  112, 128, 144, 160,
                                     176, 192, 208, 224, 240, 256};
   if (typeA.isBF16() || typeA.isF16() || typeA.isF32() || typeA.isTF32() ||
-      typeA.isFloat8E4M3FN() || typeA.isFloat8E5M2())
+      isa<Float8E5M2Type, Float8E4M3FNType>(typeA))
     if (llvm::is_contained(allowedN, sizeN))
       return success();
 
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp 
b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 83cf4a9415fe68..9590fa6fa5a8b2 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -306,8 +306,7 @@ static LogicalResult verifyConvOpModes(T op) {
   if (inputEType.isInteger(16) && !accType.isInteger(48))
     return op.emitOpError("accumulator type for i16 tensor is not i48");
 
-  if ((inputEType.isFloat8E5M2() || inputEType.isFloat8E4M3()) &&
-      !accType.isF16())
+  if (isa<Float8E5M2Type, Float8E4M3FNType>(inputEType) && !accType.isF16())
     return op.emitOpError("accumulator type for f8 tensor is not f16");
 
   if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index e190902b2e4898..bca90de6f4a8a5 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -34,25 +34,6 @@ Type AbstractType::replaceImmediateSubElements(Type type,
 
 MLIRContext *Type::getContext() const { return getDialect().getContext(); }
 
-bool Type::isFloat4E2M1FN() const { return llvm::isa<Float4E2M1FNType>(*this); 
}
-bool Type::isFloat6E2M3FN() const { return llvm::isa<Float6E2M3FNType>(*this); 
}
-bool Type::isFloat6E3M2FN() const { return llvm::isa<Float6E3M2FNType>(*this); 
}
-bool Type::isFloat8E5M2() const { return llvm::isa<Float8E5M2Type>(*this); }
-bool Type::isFloat8E4M3() const { return llvm::isa<Float8E4M3Type>(*this); }
-bool Type::isFloat8E4M3FN() const { return llvm::isa<Float8E4M3FNType>(*this); 
}
-bool Type::isFloat8E5M2FNUZ() const {
-  return llvm::isa<Float8E5M2FNUZType>(*this);
-}
-bool Type::isFloat8E4M3FNUZ() const {
-  return llvm::isa<Float8E4M3FNUZType>(*this);
-}
-bool Type::isFloat8E4M3B11FNUZ() const {
-  return llvm::isa<Float8E4M3B11FNUZType>(*this);
-}
-bool Type::isFloat8E8M0FNU() const {
-  return llvm::isa<Float8E8M0FNUType>(*this);
-}
-bool Type::isFloat8E3M4() const { return llvm::isa<Float8E3M4Type>(*this); }
 bool Type::isBF16() const { return llvm::isa<BFloat16Type>(*this); }
 bool Type::isF16() const { return llvm::isa<Float16Type>(*this); }
 bool Type::isTF32() const { return llvm::isa<FloatTF32Type>(*this); }

_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to