This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ccd81f220f [Relax][Frontend][TFLite] Fix dynamic FILL/SPLIT_V partial
implementations (#19433)
ccd81f220f is described below
commit ccd81f220f81e5de6ed6a05f850a2affbbf74ce6
Author: HoYi <[email protected]>
AuthorDate: Wed Apr 29 02:45:25 2026 +0800
[Relax][Frontend][TFLite] Fix dynamic FILL/SPLIT_V partial implementations
(#19433)
This PR fixes partial TFLite frontend support for dynamic `FILL` and
`SPLIT_V`.
Key changes:
- Allow `FILL` to accept runtime `dims` tensors by converting them with
`relax.op.tensor_to_shape` before calling `relax.op.full`.
- Allow `SPLIT_V` to accept runtime `size_splits` tensors by decomposing
the
op into `cumsum` and `dynamic_strided_slice`.
- Use the TFLite output tuple arity as the source of truth for dynamic
`SPLIT_V`, instead of relying on `size_splits` static shape information.
- Add TFLite frontend regression tests covering dynamic `FILL` import
and
dynamic `SPLIT_V` import/compile behavior.
This addresses the `FILL` and `SPLIT_V` items under the "Fix partial
implementations" section of #19412.
Validation:
```bash
python -m ruff check python/tvm/relax/frontend/tflite/tflite_frontend.py
tests/python/relax/test_frontend_tflite.py
```
```bash
python -m pytest tests/python/relax/test_frontend_tflite.py \
-k "split_v_dynamic or fill_dynamic_dims" -v
```
Result:
- All checks passed
---
.../tvm/relax/frontend/tflite/tflite_frontend.py | 81 +++++++++++++++++-----
python/tvm/relax/transform/legalize_ops/create.py | 16 ++++-
src/relax/op/tensor/create.cc | 2 +
tests/python/relax/test_frontend_tflite.py | 42 +++++++++++
4 files changed, 121 insertions(+), 20 deletions(-)
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 732950ca68..5536c369db 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -31,7 +31,7 @@ import math
import numpy as np
import tvm
-from tvm import relax
+from tvm import relax, tirx
from tvm.relax import op as _op
from .tflite_flexbuffer import FlexBufferDecoder
@@ -1770,14 +1770,24 @@ class OperatorConverter:
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2"
- if self.has_expr(input_tensors[0].tensor_idx):
- raise tvm.error.OpNotImplemented(
- "For dims parameter of Fill operator, only constant values are
supported."
- )
-
- in_dims = list(self.get_tensor_value(input_tensors[0]))
+ dims_tensor = input_tensors[0]
in_value_expr = self.get_expr(input_tensors[1].tensor_idx)
- out = relax.op.full(in_dims, in_value_expr)
+
+ if self.has_expr(dims_tensor.tensor_idx):
+ dims_expr = self.get_expr(dims_tensor.tensor_idx)
+ dims_ndim = int(self.get_tensor_shape(dims_tensor)[0])
+
+ # Bind runtime dims to fresh symbolic shape vars so the imported
+ # module remains well formed before LegalizeOps runs.
+ dims_expr = self.bb.match_cast(dims_expr,
relax.TensorStructInfo([dims_ndim], "int32"))
+ dims_expr = self.bb.normalize(relax.op.astype(dims_expr, "int64"))
+ shape_dataflow_var =
self.bb.emit(relax.op.tensor_to_shape(dims_expr))
+ shape_vars = [tirx.Var(f"fill_dim_{i}", "int64") for i in
range(dims_ndim)]
+ self.bb.match_cast(shape_dataflow_var,
relax.ShapeStructInfo(shape_vars))
+ out = relax.op.full(relax.ShapeExpr(shape_vars), in_value_expr)
+ else:
+ in_dims = list(self.get_tensor_value(dims_tensor))
+ out = relax.op.full(in_dims, in_value_expr)
return out
@@ -2331,6 +2341,7 @@ class OperatorConverter:
def convert_split_v(self, op):
"""SPLIT_V implementation."""
input_tensors = self.get_input_tensors(op)
+ output_tensors = self.get_output_tensors(op)
assert len(input_tensors) == 3, "input tensors length should be 3"
@@ -2338,22 +2349,56 @@ class OperatorConverter:
input_tensor_idx = input_tensor.tensor_idx
in_expr = self.get_expr(input_tensor_idx)
- if self.has_expr(input_tensors[1].tensor_idx):
- raise tvm.error.OpNotImplemented(
- "For size_splits parameter of SPLIT_V operator, only constant
values are supported."
- )
- size_splits = list(self.get_tensor_value(input_tensors[1]))
- size_splits = tuple(np.cumsum(size_splits)[:-1])
-
axis_tensor = input_tensors[2]
- split_axis = self.get_tensor_value(axis_tensor)
+ split_axis = int(self.get_tensor_value(axis_tensor))
+
+ size_splits_tensor = input_tensors[1]
+
+ if self.has_expr(size_splits_tensor.tensor_idx):
+ # Dynamic size_splits case: decompose into dynamic strided slices.
+ size_splits_expr = self.get_expr(size_splits_tensor.tensor_idx)
+ cumsum = relax.op.cumsum(size_splits_expr, axis=0, dtype="int64")
+ # Pad a leading zero so that cumsum[i-1] can be read uniformly
+ # via strided_slice even for i == 0.
+ zero = relax.const(np.array([0], dtype="int64"), "int64")
+ padded_cumsum = relax.op.concat([zero, cumsum], axis=0)
+ # TFLite fixes the tuple arity in the graph, even when the split
+ # sizes themselves are supplied at runtime.
+ num_splits = len(output_tensors)
+ rank = len(in_expr.struct_info.shape)
+
+ # end_base is the full input shape; only split_axis changes per
slice.
+ end_base = relax.op.shape_to_tensor(relax.op.shape_of(in_expr))
+ begin_base = relax.const(np.zeros((rank,), dtype="int64"), "int64")
+ strides = relax.const(np.ones((rank,), dtype="int64"), "int64")
+ scatter_idx = relax.const([split_axis], "int64")
+
+ outputs = []
+ for i in range(num_splits):
+ start_val = relax.op.strided_slice(
+ padded_cumsum, axes=[0], begin=[i], end=[i + 1]
+ )
+ end_val = relax.op.strided_slice(
+ padded_cumsum, axes=[0], begin=[i + 1], end=[i + 2]
+ )
+
+ begin = relax.op.scatter_elements(begin_base, scatter_idx,
start_val)
+ end = relax.op.scatter_elements(end_base, scatter_idx, end_val)
+ slice_i = relax.op.dynamic_strided_slice(in_expr, begin, end,
strides)
+ outputs.append(slice_i)
+
+ out = relax.Tuple(outputs)
+ else:
+ # Static size_splits case
+ size_splits = list(self.get_tensor_value(size_splits_tensor))
+ size_splits = tuple(np.cumsum(size_splits)[:-1])
+ out = relax.op.split(in_expr, size_splits, axis=split_axis)
- out = relax.op.split(in_expr, size_splits, axis=int(split_axis))
# Relay does not like a TupleWrapper of 1 element, further this
# only shows up with tf1.13 if we use a split with num_splits==1.
# In tf 1.14 this doesn't appear as it is automatically a reshape
# operation.
- if isinstance(out, relax.Tuple) and out.size == 1:
+ if isinstance(out, relax.Tuple) and len(out.fields) == 1:
out = out[0]
return out
diff --git a/python/tvm/relax/transform/legalize_ops/create.py
b/python/tvm/relax/transform/legalize_ops/create.py
index 99b4449ebf..6708859caa 100644
--- a/python/tvm/relax/transform/legalize_ops/create.py
+++ b/python/tvm/relax/transform/legalize_ops/create.py
@@ -23,7 +23,8 @@ import numpy as np
from tvm import tirx, topi
from ...block_builder import BlockBuilder
-from ...expr import Call, Expr, PrimValue, const
+from ...expr import Call, Expr, PrimValue, ShapeExpr, const
+from ...struct_info import ShapeStructInfo
from .common import LegalizeFunc, _try_convert_to_scalar_const,
register_legalize
@@ -34,10 +35,21 @@ def _full(is_like: bool, fill_value: float | None,
primfunc_name: str) -> Legali
if fill_value is None
else fill_value
)
+ shape = call.args[0].struct_info.shape if is_like else call.args[0]
+
+ if isinstance(shape, ShapeExpr):
+ output_shape = shape.values
+ else:
+ assert isinstance(shape.struct_info, ShapeStructInfo)
+ assert shape.struct_info.ndim >= 0
+
+ shape = bb.emit(shape)
+ output_shape = [tirx.Var(f"s{i}", "int64") for i in
range(shape.struct_info.ndim)]
+ bb.match_cast(shape, ShapeStructInfo(output_shape))
return bb.call_te(
topi.full,
- call.args[0].struct_info.shape if is_like else call.args[0],
+ output_shape,
call.struct_info.dtype,
_fill_value,
primfunc_name_hint=primfunc_name,
diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc
index 88ea67ab99..fb2aa448e3 100644
--- a/src/relax/op/tensor/create.cc
+++ b/src/relax/op/tensor/create.cc
@@ -96,6 +96,8 @@ TVM_REGISTER_OP("relax.full")
.add_argument("shape", "Shape", "The shape of the created tensor.")
.add_argument("fill_value", "Tensor", "The scalar tensor, denoting the
value to fill.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoFull)
+ .set_attr<Bool>("RequiresArgumentShapes", Bool(false))
+ .set_attr<Bool>("FDataDependent", Bool(true))
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow)
.set_attr<Bool>("FPurity", Bool(true));
diff --git a/tests/python/relax/test_frontend_tflite.py
b/tests/python/relax/test_frontend_tflite.py
index 908868faf0..23002c8668 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -208,6 +208,26 @@ def test_split():
verify(Split, Expected)
+def test_split_v_dynamic():
+ """SPLIT_V with runtime split sizes imports shape-aware Relax IR."""
+
+ class TfSplitVDynamic(tf.Module):
+ @tf.function(
+ input_signature=[
+ tf.TensorSpec(shape=(10,), dtype=tf.float32),
+ tf.TensorSpec(shape=(3,), dtype=tf.int32),
+ ]
+ )
+ def func(self, x, size_splits):
+ return tf.split(x, size_splits, axis=0)
+
+ cf = TfSplitVDynamic().func.get_concrete_function()
+ mod = _get_mod_from_cfunc(cf)
+ ir = mod.script()
+ assert "R.dynamic_strided_slice" in ir
+ assert "R.scatter_elements" in ir
+
+
def test_pack():
class Pack(tf.Module):
@tf.function(
@@ -592,6 +612,28 @@ def test_fill():
verify(TfInput, Expected)
+def test_fill_dynamic_dims():
+ """FILL with runtime dims legalizes and compiles."""
+
+ class TfFillDynamic(tf.Module):
+ @tf.function(
+ input_signature=[
+ tf.TensorSpec(shape=(2,), dtype=tf.int32),
+ tf.TensorSpec(shape=(), dtype=tf.float32),
+ ]
+ )
+ def func(self, dims, value):
+ return tf.fill(dims, value)
+
+ cf = TfFillDynamic().func.get_concrete_function()
+ mod = _get_mod_from_cfunc(cf)
+ ir = mod.script()
+ assert "R.tensor_to_shape" in ir
+ assert "R.full" in ir
+ tvm.compile(mod, tvm.target.Target("llvm"))
+ verify(cf)
+
+
@pytest.mark.parametrize(
"tf_op, relax_op",
[