This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new cb0d55d25b [TFLite] Use structural checks in frontend tests (#19888)
cb0d55d25b is described below

commit cb0d55d25b2d9584511f1d1ba4d35935de0007c0
Author: Shushi Hong <[email protected]>
AuthorDate: Thu Jun 25 07:14:18 2026 -0400

    [TFLite] Use structural checks in frontend tests (#19888)
    
    This pr
    - Replaces script-string based TFLite frontend IR checks with structural
    checks.
    - Keeps numeric/runtime tests focused on VM output validation instead of
    mixing in structural assertions.
    - Leaves the shared `verify` helper unchanged and use direct
    `tvm.ir.assert_structural_equal` where structural checks are needed.
---
 tests/python/relax/test_frontend_tflite.py | 1160 +++++++++++++++++++++++-----
 1 file changed, 983 insertions(+), 177 deletions(-)

diff --git a/tests/python/relax/test_frontend_tflite.py 
b/tests/python/relax/test_frontend_tflite.py
index 6f2845da50..9f9d4a0e8a 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -263,11 +263,115 @@ def test_split_v_dynamic():
         def func(self, x, size_splits):
             return tf.split(x, size_splits, axis=0)
 
-    cf = TfSplitVDynamic().func.get_concrete_function()
-    mod = _get_mod_from_cfunc(cf)
-    ir = mod.script()
-    assert "R.dynamic_strided_slice" in ir
-    assert "R.scatter_elements" in ir
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((10,), dtype="float32"),
+            size_splits: R.Tensor((3,), dtype="int32"),
+        ) -> R.Tuple(
+            R.Tensor(dtype="float32", ndim=1),
+            R.Tensor(dtype="float32", ndim=1),
+            R.Tensor(dtype="float32", ndim=1),
+        ):
+            R.func_attr({"num_input": 2})
+            with R.dataflow():
+                lv: R.Tensor((3,), dtype="int64") = R.cumsum(
+                    size_splits, axis=0, dtype="int64", exclusive=False
+                )
+                lv1: R.Tensor((4,), dtype="int64") = R.concat((R.const([0], 
"int64"), lv), axis=0)
+                lv2: R.Tensor((1,), dtype="int64") = R.strided_slice(
+                    lv1,
+                    (R.prim_value(0),),
+                    (R.prim_value(0),),
+                    (R.prim_value(1),),
+                    assume_inbound=False,
+                )
+                lv3: R.Tensor((1,), dtype="int64") = R.scatter_elements(
+                    R.const([0], "int64"),
+                    R.const([0], "int64"),
+                    lv2,
+                    axis=0,
+                    reduction="update",
+                )
+                lv4: R.Shape([10]) = R.shape_of(x)
+                lv5: R.Tensor((1,), dtype="int64") = R.shape_to_tensor(lv4)
+                lv6: R.Tensor((1,), dtype="int64") = R.strided_slice(
+                    lv1,
+                    (R.prim_value(0),),
+                    (R.prim_value(1),),
+                    (R.prim_value(2),),
+                    assume_inbound=False,
+                )
+                lv7: R.Tensor((1,), dtype="int64") = R.scatter_elements(
+                    lv5, R.const([0], "int64"), lv6, axis=0, reduction="update"
+                )
+                lv8: R.Tensor(dtype="float32", ndim=1) = 
R.dynamic_strided_slice(
+                    x, lv3, lv7, R.const([1], "int64")
+                )
+                lv9: R.Tensor((1,), dtype="int64") = R.strided_slice(
+                    lv1,
+                    (R.prim_value(0),),
+                    (R.prim_value(1),),
+                    (R.prim_value(2),),
+                    assume_inbound=False,
+                )
+                lv10: R.Tensor((1,), dtype="int64") = R.scatter_elements(
+                    R.const([0], "int64"),
+                    R.const([0], "int64"),
+                    lv9,
+                    axis=0,
+                    reduction="update",
+                )
+                lv11: R.Tensor((1,), dtype="int64") = R.strided_slice(
+                    lv1,
+                    (R.prim_value(0),),
+                    (R.prim_value(2),),
+                    (R.prim_value(3),),
+                    assume_inbound=False,
+                )
+                lv12: R.Tensor((1,), dtype="int64") = R.scatter_elements(
+                    lv5, R.const([0], "int64"), lv11, axis=0, 
reduction="update"
+                )
+                lv13: R.Tensor(dtype="float32", ndim=1) = 
R.dynamic_strided_slice(
+                    x, lv10, lv12, R.const([1], "int64")
+                )
+                lv14: R.Tensor((1,), dtype="int64") = R.strided_slice(
+                    lv1,
+                    (R.prim_value(0),),
+                    (R.prim_value(2),),
+                    (R.prim_value(3),),
+                    assume_inbound=False,
+                )
+                lv15: R.Tensor((1,), dtype="int64") = R.scatter_elements(
+                    R.const([0], "int64"),
+                    R.const([0], "int64"),
+                    lv14,
+                    axis=0,
+                    reduction="update",
+                )
+                lv16: R.Tensor((1,), dtype="int64") = R.strided_slice(
+                    lv1,
+                    (R.prim_value(0),),
+                    (R.prim_value(3),),
+                    (R.prim_value(4),),
+                    assume_inbound=False,
+                )
+                lv17: R.Tensor((1,), dtype="int64") = R.scatter_elements(
+                    lv5, R.const([0], "int64"), lv16, axis=0, 
reduction="update"
+                )
+                lv18: R.Tensor(dtype="float32", ndim=1) = 
R.dynamic_strided_slice(
+                    x, lv15, lv17, R.const([1], "int64")
+                )
+                gv: R.Tuple(
+                    R.Tensor(dtype="float32", ndim=1),
+                    R.Tensor(dtype="float32", ndim=1),
+                    R.Tensor(dtype="float32", ndim=1),
+                ) = (lv8, lv13, lv18)
+                R.output(gv)
+            return gv
+
+    verify(TfSplitVDynamic, Expected)
 
 
 def test_split_v_static():
@@ -530,10 +634,28 @@ def test_unique():
         def func(self, x):
             return tf.raw_ops.Unique(x=x, out_idx=tf.int64)
 
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((6,), dtype="int32"),
+        ) -> R.Tuple(R.Tensor(dtype="int32", ndim=1), R.Tensor(dtype="int64", 
ndim=1)):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                lv: R.Tuple(R.Tensor(dtype="int32", ndim=1), 
R.Tensor(dtype="int64", ndim=1)) = (
+                    R.unique(x, R.prim_value(0), R.prim_value(0), 
R.prim_value(1), R.prim_value(0))
+                )
+                lv1: R.Tensor(dtype="int32", ndim=1) = lv[0]
+                lv2: R.Tensor(dtype="int64", ndim=1) = lv[1]
+                gv: R.Tuple(R.Tensor(dtype="int32", ndim=1), 
R.Tensor(dtype="int64", ndim=1)) = (
+                    lv1,
+                    lv2,
+                )
+                R.output(gv)
+            return gv
+
     mod = _get_mod_from_cfunc(Unique().func.get_concrete_function())
-    values, inverse_indices = _run_module(mod, np.array([3, 1, 3, 2, 1, 2], 
dtype=np.int32))
-    np.testing.assert_array_equal(values, np.array([3, 1, 2], dtype=np.int32))
-    np.testing.assert_array_equal(inverse_indices, np.array([0, 1, 0, 2, 1, 
2], dtype=np.int64))
+    tvm.ir.assert_structural_equal(mod, Expected)
 
 
 def test_expand_dims():
@@ -612,7 +734,23 @@ def test_shape(input_shape, out_type):
         def func(self, x):
             return tf.shape(x, out_type=out_type)
 
