This is an automated email from the ASF dual-hosted git repository.

cyx-6 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new e472871  [FIX] Handle kTVMFFIRawStr / kTVMFFIByteArrayPtr in callback 
args path (#573)
e472871 is described below

commit e472871900b80f1407d6880d103be4a874bb4084
Author: Tianqi Chen <[email protected]>
AuthorDate: Sun Apr 26 13:42:40 2026 -0400

    [FIX] Handle kTVMFFIRawStr / kTVMFFIByteArrayPtr in callback args path 
(#573)
    
    ## Summary
    
    - **Bug**: When C++ invokes a Python callback with non-owning `RawStr`
    or `ByteArrayPtr` arg shapes, `TVMFFICyCallbackArgSetterFactory` (in
    `pycallback.pxi`) raised `ValueError` for both type indices, crashing
    the callback immediately.
    
    - **Fix**: Add two new Cython setter functions —
    `TVMFFIPyCallbackArgSetterRawStr_` and
    `TVMFFIPyCallbackArgSetterByteArrayPtr_` — and register them in
    `TVMFFICyCallbackArgSetterFactory`. Each setter converts the non-owning
    C-side view to a Python `str` (UTF-8 decode) or `bytes` (memcopy) before
    returning. Both non-owning views remain valid for the callback's
    duration, so there is no dangling-pointer concern.
    
    - **Why not change `make_ret`**: The return-value path (C-side
    `ffi::Function` invocation) never carries `kTVMFFIRawStr` /
    `kTVMFFIByteArrayPtr` because C++ normalises to owned variants
    (`ffi::String` / `ffi::Bytes`) before returning. The existing `raise
    ValueError` arms in `make_ret` remain correct guards against an
    impossible return-side type, exactly like the `kTVMFFIObjectRValueRef`
    arm.
    
    - **Test**: One regression test in `tests/python/test_function.py`
    (`test_callback_rawstr_and_bytearrayptr_args`) uses
    `tvm_ffi.cpp.load_inline` to build a C++ shim with two trampolines:
    - `invoke_with_raw_str(callback)` — passes a C-string literal as `const
    char*`, which `TypeTraits<const char*>::CopyToAnyView` packs as
    `kTVMFFIRawStr` (with `v_c_str`).
    - `invoke_with_byte_array_ptr(callback)` — passes a stack
    `TVMFFIByteArray*`, which `TypeTraits<TVMFFIByteArray*>::CopyToAnyView`
    packs as `kTVMFFIByteArrayPtr` (with `v_ptr`).
    Both paths are exercised in a single test body; the callback asserts it
    receives a `str` / `bytes` with the expected content.
    
    ## Files changed
    
    - `python/tvm_ffi/cython/pycallback.pxi` — two new setter functions +
    factory registration
    - `tests/python/test_function.py` — one regression test + `import
    tvm_ffi.cpp`
    
    ## Test plan
    
    - [ ] `uv run pytest -vvs tests/python/test_function.py` — 20 passed
    (was 19)
    - [ ] `uv run pytest tests/python` — 2303 passed, 5 skipped, 2 xfailed
    (3 pre-existing ROCm failures unrelated to this change)
    - [ ] `pre-commit run --files python/tvm_ffi/cython/pycallback.pxi
    tests/python/test_function.py` — all hooks passed
---
 python/tvm_ffi/cython/pycallback.pxi | 41 +++++++++++++++++++++++++++--
 tests/python/test_function.py        | 51 ++++++++++++++++++++++++++++++++++++
 2 files changed, 90 insertions(+), 2 deletions(-)

diff --git a/python/tvm_ffi/cython/pycallback.pxi 
b/python/tvm_ffi/cython/pycallback.pxi
index fd5f686..59dfdb6 100644
--- a/python/tvm_ffi/cython/pycallback.pxi
+++ b/python/tvm_ffi/cython/pycallback.pxi
@@ -162,6 +162,43 @@ cdef int TVMFFIPyCallbackArgSetterDLTensorPtr_(
     return 0
 
 
+cdef int TVMFFIPyCallbackArgSetterRawStr_(
+    TVMFFIPyCallbackArgSetter* handle,
+    const DLPackExchangeAPI* api,
+    const TVMFFIAny* arg,
+    PyObject** out
+) except -1:
+    """Callback arg setter for kTVMFFIRawStr -> Python str (UTF-8 decode).
+
+    ``arg.v_c_str`` is a non-owning ``const char*`` pointer into C-side storage
+    that remains valid for the duration of the callback invocation.  We copy 
the
+    contents into a Python ``str`` immediately, so there is no dangling-pointer
+    concern after the setter returns.
+    """
+    obj = arg.v_c_str.decode("utf-8")
+    Py_INCREF(obj)
+    out[0] = <PyObject*>obj
+    return 0
+
+
+cdef int TVMFFIPyCallbackArgSetterByteArrayPtr_(
+    TVMFFIPyCallbackArgSetter* handle,
+    const DLPackExchangeAPI* api,
+    const TVMFFIAny* arg,
+    PyObject** out
+) except -1:
+    """Callback arg setter for kTVMFFIByteArrayPtr -> Python bytes.
+
+    ``arg.v_ptr`` is a non-owning ``TVMFFIByteArray*`` pointer into C-side
+    storage valid for the callback's lifetime.  ``bytearray_to_bytes`` copies
+    the raw bytes into a new Python ``bytes`` object immediately.
+    """
+    obj = bytearray_to_bytes(<TVMFFIByteArray*>arg.v_ptr)
+    Py_INCREF(obj)
+    out[0] = <PyObject*>obj
+    return 0
+
+
 cdef int TVMFFIPyCallbackArgSetterRValueRef_(
     TVMFFIPyCallbackArgSetter* handle,
     const DLPackExchangeAPI* api,
@@ -241,9 +278,9 @@ cdef public int TVMFFICyCallbackArgSetterFactory(int32_t 
type_index,
     elif type_index == kTVMFFIObjectRValueRef:
         out.func = TVMFFIPyCallbackArgSetterRValueRef_
     elif type_index == kTVMFFIByteArrayPtr:
-        raise ValueError("Callback arg cannot be ByteArrayPtr")
+        out.func = TVMFFIPyCallbackArgSetterByteArrayPtr_
     elif type_index == kTVMFFIRawStr:
-        raise ValueError("Callback arg cannot be RawStr")
+        out.func = TVMFFIPyCallbackArgSetterRawStr_
     else:
         raise ValueError("Unhandled type index %d" % type_index)
     return 0
diff --git a/tests/python/test_function.py b/tests/python/test_function.py
index 3f673c1..3629118 100644
--- a/tests/python/test_function.py
+++ b/tests/python/test_function.py
@@ -24,6 +24,7 @@ from typing import Any
 import numpy as np
 import pytest
 import tvm_ffi
+import tvm_ffi.cpp
 
 try:
     import torch
@@ -467,3 +468,53 @@ def test_convert_func_with_torch_tensor_cls() -> None:
     out = f(a, b, c)
     assert calls == 1
     assert tuple(out.shape) == (3,)
+
+
+def test_callback_rawstr_and_bytearrayptr_args() -> None:
+    """Regression: C++ -> Python callback with kTVMFFIRawStr / 
kTVMFFIByteArrayPtr args.
+
+    When C++ invokes a Python callback with non-owning RawStr or ByteArrayPtr
+    arg shapes, the callback arg setter must materialise a Python str / bytes
+    directly rather than hitting the ``raise ValueError`` guard that formerly
+    existed in ``TVMFFICyCallbackArgSetterFactory``.
+
+    Two trampolines are compiled via cpp.load_inline:
+    - ``invoke_with_raw_str(callback)``    — calls ``callback("hello rawstr")``
+      using a C-string literal, which the TypeTraits pack as kTVMFFIRawStr.
+    - ``invoke_with_byte_array_ptr(callback)`` — calls ``callback(&byte_arr)``
+      where ``byte_arr`` is a ``TVMFFIByteArray`` on the stack, packed as
+      kTVMFFIByteArrayPtr.
+    """
+    mod = tvm_ffi.cpp.load_inline(
+        name="test_callback_rawstr_bytearrayptr",
+        cpp_sources=r"""
+            void invoke_with_raw_str(tvm::ffi::Function callback) {
+              // Passing a string literal packs as kTVMFFIRawStr (const char* 
TypeTraits).
+              callback("hello rawstr");
+            }
+
+            void invoke_with_byte_array_ptr(tvm::ffi::Function callback) {
+              // Passing a TVMFFIByteArray* packs as kTVMFFIByteArrayPtr.
+              static const char kData[] = "hello bytearrayptr";
+              TVMFFIByteArray byte_arr{kData, sizeof(kData) - 1};
+              callback(&byte_arr);
+            }
+        """,
+        functions=["invoke_with_raw_str", "invoke_with_byte_array_ptr"],
+    )
+
+    # --- kTVMFFIRawStr path ---
+    str_received: list[Any] = []
+    str_cb = tvm_ffi.convert(lambda x: str_received.append(x))
+    mod.invoke_with_raw_str(str_cb)
+    assert len(str_received) == 1
+    assert isinstance(str_received[0], str), f"expected str, got 
{type(str_received[0])}"
+    assert str_received[0] == "hello rawstr"
+
+    # --- kTVMFFIByteArrayPtr path ---
+    bytes_received: list[Any] = []
+    bytes_cb = tvm_ffi.convert(lambda x: bytes_received.append(x))
+    mod.invoke_with_byte_array_ptr(bytes_cb)
+    assert len(bytes_received) == 1
+    assert isinstance(bytes_received[0], bytes), f"expected bytes, got 
{type(bytes_received[0])}"
+    assert bytes_received[0] == b"hello bytearrayptr"

Reply via email to