This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5a8dae4d95 [Relax][Frontend][TFLite] Support static hashtable find
(#19879)
5a8dae4d95 is described below
commit 5a8dae4d95c55c8fec9246a607a28c3ff54ffe05
Author: Hongyi Wu <[email protected]>
AuthorDate: Wed Jun 24 14:27:49 2026 +0800
[Relax][Frontend][TFLite] Support static hashtable find (#19879)
## Summary
This PR extends the TFLite resource hashtable family from #19519 item G
from
"import + size" to a constant-foldable `HASHTABLE_FIND` subset. It
builds on the
static `HASHTABLE` / `HASHTABLE_IMPORT` import support added in #19639
and the
`HASHTABLE_LOOKUP` converter from #19654.
Relax has no string tensor type and no runtime hashtable operator, so
this PR
targets the only subset that can be lowered today: tables whose keys and
values
are constants imported in a `CALL_ONCE` init subgraph, queried by a
constant
string tensor. In that case the lookup is resolved at import time and
emitted as
a `relax.const`. Runtime string queries remain explicitly guarded
instead of
being lowered with incorrect semantics.
## Design
### Static String Tensor Decoding
The frontend now decodes constant TFLite string tensors through a shared
helper
`_get_string_tensor_value`. It parses the TFLite string-tensor binary
layout:
```text
[count(i32)][offset_0 .. offset_count(i32)][utf8 string data]
```
The helper validates the buffer length against the declared count,
checks that
offsets are in-bounds and monotonically non-decreasing, and verifies the
decoded
element count matches the tensor shape. This is reusable infrastructure
for any
future TFLite string support, independent of `HASHTABLE_FIND`.
### Hashtable Import State
`HASHTABLE_IMPORT` now stores the actual constant keys and values
(numeric or
decoded string buffers) in shared conversion state, instead of only
recording
table metadata (size, key/value dtype). Duplicate keys are rejected,
because the
constant-fold lookup assumes a unique key mapping. The captured table is
keyed by
the same `table_id` / handle resolution used by `HASHTABLE` and
`HASHTABLE_SIZE`,
so a `CALL_ONCE` init subgraph and the main graph agree on the same
logical
table.
### Constant-Foldable Find
`HASHTABLE_FIND` resolves the table handle through the importer-local
handle map
and the statically imported keys/values. For the supported subset it
builds a
Python key -> value map, applies the per-element default value,
overwrites hits
from the table, preserves the query tensor shape, and emits the result
as a
`relax.const`. The op produces no runtime Relax computation, which
matches the
fact that both the table and the query are compile-time constants.
The supported subset is intentionally narrow and guard-first:
- the table must be a constant `string -> int64` table imported via a
supported
`CALL_ONCE` `HASHTABLE_IMPORT`
- the query tensor must be a constant string buffer
- the default tensor must be a scalar or match the query shape
### String Graph Input Guard
`TensorType.STRING` graph inputs are now rejected in `from_tflite` with
a clear
`OpNotImplemented` instead of a low-level FFI `unknown dtype string`
error, since
Relax cannot represent a string tensor. This is the path hit by runtime
string
queries for `HASHTABLE_FIND` and string-valued `HASHTABLE_LOOKUP`, so
the guard
gives a frontend-level diagnostic for both.
## Operator Support
| Operator | TFLite options | Relax lowering | Supported subset |
|---|---|---|---|
| `HASHTABLE_IMPORT` | `HashtableImportOptions` | store constant
keys/values in importer state | `CALL_ONCE` init, constant keys/values,
no duplicate keys |
| `HASHTABLE_FIND` | `HashtableFindOptions` | constant-fold to
`relax.const` | constant `string -> int64` table + constant string query
|
## Not Included
- Runtime (non-constant) string queries and runtime hashtable lookup.
- `int64 -> string` find, which would require string-typed Relax
outputs.
- Any general `TensorType.STRING` tensor representation in Relax.
- A runtime Relax hashtable operator with string hashing / comparison.
- Mutable runtime resource-state threading through Relax functions.
These require core Relax support and are out of scope for the frontend.
## Tests
The tests manually build minimal TFLite flatbuffers and compare the
imported
Relax IR with `tvm.ir.assert_structural_equal`. Unsupported patterns use
`pytest.raises`.
| Test | Coverage |
|---|---|
| `test_hashtable_call_once_import_find_string_to_int64` | constant
`string -> int64` find folds to a `relax.const` |
| `test_hashtable_call_once_import_find_string_to_int64_2d_query` |
query shape preserved for a 2-D query |
| `test_hashtable_call_once_import_find_int64_to_string_unsupported` |
`int64 -> string` table rejected |
| `test_hashtable_call_once_import_find_runtime_query_unsupported` |
runtime string query rejected |
| `test_hashtable_call_once_import_duplicate_keys_unsupported` |
duplicate static keys rejected |
| `test_hashtable_lookup_string_value_unsupported` | string graph input
now gives a clean `OpNotImplemented` |
Local validation:
```bash
python -m ruff format --check \
python/tvm/relax/frontend/tflite/tflite_frontend.py \
tests/python/relax/test_frontend_tflite.py
python -m ruff check \
python/tvm/relax/frontend/tflite/tflite_frontend.py \
tests/python/relax/test_frontend_tflite.py
python -m pytest \
tests/python/relax/test_frontend_tflite.py \
-k "hashtable or resource or variable" -q
python -m pytest \
tests/python/relax/test_frontend_tflite.py -q
```
Result:
```text
ruff format --check: 2 files already formatted
ruff check: All checks passed
14 passed, 535 deselected
549 passed
```
## References
- Issue #19519 item G: TFLite resource / variable / hashtable operators
- PR #19639: TFLite resource variable and static hashtable import
support
- PR #19654: TFLite `HASHTABLE_LOOKUP` converter
---
.../tvm/relax/frontend/tflite/tflite_frontend.py | 135 +++++++++++-
tests/python/relax/test_frontend_tflite.py | 229 +++++++++++++++++++--
2 files changed, 342 insertions(+), 22 deletions(-)
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index c6fb45597c..f6db5f77e2 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -721,6 +721,45 @@ class OperatorConverter:
and tensor_wrapper.buffer.DataLength() > 0
)
+ def _get_string_tensor_value(self, tensor_wrapper, op_name):
+ """Decode a constant TFLite string tensor buffer."""
+ if not self._is_tflite_string_type(tensor_wrapper.tensor.Type()):
+ raise tvm.error.OpNotImplemented(f"{op_name} requires a
TensorType.STRING tensor")
+ if not self._has_tensor_buffer_data(tensor_wrapper):
+ raise tvm.error.OpNotImplemented(f"{op_name} requires a constant
string tensor")
+
+ data = bytes(tensor_wrapper.buffer.DataAsNumpy())
+ if len(data) < 4:
+ raise tvm.error.OpNotImplemented(f"{op_name} has an invalid string
tensor buffer")
+
+ count = int(np.frombuffer(data, dtype="<i4", count=1)[0])
+ if count < 0:
+ raise tvm.error.OpNotImplemented(f"{op_name} has an invalid string
tensor count")
+
+ header_size = 4 * (count + 2)
+ if len(data) < header_size:
+ raise tvm.error.OpNotImplemented(f"{op_name} has an invalid string
tensor offsets")
+
+ offsets = np.frombuffer(data, dtype="<i4", count=count + 1,
offset=4).astype(np.int64)
+ if np.any(offsets < header_size) or np.any(offsets > len(data)):
+ raise tvm.error.OpNotImplemented(f"{op_name} has out-of-bounds
string tensor offsets")
+ if np.any(offsets[:-1] > offsets[1:]):
+ raise tvm.error.OpNotImplemented(f"{op_name} has non-monotonic
string tensor offsets")
+
+ try:
+ values = [
+ data[int(offsets[i]) : int(offsets[i + 1])].decode("utf-8")
for i in range(count)
+ ]
+ except UnicodeDecodeError as e:
+ raise tvm.error.OpNotImplemented(f"{op_name} has invalid UTF-8
string data: {e}") from e
+ shape = self._get_tensor_shape_tuple(tensor_wrapper)
+ expected_count = math.prod(shape) if shape else 1
+ if expected_count != count:
+ raise tvm.error.OpNotImplemented(
+ f"{op_name} string tensor buffer count does not match its
shape"
+ )
+ return np.array(values, dtype=object).reshape(shape)
+
def convert_hashtable(self, op):
"""Convert a TFLite HASHTABLE into an importer-local table handle."""
input_tensors = self.get_input_tensors(op)
@@ -769,6 +808,20 @@ class OperatorConverter:
):
raise tvm.error.OpNotImplemented("HASHTABLE_IMPORT requires
constant keys and values")
+ if self._is_tflite_string_type(table_info["key_dtype"]):
+ keys = self._get_string_tensor_value(key_tensor,
"HASHTABLE_IMPORT")
+ else:
+ keys = self.get_tensor_value(key_tensor)
+ if self._is_tflite_string_type(table_info["value_dtype"]):
+ values = self._get_string_tensor_value(value_tensor,
"HASHTABLE_IMPORT")
+ else:
+ values = self.get_tensor_value(value_tensor)
+
+ if np.unique(keys).size != keys.size:
+ raise tvm.error.OpNotImplemented(
+ "HASHTABLE_IMPORT with duplicate keys is not supported"
+ )
+
hashtable_values = self.conversion_state["hashtable_values"]
table_key = table_info["table_key"]
if table_key not in hashtable_values:
@@ -776,14 +829,81 @@ class OperatorConverter:
"size": math.prod(key_shape) if key_shape else 1,
"key_dtype": table_info["key_dtype"],
"value_dtype": table_info["value_dtype"],
+ "keys": keys,
+ "values": values,
}
return None
def convert_hashtable_find(self, op):
- """Reject HASHTABLE_FIND until Relax can represent TFLite string
tensors."""
- raise tvm.error.OpNotImplemented(
- "HASHTABLE_FIND requires TensorType.STRING support in Relax TFLite
frontend"
+ """Convert the constant-foldable string-to-int64 HASHTABLE_FIND
subset."""
+ from tflite.TensorType import TensorType
+
+ input_tensors = self.get_input_tensors(op)
+ output_tensors = self.get_output_tensors(op)
+ if len(input_tensors) != 3 or len(output_tensors) != 1:
+ raise tvm.error.OpNotImplemented(
+ "HASHTABLE_FIND expects table, query, and default inputs with
one output"
+ )
+
+ table_tensor, query_tensor, default_tensor = input_tensors
+ output_tensor = output_tensors[0]
+ table_info = self._get_hashtable_info_for_handle(table_tensor,
"HASHTABLE_FIND")
+ table_key = table_info["table_key"]
+ hashtable_values = self.conversion_state["hashtable_values"]
+ if table_key not in hashtable_values:
+ raise tvm.error.OpNotImplemented(
+ "HASHTABLE_FIND requires a table initialized by a supported
CALL_ONCE subgraph"
+ )
+ table_values = hashtable_values[table_key]
+
+ if (
+ query_tensor.tensor.Type() != table_values["key_dtype"]
+ or default_tensor.tensor.Type() != table_values["value_dtype"]
+ or output_tensor.tensor.Type() != table_values["value_dtype"]
+ ):
+ raise tvm.error.OpNotImplemented("HASHTABLE_FIND key/value dtypes
mismatch")
+
+ if not (
+ self._is_tflite_string_type(table_values["key_dtype"])
+ and table_values["value_dtype"] == TensorType.INT64
+ ):
+ raise tvm.error.OpNotImplemented(
+ "HASHTABLE_FIND only supports constant string -> int64 tables"
+ )
+ if not self._has_tensor_buffer_data(query_tensor):
+ raise tvm.error.OpNotImplemented(
+ "HASHTABLE_FIND with runtime string queries is not supported"
+ )
+ if not self._has_tensor_buffer_data(default_tensor):
+ raise tvm.error.OpNotImplemented("HASHTABLE_FIND requires constant
default values")
+
+ query_shape = self._get_tensor_shape_tuple(query_tensor)
+ output_shape = self._get_tensor_shape_tuple(output_tensor)
+ if output_shape != query_shape:
+ raise tvm.error.OpNotImplemented("HASHTABLE_FIND output shape must
match query shape")
+
+ query_values = self._get_string_tensor_value(query_tensor,
"HASHTABLE_FIND")
+ default_values = self.get_tensor_value(default_tensor)
+ default_shape = self._get_tensor_shape_tuple(default_tensor)
+ if default_shape == () or default_values.size == 1:
+ result = np.full(output_shape, int(default_values.item()),
dtype=np.int64)
+ elif default_shape == query_shape:
+ result = default_values.astype(np.int64).copy()
+ else:
+ raise tvm.error.OpNotImplemented(
+ "HASHTABLE_FIND default value must be scalar or match query
shape"
+ )
+
+ table_map = dict(
+ zip(
+ table_values["keys"].reshape(-1).tolist(),
+ table_values["values"].reshape(-1).astype(np.int64).tolist(),
+ )
)
+ for index, key in np.ndenumerate(query_values):
+ if key in table_map:
+ result[index] = table_map[key]
+ return relax.const(result.astype(np.int64), "int64")
def convert_hashtable_lookup(self, op):
"""Convert TFLite HASHTABLE_LOOKUP for non-string value tensors."""
@@ -8834,6 +8954,15 @@ def from_tflite(
dtype = "float32"
if shape is not None:
shape = tuple(shape) + (2,)
+ if dtype == "string":
+ # Relax has no string tensor type, so TFLite
TensorType.STRING graph
+ # inputs cannot be represented. This also covers runtime
string queries
+ # for ops like HASHTABLE_FIND, whose constant-foldable
subset is handled
+ # in the op converter.
+ raise tvm.error.OpNotImplemented(
+ "Relax TFLite frontend does not support
TensorType.STRING graph inputs "
+ "(e.g. runtime string queries)"
+ )
input_var = relax.Var(
name_hint=model_input_name,
ty=relax.TensorType(shape=shape, dtype=dtype),
diff --git a/tests/python/relax/test_frontend_tflite.py
b/tests/python/relax/test_frontend_tflite.py
index 4d3a27dfc3..6f2845da50 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -6496,25 +6496,57 @@ def _build_tflite_resource_read_uninitialized_model():
)
-def _build_tflite_hashtable_find_model():
- """Build a model that imports a static hashtable and finds runtime query
keys."""
+def _build_tflite_hashtable_find_string_to_int64_model(
+ query_values=None,
+ query_shape=None,
+ default_values=None,
+ default_shape=None,
+ table_keys=None,
+ table_values=None,
+ query_is_input=False,
+):
+ """Build a static string-to-int64 HASHTABLE_FIND model."""
builder = flatbuffers.Builder(1024)
resource_type = _get_resource_tensor_type()
string_type = _get_string_tensor_type()
- table_keys = np.array([10, 20], dtype=np.int64)
- table_values = _build_tflite_string_buffer(["one hundred", "two hundred"])
- default_value = _build_tflite_string_buffer(["missing"])
+ query_values = ["alpha", "missing", "beta"] if query_values is None else
query_values
+ query_shape = [len(query_values)] if query_shape is None else query_shape
+ default_shape = [] if default_shape is None else default_shape
+ table_keys = ["alpha", "beta", "gamma"] if table_keys is None else
table_keys
+ table_values = (
+ np.array([10, 20, 30], dtype=np.int64)
+ if table_values is None
+ else np.array(table_values, dtype=np.int64)
+ )
+ default_values = (
+ np.array(-1, dtype=np.int64)
+ if default_values is None
+ else np.array(default_values, dtype=np.int64)
+ )
+ query_buffer = _build_tflite_string_buffer(query_values)
+ table_key_buffer = _build_tflite_string_buffer(table_keys)
call_once_options = _build_call_once_options(builder, 1)
- main_table_options = _build_hashtable_options(builder, table_id=0)
+ main_table_options = _build_hashtable_options(
+ builder,
+ table_id=0,
+ key_dtype=string_type,
+ value_dtype=_tfl_tensor_type.INT64,
+ )
find_options = _build_empty_builtin_options(builder,
"HashtableFindOptions")
- init_table_options = _build_hashtable_options(builder, table_id=0)
+ init_table_options = _build_hashtable_options(
+ builder,
+ table_id=0,
+ key_dtype=string_type,
+ value_dtype=_tfl_tensor_type.INT64,
+ )
import_options = _build_empty_builtin_options(builder,
"HashtableImportOptions")
- query_tensor = _build_tensor(builder, 0, [3],
tensor_type=_tfl_tensor_type.INT64)
+ query_buffer_idx = 0 if query_is_input else 1
+ query_tensor = _build_tensor(builder, query_buffer_idx, query_shape,
tensor_type=string_type)
table_tensor = _build_tensor(builder, 0, [1], tensor_type=resource_type)
- default_tensor = _build_tensor(builder, 1, [], tensor_type=string_type)
- output_tensor = _build_tensor(builder, 0, [3], tensor_type=string_type)
+ default_tensor = _build_tensor(builder, 2, default_shape,
tensor_type=_tfl_tensor_type.INT64)
+ output_tensor = _build_tensor(builder, 0, query_shape,
tensor_type=_tfl_tensor_type.INT64)
main_call_once = _build_operator(
builder,
0,
@@ -6543,18 +6575,119 @@ def _build_tflite_hashtable_find_model():
builder,
tensors=[query_tensor, table_tensor, default_tensor, output_tensor],
operators=[main_call_once, main_hashtable, main_find],
- inputs=[0],
+ inputs=[0] if query_is_input else [],
outputs=[3],
)
init_table_tensor = _build_tensor(builder, 0, [1],
tensor_type=resource_type)
- init_keys_tensor = _build_tensor(builder, 2, [2],
tensor_type=_tfl_tensor_type.INT64)
+ init_keys_tensor = _build_tensor(builder, 3, [len(table_keys)],
tensor_type=string_type)
init_values_tensor = _build_tensor(
+ builder,
+ 4,
+ [len(table_values)],
+ tensor_type=_tfl_tensor_type.INT64,
+ )
+ init_hashtable = _build_operator(
+ builder,
+ 1,
+ [],
+ [0],
+ builtin_options_type=_get_builtin_options_type("HashtableOptions"),
+ builtin_options=init_table_options,
+ )
+ init_import = _build_operator(
builder,
3,
- [2],
- tensor_type=string_type,
+ [0, 1, 2],
+ [],
+
builtin_options_type=_get_builtin_options_type("HashtableImportOptions"),
+ builtin_options=import_options,
+ )
+ init_subgraph = _build_subgraph(
+ builder,
+ tensors=[init_table_tensor, init_keys_tensor, init_values_tensor],
+ operators=[init_hashtable, init_import],
+ inputs=[],
+ outputs=[],
+ )
+
+ operator_codes = [
+ _build_operator_code(builder, _get_builtin_operator("CALL_ONCE")),
+ _build_operator_code(builder, _get_builtin_operator("HASHTABLE")),
+ _build_operator_code(builder, _get_builtin_operator("HASHTABLE_FIND")),
+ _build_operator_code(builder,
_get_builtin_operator("HASHTABLE_IMPORT")),
+ ]
+ buffers = [
+ _build_buffer(builder),
+ _build_buffer(builder, b"" if query_is_input else query_buffer),
+ _build_buffer(builder, default_values.tobytes()),
+ _build_buffer(builder, table_key_buffer),
+ _build_buffer(builder, table_values.tobytes()),
+ ]
+ return _finish_tflite_model(
+ builder,
+ subgraph=main_subgraph,
+ extra_subgraphs=[init_subgraph],
+ operator_codes=operator_codes,
+ buffers=buffers,
+ )
+
+
+def _build_tflite_hashtable_find_int64_to_string_model():
+ """Build a static int64-to-string HASHTABLE_FIND model."""
+ builder = flatbuffers.Builder(1024)
+ resource_type = _get_resource_tensor_type()
+ string_type = _get_string_tensor_type()
+ query_values = np.array([10, 30], dtype=np.int64)
+ table_keys = np.array([10, 20], dtype=np.int64)
+ table_values = _build_tflite_string_buffer(["ten", "twenty"])
+ default_value = _build_tflite_string_buffer(["missing"])
+
+ call_once_options = _build_call_once_options(builder, 1)
+ main_table_options = _build_hashtable_options(builder, table_id=0)
+ find_options = _build_empty_builtin_options(builder,
"HashtableFindOptions")
+ init_table_options = _build_hashtable_options(builder, table_id=0)
+ import_options = _build_empty_builtin_options(builder,
"HashtableImportOptions")
+
+ query_tensor = _build_tensor(builder, 1, [2],
tensor_type=_tfl_tensor_type.INT64)
+ table_tensor = _build_tensor(builder, 0, [1], tensor_type=resource_type)
+ default_tensor = _build_tensor(builder, 2, [], tensor_type=string_type)
+ output_tensor = _build_tensor(builder, 0, [2], tensor_type=string_type)
+ main_call_once = _build_operator(
+ builder,
+ 0,
+ [],
+ [],
+ builtin_options_type=_get_builtin_options_type("CallOnceOptions"),
+ builtin_options=call_once_options,
+ )
+ main_hashtable = _build_operator(
+ builder,
+ 1,
+ [],
+ [1],
+ builtin_options_type=_get_builtin_options_type("HashtableOptions"),
+ builtin_options=main_table_options,
+ )
+ main_find = _build_operator(
+ builder,
+ 2,
+ [1, 0, 2],
+ [3],
+ builtin_options_type=_get_builtin_options_type("HashtableFindOptions"),
+ builtin_options=find_options,
)
+ main_subgraph = _build_subgraph(
+ builder,
+ tensors=[query_tensor, table_tensor, default_tensor, output_tensor],
+ operators=[main_call_once, main_hashtable, main_find],
+ inputs=[],
+ outputs=[3],
+ )
+
+ init_table_tensor = _build_tensor(builder, 0, [1],
tensor_type=resource_type)
+ init_keys_tensor = _build_tensor(builder, 3, [2],
tensor_type=_tfl_tensor_type.INT64)
+ init_values_tensor = _build_tensor(builder, 4, [2],
tensor_type=string_type)
init_hashtable = _build_operator(
builder,
1,
@@ -6587,6 +6720,7 @@ def _build_tflite_hashtable_find_model():
]
buffers = [
_build_buffer(builder),
+ _build_buffer(builder, query_values.tobytes()),
_build_buffer(builder, default_value),
_build_buffer(builder, table_keys.tobytes()),
_build_buffer(builder, table_values),
@@ -6934,10 +7068,67 @@ def test_read_variable_uninitialized_unsupported():
_load_model_from_buffer(_build_tflite_resource_read_uninitialized_model())
-def test_hashtable_call_once_import_find_unsupported():
- """Test HASHTABLE_FIND remains unsupported until TFLite string tensors are
supported."""
- with pytest.raises(tvm.error.OpNotImplemented, match="TensorType.STRING"):
- _load_model_from_buffer(_build_tflite_hashtable_find_model())
+def test_hashtable_call_once_import_find_string_to_int64():
+ """Test HASHTABLE_FIND for a static string-to-int64 table."""
+ mod =
_load_model_from_buffer(_build_tflite_hashtable_find_string_to_int64_model())
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main() -> R.Tensor((3,), dtype="int64"):
+ R.func_attr({"num_input": 0})
+ with R.dataflow():
+ gv: R.Tensor((3,), dtype="int64") = R.const([10, -1, 20],
"int64")
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_hashtable_call_once_import_find_string_to_int64_2d_query():
+ """Test HASHTABLE_FIND preserves the static query shape."""
+ mod = _load_model_from_buffer(
+ _build_tflite_hashtable_find_string_to_int64_model(
+ query_values=["alpha", "beta", "missing", "gamma"],
+ query_shape=[2, 2],
+ )
+ )
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main() -> R.Tensor((2, 2), dtype="int64"):
+ R.func_attr({"num_input": 0})
+ with R.dataflow():
+ gv: R.Tensor((2, 2), dtype="int64") = R.const([[10, 20], [-1,
30]], "int64")
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_hashtable_call_once_import_find_int64_to_string_unsupported():
+ """Test HASHTABLE_FIND rejects int64-to-string tables until string outputs
exist."""
+ with pytest.raises(tvm.error.OpNotImplemented, match="string -> int64"):
+
_load_model_from_buffer(_build_tflite_hashtable_find_int64_to_string_model())
+
+
+def test_hashtable_call_once_import_find_runtime_query_unsupported():
+ """Test HASHTABLE_FIND rejects runtime string queries."""
+ with pytest.raises(tvm.error.OpNotImplemented, match="string
queries|STRING graph inputs"):
+ _load_model_from_buffer(
+
_build_tflite_hashtable_find_string_to_int64_model(query_is_input=True)
+ )
+
+
+def test_hashtable_call_once_import_duplicate_keys_unsupported():
+ """Test HASHTABLE_IMPORT rejects duplicate static keys."""
+ with pytest.raises(tvm.error.OpNotImplemented, match="duplicate keys"):
+ _load_model_from_buffer(
+ _build_tflite_hashtable_find_string_to_int64_model(
+ table_keys=["alpha", "alpha"], table_values=[10, 20]
+ )
+ )
def test_hashtable_call_once_import_size():
@@ -7126,7 +7317,7 @@ def test_hashtable_lookup_2d_value():
def test_hashtable_lookup_string_value_unsupported():
string_type = _get_string_tensor_type()
- with pytest.raises(ValueError, match="unknown dtype `string`"):
+ with pytest.raises(tvm.error.OpNotImplemented, match="STRING graph
inputs"):
_load_model_from_buffer(
_build_tflite_hashtable_lookup_model(value_shape=[3],
value_type=string_type)
)