-    verify(Shape)
+    out_dtype = "int32" if out_type == tf.int32 else "int64"
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor(input_shape, dtype="float32"),
+        ) -> R.Tensor((len(input_shape),), dtype=out_dtype):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                gv: R.Tensor((len(input_shape),), dtype=out_dtype) = R.const(
+                    list(input_shape), out_dtype
+                )
+                R.output(gv)
+            return gv
+
+    verify(Shape, Expected)
 
 
 def test_shape_dynamic_dim():
@@ -623,7 +761,19 @@ def test_shape_dynamic_dim():
         def func(self, x):
             return tf.shape(x, out_type=tf.int32)
 
-    verify(ShapeDynamic)
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((1, 3), dtype="float32")) -> R.Tensor((2,), 
dtype="int32"):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                lv: R.Shape([1, 3]) = R.shape_of(x)
+                lv1: R.Tensor((2,), dtype="int64") = R.shape_to_tensor(lv)
+                gv: R.Tensor((2,), dtype="int32") = R.astype(lv1, 
dtype="int32")
+                R.output(gv)
+            return gv
+
+    verify(ShapeDynamic, Expected)
 
 
 def _build_rank_model():
@@ -714,7 +864,23 @@ def test_range(start, limit, delta, dtype):
         def func(self):
             return tf.range(start, limit, delta, dtype=dtype)
 
-    verify(Range)
+    np_dtype = np.float32 if dtype == tf.float32 else np.int64 if dtype == 
tf.int64 else np.int32
+    expected_range = np.arange(start, limit, delta, dtype=np_dtype)
+    out_dtype = np.dtype(np_dtype).name
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main() -> R.Tensor((len(expected_range),), dtype=out_dtype):
+            R.func_attr({"num_input": 0})
+            with R.dataflow():
+                gv: R.Tensor((len(expected_range),), dtype=out_dtype) = 
R.const(
+                    expected_range, out_dtype
+                )
+                R.output(gv)
+            return gv
+
+    verify(Range, Expected)
 
 
 @pytest.mark.parametrize(
@@ -787,7 +953,6 @@ def test_tile_ir():
         ((2, 3), [2, 1], tf.float32),
         ((1, 4, 2), [3, 1, 2], tf.float32),
         ((2, 1, 3, 1), [1, 2, 1, 4], tf.float32),
-        ((2, 3), [1, 1], tf.float32),
         ((3,), [2], tf.float32),
         ((2, 3), [4, 2], tf.float32),
         ((2, 2), [1, 3], tf.int32),
@@ -801,7 +966,116 @@ def test_tile(input_shape, multiples, dtype):
         def func(self, x):
             return tf.tile(x, multiples)
 
-    verify(Tile)
+    if input_shape == (2, 3) and multiples == [2, 1]:
+
+        @I.ir_module
+        class ExpectedTile2x3Repeat2x1:
+            @R.function
+            def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((4, 3), 
dtype="float32"):
+                R.func_attr({"num_input": 1})
+                with R.dataflow():
+                    gv: R.Tensor((4, 3), dtype="float32") = R.tile(x, 
repeats=[2, 1])
+                    R.output(gv)
+                return gv
+
+        expected = ExpectedTile2x3Repeat2x1
+
+    elif input_shape == (1, 4, 2):
+
+        @I.ir_module
+        class ExpectedTile1x4x2:
+            @R.function
+            def main(x: R.Tensor((1, 4, 2), dtype="float32")) -> R.Tensor(
+                (3, 4, 4), dtype="float32"
+            ):
+                R.func_attr({"num_input": 1})
+                with R.dataflow():
+                    gv: R.Tensor((3, 4, 4), dtype="float32") = R.tile(x, 
repeats=[3, 1, 2])
+                    R.output(gv)
+                return gv
+
+        expected = ExpectedTile1x4x2
+
+    elif input_shape == (2, 1, 3, 1):
+
+        @I.ir_module
+        class ExpectedTile2x1x3x1:
+            @R.function
+            def main(x: R.Tensor((2, 1, 3, 1), dtype="float32")) -> R.Tensor(
+                (2, 2, 3, 4), dtype="float32"
+            ):
+                R.func_attr({"num_input": 1})
+                with R.dataflow():
+                    gv: R.Tensor((2, 2, 3, 4), dtype="float32") = R.tile(x, 
repeats=[1, 2, 1, 4])
+                    R.output(gv)
+                return gv
+
+        expected = ExpectedTile2x1x3x1
+
+    elif input_shape == (3,):
+
+        @I.ir_module
+        class ExpectedTile3:
+            @R.function
+            def main(x: R.Tensor((3,), dtype="float32")) -> R.Tensor((6,), 
dtype="float32"):
+                R.func_attr({"num_input": 1})
+                with R.dataflow():
+                    gv: R.Tensor((6,), dtype="float32") = R.tile(x, 
repeats=[2])
+                    R.output(gv)
+                return gv
+
+        expected = ExpectedTile3
+
+    elif input_shape == (2, 3) and multiples == [4, 2]:
+
+        @I.ir_module
+        class ExpectedTile2x3Repeat4x2:
+            @R.function
+            def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((8, 6), 
dtype="float32"):
+                R.func_attr({"num_input": 1})
+                with R.dataflow():
+                    gv: R.Tensor((8, 6), dtype="float32") = R.tile(x, 
repeats=[4, 2])
+                    R.output(gv)
+                return gv
+
+        expected = ExpectedTile2x3Repeat4x2
+
+    else:
+
+        @I.ir_module
+        class ExpectedTileInt32:
+            @R.function
+            def main(x: R.Tensor((2, 2), dtype="int32")) -> R.Tensor((2, 6), 
dtype="int32"):
+                R.func_attr({"num_input": 1})
+                with R.dataflow():
+                    gv: R.Tensor((2, 6), dtype="int32") = R.tile(x, 
repeats=[1, 3])
+                    R.output(gv)
+                return gv
+
+        expected = ExpectedTileInt32
+
+    verify(Tile, expected)
+
+
+def test_tile_identity():
+    """TILE with all repeat factors set to one imports as identity."""
+
+    class Tile(tf.Module):
+        @tf.function(input_signature=[tf.TensorSpec(shape=(2, 3), 
dtype=tf.float32)])
+        def func(self, x):
+            return tf.tile(x, [1, 1])
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), 
dtype="float32"):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                gv: R.Tensor((2, 3), dtype="float32") = x
+                R.output(gv)
+            return gv
+
+    verify(Tile, Expected)
 
 
 def test_concat_v2():
@@ -928,7 +1202,8 @@ def test_swish():
 
 
 def test_prelu_constant_alpha():
-    alpha_init = tf.keras.initializers.Constant(np.linspace(0.1, 0.3, 30, 
dtype=np.float32))
+    alpha = np.linspace(0.1, 0.3, 30, dtype=np.float32)
+    alpha_init = tf.keras.initializers.Constant(alpha)
     prelu = tf.keras.layers.PReLU(alpha_initializer=alpha_init)
 
     class TfInput(tf.Module):
@@ -936,7 +1211,23 @@ def test_prelu_constant_alpha():
         def func(self, x):
             return prelu(x)
 
-    verify(TfInput)
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((1, 30), dtype="float32")) -> R.Tensor((1, 30), 
dtype="float32"):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                lv: R.Tensor((1, 30), dtype="float32") = R.broadcast_to(
+                    R.const(alpha), R.shape([1, 30])
+                )
+                lv1: R.Tensor((30,), dtype="float32") = R.reshape(x, 
R.shape([30]))
+                lv2: R.Tensor((30,), dtype="float32") = R.reshape(lv, 
R.shape([30]))
+                lv3: R.Tensor((30,), dtype="float32") = R.nn.prelu(lv1, lv2, 
axis=0)
+                gv: R.Tensor((1, 30), dtype="float32") = R.reshape(lv3, 
R.shape([1, 30]))
+                R.output(gv)
+            return gv
+
+    verify(TfInput, Expected)
 
 
 def test_fill():
@@ -979,13 +1270,31 @@ def test_fill_dynamic_dims():
         def func(self, dims, value):
             return tf.fill(dims, value)
 
