This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a633edaa20 [Relax][Frontend][TFLite] Add BROADCAST_ARGS operator 
mapping (#19487)
a633edaa20 is described below

commit a633edaa20b425054e9708a4fd4d7175e5fa4771
Author: as4230 <[email protected]>
AuthorDate: Fri May 1 06:46:19 2026 -0400

    [Relax][Frontend][TFLite] Add BROADCAST_ARGS operator mapping (#19487)
    
    This PR adds TFLite frontend support for the BROADCAST_ARGS operator
    which computes the broadcasted shape of two input shape vectors
    
    Decomposes into existing Relax primitives instead of registering a new
    op:
    - relax.op.full + relax.op.concat align both inputs by left-padding with
    1s
    - relax.op.where + relax.op.maximum apply per-axis broadcast rule
    
    Created tests cover equal-length and different-length input cases.
    
    Validation:
    python -m pytest tests/python/relax/test_frontend_tflite.py -k
    broadcast_args
    
    Addresses the BROADCAST_ARGS item under #19412.
---
 .../tvm/relax/frontend/tflite/tflite_frontend.py   | 32 +++++++++
 tests/python/relax/test_frontend_tflite.py         | 78 ++++++++++++++++++++++
 2 files changed, 110 insertions(+)

diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py 
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 0b1097b095..9d0fdaf587 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -126,6 +126,7 @@ class OperatorConverter:
             "BATCH_TO_SPACE_ND": self.convert_batch_to_space_nd,
             "BATCH_MATMUL": self.convert_batch_matmul,
             "BITCAST": self.convert_bitcast,
+            "BROADCAST_ARGS": self.convert_broadcast_args,
             "CAST": self.convert_cast,
             "CEIL": functools.partial(self._convert_unary_elemwise, 
relax_op=_op.ceil),
             "CONCATENATION": self.convert_concatenation,
@@ -2510,6 +2511,37 @@ class OperatorConverter:
 
         return relax.op.memory.view(in_expr, shape=output_shape, 
dtype=output_dtype)
 
+    def convert_broadcast_args(self, op):
+        """Convert TFLite BROADCAST_ARGS"""
+        input_tensors = self.get_input_tensors(op)
+        output_tensors = self.get_output_tensors(op)
+        assert len(input_tensors) == 2, "input tensors length should be 2"
+        assert len(output_tensors) == 1, "output tensors length should be 1"
+
+        s0 = self.get_tensor_expr(input_tensors[0])
+        s1 = self.get_tensor_expr(input_tensors[1])
+        s0_len = to_int_list(self.get_tensor_shape(input_tensors[0]))[0]
+        s1_len = to_int_list(self.get_tensor_shape(input_tensors[1]))[0]
+        out_dtype = self.get_tensor_type_str(input_tensors[0].tensor.Type())
+
+        # Left-pad the shorter input with 1s to length target_len.
+        target_len = tirx.max(s0_len, s1_len)
+        one = relax.const(1, dtype=out_dtype)
+        s0 = relax.op.concat(
+            [relax.op.full([target_len - s0_len], one, dtype=out_dtype), s0], 
axis=0
+        )
+        s1 = relax.op.concat(
+            [relax.op.full([target_len - s1_len], one, dtype=out_dtype), s1], 
axis=0
+        )
+        # Per-dim broadcast. If either side is 1 take the other, else 
elementwise max.
+        s0_is_one = relax.op.equal(s0, one)
+        s1_is_one = relax.op.equal(s1, one)
+        return relax.op.where(
+            s0_is_one,
+            s1,
+            relax.op.where(s1_is_one, s0, relax.op.maximum(s0, s1)),
+        )
+
     def convert_cast(self, op):
         """Convert TFLite CAST"""
 
diff --git a/tests/python/relax/test_frontend_tflite.py 
b/tests/python/relax/test_frontend_tflite.py
index 64e4a6e953..f4a0612705 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -836,6 +836,84 @@ def test_square():
     verify(TfInput, Expected)
 
 
+def test_broadcast_args():
+    class TfInput(tf.Module):
+        @tf.function(
+            input_signature=[
+                tf.TensorSpec(shape=(3,), dtype=tf.int32),
+                tf.TensorSpec(shape=(3,), dtype=tf.int32),
+            ]
+        )
+        def func(self, s0, s1):
+            return tf.broadcast_dynamic_shape(s0, s1)
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            s0: R.Tensor((3,), dtype="int32"), s1: R.Tensor((3,), 
dtype="int32")
+        ) -> R.Tensor((3,), dtype="int32"):
+            R.func_attr({"num_input": 2})
+            with R.dataflow():
+                lv: R.Tensor((0,), dtype="int32") = R.full(
+                    R.shape([0]), R.const(1, "int32"), dtype="int32"
+                )
+                lv1: R.Tensor((3,), dtype="int32") = R.concat((lv, s0), axis=0)
+                lv2: R.Tensor((3,), dtype="bool") = R.equal(lv1, R.const(1, 
"int32"))
+                lv3: R.Tensor((0,), dtype="int32") = R.full(
+                    R.shape([0]), R.const(1, "int32"), dtype="int32"
+                )
+                lv4: R.Tensor((3,), dtype="int32") = R.concat((lv3, s1), 
axis=0)
+                lv5: R.Tensor((3,), dtype="bool") = R.equal(lv4, R.const(1, 
"int32"))
+                lv6: R.Tensor((3,), dtype="int32") = R.maximum(lv1, lv4)
+                lv7: R.Tensor((3,), dtype="int32") = R.where(lv5, lv1, lv6)
+                gv: R.Tensor((3,), dtype="int32") = R.where(lv2, lv4, lv7)
+                R.output(gv)
+            return gv
+
+    verify(TfInput, Expected)
+
+
+def test_broadcast_args_diff_length():
+    """BROADCAST_ARGS with shape inputs of different lengths."""
+
+    class TfInput(tf.Module):
+        @tf.function(
+            input_signature=[
+                tf.TensorSpec(shape=(1,), dtype=tf.int32),
+                tf.TensorSpec(shape=(3,), dtype=tf.int32),
+            ]
+        )
+        def func(self, s0, s1):
+            return tf.broadcast_dynamic_shape(s0, s1)
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            s0: R.Tensor((1,), dtype="int32"), s1: R.Tensor((3,), 
dtype="int32")
+        ) -> R.Tensor((3,), dtype="int32"):
+            R.func_attr({"num_input": 2})
+            with R.dataflow():
+                lv: R.Tensor((2,), dtype="int32") = R.full(
+                    R.shape([2]), R.const(1, "int32"), dtype="int32"
+                )
+                lv1: R.Tensor((3,), dtype="int32") = R.concat((lv, s0), axis=0)
+                lv2: R.Tensor((3,), dtype="bool") = R.equal(lv1, R.const(1, 
"int32"))
+                lv3: R.Tensor((0,), dtype="int32") = R.full(
+                    R.shape([0]), R.const(1, "int32"), dtype="int32"
+                )
+                lv4: R.Tensor((3,), dtype="int32") = R.concat((lv3, s1), 
axis=0)
+                lv5: R.Tensor((3,), dtype="bool") = R.equal(lv4, R.const(1, 
"int32"))
+                lv6: R.Tensor((3,), dtype="int32") = R.maximum(lv1, lv4)
+                lv7: R.Tensor((3,), dtype="int32") = R.where(lv5, lv1, lv6)
+                gv: R.Tensor((3,), dtype="int32") = R.where(lv2, lv4, lv7)
+                R.output(gv)
+            return gv
+
+    verify(TfInput, Expected)
+
+
 @pytest.mark.parametrize(
     "tf_op, relax_op",
     [

Reply via email to