This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 9c4f598  feat(dataclass): add asdict, astuple, and match_args support 
(#556)
9c4f598 is described below

commit 9c4f59878ba11950eaae22fa50d2d6a3c9322043
Author: Junru Shao <[email protected]>
AuthorDate: Fri Apr 17 18:02:40 2026 -0700

    feat(dataclass): add asdict, astuple, and match_args support (#556)
    
    ## Summary
    
    Adds three stdlib-parity features to `tvm_ffi.dataclasses`:
    
    1. **`asdict(obj, *, dict_factory=dict)`** — recursively converts a
    `@py_class` / `@c_class` instance to a plain Python `dict`. FFI
    containers (`Array`, `List`) recurse into `list`; (`Map`, `Dict`)
    recurse into `dict`, yielding JSON-ready output.
    2. **`astuple(obj, *, tuple_factory=tuple)`** — the tuple analogue of
    `asdict`, with the same recursion rules.
    3. **`match_args: bool = True`** parameter on `@py_class` and `@c_class`
    — sets `cls.__match_args__` to the tuple of positional `__init__` field
    names (`init=True and not kw_only`), enabling Python 3.10+ `match`
    statements. Skipped when the class body already defines
    `__match_args__`.
    
    Semantics follow CPython's `dataclasses` module: `asdict`/`astuple`
    raise `TypeError` for types and non-dataclass values; kw-only fields
    (via `field(kw_only=True)`, decorator-level `kw_only=True`, or the
    `KW_ONLY` sentinel) are excluded from `__match_args__`.
    
    ## Design notes
    
    - `_is_ffi_dataclass_instance` filters FFI container instances (`Array`,
    `List`, `Map`, `Dict`) from FFI dataclass instances — both share
    `__tvm_ffi_type_info__`, so the container isinstance-check runs first
    during recursion.
    - `_set_match_args` walks the `TypeInfo.parent_type_info` chain in
    parent-first order, matching the order of the auto-generated `__init__`
    signature.
    - Recursion uses an `_ATOMIC_TYPES` frozenset (mirroring stdlib
    `dataclasses._ATOMIC_TYPES`) for the fast path on immutable leaves.
    
    ## Test plan
    
    - [x] `uv run pytest tests/python/` — 2184 passed, 38 skipped, 3 xfailed
    - [x] `pre-commit run --files <touched files>` — all hooks pass (ruff,
    ty, format)
    - [x] New coverage in `tests/python/test_dataclass_common.py`:
    - `TestAsdict` (15 tests): basic, nested, inheritance, FFI
    Array/List/Map/Dict recursion, `dict_factory`, result independence,
    error paths
    - `TestAstuple` (10 tests): basic, nested, recursion, `tuple_factory`,
    error paths
    - `TestMatchArgs` (11 tests): defaults, `init=False`, kw_only (field /
    decorator / `KW_ONLY` sentinel), inheritance order, `match_args=False`
    opt-out, user-defined `__match_args__` override, c_class basic and
    inheritance
---
 python/tvm_ffi/_dunder.py              |  36 ++++
 python/tvm_ffi/dataclasses/__init__.py |   7 +-
 python/tvm_ffi/dataclasses/c_class.py  |  14 +-
 python/tvm_ffi/dataclasses/common.py   | 173 ++++++++++++++-
 python/tvm_ffi/dataclasses/py_class.py |   8 +
 tests/python/test_dataclass_common.py  | 376 +++++++++++++++++++++++++++++++++
 6 files changed, 610 insertions(+), 4 deletions(-)

diff --git a/python/tvm_ffi/_dunder.py b/python/tvm_ffi/_dunder.py
index 46f0e1b..a1a58ab 100644
--- a/python/tvm_ffi/_dunder.py
+++ b/python/tvm_ffi/_dunder.py
@@ -214,6 +214,33 @@ def _make_replace(_type_info: TypeInfo) -> Callable[..., 
Any]:
 # ---------------------------------------------------------------------------
 
 
+def _set_match_args(cls: type, type_info: TypeInfo) -> None:
+    """Set ``cls.__match_args__`` from reflected fields.
+
+    Mirrors stdlib :func:`dataclasses.dataclass` semantics: the tuple
+    contains the names of positional ``__init__`` fields (``init=True``
+    and ``kw_only=False``), walking the parent chain in parent-first
+    order.  If ``cls`` already defines ``__match_args__`` in its own
+    ``__dict__``, it is left untouched.
+    """
+    if "__match_args__" in cls.__dict__:
+        return
+    chain: list[TypeInfo] = []
+    ti: TypeInfo | None = type_info
+    while ti is not None:
+        chain.append(ti)
+        ti = ti.parent_type_info
+    names: list[str] = []
+    for ancestor in reversed(chain):
+        for tf in ancestor.fields or ():
+            df = tf.dataclass_field
+            if df is None:
+                continue
+            if df.init and not df.kw_only:
+                names.append(tf.name)
+    setattr(cls, "__match_args__", tuple(names))
+
+
 def _install_dataclass_dunders(  # noqa: PLR0912, PLR0915
     cls: type,
     *,
@@ -222,6 +249,7 @@ def _install_dataclass_dunders(  # noqa: PLR0912, PLR0915
     eq: bool,
     order: bool,
     unsafe_hash: bool,
+    match_args: bool = True,
     py_class_mode: bool = False,
 ) -> None:
     """Install structural dunder methods on *cls*.
@@ -250,6 +278,11 @@ def _install_dataclass_dunders(  # noqa: PLR0912, PLR0915
         ``NotImplemented`` for unrelated types.
     unsafe_hash
         If True, install ``__hash__`` using ``RecursiveHash``.
+    match_args
+        If True (default), set ``cls.__match_args__`` to the tuple of
+        positional ``__init__`` field names for use with ``match``
+        statements.  Skipped when the class already defines
+        ``__match_args__`` in its body.
     py_class_mode
         If True, use a ``chandle`` guard for ``__init__`` so that
         ``super().__init__()`` is a no-op, and wrap user-defined
@@ -393,3 +426,6 @@ def _install_dataclass_dunders(  # noqa: PLR0912, PLR0915
         cls.__deepcopy__ = _make_deepcopy(type_info)  # type: 
ignore[attr-defined]
     if "__replace__" not in cls.__dict__:
         cls.__replace__ = _make_replace(type_info)  # type: 
ignore[attr-defined]
+
+    if match_args:
+        _set_match_args(cls, type_info)
diff --git a/python/tvm_ffi/dataclasses/__init__.py 
b/python/tvm_ffi/dataclasses/__init__.py
index 73fdcca..850e9d4 100644
--- a/python/tvm_ffi/dataclasses/__init__.py
+++ b/python/tvm_ffi/dataclasses/__init__.py
@@ -16,17 +16,20 @@
 # under the License.
 """FFI dataclass decorators: ``c_class`` for C++-backed types, ``py_class`` 
for Python-defined types."""
 
-from tvm_ffi.core import Object
+from tvm_ffi.core import MISSING, Object
 
 from .c_class import c_class
-from .common import fields, is_dataclass, replace
+from .common import asdict, astuple, fields, is_dataclass, replace
 from .field import KW_ONLY, Field, field
 from .py_class import py_class
 
 __all__ = [
     "KW_ONLY",
+    "MISSING",
     "Field",
     "Object",
+    "asdict",
+    "astuple",
     "c_class",
     "field",
     "fields",
diff --git a/python/tvm_ffi/dataclasses/c_class.py 
b/python/tvm_ffi/dataclasses/c_class.py
index f30d657..17f66a4 100644
--- a/python/tvm_ffi/dataclasses/c_class.py
+++ b/python/tvm_ffi/dataclasses/c_class.py
@@ -69,6 +69,7 @@ def c_class(
     eq: bool = False,
     order: bool = False,
     unsafe_hash: bool = False,
+    match_args: bool = True,
 ) -> Callable[[_T], _T]:
     """Register a C++ FFI class and install structural dunder methods.
 
@@ -105,6 +106,11 @@ def c_class(
         *unsafe* because mutable fields contribute to the hash, so mutating
         an object while it is in a set or dict key will break invariants.
         Defaults to False.
+    match_args
+        If True (default), set ``__match_args__`` to a tuple of the
+        positional ``__init__`` field names (``init=True`` and not
+        ``kw_only``), enabling ``match`` statements.  Ignored when the
+        class body already defines ``__match_args__``.
 
     Returns
     -------
@@ -151,7 +157,13 @@ def c_class(
         _warn_missing_field_annotations(cls, type_info, stacklevel=2)
         _attach_field_objects(cls, type_info)
         _install_dataclass_dunders(
-            cls, init=init, repr=repr, eq=eq, order=order, 
unsafe_hash=unsafe_hash
+            cls,
+            init=init,
+            repr=repr,
+            eq=eq,
+            order=order,
+            unsafe_hash=unsafe_hash,
+            match_args=match_args,
         )
         return cls
 
diff --git a/python/tvm_ffi/dataclasses/common.py 
b/python/tvm_ffi/dataclasses/common.py
index 88c3d62..4df4a53 100644
--- a/python/tvm_ffi/dataclasses/common.py
+++ b/python/tvm_ffi/dataclasses/common.py
@@ -18,11 +18,30 @@
 
 from __future__ import annotations
 
+import copy
+from collections.abc import Callable
 from typing import Any
 
+from ..container import Array, Dict, List, Map
 from .field import Field
 
-__all__ = ["fields", "is_dataclass", "replace"]
+__all__ = ["asdict", "astuple", "fields", "is_dataclass", "replace"]
+
+# Exact-type fast path for atomic (immutable, non-recursive) values.  Mirrors
+# :data:`dataclasses._ATOMIC_TYPES` from the standard library.
+_ATOMIC_TYPES: frozenset[type] = frozenset(
+    {
+        bool,
+        bytes,
+        complex,
+        float,
+        int,
+        str,
+        type(None),
+        type(Ellipsis),
+        type(NotImplemented),
+    }
+)
 
 
 def is_dataclass(obj: Any) -> bool:
@@ -72,3 +91,155 @@ def replace(obj: Any, /, **changes: Any) -> Any:
     still replaceable.
     """
     return obj.__replace__(**changes)
+
+
+def _is_ffi_dataclass_instance(obj: Any) -> bool:
+    """Return True when *obj* is a ``@c_class`` / ``@py_class`` **instance** 
(not a type)."""
+    if isinstance(obj, type):
+        return False
+    return getattr(type(obj), "__tvm_ffi_type_info__", None) is not None
+
+
+def _asdict_inner(  # noqa: PLR0911, PLR0912
+    obj: Any, dict_factory: Callable[..., Any]
+) -> Any:
+    obj_type = type(obj)
+    if obj_type in _ATOMIC_TYPES:
+        return obj
+    # FFI containers are treated as their stdlib analogues so the result is
+    # plain Python data — handy for JSON serialisation, the main use case.
+    if isinstance(obj, (Array, List)):
+        return [_asdict_inner(v, dict_factory) for v in obj]
+    if isinstance(obj, (Map, Dict)):
+        return dict_factory(
+            [
+                (_asdict_inner(k, dict_factory), _asdict_inner(v, 
dict_factory))
+                for k, v in obj.items()
+            ]
+        )
+    if _is_ffi_dataclass_instance(obj):
+        fs = fields(obj)
+        if dict_factory is dict:
+            return {f.name: _asdict_inner(getattr(obj, f.name), dict) for f in 
fs}  # ty: ignore[invalid-argument-type]
+        return dict_factory(
+            [(f.name, _asdict_inner(getattr(obj, f.name), dict_factory)) for f 
in fs]  # ty: ignore[invalid-argument-type]
+        )
+    if obj_type is list:
+        return [_asdict_inner(v, dict_factory) for v in obj]
+    if obj_type is dict:
+        return {
+            _asdict_inner(k, dict_factory): _asdict_inner(v, dict_factory) for 
k, v in obj.items()
+        }
+    if obj_type is tuple:
+        return tuple(_asdict_inner(v, dict_factory) for v in obj)
+    if issubclass(obj_type, tuple):
+        if hasattr(obj, "_fields"):  # namedtuple
+            return obj_type(*[_asdict_inner(v, dict_factory) for v in obj])
+        return obj_type(_asdict_inner(v, dict_factory) for v in obj)
+    if issubclass(obj_type, dict):
+        if hasattr(obj_type, "default_factory"):
+            result = obj_type(obj.default_factory)
+            for k, v in obj.items():
+                result[_asdict_inner(k, dict_factory)] = _asdict_inner(v, 
dict_factory)
+            return result
+        return obj_type(
+            (_asdict_inner(k, dict_factory), _asdict_inner(v, dict_factory)) 
for k, v in obj.items()
+        )
+    if issubclass(obj_type, list):
+        return obj_type(_asdict_inner(v, dict_factory) for v in obj)
+    return copy.deepcopy(obj)
+
+
+def _astuple_inner(obj: Any, tuple_factory: Callable[..., Any]) -> Any:  # 
noqa: PLR0911
+    obj_type = type(obj)
+    if obj_type in _ATOMIC_TYPES:
+        return obj
+    if isinstance(obj, (Array, List)):
+        return [_astuple_inner(v, tuple_factory) for v in obj]
+    if isinstance(obj, (Map, Dict)):
+        return {
+            _astuple_inner(k, tuple_factory): _astuple_inner(v, tuple_factory)
+            for k, v in obj.items()
+        }
+    if _is_ffi_dataclass_instance(obj):
+        return tuple_factory(
+            [_astuple_inner(getattr(obj, f.name), tuple_factory) for f in 
fields(obj)]  # ty: ignore[invalid-argument-type]
+        )
+    if isinstance(obj, tuple) and hasattr(obj, "_fields"):  # namedtuple
+        return obj_type(*[_astuple_inner(v, tuple_factory) for v in obj])
+    if isinstance(obj, (list, tuple)):
+        return obj_type(_astuple_inner(v, tuple_factory) for v in obj)
+    if isinstance(obj, dict):
+        if hasattr(obj_type, "default_factory"):
+            result = obj_type(obj.default_factory)
+            for k, v in obj.items():
+                result[_astuple_inner(k, tuple_factory)] = _astuple_inner(v, 
tuple_factory)
+            return result
+        return obj_type(
+            (_astuple_inner(k, tuple_factory), _astuple_inner(v, 
tuple_factory))
+            for k, v in obj.items()
+        )
+    return copy.deepcopy(obj)
+
+
+def asdict(obj: Any, *, dict_factory: Callable[..., Any] = dict) -> Any:
+    r"""Return the fields of a ``@c_class`` / ``@py_class`` instance as a new 
dict.
+
+    Mirrors :func:`dataclasses.asdict` from the standard library.  The
+    function recurses into nested FFI dataclass instances, FFI containers
+    (:class:`~tvm_ffi.Array`, :class:`~tvm_ffi.List`,
+    :class:`~tvm_ffi.Map`, :class:`~tvm_ffi.Dict`), and the built-in
+    ``list`` / ``tuple`` / ``dict``.  FFI sequence containers are
+    converted to Python ``list``\ s and FFI mapping containers to
+    Python ``dict``\ s so the result is plain Python data, e.g. for
+    JSON serialisation.  Any other value is copied with
+    :func:`copy.deepcopy`.
+
+    Parameters
+    ----------
+    obj
+        A ``@c_class`` / ``@py_class`` instance.  Passing a type raises
+        :class:`TypeError`.
+    dict_factory
+        Callable used to construct the outer mapping and any nested
+        mapping recursed from an FFI dataclass.  Defaults to :class:`dict`.
+
+    Raises
+    ------
+    TypeError
+        If ``obj`` is not a ``@c_class`` / ``@py_class`` instance.
+
+    """
+    if not _is_ffi_dataclass_instance(obj):
+        raise TypeError("asdict() should be called on c_class / py_class 
instances")
+    return _asdict_inner(obj, dict_factory)
+
+
+def astuple(obj: Any, *, tuple_factory: Callable[..., Any] = tuple) -> Any:
+    """Return the fields of a ``@c_class`` / ``@py_class`` instance as a new 
tuple.
+
+    Mirrors :func:`dataclasses.astuple` from the standard library.  The
+    function recurses into nested FFI dataclass instances, FFI containers
+    (:class:`~tvm_ffi.Array`, :class:`~tvm_ffi.List`,
+    :class:`~tvm_ffi.Map`, :class:`~tvm_ffi.Dict`), and the built-in
+    ``list`` / ``tuple`` / ``dict``.  Any other value is copied with
+    :func:`copy.deepcopy`.
+
+    Parameters
+    ----------
+    obj
+        A ``@c_class`` / ``@py_class`` instance.  Passing a type raises
+        :class:`TypeError`.
+    tuple_factory
+        Callable used to construct the outer tuple and any nested tuple
+        recursed from an FFI dataclass.  Defaults to :class:`tuple`.
+
+    Raises
+    ------
+    TypeError
+        If ``obj`` is not a ``@c_class`` / ``@py_class`` instance.
+
+    """
+    if not _is_ffi_dataclass_instance(obj):
+        raise TypeError("astuple() should be called on c_class / py_class 
instances")
+    return _astuple_inner(obj, tuple_factory)
diff --git a/python/tvm_ffi/dataclasses/py_class.py 
b/python/tvm_ffi/dataclasses/py_class.py
index 9f3d1f2..70ee25d 100644
--- a/python/tvm_ffi/dataclasses/py_class.py
+++ b/python/tvm_ffi/dataclasses/py_class.py
@@ -305,6 +305,7 @@ def _register_fields_into_type(
         eq=params["eq"],
         order=params["order"],
         unsafe_hash=params["unsafe_hash"],
+        match_args=params["match_args"],
         py_class_mode=True,
     )
     return True
@@ -450,6 +451,7 @@ def py_class(  # noqa: PLR0913
     eq: bool = False,
     order: bool = False,
     unsafe_hash: bool = False,
+    match_args: bool = True,
     kw_only: bool = False,
     structural_eq: str | None = None,
     slots: bool = True,
@@ -507,6 +509,11 @@ def py_class(  # noqa: PLR0913
         Requires ``eq=True``.
     unsafe_hash
         If True, generate ``__hash__`` (unsafe for mutable objects).
+    match_args
+        If True (default), set ``__match_args__`` to a tuple of the
+        positional ``__init__`` field names (``init=True`` and not
+        ``kw_only``), enabling ``match`` statements.  Ignored when the
+        class body already defines ``__match_args__``.
     kw_only
         If True, all fields are keyword-only in ``__init__`` by default.
     structural_eq
@@ -551,6 +558,7 @@ def py_class(  # noqa: PLR0913
         "eq": eq,
         "order": order,
         "unsafe_hash": unsafe_hash,
+        "match_args": match_args,
         "kw_only": kw_only,
         "structural_eq": structural_eq,
     }
diff --git a/tests/python/test_dataclass_common.py 
b/tests/python/test_dataclass_common.py
index 95b860d..238886e 100644
--- a/tests/python/test_dataclass_common.py
+++ b/tests/python/test_dataclass_common.py
@@ -19,6 +19,7 @@
 
 from __future__ import annotations
 
+import collections
 import dataclasses as _dc
 import itertools
 import typing
@@ -29,7 +30,10 @@ import tvm_ffi
 import tvm_ffi.testing
 from tvm_ffi.core import MISSING, Object
 from tvm_ffi.dataclasses import (
+    KW_ONLY,
     Field,
+    asdict,
+    astuple,
     field,
     fields,
     is_dataclass,
@@ -250,3 +254,375 @@ class TestReplace:
         p2 = replace(p, a=10)
         assert p2.a == 10
         assert p.a == 3  # still read-only, original untouched
+
+
+# ---------------------------------------------------------------------------
+# asdict
+# ---------------------------------------------------------------------------
+class TestAsdict:
+    def test_c_class_basic(self) -> None:
+        p = tvm_ffi.testing.TestIntPair(1, 2)
+        assert asdict(p) == {"a": 1, "b": 2}
+
+    def test_py_class_basic(self) -> None:
+        @py_class(_k("PCAsdict"))
+        class PC(Object):
+            x: int
+            y: str
+
+        assert asdict(PC(x=1, y="hi")) == {"x": 1, "y": "hi"}
+
+    def test_py_class_nested(self) -> None:
+        @py_class(_k("PCInner"))
+        class Inner(Object):
+            x: int
+
+        @py_class(_k("PCOuter"))
+        class Outer(Object):
+            a: Inner
+            b: Inner
+
+        out = Outer(a=Inner(x=1), b=Inner(x=2))
+        assert asdict(out) == {"a": {"x": 1}, "b": {"x": 2}}
+
+    def test_py_class_inheritance(self) -> None:
+        @py_class(_k("PCParent"))
+        class P(Object):
+            x: int
+
+        @py_class(_k("PCChild"))
+        class C(P):
+            y: str
+
+        assert asdict(C(x=1, y="a")) == {"x": 1, "y": "a"}
+
+    def test_ffi_array_recurses_to_list(self) -> None:
+        @py_class(_k("PCWithArray"))
+        class PC(Object):
+            xs: tvm_ffi.Array
+
+        pc = PC(xs=tvm_ffi.Array([1, 2, 3]))
+        result = asdict(pc)
+        assert result == {"xs": [1, 2, 3]}
+        assert type(result["xs"]) is list
+
+    def test_ffi_list_recurses_to_list(self) -> None:
+        @py_class(_k("PCWithList"))
+        class PC(Object):
+            xs: tvm_ffi.List
+
+        pc = PC(xs=tvm_ffi.List([1, 2, 3]))
+        result = asdict(pc)
+        assert result == {"xs": [1, 2, 3]}
+        assert type(result["xs"]) is list
+
+    def test_ffi_map_recurses_to_dict(self) -> None:
+        @py_class(_k("PCWithMap"))
+        class PC(Object):
+            m: tvm_ffi.Map
+
+        pc = PC(m=tvm_ffi.Map({"a": 1, "b": 2}))
+        result = asdict(pc)
+        assert result == {"m": {"a": 1, "b": 2}}
+        assert type(result["m"]) is dict
+
+    def test_ffi_dict_recurses_to_dict(self) -> None:
+        @py_class(_k("PCWithDict"))
+        class PC(Object):
+            d: tvm_ffi.Dict
+
+        pc = PC(d=tvm_ffi.Dict({"a": 1, "b": 2}))
+        result = asdict(pc)
+        assert result == {"d": {"a": 1, "b": 2}}
+        assert type(result["d"]) is dict
+
+    def test_ffi_array_of_dataclasses(self) -> None:
+        @py_class(_k("PCItem"))
+        class Item(Object):
+            v: int
+
+        @py_class(_k("PCBox"))
+        class Box(Object):
+            items: tvm_ffi.Array
+
+        box = Box(items=tvm_ffi.Array([Item(v=1), Item(v=2)]))
+        assert asdict(box) == {"items": [{"v": 1}, {"v": 2}]}
+
+    def test_ffi_map_of_dataclasses(self) -> None:
+        @py_class(_k("PCItem2"))
+        class Item(Object):
+            v: int
+
+        @py_class(_k("PCBox2"))
+        class Box(Object):
+            items: tvm_ffi.Map
+
+        box = Box(items=tvm_ffi.Map({"a": Item(v=1), "b": Item(v=2)}))
+        assert asdict(box) == {"items": {"a": {"v": 1}, "b": {"v": 2}}}
+
+    def test_dict_factory(self) -> None:
+        @py_class(_k("PCDF"))
+        class PC(Object):
+            x: int
+            y: int
+
+        result = asdict(PC(x=1, y=2), dict_factory=collections.OrderedDict)
+        assert isinstance(result, collections.OrderedDict)
+        assert list(result.items()) == [("x", 1), ("y", 2)]
+
+    def test_dict_factory_recurses(self) -> None:
+        @py_class(_k("PCDFI"))
+        class Inner(Object):
+            v: int
+
+        @py_class(_k("PCDFO"))
+        class Outer(Object):
+            a: Inner
+
+        result = asdict(Outer(a=Inner(v=5)), 
dict_factory=collections.OrderedDict)
+        assert isinstance(result, collections.OrderedDict)
+        assert isinstance(result["a"], collections.OrderedDict)
+
+    def test_default_factory_list_independent(self) -> None:
+        """Result must be a fresh ``list``, not aliased to any internal 
state."""
+
+        @py_class(_k("PCInd"))
+        class PC(Object):
+            xs: tvm_ffi.Array
+
+        pc = PC(xs=tvm_ffi.Array([1, 2]))
+        d = asdict(pc)
+        d["xs"].append(99)
+        # Mutating the result must not affect the original.
+        assert list(pc.xs) == [1, 2]
+
+    def test_type_raises(self) -> None:
+        with pytest.raises(TypeError, match="c_class / py_class instances"):
+            asdict(tvm_ffi.testing.TestIntPair)  # passing type, not instance
+
+    def test_non_dataclass_raises(self) -> None:
+        with pytest.raises(TypeError, match="c_class / py_class instances"):
+            asdict(42)
+        with pytest.raises(TypeError, match="c_class / py_class instances"):
+            asdict([1, 2, 3])
+
+    def test_defaultdict_preserved(self) -> None:
+        """``defaultdict`` round-trips with its ``default_factory`` intact.
+
+        Exercises the ``_asdict_inner`` defaultdict branch directly, since
+        FFI ``Any`` field storage converts a stored ``defaultdict`` into
+        an FFI ``Map`` on readback.  Mirrors stdlib
+        ``dataclasses._asdict_inner``'s check on ``type(obj)``.
+        """
+        from tvm_ffi.dataclasses.common import _asdict_inner  # noqa: PLC0415
+
+        dd = collections.defaultdict(list)
+        dd["a"].append(1)
+        dd["b"].append(2)
+        result = _asdict_inner(dd, dict)
+        assert type(result) is collections.defaultdict
+        assert result.default_factory is list
+        assert dict(result) == {"a": [1], "b": [2]}
+
+
+# ---------------------------------------------------------------------------
+# astuple
+# ---------------------------------------------------------------------------
+class TestAstuple:
+    def test_c_class_basic(self) -> None:
+        p = tvm_ffi.testing.TestIntPair(1, 2)
+        assert astuple(p) == (1, 2)
+
+    def test_py_class_basic(self) -> None:
+        @py_class(_k("PCAstuple"))
+        class PC(Object):
+            x: int
+            y: str
+
+        assert astuple(PC(x=1, y="hi")) == (1, "hi")
+
+    def test_py_class_nested(self) -> None:
+        @py_class(_k("PCInnerT"))
+        class Inner(Object):
+            x: int
+
+        @py_class(_k("PCOuterT"))
+        class Outer(Object):
+            a: Inner
+            b: Inner
+
+        out = Outer(a=Inner(x=1), b=Inner(x=2))
+        assert astuple(out) == ((1,), (2,))
+
+    def test_py_class_inheritance(self) -> None:
+        @py_class(_k("PCParentT"))
+        class P(Object):
+            x: int
+
+        @py_class(_k("PCChildT"))
+        class C(P):
+            y: str
+
+        assert astuple(C(x=1, y="a")) == (1, "a")
+
+    def test_ffi_array_recurses_to_list(self) -> None:
+        @py_class(_k("PCArrT"))
+        class PC(Object):
+            xs: tvm_ffi.Array
+
+        pc = PC(xs=tvm_ffi.Array([1, 2, 3]))
+        result = astuple(pc)
+        assert result == ([1, 2, 3],)
+        assert type(result[0]) is list
+
+    def test_ffi_map_recurses_to_dict(self) -> None:
+        @py_class(_k("PCMapT"))
+        class PC(Object):
+            m: tvm_ffi.Map
+
+        pc = PC(m=tvm_ffi.Map({"a": 1}))
+        result = astuple(pc)
+        assert result == ({"a": 1},)
+        assert type(result[0]) is dict
+
+    def test_tuple_factory(self) -> None:
+        @py_class(_k("PCTF"))
+        class PC(Object):
+            x: int
+            y: int
+
+        assert astuple(PC(x=1, y=2), tuple_factory=list) == [1, 2]
+
+    def test_tuple_factory_recurses(self) -> None:
+        @py_class(_k("PCTFI"))
+        class Inner(Object):
+            v: int
+
+        @py_class(_k("PCTFO"))
+        class Outer(Object):
+            a: Inner
+
+        result = astuple(Outer(a=Inner(v=5)), tuple_factory=list)
+        assert result == [[5]]
+
+    def test_type_raises(self) -> None:
+        with pytest.raises(TypeError, match="c_class / py_class instances"):
+            astuple(tvm_ffi.testing.TestIntPair)
+
+    def test_non_dataclass_raises(self) -> None:
+        with pytest.raises(TypeError, match="c_class / py_class instances"):
+            astuple(42)
+        with pytest.raises(TypeError, match="c_class / py_class instances"):
+            astuple([1, 2, 3])
+
+    def test_defaultdict_preserved(self) -> None:
+        """``defaultdict`` round-trips with its ``default_factory`` intact."""
+        from tvm_ffi.dataclasses.common import _astuple_inner  # noqa: PLC0415
+
+        dd = collections.defaultdict(list)
+        dd["a"].append(1)
+        dd["b"].append(2)
+        result = _astuple_inner(dd, tuple)
+        assert type(result) is collections.defaultdict
+        assert result.default_factory is list
+        assert dict(result) == {"a": [1], "b": [2]}
+
+
+# ---------------------------------------------------------------------------
+# __match_args__
+# ---------------------------------------------------------------------------
+def _match_args(cls: type) -> object:
+    """Read ``__match_args__`` without tripping static attribute checks."""
+    return getattr(cls, "__match_args__")
+
+
+class TestMatchArgs:
+    def test_py_class_basic(self) -> None:
+        @py_class(_k("MAPy"))
+        class PC(Object):
+            x: int
+            y: str
+
+        assert _match_args(PC) == ("x", "y")
+
+    def test_py_class_init_false_excluded(self) -> None:
+        @py_class(_k("MAInitFalse"))
+        class PC(Object):
+            x: int
+            y: int = field(default=0, init=False)
+
+        assert _match_args(PC) == ("x",)
+
+    def test_py_class_kw_only_field_excluded(self) -> None:
+        @py_class(_k("MAKwField"))
+        class PC(Object):
+            x: int
+            y: int = field(kw_only=True)
+
+        assert _match_args(PC) == ("x",)
+
+    def test_py_class_kw_only_decorator_excludes_all(self) -> None:
+        @py_class(_k("MAKwDeco"), kw_only=True)
+        class PC(Object):
+            x: int
+            y: int
+
+        assert _match_args(PC) == ()
+
+    def test_py_class_kw_only_sentinel(self) -> None:
+        @py_class(_k("MAKwSent"))
+        class PC(Object):
+            x: int
+            _: KW_ONLY
+            y: int
+            z: int
+
+        assert _match_args(PC) == ("x",)
+
+    def test_py_class_inheritance_parent_first(self) -> None:
+        @py_class(_k("MAParent"))
+        class P(Object):
+            x: int
+            y: str
+
+        @py_class(_k("MAChild"))
+        class C(P):
+            z: float
+
+        assert _match_args(P) == ("x", "y")
+        assert _match_args(C) == ("x", "y", "z")
+
+    def test_py_class_opt_out(self) -> None:
+        @py_class(_k("MAOptOut"), match_args=False)
+        class PC(Object):
+            x: int
+            y: int
+
+        assert "__match_args__" not in PC.__dict__
+
+    def test_py_class_user_defined_preserved(self) -> None:
+        @py_class(_k("MAUser"))
+        class PC(Object):
+            x: int
+            y: int
+            __match_args__ = ("y",)
+
+        assert _match_args(PC) == ("y",)
+
+    def test_c_class_basic(self) -> None:
+        assert _match_args(tvm_ffi.testing.TestIntPair) == ("a", "b")
+
+    def test_c_class_inheritance_parent_first(self) -> None:
+        assert _match_args(tvm_ffi.testing._TestCxxClassDerived) == (
+            "v_i64",
+            "v_i32",
+            "v_f64",
+            "v_f32",
+        )
+
+    def test_tuple_type(self) -> None:
+        @py_class(_k("MATupT"))
+        class PC(Object):
+            x: int
+
+        assert isinstance(_match_args(PC), tuple)

Reply via email to