This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new b4525d2 fix(dataclass): collect py_class fields from plain bases
(#575)
b4525d2 is described below
commit b4525d2bb0e1c0190d3d1bdf46711ee5dd1c4889
Author: Junru Shao <[email protected]>
AuthorDate: Thu Apr 30 12:11:14 2026 -0700
fix(dataclass): collect py_class fields from plain bases (#575)
## Summary
Fixes `@py_class` field collection when a decorated class inherits
annotated fields through one or more non-`@py_class` Python bases.
Previously, `typing.get_type_hints(cls)` resolved inherited annotations,
but the field collector only iterated over annotations declared directly
on `cls`. A class such as `Add(BaseBinOp)` where `BaseBinOp` is a plain
Python subclass of a registered FFI class would therefore register no
`lhs` / `rhs` fields on `Add`, and construction failed with an
unexpected keyword argument error.
This PR updates the collector to walk the current class MRO until the
nearest registered FFI parent, collecting annotations from
non-registered Python bases as fields owned by the new type. The nearest
registered parent remains the TypeInfo boundary, so fields already
registered on `@py_class` ancestors are not duplicated.
The field merge now follows the review feedback from Gemini:
- collection iterates owner classes from base to child,
- `KW_ONLY` state is carried across the flattened non-FFI owner list,
- fields are merged by name in insertion-order-preserving dictionaries,
so an override keeps the original field position while adopting the
most-derived metadata.
It also updates class attribute installation to only preserve attributes
defined directly on the registered class. This prevents inherited plain
Python defaults or `Field` sentinels from shadowing the reflected
property descriptor installed for the child type.
## Behavioral Impact
- `@py_class` subclasses now accept fields declared on unregistered
Python parents between the child and its registered FFI ancestor.
- Field metadata from those plain parents is preserved, including
`ClassVar`, `KW_ONLY`, and `init=False` behavior.
- Python multiple inheritance through non-FFI bases is covered as long
as the class has a single registered FFI lineage.
- More-derived field overrides preserve the first field position while
replacing metadata and type information.
- No public API or ABI is added or removed.
## Tests
- `python -m ruff check tests/python/test_dataclass_py_class.py
python/tvm_ffi/dataclasses/py_class.py python/tvm_ffi/registry.py`
- `ty check tests/python/test_dataclass_py_class.py`
- `python -m pytest tests/python/test_dataclass_py_class.py -q`
- `python -m pytest tests/python/test_dataclass_enum.py
tests/python/test_dataclass_common.py
tests/python/test_dataclass_c_class.py
tests/python/test_dataclass_init.py -q`
All checks passed locally.
Note: the standalone `bug.py` reproducer is not present in this checkout
after the latest amendment, but the same construction path is covered by
`test_collects_fields_from_non_py_class_parent`.
---
python/tvm_ffi/dataclasses/py_class.py | 82 ++++++++++++++------
python/tvm_ffi/registry.py | 2 +-
tests/python/test_dataclass_py_class.py | 130 +++++++++++++++++++++++++++++++-
3 files changed, 188 insertions(+), 26 deletions(-)
diff --git a/python/tvm_ffi/dataclasses/py_class.py
b/python/tvm_ffi/dataclasses/py_class.py
index ac77c30..b9cacd3 100644
--- a/python/tvm_ffi/dataclasses/py_class.py
+++ b/python/tvm_ffi/dataclasses/py_class.py
@@ -21,6 +21,7 @@ from __future__ import annotations
import sys
import typing
from collections.abc import Callable
+from copy import copy
from dataclasses import dataclass
from typing import Any, ClassVar, TypeVar
@@ -136,12 +137,38 @@ def _rollback_registration(cls: type, type_info: Any) ->
None:
# ---------------------------------------------------------------------------
+def _own_annotations(cls: type) -> dict[str, Any]:
+ """Return annotations declared directly on *cls*."""
+ # Python 3.14+ (PEP 749): annotations are lazily evaluated via
+ # __annotate__ and no longer stored directly in __dict__. getattr()
+ # triggers evaluation and returns per-class annotations correctly.
+ # On Python < 3.14, getattr() follows MRO and returns *parent*
+ # annotations when the child has none — use __dict__ to avoid that.
+ if sys.version_info >= (3, 14):
+ return getattr(cls, "__annotations__", {})
+ return cls.__dict__.get("__annotations__", {})
+
+
+def _field_owner_classes(cls: type) -> list[type]:
+ """Classes whose annotations become this type's own fields."""
+ registered_parent = next(
+ (b for b in cls.__mro__[1:] if "__tvm_ffi_type_info__" in b.__dict__),
object
+ )
+ represented = set(registered_parent.__mro__)
+ return [
+ b
+ for b in reversed(cls.__mro__)
+ if b is not object and b not in represented and _own_annotations(b)
+ ]
+
+
def _collect_own_fields( # noqa: PLR0912
cls: type,
+ owner: type,
hints: dict[str, Any],
- decorator_kw_only: bool,
+ kw_only_active: bool,
decorator_frozen: bool,
-) -> list[Field]:
+) -> tuple[list[Field], bool]:
"""Parse own annotations into :class:`Field` objects.
- Skips ``ClassVar`` annotations.
@@ -152,16 +179,7 @@ def _collect_own_fields( # noqa: PLR0912
- Resolves ``hash=None`` to follow ``compare``.
"""
fields: list[Field] = []
- kw_only_active = decorator_kw_only
- # Python 3.14+ (PEP 749): annotations are lazily evaluated via
- # __annotate__ and no longer stored directly in __dict__. getattr()
- # triggers evaluation and returns per-class annotations correctly.
- # On Python < 3.14, getattr() follows MRO and returns *parent*
- # annotations when the child has none — use __dict__ to avoid that.
- if sys.version_info >= (3, 14):
- own_annotations: dict[str, str] = getattr(cls, "__annotations__", {})
- else:
- own_annotations = cls.__dict__.get("__annotations__", {})
+ own_annotations = _own_annotations(owner)
for name in own_annotations:
resolved_type = hints.get(name)
@@ -176,7 +194,7 @@ def _collect_own_fields( # noqa: PLR0912
# KW_ONLY sentinel
if resolved_type is KW_ONLY:
kw_only_active = True
- if name in cls.__dict__:
+ if owner is cls and name in cls.__dict__:
try:
delattr(cls, name)
except AttributeError:
@@ -184,14 +202,14 @@ def _collect_own_fields( # noqa: PLR0912
continue
# Extract Field from class dict (inline of _pop_field_from_class)
- class_val = cls.__dict__.get(name, MISSING)
+ class_val = owner.__dict__.get(name, MISSING)
if isinstance(class_val, Field):
- f = class_val
+ f = class_val if owner is cls else copy(class_val)
elif class_val is not MISSING:
f = field(default=class_val)
else:
f = field()
- if class_val is not MISSING:
+ if owner is cls and class_val is not MISSING:
try:
delattr(cls, name)
except AttributeError:
@@ -216,7 +234,7 @@ def _collect_own_fields( # noqa: PLR0912
fields.append(f)
- return fields
+ return fields, kw_only_active
def method(fn: Any) -> Any:
@@ -415,7 +433,20 @@ def _register_fields_into_type(
except (NameError, AttributeError):
return False
- own_fields = _collect_own_fields(cls, hints, params["kw_only"],
params["frozen"])
+ fields_map: dict[str, Field] = {}
+ kw_only_active = params["kw_only"]
+ for owner in _field_owner_classes(cls):
+ owner_fields, kw_only_active = _collect_own_fields(
+ cls,
+ owner,
+ hints,
+ kw_only_active,
+ params["frozen"],
+ )
+ for f in owner_fields:
+ assert f.name is not None
+ fields_map[f.name] = f
+ own_fields = list(fields_map.values())
py_methods = _collect_py_methods(cls)
# Register fields and type-level structural eq/hash kind with the C layer.
@@ -483,13 +514,16 @@ def _flush_pending() -> None:
def _raise_unresolved_forward_reference(cls: type, globalns: dict[str, Any])
-> None:
"""Raise :class:`TypeError` listing the annotations that cannot be
resolved."""
localns = _build_localns(cls, cross_module=True)
+ owners = _field_owner_classes(cls)
+ localns.update({owner.__name__: owner for owner in owners})
unresolved: list[str] = []
- for name, ann_str in getattr(cls, "__annotations__", {}).items():
- if isinstance(ann_str, str):
- try:
- eval(ann_str, globalns, localns)
- except NameError:
- unresolved.append(f"{name}: {ann_str}")
+ for owner in owners:
+ for name, ann_str in _own_annotations(owner).items():
+ if isinstance(ann_str, str):
+ try:
+ eval(ann_str, globalns, localns)
+ except NameError:
+ unresolved.append(f"{name}: {ann_str}")
raise TypeError(
f"Cannot instantiate {cls.__name__}: unresolved forward references:
{unresolved}"
)
diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py
index 0d0208a..07c24ee 100644
--- a/python/tvm_ffi/registry.py
+++ b/python/tvm_ffi/registry.py
@@ -399,7 +399,7 @@ def _install_init(cls: type, type_info: TypeInfo) -> None:
def _add_class_attrs(type_cls: type, type_info: TypeInfo) -> type:
for field in type_info.fields:
name = field.name
- if not hasattr(type_cls, name): # skip already defined attributes
+ if name not in type_cls.__dict__: # skip attributes defined directly
on this class
setattr(type_cls, name, field.as_property(type_cls))
has_ffi_init = False
for method in type_info.methods:
diff --git a/tests/python/test_dataclass_py_class.py
b/tests/python/test_dataclass_py_class.py
index 1b70489..423965e 100644
--- a/tests/python/test_dataclass_py_class.py
+++ b/tests/python/test_dataclass_py_class.py
@@ -32,7 +32,7 @@ from tvm_ffi import core
from tvm_ffi._dunder import _install_dataclass_dunders
from tvm_ffi._ffi_api import DeepCopy, RecursiveEq, RecursiveHash, ReprPrint
from tvm_ffi.core import MISSING, Object, TypeInfo, TypeSchema,
_to_py_class_value
-from tvm_ffi.dataclasses import KW_ONLY, Field, IntEnum, StrEnum, entry,
field, py_class
+from tvm_ffi.dataclasses import KW_ONLY, Field, IntEnum, StrEnum, entry,
field, fields, py_class
from tvm_ffi.registry import _add_class_attrs
from tvm_ffi.testing import TestObjectBase as _TestObjectBase
from tvm_ffi.testing.testing import requires_py310
@@ -739,6 +739,134 @@ class TestInheritance:
assert obj.b == 2
assert obj.c == 3
+ def test_collects_fields_from_non_py_class_parent(self) -> None:
+ @py_class(_unique_key("NPCNode"))
+ class Node(Object):
+ x: int
+
+ class BaseBinOp(Node):
+ lhs: int
+ rhs: int
+
+ @py_class(_unique_key("NPCAdd"))
+ class Add(BaseBinOp):
+ pass
+
+ obj = Add(lhs=1, rhs=2, x=0) # ty: ignore[unknown-argument]
+ assert obj.x == 0
+ assert obj.lhs == 1
+ assert obj.rhs == 2
+ assert [f.name for f in fields(Add)] == ["x", "lhs", "rhs"]
+ assert [f.name for f in _get_type_info(Add).fields] == ["lhs", "rhs"]
+
+ def test_collects_non_py_class_parent_field_options(self) -> None:
+ @py_class(_unique_key("NPCOptNode"))
+ class Node(Object):
+ x: int
+
+ class BaseOp(Node):
+ kind: ClassVar[str] = "binop"
+ lhs: int
+ hidden: int = field(default=99, init=False)
+ _: KW_ONLY
+ rhs: int
+
+ @py_class(_unique_key("NPCOptAdd"))
+ class Add(BaseOp):
+ scale: int
+
+ obj = Add(0, 1, rhs=3, scale=2) # ty:
ignore[parameter-already-assigned,unknown-argument]
+ assert obj.x == 0
+ assert obj.lhs == 1
+ assert obj.scale == 2
+ assert obj.rhs == 3
+ assert obj.hidden == 99
+ assert [f.name for f in fields(Add)] == ["x", "lhs", "hidden", "rhs",
"scale"]
+ assert [f.name for f in _get_type_info(Add).fields] == [
+ "lhs",
+ "hidden",
+ "rhs",
+ "scale",
+ ]
+ with pytest.raises(TypeError):
+ Add(0, 1, 2, rhs=3) # ty:
ignore[too-many-positional-arguments,unknown-argument]
+
+ def test_registered_parent_non_py_class_fields_not_duplicated(self) ->
None:
+ @py_class(_unique_key("NPCDedupNode"))
+ class Node(Object):
+ x: int
+
+ class BaseBinOp(Node):
+ lhs: int
+ rhs: int
+
+ @py_class(_unique_key("NPCDedupAdd"))
+ class Add(BaseBinOp):
+ op_id: int
+
+ @py_class(_unique_key("NPCDedupWeightedAdd"))
+ class WeightedAdd(Add):
+ weight: int
+
+ obj = WeightedAdd(x=0, lhs=1, rhs=2, op_id=3, weight=4) # ty:
ignore[unknown-argument]
+ assert obj.x == 0
+ assert obj.lhs == 1
+ assert obj.rhs == 2
+ assert obj.op_id == 3
+ assert obj.weight == 4
+ assert [f.name for f in fields(WeightedAdd)] == ["x", "lhs", "rhs",
"op_id", "weight"]
+ assert [f.name for f in _get_type_info(WeightedAdd).fields] ==
["weight"]
+
+ def test_multiple_non_py_class_parents_single_ffi_lineage(self) -> None:
+ @py_class(_unique_key("MROBase"))
+ class Base(Object):
+ x: int
+
+ class NonFFIClassC(Base):
+ c: int
+
+ class NonFFIClassA:
+ a: int
+
+ class NonFFIClassB:
+ b: int
+
+ @py_class(_unique_key("MROChild"))
+ class Child(NonFFIClassA, NonFFIClassB, NonFFIClassC):
+ y: int
+
+ obj = Child(x=0, c=3, b=2, a=1, y=4) # ty: ignore[unknown-argument]
+ assert obj.x == 0
+ assert obj.a == 1
+ assert obj.b == 2
+ assert obj.c == 3
+ assert obj.y == 4
+ assert [f.name for f in fields(Child)] == ["x", "c", "b", "a", "y"]
+ assert [f.name for f in _get_type_info(Child).fields] == ["c", "b",
"a", "y"]
+
+ def test_non_py_class_parent_override_preserves_field_position(self) ->
None:
+ @py_class(_unique_key("OverrideBase"))
+ class Base(Object):
+ x: int
+
+ class Parent(Base):
+ a: int
+ b: int
+ c: int
+
+ @py_class(_unique_key("OverrideChild"))
+ class Child(Parent):
+ b: str
+ d: int
+
+ obj = Child(x=0, a=1, b="two", c=3, d=4) # ty:
ignore[unknown-argument]
+ assert obj.b == "two"
+ assert [f.name for f in fields(Child)] == ["x", "a", "b", "c", "d"]
+ own_fields = _get_type_info(Child).fields
+ assert [f.name for f in own_fields] == ["a", "b", "c", "d"]
+ assert own_fields[1].dataclass_field is not None
+ assert own_fields[1].dataclass_field.type is str
+
# ###########################################################################
# 14. Forward references / deferred resolution