This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9e3eaf39a0 [Relax][Frontend][TFLite] Support dynamic
DYNAMIC_UPDATE_SLICE starts (#19881)
9e3eaf39a0 is described below
commit 9e3eaf39a091c32fdc5a016413ca8a9e6675b59f
Author: Hongyi Wu <[email protected]>
AuthorDate: Fri Jun 26 01:30:25 2026 +0800
[Relax][Frontend][TFLite] Support dynamic DYNAMIC_UPDATE_SLICE starts
(#19881)
## Summary
This PR adds Relax TFLite frontend support for runtime (dynamic) start
indices
in `STABLEHLO_DYNAMIC_UPDATE_SLICE`, addressing the
`DYNAMIC_UPDATE_SLICE` item
from #19412 section B.
`_convert_stablehlo_dynamic_update_slice` (added in #19587) previously
raised
`OpNotImplemented` when the start-index scalars were runtime
(non-constant)
values, handling only compile-time-constant starts. Models that compute
the
update offset at runtime could therefore not be imported. This PR makes
the
dynamic-start path work, with StableHLO clamping semantics, without
adding a new
Relax op. The change is limited to this converter and its test.
## Design
### Dynamic start indices via scatter_nd
The existing static path already lowers `STABLEHLO_DYNAMIC_UPDATE_SLICE`
to
`relax.op.scatter_nd`, building the scatter index grid at compile time
with
`numpy.indices`. `scatter_nd` accepts a general **runtime** `indices`
tensor and
returns the `data` (operand) shape unchanged, so the dynamic case needs
no new
op and introduces no symbolic dimensions — only the index grid is built
in-graph instead of in NumPy.
For runtime starts, the converter builds the index grid per axis `a`
(rank is
statically known from the operand/update shapes):
- clamp the start to `[0, operand_dim - update_dim]` with
`relax.op.maximum` /
`relax.op.minimum` — StableHLO clamps out-of-range starts rather than
erroring;
- `idx = arange(update_dim) + clamped_start`;
- reshape `idx` to broadcast on axis `a` and `broadcast_to` the update
shape;
- `expand_dims` a trailing index axis.
`concat` over the axes produces an int64 index tensor of shape
`(*update_shape, rank)`, which is fed to the same
`relax.op.scatter_nd(operand, indices, update, "update")` call the
static path
uses.
The static (constant-start) path is unchanged, including its
compile-time
out-of-bounds rejection.
## Operator Support
| Operator | TFLite inputs | Relax lowering | Supported subset |
|---|---|---|---|
| `STABLEHLO_DYNAMIC_UPDATE_SLICE` | `operand`, `update`, N scalar
`start` indices | `relax.op.scatter_nd` with a NumPy index grid
(constant starts) or an in-graph `arange` + clamp index grid (runtime
starts) | static operand/update shapes; constant or runtime start
indices |
## Not Included
- Dynamic (non-static) operand or update shapes — the index grid is
built from
the statically known update shape, so operand/update shapes must be
static.
Runtime *start indices* are supported; runtime *tensor shapes* are not.
## Tests
The dynamic-start test compiles the imported module and runs it on the
Relax VM,
comparing the output against a NumPy reference; it includes an
out-of-range start
to exercise clamping. The static structural-equal and out-of-bounds
tests are
unchanged.
| Test | Coverage |
|---|---|
| `test_stablehlo_dynamic_update_slice` | constant start indices,
structural-equal (existing) |
| `test_stablehlo_dynamic_update_slice_dynamic_starts` | runtime start
indices, compile + run, including an out-of-range start that is clamped
|
| `test_stablehlo_dynamic_update_slice_out_of_bounds_unsupported` |
constant-start path rejects out-of-bounds updates (existing) |
Local validation:
```bash
python -m ruff format --check \
python/tvm/relax/frontend/tflite/tflite_frontend.py \
tests/python/relax/test_frontend_tflite.py
python -m ruff check \
python/tvm/relax/frontend/tflite/tflite_frontend.py \
tests/python/relax/test_frontend_tflite.py
python -m pytest \
tests/python/relax/test_frontend_tflite.py -k dynamic_update_slice -q
python -m pytest \
tests/python/relax/test_frontend_tflite.py -q
```
Result:
```text
ruff format --check: 2 files already formatted
ruff check: All checks passed
dynamic_update_slice tests: 3 passed, 555 deselected
full TFLite pytest: 558 passed
```
## References
- Issue #19412 section B: `DYNAMIC_UPDATE_SLICE`
- PR #19587: introduced `STABLEHLO_DYNAMIC_UPDATE_SLICE` (constant
starts) and
multi-subgraph / StableHLO region support
---
.../tvm/relax/frontend/tflite/tflite_frontend.py | 52 +++++++++++++++++++---
tests/python/relax/test_frontend_tflite.py | 49 ++++++++++++++++----
2 files changed, 87 insertions(+), 14 deletions(-)
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index f22786a4c4..e2ab3a7b27 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -3349,7 +3349,13 @@ class OperatorConverter:
return self.bb.normalize(relax.op.dynamic_strided_slice(operand,
begin, end, strides))
def _convert_stablehlo_dynamic_update_slice(self, op):
- """Convert STABLEHLO_DYNAMIC_UPDATE_SLICE to Relax for static
starts."""
+ """Convert STABLEHLO_DYNAMIC_UPDATE_SLICE to Relax.
+
+ Lowers to ``relax.op.scatter_nd``. Constant start indices build the
index
+ grid at compile time; runtime (dynamic) start indices build it in-graph
+ with ``arange`` + broadcast, clamping each start to
+ ``[0, operand_dim - update_dim]`` per StableHLO semantics.
+ """
input_tensors = self.get_input_tensors(op)
# operand + update + N start-index scalars
assert len(input_tensors) >= 3, "input tensors length should be >= 3"
@@ -3368,11 +3374,21 @@ class OperatorConverter:
"STABLEHLO_DYNAMIC_UPDATE_SLICE requires operand, update, "
"and start-index ranks to match"
)
+ for dim, size in zip(operand_shape, update_shape):
+ if size > dim:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_DYNAMIC_UPDATE_SLICE update shape must be
smaller than "
+ "or equal to operand shape for all dimensions"
+ )
+
+ operand = self.get_tensor_expr(operand_tensor)
+ update = self.get_tensor_expr(update_tensor)
if any(self.has_expr(t.tensor_idx) for t in start_tensors):
- raise tvm.error.OpNotImplemented(
- "STABLEHLO_DYNAMIC_UPDATE_SLICE with dynamic start indices is
not supported"
+ indices = self._build_dynamic_update_slice_indices(
+ start_tensors, operand_shape, update_shape, rank
)
+ return self.bb.normalize(relax.op.scatter_nd(operand, indices,
update, "update"))
start_vals = [int(np.asarray(self.get_tensor_value(t)).item()) for t
in start_tensors]
for start, size, dim in zip(start_vals, update_shape, operand_shape):
@@ -3387,11 +3403,37 @@ class OperatorConverter:
update_indices[axis] += start
update_indices = np.moveaxis(update_indices, 0, -1)
- operand = self.get_tensor_expr(operand_tensor)
- update = self.get_tensor_expr(update_tensor)
indices = self.bb.normalize(relax.const(update_indices, dtype="int64"))
return self.bb.normalize(relax.op.scatter_nd(operand, indices, update,
"update"))
+ def _build_dynamic_update_slice_indices(self, start_tensors,
operand_shape, update_shape, rank):
+ """Build the scatter_nd index grid for runtime DYNAMIC_UPDATE_SLICE
starts.
+
+ Returns an int64 tensor of shape ``(*update_shape, rank)`` where axis
``a``
+ holds ``arange(update_shape[a]) + clamp(start[a], 0, operand_dim -
update_dim)``,
+ broadcast over the other axes (StableHLO clamps out-of-range starts).
+ """
+ axis_indices = []
+ for axis in range(rank):
+ start_expr = self.bb.normalize(
+ relax.op.astype(self.get_tensor_expr(start_tensors[axis]),
"int64")
+ )
+ max_start = operand_shape[axis] - update_shape[axis]
+ start_expr = relax.op.maximum(start_expr, relax.const(0, "int64"))
+ start_expr = relax.op.minimum(start_expr, relax.const(max_start,
"int64"))
+
+ base = relax.op.arange(0, update_shape[axis], 1, "int64")
+ idx = relax.op.add(base, start_expr)
+
+ broadcast_shape = [1] * rank
+ broadcast_shape[axis] = update_shape[axis]
+ idx = self.bb.normalize(relax.op.reshape(idx, broadcast_shape))
+ idx = self.bb.normalize(relax.op.broadcast_to(idx, update_shape))
+ idx = self.bb.normalize(relax.op.expand_dims(idx, axis=-1))
+ axis_indices.append(idx)
+
+ return self.bb.normalize(relax.op.concat(axis_indices, axis=-1))
+
def _convert_stablehlo_dot_general(self, op):
"""Convert the canonical 2D STABLEHLO_DOT_GENERAL subset to Relax
matmul."""
from tflite.StablehloDotGeneralOptions import
StablehloDotGeneralOptions
diff --git a/tests/python/relax/test_frontend_tflite.py
b/tests/python/relax/test_frontend_tflite.py
index 9f9d4a0e8a..c259900aef 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -10614,16 +10614,47 @@ def test_stablehlo_dynamic_update_slice():
tvm.ir.assert_structural_equal(mod, Expected)
-def test_stablehlo_dynamic_update_slice_dynamic_starts_unsupported():
- """TFLite StableHLO DYNAMIC_UPDATE_SLICE with runtime starts is
unsupported."""
- buf = _build_stablehlo_dynamic_update_slice_model([0, 0],
dynamic_starts=True)
- if hasattr(tflite.Model, "Model"):
- tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
- else:
- tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+def test_stablehlo_dynamic_update_slice_dynamic_starts():
+ """TFLite StableHLO DYNAMIC_UPDATE_SLICE with runtime starts lowers
structurally."""
+ mod = _load_model_from_buffer(
+ _build_stablehlo_dynamic_update_slice_model([0, 0],
dynamic_starts=True)
+ )
- with pytest.raises(tvm.error.OpNotImplemented, match="dynamic start"):
- from_tflite(tflite_model)
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ operand: R.Tensor((3, 4), dtype="float32"),
+ update: R.Tensor((2, 2), dtype="float32"),
+ s0: R.Tensor((), dtype="int32"),
+ s1: R.Tensor((), dtype="int32"),
+ ) -> R.Tensor((3, 4), dtype="float32"):
+ R.func_attr({"num_input": 4})
+ with R.dataflow():
+ lv: R.Tensor((2,), dtype="int64") = R.arange(0, 2, 1,
dtype="int64")
+ lv1: R.Tensor((), dtype="int64") = R.astype(s0, dtype="int64")
+ lv2: R.Tensor((), dtype="int64") = R.maximum(lv1, R.const(0,
"int64"))
+ lv3: R.Tensor((), dtype="int64") = R.minimum(lv2, R.const(1,
"int64"))
+ lv4: R.Tensor((2,), dtype="int64") = R.add(lv, lv3)
+ lv5: R.Tensor((2, 1), dtype="int64") = R.reshape(lv4, (2, 1))
+ lv6: R.Tensor((2, 2), dtype="int64") = R.broadcast_to(lv5, (2,
2))
+ lv7: R.Tensor((2,), dtype="int64") = R.arange(0, 2, 1,
dtype="int64")
+ lv8: R.Tensor((), dtype="int64") = R.astype(s1, dtype="int64")
+ lv9: R.Tensor((), dtype="int64") = R.maximum(lv8, R.const(0,
"int64"))
+ lv10: R.Tensor((), dtype="int64") = R.minimum(lv9, R.const(2,
"int64"))
+ lv11: R.Tensor((2,), dtype="int64") = R.add(lv7, lv10)
+ lv12: R.Tensor((1, 2), dtype="int64") = R.reshape(lv11, (1, 2))
+ lv13: R.Tensor((2, 2), dtype="int64") = R.broadcast_to(lv12,
(2, 2))
+ lv14: R.Tensor((2, 2, 1), dtype="int64") = R.expand_dims(lv6,
axis=[-1])
+ lv15: R.Tensor((2, 2, 1), dtype="int64") = R.expand_dims(lv13,
axis=[-1])
+ lv16: R.Tensor((2, 2, 2), dtype="int64") = R.concat((lv14,
lv15), axis=-1)
+ gv: R.Tensor((3, 4), dtype="float32") = R.scatter_nd(
+ operand, lv16, update, reduction="update"
+ )
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
def test_stablehlo_dynamic_update_slice_out_of_bounds_unsupported():