This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ccd81f220f [Relax][Frontend][TFLite] Fix dynamic FILL/SPLIT_V partial 
implementations (#19433)
ccd81f220f is described below

commit ccd81f220f81e5de6ed6a05f850a2affbbf74ce6
Author: HoYi <[email protected]>
AuthorDate: Wed Apr 29 02:45:25 2026 +0800

    [Relax][Frontend][TFLite] Fix dynamic FILL/SPLIT_V partial implementations 
(#19433)
    
    This PR fixes partial TFLite frontend support for dynamic `FILL` and
    `SPLIT_V`.
    
    Key changes:
    - Allow `FILL` to accept runtime `dims` tensors by converting them with
      `relax.op.tensor_to_shape` before calling `relax.op.full`.
    - Allow `SPLIT_V` to accept runtime `size_splits` tensors by decomposing
    the
      op into `cumsum` and `dynamic_strided_slice`.
    - Use the TFLite output tuple arity as the source of truth for dynamic
    `SPLIT_V`, instead of relying on `size_splits` static shape information.
    - Add TFLite frontend regression tests covering dynamic `FILL` import
    and
      dynamic `SPLIT_V` import/compile behavior.
    
    This addresses the `FILL` and `SPLIT_V` items under the "Fix partial
    implementations" section of #19412.
    
    Validation:
    ```bash
    python -m ruff check python/tvm/relax/frontend/tflite/tflite_frontend.py 
tests/python/relax/test_frontend_tflite.py
    ```
    
    ```bash
    python -m pytest tests/python/relax/test_frontend_tflite.py \
    -k "split_v_dynamic or fill_dynamic_dims" -v
    ```
    
    Result:
    - All checks passed
---
 .../tvm/relax/frontend/tflite/tflite_frontend.py   | 81 +++++++++++++++++-----
 python/tvm/relax/transform/legalize_ops/create.py  | 16 ++++-
 src/relax/op/tensor/create.cc                      |  2 +
 tests/python/relax/test_frontend_tflite.py         | 42 +++++++++++
 4 files changed, 121 insertions(+), 20 deletions(-)

diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py 
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 732950ca68..5536c369db 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -31,7 +31,7 @@ import math
 import numpy as np
 
 import tvm
-from tvm import relax
+from tvm import relax, tirx
 from tvm.relax import op as _op
 
 from .tflite_flexbuffer import FlexBufferDecoder
@@ -1770,14 +1770,24 @@ class OperatorConverter:
         input_tensors = self.get_input_tensors(op)
         assert len(input_tensors) == 2, "input tensors length should be 2"
 
-        if self.has_expr(input_tensors[0].tensor_idx):
-            raise tvm.error.OpNotImplemented(
-                "For dims parameter of Fill operator, only constant values are 
supported."
-            )
-
-        in_dims = list(self.get_tensor_value(input_tensors[0]))
+        dims_tensor = input_tensors[0]
         in_value_expr = self.get_expr(input_tensors[1].tensor_idx)
-        out = relax.op.full(in_dims, in_value_expr)
+
+        if self.has_expr(dims_tensor.tensor_idx):
+            dims_expr = self.get_expr(dims_tensor.tensor_idx)
+            dims_ndim = int(self.get_tensor_shape(dims_tensor)[0])
+
+            # Bind runtime dims to fresh symbolic shape vars so the imported
+            # module remains well formed before LegalizeOps runs.
+            dims_expr = self.bb.match_cast(dims_expr, 
relax.TensorStructInfo([dims_ndim], "int32"))
+            dims_expr = self.bb.normalize(relax.op.astype(dims_expr, "int64"))
+            shape_dataflow_var = 
self.bb.emit(relax.op.tensor_to_shape(dims_expr))
+            shape_vars = [tirx.Var(f"fill_dim_{i}", "int64") for i in 
range(dims_ndim)]
+            self.bb.match_cast(shape_dataflow_var, 
relax.ShapeStructInfo(shape_vars))
+            out = relax.op.full(relax.ShapeExpr(shape_vars), in_value_expr)
+        else:
+            in_dims = list(self.get_tensor_value(dims_tensor))
+            out = relax.op.full(in_dims, in_value_expr)
 
         return out
 
@@ -2331,6 +2341,7 @@ class OperatorConverter:
     def convert_split_v(self, op):
         """SPLIT_V implementation."""
         input_tensors = self.get_input_tensors(op)
+        output_tensors = self.get_output_tensors(op)
 
         assert len(input_tensors) == 3, "input tensors length should be 3"
 
@@ -2338,22 +2349,56 @@ class OperatorConverter:
         input_tensor_idx = input_tensor.tensor_idx
         in_expr = self.get_expr(input_tensor_idx)
 
-        if self.has_expr(input_tensors[1].tensor_idx):
-            raise tvm.error.OpNotImplemented(
-                "For size_splits parameter of SPLIT_V operator, only constant 
values are supported."
-            )
-        size_splits = list(self.get_tensor_value(input_tensors[1]))
-        size_splits = tuple(np.cumsum(size_splits)[:-1])
-
         axis_tensor = input_tensors[2]
-        split_axis = self.get_tensor_value(axis_tensor)
+        split_axis = int(self.get_tensor_value(axis_tensor))
+
+        size_splits_tensor = input_tensors[1]
+
+        if self.has_expr(size_splits_tensor.tensor_idx):
+            # Dynamic size_splits case: decompose into dynamic strided slices.
+            size_splits_expr = self.get_expr(size_splits_tensor.tensor_idx)
+            cumsum = relax.op.cumsum(size_splits_expr, axis=0, dtype="int64")
+            # Pad a leading zero so that cumsum[i-1] can be read uniformly
+            # via strided_slice even for i == 0.
+            zero = relax.const(np.array([0], dtype="int64"), "int64")
+            padded_cumsum = relax.op.concat([zero, cumsum], axis=0)
+            # TFLite fixes the tuple arity in the graph, even when the split
+            # sizes themselves are supplied at runtime.
+            num_splits = len(output_tensors)
+            rank = len(in_expr.struct_info.shape)
+
+            # end_base is the full input shape; only split_axis changes per 
slice.
+            end_base = relax.op.shape_to_tensor(relax.op.shape_of(in_expr))
+            begin_base = relax.const(np.zeros((rank,), dtype="int64"), "int64")
+            strides = relax.const(np.ones((rank,), dtype="int64"), "int64")
+            scatter_idx = relax.const([split_axis], "int64")
+
+            outputs = []
+            for i in range(num_splits):
+                start_val = relax.op.strided_slice(
+                    padded_cumsum, axes=[0], begin=[i], end=[i + 1]
+                )
+                end_val = relax.op.strided_slice(
+                    padded_cumsum, axes=[0], begin=[i + 1], end=[i + 2]
+                )
+
+                begin = relax.op.scatter_elements(begin_base, scatter_idx, 
start_val)
+                end = relax.op.scatter_elements(end_base, scatter_idx, end_val)
+                slice_i = relax.op.dynamic_strided_slice(in_expr, begin, end, 
strides)
+                outputs.append(slice_i)
+
+            out = relax.Tuple(outputs)
+        else:
+            # Static size_splits case
+            size_splits = list(self.get_tensor_value(size_splits_tensor))
+            size_splits = tuple(np.cumsum(size_splits)[:-1])
+            out = relax.op.split(in_expr, size_splits, axis=split_axis)
 
-        out = relax.op.split(in_expr, size_splits, axis=int(split_axis))
         # Relay does not like a TupleWrapper of 1 element, further this
         # only shows up with tf1.13 if we use a split with num_splits==1.
         # In tf 1.14 this doesn't appear as it is automatically a reshape
         # operation.
-        if isinstance(out, relax.Tuple) and out.size == 1:
+        if isinstance(out, relax.Tuple) and len(out.fields) == 1:
             out = out[0]
 
         return out
diff --git a/python/tvm/relax/transform/legalize_ops/create.py 
b/python/tvm/relax/transform/legalize_ops/create.py
index 99b4449ebf..6708859caa 100644
--- a/python/tvm/relax/transform/legalize_ops/create.py
+++ b/python/tvm/relax/transform/legalize_ops/create.py
@@ -23,7 +23,8 @@ import numpy as np
 from tvm import tirx, topi
 
 from ...block_builder import BlockBuilder
-from ...expr import Call, Expr, PrimValue, const
+from ...expr import Call, Expr, PrimValue, ShapeExpr, const
+from ...struct_info import ShapeStructInfo
 from .common import LegalizeFunc, _try_convert_to_scalar_const, 
register_legalize
 
 
@@ -34,10 +35,21 @@ def _full(is_like: bool, fill_value: float | None, 
primfunc_name: str) -> Legali
             if fill_value is None
             else fill_value
         )
+        shape = call.args[0].struct_info.shape if is_like else call.args[0]
+
+        if isinstance(shape, ShapeExpr):
+            output_shape = shape.values
+        else:
+            assert isinstance(shape.struct_info, ShapeStructInfo)
+            assert shape.struct_info.ndim >= 0
+
+            shape = bb.emit(shape)
+            output_shape = [tirx.Var(f"s{i}", "int64") for i in 
range(shape.struct_info.ndim)]
+            bb.match_cast(shape, ShapeStructInfo(output_shape))
 
         return bb.call_te(
             topi.full,
-            call.args[0].struct_info.shape if is_like else call.args[0],
+            output_shape,
             call.struct_info.dtype,
             _fill_value,
             primfunc_name_hint=primfunc_name,
diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc
index 88ea67ab99..fb2aa448e3 100644
--- a/src/relax/op/tensor/create.cc
+++ b/src/relax/op/tensor/create.cc
@@ -96,6 +96,8 @@ TVM_REGISTER_OP("relax.full")
     .add_argument("shape", "Shape", "The shape of the created tensor.")
     .add_argument("fill_value", "Tensor", "The scalar tensor, denoting the 
value to fill.")
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoFull)
+    .set_attr<Bool>("RequiresArgumentShapes", Bool(false))
+    .set_attr<Bool>("FDataDependent", Bool(true))
     .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow)
     .set_attr<Bool>("FPurity", Bool(true));
 
diff --git a/tests/python/relax/test_frontend_tflite.py 
b/tests/python/relax/test_frontend_tflite.py
index 908868faf0..23002c8668 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -208,6 +208,26 @@ def test_split():
     verify(Split, Expected)
 
 
+def test_split_v_dynamic():
+    """SPLIT_V with runtime split sizes imports shape-aware Relax IR."""
+
+    class TfSplitVDynamic(tf.Module):
+        @tf.function(
+            input_signature=[
+                tf.TensorSpec(shape=(10,), dtype=tf.float32),
+                tf.TensorSpec(shape=(3,), dtype=tf.int32),
+            ]
+        )
+        def func(self, x, size_splits):
+            return tf.split(x, size_splits, axis=0)
+
+    cf = TfSplitVDynamic().func.get_concrete_function()
+    mod = _get_mod_from_cfunc(cf)
+    ir = mod.script()
+    assert "R.dynamic_strided_slice" in ir
+    assert "R.scatter_elements" in ir
+
+
 def test_pack():
     class Pack(tf.Module):
         @tf.function(
@@ -592,6 +612,28 @@ def test_fill():
     verify(TfInput, Expected)
 
 
+def test_fill_dynamic_dims():
+    """FILL with runtime dims legalizes and compiles."""
+
+    class TfFillDynamic(tf.Module):
+        @tf.function(
+            input_signature=[
+                tf.TensorSpec(shape=(2,), dtype=tf.int32),
+                tf.TensorSpec(shape=(), dtype=tf.float32),
+            ]
+        )
+        def func(self, dims, value):
+            return tf.fill(dims, value)
+
+    cf = TfFillDynamic().func.get_concrete_function()
+    mod = _get_mod_from_cfunc(cf)
+    ir = mod.script()
+    assert "R.tensor_to_shape" in ir
+    assert "R.full" in ir
+    tvm.compile(mod, tvm.target.Target("llvm"))
+    verify(cf)
+
+
 @pytest.mark.parametrize(
     "tf_op, relax_op",
     [

Reply via email to