This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 24054b4  [FEAT] Enhance map_dataclass_to_tuple with JIT 
unpack_dataclass_to_tuple (#563)
24054b4 is described below

commit 24054b47bd013398e2c7bacb312f2d5e21059d50
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Apr 20 18:41:22 2026 -0400

    [FEAT] Enhance map_dataclass_to_tuple with JIT unpack_dataclass_to_tuple 
(#563)
    
    This PR introduces an optimized unpack_dataclass_to_tuple util and
    replaces `dataclasses.astuple` in `map_dataclass_to_tuple`, a
    JIT-compiled per-class unpacker that brings 5-11x speedup and does not
    deep-copy leaf values.
    
    Not deep-copying leaf is important for runtime use-cases, when dataclass
    fields contain tensor objects or other large resources that must not be
    copied -- leaf values are passed through by reference, not duplicated.
    
    ## Design
    
    `unpack_dataclass_to_tuple(x)` recursively converts dataclass instances
    to tuples via JIT-compiled inline attribute access:
    
    - `_extract_dataclass_to_tuple_schema(cls)` -- inspect type annotations
    - `_validate_dataclass_to_tuple_schema(schema)` -- codegen security
    - `_compile_dataclass_to_tuple_schema(prefix, schema)` -- inline code
    
    Cached per-thread (TLS) for thread safety. Conservative field
    classification: only marks as leaf when certain (known primitives,
    containers of known primitives). Unknown types use dynamic dispatch.
    
    ## Performance (vs dataclasses.astuple)
    
    | Case              | unpack | astuple | speedup |
    |-------------------|--------|---------|---------|
    | Config (2 flat)   | 144 ns |  769 ns |   5.3x  |
    | Nested (int + dc) | 152 ns | 1386 ns |   9.1x  |
    | Wide (5 flat)     | 170 ns | 1268 ns |   7.5x  |
    | Deep (2 levels)   | 200 ns | 2305 ns |  11.5x  |
    | WithAny (dynamic) | 254 ns | 1644 ns |   6.5x  |
---
 python/tvm_ffi/utils/kwargs_wrapper.py      |  39 ++--
 python/tvm_ffi/utils/unpack_dataclass.py    | 263 +++++++++++++++++++++++++
 tests/python/utils/test_kwargs_wrapper.py   |   8 +-
 tests/python/utils/test_unpack_dataclass.py | 294 ++++++++++++++++++++++++++++
 tests/scripts/benchmark_kwargs_wrapper.py   |  45 +++--
 tests/scripts/benchmark_unpack_dataclass.py | 140 +++++++++++++
 6 files changed, 753 insertions(+), 36 deletions(-)

diff --git a/python/tvm_ffi/utils/kwargs_wrapper.py 
b/python/tvm_ffi/utils/kwargs_wrapper.py
index df1a8f9..1bad61c 100644
--- a/python/tvm_ffi/utils/kwargs_wrapper.py
+++ b/python/tvm_ffi/utils/kwargs_wrapper.py
@@ -22,12 +22,13 @@ keyword argument support using code generation techniques.
 
 from __future__ import annotations
 
-import dataclasses
 import functools
 import inspect
 import keyword
 from typing import Any, Callable, Iterable
 
+from tvm_ffi.utils.unpack_dataclass import unpack_dataclass_to_tuple
+
 # Sentinel object for missing arguments
 MISSING = object()
 
@@ -35,12 +36,12 @@ MISSING = object()
 _INTERNAL_TARGET_FUNC = "__i_target_func"
 _INTERNAL_MISSING = "__i_MISSING"
 _INTERNAL_DEFAULTS_DICT = "__i_arg_defaults"
-_INTERNAL_ASTUPLE = "__i_astuple"
+_INTERNAL_UNPACK = "__i_unpack"
 _INTERNAL_NAMES = {
     _INTERNAL_TARGET_FUNC,
     _INTERNAL_MISSING,
     _INTERNAL_DEFAULTS_DICT,
-    _INTERNAL_ASTUPLE,
+    _INTERNAL_UNPACK,
 }
 
 
@@ -163,34 +164,34 @@ def _build_wrapper_code(
     call_parts: list[str] = []
     runtime_defaults: dict[str, Any] = {}
 
-    def _wrap_astuple(name: str, expr: str) -> str:
+    def _transform_expr(name: str, expr: str) -> str:
         if name in dc_to_tuple_set:
-            return f"{_INTERNAL_ASTUPLE}({expr})"
+            return f"{_INTERNAL_UNPACK}({expr})"
         return expr
 
     def _add_param_with_default(name: str, default_value: Any) -> None:
         # Directly embed None and bool defaults; use MISSING sentinel for 
others.
         if default_value is None:
             arg_parts.append(f"{name}=None")
-            call_parts.append(_wrap_astuple(name, name))
+            call_parts.append(_transform_expr(name, name))
         elif type(default_value) is bool:
             default_value_str = "True" if default_value else "False"
             arg_parts.append(f"{name}={default_value_str}")
-            call_parts.append(_wrap_astuple(name, name))
+            call_parts.append(_transform_expr(name, name))
         else:
             arg_parts.append(f"{name}={_INTERNAL_MISSING}")
             runtime_defaults[name] = default_value
             base_expr = (
                 f'{_INTERNAL_DEFAULTS_DICT}["{name}"] if {name} is 
{_INTERNAL_MISSING} else {name}'
             )
-            call_parts.append(_wrap_astuple(name, base_expr))
+            call_parts.append(_transform_expr(name, base_expr))
 
     for name in arg_names:
         if name in arg_defaults_dict:
             _add_param_with_default(name, arg_defaults_dict[name])
         else:
             arg_parts.append(name)
-            call_parts.append(_wrap_astuple(name, name))
+            call_parts.append(_transform_expr(name, name))
 
     if kwonly_names:
         arg_parts.append("*")
@@ -199,7 +200,7 @@ def _build_wrapper_code(
                 _add_param_with_default(name, kwonly_defaults[name])
             else:
                 arg_parts.append(name)
-                call_parts.append(_wrap_astuple(name, name))
+                call_parts.append(_transform_expr(name, name))
 
     arg_list = ", ".join(arg_parts)
     call_list = ", ".join(call_parts)
@@ -252,10 +253,13 @@ def make_kwargs_wrapper(
         Optional prototype function to copy metadata (__name__, __doc__, 
__module__,
         __qualname__, __annotations__) from. If None, no metadata is copied.
     map_dataclass_to_tuple
-        Optional list of argument names whose values should be converted from 
dataclass
-        instances to tuples (via ``dataclasses.astuple``) before being passed 
to the
-        target function. This is useful when the target function expects 
flattened tuple
-        arguments but callers pass dataclass instances.
+        Optional list of argument names whose values should be converted from
+        dataclass instances to tuples via ``unpack_dataclass_to_tuple``.
+        Dataclass fields are unpacked to tuple and leaf fields are shallow 
copied.
+
+        Nested dataclass fields are recursed automatically based on type 
annotations.
+        Lists/tuples are recursed element-wise. Dicts are recursed on values.
+        Non-dataclass leaves are passed through unchanged.
 
     Returns
     -------
@@ -296,7 +300,7 @@ def make_kwargs_wrapper(
         _INTERNAL_DEFAULTS_DICT: runtime_defaults,
     }
     if dc_to_tuple_set:
-        exec_globals[_INTERNAL_ASTUPLE] = dataclasses.astuple
+        exec_globals[_INTERNAL_UNPACK] = unpack_dataclass_to_tuple
     namespace: dict[str, Any] = {}
     exec(code_str, exec_globals, namespace)
     new_func = namespace["wrapper"]
@@ -335,9 +339,8 @@ def make_kwargs_wrapper_from_signature(
         These arguments will not be included in the generated wrapper. If a 
name in this iterable
         does not exist in the signature, it is silently ignored.
     map_dataclass_to_tuple
-        Optional list of argument names whose values should be converted from 
dataclass
-        instances to tuples (via ``dataclasses.astuple``) before being passed 
to the
-        target function.
+        Optional list of argument names to unpack via 
``unpack_dataclass_to_tuple``.
+        See ``make_kwargs_wrapper`` for details.
 
     Returns
     -------
diff --git a/python/tvm_ffi/utils/unpack_dataclass.py 
b/python/tvm_ffi/utils/unpack_dataclass.py
new file mode 100644
index 0000000..b85ffc6
--- /dev/null
+++ b/python/tvm_ffi/utils/unpack_dataclass.py
@@ -0,0 +1,263 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Fast dataclass-to-tuple conversion via JIT-compiled unpacking.
+
+This module provides ``unpack_dataclass_to_tuple``, a function that recursively
+converts dataclass instances to tuples. It JIT-compiles a per-class unpacker
+on first call and caches it per-thread for ~5-11x speedup over
+``dataclasses.astuple`` with no deep-copy of leaf values.
+"""
+
+from __future__ import annotations
+
+import dataclasses
+import keyword
+import sys
+import threading
+import typing
+from typing import Any
+
+# Support both typing.Union and types.UnionType (PEP 604, Python 3.10+)
+if sys.version_info >= (3, 10):
+    import types
+
+    _LEAF_CONTAINER_ORIGINS = (tuple, dict, set, frozenset, typing.Union, 
types.UnionType)
+else:
+    _LEAF_CONTAINER_ORIGINS = (tuple, dict, set, frozenset, typing.Union)
+
+# Type alias for dataclass-to-tuple schema (internal).
+# Schema values:
+#   None     -> leaf, direct attribute access (zero cost)
+#   "unpack" -> dynamic dispatch via unpack_dataclass_to_tuple at runtime
+#   dict     -> nested struct, recurse inline
+# Example: {"x": None, "y": None} -> (__x.x, __x.y,)
+# Example: {"cfg": {"x": None, "y": None}, "data": "unpack"}
+#         -> ((__x.cfg.x, __x.cfg.y,), __dispatch(__x.data),)
+DataclassToTupleSchema = dict  # dict[str, None | str | DataclassToTupleSchema]
+
+# Sentinel value in schema: field should be dynamically dispatched
+UNPACK = "unpack"
+
+# Thread-local cache for JIT-compiled per-class unpack functions
+_tls = threading.local()
+
+# Types known to be safe leaves (never contain dataclass instances)
+_KNOWN_LEAF_TYPES: set[type] = {int, float, str, bool, bytes, complex, 
type(None)}
+
+
+def _is_known_leaf_type(tp: Any) -> bool:
+    """Check if a type is definitely a leaf (no dataclass content or 
conversion needed).
+
+    Note: list is NOT a leaf because it must be converted to a tuple per the
+    unpack contract (matching dataclasses.astuple behavior).
+    """
+    if isinstance(tp, type):
+        return tp in _KNOWN_LEAF_TYPES
+    if tp is Ellipsis:
+        return True
+    origin = typing.get_origin(tp)
+    if origin is not None:
+        # list is NOT a leaf — must be converted to tuple
+        # tuple/dict/set/frozenset/Union are leaves if all args are leaves
+        if origin in _LEAF_CONTAINER_ORIGINS:
+            args = typing.get_args(tp)
+            return bool(args) and all(_is_known_leaf_type(a) for a in args)
+    return False
+
+
+def _classify_field_type(
+    field_type: Any, memo: set[type] | None = None
+) -> None | str | DataclassToTupleSchema:
+    """Classify a resolved field type into a schema entry.
+
+    Conservative: only returns None (leaf) when we are certain the type
+    cannot contain a dataclass. Otherwise returns UNPACK (dynamic dispatch).
+    """
+    if isinstance(field_type, str) or field_type is Any or field_type is 
object:
+        return UNPACK
+    if dataclasses.is_dataclass(field_type) and isinstance(field_type, type):
+        # Guard against infinite recursion for self-referential dataclasses
+        if memo is not None and field_type in memo:
+            return UNPACK
+        return _extract_dataclass_to_tuple_schema(field_type, memo=memo)
+    # Known primitive types -> leaf
+    if isinstance(field_type, type) and field_type in _KNOWN_LEAF_TYPES:
+        return None
+    # Generic containers: check element types
+    # list always needs UNPACK (must be converted to tuple)
+    # tuple/dict/set/frozenset/Union are leaves if all args are known leaves
+    # Generic containers: list always UNPACK (must convert to tuple).
+    # tuple/dict/set/frozenset/Union are leaves only if all args are known 
leaves.
+    # Everything else (unknown type) -> UNPACK (conservative).
+    origin = typing.get_origin(field_type)
+    if origin in _LEAF_CONTAINER_ORIGINS:
+        args = typing.get_args(field_type)
+        if args and all(_is_known_leaf_type(a) for a in args):
+            return None
+    return UNPACK
+
+
+def _compile_dataclass_to_tuple_schema(prefix: str, schema: 
DataclassToTupleSchema) -> str:
+    """Compile a DataclassToTupleSchema into an inline tuple expression.
+
+    Parameters
+    ----------
+    prefix
+        The variable expression to unpack (e.g. "__x" or "__x.field").
+    schema
+        The schema dict mapping field names to:
+        - None: leaf, direct attribute access
+        - "unpack": dynamic dispatch via __dispatch() at runtime
+        - nested dict: recurse inline
+
+    Returns
+    -------
+        A string expression that evaluates to a tuple of the unpacked fields.
+
+    """
+    parts: list[str] = []
+    for field_name, sub_schema in schema.items():
+        field_expr = f"{prefix}.{field_name}"
+        if sub_schema is None:
+            parts.append(field_expr)
+        elif sub_schema == UNPACK:
+            parts.append(f"__dispatch({field_expr})")
+        else:
+            parts.append(_compile_dataclass_to_tuple_schema(field_expr, 
sub_schema))
+    return "(" + ", ".join(parts) + (",)" if parts else ")")
+
+
+def _validate_dataclass_to_tuple_schema(schema: DataclassToTupleSchema) -> 
None:
+    """Validate that a DataclassToTupleSchema contains only safe identifiers.
+
+    This is critical for security since field names are embedded directly
+    in generated code via exec(). The validation ensures:
+    - Keys are strings (type check)
+    - Keys pass str.isidentifier() — rejects any non-identifier characters
+    - Keys are not Python keywords — rejects control flow injection
+    - Values are only None, "unpack", or recursively-validated dicts
+
+    Combined with the hardcoded prefix ("__x") and restricted exec_globals,
+    this prevents any code injection through crafted field names.
+
+    """
+    if not isinstance(schema, dict):
+        raise TypeError(f"DataclassToTupleSchema must be a dict, got 
{type(schema).__name__}")
+    for field_name, sub_schema in schema.items():
+        if not isinstance(field_name, str):
+            raise TypeError(f"Schema field name must be a string, got 
{type(field_name).__name__}")
+        if not field_name.isidentifier():
+            raise ValueError(f"Schema field name {field_name!r} is not a valid 
Python identifier")
+        if keyword.iskeyword(field_name):
+            raise ValueError(f"Schema field name {field_name!r} is a Python 
keyword")
+        if sub_schema is not None and sub_schema != UNPACK:
+            _validate_dataclass_to_tuple_schema(sub_schema)
+
+
+def _extract_dataclass_to_tuple_schema(
+    cls: type, *, memo: set[type] | None = None
+) -> DataclassToTupleSchema:
+    """Extract a DataclassToTupleSchema from a dataclass class using type 
annotations.
+
+    Classification per field (conservative: only leaf when certain):
+    - Known dataclass type -> nested schema (recurse inline)
+    - Known primitive type (int, float, str, bool, bytes, complex) -> None 
(leaf)
+    - Container with only known-leaf args (list[int], dict[str, float]) -> 
None (leaf)
+    - Container with dataclass/unknown args (list[Config]) -> "unpack" 
(dynamic dispatch)
+    - Any, object, unresolved string annotation -> "unpack" (dynamic dispatch)
+    - Unknown class -> "unpack" (dynamic dispatch)
+
+    Uses typing.get_type_hints() to resolve PEP 563 string annotations.
+    Uses memo set to prevent infinite recursion on self-referential 
dataclasses.
+
+    """
+    if not dataclasses.is_dataclass(cls) or not isinstance(cls, type):
+        raise TypeError(f"Expected a dataclass class, got {cls!r}")
+    if memo is None:
+        memo = set()
+    memo.add(cls)
+    try:
+        type_hints = typing.get_type_hints(cls)
+    except (NameError, TypeError, AttributeError):
+        type_hints = {}
+    schema: DataclassToTupleSchema = {}
+    for f in dataclasses.fields(cls):
+        field_type = type_hints.get(f.name, f.type)
+        schema[f.name] = _classify_field_type(field_type, memo=memo)
+    return schema
+
+
+def unpack_dataclass_to_tuple(x: Any) -> Any:
+    """Fast recursively unpack a dataclass value to tuple representation.
+
+    - Dataclass instances are unpacked to tuples of their field values.
+    - Lists and tuples are recursed element-wise, returning a tuple.
+    - Dicts are recursed on values, returning a new dict.
+    - All other values are returned as-is (leaf passthrough).
+
+    This function optimizes speed via JIT-compiling the conversion per 
dataclass
+    class and caching it per-thread. It brings about 5-11x speedup vs
+    ``dataclasses.astuple`` and does not deep-copy leaf values.
+
+    Parameters
+    ----------
+    x
+        The value to unpack.
+
+    Returns
+    -------
+        The unpacked tuple representation, or ``x`` unchanged if it's a leaf.
+
+    """
+    try:
+        cache = _tls.cache
+    except AttributeError:
+        cache = _tls.cache = {}
+
+    cls = type(x)
+    fn = cache.get(cls)
+    if fn is not None:
+        return fn(x)
+
+    # Cache miss — classify the type
+    if dataclasses.is_dataclass(cls) and isinstance(cls, type):
+        schema = _extract_dataclass_to_tuple_schema(cls)
+        # Validate that all field names in the schema are safe Python 
identifiers.
+        # This is critical: field names are embedded directly in the generated 
code string.
+        # _validate_dataclass_to_tuple_schema ensures no code injection is 
possible via
+        # crafted field names (isidentifier + iskeyword checks).
+        _validate_dataclass_to_tuple_schema(schema)
+        code_expr = _compile_dataclass_to_tuple_schema("__x", schema)
+        code = f"def __unpack(__x): return {code_expr}"
+        namespace: dict[str, Any] = {}
+        # exec_globals only exposes __dispatch (our own function), no other 
capabilities.
+        exec(code, {"__dispatch": unpack_dataclass_to_tuple}, namespace)
+        fn = namespace["__unpack"]
+        cache[cls] = fn
+        return fn(x)
+    if isinstance(x, (list, tuple)):
+        return type(x)(unpack_dataclass_to_tuple(e) for e in x)
+    if isinstance(x, dict):
+        return {k: unpack_dataclass_to_tuple(v) for k, v in x.items()}
+    # True leaf — cache identity so next call is just dict.get + return
+    cache[cls] = _LEAF_IDENTITY
+    return x
+
+
+def _LEAF_IDENTITY(x: Any) -> Any:
+    """Identity function cached for known leaf types."""
+    return x
diff --git a/tests/python/utils/test_kwargs_wrapper.py 
b/tests/python/utils/test_kwargs_wrapper.py
index 2b731b2..61362c1 100644
--- a/tests/python/utils/test_kwargs_wrapper.py
+++ b/tests/python/utils/test_kwargs_wrapper.py
@@ -373,7 +373,7 @@ def test_optimized_default_types() -> None:
 
 
 def test_map_dataclass_to_tuple() -> None:
-    """Test that dataclass arguments are converted to tuples via 
dataclasses.astuple."""
+    """Test map_dataclass_to_tuple in make_kwargs_wrapper."""
 
     @dataclasses.dataclass
     class Config:
@@ -388,7 +388,7 @@ def test_map_dataclass_to_tuple() -> None:
     def target(*args: Any) -> tuple[Any, ...]:
         return args
 
-    # Basic: one dataclass arg converted to tuple
+    # Basic: one dataclass arg converted
     wrapper = make_kwargs_wrapper(target, ["a", "cfg"], 
map_dataclass_to_tuple=["cfg"])
     result = wrapper(1, Config(x=10, y=20))
     assert result == (1, (10, 20))
@@ -402,7 +402,7 @@ def test_map_dataclass_to_tuple() -> None:
     result = wrapper(Config(x=1, y=2), Config(x=3, y=4))
     assert result == ((1, 2), (3, 4))
 
-    # Nested dataclass (astuple recurses)
+    # Nested dataclass (auto-recursion via type annotations)
     wrapper = make_kwargs_wrapper(target, ["a", "nested"], 
map_dataclass_to_tuple=["nested"])
     result = wrapper(1, Nested(value=5, cfg=Config(x=10, y=20)))
     assert result == (1, (5, (10, 20)))
@@ -412,7 +412,7 @@ def test_map_dataclass_to_tuple() -> None:
     result = wrapper(1, Config(x=10, y=20), 3)
     assert result == (1, (10, 20), 3)
 
-    # With defaults: dataclass arg has a default
+    # With defaults
     default_cfg = Config(x=0, y=0)
     wrapper = make_kwargs_wrapper(
         target, ["a", "cfg"], arg_defaults=(default_cfg,), 
map_dataclass_to_tuple=["cfg"]
diff --git a/tests/python/utils/test_unpack_dataclass.py 
b/tests/python/utils/test_unpack_dataclass.py
new file mode 100644
index 0000000..3a557ea
--- /dev/null
+++ b/tests/python/utils/test_unpack_dataclass.py
@@ -0,0 +1,294 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import dataclasses
+import sys
+from typing import Any
+
+import pytest
+from tvm_ffi.utils.unpack_dataclass import (
+    _extract_dataclass_to_tuple_schema,
+    _validate_dataclass_to_tuple_schema,
+    unpack_dataclass_to_tuple,
+)
+
+
+# Module-level dataclass definitions for schema extraction tests.
+# Must be at module level so typing.get_type_hints() can resolve
+# cross-references with PEP 563 (from __future__ import annotations).
[email protected]
+class _ExtractConfig:
+    x: int
+    y: int
+
+
[email protected]
+class _ExtractNested:
+    value: int
+    cfg: _ExtractConfig
+
+
[email protected]
+class _ExtractWithAny:
+    data: Any
+    scale: int
+
+
[email protected]
+class _ExtractWithList:
+    items: list[_ExtractConfig]
+    scale: int
+
+
[email protected]
+class _ExtractWithLeafList:
+    values: list[int]
+    name: str
+
+
[email protected]
+class _ExtractWithDict:
+    mapping: dict[str, _ExtractConfig]
+    count: int
+
+
[email protected]
+class _ExtractWithLeafDict:
+    mapping: dict[str, int]
+    count: int
+
+
[email protected]
+class _ExtractWithTuple:
+    pair: tuple[_ExtractConfig, int]
+    flag: bool
+
+
[email protected]
+class _ExtractWithOptional:
+    value: int | None
+    name: str
+
+
[email protected]
+class _ExtractWithLeafListInt:
+    items: list[int]
+    scale: int
+
+
+def test_unpack_dataclass_to_tuple() -> None:
+    """Test unpack_dataclass_to_tuple JIT-compiled unpacking."""
+
+    @dataclasses.dataclass
+    class Config:
+        x: int
+        y: int
+
+    @dataclasses.dataclass
+    class Nested:
+        value: int
+        cfg: Config
+
+    @dataclasses.dataclass
+    class Deep:
+        nested: Nested
+        flag: bool
+
+    # Flat dataclass
+    assert unpack_dataclass_to_tuple(Config(x=1, y=2)) == (1, 2)
+
+    # Nested dataclass (auto-recurses based on type annotation)
+    assert unpack_dataclass_to_tuple(Nested(value=5, cfg=Config(x=10, y=20))) 
== (5, (10, 20))
+
+    # Deep nesting
+    assert unpack_dataclass_to_tuple(
+        Deep(nested=Nested(value=5, cfg=Config(x=10, y=20)), flag=True)
+    ) == ((5, (10, 20)), True)
+
+    # Leaf passthrough
+    assert unpack_dataclass_to_tuple(42) == 42
+    assert unpack_dataclass_to_tuple("hello") == "hello"
+    assert unpack_dataclass_to_tuple(None) is None
+
+    # List recursion: list of dataclasses -> tuple of tuples
+    assert unpack_dataclass_to_tuple([Config(x=1, y=2), Config(x=3, y=4)]) == 
[(1, 2), (3, 4)]
+
+    # Tuple recursion
+    assert unpack_dataclass_to_tuple((Config(x=1, y=2), 5)) == ((1, 2), 5)
+
+    # Dict recursion (recurses values)
+    assert unpack_dataclass_to_tuple({"a": Config(x=1, y=2), "b": 3}) == {"a": 
(1, 2), "b": 3}
+
+    # Leaf values are NOT copied (no deep copy)
+    class Holder:
+        pass
+
+    @dataclasses.dataclass
+    class WithObj:
+        obj: Any
+        val: int
+
+    h = Holder()
+    result = unpack_dataclass_to_tuple(WithObj(obj=h, val=1))
+    assert result == (h, 1)
+    assert result[0] is h  # same object reference, no copy
+
+    # Dynamic dispatch: Any-typed field receives a dataclass at runtime
+    @dataclasses.dataclass
+    class WithAnyField:
+        data: Any
+        scale: int
+
+    # data is Any -> schema marks it as "unpack" -> __dispatch called at 
runtime
+    # When data is a dataclass, it should be recursively unpacked
+    result = unpack_dataclass_to_tuple(WithAnyField(data=Config(x=1, y=2), 
scale=3))
+    assert result == ((1, 2), 3)
+
+    # When data is a plain value, passthrough
+    result = unpack_dataclass_to_tuple(WithAnyField(data=42, scale=3))
+    assert result == (42, 3)
+
+    # When data is a list of dataclasses, recurse each element
+    result = unpack_dataclass_to_tuple(
+        WithAnyField(data=[Config(x=1, y=2), Config(x=3, y=4)], scale=5)
+    )
+    assert result == ([(1, 2), (3, 4)], 5)
+
+    # When data is a nested dataclass
+    result = unpack_dataclass_to_tuple(
+        WithAnyField(data=Nested(value=10, cfg=Config(x=1, y=2)), scale=5)
+    )
+    assert result == ((10, (1, 2)), 5)
+
+    # When data is a dict with dataclass values
+    result = unpack_dataclass_to_tuple(
+        WithAnyField(data={"a": Config(x=1, y=2), "b": Config(x=3, y=4)}, 
scale=5)
+    )
+    assert result == ({"a": (1, 2), "b": (3, 4)}, 5)
+
+    # Self-referential dataclass (linked list): should not infinite recurse
+    @dataclasses.dataclass
+    class Node:
+        value: int
+        next: Node | None
+
+    # Build a short linked list
+    node = Node(value=1, next=Node(value=2, next=None))
+    result = unpack_dataclass_to_tuple(node)
+    # The 'next' field is typed as Node|None, which on 3.10+ resolves to
+    # a UnionType. The self-reference is caught by memo -> UNPACK -> dynamic 
dispatch.
+    # Dynamic dispatch recursively unpacks the nested Node.
+    assert result == (1, (2, None))
+
+
+def test_validate_dataclass_to_tuple_schema() -> None:
+    """Test internal schema validation."""
+    # Valid schemas
+    _validate_dataclass_to_tuple_schema({"x": None, "y": None})
+    _validate_dataclass_to_tuple_schema({"cfg": {"x": None, "y": None}, 
"scale": None})
+
+    # Invalid: not a dict
+    with pytest.raises(TypeError, match="must be a dict"):
+        _validate_dataclass_to_tuple_schema([1, 2])  # type: ignore[arg-type]
+
+    # Invalid: non-string key
+    with pytest.raises(TypeError, match="must be a string"):
+        _validate_dataclass_to_tuple_schema({123: None})
+
+    # Invalid: not a valid identifier
+    with pytest.raises(ValueError, match="not a valid Python identifier"):
+        _validate_dataclass_to_tuple_schema({"not-valid": None})
+
+    # Invalid: Python keyword
+    with pytest.raises(ValueError, match="is a Python keyword"):
+        _validate_dataclass_to_tuple_schema({"class": None})
+
+    # Invalid: nested schema with bad key
+    with pytest.raises(ValueError, match="not a valid Python identifier"):
+        _validate_dataclass_to_tuple_schema({"cfg": {"x": None, "y!": None}})
+
+
[email protected](
+    sys.version_info < (3, 10),
+    reason="list[X]/dict[X,Y]/int|None not evaluable by get_type_hints on 
Python < 3.10",
+)
+def test_extract_dataclass_to_tuple_schema() -> None:
+    """Test schema extraction from dataclass types."""
+    # Flat: all known leaf types -> None
+    schema = _extract_dataclass_to_tuple_schema(_ExtractConfig)
+    assert schema == {"x": None, "y": None}
+
+    # Nested: known dataclass field -> nested schema
+    schema = _extract_dataclass_to_tuple_schema(_ExtractNested)
+    assert schema == {"value": None, "cfg": {"x": None, "y": None}}
+
+    # Any field -> "unpack" (dynamic dispatch)
+    schema = _extract_dataclass_to_tuple_schema(_ExtractWithAny)
+    assert schema == {"data": "unpack", "scale": None}
+
+    # list[Config] -> "unpack" (container with dataclass element)
+    schema = _extract_dataclass_to_tuple_schema(_ExtractWithList)
+    assert schema == {"items": "unpack", "scale": None}
+
+    # list[int] -> "unpack" (list must be converted to tuple per contract)
+    schema = _extract_dataclass_to_tuple_schema(_ExtractWithLeafList)
+    assert schema == {"values": "unpack", "name": None}
+
+    # list[int] standalone field also gets UNPACK
+    schema = _extract_dataclass_to_tuple_schema(_ExtractWithLeafListInt)
+    assert schema == {"items": "unpack", "scale": None}
+
+    # dict[str, Config] -> "unpack" (container with dataclass value type)
+    schema = _extract_dataclass_to_tuple_schema(_ExtractWithDict)
+    assert schema == {"mapping": "unpack", "count": None}
+
+    # dict[str, int] -> None (container with only known leaf types)
+    schema = _extract_dataclass_to_tuple_schema(_ExtractWithLeafDict)
+    assert schema == {"mapping": None, "count": None}
+
+    # tuple[Config, int] -> "unpack" (tuple containing a dataclass)
+    schema = _extract_dataclass_to_tuple_schema(_ExtractWithTuple)
+    assert schema == {"pair": "unpack", "flag": None}
+
+    # Optional[int] (int | None) -> None (Union of known leaves)
+    schema = _extract_dataclass_to_tuple_schema(_ExtractWithOptional)
+    assert schema == {"value": None, "name": None}
+
+    # Non-dataclass raises
+    with pytest.raises(TypeError, match="Expected a dataclass class"):
+        _extract_dataclass_to_tuple_schema(int)
+
+    # Locally-defined classes with built-in type annotations resolve fine
+    @dataclasses.dataclass
+    class LocalConfig:
+        x: int
+        y: int
+
+    schema = _extract_dataclass_to_tuple_schema(LocalConfig)
+    assert schema == {"x": None, "y": None}
+
+    # Locally-defined classes referencing other local classes can't be resolved
+    # by get_type_hints — all fields fall back to UNPACK
+    @dataclasses.dataclass
+    class LocalNested:
+        val: int
+        cfg: LocalConfig
+
+    schema = _extract_dataclass_to_tuple_schema(LocalNested)
+    # get_type_hints fails for the class -> all fields become "unpack"
+    assert schema == {"val": "unpack", "cfg": "unpack"}
diff --git a/tests/scripts/benchmark_kwargs_wrapper.py 
b/tests/scripts/benchmark_kwargs_wrapper.py
index 185400c..56bff57 100644
--- a/tests/scripts/benchmark_kwargs_wrapper.py
+++ b/tests/scripts/benchmark_kwargs_wrapper.py
@@ -18,6 +18,7 @@
 
 from __future__ import annotations
 
+import dataclasses
 import time
 from typing import Any
 
@@ -32,51 +33,67 @@ def target_func(*args: Any) -> None:
     pass
 
 
[email protected]
+class Config:
+    x: int
+    y: int
+    z: int
+
+
 def benchmark_kwargs_wrapper(repeat: int = 1000000) -> None:
-    """Benchmark kwargs wrapper with integer arguments."""
-    # Create test arguments
+    """Benchmark kwargs wrapper with integer arguments and dataclass 
unpacking."""
     x = 1
     y = 2
     z = 3
+    cfg = Config(x=x, y=y, z=z)
 
-    # Create wrapper with two optional kwargs
     wrapper = make_kwargs_wrapper(target_func, ["x", "y", "z"], 
arg_defaults=(None, None))
+    wrapper_dc = make_kwargs_wrapper(target_func, ["cfg"], 
map_dataclass_to_tuple=["cfg"])
+    # Warm up JIT cache
+    wrapper_dc(cfg)
 
-    # Benchmark 1: Direct call to target function (baseline)
+    # Direct call (baseline)
     start = time.time()
     for _ in range(repeat):
         target_func(x, y, z)
     end = time.time()
     print_speed("target_func(x, y, z)", (end - start) / repeat)
 
-    # Benchmark 2: Wrapper with all positional arguments
+    # Wrapper with positional args
     start = time.time()
     for _ in range(repeat):
         wrapper(x, y, z)
     end = time.time()
     print_speed("wrapper(x, y, z)", (end - start) / repeat)
 
-    # Benchmark 3: Wrapper with positional + kwargs
-    start = time.time()
-    for _ in range(repeat):
-        wrapper(x, y=y, z=z)
-    end = time.time()
-    print_speed("wrapper(x, y=y, z=z)", (end - start) / repeat)
-
-    # Benchmark 4: Wrapper with all kwargs
+    # Wrapper with kwargs
     start = time.time()
     for _ in range(repeat):
         wrapper(x=x, y=y, z=z)
     end = time.time()
     print_speed("wrapper(x=x, y=y, z=z)", (end - start) / repeat)
 
-    # Benchmark 5: Wrapper with defaults
+    # Wrapper with defaults
     start = time.time()
     for _ in range(repeat):
         wrapper(x)
     end = time.time()
     print_speed("wrapper(x) [with defaults]", (end - start) / repeat)
 
+    # Wrapper with dataclass unpack
+    start = time.time()
+    for _ in range(repeat):
+        wrapper_dc(cfg)
+    end = time.time()
+    print_speed("wrapper_dc(cfg) [map_dataclass_to_tuple]", (end - start) / 
repeat)
+
+    # Manual unpack (best possible)
+    start = time.time()
+    for _ in range(repeat):
+        target_func(cfg.x, cfg.y, cfg.z)
+    end = time.time()
+    print_speed("target_func(cfg.x, cfg.y, cfg.z) [manual]", (end - start) / 
repeat)
+
 
 if __name__ == "__main__":
     print("Benchmarking kwargs_wrapper overhead...")
diff --git a/tests/scripts/benchmark_unpack_dataclass.py 
b/tests/scripts/benchmark_unpack_dataclass.py
new file mode 100644
index 0000000..453f5a6
--- /dev/null
+++ b/tests/scripts/benchmark_unpack_dataclass.py
@@ -0,0 +1,140 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Benchmark unpack_dataclass_to_tuple vs dataclasses.astuple."""
+
+from __future__ import annotations
+
+import dataclasses
+import time
+from typing import Any
+
+from tvm_ffi.utils.unpack_dataclass import unpack_dataclass_to_tuple
+
+
+def print_speed(name: str, speed: float) -> None:
+    print(f"{name:<60} {speed} sec/call")
+
+
+def benchmark_unpack(unpack_fn: Any, val: Any, name: str, repeat: int = 
1000000) -> None:
+    """Benchmark an unpack function on a single value."""
+    unpack_fn(val)
+    start = time.time()
+    for _ in range(repeat):
+        unpack_fn(val)
+    end = time.time()
+    print_speed(name, (end - start) / repeat)
+
+
[email protected]
+class Config:
+    x: int
+    y: int
+
+
[email protected]
+class Nested:
+    value: int
+    cfg: Config
+
+
[email protected]
+class Wide:
+    a: int
+    b: int
+    c: int
+    d: int
+    e: int
+
+
[email protected]
+class Deep:
+    nested: Nested
+    flag: bool
+
+
[email protected]
+class WithList:
+    items: list[Config]
+    scale: int
+
+
[email protected]
+class WithAny:
+    data: Any
+    scale: int
+
+
[email protected]
+class ConfigAny:
+    x: Any
+    y: Any
+
+
+def noop(x: Any) -> Any:
+    return x
+
+
+if __name__ == "__main__":
+    cfg = Config(x=1, y=2)
+    nested = Nested(value=5, cfg=Config(x=10, y=20))
+    wide = Wide(a=1, b=2, c=3, d=4, e=5)
+    deep = Deep(nested=Nested(value=5, cfg=Config(x=10, y=20)), flag=True)
+    with_list = WithList(items=[Config(x=1, y=2), Config(x=3, y=4), 
Config(x=5, y=6)], scale=7)
+    with_any_dc = WithAny(data=Config(x=1, y=2), scale=3)
+    with_any_int = WithAny(data=42, scale=3)
+    astuple = dataclasses.astuple
+
+    # Correctness validation
+    assert unpack_dataclass_to_tuple(cfg) == astuple(cfg) == (1, 2)
+    assert unpack_dataclass_to_tuple(nested) == astuple(nested) == (5, (10, 
20))
+    assert unpack_dataclass_to_tuple(wide) == astuple(wide) == (1, 2, 3, 4, 5)
+    assert unpack_dataclass_to_tuple(deep) == astuple(deep) == ((5, (10, 20)), 
True)
+    assert unpack_dataclass_to_tuple(with_list) == ([(1, 2), (3, 4), (5, 6)], 
7)
+    assert unpack_dataclass_to_tuple(with_any_dc) == ((1, 2), 3)
+    assert unpack_dataclass_to_tuple(with_any_int) == (42, 3)
+    assert unpack_dataclass_to_tuple(42) == 42
+
+    print("Benchmarking unpack_dataclass_to_tuple vs dataclasses.astuple...")
+    print("-" * 90)
+    benchmark_unpack(noop, cfg, "noop(Config) [baseline]")
+    benchmark_unpack(noop, 42, "noop(int) [baseline]")
+    benchmark_unpack(unpack_dataclass_to_tuple, 42, 
"unpack_dataclass_to_tuple(int) [leaf]")
+    benchmark_unpack(unpack_dataclass_to_tuple, cfg, 
"unpack_dataclass_to_tuple(Config)")
+    benchmark_unpack(astuple, cfg, "dataclasses.astuple(Config)")
+    benchmark_unpack(unpack_dataclass_to_tuple, nested, 
"unpack_dataclass_to_tuple(Nested)")
+    benchmark_unpack(astuple, nested, "dataclasses.astuple(Nested)")
+    benchmark_unpack(unpack_dataclass_to_tuple, wide, 
"unpack_dataclass_to_tuple(Wide)")
+    benchmark_unpack(astuple, wide, "dataclasses.astuple(Wide)")
+    benchmark_unpack(unpack_dataclass_to_tuple, deep, 
"unpack_dataclass_to_tuple(Deep)")
+    benchmark_unpack(astuple, deep, "dataclasses.astuple(Deep)")
+    benchmark_unpack(unpack_dataclass_to_tuple, with_list, 
"unpack_dataclass_to_tuple(WithList)")
+    benchmark_unpack(astuple, with_list, "dataclasses.astuple(WithList)")
+    benchmark_unpack(unpack_dataclass_to_tuple, with_any_dc, 
"unpack(WithAny(Config)) [dynamic]")
+    benchmark_unpack(astuple, with_any_dc, 
"dataclasses.astuple(WithAny(Config))")
+    benchmark_unpack(unpack_dataclass_to_tuple, with_any_int, 
"unpack(WithAny(int)) [dynamic]")
+    benchmark_unpack(astuple, with_any_int, 
"dataclasses.astuple(WithAny(int))")
+    cfg_any = ConfigAny(x=1, y=2)
+    cfg_any_nested = ConfigAny(x=Config(x=1, y=2), y=3)
+    assert unpack_dataclass_to_tuple(cfg_any) == (1, 2)
+    assert unpack_dataclass_to_tuple(cfg_any_nested) == ((1, 2), 3)
+    benchmark_unpack(unpack_dataclass_to_tuple, cfg_any, 
"unpack(ConfigAny(1,2)) [all Any, leaf]")
+    benchmark_unpack(astuple, cfg_any, "dataclasses.astuple(ConfigAny(1,2))")
+    benchmark_unpack(
+        unpack_dataclass_to_tuple, cfg_any_nested, 
"unpack(ConfigAny(Config,3)) [Any, nested]"
+    )
+    benchmark_unpack(astuple, cfg_any_nested, 
"dataclasses.astuple(ConfigAny(Config,3))")
+    print("-" * 90)


Reply via email to