-    cf = TfFillDynamic().func.get_concrete_function()
-    mod = _get_mod_from_cfunc(cf)
-    ir = mod.script()
-    assert "R.tensor_to_shape" in ir
-    assert "R.full" in ir
-    tvm.compile(mod, tvm.target.Target("llvm"))
-    verify(cf)
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            dims: R.Tensor((2,), dtype="int32"), value: R.Tensor((), 
dtype="float32")
+        ) -> R.Tensor(dtype="float32", ndim=2):
+            R.func_attr({"num_input": 2})
+            fill_dim_0 = T.int64()
+            fill_dim_1 = T.int64()
+            with R.dataflow():
+                lv: R.Tensor((2,), dtype="int32") = R.match_cast(
+                    dims, R.Tensor((2,), dtype="int32")
+                )
+                lv1: R.Tensor((2,), dtype="int64") = R.astype(lv, 
dtype="int64")
+                lv2: R.Shape(ndim=2) = R.tensor_to_shape(lv1)
+                _: R.Shape([fill_dim_0, fill_dim_1]) = R.match_cast(
+                    lv2, R.Shape([fill_dim_0, fill_dim_1])
+                )
+                gv: R.Tensor((fill_dim_0, fill_dim_1), dtype="float32") = 
R.full(
+                    R.shape([fill_dim_0, fill_dim_1]), value, dtype="void"
+                )
+                R.output(gv)
+            return gv
+
+    verify(TfFillDynamic, Expected)
 
 
 def test_random_uniform_dynamic_shape():
@@ -996,11 +1305,38 @@ def test_random_uniform_dynamic_shape():
         def func(self, shape):
             return tf.raw_ops.RandomUniform(shape=shape, dtype=tf.float32, 
seed=7, seed2=11)
 
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(shape: R.Tensor((2,), dtype="int32")) -> 
R.Tensor(dtype="float32", ndim=2):
+            R.func_attr({"num_input": 1})
+            random_uniform_dim_0 = T.int64()
+            random_uniform_dim_1 = T.int64()
+            with R.dataflow():
+                lv: R.Tensor((2,), dtype="int32") = R.match_cast(
+                    shape, R.Tensor((2,), dtype="int32")
+                )
+                lv1: R.Tensor((2,), dtype="int64") = R.astype(lv, 
dtype="int64")
+                lv2: R.Shape(ndim=2) = R.tensor_to_shape(lv1)
+                _: R.Shape([random_uniform_dim_0, random_uniform_dim_1]) = 
R.match_cast(
+                    lv2, R.Shape([random_uniform_dim_0, random_uniform_dim_1])
+                )
+                gv = R.call_dps_packed(
+                    "tvm.contrib.random.uniform",
+                    (
+                        R.prim_value(7),
+                        R.prim_value(11),
+                        R.prim_value(T.float64(0.0)),
+                        R.prim_value(T.float64(1.0)),
+                    ),
+                    out_ty=R.Tensor((random_uniform_dim_0, 
random_uniform_dim_1), dtype="float32"),
+                )
+                R.output(gv)
+            return gv
+
     cf = TfRandomUniform().func.get_concrete_function()
     mod = _get_mod_from_cfunc(cf)
-    ir = mod.script()
-    assert "R.tensor_to_shape" in ir
-    assert 'R.call_dps_packed("tvm.contrib.random.uniform"' in ir
+    tvm.ir.assert_structural_equal(mod, Expected)
 
     _verify_random_with_inputs(cf, [np.array([2, 3], dtype="int32")])
 
@@ -1013,11 +1349,43 @@ def test_random_standard_normal_dynamic_shape():
         def func(self, shape):
             return tf.raw_ops.RandomStandardNormal(shape=shape, 
dtype=tf.float32, seed=3, seed2=5)
 
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(shape: R.Tensor((2,), dtype="int32")) -> 
R.Tensor(dtype="float32", ndim=2):
+            R.func_attr({"num_input": 1})
+            random_standard_normal_dim_0 = T.int64()
+            random_standard_normal_dim_1 = T.int64()
+            with R.dataflow():
+                lv: R.Tensor((2,), dtype="int32") = R.match_cast(
+                    shape, R.Tensor((2,), dtype="int32")
+                )
+                lv1: R.Tensor((2,), dtype="int64") = R.astype(lv, 
dtype="int64")
+                lv2: R.Shape(ndim=2) = R.tensor_to_shape(lv1)
+                _: R.Shape([random_standard_normal_dim_0, 
random_standard_normal_dim_1]) = (
+                    R.match_cast(
+                        lv2, R.Shape([random_standard_normal_dim_0, 
random_standard_normal_dim_1])
+                    )
+                )
+                gv = R.call_dps_packed(
+                    "tvm.contrib.random.normal",
+                    (
+                        R.prim_value(3),
+                        R.prim_value(5),
+                        R.prim_value(T.float64(0.0)),
+                        R.prim_value(T.float64(1.0)),
+                    ),
+                    out_ty=R.Tensor(
+                        (random_standard_normal_dim_0, 
random_standard_normal_dim_1),
+                        dtype="float32",
+                    ),
+                )
+                R.output(gv)
+            return gv
+
     cf = TfRandomStandardNormal().func.get_concrete_function()
     mod = _get_mod_from_cfunc(cf)
-    ir = mod.script()
-    assert "R.tensor_to_shape" in ir
-    assert 'R.call_dps_packed("tvm.contrib.random.normal"' in ir
+    tvm.ir.assert_structural_equal(mod, Expected)
 
     _verify_random_with_inputs(cf, [np.array([2, 4], dtype="int32")])
 
@@ -1041,14 +1409,58 @@ def test_multinomial_dynamic_num_samples():
                 seed2=17,
             )
 
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            logits: R.Tensor((2, 3), dtype="float32"),
+            num_samples: R.Tensor((), dtype="int32"),
+        ) -> R.Tensor(dtype="int64", ndim=2):
+            R.func_attr({"num_input": 2})
+            multinomial_num_samples = T.int64()
+            with R.dataflow():
+                lv: R.Tensor((), dtype="int32") = R.match_cast(
+                    num_samples, R.Tensor((), dtype="int32")
+                )
+                lv1: R.Tensor((), dtype="int64") = R.astype(lv, dtype="int64")
+                lv2: R.Tensor((1,), dtype="int64") = R.reshape(lv1, 
R.shape([1]))
+                lv3: R.Shape(ndim=1) = R.tensor_to_shape(lv2)
+                _: R.Shape([multinomial_num_samples]) = R.match_cast(
+                    lv3, R.Shape([multinomial_num_samples])
+                )
+                lv5: R.Tensor((2, 3), dtype="float32") = R.nn.softmax(logits, 
axis=-1)
+                lv6 = R.call_dps_packed(
+                    "tvm.contrib.random.uniform",
+                    (
+                        R.prim_value(13),
+                        R.prim_value(17),
+                        R.prim_value(T.float64(0.0)),
+                        R.prim_value(T.float64(1.0)),
+                    ),
+                    out_ty=R.Tensor((2 * multinomial_num_samples, 1), 
dtype="float32"),
+                )
+                lv7: R.Tensor((2,), dtype="int64") = R.arange(
+                    R.prim_value(0), R.prim_value(2), R.prim_value(1), 
dtype="int64"
+                )
+                lv8: R.Tensor((2, 1), dtype="int64") = R.expand_dims(lv7, 
axis=[1])
+                lv9: R.Tensor((2, multinomial_num_samples), dtype="int64") = 
R.broadcast_to(
+                    lv8, R.shape([2, multinomial_num_samples])
+                )
+                lv10: R.Tensor((2 * multinomial_num_samples, 1), 
dtype="int64") = R.reshape(
+                    lv9, R.shape([2 * multinomial_num_samples, 1])
+                )
+                lv11: R.Tensor((2 * multinomial_num_samples, 1), 
dtype="int64") = (
+                    R.multinomial_from_uniform(lv5, lv6, lv10, dtype="int64")
+                )
+                gv: R.Tensor((2, multinomial_num_samples), dtype="int64") = 
R.reshape(
+                    lv11, R.shape([2, multinomial_num_samples])
+                )
+                R.output(gv)
+            return gv
+
     cf = TfMultinomial().func.get_concrete_function()
     mod = _get_mod_from_cfunc(cf)
