This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5c17111ed9 [Fix][Runtime][RPC] Fix remote tensor handle cleanup for
RPC return values (#19410)
5c17111ed9 is described below
commit 5c17111ed98bae916f88aea32b281ef67aa249c8
Author: Shushi Hong <[email protected]>
AuthorDate: Thu Apr 16 08:39:39 2026 -0400
[Fix][Runtime][RPC] Fix remote tensor handle cleanup for RPC return values
(#19410)
This PR fixes RPC tensor cleanup for tensors returned from remote calls.
When a remote function returns a `Tensor`, the RPC protocol sends both:
- the remote backing data pointer
- the remote tensor object handle used for deletion
Previously, `TensorFromRemoteOpaqueHandle` stored only the data pointer
and called
`FreeHandle(space_.data)` during local tensor destruction. That is
incorrect:
`FreeHandle` is meant for remote object handles, not raw data-space
pointers.
This could lead to invalid cleanup behavior and crashes during teardown
in RPC workflows, including the cross-compilation + RPC tutorial
scenario reported in #18923.
This change:
- stores the remote tensor object handle in `RemoteSpace`
- calls `FreeHandle(remote_tensor_handle)` during tensor destruction
- keeps cleanup fault-tolerant if the remote connection is already
closed
---
python/tvm/rpc/testing.py | 1 +
src/runtime/rpc/rpc_module.cc | 11 ++-
src/runtime/rpc/rpc_session.h | 8 +++
tests/cpp/runtime/rpc_module_test.cc | 119 +++++++++++++++++++++++++++++++
tests/python/runtime/test_runtime_rpc.py | 23 +++---
5 files changed, 150 insertions(+), 12 deletions(-)
diff --git a/python/tvm/rpc/testing.py b/python/tvm/rpc/testing.py
index 4f2a0be32e..85735f41bc 100644
--- a/python/tvm/rpc/testing.py
+++ b/python/tvm/rpc/testing.py
@@ -21,6 +21,7 @@
import numpy as np
import tvm
+import tvm.testing
# RPC test functions to be registered for unit-tests purposes
diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc
index 13f2f0bb7c..34be3556a9 100644
--- a/src/runtime/rpc/rpc_module.cc
+++ b/src/runtime/rpc/rpc_module.cc
@@ -62,7 +62,15 @@ Tensor
TensorFromRemoteOpaqueHandle(std::shared_ptr<RPCSession> sess, void* hand
// the pointer to the remote space is passed in as the data pointer
tensor->data = &(space_);
}
- void FreeData(DLTensor* tensor) { space_.sess->FreeHandle(space_.data); }
+ void FreeData(DLTensor* tensor) {
+ if (space_.object_handle != nullptr) {
+ try {
+ space_.sess->FreeHandle(space_.object_handle);
+ } catch (const Error& e) {
+ // fault tolerance to remote close
+ }
+ }
+ }
private:
RemoteSpace space_;
@@ -70,6 +78,7 @@ Tensor
TensorFromRemoteOpaqueHandle(std::shared_ptr<RPCSession> sess, void* hand
RemoteSpace space;
space.sess = sess;
space.data = handle;
+ space.object_handle = remote_tensor_handle;
ffi::Shape shape(template_tensor->shape, template_tensor->shape +
template_tensor->ndim);
return Tensor::FromNDAlloc(RemoteSpaceAlloc(space), shape,
template_tensor->dtype, dev);
}
diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h
index d7f629a025..307469f131 100644
--- a/src/runtime/rpc/rpc_session.h
+++ b/src/runtime/rpc/rpc_session.h
@@ -281,6 +281,14 @@ struct RemoteSpace {
void* data;
/*! \brief Reference to the underlying RPC session. */
std::shared_ptr<RPCSession> sess;
+ /*!
+ * \brief The remote Tensor object handle, if this RemoteSpace wraps a
returned Tensor.
+ *
+ * Returned RPC Tensors carry both the backing data pointer and a Tensor
object handle. The
+ * object handle must be released with FreeHandle so the remote side can
correctly decrement the
+ * Tensor refcount and free the backing storage when it is no longer shared.
+ */
+ void* object_handle{nullptr};
};
/*!
diff --git a/tests/cpp/runtime/rpc_module_test.cc
b/tests/cpp/runtime/rpc_module_test.cc
new file mode 100644
index 0000000000..f551e2985c
--- /dev/null
+++ b/tests/cpp/runtime/rpc_module_test.cc
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <gtest/gtest.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/logging.h>
+#include <tvm/runtime/tensor.h>
+
+#include <memory>
+#include <string>
+
+#include "../../../src/runtime/rpc/rpc_session.h"
+
+namespace tvm {
+namespace runtime {
+
+Tensor TensorFromRemoteOpaqueHandle(std::shared_ptr<RPCSession> sess, void*
handle,
+ DLTensor* template_tensor, Device dev,
+ void* remote_tensor_handle);
+
+namespace {
+
+class RecordingRPCSession final : public RPCSession {
+ public:
+ PackedFuncHandle GetFunction(const std::string& name) final { return
nullptr; }
+
+ void CallFunc(PackedFuncHandle func, ffi::PackedArgs args,
+ const FEncodeReturn& fencode_return) final {}
+
+ void CopyToRemote(void* local_from_bytes, DLTensor* remote_to, uint64_t
nbytes) final {}
+
+ void CopyFromRemote(DLTensor* remote_from, void* local_to_bytes, uint64_t
nbytes) final {}
+
+ void FreeHandle(void* handle) final {
+ ++free_handle_calls;
+ last_freed_handle = handle;
+ if (throw_on_free) {
+ TVM_FFI_THROW(InternalError) << "simulated remote close";
+ }
+ }
+
+ DeviceAPI* GetDeviceAPI(Device dev, bool allow_missing = false) final {
return nullptr; }
+
+ bool IsLocalSession() const final { return false; }
+
+ int free_handle_calls{0};
+ void* last_freed_handle{nullptr};
+ bool throw_on_free{false};
+};
+
+DLTensor MakeTemplateTensor() {
+ static int64_t shape[1] = {4};
+ DLTensor tensor{};
+ tensor.data = nullptr;
+ tensor.device = Device{kDLCPU, 0};
+ tensor.ndim = 1;
+ tensor.dtype = DataType::Float(32);
+ tensor.shape = shape;
+ tensor.strides = nullptr;
+ tensor.byte_offset = 0;
+ return tensor;
+}
+
+Device MakeRemoteDevice(const std::shared_ptr<RPCSession>& sess) {
+ return AddRPCSessionMask(Device{kDLCPU, 0}, sess->table_index());
+}
+
+} // namespace
+
+TEST(RPCTensorTest, ReturnedTensorFreesRemoteTensorHandle) {
+ auto sess = std::make_shared<RecordingRPCSession>();
+ DLTensor template_tensor = MakeTemplateTensor();
+ void* data_handle = reinterpret_cast<void*>(0x1234);
+ void* tensor_handle = reinterpret_cast<void*>(0x5678);
+
+ {
+ auto tensor = TensorFromRemoteOpaqueHandle(sess, data_handle,
&template_tensor,
+ MakeRemoteDevice(sess),
tensor_handle);
+ EXPECT_NE(tensor.defined(), false);
+ }
+
+ EXPECT_EQ(sess->free_handle_calls, 1);
+ EXPECT_EQ(sess->last_freed_handle, tensor_handle);
+ EXPECT_NE(sess->last_freed_handle, data_handle);
+}
+
+TEST(RPCTensorTest, ReturnedTensorDestructorIgnoresFreeHandleErrors) {
+ auto sess = std::make_shared<RecordingRPCSession>();
+ sess->throw_on_free = true;
+ DLTensor template_tensor = MakeTemplateTensor();
+ void* data_handle = reinterpret_cast<void*>(0x1234);
+ void* tensor_handle = reinterpret_cast<void*>(0x5678);
+
+ EXPECT_NO_THROW({
+ auto tensor = TensorFromRemoteOpaqueHandle(sess, data_handle,
&template_tensor,
+ MakeRemoteDevice(sess),
tensor_handle);
+ });
+ EXPECT_EQ(sess->free_handle_calls, 1);
+ EXPECT_EQ(sess->last_freed_handle, tensor_handle);
+}
+
+} // namespace runtime
+} // namespace tvm
diff --git a/tests/python/runtime/test_runtime_rpc.py
b/tests/python/runtime/test_runtime_rpc.py
index 5600c7f887..ec112e1023 100644
--- a/tests/python/runtime/test_runtime_rpc.py
+++ b/tests/python/runtime/test_runtime_rpc.py
@@ -22,6 +22,7 @@ import stat
import sys
import tempfile
import time
+import gc
import numpy as np
import pytest
@@ -386,22 +387,22 @@ def test_rpc_session_constructor_args():
@tvm.testing.requires_rpc
def test_rpc_return_tensor():
- # start server
- server = rpc.Server(key="x1")
- client = rpc.connect("127.0.0.1", server.port, key="x1")
-
- m = client.get_function("rpc.test.remote_return_nd")
- get_arr = m("get_arr")
- ref_count = m("ref_count")
- get_elem = m("get_elem")
- get_arr_elem = m("get_arr_elem")
-
- # array test
def run_arr_test():
+ server = rpc.Server(key="x1")
+ client = rpc.connect("127.0.0.1", server.port, key="x1")
+ m = client.get_function("rpc.test.remote_return_nd")
+ get_arr = m("get_arr")
+ get_elem = m("get_elem")
+ get_arr_elem = m("get_arr_elem")
+
arr = get_arr()
assert get_elem(0) == 0.0
assert get_arr_elem(arr, 0) == 0.0
+ del arr
+ gc.collect()
+ assert get_elem(0) == 0.0
+
run_arr_test()