https://github.com/ftynse created 
https://github.com/llvm/llvm-project/pull/171143

Port the bindings for non-shaped builtin types in IRTypes.cpp to use the 
`mlir_type_subclass` mechanism used by non-builtin types. This is part of a 
longer-term cleanup to only support one subclassing mechanism. Eventually, the 
`PyConcreteType` mechanism will be removed.

This required a surgery in the type casters and the `mlir_type_subclass` logic 
to avoid circular imports of the `_mlir.ir` module that would otherwise when 
using `mlir_type_subclass` to define classes in the `_mlir.ir` module.

Tests are updated to use the `.get_static_typeid()` function instead of the 
`.static_typeid` property that was specific to builtin types due to the 
`PyConcreteType` mechanism. The change should be NFC otherwise.

>From deac26450350ba40b9f9357f68ec3a5e458b43d6 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <[email protected]>
Date: Mon, 8 Dec 2025 15:50:41 +0100
Subject: [PATCH] [mlir][py] partially use mlir_type_subclass for IRTypes.cpp

Port the bindings for non-shaped builtin types in IRTypes.cpp to use the
`mlir_type_subclass` mechanism used by non-builtin types. This is part of a
longer-term cleanup to only support one subclassing mechanism. Eventually, the
`PyConcreteType` mechanism will be removed.

This required a surgery in the type casters and the `mlir_type_subclass` logic
to avoid circular imports of the `_mlir.ir` module that would otherwise when
using `mlir_type_subclass` to define classes in the `_mlir.ir` module.

Tests are updated to use the `.get_static_typeid()` function instead of the
`.static_typeid` property that was specific to builtin types due to the
`PyConcreteType` mechanism. The change should be NFC otherwise.
---
 .../mlir/Bindings/Python/NanobindAdaptors.h   |   41 +-
 mlir/lib/Bindings/Python/IRTypes.cpp          | 1029 ++++++-----------
 mlir/lib/Bindings/Python/MainModule.cpp       |   15 +
 mlir/test/python/dialects/arith_dialect.py    |    8 +-
 mlir/test/python/ir/builtin_types.py          |   11 +-
 mlir/test/python/ir/value.py                  |    6 +-
 6 files changed, 425 insertions(+), 685 deletions(-)

diff --git a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h 
b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
index 6594670abaaa7..f678f57527e97 100644
--- a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
@@ -371,16 +371,22 @@ struct type_caster<MlirTypeID> {
     }
     return false;
   }
-  static handle from_cpp(MlirTypeID v, rv_policy,
-                         cleanup_list *cleanup) noexcept {
+
+  static handle
+  from_cpp_given_module(MlirTypeID v,
+                        const nanobind::module_ &module) noexcept {
     if (v.ptr == nullptr)
       return nanobind::none();
     nanobind::object capsule =
         nanobind::steal<nanobind::object>(mlirPythonTypeIDToCapsule(v));
-    return mlir::python::irModule()
-        .attr("TypeID")
+    return module.attr("TypeID")
         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
         .release();
+  }
+
+  static handle from_cpp(MlirTypeID v, rv_policy,
+                         cleanup_list *cleanup) noexcept {
+    return from_cpp_given_module(v, mlir::python::irModule());
   };
 };
 
@@ -602,9 +608,12 @@ class mlir_type_subclass : public pure_subclass {
   /// Subclasses by looking up the super-class dynamically.
   mlir_type_subclass(nanobind::handle scope, const char *typeClassName,
                      IsAFunctionTy isaFunction,
-                     GetTypeIDFunctionTy getTypeIDFunction = nullptr)
-      : mlir_type_subclass(scope, typeClassName, isaFunction,
-                           irModule().attr("Type"), getTypeIDFunction) {}
+                     GetTypeIDFunctionTy getTypeIDFunction = nullptr,
+                     const nanobind::module_ *mlirIrModule = nullptr)
+      : mlir_type_subclass(
+            scope, typeClassName, isaFunction,
+            (mlirIrModule != nullptr ? *mlirIrModule : 
irModule()).attr("Type"),
+            getTypeIDFunction, mlirIrModule) {}
 
   /// Subclasses with a provided mlir.ir.Type super-class. This must
   /// be used if the subclass is being defined in the same extension module
