This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new b4525d2  fix(dataclass): collect py_class fields from plain bases 
(#575)
b4525d2 is described below

commit b4525d2bb0e1c0190d3d1bdf46711ee5dd1c4889
Author: Junru Shao <[email protected]>
AuthorDate: Thu Apr 30 12:11:14 2026 -0700

    fix(dataclass): collect py_class fields from plain bases (#575)
    
    ## Summary
    
    Fixes `@py_class` field collection when a decorated class inherits
    annotated fields through one or more non-`@py_class` Python bases.
    
    Previously, `typing.get_type_hints(cls)` resolved inherited annotations,
    but the field collector only iterated over annotations declared directly
    on `cls`. A class such as `Add(BaseBinOp)` where `BaseBinOp` is a plain
    Python subclass of a registered FFI class would therefore register no
    `lhs` / `rhs` fields on `Add`, and construction failed with an
    unexpected keyword argument error.
    
    This PR updates the collector to walk the current class MRO until the
    nearest registered FFI parent, collecting annotations from
    non-registered Python bases as fields owned by the new type. The nearest
    registered parent remains the TypeInfo boundary, so fields already
    registered on `@py_class` ancestors are not duplicated.
    
    The field merge now follows the review feedback from Gemini:
    
    - collection iterates owner classes from base to child,
    - `KW_ONLY` state is carried across the flattened non-FFI owner list,
    - fields are merged by name in insertion-order-preserving dictionaries,
    so an override keeps the original field position while adopting the
    most-derived metadata.
    
    It also updates class attribute installation to only preserve attributes
    defined directly on the registered class. This prevents inherited plain
    Python defaults or `Field` sentinels from shadowing the reflected
    property descriptor installed for the child type.
    
    ## Behavioral Impact
    
    - `@py_class` subclasses now accept fields declared on unregistered
    Python parents between the child and its registered FFI ancestor.
    - Field metadata from those plain parents is preserved, including
    `ClassVar`, `KW_ONLY`, and `init=False` behavior.
    - Python multiple inheritance through non-FFI bases is covered as long
    as the class has a single registered FFI lineage.
    - More-derived field overrides preserve the first field position while
    replacing metadata and type information.
    - No public API or ABI is added or removed.
    
    ## Tests
    
    - `python -m ruff check tests/python/test_dataclass_py_class.py
    python/tvm_ffi/dataclasses/py_class.py python/tvm_ffi/registry.py`
    - `ty check tests/python/test_dataclass_py_class.py`
    - `python -m pytest tests/python/test_dataclass_py_class.py -q`
    - `python -m pytest tests/python/test_dataclass_enum.py
    tests/python/test_dataclass_common.py
    tests/python/test_dataclass_c_class.py
    tests/python/test_dataclass_init.py -q`
    
    All checks passed locally.
    
    Note: the standalone `bug.py` reproducer is not present in this checkout
    after the latest amendment, but the same construction path is covered by
    `test_collects_fields_from_non_py_class_parent`.
---
 python/tvm_ffi/dataclasses/py_class.py  |  82 ++++++++++++++------
 python/tvm_ffi/registry.py              |   2 +-
 tests/python/test_dataclass_py_class.py | 130 +++++++++++++++++++++++++++++++-
 3 files changed, 188 insertions(+), 26 deletions(-)

diff --git a/python/tvm_ffi/dataclasses/py_class.py 
b/python/tvm_ffi/dataclasses/py_class.py
index ac77c30..b9cacd3 100644
--- a/python/tvm_ffi/dataclasses/py_class.py
+++ b/python/tvm_ffi/dataclasses/py_class.py
@@ -21,6 +21,7 @@ from __future__ import annotations
 import sys
 import typing
 from collections.abc import Callable
+from copy import copy
 from dataclasses import dataclass
 from typing import Any, ClassVar, TypeVar
 
@@ -136,12 +137,38 @@ def _rollback_registration(cls: type, type_info: Any) -> 
None:
 # ---------------------------------------------------------------------------
 
 
+def _own_annotations(cls: type) -> dict[str, Any]:
+    """Return annotations declared directly on *cls*."""
+    # Python 3.14+ (PEP 749): annotations are lazily evaluated via
+    # __annotate__ and no longer stored directly in __dict__.  getattr()
+    # triggers evaluation and returns per-class annotations correctly.
+    # On Python < 3.14, getattr() follows MRO and returns *parent*
+    # annotations when the child has none — use __dict__ to avoid that.
+    if sys.version_info >= (3, 14):
+        return getattr(cls, "__annotations__", {})
+    return cls.__dict__.get("__annotations__", {})
+
+
+def _field_owner_classes(cls: type) -> list[type]:
+    """Classes whose annotations become this type's own fields."""
+    registered_parent = next(
+        (b for b in cls.__mro__[1:] if "__tvm_ffi_type_info__" in b.__dict__), 
object
+    )
+    represented = set(registered_parent.__mro__)
+    return [
+        b
+        for b in reversed(cls.__mro__)
+        if b is not object and b not in represented and _own_annotations(b)
+    ]
+
+
 def _collect_own_fields(  # noqa: PLR0912
     cls: type,
+    owner: type,
     hints: dict[str, Any],
-    decorator_kw_only: bool,
+    kw_only_active: bool,
     decorator_frozen: bool,
-) -> list[Field]:
+) -> tuple[list[Field], bool]:
     """Parse own annotations into :class:`Field` objects.
 
     - Skips ``ClassVar`` annotations.
@@ -152,16 +179,7 @@ def _collect_own_fields(  # noqa: PLR0912
     - Resolves ``hash=None`` to follow ``compare``.
     """
     fields: list[Field] = []
-    kw_only_active = decorator_kw_only
-    # Python 3.14+ (PEP 749): annotations are lazily evaluated via
-    # __annotate__ and no longer stored directly in __dict__.  getattr()
-    # triggers evaluation and returns per-class annotations correctly.
-    # On Python < 3.14, getattr() follows MRO and returns *parent*
-    # annotations when the child has none — use __dict__ to avoid that.
-    if sys.version_info >= (3, 14):
-        own_annotations: dict[str, str] = getattr(cls, "__annotations__", {})
-    else:
-        own_annotations = cls.__dict__.get("__annotations__", {})
+    own_annotations = _own_annotations(owner)
 
     for name in own_annotations:
         resolved_type = hints.get(name)
@@ -176,7 +194,7 @@ def _collect_own_fields(  # noqa: PLR0912
         # KW_ONLY sentinel
         if resolved_type is KW_ONLY:
             kw_only_active = True
-            if name in cls.__dict__:
+            if owner is cls and name in cls.__dict__:
                 try:
                     delattr(cls, name)
                 except AttributeError:
@@ -184,14 +202,14 @@ def _collect_own_fields(  # noqa: PLR0912
             continue
 
         # Extract Field from class dict (inline of _pop_field_from_class)
-        class_val = cls.__dict__.get(name, MISSING)
+        class_val = owner.__dict__.get(name, MISSING)
         if isinstance(class_val, Field):
-            f = class_val
+            f = class_val if owner is cls else copy(class_val)
         elif class_val is not MISSING:
             f = field(default=class_val)
         else:
             f = field()
-        if class_val is not MISSING:
+        if owner is cls and class_val is not MISSING:
             try:
                 delattr(cls, name)
             except AttributeError:
@@ -216,7 +234,7 @@ def _collect_own_fields(  # noqa: PLR0912
 
         fields.append(f)
 
-    return fields
+    return fields, kw_only_active
 
 
 def method(fn: Any) -> Any:
@@ -415,7 +433,20 @@ def _register_fields_into_type(
         except (NameError, AttributeError):
             return False
 
-    own_fields = _collect_own_fields(cls, hints, params["kw_only"], 
params["frozen"])
+    fields_map: dict[str, Field] = {}
+    kw_only_active = params["kw_only"]
+    for owner in _field_owner_classes(cls):
+        owner_fields, kw_only_active = _collect_own_fields(
+            cls,
+            owner,
+            hints,
+            kw_only_active,
+            params["frozen"],
+        )
+        for f in owner_fields:
+            assert f.name is not None
+            fields_map[f.name] = f
+    own_fields = list(fields_map.values())
     py_methods = _collect_py_methods(cls)
 
     # Register fields and type-level structural eq/hash kind with the C layer.
@@ -483,13 +514,16 @@ def _flush_pending() -> None:
 def _raise_unresolved_forward_reference(cls: type, globalns: dict[str, Any]) 
-> None:
     """Raise :class:`TypeError` listing the annotations that cannot be 
resolved."""
     localns = _build_localns(cls, cross_module=True)
+    owners = _field_owner_classes(cls)
+    localns.update({owner.__name__: owner for owner in owners})
     unresolved: list[str] = []
-    for name, ann_str in getattr(cls, "__annotations__", {}).items():
-        if isinstance(ann_str, str):
-            try:
-                eval(ann_str, globalns, localns)
-            except NameError:
-                unresolved.append(f"{name}: {ann_str}")
+    for owner in owners:
+        for name, ann_str in _own_annotations(owner).items():
+            if isinstance(ann_str, str):
+                try:
+                    eval(ann_str, globalns, localns)
+                except NameError:
+                    unresolved.append(f"{name}: {ann_str}")
     raise TypeError(
         f"Cannot instantiate {cls.__name__}: unresolved forward references: 
{unresolved}"
     )
diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py
index 0d0208a..07c24ee 100644
--- a/python/tvm_ffi/registry.py
+++ b/python/tvm_ffi/registry.py
@@ -399,7 +399,7 @@ def _install_init(cls: type, type_info: TypeInfo) -> None:
 def _add_class_attrs(type_cls: type, type_info: TypeInfo) -> type:
     for field in type_info.fields:
         name = field.name
-        if not hasattr(type_cls, name):  # skip already defined attributes
+        if name not in type_cls.__dict__:  # skip attributes defined directly 
on this class
             setattr(type_cls, name, field.as_property(type_cls))
     has_ffi_init = False
     for method in type_info.methods:
diff --git a/tests/python/test_dataclass_py_class.py 
b/tests/python/test_dataclass_py_class.py
index 1b70489..423965e 100644
--- a/tests/python/test_dataclass_py_class.py
+++ b/tests/python/test_dataclass_py_class.py
@@ -32,7 +32,7 @@ from tvm_ffi import core
 from tvm_ffi._dunder import _install_dataclass_dunders
 from tvm_ffi._ffi_api import DeepCopy, RecursiveEq, RecursiveHash, ReprPrint
 from tvm_ffi.core import MISSING, Object, TypeInfo, TypeSchema, 
_to_py_class_value
-from tvm_ffi.dataclasses import KW_ONLY, Field, IntEnum, StrEnum, entry, 
field, py_class
+from tvm_ffi.dataclasses import KW_ONLY, Field, IntEnum, StrEnum, entry, 
field, fields, py_class
 from tvm_ffi.registry import _add_class_attrs
 from tvm_ffi.testing import TestObjectBase as _TestObjectBase
 from tvm_ffi.testing.testing import requires_py310
@@ -739,6 +739,134 @@ class TestInheritance:
         assert obj.b == 2
         assert obj.c == 3
 
+    def test_collects_fields_from_non_py_class_parent(self) -> None:
+        @py_class(_unique_key("NPCNode"))
+        class Node(Object):
+            x: int
+
+        class BaseBinOp(Node):
+            lhs: int
+            rhs: int
+
+        @py_class(_unique_key("NPCAdd"))
+        class Add(BaseBinOp):
+            pass
+
+        obj = Add(lhs=1, rhs=2, x=0)  # ty: ignore[unknown-argument]
+        assert obj.x == 0
+        assert obj.lhs == 1
+        assert obj.rhs == 2
+        assert [f.name for f in fields(Add)] == ["x", "lhs", "rhs"]
+        assert [f.name for f in _get_type_info(Add).fields] == ["lhs", "rhs"]
+
+    def test_collects_non_py_class_parent_field_options(self) -> None:
+        @py_class(_unique_key("NPCOptNode"))
+        class Node(Object):
+            x: int
+
+        class BaseOp(Node):
+            kind: ClassVar[str] = "binop"
+            lhs: int
+            hidden: int = field(default=99, init=False)
+            _: KW_ONLY
+            rhs: int
+
+        @py_class(_unique_key("NPCOptAdd"))
+        class Add(BaseOp):
+            scale: int
+
+        obj = Add(0, 1, rhs=3, scale=2)  # ty: 
ignore[parameter-already-assigned,unknown-argument]
+        assert obj.x == 0
+        assert obj.lhs == 1
+        assert obj.scale == 2
+        assert obj.rhs == 3
+        assert obj.hidden == 99
+        assert [f.name for f in fields(Add)] == ["x", "lhs", "hidden", "rhs", 
"scale"]
+        assert [f.name for f in _get_type_info(Add).fields] == [
+            "lhs",
+            "hidden",
+            "rhs",
+            "scale",
+        ]
+        with pytest.raises(TypeError):
+            Add(0, 1, 2, rhs=3)  # ty: 
ignore[too-many-positional-arguments,unknown-argument]
+
+    def test_registered_parent_non_py_class_fields_not_duplicated(self) -> 
None:
+        @py_class(_unique_key("NPCDedupNode"))
+        class Node(Object):
+            x: int
+
+        class BaseBinOp(Node):
+            lhs: int
+            rhs: int
+
+        @py_class(_unique_key("NPCDedupAdd"))
+        class Add(BaseBinOp):
+            op_id: int
+
+        @py_class(_unique_key("NPCDedupWeightedAdd"))
+        class WeightedAdd(Add):
+            weight: int
+
+        obj = WeightedAdd(x=0, lhs=1, rhs=2, op_id=3, weight=4)  # ty: 
ignore[unknown-argument]
+        assert obj.x == 0
+        assert obj.lhs == 1
+        assert obj.rhs == 2
+        assert obj.op_id == 3
+        assert obj.weight == 4
+        assert [f.name for f in fields(WeightedAdd)] == ["x", "lhs", "rhs", 
"op_id", "weight"]
+        assert [f.name for f in _get_type_info(WeightedAdd).fields] == 
["weight"]
+
+    def test_multiple_non_py_class_parents_single_ffi_lineage(self) -> None:
+        @py_class(_unique_key("MROBase"))
+        class Base(Object):
+            x: int
+
+        class NonFFIClassC(Base):
+            c: int
+
+        class NonFFIClassA:
+            a: int
+
+        class NonFFIClassB:
+            b: int
+
+        @py_class(_unique_key("MROChild"))
+        class Child(NonFFIClassA, NonFFIClassB, NonFFIClassC):
+            y: int
+
+        obj = Child(x=0, c=3, b=2, a=1, y=4)  # ty: ignore[unknown-argument]
+        assert obj.x == 0
+        assert obj.a == 1
+        assert obj.b == 2
+        assert obj.c == 3
+        assert obj.y == 4
+        assert [f.name for f in fields(Child)] == ["x", "c", "b", "a", "y"]
+        assert [f.name for f in _get_type_info(Child).fields] == ["c", "b", 
"a", "y"]
+
+    def test_non_py_class_parent_override_preserves_field_position(self) -> 
None:
+        @py_class(_unique_key("OverrideBase"))
+        class Base(Object):
+            x: int
+
+        class Parent(Base):
+            a: int
+            b: int
+            c: int
+
+        @py_class(_unique_key("OverrideChild"))
+        class Child(Parent):
+            b: str
+            d: int
+
+        obj = Child(x=0, a=1, b="two", c=3, d=4)  # ty: 
ignore[unknown-argument]
+        assert obj.b == "two"
+        assert [f.name for f in fields(Child)] == ["x", "a", "b", "c", "d"]
+        own_fields = _get_type_info(Child).fields
+        assert [f.name for f in own_fields] == ["a", "b", "c", "d"]
+        assert own_fields[1].dataclass_field is not None
+        assert own_fields[1].dataclass_field.type is str
+
 
 # ###########################################################################
 # 14. Forward references / deferred resolution

Reply via email to