-    ir = mod.script()
-    assert "R.nn.softmax" in ir
-    assert "R.multinomial_from_uniform" in ir
-    assert "R.tensor_to_shape" in ir
-    assert "multinomial_num_samples" in ir
-    assert 'R.call_dps_packed("tvm.contrib.random.uniform"' in ir
+    tvm.ir.assert_structural_equal(mod, Expected)
 
     _verify_random_with_inputs(
         cf,
@@ -1455,7 +1867,23 @@ def test_fully_connected():
             out = tf.matmul(x, weight, transpose_b=True)
             return tf.nn.bias_add(out, bias)
 
-    verify(FullyConnected)
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((1, 8), dtype="float32")) -> R.Tensor((1, 3), 
dtype="float32"):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                lv: R.Tensor((8, 3), dtype="float32") = R.permute_dims(
+                    R.const(np.arange(24, dtype=np.float32).reshape((3, 8))), 
axes=[1, 0]
+                )
+                lv1: R.Tensor((1, 3), dtype="float32") = R.matmul(x, lv, 
out_dtype="void")
+                gv: R.Tensor((1, 3), dtype="float32") = R.add(
+                    lv1, R.const(np.array([0.5, 1.0, -1.0], dtype=np.float32))
+                )
+                R.output(gv)
+            return gv
+
+    verify(FullyConnected, Expected)
 
 
 def test_depthwise_conv2d():
@@ -1474,7 +1902,38 @@ def test_depthwise_conv2d():
                 padding="SAME",
             )
 
-    verify(DepthwiseConv2D)
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            data: R.Tensor((1, 8, 8, 2), dtype="float32"),
+            kernel: R.Tensor((3, 3, 2, 1), dtype="float32"),
+        ) -> R.Tensor((1, 8, 8, 2), dtype="float32"):
+            R.func_attr({"num_input": 2})
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 3, 2), dtype="float32") = R.reshape(
+                    kernel, R.shape([1, 3, 3, 2])
+                )
+                lv1: R.Tensor((3, 3, 2, 1), dtype="float32") = R.reshape(lv, 
R.shape([3, 3, 2, 1]))
+                lv2: R.Tensor((1, 8, 8, 2), dtype="float32") = R.nn.conv2d(
+                    data,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[1, 1, 1, 1],
+                    dilation=[1, 1],
+                    groups=2,
+                    data_layout="NHWC",
+                    kernel_layout="HWOI",
+                    out_layout="NHWC",
+                    out_dtype="void",
+                )
+                gv: R.Tensor((1, 8, 8, 2), dtype="float32") = R.add(
+                    lv2, R.const(np.zeros((2,), dtype="float32"))
+                )
+                R.output(gv)
+            return gv
+
+    verify(DepthwiseConv2D, Expected)
 
 
 def test_transpose_conv():
@@ -1495,7 +1954,36 @@ def test_transpose_conv():
                 padding="SAME",
             )
 
-    verify(TransposeConv)
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            data: R.Tensor((1, 8, 8, 2), dtype="float32"),
+            kernel: R.Tensor((3, 3, 3, 2), dtype="float32"),
+        ) -> R.Tensor((1, 8, 8, 3), dtype="float32"):
+            R.func_attr({"num_input": 2})
+            with R.dataflow():
+                lv: R.Tensor((3, 3, 3, 2), dtype="float32") = R.permute_dims(
+                    kernel, axes=[2, 0, 1, 3]
+                )
+                lv1: R.Tensor((2, 3, 3, 3), dtype="float32") = 
R.permute_dims(lv, axes=[3, 0, 1, 2])
+                gv: R.Tensor((1, 8, 8, 3), dtype="float32") = 
R.nn.conv2d_transpose(
+                    data,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[1, 1, 1, 1],
+                    output_padding=[0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="IOHW",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                R.output(gv)
+            return gv
+
+    verify(TransposeConv, Expected)
 
 
 def test_l2_pool2d():
@@ -1535,7 +2023,23 @@ def test_l2_normalization():
         def func(self, x):
             return tf.nn.l2_normalize(x, axis=-1)
 
-    verify(L2Normalization)
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((2, 4), 
dtype="float32"):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                lv: R.Tensor((2, 4), dtype="float32") = R.square(x)
+                lv1: R.Tensor((2, 1), dtype="float32") = R.sum(lv, axis=[1], 
keepdims=True)
+                lv2: R.Tensor((2, 1), dtype="float32") = R.add(
+                    lv1, R.const(9.999999960041972e-13, "float32")
+                )
+                lv3: R.Tensor((2, 1), dtype="float32") = R.sqrt(lv2)
+                gv: R.Tensor((2, 4), dtype="float32") = R.divide(x, lv3)
+                R.output(gv)
+            return gv
+
+    verify(L2Normalization, Expected)
 
 
 def test_local_response_normalization():
@@ -1550,7 +2054,42 @@ def test_local_response_normalization():
                 beta=0.75,
             )
 
-    verify(LocalResponseNormalization)
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((1, 8, 8, 4), dtype="float32")) -> R.Tensor(
+            (1, 8, 8, 4), dtype="float32"
+        ):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                lv: R.Tensor((1, 8, 8, 4), dtype="float32") = R.square(x)
+                lv1: R.Tensor((64, 4, 1, 1), dtype="float32") = R.reshape(
+                    lv, R.shape([64, 4, 1, 1])
+                )
+                lv2: R.Tensor((64, 4, 1, 1), dtype="float32") = 
R.nn.avg_pool2d(
+                    lv1,
+                    pool_size=[5, 1],
+                    strides=[1, 1],
+                    dilation=[1, 1],
+                    padding=[2, 0, 2, 0],
+                    ceil_mode=False,
+                    count_include_pad=True,
+                    layout="NHWC",
+                    out_layout="NHWC",
+                )
+                lv3: R.Tensor((1, 8, 8, 4), dtype="float32") = R.reshape(lv2, 
R.shape([1, 8, 8, 4]))
+                lv4: R.Tensor((1, 8, 8, 4), dtype="float32") = R.multiply(
+                    R.const(0.00049999996554106474, "float32"), lv3
+                )
+                lv5: R.Tensor((1, 8, 8, 4), dtype="float32") = 
R.add(R.const(1.0, "float32"), lv4)
+                lv6: R.Tensor((1, 8, 8, 4), dtype="float32") = R.power(
+                    lv5, R.const(0.75, "float32")
+                )
+                gv: R.Tensor((1, 8, 8, 4), dtype="float32") = R.divide(x, lv6)
+                R.output(gv)
+            return gv
+
+    verify(LocalResponseNormalization, Expected)
 
 
 def test_slice():
@@ -1660,11 +2199,6 @@ def test_reverse_sequence():
                 R.output(gv)
             return gv
 
-    tvm.ir.assert_structural_equal(mod, Expected)
-    ir = mod.script()
-    assert "R.reverse_sequence" in ir
-    assert 'R.call_dps_packed("topi.reverse_sequence"' not in ir
-
     data = np.arange(24, dtype="float32").reshape((2, 4, 3))
     seq_lengths = np.array([1, 3], dtype="int32")
     expected = data.copy()
@@ -2138,7 +2672,30 @@ def test_avg_pool2d_valid():
     Pool2DModule = _make_pool2d_module(
         tf.nn.avg_pool2d, (1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1), 
"VALID"
     )
