This is an automated email from the ASF dual-hosted git repository.
cyx-6 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new e472871 [FIX] Handle kTVMFFIRawStr / kTVMFFIByteArrayPtr in callback
args path (#573)
e472871 is described below
commit e472871900b80f1407d6880d103be4a874bb4084
Author: Tianqi Chen <[email protected]>
AuthorDate: Sun Apr 26 13:42:40 2026 -0400
[FIX] Handle kTVMFFIRawStr / kTVMFFIByteArrayPtr in callback args path
(#573)
## Summary
- **Bug**: When C++ invokes a Python callback with non-owning `RawStr`
or `ByteArrayPtr` arg shapes, `TVMFFICyCallbackArgSetterFactory` (in
`pycallback.pxi`) raised `ValueError` for both type indices, crashing
the callback immediately.
- **Fix**: Add two new Cython setter functions —
`TVMFFIPyCallbackArgSetterRawStr_` and
`TVMFFIPyCallbackArgSetterByteArrayPtr_` — and register them in
`TVMFFICyCallbackArgSetterFactory`. Each setter converts the non-owning
C-side view to a Python `str` (UTF-8 decode) or `bytes` (memcopy) before
returning. Both non-owning views remain valid for the callback's
duration, so there is no dangling-pointer concern.
- **Why not change `make_ret`**: The return-value path (C-side
`ffi::Function` invocation) never carries `kTVMFFIRawStr` /
`kTVMFFIByteArrayPtr` because C++ normalises to owned variants
(`ffi::String` / `ffi::Bytes`) before returning. The existing `raise
ValueError` arms in `make_ret` remain correct guards against an
impossible return-side type, exactly like the `kTVMFFIObjectRValueRef`
arm.
- **Test**: One regression test in `tests/python/test_function.py`
(`test_callback_rawstr_and_bytearrayptr_args`) uses
`tvm_ffi.cpp.load_inline` to build a C++ shim with two trampolines:
- `invoke_with_raw_str(callback)` — passes a C-string literal as `const
char*`, which `TypeTraits<const char*>::CopyToAnyView` packs as
`kTVMFFIRawStr` (with `v_c_str`).
- `invoke_with_byte_array_ptr(callback)` — passes a stack
`TVMFFIByteArray*`, which `TypeTraits<TVMFFIByteArray*>::CopyToAnyView`
packs as `kTVMFFIByteArrayPtr` (with `v_ptr`).
Both paths are exercised in a single test body; the callback asserts it
receives a `str` / `bytes` with the expected content.
## Files changed
- `python/tvm_ffi/cython/pycallback.pxi` — two new setter functions +
factory registration
- `tests/python/test_function.py` — one regression test + `import
tvm_ffi.cpp`
## Test plan
- [ ] `uv run pytest -vvs tests/python/test_function.py` — 20 passed
(was 19)
- [ ] `uv run pytest tests/python` — 2303 passed, 5 skipped, 2 xfailed
(3 pre-existing ROCm failures unrelated to this change)
- [ ] `pre-commit run --files python/tvm_ffi/cython/pycallback.pxi
tests/python/test_function.py` — all hooks passed
---
python/tvm_ffi/cython/pycallback.pxi | 41 +++++++++++++++++++++++++++--
tests/python/test_function.py | 51 ++++++++++++++++++++++++++++++++++++
2 files changed, 90 insertions(+), 2 deletions(-)
diff --git a/python/tvm_ffi/cython/pycallback.pxi
b/python/tvm_ffi/cython/pycallback.pxi
index fd5f686..59dfdb6 100644
--- a/python/tvm_ffi/cython/pycallback.pxi
+++ b/python/tvm_ffi/cython/pycallback.pxi
@@ -162,6 +162,43 @@ cdef int TVMFFIPyCallbackArgSetterDLTensorPtr_(
return 0
+cdef int TVMFFIPyCallbackArgSetterRawStr_(
+ TVMFFIPyCallbackArgSetter* handle,
+ const DLPackExchangeAPI* api,
+ const TVMFFIAny* arg,
+ PyObject** out
+) except -1:
+ """Callback arg setter for kTVMFFIRawStr -> Python str (UTF-8 decode).
+
+ ``arg.v_c_str`` is a non-owning ``const char*`` pointer into C-side storage
+ that remains valid for the duration of the callback invocation. We copy
the
+ contents into a Python ``str`` immediately, so there is no dangling-pointer
+ concern after the setter returns.
+ """
+ obj = arg.v_c_str.decode("utf-8")
+ Py_INCREF(obj)
+ out[0] = <PyObject*>obj
+ return 0
+
+
+cdef int TVMFFIPyCallbackArgSetterByteArrayPtr_(
+ TVMFFIPyCallbackArgSetter* handle,
+ const DLPackExchangeAPI* api,
+ const TVMFFIAny* arg,
+ PyObject** out
+) except -1:
+ """Callback arg setter for kTVMFFIByteArrayPtr -> Python bytes.
+
+ ``arg.v_ptr`` is a non-owning ``TVMFFIByteArray*`` pointer into C-side
+ storage valid for the callback's lifetime. ``bytearray_to_bytes`` copies
+ the raw bytes into a new Python ``bytes`` object immediately.
+ """
+ obj = bytearray_to_bytes(<TVMFFIByteArray*>arg.v_ptr)
+ Py_INCREF(obj)
+ out[0] = <PyObject*>obj
+ return 0
+
+
cdef int TVMFFIPyCallbackArgSetterRValueRef_(
TVMFFIPyCallbackArgSetter* handle,
const DLPackExchangeAPI* api,
@@ -241,9 +278,9 @@ cdef public int TVMFFICyCallbackArgSetterFactory(int32_t
type_index,
elif type_index == kTVMFFIObjectRValueRef:
out.func = TVMFFIPyCallbackArgSetterRValueRef_
elif type_index == kTVMFFIByteArrayPtr:
- raise ValueError("Callback arg cannot be ByteArrayPtr")
+ out.func = TVMFFIPyCallbackArgSetterByteArrayPtr_
elif type_index == kTVMFFIRawStr:
- raise ValueError("Callback arg cannot be RawStr")
+ out.func = TVMFFIPyCallbackArgSetterRawStr_
else:
raise ValueError("Unhandled type index %d" % type_index)
return 0
diff --git a/tests/python/test_function.py b/tests/python/test_function.py
index 3f673c1..3629118 100644
--- a/tests/python/test_function.py
+++ b/tests/python/test_function.py
@@ -24,6 +24,7 @@ from typing import Any
import numpy as np
import pytest
import tvm_ffi
+import tvm_ffi.cpp
try:
import torch
@@ -467,3 +468,53 @@ def test_convert_func_with_torch_tensor_cls() -> None:
out = f(a, b, c)
assert calls == 1
assert tuple(out.shape) == (3,)
+
+
+def test_callback_rawstr_and_bytearrayptr_args() -> None:
+ """Regression: C++ -> Python callback with kTVMFFIRawStr /
kTVMFFIByteArrayPtr args.
+
+ When C++ invokes a Python callback with non-owning RawStr or ByteArrayPtr
+ arg shapes, the callback arg setter must materialise a Python str / bytes
+ directly rather than hitting the ``raise ValueError`` guard that formerly
+ existed in ``TVMFFICyCallbackArgSetterFactory``.
+
+ Two trampolines are compiled via cpp.load_inline:
+ - ``invoke_with_raw_str(callback)`` — calls ``callback("hello rawstr")``
+ using a C-string literal, which the TypeTraits pack as kTVMFFIRawStr.
+ - ``invoke_with_byte_array_ptr(callback)`` — calls ``callback(&byte_arr)``
+ where ``byte_arr`` is a ``TVMFFIByteArray`` on the stack, packed as
+ kTVMFFIByteArrayPtr.
+ """
+ mod = tvm_ffi.cpp.load_inline(
+ name="test_callback_rawstr_bytearrayptr",
+ cpp_sources=r"""
+ void invoke_with_raw_str(tvm::ffi::Function callback) {
+ // Passing a string literal packs as kTVMFFIRawStr (const char*
TypeTraits).
+ callback("hello rawstr");
+ }
+
+ void invoke_with_byte_array_ptr(tvm::ffi::Function callback) {
+ // Passing a TVMFFIByteArray* packs as kTVMFFIByteArrayPtr.
+ static const char kData[] = "hello bytearrayptr";
+ TVMFFIByteArray byte_arr{kData, sizeof(kData) - 1};
+ callback(&byte_arr);
+ }
+ """,
+ functions=["invoke_with_raw_str", "invoke_with_byte_array_ptr"],
+ )
+
+ # --- kTVMFFIRawStr path ---
+ str_received: list[Any] = []
+ str_cb = tvm_ffi.convert(lambda x: str_received.append(x))
+ mod.invoke_with_raw_str(str_cb)
+ assert len(str_received) == 1
+ assert isinstance(str_received[0], str), f"expected str, got
{type(str_received[0])}"
+ assert str_received[0] == "hello rawstr"
+
+ # --- kTVMFFIByteArrayPtr path ---
+ bytes_received: list[Any] = []
+ bytes_cb = tvm_ffi.convert(lambda x: bytes_received.append(x))
+ mod.invoke_with_byte_array_ptr(bytes_cb)
+ assert len(bytes_received) == 1
+ assert isinstance(bytes_received[0], bytes), f"expected bytes, got
{type(bytes_received[0])}"
+ assert bytes_received[0] == b"hello bytearrayptr"