This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch handle-rawstr-bytearrayptr-callback
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git

commit 945238998c5379f9d930e8f778dba23e0d71a0ef
Author: tqchen <[email protected]>
AuthorDate: Sun Apr 26 16:43:12 2026 +0000

    [FIX] Handle kTVMFFIRawStr / kTVMFFIByteArrayPtr in callback args path
    
    Bug: when C++ invokes a Python callback with non-owning RawStr or
    ByteArrayPtr arg shapes, TVMFFICyCallbackArgSetterFactory raised
    ValueError for both type indices, crashing the callback immediately.
    
    Fix: add two new Cython setter functions —
    TVMFFIPyCallbackArgSetterRawStr_ and TVMFFIPyCallbackArgSetterByteArrayPtr_
    — and register them in TVMFFICyCallbackArgSetterFactory. Each setter
    converts the non-owning C-side view to a Python str (UTF-8 decode) or
    bytes (memcopy) before returning control to the callback.
    
    Why not change make_ret: the return-value path (C-side ffi::Function
    invocation) never carries kTVMFFIRawStr / kTVMFFIByteArrayPtr because
    C++ normalises to owned variants (ffi::String / ffi::Bytes) before
    returning. The existing raise ValueError arms in make_ret remain
    correct guards against an impossible return-side type.
    
    Test: single regression in tests/python/test_function.py
    (test_callback_rawstr_and_bytearrayptr_args) that uses cpp.load_inline
    to build a C++ shim with two trampolines. invoke_with_raw_str passes
    a string literal as const char* (TypeTraits packs it as kTVMFFIRawStr);
    invoke_with_byte_array_ptr passes a stack TVMFFIByteArray* (TypeTraits
    packs it as kTVMFFIByteArrayPtr). The Python test asserts that the
    callback receives a str / bytes with the expected content.
---
 python/tvm_ffi/cython/pycallback.pxi | 41 ++++++++++++++++++++++++--
 tests/python/test_function.py        | 56 ++++++++++++++++++++++++++++++++++++
 2 files changed, 95 insertions(+), 2 deletions(-)

