This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8873a4c8a5 [Relax][Frontend][TFLite] Add segment operator mappings 
(#19491)
8873a4c8a5 is described below

commit 8873a4c8a504cc1523a997f55cb3e6e3d9bb0759
Author: HoYi <[email protected]>
AuthorDate: Sun May 3 13:39:24 2026 +0800

    [Relax][Frontend][TFLite] Add segment operator mappings (#19491)
    
    ## Summary
    
    This PR adds Relax TFLite frontend support for the following segment
    operators from #19412:
    
      - `SEGMENT_SUM`
      - `UNSORTED_SEGMENT_MIN`
      - `UNSORTED_SEGMENT_PROD`
    
    These operators are lowered through `relax.op.scatter_nd` with the
    corresponding reduction modes.
    
      ## Changes
    
      ### TFLite Frontend
    
      1. Add TFLite converter mappings for segment operators:
         - `SEGMENT_SUM` -> `scatter_nd(..., reduction="add")`
         - `UNSORTED_SEGMENT_MIN` -> `scatter_nd(..., reduction="min")`
         - `UNSORTED_SEGMENT_PROD` -> `scatter_nd(..., reduction="mul")`
    
      2. Add shared segment lowering logic:
         - Convert `segment_ids` into scatter indices via `expand_dims`.
    - Build the output shape from `num_segments` or constant `segment_ids`.
    - Initialize the scatter base tensor with the correct reduction
    identity.
    
      ### Tests
    
      Add TFLite frontend tests for:
    
      - `test_segment_sum`
      - `test_unsorted_segment_min`
      - `test_unsorted_segment_prod`
    
    Each test verifies the imported Relax IR lowers to `R.scatter_nd` with
    the expected reduction mode and base tensor initialization.
    
      ## Testing
    
      All targeted tests pass:
    
      ```bash
      python -m pytest  \
        tests/python/relax/test_frontend_tflite.py::test_scatter_nd \
        tests/python/relax/test_frontend_tflite.py::test_segment_sum \
    tests/python/relax/test_frontend_tflite.py::test_unsorted_segment_min \
    tests/python/relax/test_frontend_tflite.py::test_unsorted_segment_prod \
        -q
    ```
      ## References
    
      - Issue #19412: TFLite Relax frontend operator support tracking
      - Related PR #19490: Adds SCATTER_ND support
---
 .../tvm/relax/frontend/tflite/tflite_frontend.py   | 103 +++++++++++++++++++++
 tests/python/relax/test_frontend_tflite.py         |  95 +++++++++++++++++++
 2 files changed, 198 insertions(+)

diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py 
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index ebfbcacf9c..8d112b91d6 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -223,6 +223,9 @@ class OperatorConverter:
             "SCATTER_ND": self.convert_scatter_nd,
             "SELECT": self.convert_select,
             "SELECT_V2": self.convert_select,
+            "SEGMENT_SUM": functools.partial(
+                self._convert_segment_op, op_name="SEGMENT_SUM", 
reduction="add"
+            ),
             "SHAPE": self.convert_shape,
             "SIN": functools.partial(self._convert_unary_elemwise, 
relax_op=_op.sin),
             "SLICE": self.convert_slice,
@@ -246,6 +249,12 @@ class OperatorConverter:
             "TRANSPOSE_CONV": self.convert_transpose_conv,
             "TRANSPOSE": self.convert_transpose,
             "UNPACK": self.convert_unpack,
+            "UNSORTED_SEGMENT_MIN": functools.partial(
+                self._convert_segment_op, op_name="UNSORTED_SEGMENT_MIN", 
reduction="min"
+            ),
+            "UNSORTED_SEGMENT_PROD": functools.partial(
+                self._convert_segment_op, op_name="UNSORTED_SEGMENT_PROD", 
reduction="mul"
+            ),
             # "UNIDIRECTIONAL_SEQUENCE_LSTM": 
self.convert_unidirectional_sequence_lstm,
             "WHERE": self.convert_select,
             "ZEROS_LIKE": self.convert_zeros_like,
@@ -2586,6 +2595,100 @@ class OperatorConverter:
         data = relax.op.zeros(shape, updates_dtype)
         return relax.op.scatter_nd(data, indices, updates, "update")
 
+    def _get_segment_scatter_base(self, output_shape, output_dtype, reduction):
+        """Create the identity base tensor for scatter-based segment 
reductions."""
+        if reduction == "add":
+            return relax.op.zeros(output_shape, output_dtype)
+        if reduction == "mul":
+            return relax.op.full(output_shape, relax.const(1, output_dtype), 
output_dtype)
+        if reduction == "min":
+            np_dtype = np.dtype(output_dtype)
+            if np.issubdtype(np_dtype, np.floating):
+                identity = np.finfo(np_dtype).max
+            elif np.issubdtype(np_dtype, np.integer):
+                identity = np.iinfo(np_dtype).max
+            else:
+                raise tvm.error.OpNotImplemented(
+                    f"UNSORTED_SEGMENT_MIN does not support output dtype 
{output_dtype}."
+                )
+            return relax.op.full(output_shape, relax.const(identity, 
output_dtype), output_dtype)
+
+        raise ValueError(f"Unsupported segment reduction mode: {reduction}")
+
+    def _get_segment_num_segments(self, op_name, input_tensors):
+        if op_name == "SEGMENT_SUM":
+            segment_ids_tensor = input_tensors[1]
+            if self.has_expr(segment_ids_tensor.tensor_idx):
+                raise tvm.error.OpNotImplemented(
+                    "TFLite SEGMENT_SUM with runtime segment_ids is not 
supported, "
+                    "because TFLite does not encode a reliable output segment 
count."
+            )
+            segment_ids = self.get_tensor_value(segment_ids_tensor)
+            if np.any(segment_ids < 0):
+                raise tvm.error.OpNotImplemented(
+                    "TFLite SEGMENT_SUM with negative segment ids is not 
supported."
+                )
+            return int(np.max(segment_ids)) + 1 if segment_ids.size else 0
+
+        num_segments_tensor = input_tensors[2]
+        if self.has_expr(num_segments_tensor.tensor_idx):
+            raise tvm.error.OpNotImplemented(
+                f"TFLite {op_name} with runtime num_segments is not supported."
+            )
+        num_segments_value = self.get_tensor_value(num_segments_tensor)
+        assert num_segments_value.size == 1, f"{op_name} num_segments should 
be a scalar tensor"
+        num_segments = int(num_segments_value.item())
+        assert num_segments >= 0, f"{op_name} num_segments should be 
non-negative"
+        return num_segments
+
+    def _convert_segment_op(self, op, op_name, reduction):
+        """Convert TFLite segment ops through relax.op.scatter_nd."""
+        from tflite.TensorType import TensorType
+
+        input_tensors = self.get_input_tensors(op)
+        expected_inputs = 2 if op_name == "SEGMENT_SUM" else 3
+        assert len(input_tensors) == expected_inputs, (
+            f"{op_name} should have {expected_inputs} input tensors"
+        )
+
+        data_tensor = input_tensors[0]
+        segment_ids_tensor = input_tensors[1]
+        for t in input_tensors:
+            assert not t.qnn_params, "Quantized input is not expected."
+
+        segment_ids_type = segment_ids_tensor.tensor.Type()
+        assert segment_ids_type in (TensorType.INT32, TensorType.INT64)
+        if op_name != "SEGMENT_SUM":
+            num_segments_type = input_tensors[2].tensor.Type()
+            assert num_segments_type in (TensorType.INT32, TensorType.INT64)
+        if not self.has_expr(segment_ids_tensor.tensor_idx):
+            segment_ids_value = self.get_tensor_value(segment_ids_tensor)
+            if np.any(segment_ids_value < 0):
+                raise tvm.error.OpNotImplemented(
+                    f"TFLite {op_name} with negative segment ids is not 
supported."
+                )
+
+        output_tensors = self.get_output_tensors(op)
+        assert len(output_tensors) == 1, f"{op_name} should have 1 output 
tensor"
+        output_tensor = output_tensors[0]
+        output_dtype = self.get_tensor_type_str(output_tensor.tensor.Type())
+
+        data_shape = to_int_list(self.get_tensor_shape(data_tensor))
+        segment_ids_shape = 
to_int_list(self.get_tensor_shape(segment_ids_tensor))
+        segment_ids_rank = len(segment_ids_shape)
+        assert data_shape[:segment_ids_rank] == segment_ids_shape, (
+            f"{op_name} requires segment_ids shape to match a prefix of data 
shape"
+        )
+        num_segments = self._get_segment_num_segments(op_name, input_tensors)
+        output_shape = [num_segments] + data_shape[segment_ids_rank:]
+
+        data = self.get_tensor_expr(data_tensor)
+        segment_ids = self.get_tensor_expr(segment_ids_tensor)
+        indices = relax.op.expand_dims(segment_ids, axis=[segment_ids_rank])
+
+        base = self._get_segment_scatter_base(output_shape, output_dtype, 
reduction)
+        return relax.op.scatter_nd(base, indices, data, reduction)
+
     def convert_select(self, op):
         """Convert TFLite SELECT"""
         input_tensors = self.get_input_tensors(op)
diff --git a/tests/python/relax/test_frontend_tflite.py 
b/tests/python/relax/test_frontend_tflite.py
index c5531ccf73..a2d2612232 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -1783,6 +1783,101 @@ def test_scatter_nd():
     verify(Model)
 
 
+def test_segment_sum():
+    """SEGMENT_SUM lowers to scatter_nd with add reduction."""
+
+    class Model(tf.Module):
+        @tf.function(input_signature=[tf.TensorSpec(shape=(4, 2), 
dtype=tf.float32)])
+        def func(self, data):
+            return tf.raw_ops.SegmentSum(
+                data=data, segment_ids=tf.constant([0, 0, 1, 2], 
dtype=tf.int32)
+            )
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(data: R.Tensor((4, 2), dtype="float32")) -> R.Tensor((3, 2), 
dtype="float32"):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                lv: R.Tensor((3, 2), dtype="float32") = R.zeros(R.shape([3, 
2]), dtype="float32")
+                lv1: R.Tensor((4, 1), dtype="int32") = R.expand_dims(
+                    R.const([0, 0, 1, 2], "int32"), axis=[1]
+                )
+                gv: R.Tensor((3, 2), dtype="float32") = R.scatter_nd(
+                    lv, lv1, data, reduction="add"
+                )
+                R.output(gv)
+            return gv
+
+    verify(Model, Expected)
+
+
+def test_unsorted_segment_min():
+    """UNSORTED_SEGMENT_MIN lowers to scatter_nd with min reduction."""
+
+    class Model(tf.Module):
+        @tf.function(input_signature=[tf.TensorSpec(shape=(4, 2), 
dtype=tf.float32)])
+        def func(self, data):
+            return tf.raw_ops.UnsortedSegmentMin(
+                data=data,
+                segment_ids=tf.constant([2, 0, 2, 1], dtype=tf.int32),
+                num_segments=tf.constant(3, dtype=tf.int32),
+            )
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(data: R.Tensor((4, 2), dtype="float32")) -> R.Tensor((3, 2), 
dtype="float32"):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                lv: R.Tensor((3, 2), dtype="float32") = R.full(
+                    R.shape([3, 2]), R.const(np.finfo(np.float32).max, 
"float32"), dtype="float32"
+                )
+                lv1: R.Tensor((4, 1), dtype="int32") = R.expand_dims(
+                    R.const([2, 0, 2, 1], "int32"), axis=[1]
+                )
+                gv: R.Tensor((3, 2), dtype="float32") = R.scatter_nd(
+                    lv, lv1, data, reduction="min"
+                )
+                R.output(gv)
+            return gv
+
+    verify(Model, Expected)
+
+
+def test_unsorted_segment_prod():
+    """UNSORTED_SEGMENT_PROD lowers to scatter_nd with mul reduction."""
+
+    class Model(tf.Module):
+        @tf.function(input_signature=[tf.TensorSpec(shape=(4, 2), 
dtype=tf.float32)])
+        def func(self, data):
+            return tf.raw_ops.UnsortedSegmentProd(
+                data=data,
+                segment_ids=tf.constant([1, 0, 1, 2], dtype=tf.int32),
+                num_segments=tf.constant(3, dtype=tf.int32),
+            )
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(data: R.Tensor((4, 2), dtype="float32")) -> R.Tensor((3, 2), 
dtype="float32"):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                lv: R.Tensor((3, 2), dtype="float32") = R.full(
+                    R.shape([3, 2]), R.const(1, "float32"), dtype="float32"
+                )
+                lv1: R.Tensor((4, 1), dtype="int32") = R.expand_dims(
+                    R.const([1, 0, 1, 2], "int32"), axis=[1]
+                )
+                gv: R.Tensor((3, 2), dtype="float32") = R.scatter_nd(
+                    lv, lv1, data, reduction="mul"
+                )
+                R.output(gv)
+            return gv
+
+    verify(Model, Expected)
+
+
 def test_batch_matmul():
     class BatchMatMul(tf.Module):
         @tf.function(

Reply via email to