This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 29de9ab2e3 Add op support for slice_scatter (#18019)
29de9ab2e3 is described below

commit 29de9ab2e3dc669313eaf03fdd44327c411a2ff5
Author: kavin-mcw <[email protected]>
AuthorDate: Wed May 28 18:51:50 2025 +0530

    Add op support for slice_scatter (#18019)
    
    * Implement slice_scatter e2e
    
    * Fix lint issue
---
 include/tvm/relax/attrs/manipulate.h               |   9 ++
 .../frontend/torch/base_fx_graph_translator.py     |  13 ++
 .../frontend/torch/exported_program_translator.py  |   1 +
 python/tvm/relax/frontend/torch/fx_translator.py   |   1 +
 python/tvm/relax/op/__init__.py                    |   1 +
 python/tvm/relax/op/manipulate.py                  |  38 +++++
 .../tvm/relax/transform/legalize_ops/manipulate.py |  14 ++
 python/tvm/script/ir_builder/relax/ir.py           |   2 +
 python/tvm/topi/__init__.py                        |   1 +
 python/tvm/topi/slice_scatter.py                   |  74 ++++++++++
 src/relax/op/tensor/manipulate.cc                  | 155 +++++++++++++++++++++
 src/relax/op/tensor/manipulate.h                   |  12 ++
 .../relax/test_frontend_from_exported_program.py   |  45 ++++++
 tests/python/relax/test_frontend_from_fx.py        |  45 ++++++
 14 files changed, 411 insertions(+)

diff --git a/include/tvm/relax/attrs/manipulate.h 
b/include/tvm/relax/attrs/manipulate.h
index 3a5e3951af..f8a6ddfe0a 100644
--- a/include/tvm/relax/attrs/manipulate.h
+++ b/include/tvm/relax/attrs/manipulate.h
@@ -229,6 +229,15 @@ struct ScatterNDAttrs : public 
tvm::AttrsNode<ScatterNDAttrs> {
   }
 };  // struct ScatterNDAttrs
 
+/*! \brief Attributes used in slice_scatter operator */
+struct SliceScatterAttrs : public tvm::AttrsNode<SliceScatterAttrs> {
+  int axis;
+
+  TVM_DECLARE_ATTRS(SliceScatterAttrs, "relax.attrs.SliceScatterAttrs") {
+    TVM_ATTR_FIELD(axis).set_default(0).describe("the dimension to insert the 
slice into ");
+  }
+};  // struct SliceScatterAttrs
+
 /*! \brief Attributes used in one_hot operator */
 struct OneHotAttrs : public tvm::AttrsNode<OneHotAttrs> {
   int depth;
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 50969e85a5..485b7c088a 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1518,6 +1518,19 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
 
         return self.block_builder.emit(relax.op.meshgrid(new_inputs, 
indexing=indexing))
 
+    def _slice_scatter(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        input_tensor = args[0]
+        src = args[1]
+        dim = args[2] if len(args) > 2 else node.kwargs.get("dim", 0)
+        start = args[3] if len(args) > 3 else node.kwargs.get("start", 0)
+        end = args[4] if len(args) > 4 else node.kwargs.get("end", 
self.shape_of(input_tensor)[dim])
+        step = args[5] if len(args) > 5 else node.kwargs.get("step", 1)
+
+        return self.block_builder.emit(
+            relax.op.slice_scatter(input_tensor, src, start, end, step, 
axis=dim)
+        )
+
     def _permute(self, node: fx.Node) -> relax.Var:
         import torch  # type: ignore
 
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index efa3de3a10..4e7c0bf324 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -493,6 +493,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "roll.default": self._roll,
             "select.int": self._select,
             "slice.Tensor": self._slice,
+            "slice_scatter.default": self._slice_scatter,
             "sort.default": self._sort,
             "split.Tensor": self._split,
             "split_with_sizes.default": self._split,
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 97a2b51e49..33abccbe5f 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -909,6 +909,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             "scatter": self._scatter,
             "select": self._select,
             "size": self._size,
+            "slice_scatter": self._slice_scatter,
             "sort": self._sort,
             "split": self._split,
             "squeeze": self._squeeze,
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 0a2f0980fd..c4a5d2fd23 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -105,6 +105,7 @@ from .manipulate import (
     reshape,
     scatter_elements,
     scatter_nd,
+    slice_scatter,
     split,
     squeeze,
     stack,
diff --git a/python/tvm/relax/op/manipulate.py 
b/python/tvm/relax/op/manipulate.py
index b52aced59a..c71b19494a 100644
--- a/python/tvm/relax/op/manipulate.py
+++ b/python/tvm/relax/op/manipulate.py
@@ -786,6 +786,44 @@ def scatter_nd(data: Expr, indices: Expr, updates: Expr, 
reduction: str = "updat
     return _ffi_api.scatter_nd(data, indices, updates, reduction)  # type: 
ignore
 
 
+def slice_scatter(input_tensor: Expr, src: Expr, start, end, step, axis=0):
+    """Embeds the values of the src tensor into input at the given dimension.
+
+    Parameters
+    ----------
+    input_tensor: relax.Expr
+        The input tensor to be updated.
+
+    src: relax.Expr
+        The tensor to embed into input.
+
+    axis: int
+        The dimension to insert the slice into.
+
+    start:
+        The start index of where to insert the slice.
+
+    end:
+        The end index of where to insert the slice.
+
+    step:
+        The how many elements to skip in.
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result tensor with the same shape as `data`.
+
+    """
+    if not isinstance(start, PrimValue):
+        start = PrimValue(start)
+    if not isinstance(end, PrimValue):
+        end = PrimValue(end)
+    if not isinstance(step, PrimValue):
+        step = PrimValue(step)
+    return _ffi_api.slice_scatter(input_tensor, src, axis, start, end, step)
+
+
 def one_hot(
     indices: Expr, on_value: PrimValue, off_value: PrimValue, depth: int, 
axis: int = -1
 ) -> Expr:
diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py 
b/python/tvm/relax/transform/legalize_ops/manipulate.py
index 835be4bd4e..58abe434a2 100644
--- a/python/tvm/relax/transform/legalize_ops/manipulate.py
+++ b/python/tvm/relax/transform/legalize_ops/manipulate.py
@@ -263,6 +263,20 @@ def _scatter_nd(bb: BlockBuilder, call: Call) -> Expr:
     )
 
 
+@register_legalize("relax.slice_scatter")
+def _slice_scatter(bb: BlockBuilder, call: Call) -> Expr:
+
+    return bb.call_te(
+        topi.slice_scatter,
+        call.args[0],
+        call.args[1],
+        call.args[2],
+        call.args[3],
+        call.args[4],
+        call.attrs.axis,
+    )
+
+
 @register_legalize("relax.one_hot")
 def _one_hot(bb: BlockBuilder, call: Call) -> Expr:
     indices, on_value, off_value = call.args
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index b696d73031..92f84ce05c 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -157,6 +157,7 @@ from tvm.relax.op import (
     sign,
     sin,
     sinh,
+    slice_scatter,
     sort,
     split,
     sqrt,
@@ -854,6 +855,7 @@ __all__ = [
     "sign",
     "sin",
     "sinh",
+    "slice_scatter",
     "sort",
     "split",
     "square",
diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py
index fa4e98a89a..34ff213164 100644
--- a/python/tvm/topi/__init__.py
+++ b/python/tvm/topi/__init__.py
@@ -40,6 +40,7 @@ from .broadcast import *
 from .sort import *
 from .scatter import *
 from .scatter_elements import *
+from .slice_scatter import *
 from .sparse_reshape import *
 from .scan import *
 from .einsum import *
diff --git a/python/tvm/topi/slice_scatter.py b/python/tvm/topi/slice_scatter.py
new file mode 100644
index 0000000000..d8772d0f5b
--- /dev/null
+++ b/python/tvm/topi/slice_scatter.py
@@ -0,0 +1,74 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""SliceScatter operator"""
+from tvm import topi
+from . import utils
+
+
+def slice_scatter(input_tensor, src, start, end, step, axis):
+    """
+    Scatters a slice of src into input along the given axis (SSA form).
+
+    Args:
+        input_tensor (te.Tensor): The input tensor to scatter into.
+        src (te.Tensor): The source tensor to scatter from.
+        start (int): The starting index of the slice.
+        end (int): The ending index of the slice.
+        step (int): The step size of the slice.
+        axis (int): The axis to scatter along.
+
+    Returns:
+        list[te.Tensor]: A list containing the output tensor with the slice 
scattered.
+    """
+
+    dim_size_expr = input_tensor.shape[axis]  # Expression for dimension size
+    dim_size = utils.get_const_int(dim_size_expr)  # Dimension size (as 
constant int)
+
+    if start == 0 and end == dim_size and step == 1:
+        return topi.identity(src)
+
+    mask = topi.full((dim_size,), "bool", True)
+    idx = topi.arange(start=0, stop=dim_size, step=1, dtype="int64")
+
+    if start != 0:
+        mask = topi.logical_and(mask, topi.greater_equal(idx, start))
+
+    if end != dim_size:
+        mask = topi.logical_and(mask, topi.less(idx, end))
+
+    if step != 1:
+        step_mask = topi.equal(topi.floor_mod(idx - start, step), 0)
+        mask = topi.logical_and(mask, step_mask)
+
+    mask_shape_base = [1] * len(input_tensor.shape)
+    mask_shape_base[axis] = dim_size
+    mask_shape = tuple(mask_shape_base)
+
+    mask_reshaped = topi.reshape(mask, mask_shape)
+
+    idx_new_pre = idx - start + (step - 1)
+    idx_new_div = topi.floor_divide(idx_new_pre, step)
+    idx_new = topi.clip(idx_new_div, 0, dim_size - 1)
+
+    temp = topi.take(src, idx_new, axis=axis)
+
+    mask_shape_expanded_base = list(input_tensor.shape)
+    mask_shape_expanded = tuple(mask_shape_expanded_base)
+
+    mask_expanded = topi.broadcast_to(mask_reshaped, mask_shape_expanded)
+
+    output = topi.where(mask_expanded, temp, input_tensor)
+
+    return [output]
diff --git a/src/relax/op/tensor/manipulate.cc 
b/src/relax/op/tensor/manipulate.cc
index e98ba946c5..f834bed253 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -2448,6 +2448,161 @@ TVM_REGISTER_OP("relax.scatter_nd")
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoScatterND)
     .set_attr<Bool>("FPurity", Bool(true));
 
+/* relax.scatter_nd */
+TVM_REGISTER_NODE_TYPE(SliceScatterAttrs);
+
+Expr slice_scatter(Expr input, Expr src, int axis, PrimValue start, PrimValue 
end, PrimValue step) {
+  auto attrs = make_object<SliceScatterAttrs>();
+  attrs->axis = std::move(axis);
+  static const Op& op = Op::Get("relax.slice_scatter");
+  return Call(op, {input, src, start, end, step}, Attrs(attrs), {});
+}
+
+TVM_FFI_REGISTER_GLOBAL("relax.op.slice_scatter").set_body_typed(slice_scatter);
+
+StructInfo InferStructInfoSliceScatter(const Call& call, const BlockBuilder& 
ctx) {
+  arith::Analyzer* analyzer = ctx->GetAnalyzer();
+  const auto* data_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+  const auto* src_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[1]);
+  auto* attrs = call->attrs.as<SliceScatterAttrs>();
+
+  auto diag_tensor_check = [&](const TensorStructInfoNode* sinfo, const Expr& 
arg_expr,
+                               String name) {
+    if (sinfo == nullptr) {
+      ctx->ReportFatal(Diagnostic::Error(call) << "SliceScatter requires the 
input " << name
+                                               << " to be a Tensor. However, 
the given one is "
+                                               << 
arg_expr->struct_info_->GetTypeKey());
+    }
+  };
+
+  diag_tensor_check(data_sinfo, call->args[0], "data");
+  diag_tensor_check(src_sinfo, call->args[1], "src");
+
+  if (data_sinfo->IsUnknownNdim()) {
+    return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, 
data_sinfo->vdevice);
+  }
+
+  int ndim = data_sinfo->ndim;
+  int raw_axis = attrs->axis;
+  if (raw_axis < -ndim || raw_axis >= ndim) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "SliceScatter requires the input axis to be in the 
range "
+                     << "[" << -ndim << ", " << ndim - 1 << "]. However, the 
input axis is "
+                     << raw_axis << ", while ndim is " << ndim);
+  }
+
+  if (!data_sinfo->IsUnknownNdim() && !src_sinfo->IsUnknownNdim()) {
+    if (data_sinfo->ndim != src_sinfo->ndim) {
+      ctx->ReportFatal(Diagnostic::Error(call)
+                       << "SliceScatter op requires the data tensor to have 
the same rank as the "
+                          "src tensor. However, the given dimensions are "
+                       << "src: " << src_sinfo->ndim << ", data: " << 
data_sinfo->ndim);
+    }
+  }
+
+  if (data_sinfo->IsUnknownDtype() || src_sinfo->IsUnknownDtype()) {
+    auto diag_dtype_warn = [&](const TensorStructInfoNode* sinfo, String name) 
{
+      if (sinfo->IsUnknownDtype()) {
+        LOG(WARNING) << "SliceScatter: Data type of " << name
+                     << " has not been specified for call node " << call
+                     << ". Assuming it is compatible.";
+      }
+    };
+    diag_dtype_warn(data_sinfo, "data");
+    diag_dtype_warn(src_sinfo, "src");
+  } else {
+    if (data_sinfo->dtype != src_sinfo->dtype) {
+      ctx->ReportFatal(Diagnostic::Error(call)
+                       << "SliceScatter op requires the input data to have the 
same type as "
+                          "src. However, the given types are "
+                       << "data: " << data_sinfo->dtype << ", src: " << 
src_sinfo->dtype);
+    }
+  }
+
+  auto get_prim_expr_from_arg = [&ctx, &call](const Expr& arg_expr, 
std::string key) -> PrimExpr {
+    const auto* prim_value_node = arg_expr.as<PrimValueNode>();
+    if (prim_value_node == nullptr) {
+      ctx->ReportFatal(Diagnostic::Error(call)
+                       << "SliceScatter expects the `" << key << "` argument 
(" << arg_expr
+                       << ") to be a PrimValue, but got " << 
arg_expr->GetTypeKey());
+    }
+    const PrimExpr& prim_expr = prim_value_node->value;
+    if (!prim_expr.dtype().is_int() && !prim_expr.dtype().is_uint()) {
+      ctx->ReportFatal(Diagnostic::Error(call)
+                       << "SliceScatter expects `" << key << "` (" << prim_expr
+                       << ") to be an integer PrimValue, but got dtype " << 
prim_expr.dtype());
+    }
+    return prim_expr;
+  };
+
+  PrimExpr start_val = get_prim_expr_from_arg(call->args[2], "start");
+  PrimExpr stop_val = get_prim_expr_from_arg(call->args[3], "end");
+  PrimExpr step_val = get_prim_expr_from_arg(call->args[4], "step");
+
+  if (analyzer->CanProve(step_val < 1)) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "SliceScatter op requires the step (" << step_val << 
") to be >= 1.");
+  }
+
+  if (analyzer->CanProve(stop_val < start_val)) {
+    ctx->ReportFatal(Diagnostic::Error(call) << "SliceScatter op requires 
start (" << start_val
+                                             << ") <= end (" << stop_val << 
").");
+  }
+
+  int axis = NormalizeAxis(call, ctx, ndim, attrs->axis);
+
+  const auto* data_shape_node = data_sinfo->shape.as<ShapeExprNode>();
+  const auto* src_shape_node = src_sinfo->shape.as<ShapeExprNode>();
+
+  if (data_shape_node && src_shape_node && !src_sinfo->IsUnknownNdim()) {
+    ICHECK_EQ(data_shape_node->values.size(), static_cast<size_t>(ndim))
+        << "Internal error: data_shape_node rank mismatch with 
data_sinfo->ndim for call " << call;
+    ICHECK_EQ(src_shape_node->values.size(), 
static_cast<size_t>(src_sinfo->ndim))
+        << "Internal error: src_shape_node rank mismatch with src_sinfo->ndim 
for call " << call;
+
+    PrimExpr num_elem = tvm::floordiv((stop_val - start_val + step_val - 
PrimExpr(1)), step_val);
+
+    for (int i = 0; i < ndim; i++) {
+      if (i != axis) {
+        if (analyzer->CanProve(data_shape_node->values[i] != 
src_shape_node->values[i])) {
+          ctx->ReportFatal(
+              Diagnostic::Error(call)
+              << "SliceScatter op requires the data tensor to have the same 
shape as the "
+                 "src tensor except at the scatter axis ("
+              << axis << "). Mismatch at dimension " << i << ". "
+              << "data shape: " << data_sinfo->GetShape().value()
+              << ", src shape: " << src_sinfo->GetShape().value());
+        }
+      }
+    }
+
+    if (analyzer->CanProve(src_shape_node->values[axis] != num_elem)) {
+      ctx->ReportFatal(Diagnostic::Error(call)
+                       << "SliceScatter op requires the src tensor's dimension 
at scatter axis ("
+                       << axis << ") to match the number of elements in the 
slice. "
+                       << "Actual src dimension at axis " << axis << ": "
+                       << src_shape_node->values[axis]
+                       << ", Expected elements in slice (num_elem): " << 
num_elem);
+    }
+  }
+
+  if (data_sinfo->shape.defined()) {
+    return TensorStructInfo(data_sinfo->shape.value(), data_sinfo->dtype, 
data_sinfo->vdevice);
+  }
+  return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, 
data_sinfo->vdevice);
+}
+
+TVM_REGISTER_OP("relax.slice_scatter")
+    .set_attrs_type<SliceScatterAttrs>()
+    .set_num_inputs(5)
+    .add_argument("input", "Tensor", "The input tensor.")
+    .add_argument("src", "Tensor", "The source tensor to scatter.")
+    .add_argument("start", "PrimValue", "The starting index of the slice 
(inclusive).")
+    .add_argument("end", "PrimValue", "The ending index of the slice 
(exclusive).")
+    .add_argument("step", "PrimValue", "The step of the slice.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoSliceScatter)
+    .set_attr<Bool>("FPurity", Bool(true));
+
 /* relax.one_hot */
 TVM_REGISTER_NODE_TYPE(OneHotAttrs);
 Expr one_hot(Expr indices, PrimValue on_value, PrimValue off_value, int depth, 
int axis) {
diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h
index 7d42b50838..cc15d5d4ab 100644
--- a/src/relax/op/tensor/manipulate.h
+++ b/src/relax/op/tensor/manipulate.h
@@ -273,6 +273,18 @@ Expr scatter_elements(Expr data, Expr indices, Expr 
updates, int axis, String re
  */
 Expr scatter_nd(Expr data, Expr indices, Expr updates, String reduction);
 
+/*!
+ * \brief Embeds the values of the src tensor into input at the given 
dimension.
+ * \param input The input tensor to be updated.
+ * \param src The tensor to embed into input.
+ * \param dim The dimension to insert the slice into.
+ * \param start The start index of where to insert the slice.
+ * \param end The end index of where to insert the slice.
+ * \param step The how many elements to skip in
+ * \return  The computed result tensor with the same shape as `data`.
+ */
+Expr slice_scatter(Expr input, Expr src, int axis, PrimValue start, PrimValue 
end, PrimValue step);
+
 /*!
  * \brief Returns a one-hot tensor.
  * \param indices The indices to set to `on_value`.
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index aaaf7e6eac..e6f75372d1 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -3923,6 +3923,51 @@ def test_select_slice():
     verify_model(Slice2(), example_args, {}, expected2)
 
 
+def test_slice_scatter():
+    class SliceScatter1(Module):
+        def forward(self, input, src):
+            return torch.slice_scatter(input, src, dim=1, start=1, end=7, 
step=2)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            a: R.Tensor((8, 8, 10, 10), dtype="float32"),
+            b: R.Tensor((8, 3, 10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((8, 8, 10, 10), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((8, 8, 10, 10), dtype="float32") = 
R.slice_scatter(
+                    a, b, R.prim_value(1), R.prim_value(7), R.prim_value(2), 
axis=1
+                )
+                gv: R.Tuple(R.Tensor((8, 8, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    class SliceScatter2(Module):
+        def forward(self, input, src):
+            return torch.slice_scatter(input, src, dim=0, start=0, end=6, 
step=1)
+
+    @I.ir_module
+    class expected2:
+        @R.function
+        def main(
+            a: R.Tensor((8, 16), dtype="float32"), b: R.Tensor((6, 16), 
dtype="float32")
+        ) -> R.Tuple(R.Tensor((8, 16), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((8, 16), dtype="float32") = R.slice_scatter(
+                    a, b, R.prim_value(0), R.prim_value(6), R.prim_value(1), 
axis=0
+                )
+                gv: R.Tuple(R.Tensor((8, 16), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(8, 8, 10, 10, dtype=torch.float32), 
torch.randn(8, 3, 10, 10))
+    verify_model(SliceScatter1(), example_args, {}, expected1)
+
+    example_args = (torch.randn(8, 16, dtype=torch.float32), torch.randn(6, 
16))
+    verify_model(SliceScatter2(), example_args, {}, expected2)
+
+
 def test_split():
     class Chunk(Module):
         def forward(self, input):
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 789c5649e6..f33b550858 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -5033,6 +5033,51 @@ def test_scatter():
     verify_model(Scatter(), input_info, {}, expected)
 
 
+def test_slice_scatter():
+    class SliceScatter1(Module):
+        def forward(self, input, src):
+            return torch.slice_scatter(input, src, dim=1, start=1, end=7, 
step=2)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            a: R.Tensor((8, 8, 10, 10), dtype="float32"),
+            b: R.Tensor((8, 3, 10, 10), dtype="float32"),
+        ) -> R.Tensor((8, 8, 10, 10), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((8, 8, 10, 10), dtype="float32") = 
R.slice_scatter(
+                    a, b, R.prim_value(1), R.prim_value(7), R.prim_value(2), 
axis=1
+                )
+                gv: R.Tensor((8, 8, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    class SliceScatter2(Module):
+        def forward(self, input, src):
+            return torch.slice_scatter(input, src, dim=0, start=0, end=6, 
step=1)
+
+    @I.ir_module
+    class expected2:
+        @R.function
+        def main(
+            a: R.Tensor((8, 16), dtype="float32"), b: R.Tensor((6, 16), 
dtype="float32")
+        ) -> R.Tensor((8, 16), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((8, 16), dtype="float32") = R.slice_scatter(
+                    a, b, R.prim_value(0), R.prim_value(6), R.prim_value(1), 
axis=0
+                )
+                gv: R.Tensor((8, 16), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(
+        SliceScatter1(), [((8, 8, 10, 10), "float32"), ((8, 3, 10, 10), 
"float32")], {}, expected1
+    )
+
+    verify_model(SliceScatter2(), [((8, 16), "float32"), ((6, 16), 
"float32")], {}, expected2)
+
+
 def test_masked_scatter():
     class MaskedScatter1(Module):
         def forward(self, data, mask, src):

Reply via email to