@@ -613,7 +622,8 @@ class mlir_type_subclass : public pure_subclass {
   mlir_type_subclass(nanobind::handle scope, const char *typeClassName,
                      IsAFunctionTy isaFunction,
                      const nanobind::object &superCls,
-                     GetTypeIDFunctionTy getTypeIDFunction = nullptr)
+                     GetTypeIDFunctionTy getTypeIDFunction = nullptr,
+                     const nanobind::module_ *mlirIrModule = nullptr)
       : pure_subclass(scope, typeClassName, superCls) {
     // Casting constructor. Note that it is hard, if not impossible, to 
properly
     // call chain to parent `__init__` in nanobind due to its special handling
@@ -672,9 +682,18 @@ class mlir_type_subclass : public pure_subclass {
           nanobind::sig("def get_static_typeid() -> " 
MAKE_MLIR_PYTHON_QUALNAME("ir.TypeID"))
           // clang-format on
       );
-      nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
-          .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
-              getTypeIDFunction())(nanobind::cpp_function(
+
+      // Directly call the caster implementation given the "ir" module,
+      // otherwise it may trigger recursive import as the default caster
+      // attempts to import the "ir" module.
+      MlirTypeID typeID = getTypeIDFunction();
+      mlirIrModule = mlirIrModule ? mlirIrModule : &irModule();
+      nanobind::handle pyTypeID =
+          nanobind::detail::type_caster<MlirTypeID>::from_cpp_given_module(
+              typeID, *mlirIrModule);
+
+      mlirIrModule->attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(pyTypeID)(
+          nanobind::cpp_function(
               [thisClass = thisClass](const nanobind::object &mlirType) {
                 return thisClass(mlirType);
               }));
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp 
b/mlir/lib/Bindings/Python/IRTypes.cpp
index 34c5b8dd86a66..2e4090c358c47 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -18,13 +18,13 @@
 #include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/BuiltinTypes.h"
 #include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
 
 namespace nb = nanobind;
 using namespace mlir;
 using namespace mlir::python;
 
 using llvm::SmallVector;
-using llvm::Twine;
 
 namespace {
 
@@ -34,480 +34,368 @@ static int mlirTypeIsAIntegerOrFloat(MlirType type) {
          mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
 }
 
-class PyIntegerType : public PyConcreteType<PyIntegerType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirIntegerTypeGetTypeID;
-  static constexpr const char *pyClassName = "IntegerType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get_signless",
-        [](unsigned width, DefaultingPyMlirContext context) {
-          MlirType t = mlirIntegerTypeGet(context->get(), width);
-          return PyIntegerType(context->getRef(), t);
-        },
-        nb::arg("width"), nb::arg("context") = nb::none(),
-        "Create a signless integer type");
-    c.def_static(
-        "get_signed",
-        [](unsigned width, DefaultingPyMlirContext context) {
-          MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
-          return PyIntegerType(context->getRef(), t);
-        },
-        nb::arg("width"), nb::arg("context") = nb::none(),
-        "Create a signed integer type");
-    c.def_static(
-        "get_unsigned",
-        [](unsigned width, DefaultingPyMlirContext context) {
-          MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
-          return PyIntegerType(context->getRef(), t);
-        },
-        nb::arg("width"), nb::arg("context") = nb::none(),
-        "Create an unsigned integer type");
-    c.def_prop_ro(
-        "width",
-        [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
-        "Returns the width of the integer type");
-    c.def_prop_ro(
-        "is_signless",
-        [](PyIntegerType &self) -> bool {
-          return mlirIntegerTypeIsSignless(self);
-        },
-        "Returns whether this is a signless integer");
-    c.def_prop_ro(
-        "is_signed",
-        [](PyIntegerType &self) -> bool {
-          return mlirIntegerTypeIsSigned(self);
-        },
-        "Returns whether this is a signed integer");
-    c.def_prop_ro(
-        "is_unsigned",
-        [](PyIntegerType &self) -> bool {
-          return mlirIntegerTypeIsUnsigned(self);
-        },
-        "Returns whether this is an unsigned integer");
-  }
-};
-
-/// Index Type subclass - IndexType.
-class PyIndexType : public PyConcreteType<PyIndexType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirIndexTypeGetTypeID;
-  static constexpr const char *pyClassName = "IndexType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirIndexTypeGet(context->get());
-          return PyIndexType(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a index type.");
-  }
-};
-
-class PyFloatType : public PyConcreteType<PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat;
-  static constexpr const char *pyClassName = "FloatType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_prop_ro(
-        "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); },
-        "Returns the width of the floating-point type");
-  }
-};
-
-/// Floating Point Type subclass - Float4E2M1FNType.
-class PyFloat4E2M1FNType
-    : public PyConcreteType<PyFloat4E2M1FNType, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat4E2M1FNTypeGetTypeID;
-  static constexpr const char *pyClassName = "Float4E2M1FNType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirFloat4E2M1FNTypeGet(context->get());
-          return PyFloat4E2M1FNType(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a float4_e2m1fn type.");
-  }
-};
-
-/// Floating Point Type subclass - Float6E2M3FNType.
-class PyFloat6E2M3FNType
-    : public PyConcreteType<PyFloat6E2M3FNType, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat6E2M3FNTypeGetTypeID;
-  static constexpr const char *pyClassName = "Float6E2M3FNType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirFloat6E2M3FNTypeGet(context->get());
-          return PyFloat6E2M3FNType(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a float6_e2m3fn type.");
-  }
-};
-
-/// Floating Point Type subclass - Float6E3M2FNType.
-class PyFloat6E3M2FNType
-    : public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat6E3M2FNTypeGetTypeID;
-  static constexpr const char *pyClassName = "Float6E3M2FNType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirFloat6E3M2FNTypeGet(context->get());
-          return PyFloat6E3M2FNType(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a float6_e3m2fn type.");
-  }
-};
-
-/// Floating Point Type subclass - Float8E4M3FNType.
-class PyFloat8E4M3FNType
-    : public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat8E4M3FNTypeGetTypeID;
-  static constexpr const char *pyClassName = "Float8E4M3FNType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirFloat8E4M3FNTypeGet(context->get());
-          return PyFloat8E4M3FNType(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a float8_e4m3fn type.");
-  }
-};
-
-/// Floating Point Type subclass - Float8E5M2Type.
-class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat8E5M2TypeGetTypeID;
-  static constexpr const char *pyClassName = "Float8E5M2Type";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirFloat8E5M2TypeGet(context->get());
-          return PyFloat8E5M2Type(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a float8_e5m2 type.");
-  }
-};
-
-/// Floating Point Type subclass - Float8E4M3Type.
-class PyFloat8E4M3Type : public PyConcreteType<PyFloat8E4M3Type, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat8E4M3TypeGetTypeID;
-  static constexpr const char *pyClassName = "Float8E4M3Type";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirFloat8E4M3TypeGet(context->get());
-          return PyFloat8E4M3Type(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a float8_e4m3 type.");
-  }
-};
-
-/// Floating Point Type subclass - Float8E4M3FNUZ.
-class PyFloat8E4M3FNUZType
-    : public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat8E4M3FNUZTypeGetTypeID;
-  static constexpr const char *pyClassName = "Float8E4M3FNUZType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get());
-          return PyFloat8E4M3FNUZType(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a float8_e4m3fnuz type.");
-  }
-};
-
-/// Floating Point Type subclass - Float8E4M3B11FNUZ.
-class PyFloat8E4M3B11FNUZType
-    : public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat8E4M3B11FNUZTypeGetTypeID;
-  static constexpr const char *pyClassName = "Float8E4M3B11FNUZType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get());
-          return PyFloat8E4M3B11FNUZType(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a float8_e4m3b11fnuz type.");
-  }
-};
-
-/// Floating Point Type subclass - Float8E5M2FNUZ.
-class PyFloat8E5M2FNUZType
-    : public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat8E5M2FNUZTypeGetTypeID;
-  static constexpr const char *pyClassName = "Float8E5M2FNUZType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get());
-          return PyFloat8E5M2FNUZType(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a float8_e5m2fnuz type.");
-  }
-};
-
-/// Floating Point Type subclass - Float8E3M4Type.
-class PyFloat8E3M4Type : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat8E3M4TypeGetTypeID;
-  static constexpr const char *pyClassName = "Float8E3M4Type";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirFloat8E3M4TypeGet(context->get());
-          return PyFloat8E3M4Type(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a float8_e3m4 type.");
-  }
-};
-
-/// Floating Point Type subclass - Float8E8M0FNUType.
-class PyFloat8E8M0FNUType
-    : public PyConcreteType<PyFloat8E8M0FNUType, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat8E8M0FNUTypeGetTypeID;
-  static constexpr const char *pyClassName = "Float8E8M0FNUType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirFloat8E8M0FNUTypeGet(context->get());
-          return PyFloat8E8M0FNUType(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a float8_e8m0fnu type.");
-  }
-};
-
-/// Floating Point Type subclass - BF16Type.
-class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirBFloat16TypeGetTypeID;
-  static constexpr const char *pyClassName = "BF16Type";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirBF16TypeGet(context->get());
-          return PyBF16Type(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a bf16 type.");
-  }
-};
-
-/// Floating Point Type subclass - F16Type.
-class PyF16Type : public PyConcreteType<PyF16Type, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat16TypeGetTypeID;
-  static constexpr const char *pyClassName = "F16Type";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirF16TypeGet(context->get());
-          return PyF16Type(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a f16 type.");
-  }
-};
-
-/// Floating Point Type subclass - TF32Type.
-class PyTF32Type : public PyConcreteType<PyTF32Type, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloatTF32TypeGetTypeID;
-  static constexpr const char *pyClassName = "FloatTF32Type";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirTF32TypeGet(context->get());
-          return PyTF32Type(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a tf32 type.");
-  }
-};
-
-/// Floating Point Type subclass - F32Type.
-class PyF32Type : public PyConcreteType<PyF32Type, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat32TypeGetTypeID;
-  static constexpr const char *pyClassName = "F32Type";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirF32TypeGet(context->get());
-          return PyF32Type(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a f32 type.");
-  }
-};
+static void populateIRTypesModule(const nanobind::module_ &m) {
+  using namespace nanobind_adaptors;
 
-/// Floating Point Type subclass - F64Type.
-class PyF64Type : public PyConcreteType<PyF64Type, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat64TypeGetTypeID;
-  static constexpr const char *pyClassName = "F64Type";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirF64TypeGet(context->get());
-          return PyF64Type(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a f64 type.");
-  }
-};
-
-/// None Type subclass - NoneType.
-class PyNoneType : public PyConcreteType<PyNoneType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirNoneTypeGetTypeID;
-  static constexpr const char *pyClassName = "NoneType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirNoneTypeGet(context->get());
-          return PyNoneType(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a none type.");
-  }
-};
-
-/// Complex Type subclass - ComplexType.
-class PyComplexType : public PyConcreteType<PyComplexType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirComplexTypeGetTypeID;
-  static constexpr const char *pyClassName = "ComplexType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](PyType &elementType) {
-          // The element must be a floating point or integer scalar type.
-          if (mlirTypeIsAIntegerOrFloat(elementType)) {
-            MlirType t = mlirComplexTypeGet(elementType);
-            return PyComplexType(elementType.getContext(), t);
-          }
-          throw nb::value_error(
-              (Twine("invalid '") +
-               nb::cast<std::string>(nb::repr(nb::cast(elementType))) +
-               "' and expected floating point or integer type.")
-                  .str()
-                  .c_str());
-        },
-        "Create a complex type");
-    c.def_prop_ro(
-        "element_type",
-        [](PyComplexType &self) -> nb::typed<nb::object, PyType> {
-          return PyType(self.getContext(), mlirComplexTypeGetElementType(self))
-              .maybeDownCast();
-        },
-        "Returns element type.");
-  }
-};
+  mlir_type_subclass integerType(m, "IntegerType", mlirTypeIsAInteger,
+                                 mlirIntegerTypeGetTypeID, &m);
+  integerType.def_classmethod(
+      "get_signless",
+      [](const nb::object &cls, unsigned width, MlirContext ctx) {
+        return cls(mlirIntegerTypeGet(ctx, width));
+      },
+      nb::arg("cls"), nb::arg("width"), nb::arg("context") = nb::none(),
+      "Create a signless integer type");
+  integerType.def_classmethod(
+      "get_signed",
+      [](const nb::object &cls, unsigned width, MlirContext ctx) {
+        return cls(mlirIntegerTypeSignedGet(ctx, width));
+      },
+      nb::arg("cls"), nb::arg("width"), nb::arg("context") = nb::none(),
+      "Create a signed integer type");
+  integerType.def_classmethod(
+      "get_unsigned",
+      [](const nb::object &cls, unsigned width, MlirContext ctx) {
+        return cls(mlirIntegerTypeUnsignedGet(ctx, width));
+      },
+      nb::arg("cls"), nb::arg("width"), nb::arg("context") = nb::none(),
+      "Create an unsigned integer type");
+  integerType.def_property_readonly(
+      "width", [](MlirType self) { return mlirIntegerTypeGetWidth(self); },
+      "Returns the width of the integer type");
+  integerType.def_property_readonly(
+      "is_signless",
+      [](MlirType self) { return mlirIntegerTypeIsSignless(self); },
+      "Returns whether this is a signless integer");
+  integerType.def_property_readonly(
+      "is_signed", [](MlirType self) { return mlirIntegerTypeIsSigned(self); },
+      "Returns whether this is a signed integer");
+  integerType.def_property_readonly(
+      "is_unsigned",
+      [](MlirType self) { return mlirIntegerTypeIsUnsigned(self); },
+      "Returns whether this is an unsigned integer");
+
+  // IndexType
+  mlir_type_subclass indexType(m, "IndexType", mlirTypeIsAIndex,
+                               mlirIndexTypeGetTypeID, &m);
+
+  indexType.def_classmethod(
+      "get",
+      [](const nb::object &cls, MlirContext ctx) {
+        return cls(mlirIndexTypeGet(ctx));
+      },
+      nb::arg("cls"), nb::arg("context") = nb::none(), "Create a index type.");
+
+  // FloatType (base class for specific float types)
+  mlir_type_subclass floatType(m, "FloatType", mlirTypeIsAFloat, nullptr, &m);
+  floatType.def_property_readonly(
+      "width", [](MlirType self) { return mlirFloatTypeGetWidth(self); },
+      "Returns the width of the floating-point type");
+
+  // Float4E2M1FNType
+  mlir_type_subclass float4E2M1FNType(
+      m, "Float4E2M1FNType", mlirTypeIsAFloat4E2M1FN, floatType.get_class(),
+      mlirFloat4E2M1FNTypeGetTypeID, &m);
+  float4E2M1FNType.def_classmethod(
+      "get",
+      [](const nb::object &cls, MlirContext ctx) {
+        return cls(mlirFloat4E2M1FNTypeGet(ctx));
+      },
+      nb::arg("cls"), nb::arg("context") = nb::none(),
+      "Create a float4_e2m1fn type.");
+
+  // Float6E2M3FNType
+  mlir_type_subclass float6E2M3FNType(
+      m, "Float6E2M3FNType", mlirTypeIsAFloat6E2M3FN, floatType.get_class(),
+      mlirFloat6E2M3FNTypeGetTypeID, &m);
+  float6E2M3FNType.def_classmethod(
+      "get",
+      [](const nb::object &cls, MlirContext ctx) {
+        return cls(mlirFloat6E2M3FNTypeGet(ctx));
+      },
+      nb::arg("cls"), nb::arg("context") = nb::none(),
+      "Create a float6_e2m3fn type.");
+
+  // Float6E3M2FNType
+  mlir_type_subclass float6E3M2FNType(
+      m, "Float6E3M2FNType", mlirTypeIsAFloat6E3M2FN, floatType.get_class(),
+      mlirFloat6E3M2FNTypeGetTypeID, &m);
+  float6E3M2FNType.def_classmethod(
+      "get",
+      [](const nb::object &cls, MlirContext ctx) {
+        return cls(mlirFloat6E3M2FNTypeGet(ctx));
+      },
+      nb::arg("cls"), nb::arg("context") = nb::none(),
+      "Create a float6_e3m2fn type.");
+
+  // Float8E4M3FNType
+  mlir_type_subclass float8E4M3FNType(
+      m, "Float8E4M3FNType", mlirTypeIsAFloat8E4M3FN, floatType.get_class(),
+      mlirFloat8E4M3FNTypeGetTypeID, &m);
+  float8E4M3FNType.def_classmethod(
+      "get",
+      [](const nb::object &cls, MlirContext ctx) {
+        return cls(mlirFloat8E4M3FNTypeGet(ctx));
+      },
+      nb::arg("cls"), nb::arg("context") = nb::none(),
+      "Create a float8_e4m3fn type.");
+
+  // Float8E5M2Type
+  mlir_type_subclass float8E5M2Type(m, "Float8E5M2Type", mlirTypeIsAFloat8E5M2,
+                                    floatType.get_class(),
+                                    mlirFloat8E5M2TypeGetTypeID, &m);
+  float8E5M2Type.def_classmethod(
+      "get",
+      [](const nb::object &cls, MlirContext ctx) {
+        return cls(mlirFloat8E5M2TypeGet(ctx));
+      },
+      nb::arg("cls"), nb::arg("context") = nb::none(),
+      "Create a float8_e5m2 type.");
+
+  // Float8E4M3Type
+  mlir_type_subclass float8E4M3Type(m, "Float8E4M3Type", mlirTypeIsAFloat8E4M3,
+                                    floatType.get_class(),
+                                    mlirFloat8E4M3TypeGetTypeID, &m);
+  float8E4M3Type.def_classmethod(
+      "get",
+      [](const nb::object &cls, MlirContext ctx) {
+        return cls(mlirFloat8E4M3TypeGet(ctx));
+      },
+      nb::arg("cls"), nb::arg("context") = nb::none(),
+      "Create a float8_e4m3 type.");
+
+  // Float8E4M3FNUZType
+  mlir_type_subclass float8E4M3FNUZType(
+      m, "Float8E4M3FNUZType", mlirTypeIsAFloat8E4M3FNUZ, 
floatType.get_class(),
+      mlirFloat8E4M3FNUZTypeGetTypeID, &m);
+  float8E4M3FNUZType.def_classmethod(
+      "get",
+      [](const nb::object &cls, MlirContext ctx) {
+        return cls(mlirFloat8E4M3FNUZTypeGet(ctx));
+      },
+      nb::arg("cls"), nb::arg("context") = nb::none(),
+      "Create a float8_e4m3fnuz type.");
+
+  // Float8E4M3B11FNUZType
+  mlir_type_subclass float8E4M3B11FNUZType(
+      m, "Float8E4M3B11FNUZType", mlirTypeIsAFloat8E4M3B11FNUZ,
+      floatType.get_class(), mlirFloat8E4M3B11FNUZTypeGetTypeID, &m);
+  float8E4M3B11FNUZType.def_classmethod(
+      "get",
+      [](const nb::object &cls, MlirContext ctx) {
+        return cls(mlirFloat8E4M3B11FNUZTypeGet(ctx));
+      },
+      nb::arg("cls"), nb::arg("context") = nb::none(),
+      "Create a float8_e4m3b11fnuz type.");
+
+  // Float8E5M2FNUZType
+  mlir_type_subclass float8E5M2FNUZType(
+      m, "Float8E5M2FNUZType", mlirTypeIsAFloat8E5M2FNUZ, 
floatType.get_class(),
+      mlirFloat8E5M2FNUZTypeGetTypeID, &m);
+  float8E5M2FNUZType.def_classmethod(
+      "get",
+      [](const nb::object &cls, MlirContext ctx) {
+        return cls(mlirFloat8E5M2FNUZTypeGet(ctx));
+      },
+      nb::arg("cls"), nb::arg("context") = nb::none(),
+      "Create a float8_e5m2fnuz type.");
+
+  // Float8E3M4Type
+  mlir_type_subclass float8E3M4Type(m, "Float8E3M4Type", mlirTypeIsAFloat8E3M4,
+                                    floatType.get_class(),
+                                    mlirFloat8E3M4TypeGetTypeID, &m);
+  float8E3M4Type.def_classmethod(
+      "get",
+      [](const nb::object &cls, MlirContext ctx) {
+        return cls(mlirFloat8E3M4TypeGet(ctx));
+      },
+      nb::arg("cls"), nb::arg("context") = nb::none(),
+      "Create a float8_e3m4 type.");
+
+  // Float8E8M0FNUType
+  mlir_type_subclass float8E8M0FNUType(
+      m, "Float8E8M0FNUType", mlirTypeIsAFloat8E8M0FNU, floatType.get_class(),
+      mlirFloat8E8M0FNUTypeGetTypeID, &m);
+  float8E8M0FNUType.def_classmethod(
+      "get",
+      [](const nb::object &cls, MlirContext ctx) {
+        return cls(mlirFloat8E8M0FNUTypeGet(ctx));
+      },
+      nb::arg("cls"), nb::arg("context") = nb::none(),
+      "Create a float8_e8m0fnu type.");
+
+  // BF16Type
+  mlir_type_subclass bf16Type(m, "BF16Type", mlirTypeIsABF16,
+                              floatType.get_class(), mlirBFloat16TypeGetTypeID,
+                              &m);
+  bf16Type.def_classmethod(
+      "get",
+      [](const nb::object &cls, MlirContext ctx) {
+        return cls(mlirBF16TypeGet(ctx));
+      },
+      nb::arg("cls"), nb::arg("context") = nb::none(), "Create a bf16 type.");
+
+  // F16Type
+  mlir_type_subclass f16Type(m, "F16Type", mlirTypeIsAF16,
+                             floatType.get_class(), mlirFloat16TypeGetTypeID,
+                             &m);
+  f16Type.def_classmethod(
+      "get",
+      [](const nb::object &cls, MlirContext ctx) {
+        return cls(mlirF16TypeGet(ctx));
+      },
+      nb::arg("cls"), nb::arg("context") = nb::none(), "Create a f16 type.");
+
+  // FloatTF32Type
+  mlir_type_subclass tf32Type(m, "FloatTF32Type", mlirTypeIsATF32,
+                              floatType.get_class(), 
mlirFloatTF32TypeGetTypeID,
+                              &m);
+  tf32Type.def_classmethod(
+      "get",
+      [](const nb::object &cls, MlirContext ctx) {
+        return cls(mlirTF32TypeGet(ctx));
+      },
+      nb::arg("cls"), nb::arg("context") = nb::none(), "Create a tf32 type.");
+
+  // F32Type
+  mlir_type_subclass f32Type(m, "F32Type", mlirTypeIsAF32,
+                             floatType.get_class(), mlirFloat32TypeGetTypeID,
+                             &m);
+  f32Type.def_classmethod(
+      "get",
+      [](const nb::object &cls, MlirContext ctx) {
+        return cls(mlirF32TypeGet(ctx));
+      },
+      nb::arg("cls"), nb::arg("context") = nb::none(), "Create a f32 type.");
+
+  // F64Type
+  mlir_type_subclass f64Type(m, "F64Type", mlirTypeIsAF64,
+                             floatType.get_class(), mlirFloat64TypeGetTypeID,
+                             &m);
+  f64Type.def_classmethod(
+      "get",
+      [](const nb::object &cls, MlirContext ctx) {
+        return cls(mlirF64TypeGet(ctx));
+      },
+      nb::arg("cls"), nb::arg("context") = nb::none(), "Create a f64 type.");
+
+  // NoneType
+  mlir_type_subclass noneType(m, "NoneType", mlirTypeIsANone,
+                              mlirNoneTypeGetTypeID, &m);
+  noneType.def_classmethod(
+      "get",
+      [](const nb::object &cls, MlirContext ctx) {
+        return cls(mlirNoneTypeGet(ctx));
+      },
+      nb::arg("cls"), nb::arg("context") = nb::none(), "Create a none type.");
+
+  // ComplexType
+  mlir_type_subclass complexType(m, "ComplexType", mlirTypeIsAComplex,
+                                 mlirComplexTypeGetTypeID, &m);
+  complexType.def_classmethod(
+      "get",
+      [](const nb::object &cls, MlirType elementType) {
+        // The element must be a floating point or integer scalar type.
+        if (mlirTypeIsAIntegerOrFloat(elementType)) {
+          return cls(mlirComplexTypeGet(elementType));
+        }
+        throw nb::value_error("Invalid element type for ComplexType: expected "
+                              "floating point or integer type.");
+      },
+      "Create a complex type");
+  complexType.def_property_readonly(
+      "element_type",
+      [](MlirType self) { return mlirComplexTypeGetElementType(self); },
+      "Returns element type.");
+
+  // TupleType
+  mlir_type_subclass tupleType(m, "TupleType", mlirTypeIsATuple,
+                               mlirTupleTypeGetTypeID, &m);
+  tupleType.def_classmethod(
+      "get_tuple",
+      [](const nb::object &cls, std::vector<MlirType> elements,
+         MlirContext ctx) {
+        return cls(mlirTupleTypeGet(ctx, elements.size(), elements.data()));
+      },
+      nb::arg("cls"), nb::arg("elements"), nb::arg("context") = nb::none(),
+      "Create a tuple type");
+  tupleType.def(
+      "get_type",
+      [](MlirType self, intptr_t pos) {
+        return mlirTupleTypeGetType(self, pos);
+      },
+      nb::arg("pos"), "Returns the pos-th type in the tuple type.");
+  tupleType.def_property_readonly(
+      "num_types", [](MlirType self) { return mlirTupleTypeGetNumTypes(self); 
},
+      "Returns the number of types contained in a tuple.");
+
+  // FunctionType
+  mlir_type_subclass functionType(m, "FunctionType", mlirTypeIsAFunction,
+                                  mlirFunctionTypeGetTypeID, &m);
+  functionType.def_classmethod(
+      "get",
+      [](const nb::object &cls, std::vector<MlirType> inputs,
+         std::vector<MlirType> results, MlirContext ctx) {
+        return cls(mlirFunctionTypeGet(ctx, inputs.size(), inputs.data(),
+                                       results.size(), results.data()));
+      },
+      nb::arg("cls"), nb::arg("inputs"), nb::arg("results"),
+      nb::arg("context") = nb::none(),
+      "Gets a FunctionType from a list of input and result types");
+  functionType.def_property_readonly(
+      "inputs",
+      [](MlirType self) {
+        nb::list types;
+        for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
+             ++i) {
+          types.append(mlirFunctionTypeGetInput(self, i));
+        }
+        return types;
+      },
+      "Returns the list of input types in the FunctionType.");
+  functionType.def_property_readonly(
+      "results",
+      [](MlirType self) {
+        nb::list types;
+        for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
+             ++i) {
+          types.append(mlirFunctionTypeGetResult(self, i));
+        }
+        return types;
+      },
+      "Returns the list of result types in the FunctionType.");
+
+  // OpaqueType
+  mlir_type_subclass opaqueType(m, "OpaqueType", mlirTypeIsAOpaque,
+                                mlirOpaqueTypeGetTypeID, &m);
+  opaqueType.def_classmethod(
+      "get",
+      [](const nb::object &cls, const std::string &dialectNamespace,
+         const std::string &typeData, MlirContext ctx) {
+        MlirStringRef dialectNs = mlirStringRefCreate(dialectNamespace.data(),
+                                                      dialectNamespace.size());
+        MlirStringRef data =
+            mlirStringRefCreate(typeData.data(), typeData.size());
+        return cls(mlirOpaqueTypeGet(ctx, dialectNs, data));
+      },
+      nb::arg("cls"), nb::arg("dialect_namespace"), nb::arg("buffer"),
+      nb::arg("context") = nb::none(),
+      "Create an unregistered (opaque) dialect type.");
+  opaqueType.def_property_readonly(
+      "dialect_namespace",
+      [](MlirType self) {
+        MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self);
+        return nb::str(stringRef.data, stringRef.length);
+      },
+      "Returns the dialect namespace for the Opaque type as a string.");
+  opaqueType.def_property_readonly(
+      "data",
+      [](MlirType self) {
+        MlirStringRef stringRef = mlirOpaqueTypeGetData(self);
+        return nb::str(stringRef.data, stringRef.length);
+      },
+      "Returns the data for the Opaque type as a string.");
+}
 
 } // namespace
 
