This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new fbbbae994d [Relax][Frontend][TFLite] Add SCATTER_ND operator for Relax
TFLite (#19490)
fbbbae994d is described below
commit fbbbae994d9f9fd1ba3b99c9ed8c5212dd912425
Author: Bana <[email protected]>
AuthorDate: Fri May 1 14:23:31 2026 +0300
[Relax][Frontend][TFLite] Add SCATTER_ND operator for Relax TFLite (#19490)
This PR adds support for the `SCATTER_ND` operator in the Relax TFLite
frontend.
### Key Changes:
- Added handler `convert_scatter_nd` to parse `indices` and `updates`.
- Explicitly handles static vs dynamic shape tensor extraction via
`to_int_list` and `relax.op.tensor_to_shape`.
- Uses `relax.op.zeros` to initialize the base array based on the
`updates` precision dtype.
- Mapped `SCATTER_ND` to the corresponding `relax.op.scatter_nd()`
target.
- Registered the translator into `convert_map` and provided the matching
unit test in test_frontend_tflite.py.
### Testing:
Passed unit tests
```
pytest tests/python/relax/test_frontend_tflite.py::test_scatter_nd
```
Related to #19412
---
.../tvm/relax/frontend/tflite/tflite_frontend.py | 26 ++++++++++++++++++++++
tests/python/relax/test_frontend_tflite.py | 15 +++++++++++++
2 files changed, 41 insertions(+)
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 05bda6816b..e45f569856 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -218,6 +218,7 @@ class OperatorConverter:
"RSQRT": functools.partial(self._convert_unary_elemwise,
relax_op=_op.rsqrt),
"REVERSE_SEQUENCE": self.convert_reverse_sequence,
"REVERSE_V2": self.convert_reverse_v2,
+ "SCATTER_ND": self.convert_scatter_nd,
"SELECT": self.convert_select,
"SHAPE": self.convert_shape,
"SIN": functools.partial(self._convert_unary_elemwise,
relax_op=_op.sin),
@@ -2557,6 +2558,31 @@ class OperatorConverter:
out = relax.op.strided_slice(in_expr, axes=axes, begin=begin, end=end)
return out
+ def convert_scatter_nd(self, op):
+ """Convert TFLite SCATTER_ND"""
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 3, "SCATTER_ND should have 3 input
tensors"
+ indices = self.get_tensor_expr(input_tensors[0])
+ updates = self.get_tensor_expr(input_tensors[1])
+ shape_tensor = input_tensors[2]
+
+ output_tensors = self.get_output_tensors(op)
+ assert len(output_tensors) == 1, "SCATTER_ND should have 1 output
tensor"
+ updates_dtype =
self.get_tensor_type_str(output_tensors[0].tensor.Type())
+
+ if self.has_expr(shape_tensor.tensor_idx):
+ shape_expr = self.get_expr(shape_tensor.tensor_idx)
+ shape_expr = self.bb.normalize(relax.op.astype(shape_expr,
"int64"))
+ shape = self.bb.emit(relax.op.tensor_to_shape(shape_expr))
+ else:
+ shape = to_int_list(self.get_tensor_value(shape_tensor))
+
+ indices_dims = len(self._infer_shape(indices))
+ indices = relax.op.permute_dims(indices, axes=[-1] +
list(range(indices_dims - 1)))
+
+ data = relax.op.zeros(shape, updates_dtype)
+ return relax.op.scatter_nd(data, indices, updates, "update")
+
def convert_select(self, op):
"""Convert TFLite SELECT"""
input_tensors = self.get_input_tensors(op)
diff --git a/tests/python/relax/test_frontend_tflite.py
b/tests/python/relax/test_frontend_tflite.py
index 418a7665a7..beef66e09b 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -1707,6 +1707,21 @@ def test_networks(net, shape):
verify(concrete_func)
+def test_scatter_nd():
+ class Model(tf.Module):
+ @tf.function(
+ input_signature=[
+ tf.TensorSpec(shape=(4, 1), dtype=tf.int32),
+ tf.TensorSpec(shape=(4,), dtype=tf.float32),
+ tf.TensorSpec(shape=(1,), dtype=tf.int32),
+ ]
+ )
+ def func(self, indices, updates, shape):
+ return tf.scatter_nd(indices, updates, shape)
+
+ verify(Model)
+
+
def test_batch_matmul():
class BatchMatMul(tf.Module):
@tf.function(