-    verify(Pool2DModule)
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            data: R.Tensor((1, 128, 128, 32), dtype="float32"),
+        ) -> R.Tensor((1, 127, 127, 32), dtype="float32"):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                gv: R.Tensor((1, 127, 127, 32), dtype="float32") = 
R.nn.avg_pool2d(
+                    data,
+                    pool_size=[2, 2],
+                    strides=[1, 1],
+                    dilation=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    ceil_mode=False,
+                    count_include_pad=False,
+                    layout="NHWC",
+                    out_layout="NHWC",
+                )
+                R.output(gv)
+            return gv
+
+    verify(Pool2DModule, Expected)
 
 
 def test_max_pool2d_same():
@@ -2174,7 +2731,30 @@ def test_max_pool2d_valid():
     Pool2DModule = _make_pool2d_module(
         tf.nn.max_pool2d, (1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1), 
"VALID"
     )
-    verify(Pool2DModule)
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            data: R.Tensor((1, 128, 128, 32), dtype="float32"),
+        ) -> R.Tensor((1, 127, 127, 32), dtype="float32"):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                gv: R.Tensor((1, 127, 127, 32), dtype="float32") = 
R.nn.max_pool2d(
+                    data,
+                    pool_size=[2, 2],
+                    strides=[1, 1],
+                    dilation=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    ceil_mode=False,
+                    count_include_pad=False,
+                    layout="NHWC",
+                    out_layout="NHWC",
+                )
+                R.output(gv)
+            return gv
+
+    verify(Pool2DModule, Expected)
 
 
 @pytest.mark.parametrize(
@@ -2236,6 +2816,8 @@ def test_networks(net, shape):
     model = NetworkModule()
     concrete_func = 
model.func.get_concrete_function(tf.TensorSpec(shape=shape, dtype=tf.float32))
 
+    mod = _get_mod_from_cfunc(concrete_func)
+    tvm.ir.assert_structural_equal(mod["main"].ret_ty, relax.TensorType((1, 
1000), "float32"))
     verify(concrete_func)
 
 
@@ -2245,14 +2827,38 @@ def test_broadcast_to():
         def func(self, x):
             return tf.broadcast_to(x, [3, 2, 2])
 
-    verify(Model)
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((2, 2), dtype="float32")) -> R.Tensor((3, 2, 2), 
dtype="float32"):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                gv: R.Tensor((3, 2, 2), dtype="float32") = R.multiply(
+                    x, R.const(np.ones((3, 2, 2), dtype="float32"))
+                )
+                R.output(gv)
+            return gv
+
+    verify(Model, Expected)
 
     class ModelScalarAndInt(tf.Module):
         @tf.function(input_signature=[tf.TensorSpec(shape=(), dtype=tf.int32)])
         def func(self, x):
             return tf.broadcast_to(x, [4, 4])
 
-    verify(ModelScalarAndInt)
+    @I.ir_module
+    class ExpectedScalarAndInt:
+        @R.function
+        def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((4, 4), 
dtype="int32"):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                gv: R.Tensor((4, 4), dtype="int32") = R.multiply(
+                    x, R.const(np.ones((4, 4), dtype="int32"))
+                )
+                R.output(gv)
+            return gv
+
+    verify(ModelScalarAndInt, ExpectedScalarAndInt)
 
 
 def test_embedding_lookup():
@@ -2262,7 +2868,23 @@ def test_embedding_lookup():
             params = tf.constant([[1, 2], [3, 4], [5, 6]], dtype=tf.float32)
             return tf.nn.embedding_lookup(params, indices)
 
-    verify(Model)
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(indices: R.Tensor((3,), dtype="int32")) -> R.Tensor((3, 2), 
dtype="float32"):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                lv: R.Tensor((3,), dtype="int32") = R.astype(indices, 
dtype="int32")
+                gv: R.Tensor((3, 2), dtype="float32") = R.take(
+                    R.const(np.array([[1, 2], [3, 4], [5, 6]], 
dtype=np.float32)),
+                    lv,
+                    axis=0,
+                    mode="fast",
+                )
+                R.output(gv)
+            return gv
+
+    verify(Model, Expected)
 
     class ModelMultidim(tf.Module):
         @tf.function(input_signature=[tf.TensorSpec(shape=(2, 3), 
dtype=tf.int32)])
@@ -2270,7 +2892,23 @@ def test_embedding_lookup():
             params = tf.constant([[1, 2], [3, 4], [5, 6], [7, 8]], 
dtype=tf.float32)
             return tf.nn.embedding_lookup(params, indices)
 
-    verify(ModelMultidim)
+    @I.ir_module
+    class ExpectedMultidim:
+        @R.function
+        def main(indices: R.Tensor((2, 3), dtype="int32")) -> R.Tensor((2, 3, 
2), dtype="float32"):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                lv: R.Tensor((2, 3), dtype="int32") = R.astype(indices, 
dtype="int32")
+                gv: R.Tensor((2, 3, 2), dtype="float32") = R.take(
+                    R.const(np.array([[1, 2], [3, 4], [5, 6], [7, 8]], 
dtype=np.float32)),
+                    lv,
+                    axis=0,
+                    mode="fast",
+                )
+                R.output(gv)
+            return gv
+
+    verify(ModelMultidim, ExpectedMultidim)
 
 
 def test_select_v2():
@@ -2285,7 +2923,21 @@ def test_select_v2():
         def func(self, condition, x, y):
             return tf.where(condition, x, y)
 
-    verify(Model)
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            condition: R.Tensor((2, 2), dtype="bool"),
+            x: R.Tensor((2, 2), dtype="float32"),
+            y: R.Tensor((2, 2), dtype="float32"),
+        ) -> R.Tensor((2, 2), dtype="float32"):
+            R.func_attr({"num_input": 3})
+            with R.dataflow():
+                gv: R.Tensor((2, 2), dtype="float32") = R.where(condition, x, 
y)
+                R.output(gv)
+            return gv
+
+    verify(Model, Expected)
 
     class ModelBroadcasting(tf.Module):
         @tf.function(
@@ -2298,7 +2950,21 @@ def test_select_v2():
         def func(self, condition, x, y):
             return tf.where(condition, x, y)
 
-    verify(ModelBroadcasting)
+    @I.ir_module
+    class ExpectedBroadcasting:
+        @R.function
+        def main(
+            condition: R.Tensor((2, 1), dtype="bool"),
+            x: R.Tensor((2, 2), dtype="float32"),
+            y: R.Tensor((), dtype="float32"),
+        ) -> R.Tensor((2, 2), dtype="float32"):
+            R.func_attr({"num_input": 3})
+            with R.dataflow():
+                gv: R.Tensor((2, 2), dtype="float32") = R.where(condition, x, 
y)
+                R.output(gv)
+            return gv
+
+    verify(ModelBroadcasting, ExpectedBroadcasting)
 
 
 def test_scatter_nd():
@@ -2313,7 +2979,27 @@ def test_scatter_nd():
         def func(self, indices, updates, shape):
             return tf.scatter_nd(indices, updates, shape)
 
-    verify(Model)
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            indices: R.Tensor((4, 1), dtype="int32"),
+            updates: R.Tensor((4,), dtype="float32"),
+            shape: R.Tensor((1,), dtype="int32"),
+        ) -> R.Tensor(dtype="float32", ndim=1):
+            R.func_attr({"num_input": 3})
+            with R.dataflow():
+                lv: R.Tensor((1,), dtype="int64") = R.astype(shape, 
dtype="int64")
+                lv1: R.Shape(ndim=1) = R.tensor_to_shape(lv)
+                lv2: R.Tensor(lv1, dtype="float32") = R.zeros(lv1, 
dtype="float32")
+                lv3: R.Tensor((1, 4), dtype="int32") = R.permute_dims(indices, 
axes=[-1, 0])
+                gv: R.Tensor(dtype="float32", ndim=1) = R.scatter_nd(
+                    lv2, lv3, updates, reduction="update"
+                )
+                R.output(gv)
+            return gv
+
+    verify(Model, Expected)
 
 
 def test_segment_sum():
