This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 87bf3022b7 [Relax][Frontend][TFLite] Add tests coverage for 
SPACE_TO_BATCH_ND and BATCH_TO_SPACE_ND (#19499)
87bf3022b7 is described below

commit 87bf3022b799c5204cc0971dedf57bfa1717cf9d
Author: Bana <[email protected]>
AuthorDate: Mon May 4 04:12:22 2026 +0300

    [Relax][Frontend][TFLite] Add tests coverage for SPACE_TO_BATCH_ND and 
BATCH_TO_SPACE_ND (#19499)
    
    **Changes**
    Add tests in `test_frontend_tflite.py`.
    Lower S`PACE_TO_BATCH_ND` / `BATCH_TO_SPACE_ND` through TOPI in
    `tflite_frontend.py`.
    Use tf.raw_ops.BatchToSpaceND in the test because tf.batch_to_space_nd
    is not available in this TF build.
    
    **Why the TFLite frontend changed**
    The frontend was calling relax.op.nn.space_to_batch_nd /
    relax.op.nn.batch_to_space_nd, which aren’t implemented in this
    checkout. I updated the TFLite frontend to lower these ops via TOPI
    packed calls so conversion works and the new tests can pass.
    
    
    **Test:**
    ```
    pytest test_frontend_tflite.py -k "test_space_to_batch_nd or 
test_batch_to_space_nd"
    ```
    related to #18971
---
 .../tvm/relax/frontend/tflite/tflite_frontend.py   | 47 +++++++++++++--
 tests/python/relax/test_frontend_tflite.py         | 68 ++++++++++++++++++++++
 2 files changed, 109 insertions(+), 6 deletions(-)

diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py 
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 8d112b91d6..e66dff8356 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -3280,10 +3280,27 @@ class OperatorConverter:
         input_tensor_idx = input_tensor.tensor_idx
         in_expr = self.get_expr(input_tensor_idx)
 
-        block_shape = list(self.get_tensor_value(input_tensors[1]))
-        crops = self.get_tensor_value(input_tensors[2]).tolist()
+        block_shape = to_int_list(self.get_tensor_value(input_tensors[1]))
+        crops = self.get_tensor_value(input_tensors[2])
+        crop_begin = to_int_list(crops[:, 0])
+        crop_end = to_int_list(crops[:, 1])
 
-        out = relax.op.nn.batch_to_space_nd(in_expr, block_shape, crops)
+        output_tensors = self.get_output_tensors(op)
+        assert len(output_tensors) == 1, "output tensors length should be 1"
+        output_tensor = output_tensors[0]
+        output_shape = to_int_list(self.get_tensor_shape(output_tensor))
+        output_dtype = self.get_tensor_type_str(output_tensor.tensor.Type())
+
+        out = relax.op.call_dps_packed(
+            "topi.nn.batch_to_space_nd",
+            (
+                in_expr,
+                relax.ShapeExpr(block_shape),
+                relax.ShapeExpr(crop_begin),
+                relax.ShapeExpr(crop_end),
+            ),
+            out_sinfo=relax.TensorStructInfo(output_shape, output_dtype),
+        )
 
         return out
 
@@ -3389,10 +3406,28 @@ class OperatorConverter:
         input_tensor_idx = input_tensor.tensor_idx
         in_expr = self.get_expr(input_tensor_idx)
 
-        block_shape = list(self.get_tensor_value(input_tensors[1]))
-        paddings = self.get_tensor_value(input_tensors[2]).tolist()
+        block_shape = to_int_list(self.get_tensor_value(input_tensors[1]))
+        paddings = self.get_tensor_value(input_tensors[2])
+        pad_before = to_int_list(paddings[:, 0])
+        pad_after = to_int_list(paddings[:, 1])
 
-        out = relax.op.nn.space_to_batch_nd(in_expr, block_shape, paddings)
+        output_tensors = self.get_output_tensors(op)
+        assert len(output_tensors) == 1, "output tensors length should be 1"
+        output_tensor = output_tensors[0]
+        output_shape = to_int_list(self.get_tensor_shape(output_tensor))
+        output_dtype = self.get_tensor_type_str(output_tensor.tensor.Type())
+
+        out = relax.op.call_dps_packed(
+            "topi.nn.space_to_batch_nd",
+            (
+                in_expr,
+                relax.ShapeExpr(block_shape),
+                relax.ShapeExpr(pad_before),
+                relax.ShapeExpr(pad_after),
+                0.0,
+            ),
+            out_sinfo=relax.TensorStructInfo(output_shape, output_dtype),
+        )
 
         return out
 
diff --git a/tests/python/relax/test_frontend_tflite.py 
b/tests/python/relax/test_frontend_tflite.py
index a2d2612232..69e9b290fd 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -3099,6 +3099,74 @@ def test_space_to_depth():
     verify(SpaceToDepth, Expected)
 
 
[email protected](
+    "input_shape, block_shape, paddings, expected_out_shape",
+    [
+        ((1, 2, 2, 1), [2, 2], [[0, 0], [0, 0]], (4, 1, 1, 1)),
+        ((1, 2, 3, 1), [2, 2], [[0, 0], [1, 0]], (4, 1, 2, 1)),
+    ],
+)
+def test_space_to_batch_nd(input_shape, block_shape, paddings, 
expected_out_shape):
+    """SPACE_TO_BATCH_ND imports to Relax and preserves expected output 
shape."""
+
+    class SpaceToBatchND(tf.Module):
+        @tf.function(input_signature=[tf.TensorSpec(shape=input_shape, 
dtype=tf.float32)])
+        def func(self, x):
+            return tf.space_to_batch_nd(
+                x,
+                tf.constant(block_shape, dtype=tf.int32),
+                tf.constant(paddings, dtype=tf.int32),
+            )
+
+    cf = SpaceToBatchND().func.get_concrete_function()
+    mod = _get_mod_from_cfunc(cf)
+    ir = mod.script()
+
+    assert "space_to_batch_nd" in ir
+    assert len(mod["main"].params) == 1
+    tvm.ir.assert_structural_equal(
+        mod["main"].ret_struct_info,
+        relax.TensorStructInfo(expected_out_shape, "float32"),
+    )
+
+    if "CI_ENV_NIGHTLY" in os.environ:
+        verify(SpaceToBatchND)
+
+
[email protected](
+    "input_shape, block_shape, crops, expected_out_shape",
+    [
+        ((4, 1, 1, 1), [2, 2], [[0, 0], [0, 0]], (1, 2, 2, 1)),
+        ((4, 1, 2, 1), [2, 2], [[0, 0], [1, 0]], (1, 2, 3, 1)),
+    ],
+)
+def test_batch_to_space_nd(input_shape, block_shape, crops, 
expected_out_shape):
+    """BATCH_TO_SPACE_ND imports to Relax and preserves expected output 
shape."""
+
+    class BatchToSpaceND(tf.Module):
+        @tf.function(input_signature=[tf.TensorSpec(shape=input_shape, 
dtype=tf.float32)])
+        def func(self, x):
+            return tf.raw_ops.BatchToSpaceND(
+                input=x,
+                block_shape=tf.constant(block_shape, dtype=tf.int32),
+                crops=tf.constant(crops, dtype=tf.int32),
+            )
+
+    cf = BatchToSpaceND().func.get_concrete_function()
+    mod = _get_mod_from_cfunc(cf)
+    ir = mod.script()
+
+    assert "batch_to_space_nd" in ir
+    assert len(mod["main"].params) == 1
+    tvm.ir.assert_structural_equal(
+        mod["main"].ret_struct_info,
+        relax.TensorStructInfo(expected_out_shape, "float32"),
+    )
+
+    if "CI_ENV_NIGHTLY" in os.environ:
+        verify(BatchToSpaceND)
+
+
 def test_leaky_relu():
     class LeakyReLU(tf.Module):
         @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), 
dtype=tf.float32)])

Reply via email to