This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 02a344d feat: Added typed method registration for py_class (#567)
02a344d is described below
commit 02a344dbe6df51b8506194a2f56bea482274f90a
Author: Kathryn (Jinqi) Chen <[email protected]>
AuthorDate: Fri Apr 24 00:08:39 2026 -0700
feat: Added typed method registration for py_class (#567)
## Add type method registration on py_class
feat(dataclasses): add @method decorator for FFI TypeMethod registration
Closes the gap where `@py_class`-decorated Python classes couldn't
expose user-defined methods to the FFI reflection table. Trait
references like `$method:NAME` in `__ffi_ir_traits__` now resolve
cleanly against Python-defined methods, on parity with C++-defined
`refl::ObjectDef<T>().def("name", ...)`.
## Architecture
- `tvm_ffi.method` decorator stamps `fn.__ffi_method__ = True` on the
decorated function (unwraps `staticmethod` to `__func__` so both
`@method @staticmethod` and `@staticmethod @method` work).
- `_collect_py_methods` in `py_class.py` widens beyond the existing
`_FFI_RECOGNIZED_METHODS` allowlist: it now also picks up any
callable in `cls.__dict__` carrying the `__ffi_method__` marker
via `_is_method_marked`. The allowlist still routes TypeAttrColumn
dunders (`__ffi_repr__`, `__ffi_ir_traits__`, etc.) through
`TVMFFITypeRegisterAttr` unchanged.
- Marked methods register through the existing TypeMethod path
(`TVMFFITypeRegisterMethod`) — same machinery the C++ side uses.
The C++ printer's `FindMethod` (`pyast_trait_print.cc:130-148`) walks
`info->methods[]` and discovers these entries identically to
C++-defined methods, including ancestor-walking for inheritance.
- `_validate_method_name` rejects `@method` decoration on:
- `__ffi_*` reserved prefix (would double-register as TypeAttrColumn),
- Python protocol dunders (`__len__`, `__iter__`, etc. — reserved
for Python semantics),
- `@classmethod` (the `cls` first arg breaks the packed-call
convention; raises at decoration time).
- Identity-typed: `@method` returns the original function unchanged
(no wrapping); IDE / type-checker sees the method's signature
unmodified.
## Public Interfaces
- `tvm_ffi.method` (re-exported from `tvm_ffi.dataclasses`)
## Behavioral Changes
None for existing classes. Only NEW `@method`-decorated callables land
in `TVMFFITypeInfo.methods[]`.
## Tests
- `tests/python/test_typed_method.py` (new, 13 tests): registration
shape, FFI Function callability, validation errors.
- `tests/python/test_ir_traits.py` (5 new tests): end-to-end via
`pyast.to_python()` against `$method:` refs in
`AssignTraits.text_printer_post`,
`
---
python/tvm_ffi/__init__.py | 2 +
python/tvm_ffi/dataclasses/py_class.py | 149 ++++++++++++--
tests/python/test_typed_method.py | 346 +++++++++++++++++++++++++++++++++
3 files changed, 479 insertions(+), 18 deletions(-)
diff --git a/python/tvm_ffi/__init__.py b/python/tvm_ffi/__init__.py
index 77e95e5..b1872a7 100644
--- a/python/tvm_ffi/__init__.py
+++ b/python/tvm_ffi/__init__.py
@@ -72,6 +72,7 @@ if TYPE_CHECKING or not _is_config_mode():
from ._tensor import Device, device, DLDeviceType
from ._tensor import from_dlpack, Tensor, Shape
from .container import Array, Dict, List, Map
+ from .dataclasses.py_class import method
from .module import Module, system_lib, load_module
from .stream import StreamContext, get_raw_stream, use_raw_stream,
use_torch_stream
from .structural import (
@@ -156,6 +157,7 @@ __all__ = [
"get_raw_stream",
"init_ffi_api",
"load_module",
+ "method",
"register_error",
"register_global_func",
"register_object",
diff --git a/python/tvm_ffi/dataclasses/py_class.py
b/python/tvm_ffi/dataclasses/py_class.py
index 0a44aa0..ac77c30 100644
--- a/python/tvm_ffi/dataclasses/py_class.py
+++ b/python/tvm_ffi/dataclasses/py_class.py
@@ -219,30 +219,143 @@ def _collect_own_fields( # noqa: PLR0912
return fields
-def _collect_py_methods(cls: type) -> list[tuple[str, Any, bool]] | None:
- """Extract recognized FFI dunder methods and type attrs from the class
body.
+def method(fn: Any) -> Any:
+ """Mark a ``@py_class`` method for FFI TypeMethod registration.
+
+ Decorate any staticmethod or plain instance method on a ``@py_class``
+ body to have it land in the C-level ``TVMFFITypeInfo.methods[]``
+ table. Once registered, the method is resolvable by name from any
+ FFI consumer — Python-side reflection via ``TypeInfo.methods``,
+ C++, Rust — through the same path already used by C++-defined
+ methods declared via ``refl::ObjectDef<T>().def(...)``.
+
+ Example::
+
+ from tvm_ffi import Object, method
+ from tvm_ffi.dataclasses import py_class
+
+
+ @py_class("example.Node")
+ class Node(Object):
+ x: int
+
+ @method
+ def label(self) -> str:
+ return f"N({self.x})"
+
+
+ # The method is now in ``TypeInfo.methods`` and FFI-callable:
+ info = Node.__tvm_ffi_type_info__
+ fn = next(m.func for m in info.methods if m.name == "label")
+ fn(Node(x=7)) # -> "N(7)"
+
+ ``staticmethod`` is supported: the marker is written onto the
+ underlying function and unwrapped at registration time. Plain
+ functions are also accepted — the marker lives on the function
+ object directly. ``classmethod`` is rejected at decoration time
+ because its ``cls``-first dispatch does not match the
+ packed-call convention.
+ """
+ if isinstance(fn, staticmethod):
+ fn.__func__.__ffi_method__ = True
+ return fn
+ if isinstance(fn, classmethod):
+ raise TypeError(
+ "@tvm_ffi.method: @classmethod is not supported for FFI "
+ "TypeMethod registration — the classmethod's ``cls`` first "
+ "arg does not match the packed-call convention. Use "
+ "@staticmethod or a plain instance method instead.",
+ )
+ if not callable(fn):
+ raise TypeError(
+ f"@tvm_ffi.method: expected a callable, got {type(fn).__name__}.",
+ )
+ fn.__ffi_method__ = True
+ return fn
+
+
+def _is_method_marked(value: Any) -> bool:
+ """Return True when ``value`` is a callable marked by :func:`method`."""
+ if isinstance(value, (staticmethod, classmethod)):
+ return getattr(value.__func__, "__ffi_method__", False) is True
+ if callable(value):
+ return getattr(value, "__ffi_method__", False) is True
+ return False
- Only names listed in :data:`_FFI_RECOGNIZED_METHODS` are collected.
- Callables are collected with their ``is_static`` flag; non-callable
- values (e.g. ``__ffi_ir_traits__``) are collected as-is — the Cython
- layer routes them to ``TVMFFITypeRegisterAttr`` based on name.
- Returns a list of ``(name, value, is_static)`` tuples, or ``None``
- if no eligible entries were found.
+def _validate_method_name(cls: type, name: str) -> None:
+ """Reject ``@method``-marked names that collide with reserved namespaces.
+
+ Names in :data:`_FFI_TYPE_ATTR_NAMES` and Python-protocol dunders
+ are not allowed for ``@method`` — they are routed through the
+ TypeAttrColumn / Python-protocol paths instead.
+ """
+ if name in _FFI_TYPE_ATTR_NAMES:
+ raise NameError(
+ f"@py_class({cls.__name__!r}): {name!r} is a TypeAttrColumn "
+ "name — define it directly on the class body without "
+ "``@method``; the FFI system routes it to
``TVMFFITypeRegisterAttr``.",
+ )
+ if name.startswith("__ffi_"):
+ raise NameError(
+ f"@py_class({cls.__name__!r}): {name!r} starts with the "
+ "reserved ``__ffi_`` prefix. Pick a different name for your "
+ "``@method``-decorated method.",
+ )
+ if name.startswith("__") and name.endswith("__"):
+ raise NameError(
+ f"@py_class({cls.__name__!r}): {name!r} is a Python protocol "
+ "dunder — these are reserved for Python semantics and cannot "
+ "be registered as FFI TypeMethods.",
+ )
+
+
+def _collect_py_methods(cls: type) -> list[tuple[str, Any, bool]] | None:
+ """Extract FFI-registered entries from a ``@py_class`` body.
+
+ Two sources are collected:
+
+ 1. **TypeAttrColumn dunders** — names in :data:`_FFI_RECOGNIZED_METHODS`
+ that appear in ``cls.__dict__``. Both callables (e.g.
+ ``__ffi_repr__``) and non-callable values flow here; the Cython
+ layer routes them to ``TVMFFITypeRegisterAttr`` based on name.
+ 2. **User TypeMethods** — every callable in ``cls.__dict__`` marked
+ with :func:`method`. Registered via ``TVMFFITypeRegisterMethod``
+ so the method is resolvable by name from any FFI consumer
+ (introspection through ``TypeInfo.methods``, name-based lookup
+ from C++ / Rust, etc.). The decorator pattern keeps the
+ per-class declaration co-located with the method body; no
+ separate allowlist.
+
+ Validation runs at registration time — reserved ``__ffi_*`` names
+ and Python protocol dunders cannot be ``@method``-decorated; those
+ are reserved by the TypeAttrColumn and Python semantics respectively.
+
+ Returns the ``(name, value, is_static)`` list, or :data:`None` when
+ no entries were found.
"""
methods: list[tuple[str, Any, bool]] = []
for name, value in cls.__dict__.items():
- if name not in _FFI_RECOGNIZED_METHODS:
+ marked = _is_method_marked(value)
+ if name not in _FFI_RECOGNIZED_METHODS and not marked:
continue
- if isinstance(value, staticmethod):
- func = value.__func__
- is_static = True
- elif callable(value):
- func = value
- is_static = False
- else:
- func = value
- is_static = False
+ if marked:
+ _validate_method_name(cls, name)
+ # In every case, registering a classmethod as a TypeMethod is
+ # wrong: the packed-call convention places ``self`` (an instance)
+ # in slot 0, but classmethod's descriptor binds slot 0 to the
+ # class.
+ if isinstance(value, classmethod):
+ raise TypeError(
+ f"@py_class({cls.__name__!r}): {name!r} is wrapped by "
+ "@classmethod, which is incompatible with FFI "
+ "registration — the cls-first arg breaks the packed-call "
+ "convention. Use @staticmethod or a plain instance "
+ "method. If you wrote ``@classmethod @method``, swap to "
+ "``@staticmethod @method`` (or drop @classmethod).",
+ )
+ is_static = isinstance(value, staticmethod)
+ func = value.__func__ if is_static else value
methods.append((name, func, is_static))
return methods if methods else None
diff --git a/tests/python/test_typed_method.py
b/tests/python/test_typed_method.py
new file mode 100644
index 0000000..470c43e
--- /dev/null
+++ b/tests/python/test_typed_method.py
@@ -0,0 +1,346 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Tests for ``@tvm_ffi.method`` — opt-in TypeMethod registration on
+``@py_class``-decorated classes.
+"""
+
+from __future__ import annotations
+
+import itertools
+from typing import Any
+
+import pytest
+from tvm_ffi import Object, method
+from tvm_ffi.core import TypeInfo
+from tvm_ffi.dataclasses import py_class
+
+_counter = itertools.count()
+
+
+def _unique_key(base: str) -> str:
+ """Return a globally unique type key so tests can re-register freely."""
+ return f"testing.method_dec.{base}_{next(_counter)}"
+
+
+def _find_method(info: TypeInfo, name: str) -> Any:
+ """Return the ``TypeMethod`` entry for ``name`` or :data:`None`."""
+ return next((m for m in info.methods if m.name == name), None)
+
+
+def _toy_method_resolve(obj: Any, ref: str, *args: Any, **kwargs: Any) -> Any:
+ """Test helper: resolve a ``"$method:NAME"`` string ref and call it.
+
+ Parses the ``$method:`` prefix, looks up ``NAME`` in the instance's
+ ``TypeInfo.methods`` table, and invokes the resolved ``Function``
+ on ``obj`` (plus any extra args). A successful return value means
+ the whole ``@method`` → ``TVMFFITypeRegisterMethod`` → FFI callable
+ chain is intact. The ``$method:`` prefix is a test-local convention
+ for demonstrating name-based method lookup; it has no production
+ meaning on its own.
+ """
+ prefix = "$method:"
+ if not ref.startswith(prefix):
+ raise ValueError(f"Not a $method: ref: {ref!r}")
+ name = ref[len(prefix) :]
+ info = type(obj).__tvm_ffi_type_info__ # ty: ignore[unresolved-attribute]
+ m = _find_method(info, name)
+ if m is None:
+ raise LookupError(
+ f"{type(obj).__name__}.{name}: not in TypeInfo.methods — "
+ "was the method decorated with ``@tvm_ffi.method``?",
+ )
+ return m.func(obj, *args, **kwargs)
+
+
+# ---------------------------------------------------------------------------
+# Registration — ``@method``-marked methods land in TypeInfo.methods
+# ---------------------------------------------------------------------------
+
+
+class TestMethodRegistration:
+ """``@method`` drops the function's signature into
+ ``TVMFFITypeRegisterMethod``; the name is resolvable from any FFI
+ consumer.
+ """
+
+ def test_instance_method_registered_and_ffi_callable(self) -> None:
+ """A plain instance-style method registers with ``is_static=False``
+ and the returned FFI Function accepts the instance as arg 0.
+ """
+
+ @py_class(_unique_key("Node"))
+ class Node(Object):
+ x: int
+
+ @method
+ def label(self) -> str:
+ return f"N({self.x})"
+
+ m = _find_method(Node.__tvm_ffi_type_info__, "label") # ty:
ignore[unresolved-attribute]
+ assert m is not None
+ assert m.is_static is False
+ # FFI call routes through the C method table — proves the
+ # registration landed on the C side, not just the Python attr.
+ assert m.func(Node(x=7)) == "N(7)"
+
+ def test_staticmethod_registered_with_is_static_true(self) -> None:
+ """``@method`` on top of ``@staticmethod`` marks the underlying
+ function; the unwrap happens inside ``_collect_py_methods``.
+ """
+
+ @py_class(_unique_key("Nstat"))
+ class Nstat(Object):
+ x: int
+
+ @method
+ @staticmethod
+ def constant() -> int:
+ return 42
+
+ m = _find_method(Nstat.__tvm_ffi_type_info__, "constant") # ty:
ignore[unresolved-attribute]
+ assert m is not None
+ assert m.is_static is True
+ assert m.func() == 42
+
+ def test_multiple_methods_all_registered(self) -> None:
+ """Every ``@method``-marked callable appears in ``info.methods``."""
+
+ @py_class(_unique_key("NodeMulti"))
+ class NodeMulti(Object):
+ x: int
+
+ @method
+ def kind(self) -> str:
+ return "multi"
+
+ @method
+ def double(self) -> int:
+ return self.x * 2
+
+ @method
+ def prefixed(self, p: str) -> str:
+ return f"{p}-{self.x}"
+
+ names = {m.name for m in NodeMulti.__tvm_ffi_type_info__.methods} #
ty: ignore[unresolved-attribute]
+ assert {"kind", "double", "prefixed"}.issubset(names)
+
+ def test_no_decorator_no_registration(self) -> None:
+ """Without ``@method``, a class-body function is a plain Python
+ attribute — nothing reaches ``TypeInfo.methods``. Protects the
+ opt-in contract: users aren't surprised by accidental FFI
+ registration of helper methods.
+ """
+
+ @py_class(_unique_key("NodeBare"))
+ class NodeBare(Object):
+ x: int
+
+ def helper(self) -> int: # no @method
+ return self.x
+
+ assert _find_method(NodeBare.__tvm_ffi_type_info__, "helper") is None
# ty: ignore[unresolved-attribute]
+
+ def test_python_attribute_still_callable(self) -> None:
+ """Registration doesn't shadow the Python attribute — callers
+ can still invoke the method normally as ``instance.name(...)``.
+ """
+
+ @py_class(_unique_key("NodeKeep"))
+ class NodeKeep(Object):
+ x: int
+
+ @method
+ def doubled(self) -> int:
+ return self.x * 2
+
+ assert NodeKeep(x=5).doubled() == 10
+
+
+# ---------------------------------------------------------------------------
+# End-to-end: name-based method lookup via ``TypeInfo.methods``
+# ---------------------------------------------------------------------------
+
+
+class TestNameBasedResolution:
+ """Resolving a ``@method``-decorated method by its name through
+ ``TypeInfo.methods`` and calling it via the FFI path. This is the
+ same code path any cross-language consumer (C++, Rust, future
+ reflection-driven tooling) takes to invoke a Python-defined
+ method by name.
+ """
+
+ def test_name_lookup_invokes_decorated_method(self) -> None:
+ """Name lookup → FFI call round-trip: look up ``label`` in
+ ``TypeInfo.methods`` and invoke it via the resolved Function.
+ """
+
+ @py_class(_unique_key("Op"))
+ class Op(Object):
+ kind: str
+
+ @method
+ def label(self) -> str:
+ return f"op:{self.kind}"
+
+ result = _toy_method_resolve(Op(kind="add"), "$method:label")
+ assert result == "op:add"
+
+ def test_name_lookup_threads_extra_args(self) -> None:
+ """Extra positional / keyword arguments thread through the FFI
+ Function unchanged — covers multi-argument shapes any
+ consumer would need.
+ """
+
+ @py_class(_unique_key("PrologueOp"))
+ class PrologueOp(Object):
+ kind: str
+
+ @method
+ def print_prologue(self, printer: Any, frame: Any) -> str:
+ # Use the extra args so a missing pass-through would show up.
+ return f"{printer}-{self.kind}-{frame}"
+
+ op = PrologueOp(kind="add")
+ assert _toy_method_resolve(op, "$method:print_prologue", "PR", "FR")
== "PR-add-FR"
+
+ def test_name_lookup_missing_method_raises_clearly(self) -> None:
+ """Looking up a method name that was NOT decorated with
+ ``@method`` raises at resolution time — the failure mode a
+ user hits when they forget the decorator.
+ """
+
+ @py_class(_unique_key("OpMiss"))
+ class OpMiss(Object):
+ kind: str
+
+ def unmarked(self) -> str: # no @method — not registered
+ return self.kind
+
+ with pytest.raises(LookupError, match=r"not in TypeInfo\.methods"):
+ _toy_method_resolve(OpMiss(kind="x"), "$method:unmarked")
+
+
+# ---------------------------------------------------------------------------
+# Validation — reserved names / wrong wrappers rejected at decoration
+# ---------------------------------------------------------------------------
+
+
+class TestMethodValidation:
+ """``@method`` + the registration path both raise with clear,
+ class-scoped messages when a name or wrapper is reserved.
+ """
+
+ def test_rejects_classmethod(self) -> None:
+ """``@classmethod``'s first-arg is the class, not the instance —
+ breaks the packed-call convention. Rejected at decoration time
+ (before py_class even sees the method).
+ """
+ with pytest.raises(TypeError, match=r"@classmethod is not supported"):
+
+ class _Bad:
+ @method
+ @classmethod
+ def maker(cls) -> int:
+ return 0
+
+ def test_rejects_classmethod_method_decorator_order_swap(self) -> None:
+ """The decorator catches ``@method @classmethod`` but a user can
+ bypass it by writing ``@classmethod @method`` — @method runs
+ first on the bare function, then classmethod wraps the marked
+ function. The collector must surface this with a clear error;
+ without the guard, the entry would silently fail to register
+ (Python 3.11+) or register as a malformed instance method.
+ """
+ with pytest.raises(TypeError, match=r"wrapped by @classmethod"):
+
+ @py_class(_unique_key("CMOrderBad"))
+ class _CMOrderBad(Object):
+ x: int
+
+ @classmethod
+ @method
+ def maker(cls) -> int:
+ return 0
+
+ def test_rejects_manually_marked_classmethod(self) -> None:
+ """The decorator can also be bypassed by marking a function
+ directly (``fn.__ffi_method__ = True``) and then wrapping it
+ in ``classmethod``. The collector's classmethod check fires
+ on the descriptor regardless of how the marker got there.
+ """
+
+ def _maker(cls: type) -> int:
+ return 0
+
+ _maker.__ffi_method__ = True # ty: ignore[unresolved-attribute]
+ cm = classmethod(_maker)
+
+ with pytest.raises(TypeError, match=r"wrapped by @classmethod"):
+
+ @py_class(_unique_key("CMManualBad"))
+ class _CMManualBad(Object):
+ x: int
+ maker = cm
+
+ def test_rejects_non_callable(self) -> None:
+ """``@method`` applied to a bare value (not a callable) raises."""
+ with pytest.raises(TypeError, match=r"expected a callable"):
+ method(42)
+
+ def test_rejects_reserved_ffi_prefix(self) -> None:
+ """``__ffi_*`` names are routed through TypeAttrColumn — using
+ ``@method`` on them is surely a user error (would silently
+ double-register), so ``_collect_py_methods`` raises.
+ """
+ with pytest.raises(NameError, match=r"reserved ``__ffi_`` prefix"):
+
+ @py_class(_unique_key("RFfiPfx"))
+ class _RFfiPfx(Object):
+ x: int
+
+ @method
+ def __ffi_custom__(self) -> int:
+ return 0
+
+ def test_rejects_typeattrcolumn_name(self) -> None:
+ """Decorating a TypeAttrColumn dunder with ``@method`` is
+ rejected — those are routed to ``TVMFFITypeRegisterAttr``
+ already, never to TypeMethod.
+ """
+ with pytest.raises(NameError, match=r"TypeAttrColumn"):
+
+ @py_class(_unique_key("RAttr"))
+ class _RAttr(Object):
+ x: int
+
+ @method
+ def __ffi_repr__(self, fn_repr: Any) -> str:
+ return "r"
+
+ def test_rejects_python_protocol_dunder(self) -> None:
+ """``__len__`` / ``__iter__`` / etc. are reserved for Python
+ semantics — cannot be FFI TypeMethods.
+ """
+ with pytest.raises(NameError, match=r"Python protocol dunder"):
+
+ @py_class(_unique_key("RDun"))
+ class _RDun(Object):
+ x: int
+
+ @method
+ def __len__(self) -> int:
+ return 0