This is an automated email from the ASF dual-hosted git repository.

cyx-6 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new bc64ba3  [REFACTOR] Optimized Python callback path via 
TVMFFIPyCallbackArgSetter (#569)
bc64ba3 is described below

commit bc64ba3e981ae44a542ed4e7590320b15e530f57
Author: Tianqi Chen <[email protected]>
AuthorDate: Sat Apr 25 12:57:01 2026 -0400

    [REFACTOR] Optimized Python callback path via TVMFFIPyCallbackArgSetter 
(#569)
    
    This PR optimizes python callback specifically when the arguments expect
    a torch.Tensor class.
    
    `tvm_ffi.convert_func(pyfunc, tensor_cls=None)` is provided. If
    `tensor_cls`
    is provided, the closure threads its `__dlpack_c_exchange_api__` capsule
    into the callback so tensor args are delivered as `tensor_cls` instances
    (e.g., `torch.Tensor`) directly from the C-level setter — no per-arg
    Python-level conversion without python torch dlpack conversion overhead
    
    Example:
    
    ```python
    import torch
    import tvm_ffi
    
    def callback(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> None:
        # a, b, c arrive as torch.Tensor, converted in C via the closure's
        # DLPack exchange API — no from_dlpack() needed inside the callback.
        torch.add(a, b, out=c)
    
    f = tvm_ffi.convert_func(callback, tensor_cls=torch.Tensor)
    ```
    
    | Variant                               | Per-call |
    |---------------------------------------|----------|
    | pycallback[tensor_cls=torch.Tensor]   | ~640 ns  |
    | pycallback+from_dlpack (explicit x3)  | ~9.1 us  |
---
 docs/reference/python/index.rst                |   1 +
 python/tvm_ffi/__init__.py                     |   3 +-
 python/tvm_ffi/_convert.py                     |  69 +++-
 python/tvm_ffi/core.pyi                        |   2 +-
 python/tvm_ffi/cython/base.pxi                 |  42 ++-
 python/tvm_ffi/cython/core.pyx                 |   1 +
 python/tvm_ffi/cython/error.pxi                |  12 +
 python/tvm_ffi/cython/function.pxi             |  68 +---
 python/tvm_ffi/cython/object.pxi               |   4 +-
 python/tvm_ffi/cython/pycallback.pxi           | 285 ++++++++++++++
 python/tvm_ffi/cython/tvm_ffi_python_helpers.h | 502 +++++++++++++++++++++----
 python/tvm_ffi/cython/type_info.pxi            |   3 -
 tests/python/test_function.py                  |  65 ++++
 tests/scripts/benchmark_pycallback.py          | 111 ++++++
 14 files changed, 1018 insertions(+), 150 deletions(-)

diff --git a/docs/reference/python/index.rst b/docs/reference/python/index.rst
index 93144d9..a59a587 100644
--- a/docs/reference/python/index.rst
+++ b/docs/reference/python/index.rst
@@ -134,6 +134,7 @@ Misc
   access_path.AccessPath
   access_path.AccessStep
   convert
+  convert_func
   ObjectConvertible
 
 .. (Experimental) Dataclasses
diff --git a/python/tvm_ffi/__init__.py b/python/tvm_ffi/__init__.py
index b1872a7..012de31 100644
--- a/python/tvm_ffi/__init__.py
+++ b/python/tvm_ffi/__init__.py
@@ -67,7 +67,7 @@ if TYPE_CHECKING or not _is_config_mode():
     )
     from ._dtype import dtype
     from .core import Object, ObjectConvertible, Function, CAny, CContainerBase
-    from ._convert import convert
+    from ._convert import convert, convert_func
     from .error import register_error
     from ._tensor import Device, device, DLDeviceType
     from ._tensor import from_dlpack, Tensor, Shape
@@ -146,6 +146,7 @@ __all__ = [
     "__version_tuple__",
     "access_path",
     "convert",
+    "convert_func",
     "cpp",
     "dataclasses",
     "device",
diff --git a/python/tvm_ffi/_convert.py b/python/tvm_ffi/_convert.py
index 49e7931..c04688c 100644
--- a/python/tvm_ffi/_convert.py
+++ b/python/tvm_ffi/_convert.py
@@ -21,7 +21,7 @@ from __future__ import annotations
 import ctypes
 from numbers import Number
 from types import ModuleType
-from typing import Any
+from typing import Any, Callable
 
 from . import _dtype, container, core
 
@@ -138,3 +138,70 @@ def convert(value: Any) -> Any:  # noqa: PLR0911,PLR0912
     else:
         # in this case, it is an opaque python object
         return core._convert_to_opaque_object(value)
+
+
+def convert_func(
+    pyfunc: Callable[..., Any],
+    tensor_cls: type | None = None,
+) -> Any:
+    """Convert a Python callable to an FFI :py:class:`~tvm_ffi.Function`.
+
+    This is the callable-specific sibling of :py:func:`tvm_ffi.convert`.
+    It accepts one extra argument, ``tensor_cls``, that lets the caller
+    specify how tensor arguments should be delivered to the Python
+    callable when the resulting :py:class:`Function` is invoked from C++.
+    :py:func:`tvm_ffi.convert` has no such knob — it always produces a
+    :py:class:`Function` whose callback receives ``tvm_ffi.Tensor`` for
+    tensor args.
+
+    Parameters
+    ----------
+    pyfunc : Callable
+        The Python callable to wrap.
+    tensor_cls : type, optional
+        The class whose instances the callback should receive for tensor
+        args. The class must expose a ``__dlpack_c_exchange_api__``
+        :py:class:`PyCapsule`; its capsule is threaded into the callback
+        closure so tensor args are converted at the C level (via the
+        DLPack exchange API) before the Python callback body runs — this
+        is significantly faster than calling ``torch.from_dlpack(x)`` (or
+        equivalent) inside the callback. Raises :py:class:`TypeError` if
+        ``tensor_cls`` does not expose the attribute.
+
+        When ``tensor_cls`` is ``None``, ``convert_func`` behaves like the
+        callable branch of :py:func:`tvm_ffi.convert`.
+
+    Returns
+    -------
+    Function
+        The wrapped FFI function.
+
+    Examples
+    --------
+    .. code-block:: python
+
+        import torch
+        import tvm_ffi
+
+        # Without tensor_cls: same as tvm_ffi.convert(pyfunc) — the callback
+        # receives tvm_ffi.Tensor for tensor args.
+        f = tvm_ffi.convert_func(lambda x: x + 1)
+        assert isinstance(f, tvm_ffi.Function)
+
+
+        # With tensor_cls=torch.Tensor: the callback receives torch.Tensor
+        # directly; the DLPack conversion happens in C before the body runs.
+        def callback(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
+            return a + b
+
+
+        g = tvm_ffi.convert_func(callback, tensor_cls=torch.Tensor)
+
+    See Also
+    --------
+    :py:func:`tvm_ffi.convert` :
+        Generic value-to-FFI conversion. Use this when you don't need to
+        specify ``tensor_cls``.
+
+    """
+    return core._convert_to_ffi_func(pyfunc, tensor_cls=tensor_cls)
diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi
index 90071c7..892a3b7 100644
--- a/python/tvm_ffi/core.pyi
+++ b/python/tvm_ffi/core.pyi
@@ -210,7 +210,7 @@ def _register_global_func(
     name: str, pyfunc: Callable[..., Any] | Function, override: bool
 ) -> Function: ...
 def _get_global_func(name: str, allow_missing: bool) -> Function | None: ...
-def _convert_to_ffi_func(pyfunc: Callable[..., Any]) -> Function: ...
+def _convert_to_ffi_func(pyfunc: Callable[..., Any], tensor_cls: type | None = 
...) -> Function: ...
 def _convert_to_opaque_object(pyobject: Any) -> OpaquePyObject: ...
 def _print_debug_info() -> None: ...
 def _register_py_class(parent_type_info: TypeInfo, type_key: str, type_cls: 
type) -> TypeInfo: ...
diff --git a/python/tvm_ffi/cython/base.pxi b/python/tvm_ffi/cython/base.pxi
index b851e1e..c5c28a1 100644
--- a/python/tvm_ffi/cython/base.pxi
+++ b/python/tvm_ffi/cython/base.pxi
@@ -366,10 +366,8 @@ cdef extern from "tvm_ffi_python_helpers.h":
         int (*func)(TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,  
PyObject* py_arg, TVMFFIAny* out) except -1
         const DLPackExchangeAPI* dlpack_c_exchange_api
 
-    ctypedef int (*TVMFFIPyArgSetterFactory)(PyObject* value, 
TVMFFIPyArgSetter* out) except -1
     # The main call function
     int TVMFFIPyFuncCall(
-        TVMFFIPyArgSetterFactory setter_factory,
         void* chandle,
         PyObject* py_arg_tuple,
         TVMFFIAny* result,
@@ -379,7 +377,6 @@ cdef extern from "tvm_ffi_python_helpers.h":
     ) except -1
 
     int TVMFFIPyConstructorCall(
-        TVMFFIPyArgSetterFactory setter_factory,
         void* chandle,
         PyObject* py_arg_tuple,
         TVMFFIAny* result,
@@ -388,7 +385,6 @@ cdef extern from "tvm_ffi_python_helpers.h":
     ) except -1
 
     int TVMFFIPyCallFieldSetter(
-        TVMFFIPyArgSetterFactory setter_factory,
         void* field_setter,
         int64_t field_flags,
         void* field_ptr,
@@ -397,20 +393,18 @@ cdef extern from "tvm_ffi_python_helpers.h":
     ) except -1
 
     int TVMFFIPyPyObjectToFFIAny(
-        TVMFFIPyArgSetterFactory setter_factory,
         PyObject* py_arg,
         TVMFFIAny* out,
         int* c_api_ret_code
     ) except -1
 
     int TVMFFIPySetArgumentGenericDispatcher(
-        TVMFFIPyArgSetterFactory setter_factory,
         TVMFFIPyCallContext* ctx,
         PyObject* py_arg,
         TVMFFIAny* out
     ) except -1
 
-    size_t TVMFFIPyGetDispatchMapSize() noexcept
+    size_t TVMFFIPyGetArgDispatchMapSize() noexcept
 
     void TVMFFIPyPushTempFFIObject(TVMFFIPyCallContext* ctx, 
TVMFFIObjectHandle arg) noexcept
     void TVMFFIPyPushTempPyObject(TVMFFIPyCallContext* ctx, PyObject* arg) 
noexcept
@@ -420,6 +414,40 @@ cdef extern from "tvm_ffi_python_helpers.h":
     int TVMFFIPyArgSetterInt_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*, 
PyObject* arg, TVMFFIAny* out) except -1
     int TVMFFIPyArgSetterBool_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*, 
PyObject* arg, TVMFFIAny* out) except -1
     int TVMFFIPyArgSetterNone_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*, 