@@ -977,202 +865,17 @@ class PyUnrankedMemRefType
   }
 };
 
-/// Tuple Type subclass - TupleType.
-class PyTupleType : public PyConcreteType<PyTupleType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirTupleTypeGetTypeID;
-  static constexpr const char *pyClassName = "TupleType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get_tuple",
-        [](const std::vector<PyType> &elements,
-           DefaultingPyMlirContext context) {
-          std::vector<MlirType> mlirElements;
-          mlirElements.reserve(elements.size());
-          for (const auto &element : elements)
-            mlirElements.push_back(element.get());
-          MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
-                                        mlirElements.data());
-          return PyTupleType(context->getRef(), t);
-        },
-        nb::arg("elements"), nb::arg("context") = nb::none(),
-        "Create a tuple type");
-    c.def_static(
-        "get_tuple",
-        [](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
-          MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
-                                        elements.data());
-          return PyTupleType(context->getRef(), t);
-        },
-        nb::arg("elements"), nb::arg("context") = nb::none(),
-        // clang-format off
-        nb::sig("def get_tuple(elements: Sequence[Type], context: Context | 
None = None) -> TupleType"),
-        // clang-format on
-        "Create a tuple type");
-    c.def(
-        "get_type",
-        [](PyTupleType &self, intptr_t pos) -> nb::typed<nb::object, PyType> {
-          return PyType(self.getContext(), mlirTupleTypeGetType(self, pos))
-              .maybeDownCast();
-        },
-        nb::arg("pos"), "Returns the pos-th type in the tuple type.");
-    c.def_prop_ro(
-        "num_types",
-        [](PyTupleType &self) -> intptr_t {
-          return mlirTupleTypeGetNumTypes(self);
-        },
-        "Returns the number of types contained in a tuple.");
-  }
-};
-
-/// Function type.
-class PyFunctionType : public PyConcreteType<PyFunctionType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFunctionTypeGetTypeID;
-  static constexpr const char *pyClassName = "FunctionType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](std::vector<PyType> inputs, std::vector<PyType> results,
-           DefaultingPyMlirContext context) {
-          std::vector<MlirType> mlirInputs;
-          mlirInputs.reserve(inputs.size());
-          for (const auto &input : inputs)
-            mlirInputs.push_back(input.get());
-          std::vector<MlirType> mlirResults;
-          mlirResults.reserve(results.size());
-          for (const auto &result : results)
-            mlirResults.push_back(result.get());
-
-          MlirType t = mlirFunctionTypeGet(context->get(), inputs.size(),
-                                           mlirInputs.data(), results.size(),
-                                           mlirResults.data());
-          return PyFunctionType(context->getRef(), t);
-        },
-        nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
-        "Gets a FunctionType from a list of input and result types");
-    c.def_static(
-        "get",
-        [](std::vector<MlirType> inputs, std::vector<MlirType> results,
-           DefaultingPyMlirContext context) {
-          MlirType t =
-              mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
-                                  results.size(), results.data());
-          return PyFunctionType(context->getRef(), t);
-        },
-        nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
-        // clang-format off
-        nb::sig("def get(inputs: Sequence[Type], results: Sequence[Type], 
context: Context | None = None) -> FunctionType"),
-        // clang-format on
-        "Gets a FunctionType from a list of input and result types");
-    c.def_prop_ro(
-        "inputs",
-        [](PyFunctionType &self) {
-          MlirType t = self;
-          nb::list types;
-          for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
-               ++i) {
-            types.append(mlirFunctionTypeGetInput(t, i));
-          }
-          return types;
-        },
-        "Returns the list of input types in the FunctionType.");
-    c.def_prop_ro(
-        "results",
-        [](PyFunctionType &self) {
-          nb::list types;
-          for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
-               ++i) {
-            types.append(mlirFunctionTypeGetResult(self, i));
-          }
-          return types;
-        },
-        "Returns the list of result types in the FunctionType.");
-  }
-};
-
-static MlirStringRef toMlirStringRef(const std::string &s) {
-  return mlirStringRefCreate(s.data(), s.size());
-}
-
-/// Opaque Type subclass - OpaqueType.
-class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirOpaqueTypeGetTypeID;
-  static constexpr const char *pyClassName = "OpaqueType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](const std::string &dialectNamespace, const std::string &typeData,
-           DefaultingPyMlirContext context) {
-          MlirType type = mlirOpaqueTypeGet(context->get(),
-                                            toMlirStringRef(dialectNamespace),
-                                            toMlirStringRef(typeData));
-          return PyOpaqueType(context->getRef(), type);
-        },
-        nb::arg("dialect_namespace"), nb::arg("buffer"),
-        nb::arg("context") = nb::none(),
-        "Create an unregistered (opaque) dialect type.");
-    c.def_prop_ro(
-        "dialect_namespace",
-        [](PyOpaqueType &self) {
-          MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self);
-          return nb::str(stringRef.data, stringRef.length);
-        },
-        "Returns the dialect namespace for the Opaque type as a string.");
-    c.def_prop_ro(
-        "data",
-        [](PyOpaqueType &self) {
-          MlirStringRef stringRef = mlirOpaqueTypeGetData(self);
-          return nb::str(stringRef.data, stringRef.length);
-        },
-        "Returns the data for the Opaque type as a string.");
-  }
-};
-
 } // namespace
 
 void mlir::python::populateIRTypes(nb::module_ &m) {
-  PyIntegerType::bind(m);
-  PyFloatType::bind(m);
-  PyIndexType::bind(m);
-  PyFloat4E2M1FNType::bind(m);
-  PyFloat6E2M3FNType::bind(m);
-  PyFloat6E3M2FNType::bind(m);
-  PyFloat8E4M3FNType::bind(m);
-  PyFloat8E5M2Type::bind(m);
-  PyFloat8E4M3Type::bind(m);
-  PyFloat8E4M3FNUZType::bind(m);
-  PyFloat8E4M3B11FNUZType::bind(m);
-  PyFloat8E5M2FNUZType::bind(m);
-  PyFloat8E3M4Type::bind(m);
-  PyFloat8E8M0FNUType::bind(m);
-  PyBF16Type::bind(m);
-  PyF16Type::bind(m);
-  PyTF32Type::bind(m);
-  PyF32Type::bind(m);
-  PyF64Type::bind(m);
-  PyNoneType::bind(m);
-  PyComplexType::bind(m);
+  // Populate types using mlir_type_subclass
+  populateIRTypesModule(m);
+
+  // Keep PyShapedType and its subclasses that weren't replaced
   PyShapedType::bind(m);
   PyVectorType::bind(m);
   PyRankedTensorType::bind(m);
   PyUnrankedTensorType::bind(m);
   PyMemRefType::bind(m);
   PyUnrankedMemRefType::bind(m);
-  PyTupleType::bind(m);
-  PyFunctionType::bind(m);
-  PyOpaqueType::bind(m);
 }
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp 
b/mlir/lib/Bindings/Python/MainModule.cpp
index ba767ad6692cf..fb73beda4cf88 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -145,6 +145,21 @@ NB_MODULE(_mlir, m) {
 
   // Define and populate IR submodule.
   auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
+  irModule.def(
+      MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR,
+      [](MlirTypeID mlirTypeID, bool replace) -> nb::object {
+        return nb::cpp_function([mlirTypeID, replace](
+                                    nb::callable typeCaster) -> nb::object {
+          PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace);
+          return typeCaster;
+        });
+      },
+      // clang-format off
+    nb::sig("def register_type_caster(typeid: _mlir.ir.TypeID, *, replace: 
bool = False) "
+                      "-> typing.Callable[[typing.Callable[[T], U]], 
typing.Callable[[T], U]]"),
+      // clang-format on
+      "typeid"_a, nb::kw_only(), "replace"_a = false,
+      "Register a type caster for casting MLIR types to custom user types.");
   populateIRCore(irModule);
   populateIRAffine(irModule);
   populateIRAttributes(irModule);
