This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8873a4c8a5 [Relax][Frontend][TFLite] Add segment operator mappings
(#19491)
8873a4c8a5 is described below
commit 8873a4c8a504cc1523a997f55cb3e6e3d9bb0759
Author: HoYi <[email protected]>
AuthorDate: Sun May 3 13:39:24 2026 +0800
[Relax][Frontend][TFLite] Add segment operator mappings (#19491)
## Summary
This PR adds Relax TFLite frontend support for the following segment
operators from #19412:
- `SEGMENT_SUM`
- `UNSORTED_SEGMENT_MIN`
- `UNSORTED_SEGMENT_PROD`
These operators are lowered through `relax.op.scatter_nd` with the
corresponding reduction modes.
## Changes
### TFLite Frontend
1. Add TFLite converter mappings for segment operators:
- `SEGMENT_SUM` -> `scatter_nd(..., reduction="add")`
- `UNSORTED_SEGMENT_MIN` -> `scatter_nd(..., reduction="min")`
- `UNSORTED_SEGMENT_PROD` -> `scatter_nd(..., reduction="mul")`
2. Add shared segment lowering logic:
- Convert `segment_ids` into scatter indices via `expand_dims`.
- Build the output shape from `num_segments` or constant `segment_ids`.
- Initialize the scatter base tensor with the correct reduction
identity.
### Tests
Add TFLite frontend tests for:
- `test_segment_sum`
- `test_unsorted_segment_min`
- `test_unsorted_segment_prod`
Each test verifies the imported Relax IR lowers to `R.scatter_nd` with
the expected reduction mode and base tensor initialization.
## Testing
All targeted tests pass:
```bash
python -m pytest \
tests/python/relax/test_frontend_tflite.py::test_scatter_nd \
tests/python/relax/test_frontend_tflite.py::test_segment_sum \
tests/python/relax/test_frontend_tflite.py::test_unsorted_segment_min \
tests/python/relax/test_frontend_tflite.py::test_unsorted_segment_prod \
-q
```
## References
- Issue #19412: TFLite Relax frontend operator support tracking
- Related PR #19490: Adds SCATTER_ND support
---
.../tvm/relax/frontend/tflite/tflite_frontend.py | 103 +++++++++++++++++++++
tests/python/relax/test_frontend_tflite.py | 95 +++++++++++++++++++
2 files changed, 198 insertions(+)
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index ebfbcacf9c..8d112b91d6 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -223,6 +223,9 @@ class OperatorConverter:
"SCATTER_ND": self.convert_scatter_nd,
"SELECT": self.convert_select,
"SELECT_V2": self.convert_select,
+ "SEGMENT_SUM": functools.partial(
+ self._convert_segment_op, op_name="SEGMENT_SUM",
reduction="add"
+ ),
"SHAPE": self.convert_shape,
"SIN": functools.partial(self._convert_unary_elemwise,
relax_op=_op.sin),
"SLICE": self.convert_slice,
@@ -246,6 +249,12 @@ class OperatorConverter:
"TRANSPOSE_CONV": self.convert_transpose_conv,
"TRANSPOSE": self.convert_transpose,
"UNPACK": self.convert_unpack,
+ "UNSORTED_SEGMENT_MIN": functools.partial(
+ self._convert_segment_op, op_name="UNSORTED_SEGMENT_MIN",
reduction="min"
+ ),
+ "UNSORTED_SEGMENT_PROD": functools.partial(
+ self._convert_segment_op, op_name="UNSORTED_SEGMENT_PROD",
reduction="mul"
+ ),
# "UNIDIRECTIONAL_SEQUENCE_LSTM":
self.convert_unidirectional_sequence_lstm,
"WHERE": self.convert_select,
"ZEROS_LIKE": self.convert_zeros_like,
@@ -2586,6 +2595,100 @@ class OperatorConverter:
data = relax.op.zeros(shape, updates_dtype)
return relax.op.scatter_nd(data, indices, updates, "update")
+ def _get_segment_scatter_base(self, output_shape, output_dtype, reduction):
+ """Create the identity base tensor for scatter-based segment
reductions."""
+ if reduction == "add":
+ return relax.op.zeros(output_shape, output_dtype)
+ if reduction == "mul":
+ return relax.op.full(output_shape, relax.const(1, output_dtype),
output_dtype)
+ if reduction == "min":
+ np_dtype = np.dtype(output_dtype)
+ if np.issubdtype(np_dtype, np.floating):
+ identity = np.finfo(np_dtype).max
+ elif np.issubdtype(np_dtype, np.integer):
+ identity = np.iinfo(np_dtype).max
+ else:
+ raise tvm.error.OpNotImplemented(
+ f"UNSORTED_SEGMENT_MIN does not support output dtype
{output_dtype}."
+ )
+ return relax.op.full(output_shape, relax.const(identity,
output_dtype), output_dtype)
+
+ raise ValueError(f"Unsupported segment reduction mode: {reduction}")
+
+ def _get_segment_num_segments(self, op_name, input_tensors):
+ if op_name == "SEGMENT_SUM":
+ segment_ids_tensor = input_tensors[1]
+ if self.has_expr(segment_ids_tensor.tensor_idx):
+ raise tvm.error.OpNotImplemented(
+ "TFLite SEGMENT_SUM with runtime segment_ids is not
supported, "
+ "because TFLite does not encode a reliable output segment
count."
+ )
+ segment_ids = self.get_tensor_value(segment_ids_tensor)
+ if np.any(segment_ids < 0):
+ raise tvm.error.OpNotImplemented(
+ "TFLite SEGMENT_SUM with negative segment ids is not
supported."
+ )
+ return int(np.max(segment_ids)) + 1 if segment_ids.size else 0
+
+ num_segments_tensor = input_tensors[2]
+ if self.has_expr(num_segments_tensor.tensor_idx):
+ raise tvm.error.OpNotImplemented(
+ f"TFLite {op_name} with runtime num_segments is not supported."
+ )
+ num_segments_value = self.get_tensor_value(num_segments_tensor)
+ assert num_segments_value.size == 1, f"{op_name} num_segments should
be a scalar tensor"
+ num_segments = int(num_segments_value.item())
+ assert num_segments >= 0, f"{op_name} num_segments should be
non-negative"
+ return num_segments
+
+ def _convert_segment_op(self, op, op_name, reduction):
+ """Convert TFLite segment ops through relax.op.scatter_nd."""
+ from tflite.TensorType import TensorType
+
+ input_tensors = self.get_input_tensors(op)
+ expected_inputs = 2 if op_name == "SEGMENT_SUM" else 3
+ assert len(input_tensors) == expected_inputs, (
+ f"{op_name} should have {expected_inputs} input tensors"
+ )
+
+ data_tensor = input_tensors[0]
+ segment_ids_tensor = input_tensors[1]
+ for t in input_tensors:
+ assert not t.qnn_params, "Quantized input is not expected."
+
+ segment_ids_type = segment_ids_tensor.tensor.Type()
+ assert segment_ids_type in (TensorType.INT32, TensorType.INT64)
+ if op_name != "SEGMENT_SUM":
+ num_segments_type = input_tensors[2].tensor.Type()
+ assert num_segments_type in (TensorType.INT32, TensorType.INT64)
+ if not self.has_expr(segment_ids_tensor.tensor_idx):
+ segment_ids_value = self.get_tensor_value(segment_ids_tensor)
+ if np.any(segment_ids_value < 0):
+ raise tvm.error.OpNotImplemented(
+ f"TFLite {op_name} with negative segment ids is not
supported."
+ )
+
+ output_tensors = self.get_output_tensors(op)
+ assert len(output_tensors) == 1, f"{op_name} should have 1 output
tensor"
+ output_tensor = output_tensors[0]
+ output_dtype = self.get_tensor_type_str(output_tensor.tensor.Type())
+
+ data_shape = to_int_list(self.get_tensor_shape(data_tensor))
+ segment_ids_shape =
to_int_list(self.get_tensor_shape(segment_ids_tensor))
+ segment_ids_rank = len(segment_ids_shape)
+ assert data_shape[:segment_ids_rank] == segment_ids_shape, (
+ f"{op_name} requires segment_ids shape to match a prefix of data
shape"
+ )
+ num_segments = self._get_segment_num_segments(op_name, input_tensors)
+ output_shape = [num_segments] + data_shape[segment_ids_rank:]
+
+ data = self.get_tensor_expr(data_tensor)
+ segment_ids = self.get_tensor_expr(segment_ids_tensor)
+ indices = relax.op.expand_dims(segment_ids, axis=[segment_ids_rank])
+
+ base = self._get_segment_scatter_base(output_shape, output_dtype,
reduction)
+ return relax.op.scatter_nd(base, indices, data, reduction)
+
def convert_select(self, op):
"""Convert TFLite SELECT"""
input_tensors = self.get_input_tensors(op)
diff --git a/tests/python/relax/test_frontend_tflite.py
b/tests/python/relax/test_frontend_tflite.py
index c5531ccf73..a2d2612232 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -1783,6 +1783,101 @@ def test_scatter_nd():
verify(Model)
+def test_segment_sum():
+ """SEGMENT_SUM lowers to scatter_nd with add reduction."""
+
+ class Model(tf.Module):
+ @tf.function(input_signature=[tf.TensorSpec(shape=(4, 2),
dtype=tf.float32)])
+ def func(self, data):
+ return tf.raw_ops.SegmentSum(
+ data=data, segment_ids=tf.constant([0, 0, 1, 2],
dtype=tf.int32)
+ )
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(data: R.Tensor((4, 2), dtype="float32")) -> R.Tensor((3, 2),
dtype="float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ lv: R.Tensor((3, 2), dtype="float32") = R.zeros(R.shape([3,
2]), dtype="float32")
+ lv1: R.Tensor((4, 1), dtype="int32") = R.expand_dims(
+ R.const([0, 0, 1, 2], "int32"), axis=[1]
+ )
+ gv: R.Tensor((3, 2), dtype="float32") = R.scatter_nd(
+ lv, lv1, data, reduction="add"
+ )
+ R.output(gv)
+ return gv
+
+ verify(Model, Expected)
+
+
+def test_unsorted_segment_min():
+ """UNSORTED_SEGMENT_MIN lowers to scatter_nd with min reduction."""
+
+ class Model(tf.Module):
+ @tf.function(input_signature=[tf.TensorSpec(shape=(4, 2),
dtype=tf.float32)])
+ def func(self, data):
+ return tf.raw_ops.UnsortedSegmentMin(
+ data=data,
+ segment_ids=tf.constant([2, 0, 2, 1], dtype=tf.int32),
+ num_segments=tf.constant(3, dtype=tf.int32),
+ )
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(data: R.Tensor((4, 2), dtype="float32")) -> R.Tensor((3, 2),
dtype="float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ lv: R.Tensor((3, 2), dtype="float32") = R.full(
+ R.shape([3, 2]), R.const(np.finfo(np.float32).max,
"float32"), dtype="float32"
+ )
+ lv1: R.Tensor((4, 1), dtype="int32") = R.expand_dims(
+ R.const([2, 0, 2, 1], "int32"), axis=[1]
+ )
+ gv: R.Tensor((3, 2), dtype="float32") = R.scatter_nd(
+ lv, lv1, data, reduction="min"
+ )
+ R.output(gv)
+ return gv
+
+ verify(Model, Expected)
+
+
+def test_unsorted_segment_prod():
+ """UNSORTED_SEGMENT_PROD lowers to scatter_nd with mul reduction."""
+
+ class Model(tf.Module):
+ @tf.function(input_signature=[tf.TensorSpec(shape=(4, 2),
dtype=tf.float32)])
+ def func(self, data):
+ return tf.raw_ops.UnsortedSegmentProd(
+ data=data,
+ segment_ids=tf.constant([1, 0, 1, 2], dtype=tf.int32),
+ num_segments=tf.constant(3, dtype=tf.int32),
+ )
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(data: R.Tensor((4, 2), dtype="float32")) -> R.Tensor((3, 2),
dtype="float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ lv: R.Tensor((3, 2), dtype="float32") = R.full(
+ R.shape([3, 2]), R.const(1, "float32"), dtype="float32"
+ )
+ lv1: R.Tensor((4, 1), dtype="int32") = R.expand_dims(
+ R.const([1, 0, 1, 2], "int32"), axis=[1]
+ )
+ gv: R.Tensor((3, 2), dtype="float32") = R.scatter_nd(
+ lv, lv1, data, reduction="mul"
+ )
+ R.output(gv)
+ return gv
+
+ verify(Model, Expected)
+
+
def test_batch_matmul():
class BatchMatMul(tf.Module):
@tf.function(