This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 24054b4 [FEAT] Enhance map_dataclass_to_tuple with JIT
unpack_dataclass_to_tuple (#563)
24054b4 is described below
commit 24054b47bd013398e2c7bacb312f2d5e21059d50
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Apr 20 18:41:22 2026 -0400
[FEAT] Enhance map_dataclass_to_tuple with JIT unpack_dataclass_to_tuple
(#563)
This PR introduces an optimized unpack_dataclass_to_tuple util and
replaces `dataclasses.astuple` in `map_dataclass_to_tuple`, a
JIT-compiled per-class unpacker that brings 5-11x speedup and does not
deep-copy leaf values.
Not deep-copying leaf is important for runtime use-cases, when dataclass
fields contain tensor objects or other large resources that must not be
copied -- leaf values are passed through by reference, not duplicated.
## Design
`unpack_dataclass_to_tuple(x)` recursively converts dataclass instances
to tuples via JIT-compiled inline attribute access:
- `_extract_dataclass_to_tuple_schema(cls)` -- inspect type annotations
- `_validate_dataclass_to_tuple_schema(schema)` -- codegen security
- `_compile_dataclass_to_tuple_schema(prefix, schema)` -- inline code
Cached per-thread (TLS) for thread safety. Conservative field
classification: only marks as leaf when certain (known primitives,
containers of known primitives). Unknown types use dynamic dispatch.
## Performance (vs dataclasses.astuple)
| Case | unpack | astuple | speedup |
|-------------------|--------|---------|---------|
| Config (2 flat) | 144 ns | 769 ns | 5.3x |
| Nested (int + dc) | 152 ns | 1386 ns | 9.1x |
| Wide (5 flat) | 170 ns | 1268 ns | 7.5x |
| Deep (2 levels) | 200 ns | 2305 ns | 11.5x |
| WithAny (dynamic) | 254 ns | 1644 ns | 6.5x |
---
python/tvm_ffi/utils/kwargs_wrapper.py | 39 ++--
python/tvm_ffi/utils/unpack_dataclass.py | 263 +++++++++++++++++++++++++
tests/python/utils/test_kwargs_wrapper.py | 8 +-
tests/python/utils/test_unpack_dataclass.py | 294 ++++++++++++++++++++++++++++
tests/scripts/benchmark_kwargs_wrapper.py | 45 +++--
tests/scripts/benchmark_unpack_dataclass.py | 140 +++++++++++++
6 files changed, 753 insertions(+), 36 deletions(-)
diff --git a/python/tvm_ffi/utils/kwargs_wrapper.py
b/python/tvm_ffi/utils/kwargs_wrapper.py
index df1a8f9..1bad61c 100644
--- a/python/tvm_ffi/utils/kwargs_wrapper.py
+++ b/python/tvm_ffi/utils/kwargs_wrapper.py
@@ -22,12 +22,13 @@ keyword argument support using code generation techniques.
from __future__ import annotations
-import dataclasses
import functools
import inspect
import keyword
from typing import Any, Callable, Iterable
+from tvm_ffi.utils.unpack_dataclass import unpack_dataclass_to_tuple
+
# Sentinel object for missing arguments
MISSING = object()
@@ -35,12 +36,12 @@ MISSING = object()
_INTERNAL_TARGET_FUNC = "__i_target_func"
_INTERNAL_MISSING = "__i_MISSING"
_INTERNAL_DEFAULTS_DICT = "__i_arg_defaults"
-_INTERNAL_ASTUPLE = "__i_astuple"
+_INTERNAL_UNPACK = "__i_unpack"
_INTERNAL_NAMES = {
_INTERNAL_TARGET_FUNC,
_INTERNAL_MISSING,
_INTERNAL_DEFAULTS_DICT,
- _INTERNAL_ASTUPLE,
+ _INTERNAL_UNPACK,
}
@@ -163,34 +164,34 @@ def _build_wrapper_code(
call_parts: list[str] = []
runtime_defaults: dict[str, Any] = {}
- def _wrap_astuple(name: str, expr: str) -> str:
+ def _transform_expr(name: str, expr: str) -> str:
if name in dc_to_tuple_set:
- return f"{_INTERNAL_ASTUPLE}({expr})"
+ return f"{_INTERNAL_UNPACK}({expr})"
return expr
def _add_param_with_default(name: str, default_value: Any) -> None:
# Directly embed None and bool defaults; use MISSING sentinel for
others.
if default_value is None:
arg_parts.append(f"{name}=None")
- call_parts.append(_wrap_astuple(name, name))
+ call_parts.append(_transform_expr(name, name))
elif type(default_value) is bool:
default_value_str = "True" if default_value else "False"
arg_parts.append(f"{name}={default_value_str}")
- call_parts.append(_wrap_astuple(name, name))
+ call_parts.append(_transform_expr(name, name))
else:
arg_parts.append(f"{name}={_INTERNAL_MISSING}")
runtime_defaults[name] = default_value
base_expr = (
f'{_INTERNAL_DEFAULTS_DICT}["{name}"] if {name} is
{_INTERNAL_MISSING} else {name}'
)
- call_parts.append(_wrap_astuple(name, base_expr))
+ call_parts.append(_transform_expr(name, base_expr))
for name in arg_names:
if name in arg_defaults_dict:
_add_param_with_default(name, arg_defaults_dict[name])
else:
arg_parts.append(name)
- call_parts.append(_wrap_astuple(name, name))
+ call_parts.append(_transform_expr(name, name))
if kwonly_names:
arg_parts.append("*")
@@ -199,7 +200,7 @@ def _build_wrapper_code(
_add_param_with_default(name, kwonly_defaults[name])
else:
arg_parts.append(name)
- call_parts.append(_wrap_astuple(name, name))
+ call_parts.append(_transform_expr(name, name))
arg_list = ", ".join(arg_parts)
call_list = ", ".join(call_parts)
@@ -252,10 +253,13 @@ def make_kwargs_wrapper(
Optional prototype function to copy metadata (__name__, __doc__,
__module__,
__qualname__, __annotations__) from. If None, no metadata is copied.
map_dataclass_to_tuple
- Optional list of argument names whose values should be converted from
dataclass
- instances to tuples (via ``dataclasses.astuple``) before being passed
to the
- target function. This is useful when the target function expects
flattened tuple
- arguments but callers pass dataclass instances.
+ Optional list of argument names whose values should be converted from
+ dataclass instances to tuples via ``unpack_dataclass_to_tuple``.
+ Dataclass fields are unpacked to tuple and leaf fields are shallow
copied.
+
+ Nested dataclass fields are recursed automatically based on type
annotations.
+ Lists/tuples are recursed element-wise. Dicts are recursed on values.
+ Non-dataclass leaves are passed through unchanged.
Returns
-------
@@ -296,7 +300,7 @@ def make_kwargs_wrapper(
_INTERNAL_DEFAULTS_DICT: runtime_defaults,
}
if dc_to_tuple_set:
- exec_globals[_INTERNAL_ASTUPLE] = dataclasses.astuple
+ exec_globals[_INTERNAL_UNPACK] = unpack_dataclass_to_tuple
namespace: dict[str, Any] = {}
exec(code_str, exec_globals, namespace)
new_func = namespace["wrapper"]
@@ -335,9 +339,8 @@ def make_kwargs_wrapper_from_signature(
These arguments will not be included in the generated wrapper. If a
name in this iterable
does not exist in the signature, it is silently ignored.
map_dataclass_to_tuple
- Optional list of argument names whose values should be converted from
dataclass
- instances to tuples (via ``dataclasses.astuple``) before being passed
to the
- target function.
+ Optional list of argument names to unpack via
``unpack_dataclass_to_tuple``.
+ See ``make_kwargs_wrapper`` for details.
Returns
-------
diff --git a/python/tvm_ffi/utils/unpack_dataclass.py
b/python/tvm_ffi/utils/unpack_dataclass.py
new file mode 100644
index 0000000..b85ffc6
--- /dev/null
+++ b/python/tvm_ffi/utils/unpack_dataclass.py
@@ -0,0 +1,263 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Fast dataclass-to-tuple conversion via JIT-compiled unpacking.
+
+This module provides ``unpack_dataclass_to_tuple``, a function that recursively
+converts dataclass instances to tuples. It JIT-compiles a per-class unpacker
+on first call and caches it per-thread for ~5-11x speedup over
+``dataclasses.astuple`` with no deep-copy of leaf values.
+"""
+
+from __future__ import annotations
+
+import dataclasses
+import keyword
+import sys
+import threading
+import typing
+from typing import Any
+
+# Support both typing.Union and types.UnionType (PEP 604, Python 3.10+)
+if sys.version_info >= (3, 10):
+ import types
+
+ _LEAF_CONTAINER_ORIGINS = (tuple, dict, set, frozenset, typing.Union,
types.UnionType)
+else:
+ _LEAF_CONTAINER_ORIGINS = (tuple, dict, set, frozenset, typing.Union)
+
+# Type alias for dataclass-to-tuple schema (internal).
+# Schema values:
+# None -> leaf, direct attribute access (zero cost)
+# "unpack" -> dynamic dispatch via unpack_dataclass_to_tuple at runtime
+# dict -> nested struct, recurse inline
+# Example: {"x": None, "y": None} -> (__x.x, __x.y,)
+# Example: {"cfg": {"x": None, "y": None}, "data": "unpack"}
+# -> ((__x.cfg.x, __x.cfg.y,), __dispatch(__x.data),)
+DataclassToTupleSchema = dict # dict[str, None | str | DataclassToTupleSchema]
+
+# Sentinel value in schema: field should be dynamically dispatched
+UNPACK = "unpack"
+
+# Thread-local cache for JIT-compiled per-class unpack functions
+_tls = threading.local()
+
+# Types known to be safe leaves (never contain dataclass instances)
+_KNOWN_LEAF_TYPES: set[type] = {int, float, str, bool, bytes, complex,
type(None)}
+
+
+def _is_known_leaf_type(tp: Any) -> bool:
+ """Check if a type is definitely a leaf (no dataclass content or
conversion needed).
+
+ Note: list is NOT a leaf because it must be converted to a tuple per the
+ unpack contract (matching dataclasses.astuple behavior).
+ """
+ if isinstance(tp, type):
+ return tp in _KNOWN_LEAF_TYPES
+ if tp is Ellipsis:
+ return True
+ origin = typing.get_origin(tp)
+ if origin is not None:
+ # list is NOT a leaf — must be converted to tuple
+ # tuple/dict/set/frozenset/Union are leaves if all args are leaves
+ if origin in _LEAF_CONTAINER_ORIGINS:
+ args = typing.get_args(tp)
+ return bool(args) and all(_is_known_leaf_type(a) for a in args)
+ return False
+
+
+def _classify_field_type(
+ field_type: Any, memo: set[type] | None = None
+) -> None | str | DataclassToTupleSchema:
+ """Classify a resolved field type into a schema entry.
+
+ Conservative: only returns None (leaf) when we are certain the type
+ cannot contain a dataclass. Otherwise returns UNPACK (dynamic dispatch).
+ """
+ if isinstance(field_type, str) or field_type is Any or field_type is
object:
+ return UNPACK
+ if dataclasses.is_dataclass(field_type) and isinstance(field_type, type):
+ # Guard against infinite recursion for self-referential dataclasses
+ if memo is not None and field_type in memo:
+ return UNPACK
+ return _extract_dataclass_to_tuple_schema(field_type, memo=memo)
+ # Known primitive types -> leaf
+ if isinstance(field_type, type) and field_type in _KNOWN_LEAF_TYPES:
+ return None
+ # Generic containers: check element types
+ # list always needs UNPACK (must be converted to tuple)
+ # tuple/dict/set/frozenset/Union are leaves if all args are known leaves
+ # Generic containers: list always UNPACK (must convert to tuple).
+ # tuple/dict/set/frozenset/Union are leaves only if all args are known
leaves.
+ # Everything else (unknown type) -> UNPACK (conservative).
+ origin = typing.get_origin(field_type)
+ if origin in _LEAF_CONTAINER_ORIGINS:
+ args = typing.get_args(field_type)
+ if args and all(_is_known_leaf_type(a) for a in args):
+ return None
+ return UNPACK
+
+
+def _compile_dataclass_to_tuple_schema(prefix: str, schema:
DataclassToTupleSchema) -> str:
+ """Compile a DataclassToTupleSchema into an inline tuple expression.
+
+ Parameters
+ ----------
+ prefix
+ The variable expression to unpack (e.g. "__x" or "__x.field").
+ schema
+ The schema dict mapping field names to:
+ - None: leaf, direct attribute access
+ - "unpack": dynamic dispatch via __dispatch() at runtime
+ - nested dict: recurse inline
+
+ Returns
+ -------
+ A string expression that evaluates to a tuple of the unpacked fields.
+
+ """
+ parts: list[str] = []
+ for field_name, sub_schema in schema.items():
+ field_expr = f"{prefix}.{field_name}"
+ if sub_schema is None:
+ parts.append(field_expr)
+ elif sub_schema == UNPACK:
+ parts.append(f"__dispatch({field_expr})")
+ else:
+ parts.append(_compile_dataclass_to_tuple_schema(field_expr,
sub_schema))
+ return "(" + ", ".join(parts) + (",)" if parts else ")")
+
+
+def _validate_dataclass_to_tuple_schema(schema: DataclassToTupleSchema) ->
None:
+ """Validate that a DataclassToTupleSchema contains only safe identifiers.
+
+ This is critical for security since field names are embedded directly
+ in generated code via exec(). The validation ensures:
+ - Keys are strings (type check)
+ - Keys pass str.isidentifier() — rejects any non-identifier characters
+ - Keys are not Python keywords — rejects control flow injection
+ - Values are only None, "unpack", or recursively-validated dicts
+
+ Combined with the hardcoded prefix ("__x") and restricted exec_globals,
+ this prevents any code injection through crafted field names.
+
+ """
+ if not isinstance(schema, dict):
+ raise TypeError(f"DataclassToTupleSchema must be a dict, got
{type(schema).__name__}")
+ for field_name, sub_schema in schema.items():
+ if not isinstance(field_name, str):
+ raise TypeError(f"Schema field name must be a string, got
{type(field_name).__name__}")
+ if not field_name.isidentifier():
+ raise ValueError(f"Schema field name {field_name!r} is not a valid
Python identifier")
+ if keyword.iskeyword(field_name):
+ raise ValueError(f"Schema field name {field_name!r} is a Python
keyword")
+ if sub_schema is not None and sub_schema != UNPACK:
+ _validate_dataclass_to_tuple_schema(sub_schema)
+
+
+def _extract_dataclass_to_tuple_schema(
+ cls: type, *, memo: set[type] | None = None
+) -> DataclassToTupleSchema:
+ """Extract a DataclassToTupleSchema from a dataclass class using type
annotations.
+
+ Classification per field (conservative: only leaf when certain):
+ - Known dataclass type -> nested schema (recurse inline)
+ - Known primitive type (int, float, str, bool, bytes, complex) -> None
(leaf)
+ - Container with only known-leaf args (list[int], dict[str, float]) ->
None (leaf)
+ - Container with dataclass/unknown args (list[Config]) -> "unpack"
(dynamic dispatch)
+ - Any, object, unresolved string annotation -> "unpack" (dynamic dispatch)
+ - Unknown class -> "unpack" (dynamic dispatch)
+
+ Uses typing.get_type_hints() to resolve PEP 563 string annotations.
+ Uses memo set to prevent infinite recursion on self-referential
dataclasses.
+
+ """
+ if not dataclasses.is_dataclass(cls) or not isinstance(cls, type):
+ raise TypeError(f"Expected a dataclass class, got {cls!r}")
+ if memo is None:
+ memo = set()
+ memo.add(cls)
+ try:
+ type_hints = typing.get_type_hints(cls)
+ except (NameError, TypeError, AttributeError):
+ type_hints = {}
+ schema: DataclassToTupleSchema = {}
+ for f in dataclasses.fields(cls):
+ field_type = type_hints.get(f.name, f.type)
+ schema[f.name] = _classify_field_type(field_type, memo=memo)
+ return schema
+
+
+def unpack_dataclass_to_tuple(x: Any) -> Any:
+ """Fast recursively unpack a dataclass value to tuple representation.
+
+ - Dataclass instances are unpacked to tuples of their field values.
+ - Lists and tuples are recursed element-wise, returning a tuple.
+ - Dicts are recursed on values, returning a new dict.
+ - All other values are returned as-is (leaf passthrough).
+
+ This function optimizes speed via JIT-compiling the conversion per
dataclass
+ class and caching it per-thread. It brings about 5-11x speedup vs
+ ``dataclasses.astuple`` and does not deep-copy leaf values.
+
+ Parameters
+ ----------
+ x
+ The value to unpack.
+
+ Returns
+ -------
+ The unpacked tuple representation, or ``x`` unchanged if it's a leaf.
+
+ """
+ try:
+ cache = _tls.cache
+ except AttributeError:
+ cache = _tls.cache = {}
+
+ cls = type(x)
+ fn = cache.get(cls)
+ if fn is not None:
+ return fn(x)
+
+ # Cache miss — classify the type
+ if dataclasses.is_dataclass(cls) and isinstance(cls, type):
+ schema = _extract_dataclass_to_tuple_schema(cls)
+ # Validate that all field names in the schema are safe Python
identifiers.
+ # This is critical: field names are embedded directly in the generated
code string.
+ # _validate_dataclass_to_tuple_schema ensures no code injection is
possible via
+ # crafted field names (isidentifier + iskeyword checks).
+ _validate_dataclass_to_tuple_schema(schema)
+ code_expr = _compile_dataclass_to_tuple_schema("__x", schema)
+ code = f"def __unpack(__x): return {code_expr}"
+ namespace: dict[str, Any] = {}
+ # exec_globals only exposes __dispatch (our own function), no other
capabilities.
+ exec(code, {"__dispatch": unpack_dataclass_to_tuple}, namespace)
+ fn = namespace["__unpack"]
+ cache[cls] = fn
+ return fn(x)
+ if isinstance(x, (list, tuple)):
+ return type(x)(unpack_dataclass_to_tuple(e) for e in x)
+ if isinstance(x, dict):
+ return {k: unpack_dataclass_to_tuple(v) for k, v in x.items()}
+ # True leaf — cache identity so next call is just dict.get + return
+ cache[cls] = _LEAF_IDENTITY
+ return x
+
+
+def _LEAF_IDENTITY(x: Any) -> Any:
+ """Identity function cached for known leaf types."""
+ return x
diff --git a/tests/python/utils/test_kwargs_wrapper.py
b/tests/python/utils/test_kwargs_wrapper.py
index 2b731b2..61362c1 100644
--- a/tests/python/utils/test_kwargs_wrapper.py
+++ b/tests/python/utils/test_kwargs_wrapper.py
@@ -373,7 +373,7 @@ def test_optimized_default_types() -> None:
def test_map_dataclass_to_tuple() -> None:
- """Test that dataclass arguments are converted to tuples via
dataclasses.astuple."""
+ """Test map_dataclass_to_tuple in make_kwargs_wrapper."""
@dataclasses.dataclass
class Config:
@@ -388,7 +388,7 @@ def test_map_dataclass_to_tuple() -> None:
def target(*args: Any) -> tuple[Any, ...]:
return args
- # Basic: one dataclass arg converted to tuple
+ # Basic: one dataclass arg converted
wrapper = make_kwargs_wrapper(target, ["a", "cfg"],
map_dataclass_to_tuple=["cfg"])
result = wrapper(1, Config(x=10, y=20))
assert result == (1, (10, 20))
@@ -402,7 +402,7 @@ def test_map_dataclass_to_tuple() -> None:
result = wrapper(Config(x=1, y=2), Config(x=3, y=4))
assert result == ((1, 2), (3, 4))
- # Nested dataclass (astuple recurses)
+ # Nested dataclass (auto-recursion via type annotations)
wrapper = make_kwargs_wrapper(target, ["a", "nested"],
map_dataclass_to_tuple=["nested"])
result = wrapper(1, Nested(value=5, cfg=Config(x=10, y=20)))
assert result == (1, (5, (10, 20)))
@@ -412,7 +412,7 @@ def test_map_dataclass_to_tuple() -> None:
result = wrapper(1, Config(x=10, y=20), 3)
assert result == (1, (10, 20), 3)
- # With defaults: dataclass arg has a default
+ # With defaults
default_cfg = Config(x=0, y=0)
wrapper = make_kwargs_wrapper(
target, ["a", "cfg"], arg_defaults=(default_cfg,),
map_dataclass_to_tuple=["cfg"]
diff --git a/tests/python/utils/test_unpack_dataclass.py
b/tests/python/utils/test_unpack_dataclass.py
new file mode 100644
index 0000000..3a557ea
--- /dev/null
+++ b/tests/python/utils/test_unpack_dataclass.py
@@ -0,0 +1,294 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import dataclasses
+import sys
+from typing import Any
+
+import pytest
+from tvm_ffi.utils.unpack_dataclass import (
+ _extract_dataclass_to_tuple_schema,
+ _validate_dataclass_to_tuple_schema,
+ unpack_dataclass_to_tuple,
+)
+
+
+# Module-level dataclass definitions for schema extraction tests.
+# Must be at module level so typing.get_type_hints() can resolve
+# cross-references with PEP 563 (from __future__ import annotations).
[email protected]
+class _ExtractConfig:
+ x: int
+ y: int
+
+
[email protected]
+class _ExtractNested:
+ value: int
+ cfg: _ExtractConfig
+
+
[email protected]
+class _ExtractWithAny:
+ data: Any
+ scale: int
+
+
[email protected]
+class _ExtractWithList:
+ items: list[_ExtractConfig]
+ scale: int
+
+
[email protected]
+class _ExtractWithLeafList:
+ values: list[int]
+ name: str
+
+
[email protected]
+class _ExtractWithDict:
+ mapping: dict[str, _ExtractConfig]
+ count: int
+
+
[email protected]
+class _ExtractWithLeafDict:
+ mapping: dict[str, int]
+ count: int
+
+
[email protected]
+class _ExtractWithTuple:
+ pair: tuple[_ExtractConfig, int]
+ flag: bool
+
+
[email protected]
+class _ExtractWithOptional:
+ value: int | None
+ name: str
+
+
[email protected]
+class _ExtractWithLeafListInt:
+ items: list[int]
+ scale: int
+
+
+def test_unpack_dataclass_to_tuple() -> None:
+ """Test unpack_dataclass_to_tuple JIT-compiled unpacking."""
+
+ @dataclasses.dataclass
+ class Config:
+ x: int
+ y: int
+
+ @dataclasses.dataclass
+ class Nested:
+ value: int
+ cfg: Config
+
+ @dataclasses.dataclass
+ class Deep:
+ nested: Nested
+ flag: bool
+
+ # Flat dataclass
+ assert unpack_dataclass_to_tuple(Config(x=1, y=2)) == (1, 2)
+
+ # Nested dataclass (auto-recurses based on type annotation)
+ assert unpack_dataclass_to_tuple(Nested(value=5, cfg=Config(x=10, y=20)))
== (5, (10, 20))
+
+ # Deep nesting
+ assert unpack_dataclass_to_tuple(
+ Deep(nested=Nested(value=5, cfg=Config(x=10, y=20)), flag=True)
+ ) == ((5, (10, 20)), True)
+
+ # Leaf passthrough
+ assert unpack_dataclass_to_tuple(42) == 42
+ assert unpack_dataclass_to_tuple("hello") == "hello"
+ assert unpack_dataclass_to_tuple(None) is None
+
+ # List recursion: list of dataclasses -> tuple of tuples
+ assert unpack_dataclass_to_tuple([Config(x=1, y=2), Config(x=3, y=4)]) ==
[(1, 2), (3, 4)]
+
+ # Tuple recursion
+ assert unpack_dataclass_to_tuple((Config(x=1, y=2), 5)) == ((1, 2), 5)
+
+ # Dict recursion (recurses values)
+ assert unpack_dataclass_to_tuple({"a": Config(x=1, y=2), "b": 3}) == {"a":
(1, 2), "b": 3}
+
+ # Leaf values are NOT copied (no deep copy)
+ class Holder:
+ pass
+
+ @dataclasses.dataclass
+ class WithObj:
+ obj: Any
+ val: int
+
+ h = Holder()
+ result = unpack_dataclass_to_tuple(WithObj(obj=h, val=1))
+ assert result == (h, 1)
+ assert result[0] is h # same object reference, no copy
+
+ # Dynamic dispatch: Any-typed field receives a dataclass at runtime
+ @dataclasses.dataclass
+ class WithAnyField:
+ data: Any
+ scale: int
+
+ # data is Any -> schema marks it as "unpack" -> __dispatch called at
runtime
+ # When data is a dataclass, it should be recursively unpacked
+ result = unpack_dataclass_to_tuple(WithAnyField(data=Config(x=1, y=2),
scale=3))
+ assert result == ((1, 2), 3)
+
+ # When data is a plain value, passthrough
+ result = unpack_dataclass_to_tuple(WithAnyField(data=42, scale=3))
+ assert result == (42, 3)
+
+ # When data is a list of dataclasses, recurse each element
+ result = unpack_dataclass_to_tuple(
+ WithAnyField(data=[Config(x=1, y=2), Config(x=3, y=4)], scale=5)
+ )
+ assert result == ([(1, 2), (3, 4)], 5)
+
+ # When data is a nested dataclass
+ result = unpack_dataclass_to_tuple(
+ WithAnyField(data=Nested(value=10, cfg=Config(x=1, y=2)), scale=5)
+ )
+ assert result == ((10, (1, 2)), 5)
+
+ # When data is a dict with dataclass values
+ result = unpack_dataclass_to_tuple(
+ WithAnyField(data={"a": Config(x=1, y=2), "b": Config(x=3, y=4)},
scale=5)
+ )
+ assert result == ({"a": (1, 2), "b": (3, 4)}, 5)
+
+ # Self-referential dataclass (linked list): should not infinite recurse
+ @dataclasses.dataclass
+ class Node:
+ value: int
+ next: Node | None
+
+ # Build a short linked list
+ node = Node(value=1, next=Node(value=2, next=None))
+ result = unpack_dataclass_to_tuple(node)
+ # The 'next' field is typed as Node|None, which on 3.10+ resolves to
+ # a UnionType. The self-reference is caught by memo -> UNPACK -> dynamic
dispatch.
+ # Dynamic dispatch recursively unpacks the nested Node.
+ assert result == (1, (2, None))
+
+
+def test_validate_dataclass_to_tuple_schema() -> None:
+ """Test internal schema validation."""
+ # Valid schemas
+ _validate_dataclass_to_tuple_schema({"x": None, "y": None})
+ _validate_dataclass_to_tuple_schema({"cfg": {"x": None, "y": None},
"scale": None})
+
+ # Invalid: not a dict
+ with pytest.raises(TypeError, match="must be a dict"):
+ _validate_dataclass_to_tuple_schema([1, 2]) # type: ignore[arg-type]
+
+ # Invalid: non-string key
+ with pytest.raises(TypeError, match="must be a string"):
+ _validate_dataclass_to_tuple_schema({123: None})
+
+ # Invalid: not a valid identifier
+ with pytest.raises(ValueError, match="not a valid Python identifier"):
+ _validate_dataclass_to_tuple_schema({"not-valid": None})
+
+ # Invalid: Python keyword
+ with pytest.raises(ValueError, match="is a Python keyword"):
+ _validate_dataclass_to_tuple_schema({"class": None})
+
+ # Invalid: nested schema with bad key
+ with pytest.raises(ValueError, match="not a valid Python identifier"):
+ _validate_dataclass_to_tuple_schema({"cfg": {"x": None, "y!": None}})
+
+
[email protected](
+ sys.version_info < (3, 10),
+ reason="list[X]/dict[X,Y]/int|None not evaluable by get_type_hints on
Python < 3.10",
+)
+def test_extract_dataclass_to_tuple_schema() -> None:
+ """Test schema extraction from dataclass types."""
+ # Flat: all known leaf types -> None
+ schema = _extract_dataclass_to_tuple_schema(_ExtractConfig)
+ assert schema == {"x": None, "y": None}
+
+ # Nested: known dataclass field -> nested schema
+ schema = _extract_dataclass_to_tuple_schema(_ExtractNested)
+ assert schema == {"value": None, "cfg": {"x": None, "y": None}}
+
+ # Any field -> "unpack" (dynamic dispatch)
+ schema = _extract_dataclass_to_tuple_schema(_ExtractWithAny)
+ assert schema == {"data": "unpack", "scale": None}
+
+ # list[Config] -> "unpack" (container with dataclass element)
+ schema = _extract_dataclass_to_tuple_schema(_ExtractWithList)
+ assert schema == {"items": "unpack", "scale": None}
+
+ # list[int] -> "unpack" (list must be converted to tuple per contract)
+ schema = _extract_dataclass_to_tuple_schema(_ExtractWithLeafList)
+ assert schema == {"values": "unpack", "name": None}
+
+ # list[int] standalone field also gets UNPACK
+ schema = _extract_dataclass_to_tuple_schema(_ExtractWithLeafListInt)
+ assert schema == {"items": "unpack", "scale": None}
+
+ # dict[str, Config] -> "unpack" (container with dataclass value type)
+ schema = _extract_dataclass_to_tuple_schema(_ExtractWithDict)
+ assert schema == {"mapping": "unpack", "count": None}
+
+ # dict[str, int] -> None (container with only known leaf types)
+ schema = _extract_dataclass_to_tuple_schema(_ExtractWithLeafDict)
+ assert schema == {"mapping": None, "count": None}
+
+ # tuple[Config, int] -> "unpack" (tuple containing a dataclass)
+ schema = _extract_dataclass_to_tuple_schema(_ExtractWithTuple)
+ assert schema == {"pair": "unpack", "flag": None}
+
+ # Optional[int] (int | None) -> None (Union of known leaves)
+ schema = _extract_dataclass_to_tuple_schema(_ExtractWithOptional)
+ assert schema == {"value": None, "name": None}
+
+ # Non-dataclass raises
+ with pytest.raises(TypeError, match="Expected a dataclass class"):
+ _extract_dataclass_to_tuple_schema(int)
+
+ # Locally-defined classes with built-in type annotations resolve fine
+ @dataclasses.dataclass
+ class LocalConfig:
+ x: int
+ y: int
+
+ schema = _extract_dataclass_to_tuple_schema(LocalConfig)
+ assert schema == {"x": None, "y": None}
+
+ # Locally-defined classes referencing other local classes can't be resolved
+ # by get_type_hints — all fields fall back to UNPACK
+ @dataclasses.dataclass
+ class LocalNested:
+ val: int
+ cfg: LocalConfig
+
+ schema = _extract_dataclass_to_tuple_schema(LocalNested)
+ # get_type_hints fails for the class -> all fields become "unpack"
+ assert schema == {"val": "unpack", "cfg": "unpack"}
diff --git a/tests/scripts/benchmark_kwargs_wrapper.py
b/tests/scripts/benchmark_kwargs_wrapper.py
index 185400c..56bff57 100644
--- a/tests/scripts/benchmark_kwargs_wrapper.py
+++ b/tests/scripts/benchmark_kwargs_wrapper.py
@@ -18,6 +18,7 @@
from __future__ import annotations
+import dataclasses
import time
from typing import Any
@@ -32,51 +33,67 @@ def target_func(*args: Any) -> None:
pass
[email protected]
+class Config:
+ x: int
+ y: int
+ z: int
+
+
def benchmark_kwargs_wrapper(repeat: int = 1000000) -> None:
- """Benchmark kwargs wrapper with integer arguments."""
- # Create test arguments
+ """Benchmark kwargs wrapper with integer arguments and dataclass
unpacking."""
x = 1
y = 2
z = 3
+ cfg = Config(x=x, y=y, z=z)
- # Create wrapper with two optional kwargs
wrapper = make_kwargs_wrapper(target_func, ["x", "y", "z"],
arg_defaults=(None, None))
+ wrapper_dc = make_kwargs_wrapper(target_func, ["cfg"],
map_dataclass_to_tuple=["cfg"])
+ # Warm up JIT cache
+ wrapper_dc(cfg)
- # Benchmark 1: Direct call to target function (baseline)
+ # Direct call (baseline)
start = time.time()
for _ in range(repeat):
target_func(x, y, z)
end = time.time()
print_speed("target_func(x, y, z)", (end - start) / repeat)
- # Benchmark 2: Wrapper with all positional arguments
+ # Wrapper with positional args
start = time.time()
for _ in range(repeat):
wrapper(x, y, z)
end = time.time()
print_speed("wrapper(x, y, z)", (end - start) / repeat)
- # Benchmark 3: Wrapper with positional + kwargs
- start = time.time()
- for _ in range(repeat):
- wrapper(x, y=y, z=z)
- end = time.time()
- print_speed("wrapper(x, y=y, z=z)", (end - start) / repeat)
-
- # Benchmark 4: Wrapper with all kwargs
+ # Wrapper with kwargs
start = time.time()
for _ in range(repeat):
wrapper(x=x, y=y, z=z)
end = time.time()
print_speed("wrapper(x=x, y=y, z=z)", (end - start) / repeat)
- # Benchmark 5: Wrapper with defaults
+ # Wrapper with defaults
start = time.time()
for _ in range(repeat):
wrapper(x)
end = time.time()
print_speed("wrapper(x) [with defaults]", (end - start) / repeat)
+ # Wrapper with dataclass unpack
+ start = time.time()
+ for _ in range(repeat):
+ wrapper_dc(cfg)
+ end = time.time()
+ print_speed("wrapper_dc(cfg) [map_dataclass_to_tuple]", (end - start) /
repeat)
+
+ # Manual unpack (best possible)
+ start = time.time()
+ for _ in range(repeat):
+ target_func(cfg.x, cfg.y, cfg.z)
+ end = time.time()
+ print_speed("target_func(cfg.x, cfg.y, cfg.z) [manual]", (end - start) /
repeat)
+
if __name__ == "__main__":
print("Benchmarking kwargs_wrapper overhead...")
diff --git a/tests/scripts/benchmark_unpack_dataclass.py
b/tests/scripts/benchmark_unpack_dataclass.py
new file mode 100644
index 0000000..453f5a6
--- /dev/null
+++ b/tests/scripts/benchmark_unpack_dataclass.py
@@ -0,0 +1,140 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Benchmark unpack_dataclass_to_tuple vs dataclasses.astuple."""
+
+from __future__ import annotations
+
+import dataclasses
+import time
+from typing import Any
+
+from tvm_ffi.utils.unpack_dataclass import unpack_dataclass_to_tuple
+
+
+def print_speed(name: str, speed: float) -> None:
+ print(f"{name:<60} {speed} sec/call")
+
+
+def benchmark_unpack(unpack_fn: Any, val: Any, name: str, repeat: int =
1000000) -> None:
+ """Benchmark an unpack function on a single value."""
+ unpack_fn(val)
+ start = time.time()
+ for _ in range(repeat):
+ unpack_fn(val)
+ end = time.time()
+ print_speed(name, (end - start) / repeat)
+
+
[email protected]
+class Config:
+ x: int
+ y: int
+
+
[email protected]
+class Nested:
+ value: int
+ cfg: Config
+
+
[email protected]
+class Wide:
+ a: int
+ b: int
+ c: int
+ d: int
+ e: int
+
+
[email protected]
+class Deep:
+ nested: Nested
+ flag: bool
+
+
[email protected]
+class WithList:
+ items: list[Config]
+ scale: int
+
+
[email protected]
+class WithAny:
+ data: Any
+ scale: int
+
+
[email protected]
+class ConfigAny:
+ x: Any
+ y: Any
+
+
+def noop(x: Any) -> Any:
+ return x
+
+
+if __name__ == "__main__":
+ cfg = Config(x=1, y=2)
+ nested = Nested(value=5, cfg=Config(x=10, y=20))
+ wide = Wide(a=1, b=2, c=3, d=4, e=5)
+ deep = Deep(nested=Nested(value=5, cfg=Config(x=10, y=20)), flag=True)
+ with_list = WithList(items=[Config(x=1, y=2), Config(x=3, y=4),
Config(x=5, y=6)], scale=7)
+ with_any_dc = WithAny(data=Config(x=1, y=2), scale=3)
+ with_any_int = WithAny(data=42, scale=3)
+ astuple = dataclasses.astuple
+
+ # Correctness validation
+ assert unpack_dataclass_to_tuple(cfg) == astuple(cfg) == (1, 2)
+ assert unpack_dataclass_to_tuple(nested) == astuple(nested) == (5, (10,
20))
+ assert unpack_dataclass_to_tuple(wide) == astuple(wide) == (1, 2, 3, 4, 5)
+ assert unpack_dataclass_to_tuple(deep) == astuple(deep) == ((5, (10, 20)),
True)
+ assert unpack_dataclass_to_tuple(with_list) == ([(1, 2), (3, 4), (5, 6)],
7)
+ assert unpack_dataclass_to_tuple(with_any_dc) == ((1, 2), 3)
+ assert unpack_dataclass_to_tuple(with_any_int) == (42, 3)
+ assert unpack_dataclass_to_tuple(42) == 42
+
+ print("Benchmarking unpack_dataclass_to_tuple vs dataclasses.astuple...")
+ print("-" * 90)
+ benchmark_unpack(noop, cfg, "noop(Config) [baseline]")
+ benchmark_unpack(noop, 42, "noop(int) [baseline]")
+ benchmark_unpack(unpack_dataclass_to_tuple, 42,
"unpack_dataclass_to_tuple(int) [leaf]")
+ benchmark_unpack(unpack_dataclass_to_tuple, cfg,
"unpack_dataclass_to_tuple(Config)")
+ benchmark_unpack(astuple, cfg, "dataclasses.astuple(Config)")
+ benchmark_unpack(unpack_dataclass_to_tuple, nested,
"unpack_dataclass_to_tuple(Nested)")
+ benchmark_unpack(astuple, nested, "dataclasses.astuple(Nested)")
+ benchmark_unpack(unpack_dataclass_to_tuple, wide,
"unpack_dataclass_to_tuple(Wide)")
+ benchmark_unpack(astuple, wide, "dataclasses.astuple(Wide)")
+ benchmark_unpack(unpack_dataclass_to_tuple, deep,
"unpack_dataclass_to_tuple(Deep)")
+ benchmark_unpack(astuple, deep, "dataclasses.astuple(Deep)")
+ benchmark_unpack(unpack_dataclass_to_tuple, with_list,
"unpack_dataclass_to_tuple(WithList)")
+ benchmark_unpack(astuple, with_list, "dataclasses.astuple(WithList)")
+ benchmark_unpack(unpack_dataclass_to_tuple, with_any_dc,
"unpack(WithAny(Config)) [dynamic]")
+ benchmark_unpack(astuple, with_any_dc,
"dataclasses.astuple(WithAny(Config))")
+ benchmark_unpack(unpack_dataclass_to_tuple, with_any_int,
"unpack(WithAny(int)) [dynamic]")
+ benchmark_unpack(astuple, with_any_int,
"dataclasses.astuple(WithAny(int))")
+ cfg_any = ConfigAny(x=1, y=2)
+ cfg_any_nested = ConfigAny(x=Config(x=1, y=2), y=3)
+ assert unpack_dataclass_to_tuple(cfg_any) == (1, 2)
+ assert unpack_dataclass_to_tuple(cfg_any_nested) == ((1, 2), 3)
+ benchmark_unpack(unpack_dataclass_to_tuple, cfg_any,
"unpack(ConfigAny(1,2)) [all Any, leaf]")
+ benchmark_unpack(astuple, cfg_any, "dataclasses.astuple(ConfigAny(1,2))")
+ benchmark_unpack(
+ unpack_dataclass_to_tuple, cfg_any_nested,
"unpack(ConfigAny(Config,3)) [Any, nested]"
+ )
+ benchmark_unpack(astuple, cfg_any_nested,
"dataclasses.astuple(ConfigAny(Config,3))")
+ print("-" * 90)