diff --git a/mlir/test/python/dialects/arith_dialect.py 
b/mlir/test/python/dialects/arith_dialect.py
index c9af5e7b46db8..ad318238b77c6 100644
--- a/mlir/test/python/dialects/arith_dialect.py
+++ b/mlir/test/python/dialects/arith_dialect.py
@@ -54,10 +54,10 @@ def _binary_op(lhs, rhs, op: str) -> "ArithValue":
         op = getattr(arith, f"{op}Op")
         return op(lhs, rhs).result
 
-    @register_value_caster(F16Type.static_typeid)
-    @register_value_caster(F32Type.static_typeid)
-    @register_value_caster(F64Type.static_typeid)
-    @register_value_caster(IntegerType.static_typeid)
+    @register_value_caster(F16Type.get_static_typeid())
+    @register_value_caster(F32Type.get_static_typeid())
+    @register_value_caster(F64Type.get_static_typeid())
+    @register_value_caster(IntegerType.get_static_typeid())
     class ArithValue(Value):
         def __init__(self, v):
             super().__init__(v)
diff --git a/mlir/test/python/ir/builtin_types.py 
b/mlir/test/python/ir/builtin_types.py
index 54863253fc770..20509050eda9f 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -185,7 +185,7 @@ def testStandardTypeCasts():
     try:
         tillegal = IntegerType(Type.parse("f32", ctx))
     except ValueError as e:
