This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 7055626  fix(dataclass): resolve cross-module forward refs (#559)
7055626 is described below

commit 7055626763b3e76c60842274286e10bde0f464ad
Author: Junru Shao <[email protected]>
AuthorDate: Mon Apr 20 12:40:52 2026 -0700

    fix(dataclass): resolve cross-module forward refs (#559)
---
 python/tvm_ffi/dataclasses/py_class.py  | 53 ++++++++++++++++++++++++++-----
 tests/python/test_dataclass_py_class.py | 56 +++++++++++++++++++++++++++++++++
 2 files changed, 101 insertions(+), 8 deletions(-)

diff --git a/python/tvm_ffi/dataclasses/py_class.py 
b/python/tvm_ffi/dataclasses/py_class.py
index 70ee25d..cf41abd 100644
--- a/python/tvm_ffi/dataclasses/py_class.py
+++ b/python/tvm_ffi/dataclasses/py_class.py
@@ -244,6 +244,32 @@ def _collect_py_methods(cls: type) -> list[tuple[str, Any, 
bool]] | None:
     return methods if methods else None
 
 
+def _build_localns(cls: type, *, cross_module: bool = False) -> dict[str, Any]:
+    """Build the localns dict for resolving ``cls``'s annotations.
+
+    By default, includes only classes from ``cls.__module__``, preserving
+    standard Python name resolution semantics.  When ``cross_module=True``,
+    also includes classes from all other registered modules as a fallback
+    — this is needed when ``cls`` has a forward reference to a class in
+    another module that can't appear in ``cls.__module__``'s globals due
+    to a circular import (e.g. the target is imported only under
+    ``if TYPE_CHECKING:``).
+
+    Cross-module entries are added with ``setdefault`` so same-module
+    classes and the class itself always take precedence over foreign
+    classes with the same ``__name__``.
+    """
+    localns = dict(_PY_CLASS_BY_MODULE.get(cls.__module__, {}))
+    localns[cls.__name__] = cls
+    if cross_module:
+        for mod_name, mod_classes in list(_PY_CLASS_BY_MODULE.items()):
+            if mod_name == cls.__module__:
+                continue
+            for name, klass in mod_classes.items():
+                localns.setdefault(name, klass)
+    return localns
+
+
 def _register_fields_into_type(
     cls: type,
     type_info: Any,
@@ -255,15 +281,23 @@ def _register_fields_into_type(
     Returns True on success, False if forward references are unresolved.
     """
     # Resolve string annotations to types; return False (defer) on NameError.
-    localns = dict(_PY_CLASS_BY_MODULE.get(cls.__module__, {}))
-    localns[cls.__name__] = cls
+    #
+    # First try with module-scoped localns (standard Python name resolution).
+    # On NameError, retry with a cross-module localns that includes classes
+    # from every registered module — this handles circular imports where the
+    # target of a forward reference is imported only under TYPE_CHECKING and
+    # therefore never enters the declaring module's globals.
+    kwargs: dict[str, Any] = {"globalns": globalns, "localns": 
_build_localns(cls)}
+    if sys.version_info >= (3, 11):
+        kwargs["include_extras"] = True
     try:
-        kwargs: dict[str, Any] = {"globalns": globalns, "localns": localns}
-        if sys.version_info >= (3, 11):
-            kwargs["include_extras"] = True
         hints = typing.get_type_hints(cls, **kwargs)
     except (NameError, AttributeError):
-        return False
+        kwargs["localns"] = _build_localns(cls, cross_module=True)
+        try:
+            hints = typing.get_type_hints(cls, **kwargs)
+        except (NameError, AttributeError):
+            return False
 
     own_fields = _collect_own_fields(cls, hints, params["kw_only"], 
params["frozen"])
     py_methods = _collect_py_methods(cls)
@@ -332,8 +366,7 @@ def _flush_pending() -> None:
 
 def _raise_unresolved_forward_reference(cls: type, globalns: dict[str, Any]) 
-> None:
     """Raise :class:`TypeError` listing the annotations that cannot be 
resolved."""
-    localns = dict(_PY_CLASS_BY_MODULE.get(cls.__module__, {}))
-    localns[cls.__name__] = cls
+    localns = _build_localns(cls, cross_module=True)
     unresolved: list[str] = []
     for name, ann_str in getattr(cls, "__annotations__", {}).items():
         if isinstance(ann_str, str):
@@ -354,6 +387,10 @@ def _make_temporary_init(
             try:
                 if not _register_fields_into_type(cls, type_info, globalns, 
params):
                     _raise_unresolved_forward_reference(cls, globalns)
+                # cls stays in _PENDING_CLASSES after phase-2 succeeds; drop it
+                # before _flush_pending so the loop doesn't hit the 
Cython-level
+                # "_register_fields already called" assertion on a second pass.
+                _PENDING_CLASSES[:] = [p for p in _PENDING_CLASSES if p.cls is 
not cls]
                 _flush_pending()
             except Exception:
                 # Remove from pending list and roll back so the type key can 
be reused.
diff --git a/tests/python/test_dataclass_py_class.py 
b/tests/python/test_dataclass_py_class.py
index 04f8773..acfabe5 100644
--- a/tests/python/test_dataclass_py_class.py
+++ b/tests/python/test_dataclass_py_class.py
@@ -759,6 +759,62 @@ class TestForwardReferences:
         assert obj.ref is not None
         assert obj.ref.value == 2
 
+    @requires_py310
+    def test_cross_module_forward_ref_via_circular_import(self) -> None:
+        """Forward ref to a class in another module that isn't in the declaring
+        module's globals (circular import / TYPE_CHECKING-gated) still 
resolves.
+
+        Mirrors the loom codegen case where ``weave_ir.TaskSpec`` has a field
+        ``body: tuple[Op, ...]`` and ``Op`` lives in ``ops`` — ``ops`` already
+        imports from ``weave_ir``, so ``Op`` cannot appear in ``weave_ir``'s
+        globals without introducing a circular import.
+        """
+        # Module B: the target of the forward ref.  Define it first so it's
+        # registered in ``_PY_CLASS_BY_MODULE`` under its own module key when
+        # module A's phase-2 runs.
+        ns_b: Dict[str, Any] = {
+            "__name__": "testing.cross_mod_b",
+            "py_class": py_class,
+            "Object": Object,
+            "_unique_key": _unique_key,
+        }
+        exec(
+            "from __future__ import annotations\n"
+            "@py_class(_unique_key('CrossChild'))\n"
+            "class CrossChild(Object):\n"
+            "    x: int\n",
+            ns_b,
+        )
+        child_cls = ns_b["CrossChild"]
+        assert child_cls.__module__ == "testing.cross_mod_b"
+
+        # Module A: references ``CrossChild`` by name but does NOT have it
+        # in its globals (simulating a ``TYPE_CHECKING``-gated import).
+        ns_a: Dict[str, Any] = {
+            "__name__": "testing.cross_mod_a",
+            "py_class": py_class,
+            "Object": Object,
+            "_unique_key": _unique_key,
+        }
+        exec(
+            "from __future__ import annotations\n"
+            "@py_class(_unique_key('CrossHolder'))\n"
+            "class Holder(Object):\n"
+            "    value: int\n"
+            "    child: CrossChild | None = None\n",
+            ns_a,
+        )
+        holder_cls = ns_a["Holder"]
+        assert holder_cls.__module__ == "testing.cross_mod_a"
+
+        # Instantiation forces phase-2 to run; the cross-module localns
+        # fallback should pick up CrossChild from module B even though it's
+        # neither in module A's globals nor in module A's per-module registry.
+        child = child_cls(x=7)
+        holder = holder_cls(value=1, child=child)
+        assert holder.child is not None
+        assert holder.child.x == 7
+
 
 # ###########################################################################
 # 15. User-defined dunder preservation

Reply via email to