@@ -2985,17 +3671,16 @@ def test_nms_v5_ir():
         score_threshold=0.0,
     )
 
-    ir = mod.script()
-
-    # Validate correct sorting/id indices are passed to valid_counts
-    assert "score_index=0" in ir
-    assert "id_index=-1" in ir
-    # NMS size limit validation
-    assert f"max_output_size={max_output_size}" in ir
-    # Valid output shape must be () statically
-    assert 'R.Tensor((), dtype="int32")' in ir
-    # Bounding boxes / scores tensor bounds checks
-    assert f"R.Tensor(({max_output_size},)" in ir
+    tvm.ir.assert_structural_equal(
+        mod["main"].ret_ty,
+        relax.TupleType(
+            [
+                relax.TensorType((max_output_size,), "int32"),
+                relax.TensorType((max_output_size,), "float32"),
+                relax.TensorType((), "int32"),
+            ]
+        ),
+    )
 
 
 def test_nms_v5_soft_ir():
@@ -3010,14 +3695,16 @@ def test_nms_v5_soft_ir():
         soft_nms_sigma=0.5,
     )
 
-    ir = mod.script()
-
-    # soft_nms_sigma must appear in the IR
-    assert "soft_nms_sigma=0.5" in ir
-    # score_threshold must also be forwarded
-    assert "score_threshold=0.0" in ir
-    # Soft-NMS padded scores must be clipped to non-negative values.
-    assert "R.clip(" in ir
+    tvm.ir.assert_structural_equal(
+        mod["main"].ret_ty,
+        relax.TupleType(
+            [
+                relax.TensorType((max_output_size,), "int32"),
+                relax.TensorType((max_output_size,), "float32"),
+                relax.TensorType((), "int32"),
+            ]
+        ),
+    )
 
 
 _NMS_V4_CASES = [
@@ -3099,19 +3786,15 @@ def test_nms_v4_ir():
         score_threshold=0.0,
     )
 
-    ir = mod.script()
-
-    # Validate correct sorting/id indices are passed to valid_counts
-    assert "score_index=0" in ir
-    assert "id_index=-1" in ir
-    # NMS size limit validation
-    assert f"max_output_size={max_output_size}" in ir
-    # Valid output shape must be () statically
-    assert 'R.Tensor((), dtype="int32")' in ir
-    # Selected indices tensor bounds check
-    assert f"R.Tensor(({max_output_size},)" in ir
-    # V4 must use hard-NMS (soft_nms_sigma left at default 0.0)
-    assert "soft_nms_sigma=0.0" in ir
+    tvm.ir.assert_structural_equal(
+        mod["main"].ret_ty,
+        relax.TupleType(
+            [
+                relax.TensorType((max_output_size,), "int32"),
+                relax.TensorType((), "int32"),
+            ]
+        ),
+    )
 
 
 _DETECTION_POSTPROCESS_SMOKE_CASES = [
@@ -3188,16 +3871,7 @@ _DETECTION_POSTPROCESS_SHAPE_CASES = [
 )
 def test_detection_postprocess_smoke(build_kwargs, expected_topk_count, 
expected_keep_background):
     mod = _build_detection_postprocess_mod(**build_kwargs)
-    ir = mod.script()
-
-    assert "R.vision.multibox_transform_loc" in ir
-    assert "R.vision.all_class_non_max_suppression" in ir
-    assert 'output_format="tensorflow"' in ir
-    assert "R.where" in ir
-    assert "R.gather_elements" in ir
-    assert "R.gather_nd" in ir
-    assert ir.count("R.topk(") == expected_topk_count
-    assert f"keep_background={expected_keep_background}" in ir
+
     expected_batch = build_kwargs["batch_size"]
     expected_max_detections = build_kwargs["max_detections"]
     tvm.ir.assert_structural_equal(
@@ -3213,9 +3887,6 @@ def test_detection_postprocess_smoke(build_kwargs, 
expected_topk_count, expected
     )
 
     legalized = relax.transform.LegalizeOps()(mod)
-    legalized_ir = legalized.script()
-    assert "R.vision.all_class_non_max_suppression(" not in legalized_ir
-    assert "R.call_tir(" in legalized_ir
     tvm.ir.assert_structural_equal(legalized["main"].ret_ty, 
mod["main"].ret_ty)
 
 
@@ -3705,19 +4376,59 @@ def test_space_to_batch_nd(input_shape, block_shape, 
paddings, expected_out_shap
                 tf.constant(paddings, dtype=tf.int32),
             )
 
-    cf = SpaceToBatchND().func.get_concrete_function()
-    mod = _get_mod_from_cfunc(cf)
-    ir = mod.script()
+    if expected_out_shape == (4, 1, 1, 1):
+
+        @I.ir_module
+        class ExpectedSpaceToBatchNoPadding:
+            @R.function
+            def main(x: R.Tensor((1, 2, 2, 1), dtype="float32")) -> R.Tensor(
+                (4, 1, 1, 1), dtype="float32"
+            ):
+                R.func_attr({"num_input": 1})
+                with R.dataflow():
+                    gv = R.call_dps_packed(
+                        "topi.nn.space_to_batch_nd",
+                        (
+                            x,
+                            R.shape([2, 2]),
+                            R.shape([0, 0]),
+                            R.shape([0, 0]),
+                            R.prim_value(T.float64(0.0)),
+                        ),
+                        out_ty=R.Tensor((4, 1, 1, 1), dtype="float32"),
+                    )
+                    R.output(gv)
+                return gv
+
+        expected = ExpectedSpaceToBatchNoPadding
 
-    assert "space_to_batch_nd" in ir
-    assert len(mod["main"].params) == 1
-    tvm.ir.assert_structural_equal(
-        mod["main"].ret_ty,
-        relax.TensorType(expected_out_shape, "float32"),
-    )
+    else:
 
-    if "CI_ENV_NIGHTLY" in os.environ:
-        verify(SpaceToBatchND)
+        @I.ir_module
+        class ExpectedSpaceToBatchWithPadding:
+            @R.function
+            def main(x: R.Tensor((1, 2, 3, 1), dtype="float32")) -> R.Tensor(
+                (4, 1, 2, 1), dtype="float32"
+            ):
+                R.func_attr({"num_input": 1})
+                with R.dataflow():
+                    gv = R.call_dps_packed(
+                        "topi.nn.space_to_batch_nd",
+                        (
+                            x,
+                            R.shape([2, 2]),
+                            R.shape([0, 1]),
+                            R.shape([0, 0]),
+                            R.prim_value(T.float64(0.0)),
+                        ),
+                        out_ty=R.Tensor((4, 1, 2, 1), dtype="float32"),
+                    )
+                    R.output(gv)
+                return gv
+
+        expected = ExpectedSpaceToBatchWithPadding
+
+    verify(SpaceToBatchND, expected)
 
 
 @pytest.mark.parametrize(
@@ -3739,19 +4450,47 @@ def test_batch_to_space_nd(input_shape, block_shape, 
crops, expected_out_shape):
                 crops=tf.constant(crops, dtype=tf.int32),
             )
 
-    cf = BatchToSpaceND().func.get_concrete_function()
-    mod = _get_mod_from_cfunc(cf)
-    ir = mod.script()
+    if expected_out_shape == (1, 2, 2, 1):
+
+        @I.ir_module
+        class ExpectedBatchToSpaceNoCrop:
+            @R.function
+            def main(x: R.Tensor((4, 1, 1, 1), dtype="float32")) -> R.Tensor(
+                (1, 2, 2, 1), dtype="float32"
+            ):
+                R.func_attr({"num_input": 1})
+                with R.dataflow():
+                    gv = R.call_dps_packed(
+                        "topi.nn.batch_to_space_nd",
+                        (x, R.shape([2, 2]), R.shape([0, 0]), R.shape([0, 0])),
+                        out_ty=R.Tensor((1, 2, 2, 1), dtype="float32"),
+                    )
+                    R.output(gv)
+                return gv
+
+        expected = ExpectedBatchToSpaceNoCrop
 
-    assert "batch_to_space_nd" in ir
-    assert len(mod["main"].params) == 1
-    tvm.ir.assert_structural_equal(
-        mod["main"].ret_ty,
-        relax.TensorType(expected_out_shape, "float32"),
-    )
+    else:
+
+        @I.ir_module
+        class ExpectedBatchToSpaceWithCrop:
+            @R.function
+            def main(x: R.Tensor((4, 1, 2, 1), dtype="float32")) -> R.Tensor(
+                (1, 2, 3, 1), dtype="float32"
+            ):
+                R.func_attr({"num_input": 1})
+                with R.dataflow():
+                    gv = R.call_dps_packed(
+                        "topi.nn.batch_to_space_nd",
+                        (x, R.shape([2, 2]), R.shape([0, 1]), R.shape([0, 0])),
+                        out_ty=R.Tensor((1, 2, 3, 1), dtype="float32"),
+                    )
+                    R.output(gv)
+                return gv
 
-    if "CI_ENV_NIGHTLY" in os.environ:
-        verify(BatchToSpaceND)
+        expected = ExpectedBatchToSpaceWithCrop
+
+    verify(BatchToSpaceND, expected)
 
 
 def test_leaky_relu():
@@ -3918,7 +4657,8 @@ def test_fake_quant_narrow_range_vector():
 
 
 def test_prelu_basic():
-    alpha_init = tf.keras.initializers.Constant(np.linspace(0.1, 0.3, 30, 
dtype=np.float32))
+    alpha = np.linspace(0.1, 0.3, 30, dtype=np.float32)
+    alpha_init = tf.keras.initializers.Constant(alpha)
     prelu = tf.keras.layers.PReLU(alpha_initializer=alpha_init)
 
     class TfInput(tf.Module):
@@ -3926,7 +4666,23 @@ def test_prelu_basic():
         def func(self, x):
             return prelu(x)
 
-    verify(TfInput)
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((1, 30), dtype="float32")) -> R.Tensor((1, 30), 
dtype="float32"):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                lv: R.Tensor((1, 30), dtype="float32") = R.broadcast_to(
+                    R.const(alpha), R.shape([1, 30])
+                )
+                lv1: R.Tensor((30,), dtype="float32") = R.reshape(x, 
R.shape([30]))
+                lv2: R.Tensor((30,), dtype="float32") = R.reshape(lv, 
R.shape([30]))
+                lv3: R.Tensor((30,), dtype="float32") = R.nn.prelu(lv1, lv2, 
axis=0)
+                gv: R.Tensor((1, 30), dtype="float32") = R.reshape(lv3, 
R.shape([1, 30]))
+                R.output(gv)
+            return gv
+
+    verify(TfInput, Expected)
 
 
 @pytest.mark.parametrize(
@@ -4760,11 +5516,6 @@ def test_rfft2d_static_pair_output():
         )
     )
 
