This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new fbbbae994d [Relax][Frontend][TFLite] Add SCATTER_ND operator for Relax 
TFLite (#19490)
fbbbae994d is described below

commit fbbbae994d9f9fd1ba3b99c9ed8c5212dd912425
Author: Bana <[email protected]>
AuthorDate: Fri May 1 14:23:31 2026 +0300

    [Relax][Frontend][TFLite] Add SCATTER_ND operator for Relax TFLite (#19490)
    
    This PR adds support for the `SCATTER_ND` operator in the Relax TFLite
    frontend.
    
    ### Key Changes:
    
    - Added handler `convert_scatter_nd` to parse `indices` and `updates`.
    - Explicitly handles static vs dynamic shape tensor extraction via
    `to_int_list` and `relax.op.tensor_to_shape`.
    - Uses `relax.op.zeros` to initialize the base array based on the
    `updates` precision dtype.
    - Mapped `SCATTER_ND` to the corresponding `relax.op.scatter_nd()`
    target.
    - Registered the translator into `convert_map` and provided the matching
    unit test in test_frontend_tflite.py.
    
    ### Testing:
    Passed unit tests
    ```
    pytest tests/python/relax/test_frontend_tflite.py::test_scatter_nd
    ```
    
    Related to #19412
---
 .../tvm/relax/frontend/tflite/tflite_frontend.py   | 26 ++++++++++++++++++++++
 tests/python/relax/test_frontend_tflite.py         | 15 +++++++++++++
 2 files changed, 41 insertions(+)

diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py 
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 05bda6816b..e45f569856 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -218,6 +218,7 @@ class OperatorConverter:
             "RSQRT": functools.partial(self._convert_unary_elemwise, 
relax_op=_op.rsqrt),
             "REVERSE_SEQUENCE": self.convert_reverse_sequence,
             "REVERSE_V2": self.convert_reverse_v2,
+            "SCATTER_ND": self.convert_scatter_nd,
             "SELECT": self.convert_select,
             "SHAPE": self.convert_shape,
             "SIN": functools.partial(self._convert_unary_elemwise, 
relax_op=_op.sin),
@@ -2557,6 +2558,31 @@ class OperatorConverter:
         out = relax.op.strided_slice(in_expr, axes=axes, begin=begin, end=end)
         return out
 
+    def convert_scatter_nd(self, op):
+        """Convert TFLite SCATTER_ND"""
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 3, "SCATTER_ND should have 3 input 
tensors"
+        indices = self.get_tensor_expr(input_tensors[0])
+        updates = self.get_tensor_expr(input_tensors[1])
+        shape_tensor = input_tensors[2]
+
+        output_tensors = self.get_output_tensors(op)
+        assert len(output_tensors) == 1, "SCATTER_ND should have 1 output 
tensor"
+        updates_dtype = 
self.get_tensor_type_str(output_tensors[0].tensor.Type())
+
+        if self.has_expr(shape_tensor.tensor_idx):
+            shape_expr = self.get_expr(shape_tensor.tensor_idx)
+            shape_expr = self.bb.normalize(relax.op.astype(shape_expr, 
"int64"))
+            shape = self.bb.emit(relax.op.tensor_to_shape(shape_expr))
+        else:
+            shape = to_int_list(self.get_tensor_value(shape_tensor))
+
+        indices_dims = len(self._infer_shape(indices))
+        indices = relax.op.permute_dims(indices, axes=[-1] + 
list(range(indices_dims - 1)))
+
+        data = relax.op.zeros(shape, updates_dtype)
+        return relax.op.scatter_nd(data, indices, updates, "update")
+
     def convert_select(self, op):
         """Convert TFLite SELECT"""
         input_tensors = self.get_input_tensors(op)
diff --git a/tests/python/relax/test_frontend_tflite.py 
b/tests/python/relax/test_frontend_tflite.py
index 418a7665a7..beef66e09b 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -1707,6 +1707,21 @@ def test_networks(net, shape):
     verify(concrete_func)
 
 
+def test_scatter_nd():
+    class Model(tf.Module):
+        @tf.function(
+            input_signature=[
+                tf.TensorSpec(shape=(4, 1), dtype=tf.int32),
+                tf.TensorSpec(shape=(4,), dtype=tf.float32),
+                tf.TensorSpec(shape=(1,), dtype=tf.int32),
+            ]
+        )
+        def func(self, indices, updates, shape):
+            return tf.scatter_nd(indices, updates, shape)
+
+    verify(Model)
+
+
 def test_batch_matmul():
     class BatchMatMul(tf.Module):
         @tf.function(

Reply via email to