This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch handle-rawstr-bytearrayptr-callback in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
commit 945238998c5379f9d930e8f778dba23e0d71a0ef Author: tqchen <[email protected]> AuthorDate: Sun Apr 26 16:43:12 2026 +0000 [FIX] Handle kTVMFFIRawStr / kTVMFFIByteArrayPtr in callback args path Bug: when C++ invokes a Python callback with non-owning RawStr or ByteArrayPtr arg shapes, TVMFFICyCallbackArgSetterFactory raised ValueError for both type indices, crashing the callback immediately. Fix: add two new Cython setter functions — TVMFFIPyCallbackArgSetterRawStr_ and TVMFFIPyCallbackArgSetterByteArrayPtr_ — and register them in TVMFFICyCallbackArgSetterFactory. Each setter converts the non-owning C-side view to a Python str (UTF-8 decode) or bytes (memcopy) before returning control to the callback. Why not change make_ret: the return-value path (C-side ffi::Function invocation) never carries kTVMFFIRawStr / kTVMFFIByteArrayPtr because C++ normalises to owned variants (ffi::String / ffi::Bytes) before returning. The existing raise ValueError arms in make_ret remain correct guards against an impossible return-side type. Test: single regression in tests/python/test_function.py (test_callback_rawstr_and_bytearrayptr_args) that uses cpp.load_inline to build a C++ shim with two trampolines. invoke_with_raw_str passes a string literal as const char* (TypeTraits packs it as kTVMFFIRawStr); invoke_with_byte_array_ptr passes a stack TVMFFIByteArray* (TypeTraits packs it as kTVMFFIByteArrayPtr). The Python test asserts that the callback receives a str / bytes with the expected content. --- python/tvm_ffi/cython/pycallback.pxi | 41 ++++++++++++++++++++++++-- tests/python/test_function.py | 56 ++++++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 2 deletions(-) diff --git a/python/tvm_ffi/cython/pycallback.pxi b/python/tvm_ffi/cython/pycallback.pxi index fd5f686..59dfdb6 100644 --- a/python/tvm_ffi/cython/pycallback.pxi +++ b/python/tvm_ffi/cython/pycallback.pxi @@ -162,6 +162,43 @@ cdef int TVMFFIPyCallbackArgSetterDLTensorPtr_( return 0 +cdef int TVMFFIPyCallbackArgSetterRawStr_( + TVMFFIPyCallbackArgSetter* handle, + const DLPackExchangeAPI* api, + const TVMFFIAny* arg, + PyObject** out +) except -1: + """Callback arg setter for kTVMFFIRawStr -> Python str (UTF-8 decode). + + ``arg.v_c_str`` is a non-owning ``const char*`` pointer into C-side storage + that remains valid for the duration of the callback invocation. We copy the + contents into a Python ``str`` immediately, so there is no dangling-pointer + concern after the setter returns. + """ + obj = arg.v_c_str.decode("utf-8") + Py_INCREF(obj) + out[0] = <PyObject*>obj + return 0 + + +cdef int TVMFFIPyCallbackArgSetterByteArrayPtr_( + TVMFFIPyCallbackArgSetter* handle, + const DLPackExchangeAPI* api, + const TVMFFIAny* arg, + PyObject** out +) except -1: + """Callback arg setter for kTVMFFIByteArrayPtr -> Python bytes. + + ``arg.v_ptr`` is a non-owning ``TVMFFIByteArray*`` pointer into C-side + storage valid for the callback's lifetime. ``bytearray_to_bytes`` copies + the raw bytes into a new Python ``bytes`` object immediately. + """ + obj = bytearray_to_bytes(<TVMFFIByteArray*>arg.v_ptr) + Py_INCREF(obj) + out[0] = <PyObject*>obj + return 0 + + cdef int TVMFFIPyCallbackArgSetterRValueRef_( TVMFFIPyCallbackArgSetter* handle, const DLPackExchangeAPI* api, @@ -241,9 +278,9 @@ cdef public int TVMFFICyCallbackArgSetterFactory(int32_t type_index, elif type_index == kTVMFFIObjectRValueRef: out.func = TVMFFIPyCallbackArgSetterRValueRef_ elif type_index == kTVMFFIByteArrayPtr: - raise ValueError("Callback arg cannot be ByteArrayPtr") + out.func = TVMFFIPyCallbackArgSetterByteArrayPtr_ elif type_index == kTVMFFIRawStr: - raise ValueError("Callback arg cannot be RawStr") + out.func = TVMFFIPyCallbackArgSetterRawStr_ else: raise ValueError("Unhandled type index %d" % type_index) return 0 diff --git a/tests/python/test_function.py b/tests/python/test_function.py index 3f673c1..9ee6a8b 100644 --- a/tests/python/test_function.py +++ b/tests/python/test_function.py @@ -24,6 +24,7 @@ from typing import Any import numpy as np import pytest import tvm_ffi +import tvm_ffi.cpp try: import torch @@ -467,3 +468,58 @@ def test_convert_func_with_torch_tensor_cls() -> None: out = f(a, b, c) assert calls == 1 assert tuple(out.shape) == (3,) + + +def test_callback_rawstr_and_bytearrayptr_args() -> None: + """Regression: C++ -> Python callback with kTVMFFIRawStr / kTVMFFIByteArrayPtr args. + + When C++ invokes a Python callback with non-owning RawStr or ByteArrayPtr + arg shapes, the callback arg setter must materialise a Python str / bytes + directly rather than hitting the ``raise ValueError`` guard that formerly + existed in ``TVMFFICyCallbackArgSetterFactory``. + + Two trampolines are compiled via cpp.load_inline: + - ``invoke_with_raw_str(callback)`` — calls ``callback("hello rawstr")`` + using a C-string literal, which the TypeTraits pack as kTVMFFIRawStr. + - ``invoke_with_byte_array_ptr(callback)`` — calls ``callback(&byte_arr)`` + where ``byte_arr`` is a ``TVMFFIByteArray`` on the stack, packed as + kTVMFFIByteArrayPtr. + """ + mod = tvm_ffi.cpp.load_inline( + name="test_callback_rawstr_bytearrayptr", + cpp_sources=""" +#include <tvm/ffi/function.h> +#include <tvm/ffi/c_api.h> + +using namespace tvm::ffi; + +void invoke_with_raw_str(Function callback) { + // Passing a string literal packs as kTVMFFIRawStr (const char* TypeTraits). + callback("hello rawstr"); +} + +void invoke_with_byte_array_ptr(Function callback) { + // Passing a TVMFFIByteArray* packs as kTVMFFIByteArrayPtr. + static const char kData[] = "hello bytearrayptr"; + TVMFFIByteArray byte_arr{kData, sizeof(kData) - 1}; + callback(&byte_arr); +} +""", + functions=["invoke_with_raw_str", "invoke_with_byte_array_ptr"], + ) + + # --- kTVMFFIRawStr path --- + str_received: list[Any] = [] + str_cb = tvm_ffi.convert(lambda x: str_received.append(x)) + mod.invoke_with_raw_str(str_cb) + assert len(str_received) == 1 + assert isinstance(str_received[0], str), f"expected str, got {type(str_received[0])}" + assert str_received[0] == "hello rawstr" + + # --- kTVMFFIByteArrayPtr path --- + bytes_received: list[Any] = [] + bytes_cb = tvm_ffi.convert(lambda x: bytes_received.append(x)) + mod.invoke_with_byte_array_ptr(bytes_cb) + assert len(bytes_received) == 1 + assert isinstance(bytes_received[0], bytes), f"expected bytes, got {type(bytes_received[0])}" + assert bytes_received[0] == b"hello bytearrayptr"