-    mod_script = mod.script()
-    assert "tflite_rfft2d" in mod_script
-    assert "R.call_tir" in mod_script
-    assert 'R.Tensor((2, 3, 2), dtype="float32")' in mod_script
-
     data = np.array([[1.0, -2.0, 3.0, 4.0], [5.0, 6.0, -7.0, 8.0]], 
dtype="float32")
     expected = np.fft.rfft2(data).astype(np.complex64)
     # atol accommodates the float32 reference kernel: numpy's rfft2 internally 
uses
@@ -5975,6 +6726,11 @@ def test_while_subgraphs_repeated_cond_body_pair():
     mod = _load_model_from_buffer(_build_tflite_repeated_while_model())
     names = [gv.name_hint for gv in mod.get_global_vars()]
     assert names.count("tflite_while_subgraph_1_2") == 1
+    tvm.ir.assert_structural_equal(mod["main"].ret_ty, relax.TensorType((), 
"int32"))
+    tvm.ir.assert_structural_equal(
+        mod["tflite_while_subgraph_1_2"].ret_ty,
+        relax.TensorType((), "int32"),
+    )
 
 
 def _build_tflite_two_var_while_model():
@@ -8393,14 +9149,12 @@ def _build_stablehlo_rng_model(algorithm, state_len, 
out_shape, out_tensor_type,
     )
 
 
-def _run_stablehlo_rng_model(algorithm, state_len, out_shape, out_tensor_type, 
init_state):
-    """Import, compile, and execute an RNG model, returning (output_state, 
output)."""
-    buf = _build_stablehlo_rng_model(algorithm, state_len, out_shape, 
out_tensor_type)
-    mod = _load_model_from_buffer(buf)
-    ex = tvm.compile(mod, tvm.target.Target("llvm"))
-    vm = relax.VirtualMachine(ex, tvm.cpu())
-    result = vm["main"](tvm.runtime.tensor(np.array(init_state, 
dtype="uint64")))
-    return result[0].numpy(), result[1].numpy()
+_TFL_TENSOR_TYPE_TO_DTYPE = {
+    _tfl_tensor_type.INT32: "int32",
+    _tfl_tensor_type.UINT32: "uint32",
+    _tfl_tensor_type.INT64: "int64",
+    _tfl_tensor_type.UINT64: "uint64",
+}
 
 
 # Expected vectors are taken verbatim from the TFLite runtime kernel test
@@ -8465,9 +9219,12 @@ _RNG_PHILOX_STATE = {
 )
 def test_stablehlo_rng_bit_generator_threefry(out_dtype, out_tensor_type):
     """TFLite STABLEHLO_RNG_BIT_GENERATOR THREEFRY matches the runtime kernel 
bit-exactly."""
-    state, output = _run_stablehlo_rng_model(
-        _tfl_rng_algorithm.THREEFRY, 2, [2, 3], out_tensor_type, [1, 2]
-    )
+    buf = _build_stablehlo_rng_model(_tfl_rng_algorithm.THREEFRY, 2, [2, 3], 
out_tensor_type)
+    mod = _load_model_from_buffer(buf)
+    ex = tvm.compile(mod, tvm.target.Target("llvm"))
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    result = vm["main"](tvm.runtime.tensor(np.array([1, 2], dtype="uint64")))
+    state, output = result[0].numpy(), result[1].numpy()
     assert output.flatten().tolist() == _RNG_THREEFRY_EXPECTED[out_dtype]
     assert state.tolist() == _RNG_THREEFRY_STATE[out_dtype]
 
@@ -8483,18 +9240,24 @@ def 
test_stablehlo_rng_bit_generator_threefry(out_dtype, out_tensor_type):
 )
 def test_stablehlo_rng_bit_generator_philox(out_dtype, out_tensor_type):
     """TFLite STABLEHLO_RNG_BIT_GENERATOR PHILOX matches the runtime kernel 
bit-exactly."""
-    state, output = _run_stablehlo_rng_model(
-        _tfl_rng_algorithm.PHILOX, 3, [2, 3], out_tensor_type, [1, 2, 3]
-    )
+    buf = _build_stablehlo_rng_model(_tfl_rng_algorithm.PHILOX, 3, [2, 3], 
out_tensor_type)
+    mod = _load_model_from_buffer(buf)
+    ex = tvm.compile(mod, tvm.target.Target("llvm"))
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    result = vm["main"](tvm.runtime.tensor(np.array([1, 2, 3], 
dtype="uint64")))
+    state, output = result[0].numpy(), result[1].numpy()
     assert output.flatten().tolist() == _RNG_PHILOX_EXPECTED[out_dtype]
     assert state.tolist() == _RNG_PHILOX_STATE[out_dtype]
 
 
 def test_stablehlo_rng_bit_generator_default_matches_philox():
     """TFLite STABLEHLO_RNG_BIT_GENERATOR DEFAULT resolves to the PHILOX 
algorithm."""
-    state, output = _run_stablehlo_rng_model(
-        _tfl_rng_algorithm.DEFAULT, 3, [2, 3], _tfl_tensor_type.INT32, [1, 2, 
3]
-    )
+    buf = _build_stablehlo_rng_model(_tfl_rng_algorithm.DEFAULT, 3, [2, 3], 
_tfl_tensor_type.INT32)
+    mod = _load_model_from_buffer(buf)
+    ex = tvm.compile(mod, tvm.target.Target("llvm"))
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    result = vm["main"](tvm.runtime.tensor(np.array([1, 2, 3], 
dtype="uint64")))
+    state, output = result[0].numpy(), result[1].numpy()
     assert output.flatten().tolist() == _RNG_PHILOX_EXPECTED["int32"]
     assert state.tolist() == _RNG_PHILOX_STATE["int32"]
 
