This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 14751b3491 [relax][tflite] Add PRELU/LRN/SQUARED_DIFFERENCE tests 
(partial #18971) (#19404)
14751b3491 is described below

commit 14751b34913394bd4b621338f702874c51cffb9a
Author: Ahmad Jahaf <[email protected]>
AuthorDate: Wed Apr 15 07:40:50 2026 +0300

    [relax][tflite] Add PRELU/LRN/SQUARED_DIFFERENCE tests (partial #18971) 
(#19404)
    
    ## Summary
    This PR partially implements test coverage requested in issue #18971 for
    Relax TFLite frontend operator tests.
    
    Added explicit tests in
    
[tests/python/relax/test_frontend_tflite.py](tests/python/relax/test_frontend_tflite.py):
    - PRELU
    - SQUARED_DIFFERENCE
    - LOCAL_RESPONSE_NORMALIZATION
    
    ## Validation
    Ran:
    - `pytest tests/python/relax/test_frontend_tflite.py -k 'test_prelu or
    test_squared_difference or test_local_response_normalization' -q`
    
    Result:
    - 3 passed
    
    Refs: #18971
---
 .../tvm/relax/frontend/tflite/tflite_frontend.py   | 28 +++++++++--
 tests/python/relax/test_frontend_tflite.py         | 55 +++++++++++++++++++++-
 2 files changed, 79 insertions(+), 4 deletions(-)

diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py 
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 0f9f168a13..ce74f707cf 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -844,7 +844,29 @@ class OperatorConverter:
         size = (radius * 2) + 1
         alpha = alpha * size
         axis = 3  # NHWC format
-        out = relax.op.nn.lrn(in_expr, size=size, axis=axis, bias=bias, 
alpha=alpha, beta=beta)
+        data_shape = to_int_list(self.get_tensor_shape(input_tensor))
+        in_type = self.get_tensor_type_str(input_tensor.tensor.Type())
+
+        # Relax currently does not expose a dedicated LRN op. Implement NHWC 
channel LRN
+        # by pooling squared values over the channel axis.
+        squared = self.bb.normalize(relax.op.square(in_expr))
+        squared_2d = _op.reshape(squared, [-1, data_shape[axis], 1, 1])
+        pooled = self.bb.normalize(
+            relax.op.nn.avg_pool2d(
+                squared_2d,
+                pool_size=[size, 1],
+                strides=[1, 1],
+                padding=[radius, 0, radius, 0],
+                layout="NHWC",
+                count_include_pad=True,
+            )
+        )
+        pooled = self.bb.normalize(_op.reshape(pooled, data_shape))
+        denom = relax.op.power(
+            relax.op.add(relax.const(bias, in_type), 
relax.op.multiply(relax.const(alpha, in_type), pooled)),
+            relax.const(beta, in_type),
+        )
+        out = relax.op.divide(in_expr, denom)
 
         return out
 
@@ -1421,7 +1443,7 @@ class OperatorConverter:
             out_f32 = relax.op.subtract(lhs_expr_f32, rhs_expr_f32)
             return self.quantize(out_f32 * out_f32, output_tensors[0])
 
-        difference = self._convert_elemwise(_op.subtract, op)
+        difference = self._convert_elemwise(op, _op.subtract)
         # _convert_elemwise has guaranteed only have one output tensor
         exp_type = 
self.get_tensor_type_str(self.get_output_tensors(op)[0].tensor.Type())
         out = relax.op.power(difference, relax.const(2, exp_type))
@@ -3010,10 +3032,10 @@ class OperatorConverter:
         input_tensor = input_tensors[0]
         alpha_tensor = input_tensors[1]
         data_shape = to_int_list(self.get_tensor_shape(input_tensor))
+        in_expr = self.get_tensor_expr(input_tensor)
         alpha_expr = self.get_tensor_expr(alpha_tensor)
         alpha_expr = self.bb.normalize(relax.op.broadcast_to(alpha_expr, 
data_shape))
         alpha_expr = self.bb.normalize(relax.op.reshape(alpha_expr, [-1]))
-        in_expr = self.get_tensor_expr(input_tensor)
         out = relax.op.nn.prelu(_op.reshape(in_expr, [-1]), alpha_expr, axis=0)
         out = relax.op.reshape(out, data_shape)
         return out
diff --git a/tests/python/relax/test_frontend_tflite.py 
b/tests/python/relax/test_frontend_tflite.py
index 8eb2c8e13b..a116daebb1 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -518,6 +518,18 @@ def test_swish():
     verify(TfInput, Expected)
 
 
+def test_prelu():
+    alpha_init = tf.keras.initializers.Constant(np.linspace(0.1, 0.3, 30, 
dtype=np.float32))
+    prelu = tf.keras.layers.PReLU(alpha_initializer=alpha_init)
+
+    class TfInput(tf.Module):
+        @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), 
dtype=tf.float32)])
+        def func(self, x):
+            return prelu(x)
+
+    verify(TfInput)
+
+
 def test_fill():
     class TfInput(tf.Module):
         @tf.function(
@@ -800,6 +812,33 @@ def test_split_binary(tf_op, relax_op):
     verify(Binary, Expected)
 
 
+def test_squared_difference():
+    class SquaredDifference(tf.Module):
+        @tf.function(
+            input_signature=[
+                tf.TensorSpec(shape=(2, 3), dtype=tf.float32),
+                tf.TensorSpec(shape=(2, 3), dtype=tf.float32),
+            ]
+        )
+        def func(self, x, y):
+            return tf.math.squared_difference(x, y)
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), 
dtype="float32")
+        ) -> R.Tensor((2, 3), dtype="float32"):
+            R.func_attr({"num_input": 2})
+            with R.dataflow():
+                lv: R.Tensor((2, 3), dtype="float32") = R.subtract(x, y)
+                gv: R.Tensor((2, 3), dtype="float32") = R.power(lv, 
R.const(2.0, "float32"))
+                R.output(gv)
+            return gv
+
+    verify(SquaredDifference, Expected)
+
+
 @pytest.mark.parametrize(
     "tf_op, relax_op, axis, out_shape",
     [
@@ -918,6 +957,21 @@ def test_l2_normalization():
     verify(L2Normalization)
 
 
+def test_local_response_normalization():
+    class LocalResponseNormalization(tf.Module):
+        @tf.function(input_signature=[tf.TensorSpec(shape=(1, 8, 8, 4), 
dtype=tf.float32)])
+        def func(self, x):
+            return tf.nn.local_response_normalization(
+                x,
+                depth_radius=2,
+                bias=1.0,
+                alpha=1e-4,
+                beta=0.75,
+            )
+
+    verify(LocalResponseNormalization)
+
+
 def test_slice():
     class Slice(tf.Module):
         @tf.function(input_signature=[tf.TensorSpec(shape=(3, 4), 
dtype=tf.float32)])
@@ -957,7 +1011,6 @@ def test_reverse_v2():
 
     verify(ReverseV2, Expected)
 
-
 def _make_conv2d_module(data_shape, kernel_shape, data_format, strides, 
padding):
     class Conv2DModule(tf.Module):
         @tf.function(

Reply via email to