This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a633edaa20 [Relax][Frontend][TFLite] Add BROADCAST_ARGS operator
mapping (#19487)
a633edaa20 is described below
commit a633edaa20b425054e9708a4fd4d7175e5fa4771
Author: as4230 <[email protected]>
AuthorDate: Fri May 1 06:46:19 2026 -0400
[Relax][Frontend][TFLite] Add BROADCAST_ARGS operator mapping (#19487)
This PR adds TFLite frontend support for the BROADCAST_ARGS operator
which computes the broadcasted shape of two input shape vectors
Decomposes into existing Relax primitives instead of registering a new
op:
- relax.op.full + relax.op.concat align both inputs by left-padding with
1s
- relax.op.where + relax.op.maximum apply per-axis broadcast rule
Created tests cover equal-length and different-length input cases.
Validation:
python -m pytest tests/python/relax/test_frontend_tflite.py -k
broadcast_args
Addresses the BROADCAST_ARGS item under #19412.
---
.../tvm/relax/frontend/tflite/tflite_frontend.py | 32 +++++++++
tests/python/relax/test_frontend_tflite.py | 78 ++++++++++++++++++++++
2 files changed, 110 insertions(+)
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 0b1097b095..9d0fdaf587 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -126,6 +126,7 @@ class OperatorConverter:
"BATCH_TO_SPACE_ND": self.convert_batch_to_space_nd,
"BATCH_MATMUL": self.convert_batch_matmul,
"BITCAST": self.convert_bitcast,
+ "BROADCAST_ARGS": self.convert_broadcast_args,
"CAST": self.convert_cast,
"CEIL": functools.partial(self._convert_unary_elemwise,
relax_op=_op.ceil),
"CONCATENATION": self.convert_concatenation,
@@ -2510,6 +2511,37 @@ class OperatorConverter:
return relax.op.memory.view(in_expr, shape=output_shape,
dtype=output_dtype)
+ def convert_broadcast_args(self, op):
+ """Convert TFLite BROADCAST_ARGS"""
+ input_tensors = self.get_input_tensors(op)
+ output_tensors = self.get_output_tensors(op)
+ assert len(input_tensors) == 2, "input tensors length should be 2"
+ assert len(output_tensors) == 1, "output tensors length should be 1"
+
+ s0 = self.get_tensor_expr(input_tensors[0])
+ s1 = self.get_tensor_expr(input_tensors[1])
+ s0_len = to_int_list(self.get_tensor_shape(input_tensors[0]))[0]
+ s1_len = to_int_list(self.get_tensor_shape(input_tensors[1]))[0]
+ out_dtype = self.get_tensor_type_str(input_tensors[0].tensor.Type())
+
+ # Left-pad the shorter input with 1s to length target_len.
+ target_len = tirx.max(s0_len, s1_len)
+ one = relax.const(1, dtype=out_dtype)
+ s0 = relax.op.concat(
+ [relax.op.full([target_len - s0_len], one, dtype=out_dtype), s0],
axis=0
+ )
+ s1 = relax.op.concat(
+ [relax.op.full([target_len - s1_len], one, dtype=out_dtype), s1],
axis=0
+ )
+ # Per-dim broadcast. If either side is 1 take the other, else
elementwise max.
+ s0_is_one = relax.op.equal(s0, one)
+ s1_is_one = relax.op.equal(s1, one)
+ return relax.op.where(
+ s0_is_one,
+ s1,
+ relax.op.where(s1_is_one, s0, relax.op.maximum(s0, s1)),
+ )
+
def convert_cast(self, op):
"""Convert TFLite CAST"""
diff --git a/tests/python/relax/test_frontend_tflite.py
b/tests/python/relax/test_frontend_tflite.py
index 64e4a6e953..f4a0612705 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -836,6 +836,84 @@ def test_square():
verify(TfInput, Expected)
+def test_broadcast_args():
+ class TfInput(tf.Module):
+ @tf.function(
+ input_signature=[
+ tf.TensorSpec(shape=(3,), dtype=tf.int32),
+ tf.TensorSpec(shape=(3,), dtype=tf.int32),
+ ]
+ )
+ def func(self, s0, s1):
+ return tf.broadcast_dynamic_shape(s0, s1)
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ s0: R.Tensor((3,), dtype="int32"), s1: R.Tensor((3,),
dtype="int32")
+ ) -> R.Tensor((3,), dtype="int32"):
+ R.func_attr({"num_input": 2})
+ with R.dataflow():
+ lv: R.Tensor((0,), dtype="int32") = R.full(
+ R.shape([0]), R.const(1, "int32"), dtype="int32"
+ )
+ lv1: R.Tensor((3,), dtype="int32") = R.concat((lv, s0), axis=0)
+ lv2: R.Tensor((3,), dtype="bool") = R.equal(lv1, R.const(1,
"int32"))
+ lv3: R.Tensor((0,), dtype="int32") = R.full(
+ R.shape([0]), R.const(1, "int32"), dtype="int32"
+ )
+ lv4: R.Tensor((3,), dtype="int32") = R.concat((lv3, s1),
axis=0)
+ lv5: R.Tensor((3,), dtype="bool") = R.equal(lv4, R.const(1,
"int32"))
+ lv6: R.Tensor((3,), dtype="int32") = R.maximum(lv1, lv4)
+ lv7: R.Tensor((3,), dtype="int32") = R.where(lv5, lv1, lv6)
+ gv: R.Tensor((3,), dtype="int32") = R.where(lv2, lv4, lv7)
+ R.output(gv)
+ return gv
+
+ verify(TfInput, Expected)
+
+
+def test_broadcast_args_diff_length():
+ """BROADCAST_ARGS with shape inputs of different lengths."""
+
+ class TfInput(tf.Module):
+ @tf.function(
+ input_signature=[
+ tf.TensorSpec(shape=(1,), dtype=tf.int32),
+ tf.TensorSpec(shape=(3,), dtype=tf.int32),
+ ]
+ )
+ def func(self, s0, s1):
+ return tf.broadcast_dynamic_shape(s0, s1)
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ s0: R.Tensor((1,), dtype="int32"), s1: R.Tensor((3,),
dtype="int32")
+ ) -> R.Tensor((3,), dtype="int32"):
+ R.func_attr({"num_input": 2})
+ with R.dataflow():
+ lv: R.Tensor((2,), dtype="int32") = R.full(
+ R.shape([2]), R.const(1, "int32"), dtype="int32"
+ )
+ lv1: R.Tensor((3,), dtype="int32") = R.concat((lv, s0), axis=0)
+ lv2: R.Tensor((3,), dtype="bool") = R.equal(lv1, R.const(1,
"int32"))
+ lv3: R.Tensor((0,), dtype="int32") = R.full(
+ R.shape([0]), R.const(1, "int32"), dtype="int32"
+ )
+ lv4: R.Tensor((3,), dtype="int32") = R.concat((lv3, s1),
axis=0)
+ lv5: R.Tensor((3,), dtype="bool") = R.equal(lv4, R.const(1,
"int32"))
+ lv6: R.Tensor((3,), dtype="int32") = R.maximum(lv1, lv4)
+ lv7: R.Tensor((3,), dtype="int32") = R.where(lv5, lv1, lv6)
+ gv: R.Tensor((3,), dtype="int32") = R.where(lv2, lv4, lv7)
+ R.output(gv)
+ return gv
+
+ verify(TfInput, Expected)
+
+
@pytest.mark.parametrize(
"tf_op, relax_op",
[