This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5a8dae4d95 [Relax][Frontend][TFLite] Support static hashtable find 
(#19879)
5a8dae4d95 is described below

commit 5a8dae4d95c55c8fec9246a607a28c3ff54ffe05
Author: Hongyi Wu <[email protected]>
AuthorDate: Wed Jun 24 14:27:49 2026 +0800

    [Relax][Frontend][TFLite] Support static hashtable find (#19879)
    
    ## Summary
    
    This PR extends the TFLite resource hashtable family from #19519 item G
    from
    "import + size" to a constant-foldable `HASHTABLE_FIND` subset. It
    builds on the
    static `HASHTABLE` / `HASHTABLE_IMPORT` import support added in #19639
    and the
    `HASHTABLE_LOOKUP` converter from #19654.
    
    Relax has no string tensor type and no runtime hashtable operator, so
    this PR
    targets the only subset that can be lowered today: tables whose keys and
    values
    are constants imported in a `CALL_ONCE` init subgraph, queried by a
    constant
    string tensor. In that case the lookup is resolved at import time and
    emitted as
    a `relax.const`. Runtime string queries remain explicitly guarded
    instead of
    being lowered with incorrect semantics.
    
    ## Design
    
    ### Static String Tensor Decoding
    
    The frontend now decodes constant TFLite string tensors through a shared
    helper
    `_get_string_tensor_value`. It parses the TFLite string-tensor binary
    layout:
    
    ```text
    [count(i32)][offset_0 .. offset_count(i32)][utf8 string data]
    ```
    
    The helper validates the buffer length against the declared count,
    checks that
    offsets are in-bounds and monotonically non-decreasing, and verifies the
    decoded
    element count matches the tensor shape. This is reusable infrastructure
    for any
    future TFLite string support, independent of `HASHTABLE_FIND`.
    
    ### Hashtable Import State
    
    `HASHTABLE_IMPORT` now stores the actual constant keys and values
    (numeric or
    decoded string buffers) in shared conversion state, instead of only
    recording
    table metadata (size, key/value dtype). Duplicate keys are rejected,
    because the
    constant-fold lookup assumes a unique key mapping. The captured table is
    keyed by
    the same `table_id` / handle resolution used by `HASHTABLE` and
    `HASHTABLE_SIZE`,
    so a `CALL_ONCE` init subgraph and the main graph agree on the same
    logical
    table.
    
    ### Constant-Foldable Find
    
    `HASHTABLE_FIND` resolves the table handle through the importer-local
    handle map
    and the statically imported keys/values. For the supported subset it
    builds a
    Python key -> value map, applies the per-element default value,
    overwrites hits
    from the table, preserves the query tensor shape, and emits the result
    as a
    `relax.const`. The op produces no runtime Relax computation, which
    matches the
    fact that both the table and the query are compile-time constants.
    
    The supported subset is intentionally narrow and guard-first:
    
    - the table must be a constant `string -> int64` table imported via a
    supported
      `CALL_ONCE` `HASHTABLE_IMPORT`
    - the query tensor must be a constant string buffer
    - the default tensor must be a scalar or match the query shape
    
    ### String Graph Input Guard
    
    `TensorType.STRING` graph inputs are now rejected in `from_tflite` with
    a clear
    `OpNotImplemented` instead of a low-level FFI `unknown dtype string`
    error, since
    Relax cannot represent a string tensor. This is the path hit by runtime
    string
    queries for `HASHTABLE_FIND` and string-valued `HASHTABLE_LOOKUP`, so
    the guard
    gives a frontend-level diagnostic for both.
    
    ## Operator Support
    
    | Operator | TFLite options | Relax lowering | Supported subset |
    |---|---|---|---|
    | `HASHTABLE_IMPORT` | `HashtableImportOptions` | store constant
    keys/values in importer state | `CALL_ONCE` init, constant keys/values,
    no duplicate keys |
    | `HASHTABLE_FIND` | `HashtableFindOptions` | constant-fold to
    `relax.const` | constant `string -> int64` table + constant string query
    |
    
    ## Not Included
    
    - Runtime (non-constant) string queries and runtime hashtable lookup.
    - `int64 -> string` find, which would require string-typed Relax
    outputs.
    - Any general `TensorType.STRING` tensor representation in Relax.
    - A runtime Relax hashtable operator with string hashing / comparison.
    - Mutable runtime resource-state threading through Relax functions.
    
    These require core Relax support and are out of scope for the frontend.
    
    ## Tests
    
    The tests manually build minimal TFLite flatbuffers and compare the
    imported
    Relax IR with `tvm.ir.assert_structural_equal`. Unsupported patterns use
    `pytest.raises`.
    
    | Test | Coverage |
    |---|---|
    | `test_hashtable_call_once_import_find_string_to_int64` | constant
    `string -> int64` find folds to a `relax.const` |
    | `test_hashtable_call_once_import_find_string_to_int64_2d_query` |
    query shape preserved for a 2-D query |
    | `test_hashtable_call_once_import_find_int64_to_string_unsupported` |
    `int64 -> string` table rejected |
    | `test_hashtable_call_once_import_find_runtime_query_unsupported` |
    runtime string query rejected |
    | `test_hashtable_call_once_import_duplicate_keys_unsupported` |
    duplicate static keys rejected |
    | `test_hashtable_lookup_string_value_unsupported` | string graph input
    now gives a clean `OpNotImplemented` |
    
    Local validation:
    
    ```bash
    python -m ruff format --check \
      python/tvm/relax/frontend/tflite/tflite_frontend.py \
      tests/python/relax/test_frontend_tflite.py
    
    python -m ruff check \
      python/tvm/relax/frontend/tflite/tflite_frontend.py \
      tests/python/relax/test_frontend_tflite.py
    
    python -m pytest \
      tests/python/relax/test_frontend_tflite.py \
      -k "hashtable or resource or variable" -q
    
    python -m pytest \
      tests/python/relax/test_frontend_tflite.py -q
    ```
    
    Result:
    
    ```text
    ruff format --check: 2 files already formatted
    ruff check: All checks passed
    14 passed, 535 deselected
    549 passed
    ```
    
    ## References
    
    - Issue #19519 item G: TFLite resource / variable / hashtable operators
    - PR #19639: TFLite resource variable and static hashtable import
    support
    - PR #19654: TFLite `HASHTABLE_LOOKUP` converter
---
 .../tvm/relax/frontend/tflite/tflite_frontend.py   | 135 +++++++++++-
 tests/python/relax/test_frontend_tflite.py         | 229 +++++++++++++++++++--
 2 files changed, 342 insertions(+), 22 deletions(-)

diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py 
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index c6fb45597c..f6db5f77e2 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -721,6 +721,45 @@ class OperatorConverter:
             and tensor_wrapper.buffer.DataLength() > 0
         )
 
+    def _get_string_tensor_value(self, tensor_wrapper, op_name):
+        """Decode a constant TFLite string tensor buffer."""
+        if not self._is_tflite_string_type(tensor_wrapper.tensor.Type()):
+            raise tvm.error.OpNotImplemented(f"{op_name} requires a 
TensorType.STRING tensor")
+        if not self._has_tensor_buffer_data(tensor_wrapper):
+            raise tvm.error.OpNotImplemented(f"{op_name} requires a constant 
string tensor")
+
+        data = bytes(tensor_wrapper.buffer.DataAsNumpy())
+        if len(data) < 4:
+            raise tvm.error.OpNotImplemented(f"{op_name} has an invalid string 
tensor buffer")
+
+        count = int(np.frombuffer(data, dtype="<i4", count=1)[0])
+        if count < 0:
+            raise tvm.error.OpNotImplemented(f"{op_name} has an invalid string 
tensor count")
+
+        header_size = 4 * (count + 2)
+        if len(data) < header_size:
+            raise tvm.error.OpNotImplemented(f"{op_name} has an invalid string 
tensor offsets")
+
+        offsets = np.frombuffer(data, dtype="<i4", count=count + 1, 
offset=4).astype(np.int64)
+        if np.any(offsets < header_size) or np.any(offsets > len(data)):
+            raise tvm.error.OpNotImplemented(f"{op_name} has out-of-bounds 
string tensor offsets")
+        if np.any(offsets[:-1] > offsets[1:]):
+            raise tvm.error.OpNotImplemented(f"{op_name} has non-monotonic 
string tensor offsets")
+
+        try:
+            values = [
+                data[int(offsets[i]) : int(offsets[i + 1])].decode("utf-8") 
for i in range(count)
+            ]
+        except UnicodeDecodeError as e:
+            raise tvm.error.OpNotImplemented(f"{op_name} has invalid UTF-8 
string data: {e}") from e
+        shape = self._get_tensor_shape_tuple(tensor_wrapper)
+        expected_count = math.prod(shape) if shape else 1
+        if expected_count != count:
+            raise tvm.error.OpNotImplemented(
+                f"{op_name} string tensor buffer count does not match its 
shape"
+            )
+        return np.array(values, dtype=object).reshape(shape)
+
     def convert_hashtable(self, op):
         """Convert a TFLite HASHTABLE into an importer-local table handle."""
         input_tensors = self.get_input_tensors(op)
@@ -769,6 +808,20 @@ class OperatorConverter:
         ):
             raise tvm.error.OpNotImplemented("HASHTABLE_IMPORT requires 
constant keys and values")
 
+        if self._is_tflite_string_type(table_info["key_dtype"]):
+            keys = self._get_string_tensor_value(key_tensor, 
"HASHTABLE_IMPORT")
+        else:
+            keys = self.get_tensor_value(key_tensor)
+        if self._is_tflite_string_type(table_info["value_dtype"]):
+            values = self._get_string_tensor_value(value_tensor, 
"HASHTABLE_IMPORT")
+        else:
+            values = self.get_tensor_value(value_tensor)
+
+        if np.unique(keys).size != keys.size:
+            raise tvm.error.OpNotImplemented(
+                "HASHTABLE_IMPORT with duplicate keys is not supported"
+            )
+
         hashtable_values = self.conversion_state["hashtable_values"]
         table_key = table_info["table_key"]
         if table_key not in hashtable_values:
@@ -776,14 +829,81 @@ class OperatorConverter:
                 "size": math.prod(key_shape) if key_shape else 1,
                 "key_dtype": table_info["key_dtype"],
                 "value_dtype": table_info["value_dtype"],
+                "keys": keys,
+                "values": values,
             }
         return None
 
     def convert_hashtable_find(self, op):
-        """Reject HASHTABLE_FIND until Relax can represent TFLite string 
tensors."""
-        raise tvm.error.OpNotImplemented(
-            "HASHTABLE_FIND requires TensorType.STRING support in Relax TFLite 
frontend"
+        """Convert the constant-foldable string-to-int64 HASHTABLE_FIND 
subset."""
+        from tflite.TensorType import TensorType
+
+        input_tensors = self.get_input_tensors(op)
+        output_tensors = self.get_output_tensors(op)
+        if len(input_tensors) != 3 or len(output_tensors) != 1:
+            raise tvm.error.OpNotImplemented(
+                "HASHTABLE_FIND expects table, query, and default inputs with 
one output"
+            )
+
+        table_tensor, query_tensor, default_tensor = input_tensors
+        output_tensor = output_tensors[0]
+        table_info = self._get_hashtable_info_for_handle(table_tensor, 
"HASHTABLE_FIND")
+        table_key = table_info["table_key"]
+        hashtable_values = self.conversion_state["hashtable_values"]
+        if table_key not in hashtable_values:
+            raise tvm.error.OpNotImplemented(
+                "HASHTABLE_FIND requires a table initialized by a supported 
CALL_ONCE subgraph"
+            )
+        table_values = hashtable_values[table_key]
+
+        if (
+            query_tensor.tensor.Type() != table_values["key_dtype"]
+            or default_tensor.tensor.Type() != table_values["value_dtype"]
+            or output_tensor.tensor.Type() != table_values["value_dtype"]
+        ):
+            raise tvm.error.OpNotImplemented("HASHTABLE_FIND key/value dtypes 
mismatch")
+
+        if not (
+            self._is_tflite_string_type(table_values["key_dtype"])
+            and table_values["value_dtype"] == TensorType.INT64
+        ):
+            raise tvm.error.OpNotImplemented(
+                "HASHTABLE_FIND only supports constant string -> int64 tables"
+            )
+        if not self._has_tensor_buffer_data(query_tensor):
+            raise tvm.error.OpNotImplemented(
+                "HASHTABLE_FIND with runtime string queries is not supported"
+            )
+        if not self._has_tensor_buffer_data(default_tensor):
+            raise tvm.error.OpNotImplemented("HASHTABLE_FIND requires constant 
default values")
+
+        query_shape = self._get_tensor_shape_tuple(query_tensor)
+        output_shape = self._get_tensor_shape_tuple(output_tensor)
+        if output_shape != query_shape:
+            raise tvm.error.OpNotImplemented("HASHTABLE_FIND output shape must 
match query shape")
+
+        query_values = self._get_string_tensor_value(query_tensor, 
"HASHTABLE_FIND")
+        default_values = self.get_tensor_value(default_tensor)
+        default_shape = self._get_tensor_shape_tuple(default_tensor)
+        if default_shape == () or default_values.size == 1:
+            result = np.full(output_shape, int(default_values.item()), 
dtype=np.int64)
+        elif default_shape == query_shape:
+            result = default_values.astype(np.int64).copy()
+        else:
+            raise tvm.error.OpNotImplemented(
+                "HASHTABLE_FIND default value must be scalar or match query 
shape"
+            )
+
+        table_map = dict(
+            zip(
+                table_values["keys"].reshape(-1).tolist(),
+                table_values["values"].reshape(-1).astype(np.int64).tolist(),
+            )
         )
+        for index, key in np.ndenumerate(query_values):
+            if key in table_map:
+                result[index] = table_map[key]
+        return relax.const(result.astype(np.int64), "int64")
 
     def convert_hashtable_lookup(self, op):
         """Convert TFLite HASHTABLE_LOOKUP for non-string value tensors."""
@@ -8834,6 +8954,15 @@ def from_tflite(
                     dtype = "float32"
                     if shape is not None:
                         shape = tuple(shape) + (2,)
+                if dtype == "string":
+                    # Relax has no string tensor type, so TFLite 
TensorType.STRING graph
+                    # inputs cannot be represented. This also covers runtime 
string queries
+                    # for ops like HASHTABLE_FIND, whose constant-foldable 
subset is handled
+                    # in the op converter.
+                    raise tvm.error.OpNotImplemented(
+                        "Relax TFLite frontend does not support 
TensorType.STRING graph inputs "
+                        "(e.g. runtime string queries)"
+                    )
                 input_var = relax.Var(
                     name_hint=model_input_name,
                     ty=relax.TensorType(shape=shape, dtype=dtype),
diff --git a/tests/python/relax/test_frontend_tflite.py 
b/tests/python/relax/test_frontend_tflite.py
index 4d3a27dfc3..6f2845da50 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -6496,25 +6496,57 @@ def _build_tflite_resource_read_uninitialized_model():
     )
 
 
-def _build_tflite_hashtable_find_model():
-    """Build a model that imports a static hashtable and finds runtime query 
keys."""
+def _build_tflite_hashtable_find_string_to_int64_model(
+    query_values=None,
+    query_shape=None,
+    default_values=None,
+    default_shape=None,
+    table_keys=None,
+    table_values=None,
+    query_is_input=False,
+):
+    """Build a static string-to-int64 HASHTABLE_FIND model."""
     builder = flatbuffers.Builder(1024)
     resource_type = _get_resource_tensor_type()
     string_type = _get_string_tensor_type()
-    table_keys = np.array([10, 20], dtype=np.int64)
-    table_values = _build_tflite_string_buffer(["one hundred", "two hundred"])
-    default_value = _build_tflite_string_buffer(["missing"])
+    query_values = ["alpha", "missing", "beta"] if query_values is None else 
query_values
+    query_shape = [len(query_values)] if query_shape is None else query_shape
+    default_shape = [] if default_shape is None else default_shape
+    table_keys = ["alpha", "beta", "gamma"] if table_keys is None else 
table_keys
+    table_values = (
+        np.array([10, 20, 30], dtype=np.int64)
+        if table_values is None
+        else np.array(table_values, dtype=np.int64)
+    )
+    default_values = (
+        np.array(-1, dtype=np.int64)
+        if default_values is None
+        else np.array(default_values, dtype=np.int64)
+    )
+    query_buffer = _build_tflite_string_buffer(query_values)
+    table_key_buffer = _build_tflite_string_buffer(table_keys)
 
     call_once_options = _build_call_once_options(builder, 1)
-    main_table_options = _build_hashtable_options(builder, table_id=0)
+    main_table_options = _build_hashtable_options(
+        builder,
+        table_id=0,
+        key_dtype=string_type,
+        value_dtype=_tfl_tensor_type.INT64,
+    )
     find_options = _build_empty_builtin_options(builder, 
"HashtableFindOptions")
-    init_table_options = _build_hashtable_options(builder, table_id=0)
+    init_table_options = _build_hashtable_options(
+        builder,
+        table_id=0,
+        key_dtype=string_type,
+        value_dtype=_tfl_tensor_type.INT64,
+    )
     import_options = _build_empty_builtin_options(builder, 
"HashtableImportOptions")
 
-    query_tensor = _build_tensor(builder, 0, [3], 
tensor_type=_tfl_tensor_type.INT64)
+    query_buffer_idx = 0 if query_is_input else 1
+    query_tensor = _build_tensor(builder, query_buffer_idx, query_shape, 
tensor_type=string_type)
     table_tensor = _build_tensor(builder, 0, [1], tensor_type=resource_type)
-    default_tensor = _build_tensor(builder, 1, [], tensor_type=string_type)
-    output_tensor = _build_tensor(builder, 0, [3], tensor_type=string_type)
+    default_tensor = _build_tensor(builder, 2, default_shape, 
tensor_type=_tfl_tensor_type.INT64)
+    output_tensor = _build_tensor(builder, 0, query_shape, 
tensor_type=_tfl_tensor_type.INT64)
     main_call_once = _build_operator(
         builder,
         0,
@@ -6543,18 +6575,119 @@ def _build_tflite_hashtable_find_model():
         builder,
         tensors=[query_tensor, table_tensor, default_tensor, output_tensor],
         operators=[main_call_once, main_hashtable, main_find],
-        inputs=[0],
+        inputs=[0] if query_is_input else [],
         outputs=[3],
     )
 
     init_table_tensor = _build_tensor(builder, 0, [1], 
tensor_type=resource_type)
-    init_keys_tensor = _build_tensor(builder, 2, [2], 
tensor_type=_tfl_tensor_type.INT64)
+    init_keys_tensor = _build_tensor(builder, 3, [len(table_keys)], 
tensor_type=string_type)
     init_values_tensor = _build_tensor(
+        builder,
+        4,
+        [len(table_values)],
+        tensor_type=_tfl_tensor_type.INT64,
+    )
+    init_hashtable = _build_operator(
+        builder,
+        1,
+        [],
+        [0],
+        builtin_options_type=_get_builtin_options_type("HashtableOptions"),
+        builtin_options=init_table_options,
+    )
+    init_import = _build_operator(
         builder,
         3,
-        [2],
-        tensor_type=string_type,
+        [0, 1, 2],
+        [],
+        
builtin_options_type=_get_builtin_options_type("HashtableImportOptions"),
+        builtin_options=import_options,
+    )
+    init_subgraph = _build_subgraph(
+        builder,
+        tensors=[init_table_tensor, init_keys_tensor, init_values_tensor],
+        operators=[init_hashtable, init_import],
+        inputs=[],
+        outputs=[],
+    )
+
+    operator_codes = [
+        _build_operator_code(builder, _get_builtin_operator("CALL_ONCE")),
+        _build_operator_code(builder, _get_builtin_operator("HASHTABLE")),
+        _build_operator_code(builder, _get_builtin_operator("HASHTABLE_FIND")),
+        _build_operator_code(builder, 
_get_builtin_operator("HASHTABLE_IMPORT")),
+    ]
+    buffers = [
+        _build_buffer(builder),
+        _build_buffer(builder, b"" if query_is_input else query_buffer),
+        _build_buffer(builder, default_values.tobytes()),
+        _build_buffer(builder, table_key_buffer),
+        _build_buffer(builder, table_values.tobytes()),
+    ]
+    return _finish_tflite_model(
+        builder,
+        subgraph=main_subgraph,
+        extra_subgraphs=[init_subgraph],
+        operator_codes=operator_codes,
+        buffers=buffers,
+    )
+
+
+def _build_tflite_hashtable_find_int64_to_string_model():
+    """Build a static int64-to-string HASHTABLE_FIND model."""
+    builder = flatbuffers.Builder(1024)
+    resource_type = _get_resource_tensor_type()
+    string_type = _get_string_tensor_type()
+    query_values = np.array([10, 30], dtype=np.int64)
+    table_keys = np.array([10, 20], dtype=np.int64)
+    table_values = _build_tflite_string_buffer(["ten", "twenty"])
+    default_value = _build_tflite_string_buffer(["missing"])
+
+    call_once_options = _build_call_once_options(builder, 1)
+    main_table_options = _build_hashtable_options(builder, table_id=0)
+    find_options = _build_empty_builtin_options(builder, 
"HashtableFindOptions")
+    init_table_options = _build_hashtable_options(builder, table_id=0)
+    import_options = _build_empty_builtin_options(builder, 
"HashtableImportOptions")
+
+    query_tensor = _build_tensor(builder, 1, [2], 
tensor_type=_tfl_tensor_type.INT64)
+    table_tensor = _build_tensor(builder, 0, [1], tensor_type=resource_type)
+    default_tensor = _build_tensor(builder, 2, [], tensor_type=string_type)
+    output_tensor = _build_tensor(builder, 0, [2], tensor_type=string_type)
+    main_call_once = _build_operator(
+        builder,
+        0,
+        [],
+        [],
+        builtin_options_type=_get_builtin_options_type("CallOnceOptions"),
+        builtin_options=call_once_options,
+    )
+    main_hashtable = _build_operator(
+        builder,
+        1,
+        [],
+        [1],
+        builtin_options_type=_get_builtin_options_type("HashtableOptions"),
+        builtin_options=main_table_options,
+    )
+    main_find = _build_operator(
+        builder,
+        2,
+        [1, 0, 2],
+        [3],
+        builtin_options_type=_get_builtin_options_type("HashtableFindOptions"),
+        builtin_options=find_options,
     )
+    main_subgraph = _build_subgraph(
+        builder,
+        tensors=[query_tensor, table_tensor, default_tensor, output_tensor],
+        operators=[main_call_once, main_hashtable, main_find],
+        inputs=[],
+        outputs=[3],
+    )
+
+    init_table_tensor = _build_tensor(builder, 0, [1], 
tensor_type=resource_type)
+    init_keys_tensor = _build_tensor(builder, 3, [2], 
tensor_type=_tfl_tensor_type.INT64)
+    init_values_tensor = _build_tensor(builder, 4, [2], 
tensor_type=string_type)
     init_hashtable = _build_operator(
         builder,
         1,
@@ -6587,6 +6720,7 @@ def _build_tflite_hashtable_find_model():
     ]
     buffers = [
         _build_buffer(builder),
+        _build_buffer(builder, query_values.tobytes()),
         _build_buffer(builder, default_value),
         _build_buffer(builder, table_keys.tobytes()),
         _build_buffer(builder, table_values),
@@ -6934,10 +7068,67 @@ def test_read_variable_uninitialized_unsupported():
         
_load_model_from_buffer(_build_tflite_resource_read_uninitialized_model())
 
 
-def test_hashtable_call_once_import_find_unsupported():
-    """Test HASHTABLE_FIND remains unsupported until TFLite string tensors are 
supported."""
-    with pytest.raises(tvm.error.OpNotImplemented, match="TensorType.STRING"):
-        _load_model_from_buffer(_build_tflite_hashtable_find_model())
+def test_hashtable_call_once_import_find_string_to_int64():
+    """Test HASHTABLE_FIND for a static string-to-int64 table."""
+    mod = 
_load_model_from_buffer(_build_tflite_hashtable_find_string_to_int64_model())
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main() -> R.Tensor((3,), dtype="int64"):
+            R.func_attr({"num_input": 0})
+            with R.dataflow():
+                gv: R.Tensor((3,), dtype="int64") = R.const([10, -1, 20], 
"int64")
+                R.output(gv)
+            return gv
+
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_hashtable_call_once_import_find_string_to_int64_2d_query():
+    """Test HASHTABLE_FIND preserves the static query shape."""
+    mod = _load_model_from_buffer(
+        _build_tflite_hashtable_find_string_to_int64_model(
+            query_values=["alpha", "beta", "missing", "gamma"],
+            query_shape=[2, 2],
+        )
+    )
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main() -> R.Tensor((2, 2), dtype="int64"):
+            R.func_attr({"num_input": 0})
+            with R.dataflow():
+                gv: R.Tensor((2, 2), dtype="int64") = R.const([[10, 20], [-1, 
30]], "int64")
+                R.output(gv)
+            return gv
+
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_hashtable_call_once_import_find_int64_to_string_unsupported():
+    """Test HASHTABLE_FIND rejects int64-to-string tables until string outputs 
exist."""
+    with pytest.raises(tvm.error.OpNotImplemented, match="string -> int64"):
+        
_load_model_from_buffer(_build_tflite_hashtable_find_int64_to_string_model())
+
+
+def test_hashtable_call_once_import_find_runtime_query_unsupported():
+    """Test HASHTABLE_FIND rejects runtime string queries."""
+    with pytest.raises(tvm.error.OpNotImplemented, match="string 
queries|STRING graph inputs"):
+        _load_model_from_buffer(
+            
_build_tflite_hashtable_find_string_to_int64_model(query_is_input=True)
+        )
+
+
+def test_hashtable_call_once_import_duplicate_keys_unsupported():
+    """Test HASHTABLE_IMPORT rejects duplicate static keys."""
+    with pytest.raises(tvm.error.OpNotImplemented, match="duplicate keys"):
+        _load_model_from_buffer(
+            _build_tflite_hashtable_find_string_to_int64_model(
+                table_keys=["alpha", "alpha"], table_values=[10, 20]
+            )
+        )
 
 
 def test_hashtable_call_once_import_size():
@@ -7126,7 +7317,7 @@ def test_hashtable_lookup_2d_value():
 
 def test_hashtable_lookup_string_value_unsupported():
     string_type = _get_string_tensor_type()
-    with pytest.raises(ValueError, match="unknown dtype `string`"):
+    with pytest.raises(tvm.error.OpNotImplemented, match="STRING graph 
inputs"):
         _load_model_from_buffer(
             _build_tflite_hashtable_lookup_model(value_shape=[3], 
value_type=string_type)
         )

Reply via email to