-        # CHECK: ValueError: Cannot cast type to IntegerType (from Type(f32))
+        # CHECK: ValueError: Cannot cast type to IntegerType (from 
F32Type(f32))
         print("ValueError:", e)
     else:
         print("Exception not produced")
@@ -302,7 +302,7 @@ def testComplexType():
         try:
             complex_invalid = ComplexType.get(index)
         except ValueError as e:
-            # CHECK: invalid 'Type(index)' and expected floating point or 
integer type.
+            # CHECK: Invalid element type for ComplexType: expected floating 
point or integer type.
             print(e)
         else:
             print("Exception not produced")
@@ -714,7 +714,8 @@ def testTypeIDs():
         # mlirTypeGetTypeID(self) for an instance.
         # CHECK: all equal
         for t1, t2 in types:
-            tid1, tid2 = t1.static_typeid, Type(t2).typeid
+            # TODO: remove the alternative once mlir_type_subclass transition 
is complete.
+            tid1, tid2 = t1.static_typeid if hasattr(t1, "static_typeid") else 
t1.get_static_typeid(), Type(t2).typeid
             assert tid1 == tid2 and hash(tid1) == hash(
                 tid2
             ), f"expected hash and value equality {t1} {t2}"
@@ -728,7 +729,9 @@ def testTypeIDs():
 
         # CHECK: all equal
         for t1, t2 in typeid_dict.items():
-            assert t1.static_typeid == t2.typeid and hash(t1.static_typeid) == 
hash(
+            # TODO: remove the alternative once mlir_type_subclass transition 
is complete.
+            tid1 = t1.static_typeid if hasattr(t1, "static_typeid") else 
t1.get_static_typeid()
+            assert tid1 == t2.typeid and hash(tid1) == hash(
                 t2.typeid
             ), f"expected hash and value equality {t1} {t2}"
         else:
diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py
index 4a241afb8e89d..9d9b6c2090974 100644
--- a/mlir/test/python/ir/value.py
+++ b/mlir/test/python/ir/value.py
@@ -361,7 +361,7 @@ def __init__(self, v):
         def __str__(self):
             return super().__str__().replace(Value.__name__, 
NOPBlockArg.__name__)
 
-    @register_value_caster(IntegerType.static_typeid)
+    @register_value_caster(IntegerType.get_static_typeid())
     def cast_int(v) -> Value:
         print("in caster", v.__class__.__name__)
         if isinstance(v, OpResult):
@@ -425,7 +425,7 @@ def reduction(arg0, arg1):
 
     try:
 
-        @register_value_caster(IntegerType.static_typeid)
+        @register_value_caster(IntegerType.get_static_typeid())
         def dont_cast_int_shouldnt_register(v):
             ...
 
@@ -433,7 +433,7 @@ def dont_cast_int_shouldnt_register(v):
         # CHECK: Value caster is already registered: {{.*}}cast_int
         print(e)
 
-    @register_value_caster(IntegerType.static_typeid, replace=True)
+    @register_value_caster(IntegerType.get_static_typeid(), replace=True)
     def dont_cast_int(v) -> OpResult:
         assert isinstance(v, OpResult)
         print("don't cast", v.result_number, v)

_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to