diff --git a/python/tvm_ffi/cython/pycallback.pxi 
b/python/tvm_ffi/cython/pycallback.pxi
index fd5f686..59dfdb6 100644
--- a/python/tvm_ffi/cython/pycallback.pxi
+++ b/python/tvm_ffi/cython/pycallback.pxi
@@ -162,6 +162,43 @@ cdef int TVMFFIPyCallbackArgSetterDLTensorPtr_(
     return 0
 
 
+cdef int TVMFFIPyCallbackArgSetterRawStr_(
+    TVMFFIPyCallbackArgSetter* handle,
+    const DLPackExchangeAPI* api,
+    const TVMFFIAny* arg,
+    PyObject** out
+) except -1:
+    """Callback arg setter for kTVMFFIRawStr -> Python str (UTF-8 decode).
+
+    ``arg.v_c_str`` is a non-owning ``const char*`` pointer into C-side storage
+    that remains valid for the duration of the callback invocation.  We copy 
the
+    contents into a Python ``str`` immediately, so there is no dangling-pointer
+    concern after the setter returns.
+    """
+    obj = arg.v_c_str.decode("utf-8")
+    Py_INCREF(obj)
+    out[0] = <PyObject*>obj
+    return 0
+
+
+cdef int TVMFFIPyCallbackArgSetterByteArrayPtr_(
+    TVMFFIPyCallbackArgSetter* handle,
+    const DLPackExchangeAPI* api,
+    const TVMFFIAny* arg,
+    PyObject** out
+) except -1:
+    """Callback arg setter for kTVMFFIByteArrayPtr -> Python bytes.
+
+    ``arg.v_ptr`` is a non-owning ``TVMFFIByteArray*`` pointer into C-side
+    storage valid for the callback's lifetime.  ``bytearray_to_bytes`` copies
+    the raw bytes into a new Python ``bytes`` object immediately.
+    """
+    obj = bytearray_to_bytes(<TVMFFIByteArray*>arg.v_ptr)
+    Py_INCREF(obj)
+    out[0] = <PyObject*>obj
+    return 0
+
+
 cdef int TVMFFIPyCallbackArgSetterRValueRef_(
     TVMFFIPyCallbackArgSetter* handle,
     const DLPackExchangeAPI* api,
@@ -241,9 +278,9 @@ cdef public int TVMFFICyCallbackArgSetterFactory(int32_t 
type_index,
     elif type_index == kTVMFFIObjectRValueRef:
         out.func = TVMFFIPyCallbackArgSetterRValueRef_
     elif type_index == kTVMFFIByteArrayPtr:
-        raise ValueError("Callback arg cannot be ByteArrayPtr")
+        out.func = TVMFFIPyCallbackArgSetterByteArrayPtr_
     elif type_index == kTVMFFIRawStr:
-        raise ValueError("Callback arg cannot be RawStr")
+        out.func = TVMFFIPyCallbackArgSetterRawStr_
     else:
         raise ValueError("Unhandled type index %d" % type_index)
     return 0
diff --git a/tests/python/test_function.py b/tests/python/test_function.py
index 3f673c1..9ee6a8b 100644
--- a/tests/python/test_function.py
+++ b/tests/python/test_function.py
@@ -24,6 +24,7 @@ from typing import Any
 import numpy as np
 import pytest
 import tvm_ffi
+import tvm_ffi.cpp
 
 try:
     import torch
@@ -467,3 +468,58 @@ def test_convert_func_with_torch_tensor_cls() -> None:
     out = f(a, b, c)
     assert calls == 1
     assert tuple(out.shape) == (3,)
+
+
+def test_callback_rawstr_and_bytearrayptr_args() -> None:
+    """Regression: C++ -> Python callback with kTVMFFIRawStr / 
kTVMFFIByteArrayPtr args.
+
+    When C++ invokes a Python callback with non-owning RawStr or ByteArrayPtr
+    arg shapes, the callback arg setter must materialise a Python str / bytes
+    directly rather than hitting the ``raise ValueError`` guard that formerly
+    existed in ``TVMFFICyCallbackArgSetterFactory``.
+
+    Two trampolines are compiled via cpp.load_inline:
+    - ``invoke_with_raw_str(callback)``    — calls ``callback("hello rawstr")``
+      using a C-string literal, which the TypeTraits pack as kTVMFFIRawStr.
+    - ``invoke_with_byte_array_ptr(callback)`` — calls ``callback(&byte_arr)``
+      where ``byte_arr`` is a ``TVMFFIByteArray`` on the stack, packed as
+      kTVMFFIByteArrayPtr.
+    """
+    mod = tvm_ffi.cpp.load_inline(
+        name="test_callback_rawstr_bytearrayptr",
+        cpp_sources="""
+#include <tvm/ffi/function.h>
+#include <tvm/ffi/c_api.h>
+
+using namespace tvm::ffi;
+
+void invoke_with_raw_str(Function callback) {
+    // Passing a string literal packs as kTVMFFIRawStr (const char* 
TypeTraits).
+    callback("hello rawstr");
+}
+
+void invoke_with_byte_array_ptr(Function callback) {
+    // Passing a TVMFFIByteArray* packs as kTVMFFIByteArrayPtr.
+    static const char kData[] = "hello bytearrayptr";
+    TVMFFIByteArray byte_arr{kData, sizeof(kData) - 1};
+    callback(&byte_arr);
+}
+""",
+        functions=["invoke_with_raw_str", "invoke_with_byte_array_ptr"],
+    )
+
+    # --- kTVMFFIRawStr path ---
+    str_received: list[Any] = []
+    str_cb = tvm_ffi.convert(lambda x: str_received.append(x))
+    mod.invoke_with_raw_str(str_cb)
+    assert len(str_received) == 1
+    assert isinstance(str_received[0], str), f"expected str, got 
{type(str_received[0])}"
+    assert str_received[0] == "hello rawstr"
+
+    # --- kTVMFFIByteArrayPtr path ---
+    bytes_received: list[Any] = []
+    bytes_cb = tvm_ffi.convert(lambda x: bytes_received.append(x))
+    mod.invoke_with_byte_array_ptr(bytes_cb)
+    assert len(bytes_received) == 1
+    assert isinstance(bytes_received[0], bytes), f"expected bytes, got 
{type(bytes_received[0])}"
+    assert bytes_received[0] == b"hello bytearrayptr"

Reply via email to