This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6157f49205 [Relax][Frontend][TFLite] Add BROADCAST_TO,
EMBEDDING_LOOKUP, and SELECT_V2 (#19489)
6157f49205 is described below
commit 6157f49205e5dbfa40e617a4b2fb72d4c913776b
Author: Bana <[email protected]>
AuthorDate: Sat May 2 12:41:41 2026 +0300
[Relax][Frontend][TFLite] Add BROADCAST_TO, EMBEDDING_LOOKUP, and SELECT_V2
(#19489)
This PR adds support for three new operators in the Relax TFLite
frontend:` BROADCAST_TO`, `EMBEDDING_LOOKUP,` and `SELECT_V2.`
Passed all newly added unit tests using
```
pytest tests/python/relax/test_frontend_tflite.py -k "test_broadcast_to or
test_embedding_lookup or test_select_v2"
```
reference #19412
---
.../tvm/relax/frontend/tflite/tflite_frontend.py | 30 ++++++++++-
tests/python/relax/test_frontend_tflite.py | 61 ++++++++++++++++++++++
2 files changed, 90 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index e45f569856..ebfbcacf9c 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -126,6 +126,7 @@ class OperatorConverter:
"BATCH_TO_SPACE_ND": self.convert_batch_to_space_nd,
"BATCH_MATMUL": self.convert_batch_matmul,
"BITCAST": self.convert_bitcast,
+ "BROADCAST_TO": self.convert_broadcast_to,
"BROADCAST_ARGS": self.convert_broadcast_args,
"CAST": self.convert_cast,
"CEIL": functools.partial(self._convert_unary_elemwise,
relax_op=_op.ceil),
@@ -141,6 +142,7 @@ class OperatorConverter:
"DILATE": self.convert_dilate,
"DIV": functools.partial(self._convert_elemwise,
relax_op=_op.divide),
"ELU": self.convert_elu,
+ "EMBEDDING_LOOKUP": self.convert_embedding_lookup,
"EQUAL": functools.partial(
self._convert_elemwise, relax_op=_op.equal, comparison_op=True
),
@@ -220,6 +222,7 @@ class OperatorConverter:
"REVERSE_V2": self.convert_reverse_v2,
"SCATTER_ND": self.convert_scatter_nd,
"SELECT": self.convert_select,
+ "SELECT_V2": self.convert_select,
"SHAPE": self.convert_shape,
"SIN": functools.partial(self._convert_unary_elemwise,
relax_op=_op.sin),
"SLICE": self.convert_slice,
@@ -1572,7 +1575,7 @@ class OperatorConverter:
assert axis < data_dim, "Axis out of bounds"
if self.has_expr(indices.tensor_idx):
- indices_expr = relax.op.cast(self.get_expr(indices.tensor_idx),
"int32")
+ indices_expr = relax.op.astype(self.get_expr(indices.tensor_idx),
"int32")
else:
indices_val = self.get_tensor_value(indices)
indices_expr = self.exp_tab.new_const(
@@ -3181,6 +3184,31 @@ class OperatorConverter:
return out
+ def convert_broadcast_to(self, op):
+ """Convert TFLite BROADCAST_TO"""
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 2, "input tensors length should be 2"
+ data = self.get_tensor_expr(input_tensors[0])
+ shape_tensor = input_tensors[1]
+ if self.has_expr(shape_tensor.tensor_idx):
+ shape_expr = self.get_expr(shape_tensor.tensor_idx)
+ shape = self.bb.emit(relax.op.tensor_to_shape(shape_expr))
+ else:
+ shape = to_int_list(self.get_tensor_value(shape_tensor))
+ return relax.op.broadcast_to(data, shape)
+
+ def convert_embedding_lookup(self, op):
+ """Convert TFLite EMBEDDING_LOOKUP"""
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 2, "input tensors length should be 2"
+ params = self.get_tensor_expr(input_tensors[0])
+ indices_tensor = input_tensors[1]
+ if self.has_expr(indices_tensor.tensor_idx):
+ indices =
relax.op.astype(self.get_expr(indices_tensor.tensor_idx), "int32")
+ else:
+ indices = self.get_tensor_expr(indices_tensor)
+ return relax.op.take(params, indices, axis=0)
+
def convert_batch_matmul(self, op):
"""batch_matmul implementation."""
diff --git a/tests/python/relax/test_frontend_tflite.py
b/tests/python/relax/test_frontend_tflite.py
index beef66e09b..c5531ccf73 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -1707,6 +1707,67 @@ def test_networks(net, shape):
verify(concrete_func)
+def test_broadcast_to():
+ class Model(tf.Module):
+ @tf.function(input_signature=[tf.TensorSpec(shape=(2, 2),
dtype=tf.float32)])
+ def func(self, x):
+ return tf.broadcast_to(x, [3, 2, 2])
+
+ verify(Model)
+
+ class ModelScalarAndInt(tf.Module):
+ @tf.function(input_signature=[tf.TensorSpec(shape=(), dtype=tf.int32)])
+ def func(self, x):
+ return tf.broadcast_to(x, [4, 4])
+
+ verify(ModelScalarAndInt)
+
+
+def test_embedding_lookup():
+ class Model(tf.Module):
+ @tf.function(input_signature=[tf.TensorSpec(shape=(3,),
dtype=tf.int32)])
+ def func(self, indices):
+ params = tf.constant([[1, 2], [3, 4], [5, 6]], dtype=tf.float32)
+ return tf.nn.embedding_lookup(params, indices)
+
+ verify(Model)
+
+ class ModelMultidim(tf.Module):
+ @tf.function(input_signature=[tf.TensorSpec(shape=(2, 3),
dtype=tf.int32)])
+ def func(self, indices):
+ params = tf.constant([[1, 2], [3, 4], [5, 6], [7, 8]],
dtype=tf.float32)
+ return tf.nn.embedding_lookup(params, indices)
+
+ verify(ModelMultidim)
+
+
+def test_select_v2():
+ class Model(tf.Module):
+ @tf.function(
+ input_signature=[
+ tf.TensorSpec(shape=(2, 2), dtype=tf.bool),
+ tf.TensorSpec(shape=(2, 2), dtype=tf.float32),
+ tf.TensorSpec(shape=(2, 2), dtype=tf.float32),
+ ]
+ )
+ def func(self, condition, x, y):
+ return tf.where(condition, x, y)
+
+ verify(Model)
+
+ class ModelBroadcasting(tf.Module):
+ @tf.function(
+ input_signature=[
+ tf.TensorSpec(shape=(2, 1), dtype=tf.bool),
+ tf.TensorSpec(shape=(2, 2), dtype=tf.float32),
+ tf.TensorSpec(shape=(), dtype=tf.float32),
+ ]
+ )
+ def func(self, condition, x, y):
+ return tf.where(condition, x, y)
+
+ verify(ModelBroadcasting)
+
def test_scatter_nd():
class Model(tf.Module):
@tf.function(