This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 9c4f598 feat(dataclass): add asdict, astuple, and match_args support
(#556)
9c4f598 is described below
commit 9c4f59878ba11950eaae22fa50d2d6a3c9322043
Author: Junru Shao <[email protected]>
AuthorDate: Fri Apr 17 18:02:40 2026 -0700
feat(dataclass): add asdict, astuple, and match_args support (#556)
## Summary
Adds three stdlib-parity features to `tvm_ffi.dataclasses`:
1. **`asdict(obj, *, dict_factory=dict)`** — recursively converts a
`@py_class` / `@c_class` instance to a plain Python `dict`. FFI
containers (`Array`, `List`) recurse into `list`; (`Map`, `Dict`)
recurse into `dict`, yielding JSON-ready output.
2. **`astuple(obj, *, tuple_factory=tuple)`** — the tuple analogue of
`asdict`, with the same recursion rules.
3. **`match_args: bool = True`** parameter on `@py_class` and `@c_class`
— sets `cls.__match_args__` to the tuple of positional `__init__` field
names (`init=True and not kw_only`), enabling Python 3.10+ `match`
statements. Skipped when the class body already defines
`__match_args__`.
Semantics follow CPython's `dataclasses` module: `asdict`/`astuple`
raise `TypeError` for types and non-dataclass values; kw-only fields
(via `field(kw_only=True)`, decorator-level `kw_only=True`, or the
`KW_ONLY` sentinel) are excluded from `__match_args__`.
## Design notes
- `_is_ffi_dataclass_instance` filters FFI container instances (`Array`,
`List`, `Map`, `Dict`) from FFI dataclass instances — both share
`__tvm_ffi_type_info__`, so the container isinstance-check runs first
during recursion.
- `_set_match_args` walks the `TypeInfo.parent_type_info` chain in
parent-first order, matching the order of the auto-generated `__init__`
signature.
- Recursion uses an `_ATOMIC_TYPES` frozenset (mirroring stdlib
`dataclasses._ATOMIC_TYPES`) for the fast path on immutable leaves.
## Test plan
- [x] `uv run pytest tests/python/` — 2184 passed, 38 skipped, 3 xfailed
- [x] `pre-commit run --files <touched files>` — all hooks pass (ruff,
ty, format)
- [x] New coverage in `tests/python/test_dataclass_common.py`:
- `TestAsdict` (15 tests): basic, nested, inheritance, FFI
Array/List/Map/Dict recursion, `dict_factory`, result independence,
error paths
- `TestAstuple` (10 tests): basic, nested, recursion, `tuple_factory`,
error paths
- `TestMatchArgs` (11 tests): defaults, `init=False`, kw_only (field /
decorator / `KW_ONLY` sentinel), inheritance order, `match_args=False`
opt-out, user-defined `__match_args__` override, c_class basic and
inheritance
---
python/tvm_ffi/_dunder.py | 36 ++++
python/tvm_ffi/dataclasses/__init__.py | 7 +-
python/tvm_ffi/dataclasses/c_class.py | 14 +-
python/tvm_ffi/dataclasses/common.py | 173 ++++++++++++++-
python/tvm_ffi/dataclasses/py_class.py | 8 +
tests/python/test_dataclass_common.py | 376 +++++++++++++++++++++++++++++++++
6 files changed, 610 insertions(+), 4 deletions(-)
diff --git a/python/tvm_ffi/_dunder.py b/python/tvm_ffi/_dunder.py
index 46f0e1b..a1a58ab 100644
--- a/python/tvm_ffi/_dunder.py
+++ b/python/tvm_ffi/_dunder.py
@@ -214,6 +214,33 @@ def _make_replace(_type_info: TypeInfo) -> Callable[...,
Any]:
# ---------------------------------------------------------------------------
+def _set_match_args(cls: type, type_info: TypeInfo) -> None:
+ """Set ``cls.__match_args__`` from reflected fields.
+
+ Mirrors stdlib :func:`dataclasses.dataclass` semantics: the tuple
+ contains the names of positional ``__init__`` fields (``init=True``
+ and ``kw_only=False``), walking the parent chain in parent-first
+ order. If ``cls`` already defines ``__match_args__`` in its own
+ ``__dict__``, it is left untouched.
+ """
+ if "__match_args__" in cls.__dict__:
+ return
+ chain: list[TypeInfo] = []
+ ti: TypeInfo | None = type_info
+ while ti is not None:
+ chain.append(ti)
+ ti = ti.parent_type_info
+ names: list[str] = []
+ for ancestor in reversed(chain):
+ for tf in ancestor.fields or ():
+ df = tf.dataclass_field
+ if df is None:
+ continue
+ if df.init and not df.kw_only:
+ names.append(tf.name)
+ setattr(cls, "__match_args__", tuple(names))
+
+
def _install_dataclass_dunders( # noqa: PLR0912, PLR0915
cls: type,
*,
@@ -222,6 +249,7 @@ def _install_dataclass_dunders( # noqa: PLR0912, PLR0915
eq: bool,
order: bool,
unsafe_hash: bool,
+ match_args: bool = True,
py_class_mode: bool = False,
) -> None:
"""Install structural dunder methods on *cls*.
@@ -250,6 +278,11 @@ def _install_dataclass_dunders( # noqa: PLR0912, PLR0915
``NotImplemented`` for unrelated types.
unsafe_hash
If True, install ``__hash__`` using ``RecursiveHash``.
+ match_args
+ If True (default), set ``cls.__match_args__`` to the tuple of
+ positional ``__init__`` field names for use with ``match``
+ statements. Skipped when the class already defines
+ ``__match_args__`` in its body.
py_class_mode
If True, use a ``chandle`` guard for ``__init__`` so that
``super().__init__()`` is a no-op, and wrap user-defined
@@ -393,3 +426,6 @@ def _install_dataclass_dunders( # noqa: PLR0912, PLR0915
cls.__deepcopy__ = _make_deepcopy(type_info) # type:
ignore[attr-defined]
if "__replace__" not in cls.__dict__:
cls.__replace__ = _make_replace(type_info) # type:
ignore[attr-defined]
+
+ if match_args:
+ _set_match_args(cls, type_info)
diff --git a/python/tvm_ffi/dataclasses/__init__.py
b/python/tvm_ffi/dataclasses/__init__.py
index 73fdcca..850e9d4 100644
--- a/python/tvm_ffi/dataclasses/__init__.py
+++ b/python/tvm_ffi/dataclasses/__init__.py
@@ -16,17 +16,20 @@
# under the License.
"""FFI dataclass decorators: ``c_class`` for C++-backed types, ``py_class``
for Python-defined types."""
-from tvm_ffi.core import Object
+from tvm_ffi.core import MISSING, Object
from .c_class import c_class
-from .common import fields, is_dataclass, replace
+from .common import asdict, astuple, fields, is_dataclass, replace
from .field import KW_ONLY, Field, field
from .py_class import py_class
__all__ = [
"KW_ONLY",
+ "MISSING",
"Field",
"Object",
+ "asdict",
+ "astuple",
"c_class",
"field",
"fields",
diff --git a/python/tvm_ffi/dataclasses/c_class.py
b/python/tvm_ffi/dataclasses/c_class.py
index f30d657..17f66a4 100644
--- a/python/tvm_ffi/dataclasses/c_class.py
+++ b/python/tvm_ffi/dataclasses/c_class.py
@@ -69,6 +69,7 @@ def c_class(
eq: bool = False,
order: bool = False,
unsafe_hash: bool = False,
+ match_args: bool = True,
) -> Callable[[_T], _T]:
"""Register a C++ FFI class and install structural dunder methods.
@@ -105,6 +106,11 @@ def c_class(
*unsafe* because mutable fields contribute to the hash, so mutating
an object while it is in a set or dict key will break invariants.
Defaults to False.
+ match_args
+ If True (default), set ``__match_args__`` to a tuple of the
+ positional ``__init__`` field names (``init=True`` and not
+ ``kw_only``), enabling ``match`` statements. Ignored when the
+ class body already defines ``__match_args__``.
Returns
-------
@@ -151,7 +157,13 @@ def c_class(
_warn_missing_field_annotations(cls, type_info, stacklevel=2)
_attach_field_objects(cls, type_info)
_install_dataclass_dunders(
- cls, init=init, repr=repr, eq=eq, order=order,
unsafe_hash=unsafe_hash
+ cls,
+ init=init,
+ repr=repr,
+ eq=eq,
+ order=order,
+ unsafe_hash=unsafe_hash,
+ match_args=match_args,
)
return cls
diff --git a/python/tvm_ffi/dataclasses/common.py
b/python/tvm_ffi/dataclasses/common.py
index 88c3d62..4df4a53 100644
--- a/python/tvm_ffi/dataclasses/common.py
+++ b/python/tvm_ffi/dataclasses/common.py
@@ -18,11 +18,30 @@
from __future__ import annotations
+import copy
+from collections.abc import Callable
from typing import Any
+from ..container import Array, Dict, List, Map
from .field import Field
-__all__ = ["fields", "is_dataclass", "replace"]
+__all__ = ["asdict", "astuple", "fields", "is_dataclass", "replace"]
+
+# Exact-type fast path for atomic (immutable, non-recursive) values. Mirrors
+# :data:`dataclasses._ATOMIC_TYPES` from the standard library.
+_ATOMIC_TYPES: frozenset[type] = frozenset(
+ {
+ bool,
+ bytes,
+ complex,
+ float,
+ int,
+ str,
+ type(None),
+ type(Ellipsis),
+ type(NotImplemented),
+ }
+)
def is_dataclass(obj: Any) -> bool:
@@ -72,3 +91,155 @@ def replace(obj: Any, /, **changes: Any) -> Any:
still replaceable.
"""
return obj.__replace__(**changes)
+
+
+def _is_ffi_dataclass_instance(obj: Any) -> bool:
+ """Return True when *obj* is a ``@c_class`` / ``@py_class`` **instance**
(not a type)."""
+ if isinstance(obj, type):
+ return False
+ return getattr(type(obj), "__tvm_ffi_type_info__", None) is not None
+
+
+def _asdict_inner( # noqa: PLR0911, PLR0912
+ obj: Any, dict_factory: Callable[..., Any]
+) -> Any:
+ obj_type = type(obj)
+ if obj_type in _ATOMIC_TYPES:
+ return obj
+ # FFI containers are treated as their stdlib analogues so the result is
+ # plain Python data — handy for JSON serialisation, the main use case.
+ if isinstance(obj, (Array, List)):
+ return [_asdict_inner(v, dict_factory) for v in obj]
+ if isinstance(obj, (Map, Dict)):
+ return dict_factory(
+ [
+ (_asdict_inner(k, dict_factory), _asdict_inner(v,
dict_factory))
+ for k, v in obj.items()
+ ]
+ )
+ if _is_ffi_dataclass_instance(obj):
+ fs = fields(obj)
+ if dict_factory is dict:
+ return {f.name: _asdict_inner(getattr(obj, f.name), dict) for f in
fs} # ty: ignore[invalid-argument-type]
+ return dict_factory(
+ [(f.name, _asdict_inner(getattr(obj, f.name), dict_factory)) for f
in fs] # ty: ignore[invalid-argument-type]
+ )
+ if obj_type is list:
+ return [_asdict_inner(v, dict_factory) for v in obj]
+ if obj_type is dict:
+ return {
+ _asdict_inner(k, dict_factory): _asdict_inner(v, dict_factory) for
k, v in obj.items()
+ }
+ if obj_type is tuple:
+ return tuple(_asdict_inner(v, dict_factory) for v in obj)
+ if issubclass(obj_type, tuple):
+ if hasattr(obj, "_fields"): # namedtuple
+ return obj_type(*[_asdict_inner(v, dict_factory) for v in obj])
+ return obj_type(_asdict_inner(v, dict_factory) for v in obj)
+ if issubclass(obj_type, dict):
+ if hasattr(obj_type, "default_factory"):
+ result = obj_type(obj.default_factory)
+ for k, v in obj.items():
+ result[_asdict_inner(k, dict_factory)] = _asdict_inner(v,
dict_factory)
+ return result
+ return obj_type(
+ (_asdict_inner(k, dict_factory), _asdict_inner(v, dict_factory))
for k, v in obj.items()
+ )
+ if issubclass(obj_type, list):
+ return obj_type(_asdict_inner(v, dict_factory) for v in obj)
+ return copy.deepcopy(obj)
+
+
+def _astuple_inner(obj: Any, tuple_factory: Callable[..., Any]) -> Any: #
noqa: PLR0911
+ obj_type = type(obj)
+ if obj_type in _ATOMIC_TYPES:
+ return obj
+ if isinstance(obj, (Array, List)):
+ return [_astuple_inner(v, tuple_factory) for v in obj]
+ if isinstance(obj, (Map, Dict)):
+ return {
+ _astuple_inner(k, tuple_factory): _astuple_inner(v, tuple_factory)
+ for k, v in obj.items()
+ }
+ if _is_ffi_dataclass_instance(obj):
+ return tuple_factory(
+ [_astuple_inner(getattr(obj, f.name), tuple_factory) for f in
fields(obj)] # ty: ignore[invalid-argument-type]
+ )
+ if isinstance(obj, tuple) and hasattr(obj, "_fields"): # namedtuple
+ return obj_type(*[_astuple_inner(v, tuple_factory) for v in obj])
+ if isinstance(obj, (list, tuple)):
+ return obj_type(_astuple_inner(v, tuple_factory) for v in obj)
+ if isinstance(obj, dict):
+ if hasattr(obj_type, "default_factory"):
+ result = obj_type(obj.default_factory)
+ for k, v in obj.items():
+ result[_astuple_inner(k, tuple_factory)] = _astuple_inner(v,
tuple_factory)
+ return result
+ return obj_type(
+ (_astuple_inner(k, tuple_factory), _astuple_inner(v,
tuple_factory))
+ for k, v in obj.items()
+ )
+ return copy.deepcopy(obj)
+
+
+def asdict(obj: Any, *, dict_factory: Callable[..., Any] = dict) -> Any:
+ r"""Return the fields of a ``@c_class`` / ``@py_class`` instance as a new
dict.
+
+ Mirrors :func:`dataclasses.asdict` from the standard library. The
+ function recurses into nested FFI dataclass instances, FFI containers
+ (:class:`~tvm_ffi.Array`, :class:`~tvm_ffi.List`,
+ :class:`~tvm_ffi.Map`, :class:`~tvm_ffi.Dict`), and the built-in
+ ``list`` / ``tuple`` / ``dict``. FFI sequence containers are
+ converted to Python ``list``\ s and FFI mapping containers to
+ Python ``dict``\ s so the result is plain Python data, e.g. for
+ JSON serialisation. Any other value is copied with
+ :func:`copy.deepcopy`.
+
+ Parameters
+ ----------
+ obj
+ A ``@c_class`` / ``@py_class`` instance. Passing a type raises
+ :class:`TypeError`.
+ dict_factory
+ Callable used to construct the outer mapping and any nested
+ mapping recursed from an FFI dataclass. Defaults to :class:`dict`.
+
+ Raises
+ ------
+ TypeError
+ If ``obj`` is not a ``@c_class`` / ``@py_class`` instance.
+
+ """
+ if not _is_ffi_dataclass_instance(obj):
+ raise TypeError("asdict() should be called on c_class / py_class
instances")
+ return _asdict_inner(obj, dict_factory)
+
+
+def astuple(obj: Any, *, tuple_factory: Callable[..., Any] = tuple) -> Any:
+ """Return the fields of a ``@c_class`` / ``@py_class`` instance as a new
tuple.
+
+ Mirrors :func:`dataclasses.astuple` from the standard library. The
+ function recurses into nested FFI dataclass instances, FFI containers
+ (:class:`~tvm_ffi.Array`, :class:`~tvm_ffi.List`,
+ :class:`~tvm_ffi.Map`, :class:`~tvm_ffi.Dict`), and the built-in
+ ``list`` / ``tuple`` / ``dict``. Any other value is copied with
+ :func:`copy.deepcopy`.
+
+ Parameters
+ ----------
+ obj
+ A ``@c_class`` / ``@py_class`` instance. Passing a type raises
+ :class:`TypeError`.
+ tuple_factory
+ Callable used to construct the outer tuple and any nested tuple
+ recursed from an FFI dataclass. Defaults to :class:`tuple`.
+
+ Raises
+ ------
+ TypeError
+ If ``obj`` is not a ``@c_class`` / ``@py_class`` instance.
+
+ """
+ if not _is_ffi_dataclass_instance(obj):
+ raise TypeError("astuple() should be called on c_class / py_class
instances")
+ return _astuple_inner(obj, tuple_factory)
diff --git a/python/tvm_ffi/dataclasses/py_class.py
b/python/tvm_ffi/dataclasses/py_class.py
index 9f3d1f2..70ee25d 100644
--- a/python/tvm_ffi/dataclasses/py_class.py
+++ b/python/tvm_ffi/dataclasses/py_class.py
@@ -305,6 +305,7 @@ def _register_fields_into_type(
eq=params["eq"],
order=params["order"],
unsafe_hash=params["unsafe_hash"],
+ match_args=params["match_args"],
py_class_mode=True,
)
return True
@@ -450,6 +451,7 @@ def py_class( # noqa: PLR0913
eq: bool = False,
order: bool = False,
unsafe_hash: bool = False,
+ match_args: bool = True,
kw_only: bool = False,
structural_eq: str | None = None,
slots: bool = True,
@@ -507,6 +509,11 @@ def py_class( # noqa: PLR0913
Requires ``eq=True``.
unsafe_hash
If True, generate ``__hash__`` (unsafe for mutable objects).
+ match_args
+ If True (default), set ``__match_args__`` to a tuple of the
+ positional ``__init__`` field names (``init=True`` and not
+ ``kw_only``), enabling ``match`` statements. Ignored when the
+ class body already defines ``__match_args__``.
kw_only
If True, all fields are keyword-only in ``__init__`` by default.
structural_eq
@@ -551,6 +558,7 @@ def py_class( # noqa: PLR0913
"eq": eq,
"order": order,
"unsafe_hash": unsafe_hash,
+ "match_args": match_args,
"kw_only": kw_only,
"structural_eq": structural_eq,
}
diff --git a/tests/python/test_dataclass_common.py
b/tests/python/test_dataclass_common.py
index 95b860d..238886e 100644
--- a/tests/python/test_dataclass_common.py
+++ b/tests/python/test_dataclass_common.py
@@ -19,6 +19,7 @@
from __future__ import annotations
+import collections
import dataclasses as _dc
import itertools
import typing
@@ -29,7 +30,10 @@ import tvm_ffi
import tvm_ffi.testing
from tvm_ffi.core import MISSING, Object
from tvm_ffi.dataclasses import (
+ KW_ONLY,
Field,
+ asdict,
+ astuple,
field,
fields,
is_dataclass,
@@ -250,3 +254,375 @@ class TestReplace:
p2 = replace(p, a=10)
assert p2.a == 10
assert p.a == 3 # still read-only, original untouched
+
+
+# ---------------------------------------------------------------------------
+# asdict
+# ---------------------------------------------------------------------------
+class TestAsdict:
+ def test_c_class_basic(self) -> None:
+ p = tvm_ffi.testing.TestIntPair(1, 2)
+ assert asdict(p) == {"a": 1, "b": 2}
+
+ def test_py_class_basic(self) -> None:
+ @py_class(_k("PCAsdict"))
+ class PC(Object):
+ x: int
+ y: str
+
+ assert asdict(PC(x=1, y="hi")) == {"x": 1, "y": "hi"}
+
+ def test_py_class_nested(self) -> None:
+ @py_class(_k("PCInner"))
+ class Inner(Object):
+ x: int
+
+ @py_class(_k("PCOuter"))
+ class Outer(Object):
+ a: Inner
+ b: Inner
+
+ out = Outer(a=Inner(x=1), b=Inner(x=2))
+ assert asdict(out) == {"a": {"x": 1}, "b": {"x": 2}}
+
+ def test_py_class_inheritance(self) -> None:
+ @py_class(_k("PCParent"))
+ class P(Object):
+ x: int
+
+ @py_class(_k("PCChild"))
+ class C(P):
+ y: str
+
+ assert asdict(C(x=1, y="a")) == {"x": 1, "y": "a"}
+
+ def test_ffi_array_recurses_to_list(self) -> None:
+ @py_class(_k("PCWithArray"))
+ class PC(Object):
+ xs: tvm_ffi.Array
+
+ pc = PC(xs=tvm_ffi.Array([1, 2, 3]))
+ result = asdict(pc)
+ assert result == {"xs": [1, 2, 3]}
+ assert type(result["xs"]) is list
+
+ def test_ffi_list_recurses_to_list(self) -> None:
+ @py_class(_k("PCWithList"))
+ class PC(Object):
+ xs: tvm_ffi.List
+
+ pc = PC(xs=tvm_ffi.List([1, 2, 3]))
+ result = asdict(pc)
+ assert result == {"xs": [1, 2, 3]}
+ assert type(result["xs"]) is list
+
+ def test_ffi_map_recurses_to_dict(self) -> None:
+ @py_class(_k("PCWithMap"))
+ class PC(Object):
+ m: tvm_ffi.Map
+
+ pc = PC(m=tvm_ffi.Map({"a": 1, "b": 2}))
+ result = asdict(pc)
+ assert result == {"m": {"a": 1, "b": 2}}
+ assert type(result["m"]) is dict
+
+ def test_ffi_dict_recurses_to_dict(self) -> None:
+ @py_class(_k("PCWithDict"))
+ class PC(Object):
+ d: tvm_ffi.Dict
+
+ pc = PC(d=tvm_ffi.Dict({"a": 1, "b": 2}))
+ result = asdict(pc)
+ assert result == {"d": {"a": 1, "b": 2}}
+ assert type(result["d"]) is dict
+
+ def test_ffi_array_of_dataclasses(self) -> None:
+ @py_class(_k("PCItem"))
+ class Item(Object):
+ v: int
+
+ @py_class(_k("PCBox"))
+ class Box(Object):
+ items: tvm_ffi.Array
+
+ box = Box(items=tvm_ffi.Array([Item(v=1), Item(v=2)]))
+ assert asdict(box) == {"items": [{"v": 1}, {"v": 2}]}
+
+ def test_ffi_map_of_dataclasses(self) -> None:
+ @py_class(_k("PCItem2"))
+ class Item(Object):
+ v: int
+
+ @py_class(_k("PCBox2"))
+ class Box(Object):
+ items: tvm_ffi.Map
+
+ box = Box(items=tvm_ffi.Map({"a": Item(v=1), "b": Item(v=2)}))
+ assert asdict(box) == {"items": {"a": {"v": 1}, "b": {"v": 2}}}
+
+ def test_dict_factory(self) -> None:
+ @py_class(_k("PCDF"))
+ class PC(Object):
+ x: int
+ y: int
+
+ result = asdict(PC(x=1, y=2), dict_factory=collections.OrderedDict)
+ assert isinstance(result, collections.OrderedDict)
+ assert list(result.items()) == [("x", 1), ("y", 2)]
+
+ def test_dict_factory_recurses(self) -> None:
+ @py_class(_k("PCDFI"))
+ class Inner(Object):
+ v: int
+
+ @py_class(_k("PCDFO"))
+ class Outer(Object):
+ a: Inner
+
+ result = asdict(Outer(a=Inner(v=5)),
dict_factory=collections.OrderedDict)
+ assert isinstance(result, collections.OrderedDict)
+ assert isinstance(result["a"], collections.OrderedDict)
+
+ def test_default_factory_list_independent(self) -> None:
+ """Result must be a fresh ``list``, not aliased to any internal
state."""
+
+ @py_class(_k("PCInd"))
+ class PC(Object):
+ xs: tvm_ffi.Array
+
+ pc = PC(xs=tvm_ffi.Array([1, 2]))
+ d = asdict(pc)
+ d["xs"].append(99)
+ # Mutating the result must not affect the original.
+ assert list(pc.xs) == [1, 2]
+
+ def test_type_raises(self) -> None:
+ with pytest.raises(TypeError, match="c_class / py_class instances"):
+ asdict(tvm_ffi.testing.TestIntPair) # passing type, not instance
+
+ def test_non_dataclass_raises(self) -> None:
+ with pytest.raises(TypeError, match="c_class / py_class instances"):
+ asdict(42)
+ with pytest.raises(TypeError, match="c_class / py_class instances"):
+ asdict([1, 2, 3])
+
+ def test_defaultdict_preserved(self) -> None:
+ """``defaultdict`` round-trips with its ``default_factory`` intact.
+
+ Exercises the ``_asdict_inner`` defaultdict branch directly, since
+ FFI ``Any`` field storage converts a stored ``defaultdict`` into
+ an FFI ``Map`` on readback. Mirrors stdlib
+ ``dataclasses._asdict_inner``'s check on ``type(obj)``.
+ """
+ from tvm_ffi.dataclasses.common import _asdict_inner # noqa: PLC0415
+
+ dd = collections.defaultdict(list)
+ dd["a"].append(1)
+ dd["b"].append(2)
+ result = _asdict_inner(dd, dict)
+ assert type(result) is collections.defaultdict
+ assert result.default_factory is list
+ assert dict(result) == {"a": [1], "b": [2]}
+
+
+# ---------------------------------------------------------------------------
+# astuple
+# ---------------------------------------------------------------------------
+class TestAstuple:
+ def test_c_class_basic(self) -> None:
+ p = tvm_ffi.testing.TestIntPair(1, 2)
+ assert astuple(p) == (1, 2)
+
+ def test_py_class_basic(self) -> None:
+ @py_class(_k("PCAstuple"))
+ class PC(Object):
+ x: int
+ y: str
+
+ assert astuple(PC(x=1, y="hi")) == (1, "hi")
+
+ def test_py_class_nested(self) -> None:
+ @py_class(_k("PCInnerT"))
+ class Inner(Object):
+ x: int
+
+ @py_class(_k("PCOuterT"))
+ class Outer(Object):
+ a: Inner
+ b: Inner
+
+ out = Outer(a=Inner(x=1), b=Inner(x=2))
+ assert astuple(out) == ((1,), (2,))
+
+ def test_py_class_inheritance(self) -> None:
+ @py_class(_k("PCParentT"))
+ class P(Object):
+ x: int
+
+ @py_class(_k("PCChildT"))
+ class C(P):
+ y: str
+
+ assert astuple(C(x=1, y="a")) == (1, "a")
+
+ def test_ffi_array_recurses_to_list(self) -> None:
+ @py_class(_k("PCArrT"))
+ class PC(Object):
+ xs: tvm_ffi.Array
+
+ pc = PC(xs=tvm_ffi.Array([1, 2, 3]))
+ result = astuple(pc)
+ assert result == ([1, 2, 3],)
+ assert type(result[0]) is list
+
+ def test_ffi_map_recurses_to_dict(self) -> None:
+ @py_class(_k("PCMapT"))
+ class PC(Object):
+ m: tvm_ffi.Map
+
+ pc = PC(m=tvm_ffi.Map({"a": 1}))
+ result = astuple(pc)
+ assert result == ({"a": 1},)
+ assert type(result[0]) is dict
+
+ def test_tuple_factory(self) -> None:
+ @py_class(_k("PCTF"))
+ class PC(Object):
+ x: int
+ y: int
+
+ assert astuple(PC(x=1, y=2), tuple_factory=list) == [1, 2]
+
+ def test_tuple_factory_recurses(self) -> None:
+ @py_class(_k("PCTFI"))
+ class Inner(Object):
+ v: int
+
+ @py_class(_k("PCTFO"))
+ class Outer(Object):
+ a: Inner
+
+ result = astuple(Outer(a=Inner(v=5)), tuple_factory=list)
+ assert result == [[5]]
+
+ def test_type_raises(self) -> None:
+ with pytest.raises(TypeError, match="c_class / py_class instances"):
+ astuple(tvm_ffi.testing.TestIntPair)
+
+ def test_non_dataclass_raises(self) -> None:
+ with pytest.raises(TypeError, match="c_class / py_class instances"):
+ astuple(42)
+ with pytest.raises(TypeError, match="c_class / py_class instances"):
+ astuple([1, 2, 3])
+
+ def test_defaultdict_preserved(self) -> None:
+ """``defaultdict`` round-trips with its ``default_factory`` intact."""
+ from tvm_ffi.dataclasses.common import _astuple_inner # noqa: PLC0415
+
+ dd = collections.defaultdict(list)
+ dd["a"].append(1)
+ dd["b"].append(2)
+ result = _astuple_inner(dd, tuple)
+ assert type(result) is collections.defaultdict
+ assert result.default_factory is list
+ assert dict(result) == {"a": [1], "b": [2]}
+
+
+# ---------------------------------------------------------------------------
+# __match_args__
+# ---------------------------------------------------------------------------
+def _match_args(cls: type) -> object:
+ """Read ``__match_args__`` without tripping static attribute checks."""
+ return getattr(cls, "__match_args__")
+
+
+class TestMatchArgs:
+ def test_py_class_basic(self) -> None:
+ @py_class(_k("MAPy"))
+ class PC(Object):
+ x: int
+ y: str
+
+ assert _match_args(PC) == ("x", "y")
+
+ def test_py_class_init_false_excluded(self) -> None:
+ @py_class(_k("MAInitFalse"))
+ class PC(Object):
+ x: int
+ y: int = field(default=0, init=False)
+
+ assert _match_args(PC) == ("x",)
+
+ def test_py_class_kw_only_field_excluded(self) -> None:
+ @py_class(_k("MAKwField"))
+ class PC(Object):
+ x: int
+ y: int = field(kw_only=True)
+
+ assert _match_args(PC) == ("x",)
+
+ def test_py_class_kw_only_decorator_excludes_all(self) -> None:
+ @py_class(_k("MAKwDeco"), kw_only=True)
+ class PC(Object):
+ x: int
+ y: int
+
+ assert _match_args(PC) == ()
+
+ def test_py_class_kw_only_sentinel(self) -> None:
+ @py_class(_k("MAKwSent"))
+ class PC(Object):
+ x: int
+ _: KW_ONLY
+ y: int
+ z: int
+
+ assert _match_args(PC) == ("x",)
+
+ def test_py_class_inheritance_parent_first(self) -> None:
+ @py_class(_k("MAParent"))
+ class P(Object):
+ x: int
+ y: str
+
+ @py_class(_k("MAChild"))
+ class C(P):
+ z: float
+
+ assert _match_args(P) == ("x", "y")
+ assert _match_args(C) == ("x", "y", "z")
+
+ def test_py_class_opt_out(self) -> None:
+ @py_class(_k("MAOptOut"), match_args=False)
+ class PC(Object):
+ x: int
+ y: int
+
+ assert "__match_args__" not in PC.__dict__
+
+ def test_py_class_user_defined_preserved(self) -> None:
+ @py_class(_k("MAUser"))
+ class PC(Object):
+ x: int
+ y: int
+ __match_args__ = ("y",)
+
+ assert _match_args(PC) == ("y",)
+
+ def test_c_class_basic(self) -> None:
+ assert _match_args(tvm_ffi.testing.TestIntPair) == ("a", "b")
+
+ def test_c_class_inheritance_parent_first(self) -> None:
+ assert _match_args(tvm_ffi.testing._TestCxxClassDerived) == (
+ "v_i64",
+ "v_i32",
+ "v_f64",
+ "v_f32",
+ )
+
+ def test_tuple_type(self) -> None:
+ @py_class(_k("MATupT"))
+ class PC(Object):
+ x: int
+
+ assert isinstance(_match_args(PC), tuple)