This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 68b06e9 fix(dataclass): distinguish containers from dataclasses (#557)
68b06e9 is described below
commit 68b06e9e45a1daa0180994aab1d7ed2695e257b1
Author: Junru Shao <[email protected]>
AuthorDate: Mon Apr 20 19:50:46 2026 -0700
fix(dataclass): distinguish containers from dataclasses (#557)
---
python/tvm_ffi/dataclasses/c_class.py | 4 ++++
python/tvm_ffi/dataclasses/common.py | 19 +++++++++++++------
python/tvm_ffi/dataclasses/py_class.py | 13 +++++++++----
3 files changed, 26 insertions(+), 10 deletions(-)
diff --git a/python/tvm_ffi/dataclasses/c_class.py
b/python/tvm_ffi/dataclasses/c_class.py
index 17f66a4..501c56b 100644
--- a/python/tvm_ffi/dataclasses/c_class.py
+++ b/python/tvm_ffi/dataclasses/c_class.py
@@ -165,6 +165,10 @@ def c_class(
unsafe_hash=unsafe_hash,
match_args=match_args,
)
+ # Marker: distinguishes @c_class / @py_class types from FFI containers
+ # (Array, List, Map, Dict) that also have __tvm_ffi_type_info__ but are
+ # not dataclasses. Used by is_dataclass() in common.py.
+ setattr(cls, "__tvm_ffi_is_dataclass__", True)
return cls
return decorator
diff --git a/python/tvm_ffi/dataclasses/common.py
b/python/tvm_ffi/dataclasses/common.py
index 4df4a53..bde5fd0 100644
--- a/python/tvm_ffi/dataclasses/common.py
+++ b/python/tvm_ffi/dataclasses/common.py
@@ -45,9 +45,16 @@ _ATOMIC_TYPES: frozenset[type] = frozenset(
def is_dataclass(obj: Any) -> bool:
- """Return True if ``obj`` is a ``@c_class`` / ``@py_class`` type or
instance."""
+ """Return True if ``obj`` is a ``@c_class`` / ``@py_class`` type or
instance.
+
+ Returns False for FFI container types (:class:`~tvm_ffi.Array`,
+ :class:`~tvm_ffi.List`, :class:`~tvm_ffi.Map`, :class:`~tvm_ffi.Dict`)
+ even though they also carry ``__tvm_ffi_type_info__``; those are
+ reflected through :func:`~tvm_ffi.register_object` directly, not
+ through the dataclass decorators.
+ """
cls = obj if isinstance(obj, type) else type(obj)
- return getattr(cls, "__tvm_ffi_type_info__", None) is not None
+ return getattr(cls, "__tvm_ffi_is_dataclass__", False) is True
def fields(obj_or_cls: Any) -> tuple[Field, ...]:
@@ -63,13 +70,13 @@ def fields(obj_or_cls: Any) -> tuple[Field, ...]:
If ``obj_or_cls`` is not a ``@c_class`` / ``@py_class`` type or
instance.
"""
- cls = obj_or_cls if isinstance(obj_or_cls, type) else type(obj_or_cls)
- ti = getattr(cls, "__tvm_ffi_type_info__", None)
- if ti is None:
+ if not is_dataclass(obj_or_cls):
raise TypeError(
f"fields() argument must be a c_class or py_class type or
instance, "
f"got {type(obj_or_cls).__name__}"
)
+ cls = obj_or_cls if isinstance(obj_or_cls, type) else type(obj_or_cls)
+ ti = getattr(cls, "__tvm_ffi_type_info__", None)
chain = []
while ti is not None:
chain.append(ti)
@@ -97,7 +104,7 @@ def _is_ffi_dataclass_instance(obj: Any) -> bool:
"""Return True when *obj* is a ``@c_class`` / ``@py_class`` **instance**
(not a type)."""
if isinstance(obj, type):
return False
- return getattr(type(obj), "__tvm_ffi_type_info__", None) is not None
+ return is_dataclass(obj)
def _asdict_inner( # noqa: PLR0911, PLR0912
diff --git a/python/tvm_ffi/dataclasses/py_class.py
b/python/tvm_ffi/dataclasses/py_class.py
index cf41abd..e10e502 100644
--- a/python/tvm_ffi/dataclasses/py_class.py
+++ b/python/tvm_ffi/dataclasses/py_class.py
@@ -122,10 +122,11 @@ def _rollback_registration(cls: type, type_info: Any) ->
None:
core._rollback_py_class(type_info) # ty: ignore[unresolved-attribute]
# Remove from our own module-level resolution namespace.
_PY_CLASS_BY_MODULE.get(cls.__module__, {}).pop(cls.__name__, None)
- try:
- delattr(cls, "__tvm_ffi_type_info__")
- except AttributeError:
- pass
+ for attr in ("__tvm_ffi_type_info__", "__tvm_ffi_is_dataclass__"):
+ try:
+ delattr(cls, attr)
+ except AttributeError:
+ pass
# ---------------------------------------------------------------------------
@@ -619,6 +620,10 @@ def py_class( # noqa: PLR0913
_rollback_registration(cls, info)
raise
+ # Marker: distinguishes @c_class / @py_class types from FFI containers
+ # (Array, List, Map, Dict) that also have __tvm_ffi_type_info__ but are
+ # not dataclasses. Used by is_dataclass() in common.py.
+ setattr(cls, "__tvm_ffi_is_dataclass__", True)
return cls
# Handle different calling conventions: