This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 90c678abc9 [Relax][Frontend][TFLite] Fix `STRIDED_SLICE` negative
stride and add `STRIDED_SLICE/SPLIT_V` tests (#19468)
90c678abc9 is described below
commit 90c678abc927806053d89fae0889ff5560e711c8
Author: HoYi <[email protected]>
AuthorDate: Thu Apr 30 04:53:59 2026 +0800
[Relax][Frontend][TFLite] Fix `STRIDED_SLICE` negative stride and add
`STRIDED_SLICE/SPLIT_V` tests (#19468)
## Summary
This PR continues the TFLite frontend work tracked in #18971 for
`STRIDED_SLICE` and `SPLIT_V`.
Since the dynamic `FILL` / `SPLIT_V` partial-implementation work has
already been handled separately in #19433, this PR focuses on the
remaining pieces in this branch:
- fixing negative-stride `STRIDED_SLICE` conversion in the TFLite
frontend
- adding regression coverage for `STRIDED_SLICE` and static `SPLIT_V`
Relates to #18971.
## Changes
1. **`STRIDED_SLICE` negative-stride fix**
- Update the TFLite frontend `convert_strided_slice` handling of
`end_mask` when `stride < 0`.
- Use an exclusive lower bound compatible with Relax slicing semantics
so reverse slices like `x[::-1]` include index `0` correctly.
2. **TFLite frontend test coverage**
- Add `test_strided_slice_stride` to cover non-unit stride handling.
- Add `test_strided_slice_negative_stride` to cover reverse slicing with
negative strides.
- Add `test_split_v_static` to cover static `SPLIT_V` conversion.
3. **Scope clarification**
- Keep this PR focused on the remaining `STRIDED_SLICE` / `SPLIT_V` work
from #18971.
- Exclude the dynamic `FILL` / `SPLIT_V` changes that are already
addressed in #19433.
## Testing
```bash
python -m pytest -n 1
tests/python/relax/test_frontend_tflite.py::test_split_v_static -q
python -m pytest -n 1
tests/python/relax/test_frontend_tflite.py::test_strided_slice_stride -q
python -m pytest -n 1
tests/python/relax/test_frontend_tflite.py::test_strided_slice_negative_stride
-q
```
## Result
- The added STRIDED_SLICE and SPLIT_V tests passed locally.
- The negative-stride STRIDED_SLICE path now matches Relax slicing
semantics.
---
.../tvm/relax/frontend/tflite/tflite_frontend.py | 7 +-
tests/python/relax/test_frontend_tflite.py | 86 ++++++++++++++++++++++
2 files changed, 92 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 80c5e024f9..b7a7e42c48 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -1704,7 +1704,12 @@ class OperatorConverter:
elif begin[index]:
m_begin[final_index] = begin[index]
if mask & end_mask:
- m_end[final_index] = 0 if stride[index] < 0 else
data_shape[final_index]
+ if stride[index] < 0:
+ # Relax negative-step slicing excludes the end
index, so an
+ # unspecified lower bound needs one extra step
past index 0.
+ m_end[final_index] = -data_shape[final_index] - 1
+ else:
+ m_end[final_index] = data_shape[final_index]
elif end[index]:
m_end[final_index] = end[index]
m_stride[final_index] = stride[index]
diff --git a/tests/python/relax/test_frontend_tflite.py
b/tests/python/relax/test_frontend_tflite.py
index bd0031d0aa..69aab2d43b 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -229,6 +229,43 @@ def test_split_v_dynamic():
assert "R.scatter_elements" in ir
+def test_split_v_static():
+ """SPLIT_V with static unequal size_splits lowers to Relax split."""
+
+ class SplitVUnequal(tf.Module):
+ @tf.function(input_signature=[tf.TensorSpec(shape=(2, 10, 4),
dtype=tf.float32)])
+ def func(self, x):
+ return tf.split(x, [2, 3, 5], axis=1)
+
+ @I.ir_module
+ class ExpectedUnequal:
+ @R.function
+ def main(x: R.Tensor((2, 10, 4), dtype="float32")) -> R.Tuple(
+ R.Tensor((2, 2, 4), dtype="float32"),
+ R.Tensor((2, 3, 4), dtype="float32"),
+ R.Tensor((2, 5, 4), dtype="float32"),
+ ):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ lv: R.Tuple(
+ R.Tensor((2, 2, 4), dtype="float32"),
+ R.Tensor((2, 3, 4), dtype="float32"),
+ R.Tensor((2, 5, 4), dtype="float32"),
+ ) = R.split(x, indices_or_sections=[2, 5], axis=1)
+ lv1: R.Tensor((2, 2, 4), dtype="float32") = lv[0]
+ lv2: R.Tensor((2, 3, 4), dtype="float32") = lv[1]
+ lv3: R.Tensor((2, 5, 4), dtype="float32") = lv[2]
+ gv: R.Tuple(
+ R.Tensor((2, 2, 4), dtype="float32"),
+ R.Tensor((2, 3, 4), dtype="float32"),
+ R.Tensor((2, 5, 4), dtype="float32"),
+ ) = lv1, lv2, lv3
+ R.output(gv)
+ return gv
+
+ verify(SplitVUnequal, ExpectedUnequal)
+
+
def test_pack():
class Pack(tf.Module):
@tf.function(
@@ -1159,6 +1196,55 @@ def test_slice():
verify(Slice, Expected)
+def test_strided_slice_stride():
+ class StridedSliceStride(tf.Module):
+ @tf.function(input_signature=[tf.TensorSpec(shape=(4, 6),
dtype=tf.float32)])
+ def func(self, x):
+ return x[0:2, 1:5:2]
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((4, 6), dtype="float32")) -> R.Tensor((2, 2),
dtype="float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ lv: R.Tensor((2, 2), dtype="float32") = R.strided_slice(
+ x,
+ axes=[0, 1],
+ begin=[0, 1],
+ end=[2, 5],
+ strides=[1, 2],
+ assume_inbound=False,
+ )
+ gv: R.Tensor((2, 2), dtype="float32") = R.reshape(lv,
R.shape([2, 2]))
+ R.output(gv)
+ return gv
+
+ verify(StridedSliceStride, Expected)
+
+
+def test_strided_slice_negative_stride():
+ class StridedSliceNegativeStride(tf.Module):
+ @tf.function(input_signature=[tf.TensorSpec(shape=(4,),
dtype=tf.float32)])
+ def func(self, x):
+ return x[::-1]
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((4,), dtype="float32")) -> R.Tensor((4,),
dtype="float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ lv: R.Tensor((4,), dtype="float32") = R.strided_slice(
+ x, axes=[0], begin=[4], end=[-5], strides=[-1],
assume_inbound=False
+ )
+ gv: R.Tensor((4,), dtype="float32") = R.reshape(lv,
R.shape([4]))
+ R.output(gv)
+ return gv
+
+ verify(StridedSliceNegativeStride, Expected)
+
+
def test_reverse_v2():
class ReverseV2(tf.Module):
@tf.function(input_signature=[tf.TensorSpec(shape=(2, 3),
dtype=tf.float32)])