This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 066bf777b8 [Relax][Frontend][TFLite] Add HASHTABLE_LOOKUP converter
(#19654)
066bf777b8 is described below
commit 066bf777b841af6ad791b86747cb934ce8c8b09f
Author: YinHanke <[email protected]>
AuthorDate: Tue Jun 2 02:41:27 2026 +0800
[Relax][Frontend][TFLite] Add HASHTABLE_LOOKUP converter (#19654)
## Summary
Add Relax TFLite frontend support for `HASHTABLE_LOOKUP`.
This PR adds a converter for `HASHTABLE_LOOKUP` in the Relax TFLite
frontend. The implementation supports non-string value tensors and
lowers the lookup through `bucketize`, `take`, and `where` so that
missing keys return zero-filled values together with a `uint8` hits mask
matching TFLite semantics for the supported cases.
The PR also adds handcrafted TFLite frontend tests covering:
- 1D float value tensors
- 2D float value tensors
- the current unsupported string-value case
## Testing
Ran `tests/python/relax/test_frontend_tflite.py -k 'hashtable_lookup'`.
Part of #19519
---
.../tvm/relax/frontend/tflite/tflite_frontend.py | 83 ++++++++++++++++++++
tests/python/relax/test_frontend_tflite.py | 89 ++++++++++++++++++++++
2 files changed, 172 insertions(+)
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 2a4455eb30..fc3d61713d 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -248,6 +248,7 @@ class OperatorConverter:
"HASHTABLE": self.convert_hashtable,
"HASHTABLE_FIND": self.convert_hashtable_find,
"HASHTABLE_IMPORT": self.convert_hashtable_import,
+ "HASHTABLE_LOOKUP": self.convert_hashtable_lookup,
"HASHTABLE_SIZE": self.convert_hashtable_size,
"IF": self.convert_if,
"L2_NORMALIZATION": self.convert_l2_normalization,
@@ -755,6 +756,88 @@ class OperatorConverter:
"HASHTABLE_FIND requires TensorType.STRING support in Relax TFLite
frontend"
)
+ def convert_hashtable_lookup(self, op):
+ """Convert TFLite HASHTABLE_LOOKUP for non-string value tensors."""
+ from tflite.TensorType import TensorType
+
+ input_tensors = self.get_input_tensors(op)
+ output_tensors = self.get_output_tensors(op)
+ if len(input_tensors) != 3 or len(output_tensors) != 2:
+ raise tvm.error.OpNotImplemented(
+ "HASHTABLE_LOOKUP expects lookup, key, and value inputs with
two outputs"
+ )
+
+ lookup_tensor, key_tensor, value_tensor = input_tensors
+ output_tensor, hits_tensor = output_tensors
+
+ if (
+ lookup_tensor.tensor.Type() != TensorType.INT32
+ or key_tensor.tensor.Type() != TensorType.INT32
+ ):
+ raise tvm.error.OpNotImplemented(
+ "HASHTABLE_LOOKUP requires int32 lookup and key tensors"
+ )
+ if self._is_tflite_string_type(value_tensor.tensor.Type()):
+ raise tvm.error.OpNotImplemented(
+ "HASHTABLE_LOOKUP with TensorType.STRING values is not
supported"
+ )
+ if value_tensor.tensor.Type() != output_tensor.tensor.Type():
+ raise tvm.error.OpNotImplemented(
+ "HASHTABLE_LOOKUP output dtype must match the value tensor
dtype"
+ )
+ if hits_tensor.tensor.Type() != TensorType.UINT8:
+ raise tvm.error.OpNotImplemented("HASHTABLE_LOOKUP hits output
must be uint8")
+
+ lookup_shape = to_int_list(self.get_tensor_shape(lookup_tensor))
+ key_shape = to_int_list(self.get_tensor_shape(key_tensor))
+ value_shape = to_int_list(self.get_tensor_shape(value_tensor))
+ output_shape = to_int_list(self.get_tensor_shape(output_tensor))
+ hits_shape = to_int_list(self.get_tensor_shape(hits_tensor))
+
+ if len(lookup_shape) != 1 or len(key_shape) != 1 or len(value_shape) <
1:
+ raise tvm.error.OpNotImplemented(
+ "HASHTABLE_LOOKUP requires rank-1 lookup/key and rank>=1 value
tensors"
+ )
+ if key_shape[0] != value_shape[0]:
+ raise tvm.error.OpNotImplemented(
+ "HASHTABLE_LOOKUP requires key and value tensors to agree on
row count"
+ )
+ if key_shape[0] == 0:
+ raise tvm.error.OpNotImplemented(
+ "HASHTABLE_LOOKUP requires a non-empty key/value table"
+ )
+ if output_shape != [lookup_shape[0]] + value_shape[1:]:
+ raise tvm.error.OpNotImplemented(
+ "HASHTABLE_LOOKUP output shape must match lookup count and
value tail shape"
+ )
+ if hits_shape != [lookup_shape[0]]:
+ raise tvm.error.OpNotImplemented(
+ "HASHTABLE_LOOKUP hits output shape must match lookup count"
+ )
+
+ lookup = self.get_tensor_expr(lookup_tensor)
+ key = self.get_tensor_expr(key_tensor)
+ value = self.get_tensor_expr(value_tensor)
+
+ positions = relax.op.bucketize(lookup, key, out_int32=True,
right=False)
+ candidate_keys = relax.op.take(key, positions, axis=0, mode="clip")
+ in_range = relax.op.less(positions, relax.const(key_shape[0], "int32"))
+ found = relax.op.logical_and(in_range, relax.op.equal(candidate_keys,
lookup))
+
+ gathered_values = relax.op.take(value, positions, axis=0, mode="clip")
+ output_dtype = self.get_tensor_type_str(output_tensor.tensor.Type())
+ zero_values = relax.op.zeros(output_shape, output_dtype)
+
+ if len(value_shape) > 1:
+ found_values = relax.op.expand_dims(found, axis=list(range(1,
len(value_shape))))
+ found_values = relax.op.broadcast_to(found_values, output_shape)
+ else:
+ found_values = found
+
+ output = relax.op.where(found_values, gathered_values, zero_values)
+ hits = relax.op.astype(found, "uint8")
+ return relax.Tuple([output, hits])
+
def convert_hashtable_size(self, op):
"""Convert HASHTABLE_SIZE for a statically imported TFLite
hashtable."""
input_tensors = self.get_input_tensors(op)
diff --git a/tests/python/relax/test_frontend_tflite.py
b/tests/python/relax/test_frontend_tflite.py
index 7c3e526d99..c34da605de 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -4053,6 +4053,18 @@ def _get_builtin_operator(builtin_name):
return getattr(_tfl_builtin_operator, builtin_name)
+def _run_module(mod, *inputs):
+ tgt = tvm.target.Target("c")
+ ex = tvm.compile(mod, tgt)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ vm.set_input("main", *inputs)
+ vm.invoke_stateful("main")
+ outputs = vm.get_outputs("main")
+ if hasattr(outputs, "numpy"):
+ return outputs.numpy()
+ return tuple(output.numpy() for output in outputs)
+
+
def _build_tflite_call_model(
call_subgraph_index=1,
callee_inputs=None,
@@ -5844,6 +5856,36 @@ def _build_tflite_hashtable_size_uninitialized_model():
)
+def _build_tflite_hashtable_lookup_model(*, value_shape, value_type=None):
+ """Build a model containing one HASHTABLE_LOOKUP operator."""
+ builder = flatbuffers.Builder(1024)
+
+ value_type = _tfl_tensor_type.FLOAT32 if value_type is None else value_type
+
+ lookup_tensor = _build_tensor(builder, 0, [4],
tensor_type=_tfl_tensor_type.INT32)
+ key_tensor = _build_tensor(builder, 1, [3],
tensor_type=_tfl_tensor_type.INT32)
+ value_tensor = _build_tensor(builder, 2, value_shape,
tensor_type=value_type)
+ output_tensor = _build_tensor(builder, 3, [4, *value_shape[1:]],
tensor_type=value_type)
+ hits_tensor = _build_tensor(builder, 4, [4],
tensor_type=_tfl_tensor_type.UINT8)
+
+ hashtable_lookup = _build_operator(builder, 0, [0, 1, 2], [3, 4])
+ main_subgraph = _build_subgraph(
+ builder,
+ tensors=[lookup_tensor, key_tensor, value_tensor, output_tensor,
hits_tensor],
+ operators=[hashtable_lookup],
+ inputs=[0, 1, 2],
+ outputs=[3, 4],
+ )
+ operator_codes = [_build_operator_code(builder,
_get_builtin_operator("HASHTABLE_LOOKUP"))]
+ buffers = [_build_buffer(builder) for _ in range(5)]
+ return _finish_tflite_model(
+ builder,
+ subgraph=main_subgraph,
+ operator_codes=operator_codes,
+ buffers=buffers,
+ )
+
+
def test_resource_variable_call_once_init_read():
"""Test reading a resource variable initialized by a supported CALL_ONCE
subgraph."""
mod = _load_model_from_buffer(_build_tflite_resource_variable_model())
@@ -5908,6 +5950,53 @@ def test_hashtable_size_uninitialized_unsupported():
_load_model_from_buffer(_build_tflite_hashtable_size_uninitialized_model())
+def test_hashtable_lookup_1d_value():
+ mod =
_load_model_from_buffer(_build_tflite_hashtable_lookup_model(value_shape=[3]))
+
+ output, hits = _run_module(
+ mod,
+ np.array([1234, -292, -11, 0], dtype=np.int32),
+ np.array([-11, 0, 1234], dtype=np.int32),
+ np.array([0.0, 0.1, 0.4], dtype=np.float32),
+ )
+
+ np.testing.assert_allclose(output, np.array([0.4, 0.0, 0.0, 0.1],
dtype=np.float32))
+ np.testing.assert_array_equal(hits, np.array([1, 0, 1, 1], dtype=np.uint8))
+
+
+def test_hashtable_lookup_2d_value():
+ mod =
_load_model_from_buffer(_build_tflite_hashtable_lookup_model(value_shape=[3,
2]))
+
+ output, hits = _run_module(
+ mod,
+ np.array([1234, -292, -11, 0], dtype=np.int32),
+ np.array([-11, 0, 1234], dtype=np.int32),
+ np.array([[0.0, 0.1], [1.0, 1.1], [2.0, 2.1]], dtype=np.float32),
+ )
+
+ np.testing.assert_allclose(
+ output,
+ np.array(
+ [
+ [2.0, 2.1],
+ [0.0, 0.0],
+ [0.0, 0.1],
+ [1.0, 1.1],
+ ],
+ dtype=np.float32,
+ ),
+ )
+ np.testing.assert_array_equal(hits, np.array([1, 0, 1, 1], dtype=np.uint8))
+
+
+def test_hashtable_lookup_string_value_unsupported():
+ string_type = _get_string_tensor_type()
+ with pytest.raises(ValueError, match="unknown dtype `string`"):
+ _load_model_from_buffer(
+ _build_tflite_hashtable_lookup_model(value_shape=[3],
value_type=string_type)
+ )
+
+
def _get_stablehlo_builtin_operator(builtin_name):
if not hasattr(_tfl_builtin_operator, builtin_name):
pytest.skip(f"TFLite schema does not provide
BuiltinOperator.{builtin_name}")