This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 29de9ab2e3 Add op support for slice_scatter (#18019)
29de9ab2e3 is described below
commit 29de9ab2e3dc669313eaf03fdd44327c411a2ff5
Author: kavin-mcw <[email protected]>
AuthorDate: Wed May 28 18:51:50 2025 +0530
Add op support for slice_scatter (#18019)
* Implement slice_scatter e2e
* Fix lint issue
---
include/tvm/relax/attrs/manipulate.h | 9 ++
.../frontend/torch/base_fx_graph_translator.py | 13 ++
.../frontend/torch/exported_program_translator.py | 1 +
python/tvm/relax/frontend/torch/fx_translator.py | 1 +
python/tvm/relax/op/__init__.py | 1 +
python/tvm/relax/op/manipulate.py | 38 +++++
.../tvm/relax/transform/legalize_ops/manipulate.py | 14 ++
python/tvm/script/ir_builder/relax/ir.py | 2 +
python/tvm/topi/__init__.py | 1 +
python/tvm/topi/slice_scatter.py | 74 ++++++++++
src/relax/op/tensor/manipulate.cc | 155 +++++++++++++++++++++
src/relax/op/tensor/manipulate.h | 12 ++
.../relax/test_frontend_from_exported_program.py | 45 ++++++
tests/python/relax/test_frontend_from_fx.py | 45 ++++++
14 files changed, 411 insertions(+)
diff --git a/include/tvm/relax/attrs/manipulate.h
b/include/tvm/relax/attrs/manipulate.h
index 3a5e3951af..f8a6ddfe0a 100644
--- a/include/tvm/relax/attrs/manipulate.h
+++ b/include/tvm/relax/attrs/manipulate.h
@@ -229,6 +229,15 @@ struct ScatterNDAttrs : public
tvm::AttrsNode<ScatterNDAttrs> {
}
}; // struct ScatterNDAttrs
+/*! \brief Attributes used in slice_scatter operator */
+struct SliceScatterAttrs : public tvm::AttrsNode<SliceScatterAttrs> {
+ int axis;
+
+ TVM_DECLARE_ATTRS(SliceScatterAttrs, "relax.attrs.SliceScatterAttrs") {
+ TVM_ATTR_FIELD(axis).set_default(0).describe("the dimension to insert the
slice into ");
+ }
+}; // struct SliceScatterAttrs
+
/*! \brief Attributes used in one_hot operator */
struct OneHotAttrs : public tvm::AttrsNode<OneHotAttrs> {
int depth;
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 50969e85a5..485b7c088a 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1518,6 +1518,19 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
return self.block_builder.emit(relax.op.meshgrid(new_inputs,
indexing=indexing))
+ def _slice_scatter(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ input_tensor = args[0]
+ src = args[1]
+ dim = args[2] if len(args) > 2 else node.kwargs.get("dim", 0)
+ start = args[3] if len(args) > 3 else node.kwargs.get("start", 0)
+ end = args[4] if len(args) > 4 else node.kwargs.get("end",
self.shape_of(input_tensor)[dim])
+ step = args[5] if len(args) > 5 else node.kwargs.get("step", 1)
+
+ return self.block_builder.emit(
+ relax.op.slice_scatter(input_tensor, src, start, end, step,
axis=dim)
+ )
+
def _permute(self, node: fx.Node) -> relax.Var:
import torch # type: ignore
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index efa3de3a10..4e7c0bf324 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -493,6 +493,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"roll.default": self._roll,
"select.int": self._select,
"slice.Tensor": self._slice,
+ "slice_scatter.default": self._slice_scatter,
"sort.default": self._sort,
"split.Tensor": self._split,
"split_with_sizes.default": self._split,
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 97a2b51e49..33abccbe5f 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -909,6 +909,7 @@ class TorchFXImporter(BaseFXGraphImporter):
"scatter": self._scatter,
"select": self._select,
"size": self._size,
+ "slice_scatter": self._slice_scatter,
"sort": self._sort,
"split": self._split,
"squeeze": self._squeeze,
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 0a2f0980fd..c4a5d2fd23 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -105,6 +105,7 @@ from .manipulate import (
reshape,
scatter_elements,
scatter_nd,
+ slice_scatter,
split,
squeeze,
stack,
diff --git a/python/tvm/relax/op/manipulate.py
b/python/tvm/relax/op/manipulate.py
index b52aced59a..c71b19494a 100644
--- a/python/tvm/relax/op/manipulate.py
+++ b/python/tvm/relax/op/manipulate.py
@@ -786,6 +786,44 @@ def scatter_nd(data: Expr, indices: Expr, updates: Expr,
reduction: str = "updat
return _ffi_api.scatter_nd(data, indices, updates, reduction) # type:
ignore
+def slice_scatter(input_tensor: Expr, src: Expr, start, end, step, axis=0):
+ """Embeds the values of the src tensor into input at the given dimension.
+
+ Parameters
+ ----------
+ input_tensor: relax.Expr
+ The input tensor to be updated.
+
+ src: relax.Expr
+ The tensor to embed into input.
+
+ axis: int
+ The dimension to insert the slice into.
+
+ start:
+ The start index of where to insert the slice.
+
+ end:
+ The end index of where to insert the slice.
+
+ step:
+ The how many elements to skip in.
+
+ Returns
+ -------
+ result : relax.Expr
+ The computed result tensor with the same shape as `data`.
+
+ """
+ if not isinstance(start, PrimValue):
+ start = PrimValue(start)
+ if not isinstance(end, PrimValue):
+ end = PrimValue(end)
+ if not isinstance(step, PrimValue):
+ step = PrimValue(step)
+ return _ffi_api.slice_scatter(input_tensor, src, axis, start, end, step)
+
+
def one_hot(
indices: Expr, on_value: PrimValue, off_value: PrimValue, depth: int,
axis: int = -1
) -> Expr:
diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py
b/python/tvm/relax/transform/legalize_ops/manipulate.py
index 835be4bd4e..58abe434a2 100644
--- a/python/tvm/relax/transform/legalize_ops/manipulate.py
+++ b/python/tvm/relax/transform/legalize_ops/manipulate.py
@@ -263,6 +263,20 @@ def _scatter_nd(bb: BlockBuilder, call: Call) -> Expr:
)
+@register_legalize("relax.slice_scatter")
+def _slice_scatter(bb: BlockBuilder, call: Call) -> Expr:
+
+ return bb.call_te(
+ topi.slice_scatter,
+ call.args[0],
+ call.args[1],
+ call.args[2],
+ call.args[3],
+ call.args[4],
+ call.attrs.axis,
+ )
+
+
@register_legalize("relax.one_hot")
def _one_hot(bb: BlockBuilder, call: Call) -> Expr:
indices, on_value, off_value = call.args
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index b696d73031..92f84ce05c 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -157,6 +157,7 @@ from tvm.relax.op import (
sign,
sin,
sinh,
+ slice_scatter,
sort,
split,
sqrt,
@@ -854,6 +855,7 @@ __all__ = [
"sign",
"sin",
"sinh",
+ "slice_scatter",
"sort",
"split",
"square",
diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py
index fa4e98a89a..34ff213164 100644
--- a/python/tvm/topi/__init__.py
+++ b/python/tvm/topi/__init__.py
@@ -40,6 +40,7 @@ from .broadcast import *
from .sort import *
from .scatter import *
from .scatter_elements import *
+from .slice_scatter import *
from .sparse_reshape import *
from .scan import *
from .einsum import *
diff --git a/python/tvm/topi/slice_scatter.py b/python/tvm/topi/slice_scatter.py
new file mode 100644
index 0000000000..d8772d0f5b
--- /dev/null
+++ b/python/tvm/topi/slice_scatter.py
@@ -0,0 +1,74 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""SliceScatter operator"""
+from tvm import topi
+from . import utils
+
+
+def slice_scatter(input_tensor, src, start, end, step, axis):
+ """
+ Scatters a slice of src into input along the given axis (SSA form).
+
+ Args:
+ input_tensor (te.Tensor): The input tensor to scatter into.
+ src (te.Tensor): The source tensor to scatter from.
+ start (int): The starting index of the slice.
+ end (int): The ending index of the slice.
+ step (int): The step size of the slice.
+ axis (int): The axis to scatter along.
+
+ Returns:
+ list[te.Tensor]: A list containing the output tensor with the slice
scattered.
+ """
+
+ dim_size_expr = input_tensor.shape[axis] # Expression for dimension size
+ dim_size = utils.get_const_int(dim_size_expr) # Dimension size (as
constant int)
+
+ if start == 0 and end == dim_size and step == 1:
+ return topi.identity(src)
+
+ mask = topi.full((dim_size,), "bool", True)
+ idx = topi.arange(start=0, stop=dim_size, step=1, dtype="int64")
+
+ if start != 0:
+ mask = topi.logical_and(mask, topi.greater_equal(idx, start))
+
+ if end != dim_size:
+ mask = topi.logical_and(mask, topi.less(idx, end))
+
+ if step != 1:
+ step_mask = topi.equal(topi.floor_mod(idx - start, step), 0)
+ mask = topi.logical_and(mask, step_mask)
+
+ mask_shape_base = [1] * len(input_tensor.shape)
+ mask_shape_base[axis] = dim_size
+ mask_shape = tuple(mask_shape_base)
+
+ mask_reshaped = topi.reshape(mask, mask_shape)
+
+ idx_new_pre = idx - start + (step - 1)
+ idx_new_div = topi.floor_divide(idx_new_pre, step)
+ idx_new = topi.clip(idx_new_div, 0, dim_size - 1)
+
+ temp = topi.take(src, idx_new, axis=axis)
+
+ mask_shape_expanded_base = list(input_tensor.shape)
+ mask_shape_expanded = tuple(mask_shape_expanded_base)
+
+ mask_expanded = topi.broadcast_to(mask_reshaped, mask_shape_expanded)
+
+ output = topi.where(mask_expanded, temp, input_tensor)
+
+ return [output]
diff --git a/src/relax/op/tensor/manipulate.cc
b/src/relax/op/tensor/manipulate.cc
index e98ba946c5..f834bed253 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -2448,6 +2448,161 @@ TVM_REGISTER_OP("relax.scatter_nd")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoScatterND)
.set_attr<Bool>("FPurity", Bool(true));
+/* relax.scatter_nd */
+TVM_REGISTER_NODE_TYPE(SliceScatterAttrs);
+
+Expr slice_scatter(Expr input, Expr src, int axis, PrimValue start, PrimValue
end, PrimValue step) {
+ auto attrs = make_object<SliceScatterAttrs>();
+ attrs->axis = std::move(axis);
+ static const Op& op = Op::Get("relax.slice_scatter");
+ return Call(op, {input, src, start, end, step}, Attrs(attrs), {});
+}
+
+TVM_FFI_REGISTER_GLOBAL("relax.op.slice_scatter").set_body_typed(slice_scatter);
+
+StructInfo InferStructInfoSliceScatter(const Call& call, const BlockBuilder&
ctx) {
+ arith::Analyzer* analyzer = ctx->GetAnalyzer();
+ const auto* data_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+ const auto* src_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[1]);
+ auto* attrs = call->attrs.as<SliceScatterAttrs>();
+
+ auto diag_tensor_check = [&](const TensorStructInfoNode* sinfo, const Expr&
arg_expr,
+ String name) {
+ if (sinfo == nullptr) {
+ ctx->ReportFatal(Diagnostic::Error(call) << "SliceScatter requires the
input " << name
+ << " to be a Tensor. However,
the given one is "
+ <<
arg_expr->struct_info_->GetTypeKey());
+ }
+ };
+
+ diag_tensor_check(data_sinfo, call->args[0], "data");
+ diag_tensor_check(src_sinfo, call->args[1], "src");
+
+ if (data_sinfo->IsUnknownNdim()) {
+ return TensorStructInfo(data_sinfo->dtype, kUnknownNDim,
data_sinfo->vdevice);
+ }
+
+ int ndim = data_sinfo->ndim;
+ int raw_axis = attrs->axis;
+ if (raw_axis < -ndim || raw_axis >= ndim) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "SliceScatter requires the input axis to be in the
range "
+ << "[" << -ndim << ", " << ndim - 1 << "]. However, the
input axis is "
+ << raw_axis << ", while ndim is " << ndim);
+ }
+
+ if (!data_sinfo->IsUnknownNdim() && !src_sinfo->IsUnknownNdim()) {
+ if (data_sinfo->ndim != src_sinfo->ndim) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "SliceScatter op requires the data tensor to have
the same rank as the "
+ "src tensor. However, the given dimensions are "
+ << "src: " << src_sinfo->ndim << ", data: " <<
data_sinfo->ndim);
+ }
+ }
+
+ if (data_sinfo->IsUnknownDtype() || src_sinfo->IsUnknownDtype()) {
+ auto diag_dtype_warn = [&](const TensorStructInfoNode* sinfo, String name)
{
+ if (sinfo->IsUnknownDtype()) {
+ LOG(WARNING) << "SliceScatter: Data type of " << name
+ << " has not been specified for call node " << call
+ << ". Assuming it is compatible.";
+ }
+ };
+ diag_dtype_warn(data_sinfo, "data");
+ diag_dtype_warn(src_sinfo, "src");
+ } else {
+ if (data_sinfo->dtype != src_sinfo->dtype) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "SliceScatter op requires the input data to have the
same type as "
+ "src. However, the given types are "
+ << "data: " << data_sinfo->dtype << ", src: " <<
src_sinfo->dtype);
+ }
+ }
+
+ auto get_prim_expr_from_arg = [&ctx, &call](const Expr& arg_expr,
std::string key) -> PrimExpr {
+ const auto* prim_value_node = arg_expr.as<PrimValueNode>();
+ if (prim_value_node == nullptr) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "SliceScatter expects the `" << key << "` argument
(" << arg_expr
+ << ") to be a PrimValue, but got " <<
arg_expr->GetTypeKey());
+ }
+ const PrimExpr& prim_expr = prim_value_node->value;
+ if (!prim_expr.dtype().is_int() && !prim_expr.dtype().is_uint()) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "SliceScatter expects `" << key << "` (" << prim_expr
+ << ") to be an integer PrimValue, but got dtype " <<
prim_expr.dtype());
+ }
+ return prim_expr;
+ };
+
+ PrimExpr start_val = get_prim_expr_from_arg(call->args[2], "start");
+ PrimExpr stop_val = get_prim_expr_from_arg(call->args[3], "end");
+ PrimExpr step_val = get_prim_expr_from_arg(call->args[4], "step");
+
+ if (analyzer->CanProve(step_val < 1)) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "SliceScatter op requires the step (" << step_val <<
") to be >= 1.");
+ }
+
+ if (analyzer->CanProve(stop_val < start_val)) {
+ ctx->ReportFatal(Diagnostic::Error(call) << "SliceScatter op requires
start (" << start_val
+ << ") <= end (" << stop_val <<
").");
+ }
+
+ int axis = NormalizeAxis(call, ctx, ndim, attrs->axis);
+
+ const auto* data_shape_node = data_sinfo->shape.as<ShapeExprNode>();
+ const auto* src_shape_node = src_sinfo->shape.as<ShapeExprNode>();
+
+ if (data_shape_node && src_shape_node && !src_sinfo->IsUnknownNdim()) {
+ ICHECK_EQ(data_shape_node->values.size(), static_cast<size_t>(ndim))
+ << "Internal error: data_shape_node rank mismatch with
data_sinfo->ndim for call " << call;
+ ICHECK_EQ(src_shape_node->values.size(),
static_cast<size_t>(src_sinfo->ndim))
+ << "Internal error: src_shape_node rank mismatch with src_sinfo->ndim
for call " << call;
+
+ PrimExpr num_elem = tvm::floordiv((stop_val - start_val + step_val -
PrimExpr(1)), step_val);
+
+ for (int i = 0; i < ndim; i++) {
+ if (i != axis) {
+ if (analyzer->CanProve(data_shape_node->values[i] !=
src_shape_node->values[i])) {
+ ctx->ReportFatal(
+ Diagnostic::Error(call)
+ << "SliceScatter op requires the data tensor to have the same
shape as the "
+ "src tensor except at the scatter axis ("
+ << axis << "). Mismatch at dimension " << i << ". "
+ << "data shape: " << data_sinfo->GetShape().value()
+ << ", src shape: " << src_sinfo->GetShape().value());
+ }
+ }
+ }
+
+ if (analyzer->CanProve(src_shape_node->values[axis] != num_elem)) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "SliceScatter op requires the src tensor's dimension
at scatter axis ("
+ << axis << ") to match the number of elements in the
slice. "
+ << "Actual src dimension at axis " << axis << ": "
+ << src_shape_node->values[axis]
+ << ", Expected elements in slice (num_elem): " <<
num_elem);
+ }
+ }
+
+ if (data_sinfo->shape.defined()) {
+ return TensorStructInfo(data_sinfo->shape.value(), data_sinfo->dtype,
data_sinfo->vdevice);
+ }
+ return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim,
data_sinfo->vdevice);
+}
+
+TVM_REGISTER_OP("relax.slice_scatter")
+ .set_attrs_type<SliceScatterAttrs>()
+ .set_num_inputs(5)
+ .add_argument("input", "Tensor", "The input tensor.")
+ .add_argument("src", "Tensor", "The source tensor to scatter.")
+ .add_argument("start", "PrimValue", "The starting index of the slice
(inclusive).")
+ .add_argument("end", "PrimValue", "The ending index of the slice
(exclusive).")
+ .add_argument("step", "PrimValue", "The step of the slice.")
+ .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoSliceScatter)
+ .set_attr<Bool>("FPurity", Bool(true));
+
/* relax.one_hot */
TVM_REGISTER_NODE_TYPE(OneHotAttrs);
Expr one_hot(Expr indices, PrimValue on_value, PrimValue off_value, int depth,
int axis) {
diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h
index 7d42b50838..cc15d5d4ab 100644
--- a/src/relax/op/tensor/manipulate.h
+++ b/src/relax/op/tensor/manipulate.h
@@ -273,6 +273,18 @@ Expr scatter_elements(Expr data, Expr indices, Expr
updates, int axis, String re
*/
Expr scatter_nd(Expr data, Expr indices, Expr updates, String reduction);
+/*!
+ * \brief Embeds the values of the src tensor into input at the given
dimension.
+ * \param input The input tensor to be updated.
+ * \param src The tensor to embed into input.
+ * \param dim The dimension to insert the slice into.
+ * \param start The start index of where to insert the slice.
+ * \param end The end index of where to insert the slice.
+ * \param step The how many elements to skip in
+ * \return The computed result tensor with the same shape as `data`.
+ */
+Expr slice_scatter(Expr input, Expr src, int axis, PrimValue start, PrimValue
end, PrimValue step);
+
/*!
* \brief Returns a one-hot tensor.
* \param indices The indices to set to `on_value`.
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index aaaf7e6eac..e6f75372d1 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -3923,6 +3923,51 @@ def test_select_slice():
verify_model(Slice2(), example_args, {}, expected2)
+def test_slice_scatter():
+ class SliceScatter1(Module):
+ def forward(self, input, src):
+ return torch.slice_scatter(input, src, dim=1, start=1, end=7,
step=2)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ a: R.Tensor((8, 8, 10, 10), dtype="float32"),
+ b: R.Tensor((8, 3, 10, 10), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((8, 8, 10, 10), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((8, 8, 10, 10), dtype="float32") =
R.slice_scatter(
+ a, b, R.prim_value(1), R.prim_value(7), R.prim_value(2),
axis=1
+ )
+ gv: R.Tuple(R.Tensor((8, 8, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ class SliceScatter2(Module):
+ def forward(self, input, src):
+ return torch.slice_scatter(input, src, dim=0, start=0, end=6,
step=1)
+
+ @I.ir_module
+ class expected2:
+ @R.function
+ def main(
+ a: R.Tensor((8, 16), dtype="float32"), b: R.Tensor((6, 16),
dtype="float32")
+ ) -> R.Tuple(R.Tensor((8, 16), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((8, 16), dtype="float32") = R.slice_scatter(
+ a, b, R.prim_value(0), R.prim_value(6), R.prim_value(1),
axis=0
+ )
+ gv: R.Tuple(R.Tensor((8, 16), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(8, 8, 10, 10, dtype=torch.float32),
torch.randn(8, 3, 10, 10))
+ verify_model(SliceScatter1(), example_args, {}, expected1)
+
+ example_args = (torch.randn(8, 16, dtype=torch.float32), torch.randn(6,
16))
+ verify_model(SliceScatter2(), example_args, {}, expected2)
+
+
def test_split():
class Chunk(Module):
def forward(self, input):
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 789c5649e6..f33b550858 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -5033,6 +5033,51 @@ def test_scatter():
verify_model(Scatter(), input_info, {}, expected)
+def test_slice_scatter():
+ class SliceScatter1(Module):
+ def forward(self, input, src):
+ return torch.slice_scatter(input, src, dim=1, start=1, end=7,
step=2)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ a: R.Tensor((8, 8, 10, 10), dtype="float32"),
+ b: R.Tensor((8, 3, 10, 10), dtype="float32"),
+ ) -> R.Tensor((8, 8, 10, 10), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((8, 8, 10, 10), dtype="float32") =
R.slice_scatter(
+ a, b, R.prim_value(1), R.prim_value(7), R.prim_value(2),
axis=1
+ )
+ gv: R.Tensor((8, 8, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ class SliceScatter2(Module):
+ def forward(self, input, src):
+ return torch.slice_scatter(input, src, dim=0, start=0, end=6,
step=1)
+
+ @I.ir_module
+ class expected2:
+ @R.function
+ def main(
+ a: R.Tensor((8, 16), dtype="float32"), b: R.Tensor((6, 16),
dtype="float32")
+ ) -> R.Tensor((8, 16), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((8, 16), dtype="float32") = R.slice_scatter(
+ a, b, R.prim_value(0), R.prim_value(6), R.prim_value(1),
axis=0
+ )
+ gv: R.Tensor((8, 16), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(
+ SliceScatter1(), [((8, 8, 10, 10), "float32"), ((8, 3, 10, 10),
"float32")], {}, expected1
+ )
+
+ verify_model(SliceScatter2(), [((8, 16), "float32"), ((6, 16),
"float32")], {}, expected2)
+
+
def test_masked_scatter():
class MaskedScatter1(Module):
def forward(self, data, mask, src):