This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6157f49205 [Relax][Frontend][TFLite] Add BROADCAST_TO, 
EMBEDDING_LOOKUP, and SELECT_V2 (#19489)
6157f49205 is described below

commit 6157f49205e5dbfa40e617a4b2fb72d4c913776b
Author: Bana <[email protected]>
AuthorDate: Sat May 2 12:41:41 2026 +0300

    [Relax][Frontend][TFLite] Add BROADCAST_TO, EMBEDDING_LOOKUP, and SELECT_V2 
(#19489)
    
    This PR adds support for three new operators in the Relax TFLite
    frontend:` BROADCAST_TO`, `EMBEDDING_LOOKUP,` and `SELECT_V2.`
    
    
    Passed all newly added unit tests using
    ```
    pytest tests/python/relax/test_frontend_tflite.py -k "test_broadcast_to or 
test_embedding_lookup or test_select_v2"
    ```
    
    reference #19412
---
 .../tvm/relax/frontend/tflite/tflite_frontend.py   | 30 ++++++++++-
 tests/python/relax/test_frontend_tflite.py         | 61 ++++++++++++++++++++++
 2 files changed, 90 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py 
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index e45f569856..ebfbcacf9c 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -126,6 +126,7 @@ class OperatorConverter:
             "BATCH_TO_SPACE_ND": self.convert_batch_to_space_nd,
             "BATCH_MATMUL": self.convert_batch_matmul,
             "BITCAST": self.convert_bitcast,
+            "BROADCAST_TO": self.convert_broadcast_to,
             "BROADCAST_ARGS": self.convert_broadcast_args,
             "CAST": self.convert_cast,
             "CEIL": functools.partial(self._convert_unary_elemwise, 
relax_op=_op.ceil),
@@ -141,6 +142,7 @@ class OperatorConverter:
             "DILATE": self.convert_dilate,
             "DIV": functools.partial(self._convert_elemwise, 
relax_op=_op.divide),
             "ELU": self.convert_elu,
+            "EMBEDDING_LOOKUP": self.convert_embedding_lookup,
             "EQUAL": functools.partial(
                 self._convert_elemwise, relax_op=_op.equal, comparison_op=True
             ),
@@ -220,6 +222,7 @@ class OperatorConverter:
             "REVERSE_V2": self.convert_reverse_v2,
             "SCATTER_ND": self.convert_scatter_nd,
             "SELECT": self.convert_select,
+            "SELECT_V2": self.convert_select,
             "SHAPE": self.convert_shape,
             "SIN": functools.partial(self._convert_unary_elemwise, 
relax_op=_op.sin),
             "SLICE": self.convert_slice,
@@ -1572,7 +1575,7 @@ class OperatorConverter:
         assert axis < data_dim, "Axis out of bounds"
 
         if self.has_expr(indices.tensor_idx):
-            indices_expr = relax.op.cast(self.get_expr(indices.tensor_idx), 
"int32")
+            indices_expr = relax.op.astype(self.get_expr(indices.tensor_idx), 
"int32")
         else:
             indices_val = self.get_tensor_value(indices)
             indices_expr = self.exp_tab.new_const(
@@ -3181,6 +3184,31 @@ class OperatorConverter:
 
         return out
 
+    def convert_broadcast_to(self, op):
+        """Convert TFLite BROADCAST_TO"""
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 2, "input tensors length should be 2"
+        data = self.get_tensor_expr(input_tensors[0])
+        shape_tensor = input_tensors[1]
+        if self.has_expr(shape_tensor.tensor_idx):
+            shape_expr = self.get_expr(shape_tensor.tensor_idx)
+            shape = self.bb.emit(relax.op.tensor_to_shape(shape_expr))
+        else:
+            shape = to_int_list(self.get_tensor_value(shape_tensor))
+        return relax.op.broadcast_to(data, shape)
+
+    def convert_embedding_lookup(self, op):
+        """Convert TFLite EMBEDDING_LOOKUP"""
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 2, "input tensors length should be 2"
+        params = self.get_tensor_expr(input_tensors[0])
+        indices_tensor = input_tensors[1]
+        if self.has_expr(indices_tensor.tensor_idx):
+            indices = 
relax.op.astype(self.get_expr(indices_tensor.tensor_idx), "int32")
+        else:
+            indices = self.get_tensor_expr(indices_tensor)
+        return relax.op.take(params, indices, axis=0)
+
     def convert_batch_matmul(self, op):
         """batch_matmul implementation."""
 
diff --git a/tests/python/relax/test_frontend_tflite.py 
b/tests/python/relax/test_frontend_tflite.py
index beef66e09b..c5531ccf73 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -1707,6 +1707,67 @@ def test_networks(net, shape):
     verify(concrete_func)
 
 
+def test_broadcast_to():
+    class Model(tf.Module):
+        @tf.function(input_signature=[tf.TensorSpec(shape=(2, 2), 
dtype=tf.float32)])
+        def func(self, x):
+            return tf.broadcast_to(x, [3, 2, 2])
+
+    verify(Model)
+
+    class ModelScalarAndInt(tf.Module):
+        @tf.function(input_signature=[tf.TensorSpec(shape=(), dtype=tf.int32)])
+        def func(self, x):
+            return tf.broadcast_to(x, [4, 4])
+
+    verify(ModelScalarAndInt)
+
+
+def test_embedding_lookup():
+    class Model(tf.Module):
+        @tf.function(input_signature=[tf.TensorSpec(shape=(3,), 
dtype=tf.int32)])
+        def func(self, indices):
+            params = tf.constant([[1, 2], [3, 4], [5, 6]], dtype=tf.float32)
+            return tf.nn.embedding_lookup(params, indices)
+
+    verify(Model)
+
+    class ModelMultidim(tf.Module):
+        @tf.function(input_signature=[tf.TensorSpec(shape=(2, 3), 
dtype=tf.int32)])
+        def func(self, indices):
+            params = tf.constant([[1, 2], [3, 4], [5, 6], [7, 8]], 
dtype=tf.float32)
+            return tf.nn.embedding_lookup(params, indices)
+
+    verify(ModelMultidim)
+
+
+def test_select_v2():
+    class Model(tf.Module):
+        @tf.function(
+            input_signature=[
+                tf.TensorSpec(shape=(2, 2), dtype=tf.bool),
+                tf.TensorSpec(shape=(2, 2), dtype=tf.float32),
+                tf.TensorSpec(shape=(2, 2), dtype=tf.float32),
+            ]
+        )
+        def func(self, condition, x, y):
+            return tf.where(condition, x, y)
+
+    verify(Model)
+
+    class ModelBroadcasting(tf.Module):
+        @tf.function(
+            input_signature=[
+                tf.TensorSpec(shape=(2, 1), dtype=tf.bool),
+                tf.TensorSpec(shape=(2, 2), dtype=tf.float32),
+                tf.TensorSpec(shape=(), dtype=tf.float32),
+            ]
+        )
+        def func(self, condition, x, y):
+            return tf.where(condition, x, y)
+
+    verify(ModelBroadcasting)
+
 def test_scatter_nd():
     class Model(tf.Module):
         @tf.function(

Reply via email to