This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d705cb2de4 [Relax][IR] Skip in-place multiply when two operands are
views of the same tensor (#19644)
d705cb2de4 is described below
commit d705cb2de4aef1ed8a5594eaf92f7c597e77ff8d
Author: ConvolutedDog <[email protected]>
AuthorDate: Sun May 31 23:50:18 2026 +0800
[Relax][IR] Skip in-place multiply when two operands are views of the same
tensor (#19644)
This PR will fix https://github.com/apache/tvm/issues/19577.
In this issue, the IRModule before applying any pass looks like:
```
%x: Tensor[(4,), float32] // function param
with R.dataflow():
%lv = expand_dims(%x, axis=1) // (4, 1)
%lv1 = expand_dims(%x, axis=1) // (4, 1) second call, new Var
%lv2 = multiply(%lv, %lv1) // (4, 1)
%lv3 = concat(%lv2, %lv1, axis=1) // (4, 2)
...
```
When the users manually apply the `DataflowUseInplaceCalls` pass, the
pass will rewrite the statement `%lv2 = multiply(%lv, %lv1)` to be like
`%lv = multiply(%lv, %lv1); %lv3 = concat(%lv, %lv1, axis=1)`, which
reuses the %lv buffer to avoid storage waste.
But this rewrite will chang the buffer context of %lv, and also in LLVM
generated code, %lv1 shared the same storage with %lv, so when executing
`%lv = concat(%lv, %lv1, axis=1)`, the %lv1 context has also been
changed to `multiply(%lv, %lv1)`. So the failure is due to the shared
storage of different views of the same tensor %x.
During the execution, %lv1 holds `x^2` instead of `x` after `multiply`.
`concat` reads %lv1 for the right column and its result is
[[1,1],[4,4],[9,9],[16,16]] instead of [[1,1],[4,2],[9,3],[16,4]] (the
correct result should be : left col `x^2`, right col should stay `x`).
Change: View-like ops (expand_dims, squeeze, reshape, permute_dims,
memory.view, ensure_zero_offset) take the input's alias set in alias
analysis instead of a new id: %lv and %lv1 share alias with %x. Then the
pass rejects in-place of `multiply(%lv, %lv1)`: %lv and %lv1 are
different vars but alias ids intersect, so no operand may be reused
in-place.
---
include/tvm/runtime/tensor.h | 20 ++
src/relax/transform/dataflow_inplace.cc | 67 ++++-
src/runtime/tensor.cc | 30 ++-
tests/python/relax/test_dataflow_inplace.py | 390 ++++++++++++++++++++++++++++
4 files changed, 505 insertions(+), 2 deletions(-)
diff --git a/include/tvm/runtime/tensor.h b/include/tvm/runtime/tensor.h
index 33a78a48d6..d3497c8ff7 100644
--- a/include/tvm/runtime/tensor.h
+++ b/include/tvm/runtime/tensor.h
@@ -183,6 +183,26 @@ class Tensor : public tvm::ffi::Tensor {
*/
TVM_RUNTIME_DLL static void CopyFromBytes(const DLTensor* to, void* from,
size_t nbytes,
TVMStreamHandle stream = nullptr);
+
+ /*!
+ * \brief Check if two tensors share the same underlying storage.
+ *
+ * This detects runtime storage aliasing (e.g. views from CreateView, etc.)
but does
+ * not imply either tensor was created by CreateView.
+ *
+ * \param a The first tensor.
+ * \param b The second tensor.
+ * \return True if the tensors share the same storage.
+ */
+ TVM_RUNTIME_DLL static bool IsStorageShared(const DLTensor* a, const
DLTensor* b);
+
+ /*!
+ * \brief Tensor overload of IsStorageShared.
+ * \param a The first tensor.
+ * \param b The second tensor.
+ * \return True if the tensors share the same storage.
+ */
+ static bool IsStorageShared(const Tensor& a, const Tensor& b);
};
/*!
diff --git a/src/relax/transform/dataflow_inplace.cc
b/src/relax/transform/dataflow_inplace.cc
index 8072ee5d14..c3ed7ef0b6 100644
--- a/src/relax/transform/dataflow_inplace.cc
+++ b/src/relax/transform/dataflow_inplace.cc
@@ -39,6 +39,67 @@
namespace tvm {
namespace relax {
+// Ops that may return a tensor sharing storage with the first argument.
+// These ops has been verified to share storage with the first argument in
+// tests/python/relax/test_dataflow_inplace.py.
+bool IsViewMemoryOp(const OpNode* op_node) {
+ // TODO: Consider to add more ops that may return a tensor sharing storage
with
+ // the first argument in the future.
+ static const std::unordered_set<std::string> kViewOps = {
+ "relax.expand_dims", "relax.squeeze",
+ "relax.reshape", "relax.permute_dims",
+ "relax.flatten", "relax.nn.batch_flatten",
+ "relax.memory.view", "relax.memory.ensure_zero_offset",
+ };
+ return kViewOps.count(op_node->name);
+}
+
+// Look up alias ids for a call argument (only Var args are expected in
dataflow blocks).
+std::unordered_set<int> GetVarAliasSetFromExpr(
+ const Expr& arg, const std::unordered_map<Var, std::unordered_set<int>>&
alias_sets) {
+ if (auto* var_node = arg.as<VarNode>()) {
+ Var var = ffi::GetRef<Var>(var_node);
+ if (!alias_sets.count(var)) {
+ return {-1};
+ }
+ return alias_sets.at(var);
+ }
+ return {-1};
+}
+
+// In-place on arg `candidate` is invalid if another distinct operand may
alias the same
+// storage (e.g. two expand_dims views of x bound to different vars). Reject
on any shared
+// alias id; -1 in the other operand's set does not skip checking other ids.
Same var twice
+// (e.g. add(z, z)) is allowed.
+bool InplaceArgDisjointFromOtherCallArgs(
+ const CallNode* call_node, int candidate,
+ const std::unordered_map<Var, std::unordered_set<int>>& alias_sets) {
+ const auto* cand_var_node = call_node->args[candidate].as<VarNode>();
+ if (!cand_var_node) {
+ return false;
+ }
+ auto cand_set = GetVarAliasSetFromExpr(call_node->args[candidate],
alias_sets);
+ if (cand_set.count(-1)) {
+ return false;
+ }
+ for (size_t j = 0; j < call_node->args.size(); j++) {
+ if (static_cast<int>(j) == candidate) {
+ continue;
+ }
+ const Expr& other_arg = call_node->args[j];
+ if (other_arg.same_as(call_node->args[candidate])) {
+ continue;
+ }
+ auto other_set = GetVarAliasSetFromExpr(other_arg, alias_sets);
+ for (int alias_idx : other_set) {
+ if (cand_set.count(alias_idx)) {
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
// Perform liveness analysis on a dataflow block, returning a map of vars to
// pairs of indices (the liveness interval, from the starting index to the end
index).
// A starting index of -1 means the var is defined before the block starts and
an end index
@@ -274,6 +335,9 @@ class AliasAnalyzer {
} else {
ret.insert(get_fresh_idx());
}
+ } else if (IsViewMemoryOp(op_node) && !call_node->args.empty()) {
+ // View-like ops may share storage with their input (and with other
views of it).
+ return GetAliasSet(call_node->args[0], bound_var);
} else {
// We are assuming most op calls return fresh values.
// We may have to track more exceptions
@@ -654,7 +718,8 @@ FindInplaceOpportunities(const DataflowBlock& block, const
ffi::Array<Var>& inpu
std::unordered_set<int> remove_candidates;
for (auto candidate : candidates) {
if (!InplaceConditionsMet(live_ranges, alias_sets, tuple_map,
currently_live,
- call_node->args[candidate], i)) {
+ call_node->args[candidate], i) ||
+ !InplaceArgDisjointFromOtherCallArgs(call_node, candidate,
alias_sets)) {
remove_candidates.insert(candidate);
}
}
diff --git a/src/runtime/tensor.cc b/src/runtime/tensor.cc
index 2b694b1742..887d576537 100644
--- a/src/runtime/tensor.cc
+++ b/src/runtime/tensor.cc
@@ -29,6 +29,8 @@
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/tensor.h>
+#include <algorithm>
+
#include "../support/base64.h"
#include "../support/bytes_io.h"
#include "tvm/runtime/data_type.h"
@@ -217,6 +219,30 @@ Tensor Tensor::CopyTo(const Device& dev,
ffi::Optional<ffi::String> mem_scope) c
return ret;
}
+inline char* StorageBegin(const DLTensor* tensor) {
+ TVM_FFI_ICHECK(tensor != nullptr);
+ return static_cast<char*>(tensor->data) + tensor->byte_offset;
+}
+
+inline char* StorageEnd(const DLTensor* tensor) {
+ TVM_FFI_ICHECK(tensor != nullptr);
+ return StorageBegin(tensor) + ffi::GetDataSize(*tensor);
+}
+
+bool Tensor::IsStorageShared(const DLTensor* a, const DLTensor* b) {
+ TVM_FFI_ICHECK(a != nullptr && b != nullptr);
+ if (a->device.device_type != b->device.device_type ||
+ a->device.device_id != b->device.device_id) {
+ return false;
+ }
+ return StorageBegin(a) == StorageBegin(b) && StorageEnd(a) == StorageEnd(b);
+}
+
+bool Tensor::IsStorageShared(const Tensor& a, const Tensor& b) {
+ TVM_FFI_ICHECK(a.defined() && b.defined());
+ return IsStorageShared(a.operator->(), b.operator->());
+}
+
void Tensor::CopyFromTo(const DLTensor* from, DLTensor* to, TVMStreamHandle
stream) {
size_t from_size = ffi::GetDataSize(*from);
size_t to_size = ffi::GetDataSize(*to);
@@ -270,5 +296,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
.def("runtime.TVMTensorCopyToBytes",
[](DLTensor* arr, void* data, size_t nbytes) {
Tensor::CopyToBytes(arr, data, nbytes); })
.def("runtime.TVMTensorCopyFromTo",
- [](DLTensor* from, DLTensor* to) { Tensor::CopyFromTo(from, to); });
+ [](DLTensor* from, DLTensor* to) { Tensor::CopyFromTo(from, to); })
+ .def("runtime.TVMTensorIsStorageShared",
+ [](Tensor a, Tensor b) { return Tensor::IsStorageShared(a, b); });
}
diff --git a/tests/python/relax/test_dataflow_inplace.py
b/tests/python/relax/test_dataflow_inplace.py
index 61791b2b32..1b23e14482 100644
--- a/tests/python/relax/test_dataflow_inplace.py
+++ b/tests/python/relax/test_dataflow_inplace.py
@@ -18,9 +18,12 @@
import numpy as np
+import pytest
+import torch
import tvm
from tvm import relax, testing
+from tvm.relax import VMInstrumentReturnKind
from tvm.relax.testing.transform import (
dataflow_alias_analysis,
dataflow_inplace_analysis,
@@ -643,5 +646,392 @@ def test_dynamic_mismatch():
tvm.ir.assert_structural_equal(new_mod, DynamicMistmatchTestCase)
+class TestViewOpSharedStorageAndNoInplace:
+ storage_ptr_x_1d = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
+ storage_ptr_x_2d = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32)
+ storage_ptr_x_squeeze = np.array([[[1.0], [2.0], [3.0], [4.0]]],
dtype=np.float32)
+ storage_ptr_x_ensure_zero_offset = np.array([[1.0], [2.0], [3.0], [4.0]],
dtype=np.float32)
+
+ @I.ir_module
+ class _SharedStorageExpandDimsModule:
+ @R.function
+ def main(x: R.Tensor((4,), dtype="float32")) -> R.Tensor((4, 1),
dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((4, 1), dtype="float32") = R.expand_dims(x,
axis=[1])
+ lv1: R.Tensor((4, 1), dtype="float32") = R.expand_dims(x,
axis=[1])
+ gv: R.Tensor((4, 1), dtype="float32") = R.add(lv, lv1)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class _SharedStorageSqueezeModule:
+ @R.function
+ def main(x: R.Tensor((1, 4, 1), dtype="float32")) -> R.Tensor((4, 1),
dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((4, 1), dtype="float32") = R.squeeze(x, axis=[0])
+ lv1: R.Tensor((4, 1), dtype="float32") = R.squeeze(x, axis=[0])
+ gv: R.Tensor((4, 1), dtype="float32") = R.add(lv, lv1)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class _SharedStorageReshapeModule:
+ @R.function
+ def main(x: R.Tensor((4,), dtype="float32")) -> R.Tensor((4, 1),
dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((4, 1), dtype="float32") = R.reshape(x, (4, 1))
+ lv1: R.Tensor((4, 1), dtype="float32") = R.reshape(x, (4, 1))
+ gv: R.Tensor((4, 1), dtype="float32") = R.add(lv, lv1)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class _SharedStoragePermuteDimsModule:
+ @R.function
+ def main(x: R.Tensor((1, 4), dtype="float32")) -> R.Tensor((4, 1),
dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((4, 1), dtype="float32") = R.permute_dims(x,
axes=[1, 0])
+ lv1: R.Tensor((4, 1), dtype="float32") = R.permute_dims(x,
axes=[1, 0])
+ gv: R.Tensor((4, 1), dtype="float32") = R.add(lv, lv1)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class _SharedStorageViewModule:
+ @R.function
+ def main(x: R.Tensor((4,), dtype="float32")) -> R.Tensor((1, 4),
dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((1, 4), dtype="float32") = R.memory.view(
+ x, R.shape([1, 4]), R.tuple(), R.tuple()
+ )
+ lv1: R.Tensor((1, 4), dtype="float32") = R.memory.view(
+ x, R.shape([1, 4]), R.tuple(), R.tuple()
+ )
+ gv: R.Tensor((1, 4), dtype="float32") = R.add(lv, lv1)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class _SharedStorageBatchFlattenModule:
+ @R.function
+ def main(x: R.Tensor((1, 4), dtype="float32")) -> R.Tensor((1, 4),
dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((1, 4), dtype="float32") = R.nn.batch_flatten(x)
+ lv1: R.Tensor((1, 4), dtype="float32") = R.nn.batch_flatten(x)
+ gv: R.Tensor((1, 4), dtype="float32") = R.add(lv, lv1)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class _SharedStorageFlattenModule:
+ @R.function
+ def main(x: R.Tensor((1, 4), dtype="float32")) -> R.Tensor((4,),
dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((4,), dtype="float32") = R.flatten(x)
+ lv1: R.Tensor((4,), dtype="float32") = R.flatten(x)
+ gv: R.Tensor((4,), dtype="float32") = R.add(lv, lv1)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class _SharedStorageEnsureZeroOffsetModule:
+ @R.function
+ def main(x: R.Tensor((4, 1), dtype="float32")) -> R.Tensor((4, 1),
dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((4, 1), dtype="float32") =
R.memory.ensure_zero_offset(x)
+ lv1: R.Tensor((4, 1), dtype="float32") =
R.memory.ensure_zero_offset(x)
+ gv: R.Tensor((4, 1), dtype="float32") = R.add(lv, lv1)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class _IndependentReluModule:
+ """Just a testcase to verify that non-view ops do not share storage."""
+
+ @R.function
+ def main(x: R.Tensor((4,), dtype="float32")) -> R.Tensor((4,),
dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((4,), dtype="float32") = R.nn.relu(x)
+ lv1: R.Tensor((4,), dtype="float32") = R.nn.relu(x)
+ gv: R.Tensor((4,), dtype="float32") = R.add(lv, lv1)
+ R.output(gv)
+ return gv
+
+ @classmethod
+ def _capture_op_tensors(cls, mod, input_nps, op_substr):
+ """Capture TVM tensors passed to VM calls whose name contains
op_substr."""
+ captures = []
+
+ def instrument(func, name, before_run, ret_value, *args):
+ del func, ret_value
+ if not before_run:
+ return VMInstrumentReturnKind.NO_OP
+ if op_substr not in name.lower():
+ return VMInstrumentReturnKind.NO_OP
+ tensor_args = [arg for arg in args if isinstance(arg,
tvm.runtime.Tensor)]
+ if not tensor_args:
+ return VMInstrumentReturnKind.NO_OP
+ captures.append({"call_name": name, "tensors": tensor_args})
+ return VMInstrumentReturnKind.NO_OP
+
+ if isinstance(input_nps, np.ndarray):
+ input_nps = [input_nps]
+
+ ex = relax.build(mod, tvm.target.Target("llvm"))
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ vm.set_instrument(instrument)
+ vm["main"](*(tvm.runtime.tensor(arr, tvm.cpu()) for arr in input_nps))
+ return captures
+
+ @pytest.mark.parametrize(
+ "mod,input_nps,op_substr,expect_same_storage",
+ [
+ pytest.param(
+ _SharedStorageExpandDimsModule,
+ [storage_ptr_x_1d],
+ "add",
+ True,
+ id="shared_storage_expand_dims",
+ ),
+ pytest.param(
+ _SharedStorageSqueezeModule,
+ [storage_ptr_x_squeeze],
+ "add",
+ True,
+ id="shared_storage_squeeze",
+ ),
+ pytest.param(
+ _SharedStorageReshapeModule,
+ [storage_ptr_x_1d],
+ "add",
+ True,
+ id="shared_storage_reshape",
+ ),
+ pytest.param(
+ _SharedStoragePermuteDimsModule,
+ [storage_ptr_x_2d],
+ "add",
+ True,
+ id="shared_storage_permute_dims",
+ ),
+ pytest.param(
+ _SharedStorageFlattenModule,
+ [storage_ptr_x_2d],
+ "add",
+ True,
+ id="shared_storage_flatten",
+ ),
+ pytest.param(
+ _SharedStorageBatchFlattenModule,
+ [storage_ptr_x_2d],
+ "add",
+ True,
+ id="shared_storage_batch_flatten",
+ ),
+ pytest.param(
+ _SharedStorageViewModule,
+ [storage_ptr_x_1d],
+ "add",
+ True,
+ id="shared_storage_memory_view",
+ ),
+ pytest.param(
+ _SharedStorageEnsureZeroOffsetModule,
+ [storage_ptr_x_ensure_zero_offset],
+ "add",
+ True,
+ id="shared_storage_ensure_zero_offset",
+ ),
+ pytest.param(
+ _IndependentReluModule,
+ [storage_ptr_x_1d],
+ "add",
+ False,
+ id="independent_storage_relu",
+ ),
+ ],
+ )
+ def test_tensor_storage_ptr_extraction(self, mod, input_nps, op_substr,
expect_same_storage):
+ """Validate runtime storage overlap/sharing via VM instrumentation."""
+ storage_shared =
tvm.get_global_func("runtime.TVMTensorIsStorageShared")
+ captures = self._capture_op_tensors(mod, input_nps, op_substr)
+ assert len(captures), f"VM instrumentation did not see a {op_substr}
call."
+ assert len(captures) == 1, f"VM instrumentation should see exactly one
{op_substr} call."
+ cap = captures[0]
+ assert len(cap["tensors"]) == 3, (
+ f"VM instrumentation should see three {op_substr} tensor operands."
+ )
+ tensor_a, tensor_b = cap["tensors"][0], cap["tensors"][1]
+ call_name = cap["call_name"]
+ if expect_same_storage:
+ assert storage_shared(tensor_a, tensor_b), (
+ f"{mod.__name__}: operands should share the same storage (call
{call_name!r})"
+ )
+ else:
+ assert not storage_shared(tensor_a, tensor_b), (
+ f"{mod.__name__}: operands must not share storage (call
{call_name!r})"
+ )
+
+ @staticmethod
+ def _emit_duplicate_view(op, x):
+ if op == "relax.expand_dims":
+ a = relax.op.expand_dims(x, axis=1)
+ b = relax.op.expand_dims(x, axis=1)
+ elif op == "relax.squeeze":
+ a = relax.op.squeeze(x, axis=[0])
+ b = relax.op.squeeze(x, axis=[0])
+ elif op == "relax.reshape":
+ a = relax.op.reshape(x, (4, 1))
+ b = relax.op.reshape(x, (4, 1))
+ elif op == "relax.permute_dims":
+ a = relax.op.permute_dims(x, axes=[1, 0])
+ b = relax.op.permute_dims(x, axes=[1, 0])
+ elif op == "relax.memory.view":
+ a = relax.op.memory.view(x, (4, 1))
+ b = relax.op.memory.view(x, (4, 1))
+ elif op == "relax.memory.ensure_zero_offset":
+ a = relax.op.memory.ensure_zero_offset(x)
+ b = relax.op.memory.ensure_zero_offset(x)
+ elif op == "relax.flatten":
+ a = relax.op.flatten(x)
+ b = relax.op.flatten(x)
+ elif op == "relax.nn.batch_flatten":
+ a = relax.op.nn.batch_flatten(x)
+ b = relax.op.nn.batch_flatten(x)
+ else:
+ raise ValueError(op)
+ return a, b
+
+ @staticmethod
+ def _concat_axis_for_view_op(op):
+ if op == "relax.flatten":
+ return 0
+ return 1
+
+ @classmethod
+ def _build_module(cls, op):
+ if op == "relax.expand_dims":
+ x_sinfo = relax.TensorStructInfo((4,), "float32")
+ elif op == "relax.squeeze":
+ x_sinfo = relax.TensorStructInfo((1, 4, 1), "float32")
+ elif op == "relax.reshape":
+ x_sinfo = relax.TensorStructInfo((4,), "float32")
+ elif op == "relax.permute_dims":
+ x_sinfo = relax.TensorStructInfo((1, 4), "float32")
+ elif op == "relax.memory.view":
+ x_sinfo = relax.TensorStructInfo((4,), "float32")
+ elif op == "relax.memory.ensure_zero_offset":
+ x_sinfo = relax.TensorStructInfo((4, 1), "float32")
+ elif op in ("relax.flatten", "relax.nn.batch_flatten"):
+ x_sinfo = relax.TensorStructInfo((1, 4), "float32")
+ else:
+ raise ValueError(op)
+
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", x_sinfo)
+ concat_axis = cls._concat_axis_for_view_op(op)
+ with bb.function("main", [x]):
+ with bb.dataflow():
+ a_expr, b_expr = cls._emit_duplicate_view(op, x)
+ a = bb.emit(a_expr)
+ b = bb.emit(b_expr)
+ prod = bb.emit(relax.op.multiply(a, b))
+ out = bb.emit(relax.op.concat([prod, b], axis=concat_axis))
+ gv = bb.emit_output(out)
+ bb.emit_func_output(gv)
+ return bb.finalize()
+
+ @classmethod
+ def _input_for_view_op(cls, op):
+ if op == "relax.squeeze":
+ return cls.storage_ptr_x_squeeze
+ if op == "relax.memory.ensure_zero_offset":
+ return cls.storage_ptr_x_ensure_zero_offset
+ if op in ("relax.permute_dims", "relax.flatten",
"relax.nn.batch_flatten"):
+ return cls.storage_ptr_x_2d
+ return cls.storage_ptr_x_1d
+
+ @staticmethod
+ def _torch_duplicate_view(x, op):
+ if op == "relax.expand_dims":
+ return x.unsqueeze(1)
+ if op == "relax.squeeze":
+ return x.squeeze(0)
+ if op == "relax.reshape":
+ return x.reshape(4, 1)
+ if op == "relax.permute_dims":
+ return x.permute(1, 0)
+ if op == "relax.memory.view":
+ return x.reshape(4, 1)
+ if op == "relax.memory.ensure_zero_offset":
+ return x
+ if op == "relax.flatten":
+ return x.flatten()
+ if op == "relax.nn.batch_flatten":
+ # TVM: ndim==2 input keeps shape (1, 4).
+ return x
+ raise ValueError(op)
+
+ @classmethod
+ def _expected_for_view_op(cls, op):
+ x = torch.from_numpy(np.asarray(cls._input_for_view_op(op),
dtype=np.float32))
+ a = cls._torch_duplicate_view(x, op)
+ b = cls._torch_duplicate_view(x, op)
+ prod = a * b
+ concat_axis = cls._concat_axis_for_view_op(op)
+ return torch.cat([prod, b], dim=concat_axis).numpy()
+
+ @pytest.mark.parametrize(
+ "view_op",
+ (
+ # Keep this list in sync with IsViewMemoryOp() in
+ # src/relax/transform/dataflow_inplace.cc
+ "relax.expand_dims",
+ "relax.squeeze",
+ "relax.reshape",
+ "relax.permute_dims",
+ "relax.flatten",
+ "relax.nn.batch_flatten",
+ "relax.memory.view",
+ "relax.memory.ensure_zero_offset",
+ ),
+ )
+ def test_no_inplace_when_view_ops_share_input(self, view_op):
+ mod = self._build_module(view_op)
+ func = mod["main"]
+ block = func.body.blocks[0]
+ params = list(func.params)
+
+ alias_sets, _ = dataflow_alias_analysis(block, params)
+ a_var = block.bindings[0].var
+ b_var = block.bindings[1].var
+ assert alias_sets[a_var] & alias_sets[b_var], (
+ f"{view_op}: duplicate views should share alias sets, but got "
+ f"{alias_sets[a_var]} and {alias_sets[b_var]}"
+ )
+
+ _, exact_match = dataflow_inplace_analysis(block, params, mod)
+ assert exact_match == [], f"{view_op}: expected no in-place
opportunities"
+
+ x_np = self._input_for_view_op(view_op).copy()
+ mod_inplace = DataflowUseInplaceCalls()(mod)
+ tvm.ir.assert_structural_equal(mod_inplace, mod)
+
+ storage_shared =
tvm.get_global_func("runtime.TVMTensorIsStorageShared")
+ captures = self._capture_op_tensors(mod_inplace, x_np, "multiply")
+ assert captures, f"{view_op}: VM instrumentation did not see a
multiply call."
+ cap = next(c for c in captures if len(c["tensors"]) >= 2)
+ tensor_a, tensor_b = cap["tensors"][0], cap["tensors"][1]
+ assert storage_shared(tensor_a, tensor_b), (
+ f"{view_op}: multiply operands should share the same storage at
runtime "
+ f"(call {cap['call_name']!r})"
+ )
+
+ ex = relax.build(mod_inplace, tvm.target.Target("llvm"))
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ out = vm["main"](tvm.runtime.tensor(x_np, tvm.cpu()))
+ np.testing.assert_allclose(out.numpy(),
self._expected_for_view_op(view_op))
+
+
if __name__ == "__main__":
testing.main()