This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 7055626 fix(dataclass): resolve cross-module forward refs (#559)
7055626 is described below
commit 7055626763b3e76c60842274286e10bde0f464ad
Author: Junru Shao <[email protected]>
AuthorDate: Mon Apr 20 12:40:52 2026 -0700
fix(dataclass): resolve cross-module forward refs (#559)
---
python/tvm_ffi/dataclasses/py_class.py | 53 ++++++++++++++++++++++++++-----
tests/python/test_dataclass_py_class.py | 56 +++++++++++++++++++++++++++++++++
2 files changed, 101 insertions(+), 8 deletions(-)
diff --git a/python/tvm_ffi/dataclasses/py_class.py
b/python/tvm_ffi/dataclasses/py_class.py
index 70ee25d..cf41abd 100644
--- a/python/tvm_ffi/dataclasses/py_class.py
+++ b/python/tvm_ffi/dataclasses/py_class.py
@@ -244,6 +244,32 @@ def _collect_py_methods(cls: type) -> list[tuple[str, Any,
bool]] | None:
return methods if methods else None
+def _build_localns(cls: type, *, cross_module: bool = False) -> dict[str, Any]:
+ """Build the localns dict for resolving ``cls``'s annotations.
+
+ By default, includes only classes from ``cls.__module__``, preserving
+ standard Python name resolution semantics. When ``cross_module=True``,
+ also includes classes from all other registered modules as a fallback
+ — this is needed when ``cls`` has a forward reference to a class in
+ another module that can't appear in ``cls.__module__``'s globals due
+ to a circular import (e.g. the target is imported only under
+ ``if TYPE_CHECKING:``).
+
+ Cross-module entries are added with ``setdefault`` so same-module
+ classes and the class itself always take precedence over foreign
+ classes with the same ``__name__``.
+ """
+ localns = dict(_PY_CLASS_BY_MODULE.get(cls.__module__, {}))
+ localns[cls.__name__] = cls
+ if cross_module:
+ for mod_name, mod_classes in list(_PY_CLASS_BY_MODULE.items()):
+ if mod_name == cls.__module__:
+ continue
+ for name, klass in mod_classes.items():
+ localns.setdefault(name, klass)
+ return localns
+
+
def _register_fields_into_type(
cls: type,
type_info: Any,
@@ -255,15 +281,23 @@ def _register_fields_into_type(
Returns True on success, False if forward references are unresolved.
"""
# Resolve string annotations to types; return False (defer) on NameError.
- localns = dict(_PY_CLASS_BY_MODULE.get(cls.__module__, {}))
- localns[cls.__name__] = cls
+ #
+ # First try with module-scoped localns (standard Python name resolution).
+ # On NameError, retry with a cross-module localns that includes classes
+ # from every registered module — this handles circular imports where the
+ # target of a forward reference is imported only under TYPE_CHECKING and
+ # therefore never enters the declaring module's globals.
+ kwargs: dict[str, Any] = {"globalns": globalns, "localns":
_build_localns(cls)}
+ if sys.version_info >= (3, 11):
+ kwargs["include_extras"] = True
try:
- kwargs: dict[str, Any] = {"globalns": globalns, "localns": localns}
- if sys.version_info >= (3, 11):
- kwargs["include_extras"] = True
hints = typing.get_type_hints(cls, **kwargs)
except (NameError, AttributeError):
- return False
+ kwargs["localns"] = _build_localns(cls, cross_module=True)
+ try:
+ hints = typing.get_type_hints(cls, **kwargs)
+ except (NameError, AttributeError):
+ return False
own_fields = _collect_own_fields(cls, hints, params["kw_only"],
params["frozen"])
py_methods = _collect_py_methods(cls)
@@ -332,8 +366,7 @@ def _flush_pending() -> None:
def _raise_unresolved_forward_reference(cls: type, globalns: dict[str, Any])
-> None:
"""Raise :class:`TypeError` listing the annotations that cannot be
resolved."""
- localns = dict(_PY_CLASS_BY_MODULE.get(cls.__module__, {}))
- localns[cls.__name__] = cls
+ localns = _build_localns(cls, cross_module=True)
unresolved: list[str] = []
for name, ann_str in getattr(cls, "__annotations__", {}).items():
if isinstance(ann_str, str):
@@ -354,6 +387,10 @@ def _make_temporary_init(
try:
if not _register_fields_into_type(cls, type_info, globalns,
params):
_raise_unresolved_forward_reference(cls, globalns)
+ # cls stays in _PENDING_CLASSES after phase-2 succeeds; drop it
+ # before _flush_pending so the loop doesn't hit the
Cython-level
+ # "_register_fields already called" assertion on a second pass.
+ _PENDING_CLASSES[:] = [p for p in _PENDING_CLASSES if p.cls is
not cls]
_flush_pending()
except Exception:
# Remove from pending list and roll back so the type key can
be reused.
diff --git a/tests/python/test_dataclass_py_class.py
b/tests/python/test_dataclass_py_class.py
index 04f8773..acfabe5 100644
--- a/tests/python/test_dataclass_py_class.py
+++ b/tests/python/test_dataclass_py_class.py
@@ -759,6 +759,62 @@ class TestForwardReferences:
assert obj.ref is not None
assert obj.ref.value == 2
+ @requires_py310
+ def test_cross_module_forward_ref_via_circular_import(self) -> None:
+ """Forward ref to a class in another module that isn't in the declaring
+ module's globals (circular import / TYPE_CHECKING-gated) still
resolves.
+
+ Mirrors the loom codegen case where ``weave_ir.TaskSpec`` has a field
+ ``body: tuple[Op, ...]`` and ``Op`` lives in ``ops`` — ``ops`` already
+ imports from ``weave_ir``, so ``Op`` cannot appear in ``weave_ir``'s
+ globals without introducing a circular import.
+ """
+ # Module B: the target of the forward ref. Define it first so it's
+ # registered in ``_PY_CLASS_BY_MODULE`` under its own module key when
+ # module A's phase-2 runs.
+ ns_b: Dict[str, Any] = {
+ "__name__": "testing.cross_mod_b",
+ "py_class": py_class,
+ "Object": Object,
+ "_unique_key": _unique_key,
+ }
+ exec(
+ "from __future__ import annotations\n"
+ "@py_class(_unique_key('CrossChild'))\n"
+ "class CrossChild(Object):\n"
+ " x: int\n",
+ ns_b,
+ )
+ child_cls = ns_b["CrossChild"]
+ assert child_cls.__module__ == "testing.cross_mod_b"
+
+ # Module A: references ``CrossChild`` by name but does NOT have it
+ # in its globals (simulating a ``TYPE_CHECKING``-gated import).
+ ns_a: Dict[str, Any] = {
+ "__name__": "testing.cross_mod_a",
+ "py_class": py_class,
+ "Object": Object,
+ "_unique_key": _unique_key,
+ }
+ exec(
+ "from __future__ import annotations\n"
+ "@py_class(_unique_key('CrossHolder'))\n"
+ "class Holder(Object):\n"
+ " value: int\n"
+ " child: CrossChild | None = None\n",
+ ns_a,
+ )
+ holder_cls = ns_a["Holder"]
+ assert holder_cls.__module__ == "testing.cross_mod_a"
+
+ # Instantiation forces phase-2 to run; the cross-module localns
+ # fallback should pick up CrossChild from module B even though it's
+ # neither in module A's globals nor in module A's per-module registry.
+ child = child_cls(x=7)
+ holder = holder_cls(value=1, child=child)
+ assert holder.child is not None
+ assert holder.child.x == 7
+
# ###########################################################################
# 15. User-defined dunder preservation