This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4a688ddcbc [Relax][Frontend][TFLite] Add EMBEDDING_LOOKUP_SPARSE
converter (#19652)
4a688ddcbc is described below
commit 4a688ddcbcb6c51d52b5458ee00b70925785155f
Author: YinHanke <[email protected]>
AuthorDate: Wed Jun 3 01:50:30 2026 +0800
[Relax][Frontend][TFLite] Add EMBEDDING_LOOKUP_SPARSE converter (#19652)
## Summary
Add Relax TFLite frontend support for `EMBEDDING_LOOKUP_SPARSE`.
This PR adds a converter for `EMBEDDING_LOOKUP_SPARSE` in the Relax
TFLite frontend. The implementation supports the `SUM`, `MEAN`, and
`SQRTN` combiners and handles higher-rank sparse indices. The sparse
aggregation is lowered through `scatter_nd` to match TFLite operator
semantics for the supported cases.
The PR also adds handcrafted TFLite frontend tests covering:
- `SUM`
- `MEAN`
- `SQRTN`
- a 3D indices case
## Testing
Ran `tests/python/relax/test_frontend_tflite.py -k
'embedding_lookup_sparse'`.
Part of #19519
---------
Co-authored-by: gemini-code-assist[bot]
<176961590+gemini-code-assist[bot]@users.noreply.github.com>
---
.../tvm/relax/frontend/tflite/tflite_frontend.py | 118 ++++++++++++
tests/python/relax/test_frontend_tflite.py | 213 +++++++++++++++++++++
2 files changed, 331 insertions(+)
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index bf90895cfc..67d57e5866 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -224,6 +224,7 @@ class OperatorConverter:
"DIV": functools.partial(self._convert_elemwise,
relax_op=_op.divide),
"ELU": self.convert_elu,
"EMBEDDING_LOOKUP": self.convert_embedding_lookup,
+ "EMBEDDING_LOOKUP_SPARSE": self.convert_embedding_lookup_sparse,
"EQUAL": functools.partial(
self._convert_elemwise, relax_op=_op.equal, comparison_op=True
),
@@ -6339,6 +6340,123 @@ class OperatorConverter:
indices = self.get_tensor_expr(indices_tensor)
return relax.op.take(params, indices, axis=0)
+ def convert_embedding_lookup_sparse(self, op):
+ """Convert TFLite EMBEDDING_LOOKUP_SPARSE."""
+ from tflite.CombinerType import CombinerType
+ from tflite.EmbeddingLookupSparseOptions import
EmbeddingLookupSparseOptions
+ from tflite.TensorType import TensorType
+
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 5, "EMBEDDING_LOOKUP_SPARSE should have 5
input tensors"
+ output_tensors = self.get_output_tensors(op)
+ assert len(output_tensors) == 1, "EMBEDDING_LOOKUP_SPARSE should have
1 output tensor"
+
+ ids_tensor, indices_tensor, dense_shape_tensor, weights_tensor,
params_tensor = (
+ input_tensors
+ )
+ output_tensor = output_tensors[0]
+
+ for tensor in input_tensors:
+ assert not tensor.qnn_params, "Quantized input is not expected."
+
+ assert ids_tensor.tensor.Type() == TensorType.INT32
+ assert indices_tensor.tensor.Type() == TensorType.INT32
+ assert dense_shape_tensor.tensor.Type() == TensorType.INT32
+ assert weights_tensor.tensor.Type() == TensorType.FLOAT32
+ assert params_tensor.tensor.Type() == TensorType.FLOAT32
+ assert output_tensor.tensor.Type() == TensorType.FLOAT32
+
+ ids_shape = to_int_list(self.get_tensor_shape(ids_tensor))
+ indices_shape = to_int_list(self.get_tensor_shape(indices_tensor))
+ dense_shape_shape =
to_int_list(self.get_tensor_shape(dense_shape_tensor))
+ weights_shape = to_int_list(self.get_tensor_shape(weights_tensor))
+ params_shape = to_int_list(self.get_tensor_shape(params_tensor))
+
+ assert len(ids_shape) == 1, "EMBEDDING_LOOKUP_SPARSE ids must be rank
1"
+ assert len(indices_shape) == 2, "EMBEDDING_LOOKUP_SPARSE indices must
be rank 2"
+ assert len(dense_shape_shape) == 1, "EMBEDDING_LOOKUP_SPARSE
dense_shape must be rank 1"
+ assert len(weights_shape) == 1, "EMBEDDING_LOOKUP_SPARSE weights must
be rank 1"
+ assert len(params_shape) >= 2, "EMBEDDING_LOOKUP_SPARSE params must be
rank >= 2"
+ assert indices_shape[0] == ids_shape[0], (
+ "EMBEDDING_LOOKUP_SPARSE ids and indices must agree on lookup
count"
+ )
+ assert weights_shape[0] == ids_shape[0], (
+ "EMBEDDING_LOOKUP_SPARSE ids and weights must agree on lookup
count"
+ )
+
+ if self.has_expr(dense_shape_tensor.tensor_idx):
+ raise tvm.error.OpNotImplemented(
+ "TFLite EMBEDDING_LOOKUP_SPARSE with runtime dense_shape is
not supported."
+ )
+
+ dense_shape = to_int_list(self.get_tensor_value(dense_shape_tensor))
+ lookup_rank = indices_shape[1]
+ assert len(dense_shape) == lookup_rank, (
+ "EMBEDDING_LOOKUP_SPARSE dense_shape length must match indices
width"
+ )
+ assert lookup_rank >= 1, "EMBEDDING_LOOKUP_SPARSE indices width must
be positive"
+ if not self.has_expr(ids_tensor.tensor_idx):
+ ids_value = self.get_tensor_value(ids_tensor)
+ if np.any(ids_value < 0):
+ raise tvm.error.OpNotImplemented(
+ "TFLite EMBEDDING_LOOKUP_SPARSE with negative ids is not
supported."
+ )
+
+ params = self.get_tensor_expr(params_tensor)
+ ids = self.get_tensor_expr(ids_tensor)
+ weights = self.get_tensor_expr(weights_tensor)
+ indices = self.get_tensor_expr(indices_tensor)
+
+ ids = relax.op.astype(ids, "int32")
+ lookup = relax.op.take(params, ids, axis=0)
+
+ embedding_tail_shape = params_shape[1:]
+ output_prefix_shape = dense_shape[:-1]
+ output_shape = output_prefix_shape + embedding_tail_shape
+
+ # Aggregation buckets are defined by every sparse index dimension
except the last one.
+ bucket_indices = relax.op.strided_slice(indices, axes=[1], begin=[0],
end=[lookup_rank - 1])
+
+ weight_expand_shape = [ids_shape[0]] + [1] * len(embedding_tail_shape)
+ weighted_lookup = relax.op.multiply(lookup, relax.op.reshape(weights,
weight_expand_shape))
+
+ value_base = relax.const(np.zeros(output_shape, dtype=np.float32),
"float32")
+ summed_lookup = relax.op.scatter_nd(value_base, bucket_indices,
weighted_lookup, "add")
+
+ op_options = op.BuiltinOptions()
+ sparse_options = EmbeddingLookupSparseOptions()
+ sparse_options.Init(op_options.Bytes, op_options.Pos)
+ combiner = sparse_options.Combiner()
+ if combiner == CombinerType.SUM:
+ return summed_lookup
+
+ count_shape = output_prefix_shape
+ count_base = relax.const(np.zeros(count_shape, dtype=np.float32),
"float32")
+ bucket_count_updates = relax.const(np.ones(ids_shape,
dtype=np.float32), "float32")
+ bucket_counts = relax.op.scatter_nd(count_base, bucket_indices,
bucket_count_updates, "add")
+ if combiner == CombinerType.MEAN:
+ denominator_updates = weights
+ elif combiner == CombinerType.SQRTN:
+ denominator_updates = relax.op.multiply(weights, weights)
+ else:
+ raise tvm.error.OpNotImplemented(
+ f"Unsupported TFLite EMBEDDING_LOOKUP_SPARSE combiner value
{combiner}"
+ )
+
+ denominator = relax.op.scatter_nd(count_base, bucket_indices,
denominator_updates, "add")
+ if combiner == CombinerType.SQRTN:
+ denominator = relax.op.sqrt(denominator)
+
+ broadcast_shape = count_shape + [1] * len(embedding_tail_shape)
+ denominator = relax.op.reshape(denominator, broadcast_shape)
+ denominator = relax.op.broadcast_to(denominator, output_shape)
+ normalized = relax.op.divide(summed_lookup, denominator)
+ bucket_counts = relax.op.reshape(bucket_counts, broadcast_shape)
+ bucket_counts = relax.op.broadcast_to(bucket_counts, output_shape)
+ return relax.op.where(
+ relax.op.greater(bucket_counts, relax.const(0.0, "float32")),
normalized, value_base
+ )
+
def convert_batch_matmul(self, op):
"""batch_matmul implementation."""
diff --git a/tests/python/relax/test_frontend_tflite.py
b/tests/python/relax/test_frontend_tflite.py
index e4866d7096..e4483b9d41 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -4039,6 +4039,17 @@ def _build_hashtable_options(
return hashtable_options.HashtableOptionsEnd(builder)
+def _build_embedding_lookup_sparse_options(builder, combiner):
+ try:
+ sparse_options =
_get_tflite_schema_module("EmbeddingLookupSparseOptions")
+ except ModuleNotFoundError:
+ pytest.skip("TFLite schema does not provide
EmbeddingLookupSparseOptions")
+
+ sparse_options.EmbeddingLookupSparseOptionsStart(builder)
+ sparse_options.EmbeddingLookupSparseOptionsAddCombiner(builder, combiner)
+ return sparse_options.EmbeddingLookupSparseOptionsEnd(builder)
+
+
def _load_model_from_buffer(model_bytes):
if hasattr(tflite.Model, "Model"):
tflite_model = tflite.Model.Model.GetRootAsModel(model_bytes, 0)
@@ -4067,6 +4078,10 @@ def _run_module(mod, *inputs):
return tuple(output.numpy() for output in outputs)
+def _run_no_input_module(mod):
+ return _run_module(mod)
+
+
def _build_tflite_call_model(
call_subgraph_index=1,
callee_inputs=None,
@@ -5858,6 +5873,88 @@ def _build_tflite_hashtable_size_uninitialized_model():
)
+def _build_tflite_embedding_lookup_sparse_model(
+ combiner, indices_data, dense_shape_data, weights_data=None
+):
+ builder = flatbuffers.Builder(4096)
+
+ ids_data = np.array([1, 3, 0], dtype=np.int32)
+ indices_data = np.array(indices_data, dtype=np.int32)
+ dense_shape_data = np.array(dense_shape_data, dtype=np.int32)
+ weights_data = (
+ np.array([1.0, 2.0, 4.0], dtype=np.float32)
+ if weights_data is None
+ else np.array(weights_data, dtype=np.float32)
+ )
+ params_data = np.array(
+ [
+ [[0.00, 0.01], [0.10, 0.11], [0.20, 0.21]],
+ [[1.00, 1.01], [1.10, 1.11], [1.20, 1.21]],
+ [[2.00, 2.01], [2.10, 2.11], [2.20, 2.21]],
+ [[3.00, 3.01], [3.10, 3.11], [3.20, 3.21]],
+ ],
+ dtype=np.float32,
+ )
+
+ output_shape = dense_shape_data[:-1].tolist() + list(params_data.shape[1:])
+ sparse_options = _build_embedding_lookup_sparse_options(builder, combiner)
+
+ ids_tensor = _build_tensor(builder, 0, list(ids_data.shape),
tensor_type=_tfl_tensor_type.INT32)
+ indices_tensor = _build_tensor(
+ builder, 1, list(indices_data.shape),
tensor_type=_tfl_tensor_type.INT32
+ )
+ dense_shape_tensor = _build_tensor(
+ builder, 2, list(dense_shape_data.shape),
tensor_type=_tfl_tensor_type.INT32
+ )
+ weights_tensor = _build_tensor(
+ builder, 3, list(weights_data.shape),
tensor_type=_tfl_tensor_type.FLOAT32
+ )
+ params_tensor = _build_tensor(
+ builder, 4, list(params_data.shape),
tensor_type=_tfl_tensor_type.FLOAT32
+ )
+ output_tensor = _build_tensor(builder, 5, output_shape,
tensor_type=_tfl_tensor_type.FLOAT32)
+
+ sparse_op = _build_operator(
+ builder,
+ 0,
+ [0, 1, 2, 3, 4],
+ [5],
+
builtin_options_type=_get_builtin_options_type("EmbeddingLookupSparseOptions"),
+ builtin_options=sparse_options,
+ )
+ subgraph = _build_subgraph(
+ builder,
+ tensors=[
+ ids_tensor,
+ indices_tensor,
+ dense_shape_tensor,
+ weights_tensor,
+ params_tensor,
+ output_tensor,
+ ],
+ operators=[sparse_op],
+ inputs=[],
+ outputs=[5],
+ )
+ operator_codes = [
+ _build_operator_code(builder,
_get_builtin_operator("EMBEDDING_LOOKUP_SPARSE"))
+ ]
+ buffers = [
+ _build_buffer(builder, ids_data.tobytes()),
+ _build_buffer(builder, indices_data.tobytes()),
+ _build_buffer(builder, dense_shape_data.tobytes()),
+ _build_buffer(builder, weights_data.tobytes()),
+ _build_buffer(builder, params_data.tobytes()),
+ _build_buffer(builder),
+ ]
+ return _finish_tflite_model(
+ builder,
+ subgraph=subgraph,
+ operator_codes=operator_codes,
+ buffers=buffers,
+ )
+
+
def _build_tflite_hashtable_lookup_model(*, value_shape, value_type=None):
"""Build a model containing one HASHTABLE_LOOKUP operator."""
builder = flatbuffers.Builder(1024)
@@ -5952,6 +6049,122 @@ def test_hashtable_size_uninitialized_unsupported():
_load_model_from_buffer(_build_tflite_hashtable_size_uninitialized_model())
+def test_embedding_lookup_sparse_sum():
+ from tflite.CombinerType import CombinerType
+
+ mod = _load_model_from_buffer(
+ _build_tflite_embedding_lookup_sparse_model(
+ CombinerType.SUM,
+ indices_data=[[0, 0], [2, 0], [2, 1]],
+ dense_shape_data=[3, 2],
+ )
+ )
+
+ out = _run_no_input_module(mod)
+ expected = np.array(
+ [
+ [[1.00, 1.01], [1.10, 1.11], [1.20, 1.21]],
+ [[0.00, 0.00], [0.00, 0.00], [0.00, 0.00]],
+ [[6.00, 6.06], [6.60, 6.66], [7.20, 7.26]],
+ ],
+ dtype=np.float32,
+ )
+ np.testing.assert_allclose(out, expected, rtol=1e-5, atol=1e-5)
+
+
+def test_embedding_lookup_sparse_mean():
+ from tflite.CombinerType import CombinerType
+
+ mod = _load_model_from_buffer(
+ _build_tflite_embedding_lookup_sparse_model(
+ CombinerType.MEAN,
+ indices_data=[[0, 0], [2, 0], [2, 1]],
+ dense_shape_data=[3, 2],
+ )
+ )
+
+ out = _run_no_input_module(mod)
+ expected = np.array(
+ [
+ [[1.00, 1.01], [1.10, 1.11], [1.20, 1.21]],
+ [[0.00, 0.00], [0.00, 0.00], [0.00, 0.00]],
+ [[1.00, 1.01], [1.10, 1.11], [1.20, 1.21]],
+ ],
+ dtype=np.float32,
+ )
+ np.testing.assert_allclose(out, expected, rtol=1e-5, atol=1e-5)
+
+
+def test_embedding_lookup_sparse_mean_negative_weights():
+ from tflite.CombinerType import CombinerType
+
+ mod = _load_model_from_buffer(
+ _build_tflite_embedding_lookup_sparse_model(
+ CombinerType.MEAN,
+ indices_data=[[0, 0], [0, 1], [2, 0]],
+ dense_shape_data=[3, 2],
+ weights_data=[1.0, -2.0, 0.0],
+ )
+ )
+
+ (output,) = (_run_no_input_module(mod),)
+ expected = np.array(
+ [
+ [[5.0, 5.01], [5.1, 5.11], [5.2, 5.21]],
+ [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]],
+ [[np.nan, np.nan], [np.nan, np.nan], [np.nan, np.nan]],
+ ],
+ dtype=np.float32,
+ )
+ np.testing.assert_allclose(output, expected, rtol=1e-5, atol=1e-5,
equal_nan=True)
+
+
+def test_embedding_lookup_sparse_sqrtn():
+ from tflite.CombinerType import CombinerType
+
+ mod = _load_model_from_buffer(
+ _build_tflite_embedding_lookup_sparse_model(
+ CombinerType.SQRTN,
+ indices_data=[[0, 0], [2, 0], [2, 1]],
+ dense_shape_data=[3, 2],
+ )
+ )
+
+ out = _run_no_input_module(mod)
+ scale = np.sqrt(20.0).astype("float32")
+ expected = np.array(
+ [
+ [[1.00, 1.01], [1.10, 1.11], [1.20, 1.21]],
+ [[0.00, 0.00], [0.00, 0.00], [0.00, 0.00]],
+ [
+ [6.00 / scale, 6.06 / scale],
+ [6.60 / scale, 6.66 / scale],
+ [7.20 / scale, 7.26 / scale],
+ ],
+ ],
+ dtype=np.float32,
+ )
+ np.testing.assert_allclose(out, expected, rtol=1e-5, atol=1e-5)
+
+
+def test_embedding_lookup_sparse_indices_3d():
+ from tflite.CombinerType import CombinerType
+
+ mod = _load_model_from_buffer(
+ _build_tflite_embedding_lookup_sparse_model(
+ CombinerType.SUM,
+ indices_data=[[0, 0, 0], [2, 0, 0], [2, 0, 1]],
+ dense_shape_data=[3, 2, 2],
+ )
+ )
+
+ out = _run_no_input_module(mod)
+ expected = np.zeros((3, 2, 3, 2), dtype=np.float32)
+ expected[0, 0] = np.array([[1.00, 1.01], [1.10, 1.11], [1.20, 1.21]],
dtype=np.float32)
+ expected[2, 0] = np.array([[6.00, 6.06], [6.60, 6.66], [7.20, 7.26]],
dtype=np.float32)
+ np.testing.assert_allclose(out, expected, rtol=1e-5, atol=1e-5)
+
+
def test_hashtable_lookup_1d_value():
mod =
_load_model_from_buffer(_build_tflite_hashtable_lookup_model(value_shape=[3]))