PyObject* arg, TVMFFIAny* out) except -1
+
+    # Callback arg setter types — view-based AnyView -> PyObject conversion
+    # used by the C++ -> Python callback path (PyCallback).
+    ctypedef int 
(*TVMFFIPyCallbackArgSetterCallback)(TVMFFIPyCallbackArgSetter* handle,
+                                                      const DLPackExchangeAPI* 
api,
+                                                      const TVMFFIAny* arg,
+                                                      PyObject** out) except -1
+
+    ctypedef struct TVMFFIPyCallbackArgSetter:
+        TVMFFIPyCallbackArgSetterCallback func
+
+    # Built-in C callback arg setters for POD types.
+    int TVMFFIPyCallbackArgSetterNone_(TVMFFIPyCallbackArgSetter*, const 
DLPackExchangeAPI*,
+                                       const TVMFFIAny*, PyObject** out) 
except -1
+    int TVMFFIPyCallbackArgSetterBool_(TVMFFIPyCallbackArgSetter*, const 
DLPackExchangeAPI*,
+                                       const TVMFFIAny* arg, PyObject** out) 
except -1
+    int TVMFFIPyCallbackArgSetterInt_(TVMFFIPyCallbackArgSetter*, const 
DLPackExchangeAPI*,
+                                      const TVMFFIAny* arg, PyObject** out) 
except -1
+    int TVMFFIPyCallbackArgSetterFloat_(TVMFFIPyCallbackArgSetter*, const 
DLPackExchangeAPI*,
+                                        const TVMFFIAny* arg, PyObject** out) 
except -1
+    int TVMFFIPyCallbackArgSetterSmallStr_(TVMFFIPyCallbackArgSetter*, const 
DLPackExchangeAPI*,
+                                           const TVMFFIAny* arg, PyObject** 
out) except -1
+    int TVMFFIPyCallbackArgSetterSmallBytes_(TVMFFIPyCallbackArgSetter*, const 
DLPackExchangeAPI*,
+                                             const TVMFFIAny* arg, PyObject** 
out) except -1
+
+    int TVMFFIPyCallback(void* context, const TVMFFIAny* packed_args,
+                         int32_t num_args, TVMFFIAny* result) noexcept
+
+    # Closure + convert helper for the PyCallback path. Returns the raw FFI rc;
+    # callers use CHECK_CALL to translate the TLS FFI error into a Python 
exception.
+    int TVMFFIPyConvertPyCallback(PyObject* callable,
+                                  const DLPackExchangeAPI* dlpack_api,
+                                  TVMFFIObjectHandle* out_handle) noexcept
+
     # MLIRPackedSafeCall
     void* TVMFFIPyMLIRPackedSafeCallCreate(void 
(*mlir_packed_safe_call)(void**) noexcept, PyObject* keep_alive_object)
     int TVMFFIPyMLIRPackedSafeCallInvoke(void* self, const TVMFFIAny* args, 
int32_t num_args, TVMFFIAny* rv)
diff --git a/python/tvm_ffi/cython/core.pyx b/python/tvm_ffi/cython/core.pyx
index 2c0a161..b025a48 100644
--- a/python/tvm_ffi/cython/core.pyx
+++ b/python/tvm_ffi/cython/core.pyx
@@ -38,4 +38,5 @@ include "./tensor.pxi"
 _register_object_by_index(kTVMFFITensor, Tensor)
 include "./function.pxi"
 _register_object_by_index(kTVMFFIFunction, Function)
+include "./pycallback.pxi"
 include "./pyclass_type_converter.pxi"
diff --git a/python/tvm_ffi/cython/error.pxi b/python/tvm_ffi/cython/error.pxi
index 6f8159c..cbd1194 100644
--- a/python/tvm_ffi/cython/error.pxi
+++ b/python/tvm_ffi/cython/error.pxi
@@ -150,6 +150,18 @@ def _convert_to_ffi_error(error: BaseException) -> Error:
         return Error(kind, message, py_backtrace)
 
 
+cdef public int TVMFFICyErrorSetRaisedFromPyError(PyObject* py_err) noexcept:
+    """Set the last FFI error from a Python exception.
+
+    Parameters
+    ----------
+    py_err : PyObject*
+        The Python exception to set as the last FFI error.
+    """
+    set_last_ffi_error(<object>py_err)
+    return -1
+
+
 cdef inline int CHECK_CALL(int ret) except -2:
     """Check the return code of the C API function call"""
     if ret == 0:
diff --git a/python/tvm_ffi/cython/function.pxi 
b/python/tvm_ffi/cython/function.pxi
index 48aaf5c..f613fda 100644
--- a/python/tvm_ffi/cython/function.pxi
+++ b/python/tvm_ffi/cython/function.pxi
@@ -120,7 +120,7 @@ cdef inline int ConstructorCall(void* constructor_handle,
     result.type_index = kTVMFFINone
     result.v_int64 = 0
     TVMFFIPyConstructorCall(
-        TVMFFIPyArgSetterFactory_, constructor_handle, py_arg_tuple, &result, 
&c_api_ret_code,
+        constructor_handle, py_arg_tuple, &result, &c_api_ret_code,
         parent_ctx
     )
     CHECK_CALL(c_api_ret_code)
@@ -556,9 +556,8 @@ cdef int TVMFFIPyArgSetterCallable_(
     PyObject* py_arg, TVMFFIAny* out
 ) except -1:
     """Setter for Callable"""
-    cdef object arg = <object>py_arg
     cdef TVMFFIObjectHandle chandle
-    _convert_to_ffi_func_handle(arg, &chandle)
+    CHECK_CALL(TVMFFIPyConvertPyCallback(py_arg, NULL, &chandle))
     out.type_index = TVMFFIObjectGetTypeIndex(chandle)
     out.v_ptr = chandle
     TVMFFIPyPushTempFFIObject(ctx, chandle)
@@ -725,16 +724,14 @@ cdef int TVMFFIPyArgSetterFFIValueProtocol_(
     # keep alive the python object since this is a temporary object
     # we must push to extra temp py objects stack to avoid overflow the temp 
py objects stack
     TVMFFIPyPushExtraTempPyObject(ctx, ffi_value_py_obj_ptr)
-    return TVMFFIPySetArgumentGenericDispatcher(
-        TVMFFIPyArgSetterFactory_, ctx, ffi_value_py_obj_ptr, out
-    )
+    return TVMFFIPySetArgumentGenericDispatcher(ctx, ffi_value_py_obj_ptr, out)
 
 
 cdef _DISPATCH_TYPE_KEEP_ALIVE = set()
 cdef _DISPATCH_TYPE_KEEP_ALIVE_LOCK = threading.Lock()
 
 
-cdef int TVMFFIPyArgSetterFactory_(PyObject* value, TVMFFIPyArgSetter* out) 
except -1:
+cdef public int TVMFFICyArgSetterFactory(PyObject* value, TVMFFIPyArgSetter* 
out) except -1:
     """
     Factory function that creates an argument setter for a given Python 
argument type.
     """
@@ -950,7 +947,6 @@ cdef class Function(CObject):
         result.type_index = kTVMFFINone
         result.v_int64 = 0
         TVMFFIPyFuncCall(
-            TVMFFIPyArgSetterFactory_,
             (<CObject>self).chandle, <PyObject*>args,
             &result,
             &c_api_ret_code,
@@ -1108,56 +1104,6 @@ def _get_global_func(name: str, allow_missing: bool):
     raise ValueError("Cannot find global function %s" % name)
 
 
-cdef int tvm_ffi_callback(void* context,
-                          const TVMFFIAny* packed_args,
-                          int32_t num_args,
-                          TVMFFIAny* result) noexcept with gil:
-    cdef list pyargs
-    cdef TVMFFIAny temp_result
-    cdef int c_api_ret_code
-    cdef object local_pyfunc = <object>(context)
-    pyargs = []
-
-    for i in range(num_args):
-        CHECK_CALL(TVMFFIAnyViewToOwnedAny(&packed_args[i], &temp_result))
-        pyargs.append(make_ret(temp_result))
-
-    try:
-        rv = local_pyfunc(*pyargs)
-        TVMFFIPyPyObjectToFFIAny(
-            TVMFFIPyArgSetterFactory_,
-            <PyObject*>rv,
-            result,
-            &c_api_ret_code
-        )
-        return c_api_ret_code
-    except Exception as err:
-        set_last_ffi_error(err)
-        return -1
-
-
-cdef inline int _convert_to_ffi_func_handle(
-    object pyfunc, TVMFFIObjectHandle* out_handle
-) except -1:
-    """Convert a python function to TVM FFI function handle"""
-    Py_INCREF(pyfunc)
-    CHECK_CALL(TVMFFIFunctionCreate(
-        <void*>(pyfunc),
-        tvm_ffi_callback,
-        TVMFFIPyObjectDeleter,
-        out_handle))
-    return 0
-
-
-def _convert_to_ffi_func(object pyfunc: Callable[..., Any]) -> Function:
-    """Convert a python function to TVM FFI function"""
-    cdef TVMFFIObjectHandle chandle
-    _convert_to_ffi_func_handle(pyfunc, &chandle)
-    ret = Function.__new__(Function)
-    (<CObject>ret).chandle = chandle
-    return ret
-
-
 cdef inline int _convert_to_opaque_object_handle(
     object pyobject, TVMFFIObjectHandle* out_handle
 ) except -1:
@@ -1201,9 +1147,9 @@ def _testing_drop_last_ref_without_thread_state() -> None:
 
 
 def _print_debug_info() -> None:
-    """Get the size of the dispatch map"""
-    cdef size_t size = TVMFFIPyGetDispatchMapSize()
-    print(f"TVMFFIPyGetDispatchMapSize: {size}")
+    """Get the size of the arg dispatch map"""
+    cdef size_t size = TVMFFIPyGetArgDispatchMapSize()
+    print(f"TVMFFIPyGetArgDispatchMapSize: {size}")
 
 
 cdef Function _OBJECT_FROM_JSON_GRAPH_STR = 
_get_global_func("ffi.FromJSONGraphString", True)
diff --git a/python/tvm_ffi/cython/object.pxi b/python/tvm_ffi/cython/object.pxi
index 4b8c52d..803ead2 100644
--- a/python/tvm_ffi/cython/object.pxi
+++ b/python/tvm_ffi/cython/object.pxi
@@ -200,7 +200,7 @@ cdef class CContainerBase(CObject):
     # 1. __dlpack_c_exchange_api__ (e.g. torch.Tensor) — points to a
     #    static struct in the framework's C++ runtime.  The source
     #    type is kept alive by _DISPATCH_TYPE_KEEP_ALIVE (set in
-    #    TVMFFIPyArgSetterFactory_), which prevents module unloading.
+    #    TVMFFICyArgSetterFactory), which prevents module unloading.
     #
     # 2. GetTorchFallbackExchangeAPI() — returns the address of a
     #    module-level Cython static; lives for the entire process.
@@ -758,7 +758,6 @@ def _register_type_attr(type_index: int32_t, attr_key: str, 
value: object) -> No
     temp.type_index = kTVMFFINone
     temp.v_int64 = 0
     TVMFFIPyPyObjectToFFIAny(
-        TVMFFIPyArgSetterFactory_,
         <PyObject*>value,
         &temp,
         &c_api_ret_code,
@@ -818,7 +817,6 @@ cdef class CAny:
         temp.type_index = kTVMFFINone
         temp.v_int64 = 0
         TVMFFIPyPyObjectToFFIAny(
-            TVMFFIPyArgSetterFactory_,
             <PyObject*>value,
             &temp,
             &c_api_ret_code
diff --git a/python/tvm_ffi/cython/pycallback.pxi 
b/python/tvm_ffi/cython/pycallback.pxi
new file mode 100644
index 0000000..fd5f686
--- /dev/null
+++ b/python/tvm_ffi/cython/pycallback.pxi
@@ -0,0 +1,285 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# ----------------------------------------------------------------------------
+#  Cython-side implementation of the C++ -> Python callback path.
+#
+#  The C++ side (tvm_ffi_python_helpers.h + TVMFFIPyCallManager::PyCallback)
+#  calls into the functions defined here via function pointers registered in
+#  callback_arg_dispatch_table_. Each per-type setter takes a borrowed
+#  AnyView and produces a new-reference PyObject*.
+#
+# The caller will grab the thread state before caling into each individual 
setter.
+#
+#  This file also hosts `_convert_to_ffi_func`, the Cython entry point that
+#  wraps a Python callable as a FFI Function backed by a TVMFFIPyCallback
+#  closure (see TVMFFIPyConvertPyCallback in the header).
+# ----------------------------------------------------------------------------
+
+
+cdef int TVMFFIPyCallbackArgSetterTensor_(
+    TVMFFIPyCallbackArgSetter* handle,
+    const DLPackExchangeAPI* api,
+    const TVMFFIAny* arg,
+    PyObject** out
+) except -1:
+    """Callback arg setter for kTVMFFITensor -> ffi.Tensor or torch.Tensor 
(via DLPack).
+
+    The DLPack branch is inlined rather than delegated to
+    ``make_tensor_from_chandle`` so we can pass a borrowed chandle:
+    ``TensorObj::ToDLPackVersioned`` incs internally, so the inc/dec pair
+    that ``make_tensor_from_chandle`` requires on an owned chandle is pure
+    waste here. The non-DLPack branch upgrades to owned and reuses
+    ``make_tensor_from_chandle`` for consistency with the RValueRef path.
+    """
+    cdef TVMFFIObjectHandle chandle = arg.v_ptr
+    cdef DLManagedTensorVersioned* dlpack
+    cdef void* py_obj
+
+    if api != NULL and api.managed_tensor_to_py_object_no_sync != NULL:
+        if TVMFFITensorToDLPackVersioned(chandle, &dlpack) == 0:
+            try:
+                api.managed_tensor_to_py_object_no_sync(dlpack, &py_obj)
+            except Exception:
+                dlpack.deleter(dlpack)
+                raise
+            # py_obj already holds +1 from the DLPack import; transfer to 
caller.
+            out[0] = <PyObject*>py_obj
+            return 0
+    # Non-DLPack path: upgrade borrowed -> owned, wrap via 
make_tensor_from_chandle.
+    TVMFFIObjectIncRef(chandle)
+    obj = make_tensor_from_chandle(chandle, NULL)
+    Py_INCREF(obj)
+    out[0] = <PyObject*>obj
+    return 0
+
+
+cdef int TVMFFIPyCallbackArgSetterObject_(
+    TVMFFIPyCallbackArgSetter* handle,
+    const DLPackExchangeAPI* api,
+    const TVMFFIAny* arg,
+    PyObject** out
+) except -1:
+    """Callback arg setter for generic static object types (type_index >= 
kTVMFFIStaticObjectBegin)."""
+    cdef TVMFFIObjectHandle chandle = arg.v_ptr
+    TVMFFIObjectIncRef(chandle)
+    try:
+        obj = make_ret_object(arg[0])
+        if api != NULL and isinstance(obj, CContainerBase):
+            (<CContainerBase>obj)._dlpack_exchange_api = api
+    except BaseException:
+        TVMFFIObjectDecRef(chandle)
+        raise
+    Py_INCREF(obj)
+    out[0] = <PyObject*>obj
+    return 0
+
+
+cdef int TVMFFIPyCallbackArgSetterOpaquePyObject_(
+    TVMFFIPyCallbackArgSetter* handle,
+    const DLPackExchangeAPI* api,
+    const TVMFFIAny* arg,
+    PyObject** out
+) except -1:
+    """Callback arg setter for kTVMFFIOpaquePyObject -> underlying Python 
object.
+
+    Inlined equivalent of `make_ret_opaque_object`: reads the cell's Python
+    handle directly, skipping the throwaway OpaquePyObject wrapper that
+    would otherwise be created just to extract the handle. `arg` is
+    borrowed, but the cell stays alive for the callback's duration, so the
+    handle is safe to read without inc'ing the chandle.
+    """
+    cdef PyObject* py_handle = 
<PyObject*>TVMFFIOpaqueObjectGetCellPtr(arg.v_ptr).handle
+    cdef object obj = <object>py_handle
+    Py_INCREF(obj)
+    out[0] = <PyObject*>obj
+    return 0
+
+
+cdef int TVMFFIPyCallbackArgSetterOpaquePtr_(
+    TVMFFIPyCallbackArgSetter* handle,
+    const DLPackExchangeAPI* api,
+    const TVMFFIAny* arg,
+    PyObject** out
+) except -1:
+    """Callback arg setter for kTVMFFIOpaquePtr -> ctypes.c_void_p."""
+    obj = ctypes_handle(arg.v_ptr)
+    Py_INCREF(obj)
+    out[0] = <PyObject*>obj
+    return 0
+
+
+cdef int TVMFFIPyCallbackArgSetterDataType_(
+    TVMFFIPyCallbackArgSetter* handle,
+    const DLPackExchangeAPI* api,
+    const TVMFFIAny* arg,
+    PyObject** out
+) except -1:
+    """Callback arg setter for kTVMFFIDataType -> DataType."""
+    obj = make_ret_dtype(arg[0])
+    Py_INCREF(obj)
+    out[0] = <PyObject*>obj
+    return 0
+
+
+cdef int TVMFFIPyCallbackArgSetterDevice_(
+    TVMFFIPyCallbackArgSetter* handle,
+    const DLPackExchangeAPI* api,
+    const TVMFFIAny* arg,
+    PyObject** out
+) except -1:
+    """Callback arg setter for kTVMFFIDevice -> Device."""
+    obj = make_ret_device(arg[0])
+    Py_INCREF(obj)
+    out[0] = <PyObject*>obj
+    return 0
+
+
+cdef int TVMFFIPyCallbackArgSetterDLTensorPtr_(
+    TVMFFIPyCallbackArgSetter* handle,
+    const DLPackExchangeAPI* api,
+    const TVMFFIAny* arg,
+    PyObject** out
+) except -1:
+    """Callback arg setter for kTVMFFIDLTensorPtr -> ffi.Tensor (via DLTensor 
pointer)."""
+    obj = make_ret_dltensor(arg[0])
+    Py_INCREF(obj)
+    out[0] = <PyObject*>obj
+    return 0
+
+
+cdef int TVMFFIPyCallbackArgSetterRValueRef_(
+    TVMFFIPyCallbackArgSetter* handle,
+    const DLPackExchangeAPI* api,
+    const TVMFFIAny* arg,
+    PyObject** out
+) except -1:
+    """Callback arg setter for kTVMFFIObjectRValueRef.
+
+    For RValueRef, ``arg.v_ptr`` is an ``Object**`` (address of the caller's
+    mutable slot holding the moved chandle), not the chandle itself. We read
+    the slot, null it out to prevent a double-move, and wrap WITHOUT inc'ing
+    (the move already gave us the +1).
+    """
+    cdef TVMFFIObjectHandle* slot_ptr = <TVMFFIObjectHandle*>arg.v_ptr
+    cdef TVMFFIObjectHandle chandle = slot_ptr[0]
+    if chandle == NULL:
+        raise ValueError("RValueRef already moved")
+    # mark as moved before constructing wrappers (so error paths don't 
double-move)
+    slot_ptr[0] = NULL
+    cdef int32_t actual_type_index = TVMFFIObjectGetTypeIndex(chandle)
+    cdef TVMFFIAny synthesized
+    synthesized.type_index = actual_type_index
+    synthesized.zero_padding = 0
+    synthesized.v_int64 = 0
+    synthesized.v_ptr = chandle
+    try:
+        if actual_type_index == kTVMFFITensor:
+            obj = make_tensor_from_chandle(chandle, api)
+        else:
+            obj = make_ret_object(synthesized)
+            if api != NULL and isinstance(obj, CContainerBase):
+                (<CContainerBase>obj)._dlpack_exchange_api = api
+    except BaseException:
+        # Caller's moved +1 needs to be released on error.
+        TVMFFIObjectDecRef(chandle)
+        raise
+    Py_INCREF(obj)
+    out[0] = <PyObject*>obj
+    return 0
+
+
+cdef public int TVMFFICyCallbackArgSetterFactory(int32_t type_index,
+                                                 TVMFFIPyCallbackArgSetter* 
out) except -1:
+    """Factory that creates callback arg setters for a given type index.
+
+    POD setters live in tvm_ffi_python_helpers.h (header-inline);
+    object-bearing setters are the Cython functions above.
+    """
+    if type_index >= kTVMFFIStaticObjectBegin:
+        if type_index == kTVMFFITensor:
+            out.func = TVMFFIPyCallbackArgSetterTensor_
+        elif type_index == kTVMFFIOpaquePyObject:
+            out.func = TVMFFIPyCallbackArgSetterOpaquePyObject_
+        else:
+            out.func = TVMFFIPyCallbackArgSetterObject_
+        return 0
+    if type_index == kTVMFFINone:
+        out.func = TVMFFIPyCallbackArgSetterNone_
+    elif type_index == kTVMFFIBool:
+        out.func = TVMFFIPyCallbackArgSetterBool_
+    elif type_index == kTVMFFIInt:
+        out.func = TVMFFIPyCallbackArgSetterInt_
+    elif type_index == kTVMFFIFloat:
+        out.func = TVMFFIPyCallbackArgSetterFloat_
+    elif type_index == kTVMFFISmallStr:
+        out.func = TVMFFIPyCallbackArgSetterSmallStr_
+    elif type_index == kTVMFFISmallBytes:
+        out.func = TVMFFIPyCallbackArgSetterSmallBytes_
+    elif type_index == kTVMFFIOpaquePtr:
+        out.func = TVMFFIPyCallbackArgSetterOpaquePtr_
+    elif type_index == kTVMFFIDataType:
+        out.func = TVMFFIPyCallbackArgSetterDataType_
+    elif type_index == kTVMFFIDevice:
+        out.func = TVMFFIPyCallbackArgSetterDevice_
+    elif type_index == kTVMFFIDLTensorPtr:
+        out.func = TVMFFIPyCallbackArgSetterDLTensorPtr_
+    elif type_index == kTVMFFIObjectRValueRef:
+        out.func = TVMFFIPyCallbackArgSetterRValueRef_
+    elif type_index == kTVMFFIByteArrayPtr:
+        raise ValueError("Callback arg cannot be ByteArrayPtr")
+    elif type_index == kTVMFFIRawStr:
+        raise ValueError("Callback arg cannot be RawStr")
+    else:
+        raise ValueError("Unhandled type index %d" % type_index)
+    return 0
+
+
+def _convert_to_ffi_func(
+    object pyfunc: Callable[..., Any],
+    object tensor_cls: object = None,
+) -> Function:
+    """Convert a python function to a TVM FFI Function.
+
+    Parameters
+    ----------
+    pyfunc : Callable
+        The Python callable to wrap. Incref'd into a TVMFFIPyCallbackClosure.
+    tensor_cls : type, optional
+        If given, its ``__dlpack_c_exchange_api__`` capsule is threaded into 
the
+        closure and used when constructing tensor return values inside the
+        callback.
+
+    Returns
+    -------
+    Function
+        The wrapped FFI function.
+    """
+    cdef TVMFFIObjectHandle chandle
+    cdef const DLPackExchangeAPI* dlpack_api = NULL
+    if tensor_cls is not None:
+        if not hasattr(tensor_cls, "__dlpack_c_exchange_api__"):
+            raise TypeError(
+                f"tensor_cls {tensor_cls!r} must expose 
__dlpack_c_exchange_api__"
+            )
+        _get_dlpack_exchange_api(
+            tensor_cls.__dlpack_c_exchange_api__, &dlpack_api
+        )
+    CHECK_CALL(TVMFFIPyConvertPyCallback(<PyObject*>pyfunc, dlpack_api, 
&chandle))
+    ret = Function.__new__(Function)
+    (<CObject>ret).chandle = chandle
+    return ret
diff --git a/python/tvm_ffi/cython/tvm_ffi_python_helpers.h 
b/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
index 9c923af..79494d5 100644
--- a/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
+++ b/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
@@ -39,6 +39,7 @@
 #include <cstring>
 #include <exception>
 #include <iostream>
+#include <memory>
 #include <unordered_map>
 #include <vector>
 
@@ -63,6 +64,31 @@ class TVMFFIPyWithAttachedThreadState {
   PyGILState_STATE gstate_;
 };
 
+/*!
+ * \brief Closure state carried as the resource handle for an FFI function that
+ *        wraps a Python callable and optional exchange api for tensor 
handling.
+ *
+ * Created by TVMFFIPyConvertPyCallback and released by
+ * TVMFFIPyCallbackClosure::Deleter when the FFI function is destroyed.
+ */
+struct TVMFFIPyCallbackClosure {
+  /*! \brief Strong reference to the Python callable. */
+  PyObject* callable;
+  /*! \brief Optional DLPack exchange API used when constructing Tensor 
returns. */
+  const DLPackExchangeAPI* dlpack_exchange_api;
+  /*!
+   * \brief Deleter registered with TVMFFIFunctionCreate. Runs on FFI function 
destroy.
+   *
+   * Releases the closure's strong Python reference and frees the closure.
+   */
+  static void Deleter(void* context) noexcept {
+    TVMFFIPyWithAttachedThreadState thread_state;
+    auto* closure = static_cast<TVMFFIPyCallbackClosure*>(context);
+    Py_DecRef(closure->callable);
+    delete closure;
+  }
+};
+
 /*!
  * \brief Thread-local call stack used by TVMFFIPyCallContext.
  */
@@ -88,6 +114,9 @@ class TVMFFIPyCallStack {
   }
 };
 
+//---------------------------------------------------------------------------------------------
+// Support for Python -> FFI function calls.
+//---------------------------------------------------------------------------------------------
 /*!
  * \brief Context for each ffi call to track the stream, device and temporary 
arguments.
  */
@@ -246,24 +275,156 @@ int TVMFFIPyArgSetterNone_(TVMFFIPyArgSetter*, 
TVMFFIPyCallContext*, PyObject* a
 }
 
 
//---------------------------------------------------------------------------------------------
-// The following section contains the dispatcher logic for function calling
+// Support for PyCallback function calls.
 
//---------------------------------------------------------------------------------------------
+
 /*!
- * \brief Factory function that creates an argument setter for a given Python 
argument type.
+ * \brief Context for a C -> Python callback call.
  *
- * This factory function analyzes a Python argument and creates an appropriate 
setter
- * that can convert Python objects of the same type to C arguments for TVM FFI 
calls.
- * The setter will be cached for future use for setting argument of the same 
type.
+ * Owns a temporary PyObject* array that holds arguments converted from the
+ * packed FFI call. Space is first taken from the thread-local args_stack on
+ * TVMFFIPyCallStack; if insufficient, we fall back to the heap.
  *
- * \param arg The Python argument value used as a type example.
- * \param out Output parameter that receives the created argument setter.
- * \return 0 on success, -1 on failure. PyError should be set if -1 is 
returned.
+ * Unlike TVMFFIPyCallContext::~TVMFFIPyCallContext, this destructor does NOT
+ * attach a thread state — callers are expected to already hold one
+ * (e.g. via TVMFFIPyWithAttachedThreadState at the top of the callback).
  *
- * \note This is a callback function supplied by the caller. The factory must 
satisfy
- *       the invariance that the same setter can be used for other arguments 
with
- *       the same type as the provided example argument.
+ * The destructor also decrefs every PyObject* pushed into py_args[0 ..
+ * num_active_py_args-1], tracking the pushed count via `num_active_py_args`.
  */
-typedef int (*TVMFFIPyArgSetterFactory)(PyObject* arg, TVMFFIPyArgSetter* out);
+class TVMFFIPyCallbackContext {
+ public:
+  /*! \brief The temporary PyObject* slots for Python call arguments. */
+  PyObject** py_args = nullptr;
+  /*! \brief How many slots have a live reference and need decref on cleanup. 
*/
+  int32_t num_active_py_args = 0;
+  /*! \brief Number of total argument slots allocated. */
+  int32_t num_args = 0;
+
+  TVMFFIPyCallbackContext(TVMFFIPyCallStack* call_stack, int32_t num_args)
+      : num_args(num_args), call_stack_(call_stack) {
+    static_assert(sizeof(TVMFFIAny) % sizeof(PyObject*) == 0);
+    // slots needed in the unit of TVMFFIAny
+    int64_t slots_needed =
+        (static_cast<int64_t>(num_args) * sizeof(PyObject*) + 
sizeof(TVMFFIAny) - 1) /
+        sizeof(TVMFFIAny);
+    old_args_stack_top_ = call_stack->args_stack_top;
+    if (call_stack->args_stack_top + slots_needed <=
+        static_cast<int64_t>(call_stack->args_stack.size())) {
+      py_args =
+          reinterpret_cast<PyObject**>(call_stack->args_stack.data() + 
call_stack->args_stack_top);
+      call_stack->args_stack_top += slots_needed;
+    } else {
+      heap_ptr_ = new PyObject*[num_args];
+      py_args = heap_ptr_;
+    }
+  }
+
+  ~TVMFFIPyCallbackContext() {
+    // caller must already hold an attached thread state; do NOT re-attach.
+    // we ensure that all the pyargs are not null
+    for (int32_t i = 0; i < num_active_py_args; ++i) {
+      Py_DecRef(py_args[i]);
+    }
+    if (heap_ptr_ == nullptr) {
+      call_stack_->args_stack_top = old_args_stack_top_;
+    } else {
+      delete[] heap_ptr_;
+    }
+  }
+
+ private:
+  TVMFFIPyCallStack* call_stack_ = nullptr;
+  int64_t old_args_stack_top_ = 0;
+  PyObject** heap_ptr_ = nullptr;
+};
+
+/*!
+ * \brief A callback arg setter entry registered to handle efficient callback 
argument conversion.
+ */
+struct TVMFFIPyCallbackArgSetter {
+  /*!
+   * \brief Callback type that converts a borrowed TVMFFIAny (AnyView) to a 
new-reference PyObject*.
+   * \param handle Pointer to the TVMFFIPyCallbackArgSetter (for per-type 
state).
+   * \param dlpack_exchange_api The DLPack exchange API (may be NULL).
+   * \param arg The TVMFFIAny value to convert (borrowed; setter must inc if 
it transfers
+   * ownership).
+   * \param out Output: a new-reference PyObject*.
+   * \return 0 on success, -1 on failure (PyErr set).
+   */
+  int (*func)(TVMFFIPyCallbackArgSetter* handle, const DLPackExchangeAPI* 
dlpack_exchange_api,
+              const TVMFFIAny* arg, PyObject** out);
+};
+
+// common callback arg setters that can be quikcly implemented in C++ and used 
by cython factory
+// note that PyErr is propagated back so we just need to return -1 on failure.
+int TVMFFIPyCallbackArgSetterNone_(TVMFFIPyCallbackArgSetter*, const 
DLPackExchangeAPI*,
+                                   const TVMFFIAny*, PyObject** out) noexcept {
+  Py_IncRef(Py_None);
+  *out = Py_None;
+  return 0;
+}
+
+int TVMFFIPyCallbackArgSetterBool_(TVMFFIPyCallbackArgSetter*, const 
DLPackExchangeAPI*,
+                                   const TVMFFIAny* arg, PyObject** out) 
noexcept {
+  *out = PyBool_FromLong(static_cast<long>(arg->v_int64));
+  return (*out != nullptr) ? 0 : -1;
+}
+
+int TVMFFIPyCallbackArgSetterInt_(TVMFFIPyCallbackArgSetter*, const 
DLPackExchangeAPI*,
+                                  const TVMFFIAny* arg, PyObject** out) 
noexcept {
+  *out = PyLong_FromLongLong(arg->v_int64);
+  return (*out != nullptr) ? 0 : -1;
+}
+
+int TVMFFIPyCallbackArgSetterFloat_(TVMFFIPyCallbackArgSetter*, const 
DLPackExchangeAPI*,
+                                    const TVMFFIAny* arg, PyObject** out) 
noexcept {
+  *out = PyFloat_FromDouble(arg->v_float64);
+  return (*out != nullptr) ? 0 : -1;
+}
+
+int TVMFFIPyCallbackArgSetterSmallStr_(TVMFFIPyCallbackArgSetter*, const 
DLPackExchangeAPI*,
+                                       const TVMFFIAny* arg, PyObject** out) 
noexcept {
+  TVMFFIByteArray ba = TVMFFISmallBytesGetContentByteArray(arg);
+  *out = PyUnicode_DecodeUTF8(ba.data, static_cast<Py_ssize_t>(ba.size), 
nullptr);
+  return (*out != nullptr) ? 0 : -1;
+}
+
+int TVMFFIPyCallbackArgSetterSmallBytes_(TVMFFIPyCallbackArgSetter*, const 
DLPackExchangeAPI*,
+                                         const TVMFFIAny* arg, PyObject** out) 
noexcept {
+  TVMFFIByteArray ba = TVMFFISmallBytesGetContentByteArray(arg);
+  *out = PyBytes_FromStringAndSize(ba.data, static_cast<Py_ssize_t>(ba.size));
+  return (*out != nullptr) ? 0 : -1;
+}
+
+///--------------------------------------------------------------------------------
+/// Declaring functions defined in Cython to be invoked by the C++ 
implementation.
+/// in all cases PyErr is propagated back so we just need to return -1 on 
failure.
+///--------------------------------------------------------------------------------
+/*
+ * \brief Set the error raised from Python to the FFI side.
+ * \param py_err The Python error to be set.
+ * \return 0 on success, -1 on failure. PyError should be set if -1 is 
returned.
+ */
+__PYX_EXTERN_C int TVMFFICyErrorSetRaisedFromPyError(PyObject* py_err);
+/*
+ * \brief Create an argument setter for a given Python argument type.
+ * \param arg The Python argument to be set.
+ * \param out The output argument setter.
+ * \return 0 on success, -1 on failure. PyError should be set if -1 is 
returned.
+ */
+__PYX_EXTERN_C int TVMFFICyArgSetterFactory(PyObject* arg, TVMFFIPyArgSetter* 
out);
+/*
+ * \brief Create a callback arg setter for a given type index.
+ * \param type_index The type index of the argument.
+ * \param out The output callback arg setter.
+ * \return 0 on success, -1 on failure. PyError should be set if -1 is 
returned.
+ */
+__PYX_EXTERN_C int TVMFFICyCallbackArgSetterFactory(int32_t type_index,
+                                                    TVMFFIPyCallbackArgSetter* 
out);
+//---------------------------------------------------------------------------------------------
+// The function call manager section
+//---------------------------------------------------------------------------------------------
 
 /*!
  * \brief A manager class that handles python ffi calls.
@@ -280,7 +441,6 @@ class TVMFFIPyCallManager {
   }
   /*!
    * \brief Call a function with a variable number of arguments
-   * \param setter_factory The factory function to create the setter
    * \param func_handle The handle of the function to call
    * \param py_arg_tuple The arguments to the function
    * \param result The result of the function
@@ -290,9 +450,8 @@ class TVMFFIPyCallManager {
    * \return 0 on when there is no python error, -1 on python error
    * \note When an error happens on FFI side, we should return 0 and set 
c_api_ret_code
    */
-  TVM_FFI_INLINE int FuncCall(TVMFFIPyArgSetterFactory setter_factory, void* 
func_handle,
-                              PyObject* py_arg_tuple, TVMFFIAny* result, int* 
c_api_ret_code,
-                              bool release_gil,
+  TVM_FFI_INLINE int FuncCall(void* func_handle, PyObject* py_arg_tuple, 
TVMFFIAny* result,
+                              int* c_api_ret_code, bool release_gil,
                               const DLPackExchangeAPI** 
optional_out_ctx_dlpack_api) {
     int64_t num_args = PyTuple_Size(py_arg_tuple);
     if (num_args == -1) return -1;
@@ -303,7 +462,7 @@ class TVMFFIPyCallManager {
       for (int64_t i = 0; i < num_args; ++i) {
         PyObject* py_arg = PyTuple_GetItem(py_arg_tuple, i);
         TVMFFIAny* c_arg = ctx.packed_args + i;
-        if (SetArgument(setter_factory, &ctx, py_arg, c_arg) != 0) return -1;
+        if (SetArgument(&ctx, py_arg, c_arg) != 0) return -1;
       }
       TVMFFIStreamHandle prev_stream = nullptr;
       DLPackManagedTensorAllocator prev_tensor_allocator = nullptr;
@@ -371,7 +530,6 @@ class TVMFFIPyCallManager {
    *
    * This function will also not release  the GIL since constructor call is 
usually cheap.
    *
-   * \param setter_factory The factory function to create the setter
    * \param func_handle The handle of the constructor to call
    * \param py_arg_tuple The arguments to the constructor
    * \param result The result of the constructor
@@ -379,9 +537,8 @@ class TVMFFIPyCallManager {
    * \param parent_ctx The parent call context to
    * \return 0 on success, -1 on failure
    */
-  TVM_FFI_INLINE int ConstructorCall(TVMFFIPyArgSetterFactory setter_factory, 
void* func_handle,
-                                     PyObject* py_arg_tuple, TVMFFIAny* 
result, int* c_api_ret_code,
-                                     TVMFFIPyCallContext* parent_ctx) {
+  TVM_FFI_INLINE int ConstructorCall(void* func_handle, PyObject* 
py_arg_tuple, TVMFFIAny* result,
+                                     int* c_api_ret_code, TVMFFIPyCallContext* 
parent_ctx) {
     int64_t num_args = PyTuple_Size(py_arg_tuple);
     if (num_args == -1) return -1;
     try {
@@ -391,7 +548,7 @@ class TVMFFIPyCallManager {
       for (int64_t i = 0; i < num_args; ++i) {
         PyObject* py_arg = PyTuple_GetItem(py_arg_tuple, i);
         TVMFFIAny* c_arg = ctx.packed_args + i;
-        if (SetArgument(setter_factory, &ctx, py_arg, c_arg) != 0) return -1;
+        if (SetArgument(&ctx, py_arg, c_arg) != 0) return -1;
       }
       c_api_ret_code[0] = TVMFFIFunctionCall(func_handle, ctx.packed_args, 
num_args, result);
       // propagate the call context to the parent context
@@ -415,13 +572,12 @@ class TVMFFIPyCallManager {
     }
   }
 
-  TVM_FFI_INLINE int SetField(TVMFFIPyArgSetterFactory setter_factory, void* 
field_setter,
-                              int64_t field_flags, void* field_ptr, PyObject* 
py_arg,
-                              int* c_api_ret_code) {
+  TVM_FFI_INLINE int SetField(void* field_setter, int64_t field_flags, void* 
field_ptr,
+                              PyObject* py_arg, int* c_api_ret_code) {
     try {
       TVMFFIPyCallContext ctx(&call_stack_, 1);
       TVMFFIAny* c_arg = ctx.packed_args;
-      if (SetArgument(setter_factory, &ctx, py_arg, c_arg) != 0) return -1;
+      if (SetArgument(&ctx, py_arg, c_arg) != 0) return -1;
       if (!(field_flags & kTVMFFIFieldFlagBitSetterIsFunctionObj)) {
         auto setter = reinterpret_cast<TVMFFIFieldSetter>(field_setter);
         c_api_ret_code[0] = (*setter)(field_ptr, c_arg);
@@ -443,12 +599,11 @@ class TVMFFIPyCallManager {
     }
   }
 
-  TVM_FFI_INLINE int PyObjectToFFIAny(TVMFFIPyArgSetterFactory setter_factory, 
PyObject* py_arg,
-                                      TVMFFIAny* out, int* c_api_ret_code) {
+  TVM_FFI_INLINE int PyObjectToFFIAny(PyObject* py_arg, TVMFFIAny* out, int* 
c_api_ret_code) {
     try {
       TVMFFIPyCallContext ctx(&call_stack_, 1);
       TVMFFIAny* c_arg = ctx.packed_args;
-      if (SetArgument(setter_factory, &ctx, py_arg, c_arg) != 0) return -1;
+      if (SetArgument(&ctx, py_arg, c_arg) != 0) return -1;
       c_api_ret_code[0] = TVMFFIAnyViewToOwnedAny(c_arg, out);
       return 0;
     } catch (const std::exception& ex) {
@@ -460,14 +615,12 @@ class TVMFFIPyCallManager {
 
   /*!
    * \brief Set an py_arg to out.
-   * \param setter_factory The factory function to create the setter
    * \param ctx The call context
    * \param py_arg The python argument to be set
    * \param out The output argument
    * \return 0 on success, -1 on failure
    */
-  TVM_FFI_INLINE int SetArgument(TVMFFIPyArgSetterFactory setter_factory, 
TVMFFIPyCallContext* ctx,
-                                 PyObject* py_arg, TVMFFIAny* out) {
+  TVM_FFI_INLINE int SetArgument(TVMFFIPyCallContext* ctx, PyObject* py_arg, 
TVMFFIAny* out) {
     PyTypeObject* py_type = Py_TYPE(py_arg);
     // pre-zero the output argument, modulo the type index
     out->type_index = kTVMFFINone;
@@ -475,8 +628,8 @@ class TVMFFIPyCallManager {
     out->v_int64 = 0;
     // find the pre-cached setter
     // This class is thread-local, so we don't need to worry about race 
condition
-    auto it = dispatch_map_.find(py_type);
-    if (it != dispatch_map_.end()) {
+    auto it = arg_dispatch_map_.find(py_type);
+    if (it != arg_dispatch_map_.end()) {
       TVMFFIPyArgSetter setter = it->second;
       // if error happens, propagate it back
       if (setter(ctx, py_arg, out) != 0) return -1;
@@ -484,37 +637,190 @@ class TVMFFIPyCallManager {
       // no dispatch found, query and create a new one.
       TVMFFIPyArgSetter setter;
       // propagate python error back
-      if (setter_factory(py_arg, &setter) != 0) {
+      if (TVMFFICyArgSetterFactory(py_arg, &setter) != 0) {
         return -1;
       }
       // update dispatch table
-      dispatch_map_.emplace(py_type, setter);
+      arg_dispatch_map_.emplace(py_type, setter);
       if (setter(ctx, py_arg, out) != 0) return -1;
     }
     return 0;
   }
 
   /*!
-   * \brief Get the size of the dispatch map
-   * \return The size of the dispatch map
+   * \brief Get the size of the arg dispatch map
+   * \return The size of the arg dispatch map
+   */
+  size_t GetArgDispatchMapSize() const { return arg_dispatch_map_.size(); }
+
+  /*!
+   * \brief Convert a borrowed TVMFFIAny (AnyView) into a new-reference 
PyObject*.
+   * \param dlpack_exchange_api The DLPack exchange API (may be NULL).
+   * \param arg The borrowed TVMFFIAny to convert.
+   * \param py_arg The output PyObject*.
+   * \return 0 on success, -1 on failure. PyError should be set if -1 is 
returned.
    */
-  size_t GetDispatchMapSize() const { return dispatch_map_.size(); }
+  TVM_FFI_INLINE int SetPyCallbackArg(const DLPackExchangeAPI* 
dlpack_exchange_api,
+                                      const TVMFFIAny* arg, PyObject** out) {
+    size_t type_index = static_cast<size_t>(arg->type_index);
+    // Mirror of SetArgument for the C++ -> Python callback path: each per-type
+    // callback arg setter is responsible for its own refcount.
+    // hot path: cached hit
+    if (type_index < callback_arg_dispatch_table_.size() &&
+        callback_arg_dispatch_table_[type_index].func != nullptr) {
+      TVMFFIPyCallbackArgSetter* setter = 
&callback_arg_dispatch_table_[type_index];
+      return setter->func(setter, dlpack_exchange_api, arg, out);
+    }
+    // cold path: grow and populate via factory
+    if (type_index >= callback_arg_dispatch_table_.size()) {
+      // initialize empty entries with nullptr
+      callback_arg_dispatch_table_.resize(type_index + 1, 
TVMFFIPyCallbackArgSetter{nullptr});
+    }
+    TVMFFIPyCallbackArgSetter* setter = 
&callback_arg_dispatch_table_[type_index];
+    if (TVMFFICyCallbackArgSetterFactory(static_cast<int32_t>(type_index), 
setter) != 0) {
+      return -1;
+    }
+    return setter->func(setter, dlpack_exchange_api, arg, out);
+  }
+
+  /*!
+   * \brief Python Callback function entry
+   *
+   * \param context The TVMFFIPyCallbackClosure* holding the Python callable
+   *                and optional DLPack exchange API.
+   * \param packed_args The packed FFI arguments.
+   * \param num_args Number of arguments.
+   * \param result Output FFI result.
+   * \return 0 on success, -1 on error.
+   */
+  TVM_FFI_INLINE int PyCallback(void* context, const TVMFFIAny* packed_args, 
int32_t num_args,
+                                TVMFFIAny* result) noexcept {
+    TVMFFIPyWithAttachedThreadState thread_state;
+    auto* closure = static_cast<TVMFFIPyCallbackClosure*>(context);
+    // Wrap the body in try/catch so any C++ exception raised by the stack
+    // allocator (TVMFFIPyCallbackContext / TVMFFIPyCallContext), dispatch
+    // table resize in SetPyCallbackArg, or unordered_map::emplace in
+    // SetArgument is converted into a PyErr + FFI error instead of
+    // triggering std::terminate via the noexcept contract.
+    try {
+      TVMFFIPyCallbackContext cb_ctx(&call_stack_, num_args);
+      // Step 1: Convert each packed arg (borrowed AnyView) to a PyObject*
+      for (int32_t i = 0; i < num_args; ++i) {
+        if (SetPyCallbackArg(closure->dlpack_exchange_api, &packed_args[i], 
&cb_ctx.py_args[i]) !=
+            0) {
+          ForwardPyErrorToFFI();
+          return -1;
+        }
+        // must set active arguments count to ensure correct recycling
+        cb_ctx.num_active_py_args = i + 1;
+      }
+      // Step 2: Call the Python function via vectorcall. Wrap py_result in
+      // a RAII guard so its +1 is released on every exit path, including
+      // the C++ exception path (e.g., bad_alloc from ret_ctx construction
+      // or SetArgument's emplace).
+#if PY_VERSION_HEX >= 0x03090000
+      PyObject* py_result_raw = PyObject_Vectorcall(closure->callable, 
cb_ctx.py_args,
+                                                    
static_cast<size_t>(num_args), nullptr);
+#else
+      // backward compatibility for Python 3.8
+      PyObject* py_result_raw = _PyObject_Vectorcall(closure->callable, 
cb_ctx.py_args,
+                                                     
static_cast<size_t>(num_args), nullptr);
+#endif
+      struct PyResultGuard {
+        PyObject* p;
+        ~PyResultGuard() {
+          if (p != nullptr) Py_DecRef(p);
+        }
+      } py_result{py_result_raw};
+      if (py_result.p == Py_None) {
+        // fast path for Py_None
+        result->type_index = kTVMFFINone;
+        result->zero_padding = 0;
+        result->v_int64 = 0;
+        return 0;
+      } else if (py_result.p != nullptr) {
+        // normal return
+        // Use SetArgument on a temporary view slot, then promote to owned.
+        // Note: SetArgument only borrows py_result's chandle into `view`; it
+        // does NOT inc. py_result must stay alive until AFTER
+        // TVMFFIAnyViewToOwnedAny has promoted the view to an owned ref,
+        // otherwise dec'ing py_result first could free the underlying object
+        // (e.g. if py_result owns the last ref to a freshly created tensor).
+        // The guard's destructor runs AFTER the return value is computed.
+        TVMFFIPyCallContext ret_ctx(&call_stack_, 1);
+        TVMFFIAny* view = ret_ctx.packed_args;
+        if (SetArgument(&ret_ctx, py_result.p, view) != 0) {
+          ForwardPyErrorToFFI();
+          return -1;
+        }
+        // TLS FFI error set on failure.
+        return TVMFFIAnyViewToOwnedAny(view, result);
+      } else {
+        // vectorcall failed
+        ForwardPyErrorToFFI();
+        return -1;
+      }
+    } catch (const std::exception& ex) {
+      // very rare, catch c++ exception and set python error
+      PyErr_SetString(PyExc_RuntimeError, ex.what());
+      ForwardPyErrorToFFI();
+      return -1;
+    }
+  }
+
+  /*!
+   * \brief Fetch the current Python exception and forward it to
+   *        TVMFFICyErrorSetRaisedFromPyError, then clear the Python error 
indicator.
+   *
+   * This helper correctly extracts the exception *value* (not just the type
+   * returned by PyErr_Occurred()) so that set_last_ffi_error can access the
+   * message and traceback.
+   */
+  static void ForwardPyErrorToFFI() noexcept {
+#if PY_VERSION_HEX >= 0x030C0000
+    // Python 3.12+: PyErr_Fetch / PyErr_NormalizeException are deprecated.
+    // PyErr_GetRaisedException returns an already-normalized exception
+    // instance and clears the indicator. Traceback is attached as usual.
+    PyObject* pvalue = PyErr_GetRaisedException();
+    if (pvalue != nullptr) {
+      TVMFFICyErrorSetRaisedFromPyError(pvalue);
+      Py_DecRef(pvalue);
+    }
+#else
+    // Python 3.9 - 3.11.
+    PyObject *ptype, *pvalue, *ptraceback;
+    PyErr_Fetch(&ptype, &pvalue, &ptraceback);
+    PyErr_NormalizeException(&ptype, &pvalue, &ptraceback);
+    if (ptraceback != nullptr) {
+      PyException_SetTraceback(pvalue, ptraceback);
+    }
+    TVMFFICyErrorSetRaisedFromPyError(pvalue);
+    Py_DecRef(ptype);
+    Py_DecRef(pvalue);
+    Py_DecRef(ptraceback);
+#endif
+  }
 
  private:
   TVMFFIPyCallManager() {
     static constexpr size_t kDefaultDispatchCapacity = 32;
-    dispatch_map_.reserve(kDefaultDispatchCapacity);
+    arg_dispatch_map_.reserve(kDefaultDispatchCapacity);
+    // Pre-allocate callback arg dispatch table for static type indices
+    static constexpr size_t kDefaultCallbackArgDispatchCapacity = 128;
+    callback_arg_dispatch_table_.resize(kDefaultCallbackArgDispatchCapacity);
   }
 
-  // internal dispacher
-  std::unordered_map<PyTypeObject*, TVMFFIPyArgSetter> dispatch_map_;
+  // internal arg dispatch map: type -> argument setter
+  std::unordered_map<PyTypeObject*, TVMFFIPyArgSetter> arg_dispatch_map_;
   // call stack
   TVMFFIPyCallStack call_stack_;
+  // callback arg setter dispatch table indexed by type_index (view-based path
+  // used by PyCallback; see TVMFFIPyCallbackArgSetter docs above)
+  std::vector<TVMFFIPyCallbackArgSetter> callback_arg_dispatch_table_;
 };
 
 /*!
  * \brief Call a function with a variable number of arguments
- * \param setter_factory The factory function to create the setter
  * \param func_handle The handle of the function to call
  * \param py_arg_tuple The arguments to the function
  * \param result The result of the function
@@ -523,13 +829,11 @@ class TVMFFIPyCallManager {
  * \param out_ctx_dlpack_api The DLPack exchange API to be used for the result
  * \return 0 on success, nonzero on failure
  */
-TVM_FFI_INLINE int TVMFFIPyFuncCall(TVMFFIPyArgSetterFactory setter_factory, 
void* func_handle,
-                                    PyObject* py_arg_tuple, TVMFFIAny* result, 
int* c_api_ret_code,
-                                    bool release_gil = true,
+TVM_FFI_INLINE int TVMFFIPyFuncCall(void* func_handle, PyObject* py_arg_tuple, 
TVMFFIAny* result,
+                                    int* c_api_ret_code, bool release_gil = 
true,
                                     const DLPackExchangeAPI** 
out_ctx_dlpack_api = nullptr) {
-  return TVMFFIPyCallManager::ThreadLocal()->FuncCall(setter_factory, 
func_handle, py_arg_tuple,
-                                                      result, c_api_ret_code, 
release_gil,
-                                                      out_ctx_dlpack_api);
+  return TVMFFIPyCallManager::ThreadLocal()->FuncCall(
+      func_handle, py_arg_tuple, result, c_api_ret_code, release_gil, 
out_ctx_dlpack_api);
 }
 
 /*!
@@ -542,7 +846,6 @@ TVM_FFI_INLINE int 
TVMFFIPyFuncCall(TVMFFIPyArgSetterFactory setter_factory, voi
  *
  * This function will also not release the GIL since constructor call is 
usually cheap.
  *
- * \param setter_factory The factory function to create the setter
  * \param func_handle The handle of the function to call
  * \param py_arg_tuple The arguments to the constructor
  * \param result The result of the constructor
@@ -552,17 +855,15 @@ TVM_FFI_INLINE int 
TVMFFIPyFuncCall(TVMFFIPyArgSetterFactory setter_factory, voi
  * \param out_dlpack_exporter The DLPack exporter to be used for the result
  * \return 0 on success, nonzero on failure
  */
-TVM_FFI_INLINE int TVMFFIPyConstructorCall(TVMFFIPyArgSetterFactory 
setter_factory,
-                                           void* func_handle, PyObject* 
py_arg_tuple,
+TVM_FFI_INLINE int TVMFFIPyConstructorCall(void* func_handle, PyObject* 
py_arg_tuple,
                                            TVMFFIAny* result, int* 
c_api_ret_code,
                                            TVMFFIPyCallContext* parent_ctx) {
-  return TVMFFIPyCallManager::ThreadLocal()->ConstructorCall(
-      setter_factory, func_handle, py_arg_tuple, result, c_api_ret_code, 
parent_ctx);
+  return TVMFFIPyCallManager::ThreadLocal()->ConstructorCall(func_handle, 
py_arg_tuple, result,
+                                                             c_api_ret_code, 
parent_ctx);
 }
 
 /*!
  * \brief Set a field of a FFI object
- * \param setter_factory The factory function to create the setter
  * \param field_setter The field setter (function pointer or FunctionObj 
handle)
  * \param field_flags The field flags (to dispatch between function pointer 
and FunctionObj)
  * \param field_ptr The pointer to the field
@@ -570,49 +871,104 @@ TVM_FFI_INLINE int 
TVMFFIPyConstructorCall(TVMFFIPyArgSetterFactory setter_facto
  * \param c_api_ret_code The return code of the function
  * \return 0 on success, nonzero on failure
  */
-TVM_FFI_INLINE int TVMFFIPyCallFieldSetter(TVMFFIPyArgSetterFactory 
setter_factory,
-                                           void* field_setter, int64_t 
field_flags, void* field_ptr,
+TVM_FFI_INLINE int TVMFFIPyCallFieldSetter(void* field_setter, int64_t 
field_flags, void* field_ptr,
                                            PyObject* py_arg, int* 
c_api_ret_code) {
-  return TVMFFIPyCallManager::ThreadLocal()->SetField(setter_factory, 
field_setter, field_flags,
-                                                      field_ptr, py_arg, 
c_api_ret_code);
+  return TVMFFIPyCallManager::ThreadLocal()->SetField(field_setter, 
field_flags, field_ptr, py_arg,
+                                                      c_api_ret_code);
 }
 
 /*!
  * \brief Set an python argument to a FFI Any using the generic dispatcher in 
call manager
- * \param setter_factory The factory function to create the setter
  * \param ctx The call context
  * \param py_arg_tvm_ffi_value The python argument to be set using the 
__tvm_ffi_value__ protocol
  * \param out The output argument
  * \return 0 on success, nonzero on failure
  */
-TVM_FFI_INLINE int 
TVMFFIPySetArgumentGenericDispatcher(TVMFFIPyArgSetterFactory setter_factory,
-                                                        TVMFFIPyCallContext* 
ctx,
+TVM_FFI_INLINE int TVMFFIPySetArgumentGenericDispatcher(TVMFFIPyCallContext* 
ctx,
                                                         PyObject* 
py_arg_tvm_ffi_value,
                                                         TVMFFIAny* out) {
-  return TVMFFIPyCallManager::ThreadLocal()->SetArgument(setter_factory, ctx, 
py_arg_tvm_ffi_value,
-                                                         out);
+  return TVMFFIPyCallManager::ThreadLocal()->SetArgument(ctx, 
py_arg_tvm_ffi_value, out);
 }
 
 /*!
  * \brief Convert a Python object to a FFI Any
- * \param setter_factory The factory function to create the setter
  * \param py_arg The python argument to be set
  * \param out The output argument
  * \param c_api_ret_code The return code of the function
  * \return 0 on success, nonzero on failure
  */
-TVM_FFI_INLINE int TVMFFIPyPyObjectToFFIAny(TVMFFIPyArgSetterFactory 
setter_factory,
-                                            PyObject* py_arg, TVMFFIAny* out, 
int* c_api_ret_code) {
-  return TVMFFIPyCallManager::ThreadLocal()->PyObjectToFFIAny(setter_factory, 
py_arg, out,
-                                                              c_api_ret_code);
+TVM_FFI_INLINE int TVMFFIPyPyObjectToFFIAny(PyObject* py_arg, TVMFFIAny* out, 
int* c_api_ret_code) {
+  return TVMFFIPyCallManager::ThreadLocal()->PyObjectToFFIAny(py_arg, out, 
c_api_ret_code);
+}
+
+/*!
+ * \brief Get the size of the arg dispatch map
+ * \return The size of the arg dispatch map
+ */
+TVM_FFI_INLINE size_t TVMFFIPyGetArgDispatchMapSize() {
+  return TVMFFIPyCallManager::ThreadLocal()->GetArgDispatchMapSize();
 }
 
+//---------------------------------------------------------------------------------------------
+// Free function wrapper for the Python callback path.
+// Mirrors the pattern of TVMFFIPyFuncCall / TVMFFIPyConstructorCall: a 
top-level
+// TVM_FFI_INLINE free function that forwards to the thread-local manager.
+//---------------------------------------------------------------------------------------------
+
 /*!
- * \brief Get the size of the dispatch map
- * \return The size of the dispatch map
+ * \brief C-callable Python callback entry point (TVMFFISafeCallType shape).
+ *
+ * Forwards to TVMFFIPyCallManager::ThreadLocal()->PyCallback. Designed to be
+ * installed as the safe_call pointer for FFI functions that wrap a Python
+ * callable.
+ *
+ * \note The `context` argument is interpreted as a TVMFFIPyCallbackClosure*
+ *       by the manager (see TVMFFIPyConvertPyCallback).
  */
-TVM_FFI_INLINE size_t TVMFFIPyGetDispatchMapSize() {
-  return TVMFFIPyCallManager::ThreadLocal()->GetDispatchMapSize();
+TVM_FFI_INLINE int TVMFFIPyCallback(void* context, const TVMFFIAny* 
packed_args, int32_t num_args,
+                                    TVMFFIAny* result) noexcept {
+  return TVMFFIPyCallManager::ThreadLocal()->PyCallback(context, packed_args, 
num_args, result);
+}
+
+/*!
+ * \brief Create an FFI function handle from a Python callable + optional 
DLPack exchange API.
+ *
+ * Allocates a TVMFFIPyCallbackClosure on the heap, IncRefs the callable, and
+ * registers it with the FFI function-creation API using TVMFFIPyCallback as 
the
+ * safe-call entry point and TVMFFIPyCallbackClosure::Deleter as the deleter.
+ *
+ * Returns the raw FFI return code (TLS FFI error set on failure). The Cython
+ * caller uses CHECK_CALL to translate it into a Python exception.
+ *
+ * \param callable The Python callable to wrap. Must be non-NULL.
+ * \param dlpack_api Optional DLPack exchange API. May be NULL.
+ * \param out_handle Destination for the new FFI function handle.
+ * \return The return code from TVMFFIFunctionCreate (0 on success).
+ */
+TVM_FFI_INLINE int TVMFFIPyConvertPyCallback(PyObject* callable,
+                                             const DLPackExchangeAPI* 
dlpack_api,
+                                             TVMFFIObjectHandle* out_handle) 
noexcept {
+  // Use nothrow new: plain `new` can throw std::bad_alloc, which in this
+  // noexcept function would trigger std::terminate. On allocation failure,
+  // set PyErr and return -1 so the Cython caller's CHECK_CALL surfaces it.
+  auto* raw = new (std::nothrow) TVMFFIPyCallbackClosure{callable, dlpack_api};
+  if (raw == nullptr) {
+    PyErr_NoMemory();
+    return -1;
+  }
+  // The callable's +1 is owned by the closure; 
TVMFFIPyCallbackClosure::Deleter
+  // is responsible for Py_DecRef on destruction. By wiring the same Deleter as
+  // the unique_ptr deleter, the failure path below (unique_ptr unwind) runs
+  // the same cleanup as the success path (invoked by the FFI runtime).
+  Py_IncRef(callable);
+  std::unique_ptr<TVMFFIPyCallbackClosure, void (*)(void*)> closure(
+      raw, &TVMFFIPyCallbackClosure::Deleter);
+  int rc = TVMFFIFunctionCreate(closure.get(), &TVMFFIPyCallback, 
&TVMFFIPyCallbackClosure::Deleter,
+                                out_handle);
+  // On success, transfer ownership to the FFI function; on failure, let
+  // unique_ptr unwind via Deleter (decrefs the callable, frees the closure).
+  if (rc == 0) closure.release();
+  return rc;
 }
 
 /*!
diff --git a/python/tvm_ffi/cython/type_info.pxi 
b/python/tvm_ffi/cython/type_info.pxi
index 86b9e65..d7f39be 100644
--- a/python/tvm_ffi/cython/type_info.pxi
+++ b/python/tvm_ffi/cython/type_info.pxi
@@ -54,7 +54,6 @@ cdef class FieldSetter:
         cdef int c_api_ret_code
         cdef void* field_ptr = (<char*>(<CObject>obj).chandle) + self.offset
         TVMFFIPyCallFieldSetter(
-            TVMFFIPyArgSetterFactory_,
             self.setter,
             self.flags,
             field_ptr,
@@ -936,7 +935,6 @@ cdef _register_one_field(
         default_obj = py_field.default_factory
     if default_obj is not MISSING:
         TVMFFIPyPyObjectToFFIAny(
-            TVMFFIPyArgSetterFactory_,
             <PyObject*>default_obj,
             &default_any,
             &c_api_ret_code
@@ -1148,7 +1146,6 @@ cdef _register_py_methods(int32_t type_index, list 
py_methods, frozenset type_at
 
             # Convert Python object -> TVMFFIAny
             TVMFFIPyPyObjectToFFIAny(
-                TVMFFIPyArgSetterFactory_,
                 <PyObject*>func,
                 &func_any,
                 &c_api_ret_code,
diff --git a/tests/python/test_function.py b/tests/python/test_function.py
index 686fc08..3f673c1 100644
--- a/tests/python/test_function.py
+++ b/tests/python/test_function.py
@@ -25,6 +25,13 @@ import numpy as np
 import pytest
 import tvm_ffi
 
+try:
+    import torch
+except ImportError:
+    torch = None  # ty: ignore[invalid-assignment]
+
+_HAS_TORCH_DLPACK_API = torch is not None and hasattr(torch.Tensor, 
"__dlpack_c_exchange_api__")
+
 
 def test_echo() -> None:
     fecho = tvm_ffi.get_global_func("testing.echo")
@@ -402,3 +409,61 @@ def test_function_with_value_protocol() -> None:
 
     nested_value_protocol = ValueProtocol([ValueProtocol(1), ValueProtocol(2), 
ValueProtocol(3)])
     assert tuple(fecho(nested_value_protocol)) == (1, 2, 3)
+
+
+def test_convert_func_tensor_cls_missing_attribute() -> None:
+    """Passing a class without __dlpack_c_exchange_api__ raises TypeError."""
+
+    class DummyTensor:
+        pass
+
+    with pytest.raises(TypeError, match="__dlpack_c_exchange_api__"):
+        tvm_ffi.convert_func(lambda x: x, tensor_cls=DummyTensor)
+
+    with pytest.raises(TypeError, match="__dlpack_c_exchange_api__"):
+        tvm_ffi.convert_func(lambda x: x, tensor_cls=object)
+
+
+def test_convert_func_raises_propagates() -> None:
+    """An exception raised inside the callback propagates out to the caller."""
+
+    def raises(x: int) -> None:
+        raise ValueError(f"boom {x}")
+
+    f = tvm_ffi.convert_func(raises)
+    with pytest.raises(ValueError, match="boom 42"):
+        f(42)
+
+
[email protected](
+    not _HAS_TORCH_DLPACK_API,
+    reason="torch.Tensor.__dlpack_c_exchange_api__ not available",
+)
+def test_convert_func_with_torch_tensor_cls() -> None:
+    """tensor_cls=torch.Tensor delivers torch.Tensor instances to the callback.
+
+    Asserts the type *inside* the callback (which runs on the C++ -> Python
+    side of the FFI boundary) — the return value's Python type depends on
+    the outer caller's conversion path, so we verify shape survives the
+    round-trip rather than isinstance on the return.
+    """
+    calls = 0
+
+    def callback(a: Any, b: Any, c: Any) -> Any:
+        nonlocal calls
+        calls += 1
+        assert isinstance(a, torch.Tensor)
+        assert isinstance(b, torch.Tensor)
+        assert isinstance(c, torch.Tensor)
+        assert list(a.shape) == [2]
+        assert list(b.shape) == [3]
+        assert list(c.shape) == [4]
+        return b
+
+    f = tvm_ffi.convert_func(callback, tensor_cls=torch.Tensor)
+    a = torch.zeros(2)
+    b = torch.ones(3)
+    c = torch.full((4,), 2.0)
+    out = f(a, b, c)
+    assert calls == 1
+    assert tuple(out.shape) == (3,)
diff --git a/tests/scripts/benchmark_pycallback.py 
b/tests/scripts/benchmark_pycallback.py
new file mode 100644
index 0000000..74b97f4
--- /dev/null
+++ b/tests/scripts/benchmark_pycallback.py
@@ -0,0 +1,111 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Benchmark C++ -> Python callback overhead with 3 torch.Tensor arguments.
+
+Both variants are invoked by the same C++ ``invoke_n`` loop so the per-call
+cost reflects only the callback-arg conversion path:
+
+1. ``convert_func(cb, tensor_cls=torch.Tensor)`` — the DLPack exchange API is
+   threaded into the closure, so each tensor arg is converted to a
+   ``torch.Tensor`` by the C-level callback arg setter before the callback 
runs.
+2. ``convert_func(cb)`` — callback receives an ``ffi.Tensor`` and calls
+   ``torch.from_dlpack(x)`` explicitly inside the callback body for each arg.
+
+Arguments are 3 x ``torch.zeros(1, device="cuda:0")``.
+"""
+
+from __future__ import annotations
+
+import time
+
+import torch
+import tvm_ffi
+import tvm_ffi.cpp
+from benchmark_dlpack import print_speed
+
+_INVOKE_N_CPP_SOURCE = r"""
+#include <tvm/ffi/function.h>
+
+void invoke_n(tvm::ffi::Function callback, int64_t n,
+              tvm::ffi::AnyView a, tvm::ffi::AnyView b, tvm::ffi::AnyView c) {
+    for (int64_t i = 0; i < n; ++i) {
+        callback(a, b, c);
+    }
+}
+"""
+
+
+def _load_invoke_n() -> object:
+    mod = tvm_ffi.cpp.load_inline(
+        name="benchmark_pycallback_invoke_n",
+        cpp_sources=_INVOKE_N_CPP_SOURCE,
+        functions=["invoke_n"],
+    )
+    return mod.get_function("invoke_n")
+
+
+def bench_pycallback_tensor_cls_torch(invoke_n, a, b, c, repeat: int) -> None: 
 # noqa: ANN001
+    """convert_func(cb, tensor_cls=torch.Tensor): callback sees torch.Tensor 
directly."""
+
+    def cb(_a, _b, _c) -> None:  # noqa: ANN001
+        pass
+
+    callback = tvm_ffi.convert_func(cb, tensor_cls=torch.Tensor)
+    invoke_n(callback, 10, a, b, c)
+    start = time.time()
+    invoke_n(callback, repeat, a, b, c)
+    end = time.time()
+    print_speed("pycallback[tensor_cls=torch.Tensor]", (end - start) / repeat)
+
+
+def bench_pycallback_from_dlpack(invoke_n, a, b, c, repeat: int) -> None:  # 
noqa: ANN001
+    """convert_func(cb): callback receives ffi.Tensor, does 
torch.from_dlpack(x) explicitly."""
+
+    def cb(_a, _b, _c) -> None:  # noqa: ANN001
+        torch.from_dlpack(_a)
+        torch.from_dlpack(_b)
+        torch.from_dlpack(_c)
+
+    callback = tvm_ffi.convert_func(cb)
+    invoke_n(callback, 10, a, b, c)
+    start = time.time()
+    invoke_n(callback, repeat, a, b, c)
+    end = time.time()
+    print_speed("pycallback+from_dlpack", (end - start) / repeat)
+
+
+def main() -> None:
+    if not hasattr(torch.Tensor, "__dlpack_c_exchange_api__"):
+        raise SystemExit("torch.Tensor.__dlpack_c_exchange_api__ not 
available")
+
+    repeat = 10000
+    invoke_n = _load_invoke_n()
+    a = torch.zeros(1, device="cuda:0")
+    b = torch.zeros(1, device="cuda:0")
+    c = torch.zeros(1, device="cuda:0")
+
+    print("---------------------------------------------------")
+    print("Benchmark C++ -> Python callback with 3 torch.Tensor args")
+    print('Arguments: 3 x torch.zeros(1, device="cuda:0")')
+    print("---------------------------------------------------")
+    bench_pycallback_tensor_cls_torch(invoke_n, a, b, c, repeat)
+    bench_pycallback_from_dlpack(invoke_n, a, b, c, repeat)
+    print("---------------------------------------------------")
+
+
+if __name__ == "__main__":
+    main()

Reply via email to