This is an automated email from the ASF dual-hosted git repository.
cyx-6 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new bc64ba3 [REFACTOR] Optimized Python callback path via
TVMFFIPyCallbackArgSetter (#569)
bc64ba3 is described below
commit bc64ba3e981ae44a542ed4e7590320b15e530f57
Author: Tianqi Chen <[email protected]>
AuthorDate: Sat Apr 25 12:57:01 2026 -0400
[REFACTOR] Optimized Python callback path via TVMFFIPyCallbackArgSetter
(#569)
This PR optimizes python callback specifically when the arguments expect
a torch.Tensor class.
`tvm_ffi.convert_func(pyfunc, tensor_cls=None)` is provided. If
`tensor_cls`
is provided, the closure threads its `__dlpack_c_exchange_api__` capsule
into the callback so tensor args are delivered as `tensor_cls` instances
(e.g., `torch.Tensor`) directly from the C-level setter — no per-arg
Python-level conversion without python torch dlpack conversion overhead
Example:
```python
import torch
import tvm_ffi
def callback(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> None:
# a, b, c arrive as torch.Tensor, converted in C via the closure's
# DLPack exchange API — no from_dlpack() needed inside the callback.
torch.add(a, b, out=c)
f = tvm_ffi.convert_func(callback, tensor_cls=torch.Tensor)
```
| Variant | Per-call |
|---------------------------------------|----------|
| pycallback[tensor_cls=torch.Tensor] | ~640 ns |
| pycallback+from_dlpack (explicit x3) | ~9.1 us |
---
docs/reference/python/index.rst | 1 +
python/tvm_ffi/__init__.py | 3 +-
python/tvm_ffi/_convert.py | 69 +++-
python/tvm_ffi/core.pyi | 2 +-
python/tvm_ffi/cython/base.pxi | 42 ++-
python/tvm_ffi/cython/core.pyx | 1 +
python/tvm_ffi/cython/error.pxi | 12 +
python/tvm_ffi/cython/function.pxi | 68 +---
python/tvm_ffi/cython/object.pxi | 4 +-
python/tvm_ffi/cython/pycallback.pxi | 285 ++++++++++++++
python/tvm_ffi/cython/tvm_ffi_python_helpers.h | 502 +++++++++++++++++++++----
python/tvm_ffi/cython/type_info.pxi | 3 -
tests/python/test_function.py | 65 ++++
tests/scripts/benchmark_pycallback.py | 111 ++++++
14 files changed, 1018 insertions(+), 150 deletions(-)
diff --git a/docs/reference/python/index.rst b/docs/reference/python/index.rst
index 93144d9..a59a587 100644
--- a/docs/reference/python/index.rst
+++ b/docs/reference/python/index.rst
@@ -134,6 +134,7 @@ Misc
access_path.AccessPath
access_path.AccessStep
convert
+ convert_func
ObjectConvertible
.. (Experimental) Dataclasses
diff --git a/python/tvm_ffi/__init__.py b/python/tvm_ffi/__init__.py
index b1872a7..012de31 100644
--- a/python/tvm_ffi/__init__.py
+++ b/python/tvm_ffi/__init__.py
@@ -67,7 +67,7 @@ if TYPE_CHECKING or not _is_config_mode():
)
from ._dtype import dtype
from .core import Object, ObjectConvertible, Function, CAny, CContainerBase
- from ._convert import convert
+ from ._convert import convert, convert_func
from .error import register_error
from ._tensor import Device, device, DLDeviceType
from ._tensor import from_dlpack, Tensor, Shape
@@ -146,6 +146,7 @@ __all__ = [
"__version_tuple__",
"access_path",
"convert",
+ "convert_func",
"cpp",
"dataclasses",
"device",
diff --git a/python/tvm_ffi/_convert.py b/python/tvm_ffi/_convert.py
index 49e7931..c04688c 100644
--- a/python/tvm_ffi/_convert.py
+++ b/python/tvm_ffi/_convert.py
@@ -21,7 +21,7 @@ from __future__ import annotations
import ctypes
from numbers import Number
from types import ModuleType
-from typing import Any
+from typing import Any, Callable
from . import _dtype, container, core
@@ -138,3 +138,70 @@ def convert(value: Any) -> Any: # noqa: PLR0911,PLR0912
else:
# in this case, it is an opaque python object
return core._convert_to_opaque_object(value)
+
+
+def convert_func(
+ pyfunc: Callable[..., Any],
+ tensor_cls: type | None = None,
+) -> Any:
+ """Convert a Python callable to an FFI :py:class:`~tvm_ffi.Function`.
+
+ This is the callable-specific sibling of :py:func:`tvm_ffi.convert`.
+ It accepts one extra argument, ``tensor_cls``, that lets the caller
+ specify how tensor arguments should be delivered to the Python
+ callable when the resulting :py:class:`Function` is invoked from C++.
+ :py:func:`tvm_ffi.convert` has no such knob — it always produces a
+ :py:class:`Function` whose callback receives ``tvm_ffi.Tensor`` for
+ tensor args.
+
+ Parameters
+ ----------
+ pyfunc : Callable
+ The Python callable to wrap.
+ tensor_cls : type, optional
+ The class whose instances the callback should receive for tensor
+ args. The class must expose a ``__dlpack_c_exchange_api__``
+ :py:class:`PyCapsule`; its capsule is threaded into the callback
+ closure so tensor args are converted at the C level (via the
+ DLPack exchange API) before the Python callback body runs — this
+ is significantly faster than calling ``torch.from_dlpack(x)`` (or
+ equivalent) inside the callback. Raises :py:class:`TypeError` if
+ ``tensor_cls`` does not expose the attribute.
+
+ When ``tensor_cls`` is ``None``, ``convert_func`` behaves like the
+ callable branch of :py:func:`tvm_ffi.convert`.
+
+ Returns
+ -------
+ Function
+ The wrapped FFI function.
+
+ Examples
+ --------
+ .. code-block:: python
+
+ import torch
+ import tvm_ffi
+
+ # Without tensor_cls: same as tvm_ffi.convert(pyfunc) — the callback
+ # receives tvm_ffi.Tensor for tensor args.
+ f = tvm_ffi.convert_func(lambda x: x + 1)
+ assert isinstance(f, tvm_ffi.Function)
+
+
+ # With tensor_cls=torch.Tensor: the callback receives torch.Tensor
+ # directly; the DLPack conversion happens in C before the body runs.
+ def callback(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
+ return a + b
+
+
+ g = tvm_ffi.convert_func(callback, tensor_cls=torch.Tensor)
+
+ See Also
+ --------
+ :py:func:`tvm_ffi.convert` :
+ Generic value-to-FFI conversion. Use this when you don't need to
+ specify ``tensor_cls``.
+
+ """
+ return core._convert_to_ffi_func(pyfunc, tensor_cls=tensor_cls)
diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi
index 90071c7..892a3b7 100644
--- a/python/tvm_ffi/core.pyi
+++ b/python/tvm_ffi/core.pyi
@@ -210,7 +210,7 @@ def _register_global_func(
name: str, pyfunc: Callable[..., Any] | Function, override: bool
) -> Function: ...
def _get_global_func(name: str, allow_missing: bool) -> Function | None: ...
-def _convert_to_ffi_func(pyfunc: Callable[..., Any]) -> Function: ...
+def _convert_to_ffi_func(pyfunc: Callable[..., Any], tensor_cls: type | None =
...) -> Function: ...
def _convert_to_opaque_object(pyobject: Any) -> OpaquePyObject: ...
def _print_debug_info() -> None: ...
def _register_py_class(parent_type_info: TypeInfo, type_key: str, type_cls:
type) -> TypeInfo: ...
diff --git a/python/tvm_ffi/cython/base.pxi b/python/tvm_ffi/cython/base.pxi
index b851e1e..c5c28a1 100644
--- a/python/tvm_ffi/cython/base.pxi
+++ b/python/tvm_ffi/cython/base.pxi
@@ -366,10 +366,8 @@ cdef extern from "tvm_ffi_python_helpers.h":
int (*func)(TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
PyObject* py_arg, TVMFFIAny* out) except -1
const DLPackExchangeAPI* dlpack_c_exchange_api
- ctypedef int (*TVMFFIPyArgSetterFactory)(PyObject* value,
TVMFFIPyArgSetter* out) except -1
# The main call function
int TVMFFIPyFuncCall(
- TVMFFIPyArgSetterFactory setter_factory,
void* chandle,
PyObject* py_arg_tuple,
TVMFFIAny* result,
@@ -379,7 +377,6 @@ cdef extern from "tvm_ffi_python_helpers.h":
) except -1
int TVMFFIPyConstructorCall(
- TVMFFIPyArgSetterFactory setter_factory,
void* chandle,
PyObject* py_arg_tuple,
TVMFFIAny* result,
@@ -388,7 +385,6 @@ cdef extern from "tvm_ffi_python_helpers.h":
) except -1
int TVMFFIPyCallFieldSetter(
- TVMFFIPyArgSetterFactory setter_factory,
void* field_setter,
int64_t field_flags,
void* field_ptr,
@@ -397,20 +393,18 @@ cdef extern from "tvm_ffi_python_helpers.h":
) except -1
int TVMFFIPyPyObjectToFFIAny(
- TVMFFIPyArgSetterFactory setter_factory,
PyObject* py_arg,
TVMFFIAny* out,
int* c_api_ret_code
) except -1
int TVMFFIPySetArgumentGenericDispatcher(
- TVMFFIPyArgSetterFactory setter_factory,
TVMFFIPyCallContext* ctx,
PyObject* py_arg,
TVMFFIAny* out
) except -1
- size_t TVMFFIPyGetDispatchMapSize() noexcept
+ size_t TVMFFIPyGetArgDispatchMapSize() noexcept
void TVMFFIPyPushTempFFIObject(TVMFFIPyCallContext* ctx,
TVMFFIObjectHandle arg) noexcept
void TVMFFIPyPushTempPyObject(TVMFFIPyCallContext* ctx, PyObject* arg)
noexcept
@@ -420,6 +414,40 @@ cdef extern from "tvm_ffi_python_helpers.h":
int TVMFFIPyArgSetterInt_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*,
PyObject* arg, TVMFFIAny* out) except -1
int TVMFFIPyArgSetterBool_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*,
PyObject* arg, TVMFFIAny* out) except -1
int TVMFFIPyArgSetterNone_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*,
PyObject* arg, TVMFFIAny* out) except -1
+
+ # Callback arg setter types — view-based AnyView -> PyObject conversion
+ # used by the C++ -> Python callback path (PyCallback).
+ ctypedef int
(*TVMFFIPyCallbackArgSetterCallback)(TVMFFIPyCallbackArgSetter* handle,
+ const DLPackExchangeAPI*
api,
+ const TVMFFIAny* arg,
+ PyObject** out) except -1
+
+ ctypedef struct TVMFFIPyCallbackArgSetter:
+ TVMFFIPyCallbackArgSetterCallback func
+
+ # Built-in C callback arg setters for POD types.
+ int TVMFFIPyCallbackArgSetterNone_(TVMFFIPyCallbackArgSetter*, const
DLPackExchangeAPI*,
+ const TVMFFIAny*, PyObject** out)
except -1
+ int TVMFFIPyCallbackArgSetterBool_(TVMFFIPyCallbackArgSetter*, const
DLPackExchangeAPI*,
+ const TVMFFIAny* arg, PyObject** out)
except -1
+ int TVMFFIPyCallbackArgSetterInt_(TVMFFIPyCallbackArgSetter*, const
DLPackExchangeAPI*,
+ const TVMFFIAny* arg, PyObject** out)
except -1
+ int TVMFFIPyCallbackArgSetterFloat_(TVMFFIPyCallbackArgSetter*, const
DLPackExchangeAPI*,
+ const TVMFFIAny* arg, PyObject** out)
except -1
+ int TVMFFIPyCallbackArgSetterSmallStr_(TVMFFIPyCallbackArgSetter*, const
DLPackExchangeAPI*,
+ const TVMFFIAny* arg, PyObject**
out) except -1
+ int TVMFFIPyCallbackArgSetterSmallBytes_(TVMFFIPyCallbackArgSetter*, const
DLPackExchangeAPI*,
+ const TVMFFIAny* arg, PyObject**
out) except -1
+
+ int TVMFFIPyCallback(void* context, const TVMFFIAny* packed_args,
+ int32_t num_args, TVMFFIAny* result) noexcept
+
+ # Closure + convert helper for the PyCallback path. Returns the raw FFI rc;
+ # callers use CHECK_CALL to translate the TLS FFI error into a Python
exception.
+ int TVMFFIPyConvertPyCallback(PyObject* callable,
+ const DLPackExchangeAPI* dlpack_api,
+ TVMFFIObjectHandle* out_handle) noexcept
+
# MLIRPackedSafeCall
void* TVMFFIPyMLIRPackedSafeCallCreate(void
(*mlir_packed_safe_call)(void**) noexcept, PyObject* keep_alive_object)
int TVMFFIPyMLIRPackedSafeCallInvoke(void* self, const TVMFFIAny* args,
int32_t num_args, TVMFFIAny* rv)
diff --git a/python/tvm_ffi/cython/core.pyx b/python/tvm_ffi/cython/core.pyx
index 2c0a161..b025a48 100644
--- a/python/tvm_ffi/cython/core.pyx
+++ b/python/tvm_ffi/cython/core.pyx
@@ -38,4 +38,5 @@ include "./tensor.pxi"
_register_object_by_index(kTVMFFITensor, Tensor)
include "./function.pxi"
_register_object_by_index(kTVMFFIFunction, Function)
+include "./pycallback.pxi"
include "./pyclass_type_converter.pxi"
diff --git a/python/tvm_ffi/cython/error.pxi b/python/tvm_ffi/cython/error.pxi
index 6f8159c..cbd1194 100644
--- a/python/tvm_ffi/cython/error.pxi
+++ b/python/tvm_ffi/cython/error.pxi
@@ -150,6 +150,18 @@ def _convert_to_ffi_error(error: BaseException) -> Error:
return Error(kind, message, py_backtrace)
+cdef public int TVMFFICyErrorSetRaisedFromPyError(PyObject* py_err) noexcept:
+ """Set the last FFI error from a Python exception.
+
+ Parameters
+ ----------
+ py_err : PyObject*
+ The Python exception to set as the last FFI error.
+ """
+ set_last_ffi_error(<object>py_err)
+ return -1
+
+
cdef inline int CHECK_CALL(int ret) except -2:
"""Check the return code of the C API function call"""
if ret == 0:
diff --git a/python/tvm_ffi/cython/function.pxi
b/python/tvm_ffi/cython/function.pxi
index 48aaf5c..f613fda 100644
--- a/python/tvm_ffi/cython/function.pxi
+++ b/python/tvm_ffi/cython/function.pxi
@@ -120,7 +120,7 @@ cdef inline int ConstructorCall(void* constructor_handle,
result.type_index = kTVMFFINone
result.v_int64 = 0
TVMFFIPyConstructorCall(
- TVMFFIPyArgSetterFactory_, constructor_handle, py_arg_tuple, &result,
&c_api_ret_code,
+ constructor_handle, py_arg_tuple, &result, &c_api_ret_code,
parent_ctx
)
CHECK_CALL(c_api_ret_code)
@@ -556,9 +556,8 @@ cdef int TVMFFIPyArgSetterCallable_(
PyObject* py_arg, TVMFFIAny* out
) except -1:
"""Setter for Callable"""
- cdef object arg = <object>py_arg
cdef TVMFFIObjectHandle chandle
- _convert_to_ffi_func_handle(arg, &chandle)
+ CHECK_CALL(TVMFFIPyConvertPyCallback(py_arg, NULL, &chandle))
out.type_index = TVMFFIObjectGetTypeIndex(chandle)
out.v_ptr = chandle
TVMFFIPyPushTempFFIObject(ctx, chandle)
@@ -725,16 +724,14 @@ cdef int TVMFFIPyArgSetterFFIValueProtocol_(
# keep alive the python object since this is a temporary object
# we must push to extra temp py objects stack to avoid overflow the temp
py objects stack
TVMFFIPyPushExtraTempPyObject(ctx, ffi_value_py_obj_ptr)
- return TVMFFIPySetArgumentGenericDispatcher(
- TVMFFIPyArgSetterFactory_, ctx, ffi_value_py_obj_ptr, out
- )
+ return TVMFFIPySetArgumentGenericDispatcher(ctx, ffi_value_py_obj_ptr, out)
cdef _DISPATCH_TYPE_KEEP_ALIVE = set()
cdef _DISPATCH_TYPE_KEEP_ALIVE_LOCK = threading.Lock()
-cdef int TVMFFIPyArgSetterFactory_(PyObject* value, TVMFFIPyArgSetter* out)
except -1:
+cdef public int TVMFFICyArgSetterFactory(PyObject* value, TVMFFIPyArgSetter*
out) except -1:
"""
Factory function that creates an argument setter for a given Python
argument type.
"""
@@ -950,7 +947,6 @@ cdef class Function(CObject):
result.type_index = kTVMFFINone
result.v_int64 = 0
TVMFFIPyFuncCall(
- TVMFFIPyArgSetterFactory_,
(<CObject>self).chandle, <PyObject*>args,
&result,
&c_api_ret_code,
@@ -1108,56 +1104,6 @@ def _get_global_func(name: str, allow_missing: bool):
raise ValueError("Cannot find global function %s" % name)
-cdef int tvm_ffi_callback(void* context,
- const TVMFFIAny* packed_args,
- int32_t num_args,
- TVMFFIAny* result) noexcept with gil:
- cdef list pyargs
- cdef TVMFFIAny temp_result
- cdef int c_api_ret_code
- cdef object local_pyfunc = <object>(context)
- pyargs = []
-
- for i in range(num_args):
- CHECK_CALL(TVMFFIAnyViewToOwnedAny(&packed_args[i], &temp_result))
- pyargs.append(make_ret(temp_result))
-
- try:
- rv = local_pyfunc(*pyargs)
- TVMFFIPyPyObjectToFFIAny(
- TVMFFIPyArgSetterFactory_,
- <PyObject*>rv,
- result,
- &c_api_ret_code
- )
- return c_api_ret_code
- except Exception as err:
- set_last_ffi_error(err)
- return -1
-
-
-cdef inline int _convert_to_ffi_func_handle(
- object pyfunc, TVMFFIObjectHandle* out_handle
-) except -1:
- """Convert a python function to TVM FFI function handle"""
- Py_INCREF(pyfunc)
- CHECK_CALL(TVMFFIFunctionCreate(
- <void*>(pyfunc),
- tvm_ffi_callback,
- TVMFFIPyObjectDeleter,
- out_handle))
- return 0
-
-
-def _convert_to_ffi_func(object pyfunc: Callable[..., Any]) -> Function:
- """Convert a python function to TVM FFI function"""
- cdef TVMFFIObjectHandle chandle
- _convert_to_ffi_func_handle(pyfunc, &chandle)
- ret = Function.__new__(Function)
- (<CObject>ret).chandle = chandle
- return ret
-
-
cdef inline int _convert_to_opaque_object_handle(
object pyobject, TVMFFIObjectHandle* out_handle
) except -1:
@@ -1201,9 +1147,9 @@ def _testing_drop_last_ref_without_thread_state() -> None:
def _print_debug_info() -> None:
- """Get the size of the dispatch map"""
- cdef size_t size = TVMFFIPyGetDispatchMapSize()
- print(f"TVMFFIPyGetDispatchMapSize: {size}")
+ """Get the size of the arg dispatch map"""
+ cdef size_t size = TVMFFIPyGetArgDispatchMapSize()
+ print(f"TVMFFIPyGetArgDispatchMapSize: {size}")
cdef Function _OBJECT_FROM_JSON_GRAPH_STR =
_get_global_func("ffi.FromJSONGraphString", True)
diff --git a/python/tvm_ffi/cython/object.pxi b/python/tvm_ffi/cython/object.pxi
index 4b8c52d..803ead2 100644
--- a/python/tvm_ffi/cython/object.pxi
+++ b/python/tvm_ffi/cython/object.pxi
@@ -200,7 +200,7 @@ cdef class CContainerBase(CObject):
# 1. __dlpack_c_exchange_api__ (e.g. torch.Tensor) — points to a
# static struct in the framework's C++ runtime. The source
# type is kept alive by _DISPATCH_TYPE_KEEP_ALIVE (set in
- # TVMFFIPyArgSetterFactory_), which prevents module unloading.
+ # TVMFFICyArgSetterFactory), which prevents module unloading.
#
# 2. GetTorchFallbackExchangeAPI() — returns the address of a
# module-level Cython static; lives for the entire process.
@@ -758,7 +758,6 @@ def _register_type_attr(type_index: int32_t, attr_key: str,
value: object) -> No
temp.type_index = kTVMFFINone
temp.v_int64 = 0
TVMFFIPyPyObjectToFFIAny(
- TVMFFIPyArgSetterFactory_,
<PyObject*>value,
&temp,
&c_api_ret_code,
@@ -818,7 +817,6 @@ cdef class CAny:
temp.type_index = kTVMFFINone
temp.v_int64 = 0
TVMFFIPyPyObjectToFFIAny(
- TVMFFIPyArgSetterFactory_,
<PyObject*>value,
&temp,
&c_api_ret_code
diff --git a/python/tvm_ffi/cython/pycallback.pxi
b/python/tvm_ffi/cython/pycallback.pxi
new file mode 100644
index 0000000..fd5f686
--- /dev/null
+++ b/python/tvm_ffi/cython/pycallback.pxi
@@ -0,0 +1,285 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# ----------------------------------------------------------------------------
+# Cython-side implementation of the C++ -> Python callback path.
+#
+# The C++ side (tvm_ffi_python_helpers.h + TVMFFIPyCallManager::PyCallback)
+# calls into the functions defined here via function pointers registered in
+# callback_arg_dispatch_table_. Each per-type setter takes a borrowed
+# AnyView and produces a new-reference PyObject*.
+#
+# The caller will grab the thread state before caling into each individual
setter.
+#
+# This file also hosts `_convert_to_ffi_func`, the Cython entry point that
+# wraps a Python callable as a FFI Function backed by a TVMFFIPyCallback
+# closure (see TVMFFIPyConvertPyCallback in the header).
+# ----------------------------------------------------------------------------
+
+
+cdef int TVMFFIPyCallbackArgSetterTensor_(
+ TVMFFIPyCallbackArgSetter* handle,
+ const DLPackExchangeAPI* api,
+ const TVMFFIAny* arg,
+ PyObject** out
+) except -1:
+ """Callback arg setter for kTVMFFITensor -> ffi.Tensor or torch.Tensor
(via DLPack).
+
+ The DLPack branch is inlined rather than delegated to
+ ``make_tensor_from_chandle`` so we can pass a borrowed chandle:
+ ``TensorObj::ToDLPackVersioned`` incs internally, so the inc/dec pair
+ that ``make_tensor_from_chandle`` requires on an owned chandle is pure
+ waste here. The non-DLPack branch upgrades to owned and reuses
+ ``make_tensor_from_chandle`` for consistency with the RValueRef path.
+ """
+ cdef TVMFFIObjectHandle chandle = arg.v_ptr
+ cdef DLManagedTensorVersioned* dlpack
+ cdef void* py_obj
+
+ if api != NULL and api.managed_tensor_to_py_object_no_sync != NULL:
+ if TVMFFITensorToDLPackVersioned(chandle, &dlpack) == 0:
+ try:
+ api.managed_tensor_to_py_object_no_sync(dlpack, &py_obj)
+ except Exception:
+ dlpack.deleter(dlpack)
+ raise
+ # py_obj already holds +1 from the DLPack import; transfer to
caller.
+ out[0] = <PyObject*>py_obj
+ return 0
+ # Non-DLPack path: upgrade borrowed -> owned, wrap via
make_tensor_from_chandle.
+ TVMFFIObjectIncRef(chandle)
+ obj = make_tensor_from_chandle(chandle, NULL)
+ Py_INCREF(obj)
+ out[0] = <PyObject*>obj
+ return 0
+
+
+cdef int TVMFFIPyCallbackArgSetterObject_(
+ TVMFFIPyCallbackArgSetter* handle,
+ const DLPackExchangeAPI* api,
+ const TVMFFIAny* arg,
+ PyObject** out
+) except -1:
+ """Callback arg setter for generic static object types (type_index >=
kTVMFFIStaticObjectBegin)."""
+ cdef TVMFFIObjectHandle chandle = arg.v_ptr
+ TVMFFIObjectIncRef(chandle)
+ try:
+ obj = make_ret_object(arg[0])
+ if api != NULL and isinstance(obj, CContainerBase):
+ (<CContainerBase>obj)._dlpack_exchange_api = api
+ except BaseException:
+ TVMFFIObjectDecRef(chandle)
+ raise
+ Py_INCREF(obj)
+ out[0] = <PyObject*>obj
+ return 0
+
+
+cdef int TVMFFIPyCallbackArgSetterOpaquePyObject_(
+ TVMFFIPyCallbackArgSetter* handle,
+ const DLPackExchangeAPI* api,
+ const TVMFFIAny* arg,
+ PyObject** out
+) except -1:
+ """Callback arg setter for kTVMFFIOpaquePyObject -> underlying Python
object.
+
+ Inlined equivalent of `make_ret_opaque_object`: reads the cell's Python
+ handle directly, skipping the throwaway OpaquePyObject wrapper that
+ would otherwise be created just to extract the handle. `arg` is
+ borrowed, but the cell stays alive for the callback's duration, so the
+ handle is safe to read without inc'ing the chandle.
+ """
+ cdef PyObject* py_handle =
<PyObject*>TVMFFIOpaqueObjectGetCellPtr(arg.v_ptr).handle
+ cdef object obj = <object>py_handle
+ Py_INCREF(obj)
+ out[0] = <PyObject*>obj
+ return 0
+
+
+cdef int TVMFFIPyCallbackArgSetterOpaquePtr_(
+ TVMFFIPyCallbackArgSetter* handle,
+ const DLPackExchangeAPI* api,
+ const TVMFFIAny* arg,
+ PyObject** out
+) except -1:
+ """Callback arg setter for kTVMFFIOpaquePtr -> ctypes.c_void_p."""
+ obj = ctypes_handle(arg.v_ptr)
+ Py_INCREF(obj)
+ out[0] = <PyObject*>obj
+ return 0
+
+
+cdef int TVMFFIPyCallbackArgSetterDataType_(
+ TVMFFIPyCallbackArgSetter* handle,
+ const DLPackExchangeAPI* api,
+ const TVMFFIAny* arg,
+ PyObject** out
+) except -1:
+ """Callback arg setter for kTVMFFIDataType -> DataType."""
+ obj = make_ret_dtype(arg[0])
+ Py_INCREF(obj)
+ out[0] = <PyObject*>obj
+ return 0
+
+
+cdef int TVMFFIPyCallbackArgSetterDevice_(
+ TVMFFIPyCallbackArgSetter* handle,
+ const DLPackExchangeAPI* api,
+ const TVMFFIAny* arg,
+ PyObject** out
+) except -1:
+ """Callback arg setter for kTVMFFIDevice -> Device."""
+ obj = make_ret_device(arg[0])
+ Py_INCREF(obj)
+ out[0] = <PyObject*>obj
+ return 0
+
+
+cdef int TVMFFIPyCallbackArgSetterDLTensorPtr_(
+ TVMFFIPyCallbackArgSetter* handle,
+ const DLPackExchangeAPI* api,
+ const TVMFFIAny* arg,
+ PyObject** out
+) except -1:
+ """Callback arg setter for kTVMFFIDLTensorPtr -> ffi.Tensor (via DLTensor
pointer)."""
+ obj = make_ret_dltensor(arg[0])
+ Py_INCREF(obj)
+ out[0] = <PyObject*>obj
+ return 0
+
+
+cdef int TVMFFIPyCallbackArgSetterRValueRef_(
+ TVMFFIPyCallbackArgSetter* handle,
+ const DLPackExchangeAPI* api,
+ const TVMFFIAny* arg,
+ PyObject** out
+) except -1:
+ """Callback arg setter for kTVMFFIObjectRValueRef.
+
+ For RValueRef, ``arg.v_ptr`` is an ``Object**`` (address of the caller's
+ mutable slot holding the moved chandle), not the chandle itself. We read
+ the slot, null it out to prevent a double-move, and wrap WITHOUT inc'ing
+ (the move already gave us the +1).
+ """
+ cdef TVMFFIObjectHandle* slot_ptr = <TVMFFIObjectHandle*>arg.v_ptr
+ cdef TVMFFIObjectHandle chandle = slot_ptr[0]
+ if chandle == NULL:
+ raise ValueError("RValueRef already moved")
+ # mark as moved before constructing wrappers (so error paths don't
double-move)
+ slot_ptr[0] = NULL
+ cdef int32_t actual_type_index = TVMFFIObjectGetTypeIndex(chandle)
+ cdef TVMFFIAny synthesized
+ synthesized.type_index = actual_type_index
+ synthesized.zero_padding = 0
+ synthesized.v_int64 = 0
+ synthesized.v_ptr = chandle
+ try:
+ if actual_type_index == kTVMFFITensor:
+ obj = make_tensor_from_chandle(chandle, api)
+ else:
+ obj = make_ret_object(synthesized)
+ if api != NULL and isinstance(obj, CContainerBase):
+ (<CContainerBase>obj)._dlpack_exchange_api = api
+ except BaseException:
+ # Caller's moved +1 needs to be released on error.
+ TVMFFIObjectDecRef(chandle)
+ raise
+ Py_INCREF(obj)
+ out[0] = <PyObject*>obj
+ return 0
+
+
+cdef public int TVMFFICyCallbackArgSetterFactory(int32_t type_index,
+ TVMFFIPyCallbackArgSetter*
out) except -1:
+ """Factory that creates callback arg setters for a given type index.
+
+ POD setters live in tvm_ffi_python_helpers.h (header-inline);
+ object-bearing setters are the Cython functions above.
+ """
+ if type_index >= kTVMFFIStaticObjectBegin:
+ if type_index == kTVMFFITensor:
+ out.func = TVMFFIPyCallbackArgSetterTensor_
+ elif type_index == kTVMFFIOpaquePyObject:
+ out.func = TVMFFIPyCallbackArgSetterOpaquePyObject_
+ else:
+ out.func = TVMFFIPyCallbackArgSetterObject_
+ return 0
+ if type_index == kTVMFFINone:
+ out.func = TVMFFIPyCallbackArgSetterNone_
+ elif type_index == kTVMFFIBool:
+ out.func = TVMFFIPyCallbackArgSetterBool_
+ elif type_index == kTVMFFIInt:
+ out.func = TVMFFIPyCallbackArgSetterInt_
+ elif type_index == kTVMFFIFloat:
+ out.func = TVMFFIPyCallbackArgSetterFloat_
+ elif type_index == kTVMFFISmallStr:
+ out.func = TVMFFIPyCallbackArgSetterSmallStr_
+ elif type_index == kTVMFFISmallBytes:
+ out.func = TVMFFIPyCallbackArgSetterSmallBytes_
+ elif type_index == kTVMFFIOpaquePtr:
+ out.func = TVMFFIPyCallbackArgSetterOpaquePtr_
+ elif type_index == kTVMFFIDataType:
+ out.func = TVMFFIPyCallbackArgSetterDataType_
+ elif type_index == kTVMFFIDevice:
+ out.func = TVMFFIPyCallbackArgSetterDevice_
+ elif type_index == kTVMFFIDLTensorPtr:
+ out.func = TVMFFIPyCallbackArgSetterDLTensorPtr_
+ elif type_index == kTVMFFIObjectRValueRef:
+ out.func = TVMFFIPyCallbackArgSetterRValueRef_
+ elif type_index == kTVMFFIByteArrayPtr:
+ raise ValueError("Callback arg cannot be ByteArrayPtr")
+ elif type_index == kTVMFFIRawStr:
+ raise ValueError("Callback arg cannot be RawStr")
+ else:
+ raise ValueError("Unhandled type index %d" % type_index)
+ return 0
+
+
+def _convert_to_ffi_func(
+ object pyfunc: Callable[..., Any],
+ object tensor_cls: object = None,
+) -> Function:
+ """Convert a python function to a TVM FFI Function.
+
+ Parameters
+ ----------
+ pyfunc : Callable
+ The Python callable to wrap. Incref'd into a TVMFFIPyCallbackClosure.
+ tensor_cls : type, optional
+ If given, its ``__dlpack_c_exchange_api__`` capsule is threaded into
the
+ closure and used when constructing tensor return values inside the
+ callback.
+
+ Returns
+ -------
+ Function
+ The wrapped FFI function.
+ """
+ cdef TVMFFIObjectHandle chandle
+ cdef const DLPackExchangeAPI* dlpack_api = NULL
+ if tensor_cls is not None:
+ if not hasattr(tensor_cls, "__dlpack_c_exchange_api__"):
+ raise TypeError(
+ f"tensor_cls {tensor_cls!r} must expose
__dlpack_c_exchange_api__"
+ )
+ _get_dlpack_exchange_api(
+ tensor_cls.__dlpack_c_exchange_api__, &dlpack_api
+ )
+ CHECK_CALL(TVMFFIPyConvertPyCallback(<PyObject*>pyfunc, dlpack_api,
&chandle))
+ ret = Function.__new__(Function)
+ (<CObject>ret).chandle = chandle
+ return ret
diff --git a/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
b/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
index 9c923af..79494d5 100644
--- a/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
+++ b/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
@@ -39,6 +39,7 @@
#include <cstring>
#include <exception>
#include <iostream>
+#include <memory>
#include <unordered_map>
#include <vector>
@@ -63,6 +64,31 @@ class TVMFFIPyWithAttachedThreadState {
PyGILState_STATE gstate_;
};
+/*!
+ * \brief Closure state carried as the resource handle for an FFI function that
+ * wraps a Python callable and optional exchange api for tensor
handling.
+ *
+ * Created by TVMFFIPyConvertPyCallback and released by
+ * TVMFFIPyCallbackClosure::Deleter when the FFI function is destroyed.
+ */
+struct TVMFFIPyCallbackClosure {
+ /*! \brief Strong reference to the Python callable. */
+ PyObject* callable;
+ /*! \brief Optional DLPack exchange API used when constructing Tensor
returns. */
+ const DLPackExchangeAPI* dlpack_exchange_api;
+ /*!
+ * \brief Deleter registered with TVMFFIFunctionCreate. Runs on FFI function
destroy.
+ *
+ * Releases the closure's strong Python reference and frees the closure.
+ */
+ static void Deleter(void* context) noexcept {
+ TVMFFIPyWithAttachedThreadState thread_state;
+ auto* closure = static_cast<TVMFFIPyCallbackClosure*>(context);
+ Py_DecRef(closure->callable);
+ delete closure;
+ }
+};
+
/*!
* \brief Thread-local call stack used by TVMFFIPyCallContext.
*/
@@ -88,6 +114,9 @@ class TVMFFIPyCallStack {
}
};
+//---------------------------------------------------------------------------------------------
+// Support for Python -> FFI function calls.
+//---------------------------------------------------------------------------------------------
/*!
* \brief Context for each ffi call to track the stream, device and temporary
arguments.
*/
@@ -246,24 +275,156 @@ int TVMFFIPyArgSetterNone_(TVMFFIPyArgSetter*,
TVMFFIPyCallContext*, PyObject* a
}
//---------------------------------------------------------------------------------------------
-// The following section contains the dispatcher logic for function calling
+// Support for PyCallback function calls.
//---------------------------------------------------------------------------------------------
+
/*!
- * \brief Factory function that creates an argument setter for a given Python
argument type.
+ * \brief Context for a C -> Python callback call.
*
- * This factory function analyzes a Python argument and creates an appropriate
setter
- * that can convert Python objects of the same type to C arguments for TVM FFI
calls.
- * The setter will be cached for future use for setting argument of the same
type.
+ * Owns a temporary PyObject* array that holds arguments converted from the
+ * packed FFI call. Space is first taken from the thread-local args_stack on
+ * TVMFFIPyCallStack; if insufficient, we fall back to the heap.
*
- * \param arg The Python argument value used as a type example.
- * \param out Output parameter that receives the created argument setter.
- * \return 0 on success, -1 on failure. PyError should be set if -1 is
returned.
+ * Unlike TVMFFIPyCallContext::~TVMFFIPyCallContext, this destructor does NOT
+ * attach a thread state — callers are expected to already hold one
+ * (e.g. via TVMFFIPyWithAttachedThreadState at the top of the callback).
*
- * \note This is a callback function supplied by the caller. The factory must
satisfy
- * the invariance that the same setter can be used for other arguments
with
- * the same type as the provided example argument.
+ * The destructor also decrefs every PyObject* pushed into py_args[0 ..
+ * num_active_py_args-1], tracking the pushed count via `num_active_py_args`.
*/
-typedef int (*TVMFFIPyArgSetterFactory)(PyObject* arg, TVMFFIPyArgSetter* out);
+class TVMFFIPyCallbackContext {
+ public:
+ /*! \brief The temporary PyObject* slots for Python call arguments. */
+ PyObject** py_args = nullptr;
+ /*! \brief How many slots have a live reference and need decref on cleanup.
*/
+ int32_t num_active_py_args = 0;
+ /*! \brief Number of total argument slots allocated. */
+ int32_t num_args = 0;
+
+ TVMFFIPyCallbackContext(TVMFFIPyCallStack* call_stack, int32_t num_args)
+ : num_args(num_args), call_stack_(call_stack) {
+ static_assert(sizeof(TVMFFIAny) % sizeof(PyObject*) == 0);
+ // slots needed in the unit of TVMFFIAny
+ int64_t slots_needed =
+ (static_cast<int64_t>(num_args) * sizeof(PyObject*) +
sizeof(TVMFFIAny) - 1) /
+ sizeof(TVMFFIAny);
+ old_args_stack_top_ = call_stack->args_stack_top;
+ if (call_stack->args_stack_top + slots_needed <=
+ static_cast<int64_t>(call_stack->args_stack.size())) {
+ py_args =
+ reinterpret_cast<PyObject**>(call_stack->args_stack.data() +
call_stack->args_stack_top);
+ call_stack->args_stack_top += slots_needed;
+ } else {
+ heap_ptr_ = new PyObject*[num_args];
+ py_args = heap_ptr_;
+ }
+ }
+
+ ~TVMFFIPyCallbackContext() {
+ // caller must already hold an attached thread state; do NOT re-attach.
+ // we ensure that all the pyargs are not null
+ for (int32_t i = 0; i < num_active_py_args; ++i) {
+ Py_DecRef(py_args[i]);
+ }
+ if (heap_ptr_ == nullptr) {
+ call_stack_->args_stack_top = old_args_stack_top_;
+ } else {
+ delete[] heap_ptr_;
+ }
+ }
+
+ private:
+ TVMFFIPyCallStack* call_stack_ = nullptr;
+ int64_t old_args_stack_top_ = 0;
+ PyObject** heap_ptr_ = nullptr;
+};
+
+/*!
+ * \brief A callback arg setter entry registered to handle efficient callback
argument conversion.
+ */
+struct TVMFFIPyCallbackArgSetter {
+ /*!
+ * \brief Callback type that converts a borrowed TVMFFIAny (AnyView) to a
new-reference PyObject*.
+ * \param handle Pointer to the TVMFFIPyCallbackArgSetter (for per-type
state).
+ * \param dlpack_exchange_api The DLPack exchange API (may be NULL).
+ * \param arg The TVMFFIAny value to convert (borrowed; setter must inc if
it transfers
+ * ownership).
+ * \param out Output: a new-reference PyObject*.
+ * \return 0 on success, -1 on failure (PyErr set).
+ */
+ int (*func)(TVMFFIPyCallbackArgSetter* handle, const DLPackExchangeAPI*
dlpack_exchange_api,
+ const TVMFFIAny* arg, PyObject** out);
+};
+
+// common callback arg setters that can be quikcly implemented in C++ and used
by cython factory
+// note that PyErr is propagated back so we just need to return -1 on failure.
+int TVMFFIPyCallbackArgSetterNone_(TVMFFIPyCallbackArgSetter*, const
DLPackExchangeAPI*,
+ const TVMFFIAny*, PyObject** out) noexcept {
+ Py_IncRef(Py_None);
+ *out = Py_None;
+ return 0;
+}
+
+int TVMFFIPyCallbackArgSetterBool_(TVMFFIPyCallbackArgSetter*, const
DLPackExchangeAPI*,
+ const TVMFFIAny* arg, PyObject** out)
noexcept {
+ *out = PyBool_FromLong(static_cast<long>(arg->v_int64));
+ return (*out != nullptr) ? 0 : -1;
+}
+
+int TVMFFIPyCallbackArgSetterInt_(TVMFFIPyCallbackArgSetter*, const
DLPackExchangeAPI*,
+ const TVMFFIAny* arg, PyObject** out)
noexcept {
+ *out = PyLong_FromLongLong(arg->v_int64);
+ return (*out != nullptr) ? 0 : -1;
+}
+
+int TVMFFIPyCallbackArgSetterFloat_(TVMFFIPyCallbackArgSetter*, const
DLPackExchangeAPI*,
+ const TVMFFIAny* arg, PyObject** out)
noexcept {
+ *out = PyFloat_FromDouble(arg->v_float64);
+ return (*out != nullptr) ? 0 : -1;
+}
+
+int TVMFFIPyCallbackArgSetterSmallStr_(TVMFFIPyCallbackArgSetter*, const
DLPackExchangeAPI*,
+ const TVMFFIAny* arg, PyObject** out)
noexcept {
+ TVMFFIByteArray ba = TVMFFISmallBytesGetContentByteArray(arg);
+ *out = PyUnicode_DecodeUTF8(ba.data, static_cast<Py_ssize_t>(ba.size),
nullptr);
+ return (*out != nullptr) ? 0 : -1;
+}
+
+int TVMFFIPyCallbackArgSetterSmallBytes_(TVMFFIPyCallbackArgSetter*, const
DLPackExchangeAPI*,
+ const TVMFFIAny* arg, PyObject** out)
noexcept {
+ TVMFFIByteArray ba = TVMFFISmallBytesGetContentByteArray(arg);
+ *out = PyBytes_FromStringAndSize(ba.data, static_cast<Py_ssize_t>(ba.size));
+ return (*out != nullptr) ? 0 : -1;
+}
+
+///--------------------------------------------------------------------------------
+/// Declaring functions defined in Cython to be invoked by the C++
implementation.
+/// in all cases PyErr is propagated back so we just need to return -1 on
failure.
+///--------------------------------------------------------------------------------
+/*
+ * \brief Set the error raised from Python to the FFI side.
+ * \param py_err The Python error to be set.
+ * \return 0 on success, -1 on failure. PyError should be set if -1 is
returned.
+ */
+__PYX_EXTERN_C int TVMFFICyErrorSetRaisedFromPyError(PyObject* py_err);
+/*
+ * \brief Create an argument setter for a given Python argument type.
+ * \param arg The Python argument to be set.
+ * \param out The output argument setter.
+ * \return 0 on success, -1 on failure. PyError should be set if -1 is
returned.
+ */
+__PYX_EXTERN_C int TVMFFICyArgSetterFactory(PyObject* arg, TVMFFIPyArgSetter*
out);
+/*
+ * \brief Create a callback arg setter for a given type index.
+ * \param type_index The type index of the argument.
+ * \param out The output callback arg setter.
+ * \return 0 on success, -1 on failure. PyError should be set if -1 is
returned.
+ */
+__PYX_EXTERN_C int TVMFFICyCallbackArgSetterFactory(int32_t type_index,
+ TVMFFIPyCallbackArgSetter*
out);
+//---------------------------------------------------------------------------------------------
+// The function call manager section
+//---------------------------------------------------------------------------------------------
/*!
* \brief A manager class that handles python ffi calls.
@@ -280,7 +441,6 @@ class TVMFFIPyCallManager {
}
/*!
* \brief Call a function with a variable number of arguments
- * \param setter_factory The factory function to create the setter
* \param func_handle The handle of the function to call
* \param py_arg_tuple The arguments to the function
* \param result The result of the function
@@ -290,9 +450,8 @@ class TVMFFIPyCallManager {
* \return 0 on when there is no python error, -1 on python error
* \note When an error happens on FFI side, we should return 0 and set
c_api_ret_code
*/
- TVM_FFI_INLINE int FuncCall(TVMFFIPyArgSetterFactory setter_factory, void*
func_handle,
- PyObject* py_arg_tuple, TVMFFIAny* result, int*
c_api_ret_code,
- bool release_gil,
+ TVM_FFI_INLINE int FuncCall(void* func_handle, PyObject* py_arg_tuple,
TVMFFIAny* result,
+ int* c_api_ret_code, bool release_gil,
const DLPackExchangeAPI**
optional_out_ctx_dlpack_api) {
int64_t num_args = PyTuple_Size(py_arg_tuple);
if (num_args == -1) return -1;
@@ -303,7 +462,7 @@ class TVMFFIPyCallManager {
for (int64_t i = 0; i < num_args; ++i) {
PyObject* py_arg = PyTuple_GetItem(py_arg_tuple, i);
TVMFFIAny* c_arg = ctx.packed_args + i;
- if (SetArgument(setter_factory, &ctx, py_arg, c_arg) != 0) return -1;
+ if (SetArgument(&ctx, py_arg, c_arg) != 0) return -1;
}
TVMFFIStreamHandle prev_stream = nullptr;
DLPackManagedTensorAllocator prev_tensor_allocator = nullptr;
@@ -371,7 +530,6 @@ class TVMFFIPyCallManager {
*
* This function will also not release the GIL since constructor call is
usually cheap.
*
- * \param setter_factory The factory function to create the setter
* \param func_handle The handle of the constructor to call
* \param py_arg_tuple The arguments to the constructor
* \param result The result of the constructor
@@ -379,9 +537,8 @@ class TVMFFIPyCallManager {
* \param parent_ctx The parent call context to
* \return 0 on success, -1 on failure
*/
- TVM_FFI_INLINE int ConstructorCall(TVMFFIPyArgSetterFactory setter_factory,
void* func_handle,
- PyObject* py_arg_tuple, TVMFFIAny*
result, int* c_api_ret_code,
- TVMFFIPyCallContext* parent_ctx) {
+ TVM_FFI_INLINE int ConstructorCall(void* func_handle, PyObject*
py_arg_tuple, TVMFFIAny* result,
+ int* c_api_ret_code, TVMFFIPyCallContext*
parent_ctx) {
int64_t num_args = PyTuple_Size(py_arg_tuple);
if (num_args == -1) return -1;
try {
@@ -391,7 +548,7 @@ class TVMFFIPyCallManager {
for (int64_t i = 0; i < num_args; ++i) {
PyObject* py_arg = PyTuple_GetItem(py_arg_tuple, i);
TVMFFIAny* c_arg = ctx.packed_args + i;
- if (SetArgument(setter_factory, &ctx, py_arg, c_arg) != 0) return -1;
+ if (SetArgument(&ctx, py_arg, c_arg) != 0) return -1;
}
c_api_ret_code[0] = TVMFFIFunctionCall(func_handle, ctx.packed_args,
num_args, result);
// propagate the call context to the parent context
@@ -415,13 +572,12 @@ class TVMFFIPyCallManager {
}
}
- TVM_FFI_INLINE int SetField(TVMFFIPyArgSetterFactory setter_factory, void*
field_setter,
- int64_t field_flags, void* field_ptr, PyObject*
py_arg,
- int* c_api_ret_code) {
+ TVM_FFI_INLINE int SetField(void* field_setter, int64_t field_flags, void*
field_ptr,
+ PyObject* py_arg, int* c_api_ret_code) {
try {
TVMFFIPyCallContext ctx(&call_stack_, 1);
TVMFFIAny* c_arg = ctx.packed_args;
- if (SetArgument(setter_factory, &ctx, py_arg, c_arg) != 0) return -1;
+ if (SetArgument(&ctx, py_arg, c_arg) != 0) return -1;
if (!(field_flags & kTVMFFIFieldFlagBitSetterIsFunctionObj)) {
auto setter = reinterpret_cast<TVMFFIFieldSetter>(field_setter);
c_api_ret_code[0] = (*setter)(field_ptr, c_arg);
@@ -443,12 +599,11 @@ class TVMFFIPyCallManager {
}
}
- TVM_FFI_INLINE int PyObjectToFFIAny(TVMFFIPyArgSetterFactory setter_factory,
PyObject* py_arg,
- TVMFFIAny* out, int* c_api_ret_code) {
+ TVM_FFI_INLINE int PyObjectToFFIAny(PyObject* py_arg, TVMFFIAny* out, int*
c_api_ret_code) {
try {
TVMFFIPyCallContext ctx(&call_stack_, 1);
TVMFFIAny* c_arg = ctx.packed_args;
- if (SetArgument(setter_factory, &ctx, py_arg, c_arg) != 0) return -1;
+ if (SetArgument(&ctx, py_arg, c_arg) != 0) return -1;
c_api_ret_code[0] = TVMFFIAnyViewToOwnedAny(c_arg, out);
return 0;
} catch (const std::exception& ex) {
@@ -460,14 +615,12 @@ class TVMFFIPyCallManager {
/*!
* \brief Set an py_arg to out.
- * \param setter_factory The factory function to create the setter
* \param ctx The call context
* \param py_arg The python argument to be set
* \param out The output argument
* \return 0 on success, -1 on failure
*/
- TVM_FFI_INLINE int SetArgument(TVMFFIPyArgSetterFactory setter_factory,
TVMFFIPyCallContext* ctx,
- PyObject* py_arg, TVMFFIAny* out) {
+ TVM_FFI_INLINE int SetArgument(TVMFFIPyCallContext* ctx, PyObject* py_arg,
TVMFFIAny* out) {
PyTypeObject* py_type = Py_TYPE(py_arg);
// pre-zero the output argument, modulo the type index
out->type_index = kTVMFFINone;
@@ -475,8 +628,8 @@ class TVMFFIPyCallManager {
out->v_int64 = 0;
// find the pre-cached setter
// This class is thread-local, so we don't need to worry about race
condition
- auto it = dispatch_map_.find(py_type);
- if (it != dispatch_map_.end()) {
+ auto it = arg_dispatch_map_.find(py_type);
+ if (it != arg_dispatch_map_.end()) {
TVMFFIPyArgSetter setter = it->second;
// if error happens, propagate it back
if (setter(ctx, py_arg, out) != 0) return -1;
@@ -484,37 +637,190 @@ class TVMFFIPyCallManager {
// no dispatch found, query and create a new one.
TVMFFIPyArgSetter setter;
// propagate python error back
- if (setter_factory(py_arg, &setter) != 0) {
+ if (TVMFFICyArgSetterFactory(py_arg, &setter) != 0) {
return -1;
}
// update dispatch table
- dispatch_map_.emplace(py_type, setter);
+ arg_dispatch_map_.emplace(py_type, setter);
if (setter(ctx, py_arg, out) != 0) return -1;
}
return 0;
}
/*!
- * \brief Get the size of the dispatch map
- * \return The size of the dispatch map
+ * \brief Get the size of the arg dispatch map
+ * \return The size of the arg dispatch map
+ */
+ size_t GetArgDispatchMapSize() const { return arg_dispatch_map_.size(); }
+
+ /*!
+ * \brief Convert a borrowed TVMFFIAny (AnyView) into a new-reference
PyObject*.
+ * \param dlpack_exchange_api The DLPack exchange API (may be NULL).
+ * \param arg The borrowed TVMFFIAny to convert.
+ * \param py_arg The output PyObject*.
+ * \return 0 on success, -1 on failure. PyError should be set if -1 is
returned.
*/
- size_t GetDispatchMapSize() const { return dispatch_map_.size(); }
+ TVM_FFI_INLINE int SetPyCallbackArg(const DLPackExchangeAPI*
dlpack_exchange_api,
+ const TVMFFIAny* arg, PyObject** out) {
+ size_t type_index = static_cast<size_t>(arg->type_index);
+ // Mirror of SetArgument for the C++ -> Python callback path: each per-type
+ // callback arg setter is responsible for its own refcount.
+ // hot path: cached hit
+ if (type_index < callback_arg_dispatch_table_.size() &&
+ callback_arg_dispatch_table_[type_index].func != nullptr) {
+ TVMFFIPyCallbackArgSetter* setter =
&callback_arg_dispatch_table_[type_index];
+ return setter->func(setter, dlpack_exchange_api, arg, out);
+ }
+ // cold path: grow and populate via factory
+ if (type_index >= callback_arg_dispatch_table_.size()) {
+ // initialize empty entries with nullptr
+ callback_arg_dispatch_table_.resize(type_index + 1,
TVMFFIPyCallbackArgSetter{nullptr});
+ }
+ TVMFFIPyCallbackArgSetter* setter =
&callback_arg_dispatch_table_[type_index];
+ if (TVMFFICyCallbackArgSetterFactory(static_cast<int32_t>(type_index),
setter) != 0) {
+ return -1;
+ }
+ return setter->func(setter, dlpack_exchange_api, arg, out);
+ }
+
+ /*!
+ * \brief Python Callback function entry
+ *
+ * \param context The TVMFFIPyCallbackClosure* holding the Python callable
+ * and optional DLPack exchange API.
+ * \param packed_args The packed FFI arguments.
+ * \param num_args Number of arguments.
+ * \param result Output FFI result.
+ * \return 0 on success, -1 on error.
+ */
+ TVM_FFI_INLINE int PyCallback(void* context, const TVMFFIAny* packed_args,
int32_t num_args,
+ TVMFFIAny* result) noexcept {
+ TVMFFIPyWithAttachedThreadState thread_state;
+ auto* closure = static_cast<TVMFFIPyCallbackClosure*>(context);
+ // Wrap the body in try/catch so any C++ exception raised by the stack
+ // allocator (TVMFFIPyCallbackContext / TVMFFIPyCallContext), dispatch
+ // table resize in SetPyCallbackArg, or unordered_map::emplace in
+ // SetArgument is converted into a PyErr + FFI error instead of
+ // triggering std::terminate via the noexcept contract.
+ try {
+ TVMFFIPyCallbackContext cb_ctx(&call_stack_, num_args);
+ // Step 1: Convert each packed arg (borrowed AnyView) to a PyObject*
+ for (int32_t i = 0; i < num_args; ++i) {
+ if (SetPyCallbackArg(closure->dlpack_exchange_api, &packed_args[i],
&cb_ctx.py_args[i]) !=
+ 0) {
+ ForwardPyErrorToFFI();
+ return -1;
+ }
+ // must set active arguments count to ensure correct recycling
+ cb_ctx.num_active_py_args = i + 1;
+ }
+ // Step 2: Call the Python function via vectorcall. Wrap py_result in
+ // a RAII guard so its +1 is released on every exit path, including
+ // the C++ exception path (e.g., bad_alloc from ret_ctx construction
+ // or SetArgument's emplace).
+#if PY_VERSION_HEX >= 0x03090000
+ PyObject* py_result_raw = PyObject_Vectorcall(closure->callable,
cb_ctx.py_args,
+
static_cast<size_t>(num_args), nullptr);
+#else
+ // backward compatibility for Python 3.8
+ PyObject* py_result_raw = _PyObject_Vectorcall(closure->callable,
cb_ctx.py_args,
+
static_cast<size_t>(num_args), nullptr);
+#endif
+ struct PyResultGuard {
+ PyObject* p;
+ ~PyResultGuard() {
+ if (p != nullptr) Py_DecRef(p);
+ }
+ } py_result{py_result_raw};
+ if (py_result.p == Py_None) {
+ // fast path for Py_None
+ result->type_index = kTVMFFINone;
+ result->zero_padding = 0;
+ result->v_int64 = 0;
+ return 0;
+ } else if (py_result.p != nullptr) {
+ // normal return
+ // Use SetArgument on a temporary view slot, then promote to owned.
+ // Note: SetArgument only borrows py_result's chandle into `view`; it
+ // does NOT inc. py_result must stay alive until AFTER
+ // TVMFFIAnyViewToOwnedAny has promoted the view to an owned ref,
+ // otherwise dec'ing py_result first could free the underlying object
+ // (e.g. if py_result owns the last ref to a freshly created tensor).
+ // The guard's destructor runs AFTER the return value is computed.
+ TVMFFIPyCallContext ret_ctx(&call_stack_, 1);
+ TVMFFIAny* view = ret_ctx.packed_args;
+ if (SetArgument(&ret_ctx, py_result.p, view) != 0) {
+ ForwardPyErrorToFFI();
+ return -1;
+ }
+ // TLS FFI error set on failure.
+ return TVMFFIAnyViewToOwnedAny(view, result);
+ } else {
+ // vectorcall failed
+ ForwardPyErrorToFFI();
+ return -1;
+ }
+ } catch (const std::exception& ex) {
+ // very rare, catch c++ exception and set python error
+ PyErr_SetString(PyExc_RuntimeError, ex.what());
+ ForwardPyErrorToFFI();
+ return -1;
+ }
+ }
+
+ /*!
+ * \brief Fetch the current Python exception and forward it to
+ * TVMFFICyErrorSetRaisedFromPyError, then clear the Python error
indicator.
+ *
+ * This helper correctly extracts the exception *value* (not just the type
+ * returned by PyErr_Occurred()) so that set_last_ffi_error can access the
+ * message and traceback.
+ */
+ static void ForwardPyErrorToFFI() noexcept {
+#if PY_VERSION_HEX >= 0x030C0000
+ // Python 3.12+: PyErr_Fetch / PyErr_NormalizeException are deprecated.
+ // PyErr_GetRaisedException returns an already-normalized exception
+ // instance and clears the indicator. Traceback is attached as usual.
+ PyObject* pvalue = PyErr_GetRaisedException();
+ if (pvalue != nullptr) {
+ TVMFFICyErrorSetRaisedFromPyError(pvalue);
+ Py_DecRef(pvalue);
+ }
+#else
+ // Python 3.9 - 3.11.
+ PyObject *ptype, *pvalue, *ptraceback;
+ PyErr_Fetch(&ptype, &pvalue, &ptraceback);
+ PyErr_NormalizeException(&ptype, &pvalue, &ptraceback);
+ if (ptraceback != nullptr) {
+ PyException_SetTraceback(pvalue, ptraceback);
+ }
+ TVMFFICyErrorSetRaisedFromPyError(pvalue);
+ Py_DecRef(ptype);
+ Py_DecRef(pvalue);
+ Py_DecRef(ptraceback);
+#endif
+ }
private:
TVMFFIPyCallManager() {
static constexpr size_t kDefaultDispatchCapacity = 32;
- dispatch_map_.reserve(kDefaultDispatchCapacity);
+ arg_dispatch_map_.reserve(kDefaultDispatchCapacity);
+ // Pre-allocate callback arg dispatch table for static type indices
+ static constexpr size_t kDefaultCallbackArgDispatchCapacity = 128;
+ callback_arg_dispatch_table_.resize(kDefaultCallbackArgDispatchCapacity);
}
- // internal dispacher
- std::unordered_map<PyTypeObject*, TVMFFIPyArgSetter> dispatch_map_;
+ // internal arg dispatch map: type -> argument setter
+ std::unordered_map<PyTypeObject*, TVMFFIPyArgSetter> arg_dispatch_map_;
// call stack
TVMFFIPyCallStack call_stack_;
+ // callback arg setter dispatch table indexed by type_index (view-based path
+ // used by PyCallback; see TVMFFIPyCallbackArgSetter docs above)
+ std::vector<TVMFFIPyCallbackArgSetter> callback_arg_dispatch_table_;
};
/*!
* \brief Call a function with a variable number of arguments
- * \param setter_factory The factory function to create the setter
* \param func_handle The handle of the function to call
* \param py_arg_tuple The arguments to the function
* \param result The result of the function
@@ -523,13 +829,11 @@ class TVMFFIPyCallManager {
* \param out_ctx_dlpack_api The DLPack exchange API to be used for the result
* \return 0 on success, nonzero on failure
*/
-TVM_FFI_INLINE int TVMFFIPyFuncCall(TVMFFIPyArgSetterFactory setter_factory,
void* func_handle,
- PyObject* py_arg_tuple, TVMFFIAny* result,
int* c_api_ret_code,
- bool release_gil = true,
+TVM_FFI_INLINE int TVMFFIPyFuncCall(void* func_handle, PyObject* py_arg_tuple,
TVMFFIAny* result,
+ int* c_api_ret_code, bool release_gil =
true,
const DLPackExchangeAPI**
out_ctx_dlpack_api = nullptr) {
- return TVMFFIPyCallManager::ThreadLocal()->FuncCall(setter_factory,
func_handle, py_arg_tuple,
- result, c_api_ret_code,
release_gil,
- out_ctx_dlpack_api);
+ return TVMFFIPyCallManager::ThreadLocal()->FuncCall(
+ func_handle, py_arg_tuple, result, c_api_ret_code, release_gil,
out_ctx_dlpack_api);
}
/*!
@@ -542,7 +846,6 @@ TVM_FFI_INLINE int
TVMFFIPyFuncCall(TVMFFIPyArgSetterFactory setter_factory, voi
*
* This function will also not release the GIL since constructor call is
usually cheap.
*
- * \param setter_factory The factory function to create the setter
* \param func_handle The handle of the function to call
* \param py_arg_tuple The arguments to the constructor
* \param result The result of the constructor
@@ -552,17 +855,15 @@ TVM_FFI_INLINE int
TVMFFIPyFuncCall(TVMFFIPyArgSetterFactory setter_factory, voi
* \param out_dlpack_exporter The DLPack exporter to be used for the result
* \return 0 on success, nonzero on failure
*/
-TVM_FFI_INLINE int TVMFFIPyConstructorCall(TVMFFIPyArgSetterFactory
setter_factory,
- void* func_handle, PyObject*
py_arg_tuple,
+TVM_FFI_INLINE int TVMFFIPyConstructorCall(void* func_handle, PyObject*
py_arg_tuple,
TVMFFIAny* result, int*
c_api_ret_code,
TVMFFIPyCallContext* parent_ctx) {
- return TVMFFIPyCallManager::ThreadLocal()->ConstructorCall(
- setter_factory, func_handle, py_arg_tuple, result, c_api_ret_code,
parent_ctx);
+ return TVMFFIPyCallManager::ThreadLocal()->ConstructorCall(func_handle,
py_arg_tuple, result,
+ c_api_ret_code,
parent_ctx);
}
/*!
* \brief Set a field of a FFI object
- * \param setter_factory The factory function to create the setter
* \param field_setter The field setter (function pointer or FunctionObj
handle)
* \param field_flags The field flags (to dispatch between function pointer
and FunctionObj)
* \param field_ptr The pointer to the field
@@ -570,49 +871,104 @@ TVM_FFI_INLINE int
TVMFFIPyConstructorCall(TVMFFIPyArgSetterFactory setter_facto
* \param c_api_ret_code The return code of the function
* \return 0 on success, nonzero on failure
*/
-TVM_FFI_INLINE int TVMFFIPyCallFieldSetter(TVMFFIPyArgSetterFactory
setter_factory,
- void* field_setter, int64_t
field_flags, void* field_ptr,
+TVM_FFI_INLINE int TVMFFIPyCallFieldSetter(void* field_setter, int64_t
field_flags, void* field_ptr,
PyObject* py_arg, int*
c_api_ret_code) {
- return TVMFFIPyCallManager::ThreadLocal()->SetField(setter_factory,
field_setter, field_flags,
- field_ptr, py_arg,
c_api_ret_code);
+ return TVMFFIPyCallManager::ThreadLocal()->SetField(field_setter,
field_flags, field_ptr, py_arg,
+ c_api_ret_code);
}
/*!
* \brief Set an python argument to a FFI Any using the generic dispatcher in
call manager
- * \param setter_factory The factory function to create the setter
* \param ctx The call context
* \param py_arg_tvm_ffi_value The python argument to be set using the
__tvm_ffi_value__ protocol
* \param out The output argument
* \return 0 on success, nonzero on failure
*/
-TVM_FFI_INLINE int
TVMFFIPySetArgumentGenericDispatcher(TVMFFIPyArgSetterFactory setter_factory,
- TVMFFIPyCallContext*
ctx,
+TVM_FFI_INLINE int TVMFFIPySetArgumentGenericDispatcher(TVMFFIPyCallContext*
ctx,
PyObject*
py_arg_tvm_ffi_value,
TVMFFIAny* out) {
- return TVMFFIPyCallManager::ThreadLocal()->SetArgument(setter_factory, ctx,
py_arg_tvm_ffi_value,
- out);
+ return TVMFFIPyCallManager::ThreadLocal()->SetArgument(ctx,
py_arg_tvm_ffi_value, out);
}
/*!
* \brief Convert a Python object to a FFI Any
- * \param setter_factory The factory function to create the setter
* \param py_arg The python argument to be set
* \param out The output argument
* \param c_api_ret_code The return code of the function
* \return 0 on success, nonzero on failure
*/
-TVM_FFI_INLINE int TVMFFIPyPyObjectToFFIAny(TVMFFIPyArgSetterFactory
setter_factory,
- PyObject* py_arg, TVMFFIAny* out,
int* c_api_ret_code) {
- return TVMFFIPyCallManager::ThreadLocal()->PyObjectToFFIAny(setter_factory,
py_arg, out,
- c_api_ret_code);
+TVM_FFI_INLINE int TVMFFIPyPyObjectToFFIAny(PyObject* py_arg, TVMFFIAny* out,
int* c_api_ret_code) {
+ return TVMFFIPyCallManager::ThreadLocal()->PyObjectToFFIAny(py_arg, out,
c_api_ret_code);
+}
+
+/*!
+ * \brief Get the size of the arg dispatch map
+ * \return The size of the arg dispatch map
+ */
+TVM_FFI_INLINE size_t TVMFFIPyGetArgDispatchMapSize() {
+ return TVMFFIPyCallManager::ThreadLocal()->GetArgDispatchMapSize();
}
+//---------------------------------------------------------------------------------------------
+// Free function wrapper for the Python callback path.
+// Mirrors the pattern of TVMFFIPyFuncCall / TVMFFIPyConstructorCall: a
top-level
+// TVM_FFI_INLINE free function that forwards to the thread-local manager.
+//---------------------------------------------------------------------------------------------
+
/*!
- * \brief Get the size of the dispatch map
- * \return The size of the dispatch map
+ * \brief C-callable Python callback entry point (TVMFFISafeCallType shape).
+ *
+ * Forwards to TVMFFIPyCallManager::ThreadLocal()->PyCallback. Designed to be
+ * installed as the safe_call pointer for FFI functions that wrap a Python
+ * callable.
+ *
+ * \note The `context` argument is interpreted as a TVMFFIPyCallbackClosure*
+ * by the manager (see TVMFFIPyConvertPyCallback).
*/
-TVM_FFI_INLINE size_t TVMFFIPyGetDispatchMapSize() {
- return TVMFFIPyCallManager::ThreadLocal()->GetDispatchMapSize();
+TVM_FFI_INLINE int TVMFFIPyCallback(void* context, const TVMFFIAny*
packed_args, int32_t num_args,
+ TVMFFIAny* result) noexcept {
+ return TVMFFIPyCallManager::ThreadLocal()->PyCallback(context, packed_args,
num_args, result);
+}
+
+/*!
+ * \brief Create an FFI function handle from a Python callable + optional
DLPack exchange API.
+ *
+ * Allocates a TVMFFIPyCallbackClosure on the heap, IncRefs the callable, and
+ * registers it with the FFI function-creation API using TVMFFIPyCallback as
the
+ * safe-call entry point and TVMFFIPyCallbackClosure::Deleter as the deleter.
+ *
+ * Returns the raw FFI return code (TLS FFI error set on failure). The Cython
+ * caller uses CHECK_CALL to translate it into a Python exception.
+ *
+ * \param callable The Python callable to wrap. Must be non-NULL.
+ * \param dlpack_api Optional DLPack exchange API. May be NULL.
+ * \param out_handle Destination for the new FFI function handle.
+ * \return The return code from TVMFFIFunctionCreate (0 on success).
+ */
+TVM_FFI_INLINE int TVMFFIPyConvertPyCallback(PyObject* callable,
+ const DLPackExchangeAPI*
dlpack_api,
+ TVMFFIObjectHandle* out_handle)
noexcept {
+ // Use nothrow new: plain `new` can throw std::bad_alloc, which in this
+ // noexcept function would trigger std::terminate. On allocation failure,
+ // set PyErr and return -1 so the Cython caller's CHECK_CALL surfaces it.
+ auto* raw = new (std::nothrow) TVMFFIPyCallbackClosure{callable, dlpack_api};
+ if (raw == nullptr) {
+ PyErr_NoMemory();
+ return -1;
+ }
+ // The callable's +1 is owned by the closure;
TVMFFIPyCallbackClosure::Deleter
+ // is responsible for Py_DecRef on destruction. By wiring the same Deleter as
+ // the unique_ptr deleter, the failure path below (unique_ptr unwind) runs
+ // the same cleanup as the success path (invoked by the FFI runtime).
+ Py_IncRef(callable);
+ std::unique_ptr<TVMFFIPyCallbackClosure, void (*)(void*)> closure(
+ raw, &TVMFFIPyCallbackClosure::Deleter);
+ int rc = TVMFFIFunctionCreate(closure.get(), &TVMFFIPyCallback,
&TVMFFIPyCallbackClosure::Deleter,
+ out_handle);
+ // On success, transfer ownership to the FFI function; on failure, let
+ // unique_ptr unwind via Deleter (decrefs the callable, frees the closure).
+ if (rc == 0) closure.release();
+ return rc;
}
/*!
diff --git a/python/tvm_ffi/cython/type_info.pxi
b/python/tvm_ffi/cython/type_info.pxi
index 86b9e65..d7f39be 100644
--- a/python/tvm_ffi/cython/type_info.pxi
+++ b/python/tvm_ffi/cython/type_info.pxi
@@ -54,7 +54,6 @@ cdef class FieldSetter:
cdef int c_api_ret_code
cdef void* field_ptr = (<char*>(<CObject>obj).chandle) + self.offset
TVMFFIPyCallFieldSetter(
- TVMFFIPyArgSetterFactory_,
self.setter,
self.flags,
field_ptr,
@@ -936,7 +935,6 @@ cdef _register_one_field(
default_obj = py_field.default_factory
if default_obj is not MISSING:
TVMFFIPyPyObjectToFFIAny(
- TVMFFIPyArgSetterFactory_,
<PyObject*>default_obj,
&default_any,
&c_api_ret_code
@@ -1148,7 +1146,6 @@ cdef _register_py_methods(int32_t type_index, list
py_methods, frozenset type_at
# Convert Python object -> TVMFFIAny
TVMFFIPyPyObjectToFFIAny(
- TVMFFIPyArgSetterFactory_,
<PyObject*>func,
&func_any,
&c_api_ret_code,
diff --git a/tests/python/test_function.py b/tests/python/test_function.py
index 686fc08..3f673c1 100644
--- a/tests/python/test_function.py
+++ b/tests/python/test_function.py
@@ -25,6 +25,13 @@ import numpy as np
import pytest
import tvm_ffi
+try:
+ import torch
+except ImportError:
+ torch = None # ty: ignore[invalid-assignment]
+
+_HAS_TORCH_DLPACK_API = torch is not None and hasattr(torch.Tensor,
"__dlpack_c_exchange_api__")
+
def test_echo() -> None:
fecho = tvm_ffi.get_global_func("testing.echo")
@@ -402,3 +409,61 @@ def test_function_with_value_protocol() -> None:
nested_value_protocol = ValueProtocol([ValueProtocol(1), ValueProtocol(2),
ValueProtocol(3)])
assert tuple(fecho(nested_value_protocol)) == (1, 2, 3)
+
+
+def test_convert_func_tensor_cls_missing_attribute() -> None:
+ """Passing a class without __dlpack_c_exchange_api__ raises TypeError."""
+
+ class DummyTensor:
+ pass
+
+ with pytest.raises(TypeError, match="__dlpack_c_exchange_api__"):
+ tvm_ffi.convert_func(lambda x: x, tensor_cls=DummyTensor)
+
+ with pytest.raises(TypeError, match="__dlpack_c_exchange_api__"):
+ tvm_ffi.convert_func(lambda x: x, tensor_cls=object)
+
+
+def test_convert_func_raises_propagates() -> None:
+ """An exception raised inside the callback propagates out to the caller."""
+
+ def raises(x: int) -> None:
+ raise ValueError(f"boom {x}")
+
+ f = tvm_ffi.convert_func(raises)
+ with pytest.raises(ValueError, match="boom 42"):
+ f(42)
+
+
[email protected](
+ not _HAS_TORCH_DLPACK_API,
+ reason="torch.Tensor.__dlpack_c_exchange_api__ not available",
+)
+def test_convert_func_with_torch_tensor_cls() -> None:
+ """tensor_cls=torch.Tensor delivers torch.Tensor instances to the callback.
+
+ Asserts the type *inside* the callback (which runs on the C++ -> Python
+ side of the FFI boundary) — the return value's Python type depends on
+ the outer caller's conversion path, so we verify shape survives the
+ round-trip rather than isinstance on the return.
+ """
+ calls = 0
+
+ def callback(a: Any, b: Any, c: Any) -> Any:
+ nonlocal calls
+ calls += 1
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, torch.Tensor)
+ assert isinstance(c, torch.Tensor)
+ assert list(a.shape) == [2]
+ assert list(b.shape) == [3]
+ assert list(c.shape) == [4]
+ return b
+
+ f = tvm_ffi.convert_func(callback, tensor_cls=torch.Tensor)
+ a = torch.zeros(2)
+ b = torch.ones(3)
+ c = torch.full((4,), 2.0)
+ out = f(a, b, c)
+ assert calls == 1
+ assert tuple(out.shape) == (3,)
diff --git a/tests/scripts/benchmark_pycallback.py
b/tests/scripts/benchmark_pycallback.py
new file mode 100644
index 0000000..74b97f4
--- /dev/null
+++ b/tests/scripts/benchmark_pycallback.py
@@ -0,0 +1,111 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Benchmark C++ -> Python callback overhead with 3 torch.Tensor arguments.
+
+Both variants are invoked by the same C++ ``invoke_n`` loop so the per-call
+cost reflects only the callback-arg conversion path:
+
+1. ``convert_func(cb, tensor_cls=torch.Tensor)`` — the DLPack exchange API is
+ threaded into the closure, so each tensor arg is converted to a
+ ``torch.Tensor`` by the C-level callback arg setter before the callback
runs.
+2. ``convert_func(cb)`` — callback receives an ``ffi.Tensor`` and calls
+ ``torch.from_dlpack(x)`` explicitly inside the callback body for each arg.
+
+Arguments are 3 x ``torch.zeros(1, device="cuda:0")``.
+"""
+
+from __future__ import annotations
+
+import time
+
+import torch
+import tvm_ffi
+import tvm_ffi.cpp
+from benchmark_dlpack import print_speed
+
+_INVOKE_N_CPP_SOURCE = r"""
+#include <tvm/ffi/function.h>
+
+void invoke_n(tvm::ffi::Function callback, int64_t n,
+ tvm::ffi::AnyView a, tvm::ffi::AnyView b, tvm::ffi::AnyView c) {
+ for (int64_t i = 0; i < n; ++i) {
+ callback(a, b, c);
+ }
+}
+"""
+
+
+def _load_invoke_n() -> object:
+ mod = tvm_ffi.cpp.load_inline(
+ name="benchmark_pycallback_invoke_n",
+ cpp_sources=_INVOKE_N_CPP_SOURCE,
+ functions=["invoke_n"],
+ )
+ return mod.get_function("invoke_n")
+
+
+def bench_pycallback_tensor_cls_torch(invoke_n, a, b, c, repeat: int) -> None:
# noqa: ANN001
+ """convert_func(cb, tensor_cls=torch.Tensor): callback sees torch.Tensor
directly."""
+
+ def cb(_a, _b, _c) -> None: # noqa: ANN001
+ pass
+
+ callback = tvm_ffi.convert_func(cb, tensor_cls=torch.Tensor)
+ invoke_n(callback, 10, a, b, c)
+ start = time.time()
+ invoke_n(callback, repeat, a, b, c)
+ end = time.time()
+ print_speed("pycallback[tensor_cls=torch.Tensor]", (end - start) / repeat)
+
+
+def bench_pycallback_from_dlpack(invoke_n, a, b, c, repeat: int) -> None: #
noqa: ANN001
+ """convert_func(cb): callback receives ffi.Tensor, does
torch.from_dlpack(x) explicitly."""
+
+ def cb(_a, _b, _c) -> None: # noqa: ANN001
+ torch.from_dlpack(_a)
+ torch.from_dlpack(_b)
+ torch.from_dlpack(_c)
+
+ callback = tvm_ffi.convert_func(cb)
+ invoke_n(callback, 10, a, b, c)
+ start = time.time()
+ invoke_n(callback, repeat, a, b, c)
+ end = time.time()
+ print_speed("pycallback+from_dlpack", (end - start) / repeat)
+
+
+def main() -> None:
+ if not hasattr(torch.Tensor, "__dlpack_c_exchange_api__"):
+ raise SystemExit("torch.Tensor.__dlpack_c_exchange_api__ not
available")
+
+ repeat = 10000
+ invoke_n = _load_invoke_n()
+ a = torch.zeros(1, device="cuda:0")
+ b = torch.zeros(1, device="cuda:0")
+ c = torch.zeros(1, device="cuda:0")
+
+ print("---------------------------------------------------")
+ print("Benchmark C++ -> Python callback with 3 torch.Tensor args")
+ print('Arguments: 3 x torch.zeros(1, device="cuda:0")')
+ print("---------------------------------------------------")
+ bench_pycallback_tensor_cls_torch(invoke_n, a, b, c, repeat)
+ bench_pycallback_from_dlpack(invoke_n, a, b, c, repeat)
+ print("---------------------------------------------------")
+
+
+if __name__ == "__main__":
+ main()