@@ -12854,12 +13617,15 @@ def test_svdf_none_activation():
 
     fn = mod["main"]
     assert len(fn.params) == 2, f"expected 2 params (input, state), got 
{len(fn.params)}"
-    in_shape = fn.params[0].ty.shape
-    assert tuple(int(d) for d in in_shape) == (batch, input_size)
-    state_shape = fn.params[1].ty.shape
-    assert tuple(int(d) for d in state_shape) == (batch, num_filters * 
memory_size)
-    out_shape = fn.ret_ty.shape
-    assert tuple(int(d) for d in out_shape) == (batch, num_units)
+    tvm.ir.assert_structural_equal(
+        fn.params[0].ty,
+        relax.TensorType((batch, input_size), "float32"),
+    )
+    tvm.ir.assert_structural_equal(
+        fn.params[1].ty,
+        relax.TensorType((batch, num_filters * memory_size), "float32"),
+    )
+    tvm.ir.assert_structural_equal(fn.ret_ty, relax.TensorType((batch, 
num_units), "float32"))
 
 
 def _build_two_step_shared_state_svdf_model(
@@ -13236,10 +14002,15 @@ def 
test_unidirectional_sequence_lstm_none_activation():
         )
     )
 
-    script = mod.script(show_meta=True)
-    assert script.count("R.sigmoid") == 2
-    assert "R.tanh" not in script
-    assert "R.multiply" in script
+    fn = mod["main"]
+    tvm.ir.assert_structural_equal(
+        fn.params[0].ty,
+        relax.TensorType((batch, time, input_size), "float32"),
+    )
+    tvm.ir.assert_structural_equal(
+        fn.ret_ty,
+        relax.TensorType((batch, time, num_units), "float32"),
+    )
 
 
 def test_unidirectional_sequence_lstm_tanh_activation():
@@ -13276,10 +14047,15 @@ def 
test_unidirectional_sequence_lstm_tanh_activation():
         )
     )
 
-    script = mod.script(show_meta=True)
-    assert script.count("R.sigmoid") == 2
-    assert script.count("R.tanh") == 2
-    assert "R.multiply" in script
+    fn = mod["main"]
+    tvm.ir.assert_structural_equal(
+        fn.params[0].ty,
+        relax.TensorType((batch, time, input_size), "float32"),
+    )
+    tvm.ir.assert_structural_equal(
+        fn.ret_ty,
+        relax.TensorType((batch, time, num_units), "float32"),
+    )
 
 
 def test_unidirectional_sequence_lstm_time_major():
@@ -13312,8 +14088,14 @@ def test_unidirectional_sequence_lstm_time_major():
     )
 
     fn = mod["main"]
-    assert tuple(int(d) for d in fn.params[0].ty.shape) == (time, batch, 
input_size)
-    assert tuple(int(d) for d in fn.ret_ty.shape) == (time, batch, num_units)
+    tvm.ir.assert_structural_equal(
+        fn.params[0].ty,
+        relax.TensorType((time, batch, input_size), "float32"),
+    )
+    tvm.ir.assert_structural_equal(
+        fn.ret_ty,
+        relax.TensorType((time, batch, num_units), "float32"),
+    )
 
 
 def test_unidirectional_sequence_lstm_rejects_projection():
@@ -13584,8 +14366,14 @@ def test_bidirectional_sequence_rnn_time_major():
     )
 
     fn = mod["main"]
-    assert tuple(int(d) for d in fn.params[0].ty.shape) == (time, batch, 
input_size)
-    assert tuple(int(d) for d in fn.ret_ty.shape) == (time, batch, num_units * 
2)
+    tvm.ir.assert_structural_equal(
+        fn.params[0].ty,
+        relax.TensorType((time, batch, input_size), "float32"),
+    )
+    tvm.ir.assert_structural_equal(
+        fn.ret_ty,
+        relax.TensorType((time, batch, num_units * 2), "float32"),
+    )
 
 
 def test_bidirectional_sequence_rnn_rejects_aux_input():
@@ -13836,11 +14624,15 @@ def 
test_bidirectional_sequence_lstm_none_activation():
         )
     )
 
-    script = mod.script(show_meta=True)
-    assert script.count("R.sigmoid") == 4
-    assert "R.tanh" not in script
-    assert script.count("R.stack") == 2
-    assert "R.concat" in script
+    fn = mod["main"]
+    tvm.ir.assert_structural_equal(
+        fn.params[0].ty,
+        relax.TensorType((batch, time, input_size), "float32"),
+    )
+    tvm.ir.assert_structural_equal(
+        fn.ret_ty,
+        relax.TensorType((batch, time, num_units * 2), "float32"),
+    )
 
 
 def test_bidirectional_sequence_lstm_time_major():
@@ -13882,8 +14674,14 @@ def test_bidirectional_sequence_lstm_time_major():
     )
 
     fn = mod["main"]
-    assert tuple(int(d) for d in fn.params[0].ty.shape) == (time, batch, 
input_size)
-    assert tuple(int(d) for d in fn.ret_ty.shape) == (time, batch, num_units * 
2)
+    tvm.ir.assert_structural_equal(
+        fn.params[0].ty,
+        relax.TensorType((time, batch, input_size), "float32"),
+    )
+    tvm.ir.assert_structural_equal(
+        fn.ret_ty,
+        relax.TensorType((time, batch, num_units * 2), "float32"),
+    )
 
 
 def test_bidirectional_sequence_lstm_rejects_aux_input():
@@ -14092,10 +14890,14 @@ def 
test_unidirectional_sequence_rnn_relu_activation():
 
     fn = mod["main"]
     assert len(fn.params) == 1, "only the sequence input should be a graph 
input"
-    in_shape = fn.params[0].ty.shape
-    assert tuple(int(d) for d in in_shape) == (batch, time, input_size)
-    out_shape = fn.ret_ty.shape
-    assert tuple(int(d) for d in out_shape) == (batch, time, num_units)
+    tvm.ir.assert_structural_equal(
+        fn.params[0].ty,
+        relax.TensorType((batch, time, input_size), "float32"),
+    )
+    tvm.ir.assert_structural_equal(
+        fn.ret_ty,
+        relax.TensorType((batch, time, num_units), "float32"),
+    )
 
 
 def test_unidirectional_sequence_rnn_time_major():
@@ -14124,11 +14926,15 @@ def test_unidirectional_sequence_rnn_time_major():
 
     fn = mod["main"]
     # Input to the graph is the raw time-major tensor [time, batch, 
input_size].
-    in_shape = fn.params[0].ty.shape
-    assert tuple(int(d) for d in in_shape) == (time, batch, input_size)
+    tvm.ir.assert_structural_equal(
+        fn.params[0].ty,
+        relax.TensorType((time, batch, input_size), "float32"),
+    )
     # Output is always batch-major [batch, time, num_units].
-    out_shape = fn.ret_ty.shape
-    assert tuple(int(d) for d in out_shape) == (batch, time, num_units)
+    tvm.ir.assert_structural_equal(
+        fn.ret_ty,
+        relax.TensorType((batch, time, num_units), "float32"),
+    )
 
 
 def test_real():

Reply via email to