This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b971a75de4 [Relax][Frontend][TFLite] Add TFLite Resource Variable and
Static Hashtable Import Support (#19639)
b971a75de4 is described below
commit b971a75de46ea12692e36946f7285537057f0cc6
Author: HoYi <[email protected]>
AuthorDate: Sat May 30 02:38:12 2026 +0800
[Relax][Frontend][TFLite] Add TFLite Resource Variable and Static Hashtable
Import Support (#19639)
## Summary
This PR adds incremental Relax TFLite frontend support for the resource
variable initialization subset:
- `VAR_HANDLE`
- `ASSIGN_VARIABLE`
- `READ_VARIABLE`
It builds on the TFLite control-flow / multi-subgraph support from
#19616,
especially `CALL_ONCE`. TFLite commonly represents initialization
through a
`CALL_ONCE` init subgraph, then uses resource handles from the main
subgraph to
read initialized variables. This PR supports that constrained
initialization
pattern without introducing general mutable runtime state into Relax.
The PR also adds explicit frontend guards for the TFLite builtin
hashtable
operators:
- `HASHTABLE`
- `HASHTABLE_IMPORT`
- `HASHTABLE_FIND`
- `HASHTABLE_SIZE`
These operators are intentionally left unsupported for now. TFLite
builtin
hashtable kernels are not generic tensor maps: their runtime
implementations
cover the `int64 -> string` and `string -> int64` table variants, and
correct
import requires proper `TensorType.STRING` support. Rejecting the
operators is
safer than lowering a synthetic numeric table semantics that TFLite does
not
actually implement.
## Design
### Shared Initialization State
The frontend now keeps resource initialization data in shared conversion
state:
- `conversion_state["resource_values"]`
- `conversion_state["in_call_once_init"]`
This state is shared by the main graph converter and the `CALL_ONCE`
init
subgraph converter. Each converter instance still keeps its own local
`self.resource_handles` map, keyed by TFLite tensor name.
Resource variables use `container + shared_name` from `VarHandleOptions`
when
present, falling back to the handle tensor name. This keeps tensor-name
bindings
scoped to each subgraph while allowing init subgraphs and the main graph
to
agree on the same logical resource.
### CALL_ONCE Init Subgraphs
`CALL_ONCE` now accepts a non-empty init subgraph when all operators are
in the
supported initialization subset:
- `VAR_HANDLE`
- `ASSIGN_VARIABLE`
The init subgraph still must have no inputs and no outputs. The
converter first
checks every operator against the allowlist, then converts the init
subgraph
with a fresh `ExprTable` and shared conversion state.
The init subconverter deliberately shares the parent `BlockBuilder`.
This is
safe for the current subset because all supported init operators update
importer
state and return `None`; they do not emit Relax bindings. A comment
documents
that this should be revisited if future `CALL_ONCE` init operators emit
Relax
expressions.
### Resource Variables
`VAR_HANDLE` is declarative. It registers the output resource tensor in
the
current converter's local `resource_handles` map and returns `None`.
`ASSIGN_VARIABLE` is accepted only while converting a supported
`CALL_ONCE` init
subgraph. It resolves the resource handle through the init converter's
local
handle map and stores the assigned tensor expression in shared
`conversion_state["resource_values"]`.
`READ_VARIABLE` resolves the main graph resource handle and returns the
initialized expression from shared state. If the resource has not been
initialized by a supported `CALL_ONCE` path, the frontend raises
`OpNotImplemented`.
This supports the common static-initialization inference pattern while
avoiding
incorrect lowering for runtime mutation.
### Hashtable Operators
`HASHTABLE` registers the table handle and validates the dtype pair
against
TFLite kernel constraints (`int64/string` or `string/int64`).
`HASHTABLE_IMPORT` in a supported `CALL_ONCE` init subgraph captures
static
metadata (table size, key/value dtypes) but does not store actual string
data,
because Relax does not yet support `TensorType.STRING`.
`HASHTABLE_SIZE` returns a scalar Relax constant for statically imported
tables.
`HASHTABLE_FIND` is rejected with `OpNotImplemented` because Relax
cannot
represent TFLite string tensors or the runtime lookup semantics.
## Operator Support
| Operator | TFLite options | Relax lowering | Supported subset |
|---|---|---|---|
| `VAR_HANDLE` | `VarHandleOptions` | handle registration only | main
graph and supported `CALL_ONCE` init subgraphs |
| `ASSIGN_VARIABLE` | `AssignVariableOptions` | store initialized Relax
expression in shared importer state | supported `CALL_ONCE` init
subgraphs only |
| `READ_VARIABLE` | `ReadVariableOptions` | return initialized Relax
expression | resource must have supported static initialization |
| `HASHTABLE` | `HashtableOptions` | handle registration + dtype
validation | validates `int64/string` or `string/int64` pair, rejects
other combinations |
| `HASHTABLE_IMPORT` | `HashtableImportOptions` | store static metadata
(size, key/value dtype) | `CALL_ONCE` init subgraphs only, constant
key/value shape validation |
| `HASHTABLE_FIND` | `HashtableFindOptions` | unsupported guard |
requires future `TensorType.STRING` support in Relax |
| `HASHTABLE_SIZE` | `HashtableSizeOptions` | scalar Relax constant |
returns `[size]` int64 for statically imported tables |
## Safety Checks
- `ASSIGN_VARIABLE` outside `CALL_ONCE` initialization raises
`OpNotImplemented`.
- `READ_VARIABLE` without supported initialization raises
`OpNotImplemented`.
- `CALL_ONCE` init subgraphs with inputs or outputs remain unsupported.
- `CALL_ONCE` init subgraphs containing operators outside the
resource-variable
initialization allowlist remain unsupported.
- TFLite builtin hashtable operators raise `OpNotImplemented` until the
frontend can model their real int64/string table semantics.
## Not Included
- Runtime `ASSIGN_VARIABLE` mutation in the main graph.
- Runtime resource-state threading through Relax function parameters and
returns.
- Cross-subgraph resource handle aliasing beyond the static
`container/shared_name` matching pattern.
- Multiple runtime writes with ordering semantics.
- TFLite builtin hashtable lowering.
- `TensorType.STRING` import support.
## Tests
The tests manually build minimal TFLite flatbuffers and compare imported
Relax
IR with `tvm.ir.assert_structural_equal`. Unsupported patterns use
`pytest.raises`.
| Test | Coverage |
|---|---|
| `test_resource_variable_call_once_init_read` | `CALL_ONCE` init
subgraph with `VAR_HANDLE + ASSIGN_VARIABLE`, then main graph
`READ_VARIABLE` |
| `test_assign_variable_main_subgraph_unsupported` | runtime/main graph
`ASSIGN_VARIABLE` guard |
| `test_read_variable_uninitialized_unsupported` | `READ_VARIABLE`
without supported initialization guard |
| `test_hashtable_call_once_import_find_unsupported` | hashtable
init/find path remains unsupported |
| `test_hashtable_call_once_import_size_unsupported` | hashtable
init/size path remains unsupported |
| `test_hashtable_import_main_subgraph_unsupported` | main graph
`HASHTABLE_IMPORT` remains unsupported |
| `test_hashtable_size_uninitialized_unsupported` | uninitialized
`HASHTABLE_SIZE` remains unsupported |
Local validation:
```bash
python -m py_compile \
python/tvm/relax/frontend/tflite/tflite_frontend.py \
tests/python/relax/test_frontend_tflite.py
python -m ruff format --check \
python/tvm/relax/frontend/tflite/tflite_frontend.py \
tests/python/relax/test_frontend_tflite.py
python -m ruff check \
python/tvm/relax/frontend/tflite/tflite_frontend.py \
tests/python/relax/test_frontend_tflite.py
python -m pytest \
tests/python/relax/test_frontend_tflite.py \
-k "resource_variable or read_variable_uninitialized or hashtable" -q
python -m pytest \
tests/python/relax/test_frontend_tflite.py -q
```
Result:
```text
py_compile: passed
ruff format --check: files already formatted
ruff check: All checks passed
targeted resource/hashtable tests: 6 passed
full test_frontend_tflite.py: 472 passed
```
---
.../tvm/relax/frontend/tflite/tflite_frontend.py | 292 +++++++++-
tests/python/relax/test_frontend_tflite.py | 611 +++++++++++++++++++++
2 files changed, 893 insertions(+), 10 deletions(-)
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 87697dc6ad..c479ec83c1 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -175,10 +175,18 @@ class OperatorConverter:
"lowered_while_functions": {},
"lowering_stack": [],
"module_builder": ctx,
+ "resource_values": {},
+ "hashtable_values": {},
+ "in_call_once_init": False,
}
else:
conversion_state.setdefault("module_builder", ctx)
+ conversion_state.setdefault("resource_values", {})
+ conversion_state.setdefault("hashtable_values", {})
+ conversion_state.setdefault("in_call_once_init", False)
self.conversion_state = conversion_state
+ self.resource_handles = {}
+ self.hashtable_handles = {}
# Add more operators
self.convert_map = {
@@ -187,6 +195,7 @@ class OperatorConverter:
"ADD_N": self.convert_add_n,
"ARG_MAX": functools.partial(self._convert_arg_min_max,
relax_op=_op.argmax),
"ARG_MIN": functools.partial(self._convert_arg_min_max,
relax_op=_op.argmin),
+ "ASSIGN_VARIABLE": self.convert_assign_variable,
"ATAN2": functools.partial(self._convert_elemwise,
relax_op=_op.atan2),
"AVERAGE_POOL_2D": functools.partial(self.convert_pool2d,
pool_type="average"),
"BATCH_TO_SPACE_ND": self.convert_batch_to_space_nd,
@@ -234,6 +243,10 @@ class OperatorConverter:
),
"GELU": self.convert_gelu,
"HARD_SWISH": self.convert_hard_swish,
+ "HASHTABLE": self.convert_hashtable,
+ "HASHTABLE_FIND": self.convert_hashtable_find,
+ "HASHTABLE_IMPORT": self.convert_hashtable_import,
+ "HASHTABLE_SIZE": self.convert_hashtable_size,
"IF": self.convert_if,
"L2_NORMALIZATION": self.convert_l2_normalization,
"L2_POOL_2D": functools.partial(self.convert_pool2d,
pool_type="l2"),
@@ -277,6 +290,7 @@ class OperatorConverter:
"QUANTIZE": self.convert_quantize,
"RANDOM_STANDARD_NORMAL": self.convert_random_standard_normal,
"RANDOM_UNIFORM": self.convert_random_uniform,
+ "READ_VARIABLE": self.convert_read_variable,
"REDUCE_ALL": functools.partial(self._convert_reduce_bool,
relax_op=_op.min),
"REDUCE_ANY": functools.partial(self._convert_reduce_bool,
relax_op=_op.max),
"REDUCE_MAX": functools.partial(self._convert_reduce,
relax_op=_op.max),
@@ -391,6 +405,7 @@ class OperatorConverter:
self._convert_segment_op, op_name="UNSORTED_SEGMENT_PROD",
reduction="mul"
),
# "UNIDIRECTIONAL_SEQUENCE_LSTM":
self.convert_unidirectional_sequence_lstm,
+ "VAR_HANDLE": self.convert_var_handle,
"WHERE": self.convert_select,
"WHILE": self.convert_while,
"ZEROS_LIKE": self.convert_zeros_like,
@@ -518,6 +533,244 @@ class OperatorConverter:
get_tensor_name(self.subgraph,
output_tensor.tensor_idx), ret[idx]
)
+ @staticmethod
+ def _decode_tflite_string(value):
+ """Decode a TFLite string field."""
+ if value is None:
+ return ""
+ if isinstance(value, bytes | bytearray):
+ return value.decode("utf-8")
+ return str(value)
+
+ def _get_var_handle_resource_key(self, op, fallback_tensor=None):
+ """Return a stable resource key for a VAR_HANDLE op."""
+ container = ""
+ shared_name = ""
+ if op.BuiltinOptions() is not None:
+ try:
+ from tflite.VarHandleOptions import VarHandleOptions
+
+ opts = self._get_builtin_options(op, VarHandleOptions)
+ if hasattr(opts, "Container"):
+ container = self._decode_tflite_string(opts.Container())
+ if hasattr(opts, "SharedName"):
+ shared_name = self._decode_tflite_string(opts.SharedName())
+ except (ImportError, ModuleNotFoundError):
+ pass
+
+ if container or shared_name:
+ return (container, shared_name)
+ if fallback_tensor is not None:
+ return ("", get_tensor_name(self.subgraph,
fallback_tensor.tensor_idx))
+ raise tvm.error.OpNotImplemented("VAR_HANDLE requires
VarHandleOptions")
+
+ def _get_resource_key_for_handle(self, tensor, op_name):
+ tensor_name = get_tensor_name(self.subgraph, tensor.tensor_idx)
+ if tensor_name not in self.resource_handles:
+ raise tvm.error.OpNotImplemented(
+ f"{op_name} requires a VAR_HANDLE in the same TFLite subgraph"
+ )
+ return self.resource_handles[tensor_name]
+
+ def convert_var_handle(self, op):
+ """Convert a TFLite VAR_HANDLE into an importer-local resource
handle."""
+ input_tensors = self.get_input_tensors(op)
+ output_tensors = self.get_output_tensors(op)
+ if len(input_tensors) != 0 or len(output_tensors) != 1:
+ raise tvm.error.OpNotImplemented("VAR_HANDLE expects no inputs and
one output")
+
+ resource_key = self._get_var_handle_resource_key(op, output_tensors[0])
+ resource_tensor_name = get_tensor_name(self.subgraph,
output_tensors[0].tensor_idx)
+ self.resource_handles[resource_tensor_name] = resource_key
+ return None
+
+ def convert_assign_variable(self, op):
+ """Convert the CALL_ONCE initialization subset of ASSIGN_VARIABLE."""
+ if not self.conversion_state["in_call_once_init"]:
+ raise tvm.error.OpNotImplemented(
+ "ASSIGN_VARIABLE outside CALL_ONCE initialization is not
supported by the "
+ "Relax TFLite frontend yet because it requires mutable
resource state modeling."
+ )
+
+ input_tensors = self.get_input_tensors(op)
+ output_tensors = self.get_output_tensors(op)
+ if len(input_tensors) != 2 or len(output_tensors) != 0:
+ raise tvm.error.OpNotImplemented(
+ "ASSIGN_VARIABLE expects a resource handle and value input
with no outputs"
+ )
+
+ resource_key = self._get_resource_key_for_handle(input_tensors[0],
"ASSIGN_VARIABLE")
+ self.conversion_state["resource_values"][resource_key] =
self.get_tensor_expr(
+ input_tensors[1]
+ )
+ return None
+
+ def convert_read_variable(self, op):
+ """Convert READ_VARIABLE for resources initialized by CALL_ONCE."""
+ input_tensors = self.get_input_tensors(op)
+ output_tensors = self.get_output_tensors(op)
+ if len(input_tensors) != 1 or len(output_tensors) != 1:
+ raise tvm.error.OpNotImplemented("READ_VARIABLE expects one input
and one output")
+
+ resource_key = self._get_resource_key_for_handle(input_tensors[0],
"READ_VARIABLE")
+ resource_values = self.conversion_state["resource_values"]
+ if resource_key not in resource_values:
+ raise tvm.error.OpNotImplemented(
+ "READ_VARIABLE requires a resource initialized by a supported
CALL_ONCE subgraph"
+ )
+ return resource_values[resource_key]
+
+ def _is_tflite_string_type(self, tensor_type):
+ from tflite.TensorType import TensorType
+
+ return hasattr(TensorType, "STRING") and tensor_type ==
TensorType.STRING
+
+ def _is_supported_hashtable_type_pair(self, key_dtype, value_dtype):
+ from tflite.TensorType import TensorType
+
+ return (key_dtype == TensorType.INT64 and
self._is_tflite_string_type(value_dtype)) or (
+ self._is_tflite_string_type(key_dtype) and value_dtype ==
TensorType.INT64
+ )
+
+ def _get_hashtable_key(self, op, fallback_tensor=None):
+ """Return a stable key and TFLite dtype pair for a HASHTABLE
resource."""
+ table_id = None
+ key_dtype = None
+ value_dtype = None
+ if op.BuiltinOptions() is not None:
+ try:
+ from tflite.HashtableOptions import HashtableOptions
+
+ opts = self._get_builtin_options(op, HashtableOptions)
+ table_id = int(opts.TableId())
+ key_dtype = int(opts.KeyDtype())
+ value_dtype = int(opts.ValueDtype())
+ except (ImportError, ModuleNotFoundError):
+ pass
+
+ if key_dtype is None or value_dtype is None:
+ raise tvm.error.OpNotImplemented("HASHTABLE requires
HashtableOptions")
+ if not self._is_supported_hashtable_type_pair(key_dtype, value_dtype):
+ raise tvm.error.OpNotImplemented(
+ "TFLite HASHTABLE only supports int64/string or string/int64
tables"
+ )
+
+ if table_id is not None:
+ return table_id, key_dtype, value_dtype
+ if fallback_tensor is not None:
+ return (
+ get_tensor_name(self.subgraph, fallback_tensor.tensor_idx),
+ key_dtype,
+ value_dtype,
+ )
+ raise tvm.error.OpNotImplemented("HASHTABLE requires HashtableOptions")
+
+ def _get_hashtable_info_for_handle(self, tensor, op_name):
+ tensor_name = get_tensor_name(self.subgraph, tensor.tensor_idx)
+ if tensor_name not in self.hashtable_handles:
+ raise tvm.error.OpNotImplemented(
+ f"{op_name} requires a HASHTABLE in the same TFLite subgraph"
+ )
+ return self.hashtable_handles[tensor_name]
+
+ @staticmethod
+ def _get_tensor_shape_tuple(tensor_wrapper):
+ if tensor_wrapper.tensor.ShapeLength() == 0:
+ return ()
+ return tuple(int(dim) for dim in tensor_wrapper.tensor.ShapeAsNumpy())
+
+ @staticmethod
+ def _has_tensor_buffer_data(tensor_wrapper):
+ return (
+ tensor_wrapper.buffer is not None
+ and hasattr(tensor_wrapper.buffer, "DataLength")
+ and tensor_wrapper.buffer.DataLength() > 0
+ )
+
+ def convert_hashtable(self, op):
+ """Convert a TFLite HASHTABLE into an importer-local table handle."""
+ input_tensors = self.get_input_tensors(op)
+ output_tensors = self.get_output_tensors(op)
+ if len(input_tensors) != 0 or len(output_tensors) != 1:
+ raise tvm.error.OpNotImplemented("HASHTABLE expects no inputs and
one output")
+
+ table_key, key_dtype, value_dtype = self._get_hashtable_key(op,
output_tensors[0])
+ table_tensor_name = get_tensor_name(self.subgraph,
output_tensors[0].tensor_idx)
+ self.hashtable_handles[table_tensor_name] = {
+ "table_key": table_key,
+ "key_dtype": key_dtype,
+ "value_dtype": value_dtype,
+ }
+ return None
+
+ def convert_hashtable_import(self, op):
+ """Convert static metadata for the CALL_ONCE HASHTABLE_IMPORT
subset."""
+ if not self.conversion_state["in_call_once_init"]:
+ raise tvm.error.OpNotImplemented(
+ "HASHTABLE_IMPORT outside CALL_ONCE initialization is not
supported by the "
+ "Relax TFLite frontend yet because it requires mutable
resource state modeling."
+ )
+
+ input_tensors = self.get_input_tensors(op)
+ output_tensors = self.get_output_tensors(op)
+ if len(input_tensors) != 3 or len(output_tensors) != 0:
+ raise tvm.error.OpNotImplemented(
+ "HASHTABLE_IMPORT expects table, keys, and values inputs with
no outputs"
+ )
+
+ table_info = self._get_hashtable_info_for_handle(input_tensors[0],
"HASHTABLE_IMPORT")
+ key_tensor = input_tensors[1]
+ value_tensor = input_tensors[2]
+ if (
+ key_tensor.tensor.Type() != table_info["key_dtype"]
+ or value_tensor.tensor.Type() != table_info["value_dtype"]
+ ):
+ raise tvm.error.OpNotImplemented("HASHTABLE_IMPORT key/value
dtypes mismatch")
+ key_shape = self._get_tensor_shape_tuple(key_tensor)
+ value_shape = self._get_tensor_shape_tuple(value_tensor)
+ if key_shape != value_shape:
+ raise tvm.error.OpNotImplemented("HASHTABLE_IMPORT requires keys
and values same shape")
+ if not self._has_tensor_buffer_data(key_tensor) or not
self._has_tensor_buffer_data(
+ value_tensor
+ ):
+ raise tvm.error.OpNotImplemented("HASHTABLE_IMPORT requires
constant keys and values")
+
+ hashtable_values = self.conversion_state["hashtable_values"]
+ table_key = table_info["table_key"]
+ if table_key not in hashtable_values:
+ hashtable_values[table_key] = {
+ "size": math.prod(key_shape) if key_shape else 1,
+ "key_dtype": table_info["key_dtype"],
+ "value_dtype": table_info["value_dtype"],
+ }
+ return None
+
+ def convert_hashtable_find(self, op):
+ """Reject HASHTABLE_FIND until Relax can represent TFLite string
tensors."""
+ raise tvm.error.OpNotImplemented(
+ "HASHTABLE_FIND requires TensorType.STRING support in Relax TFLite
frontend"
+ )
+
+ def convert_hashtable_size(self, op):
+ """Convert HASHTABLE_SIZE for a statically imported TFLite
hashtable."""
+ input_tensors = self.get_input_tensors(op)
+ output_tensors = self.get_output_tensors(op)
+ if len(input_tensors) != 1 or len(output_tensors) != 1:
+ raise tvm.error.OpNotImplemented("HASHTABLE_SIZE expects one input
and one output")
+
+ from tflite.TensorType import TensorType
+
+ if output_tensors[0].tensor.Type() != TensorType.INT64:
+ raise tvm.error.OpNotImplemented("HASHTABLE_SIZE output must be
int64")
+ table_info = self._get_hashtable_info_for_handle(input_tensors[0],
"HASHTABLE_SIZE")
+ table_key = table_info["table_key"]
+ hashtable_values = self.conversion_state["hashtable_values"]
+ if table_key not in hashtable_values:
+ raise tvm.error.OpNotImplemented(
+ "HASHTABLE_SIZE requires a table initialized by a supported
CALL_ONCE subgraph"
+ )
+ return relax.const(np.array([hashtable_values[table_key]["size"]],
dtype=np.int64), "int64")
+
def get_op_code_str(self, op):
"""Get TFLite ops string representation"""
@@ -2290,13 +2543,7 @@ class OperatorConverter:
return relax.Call(loop_gv, args)
def convert_call_once(self, op):
- """Convert the no-op subset of TFLite CALL_ONCE.
-
- Non-empty CALL_ONCE init subgraphs are used for resource initialization
- side effects in TFLite. The Relax TFLite frontend does not yet support
- TFLite resource variable operators, so only the empty no-op form is
safe
- to import.
- """
+ """Convert TFLite CALL_ONCE for no-op and resource-variable
initialization subsets."""
from tflite.CallOnceOptions import CallOnceOptions
opts = self._get_builtin_options(op, CallOnceOptions)
@@ -2312,11 +2559,36 @@ class OperatorConverter:
"CALL_ONCE with non-empty init subgraph I/O is not supported"
)
if init_subgraph.OperatorsLength() != 0:
- raise tvm.error.OpNotImplemented(
- "CALL_ONCE with non-empty init subgraphs is not supported"
- )
+ self._convert_call_once_init_subgraph(init_subgraph)
return None
+ def _convert_call_once_init_subgraph(self, init_subgraph):
+ """Convert the resource-variable initialization subset of a CALL_ONCE
subgraph."""
+ supported_init_ops = {"VAR_HANDLE", "ASSIGN_VARIABLE", "HASHTABLE",
"HASHTABLE_IMPORT"}
+ for op_idx in range(init_subgraph.OperatorsLength()):
+ op_name = self.get_op_code_str(init_subgraph.Operators(op_idx))
+ if op_name not in supported_init_ops:
+ raise tvm.error.OpNotImplemented(
+ f"CALL_ONCE init subgraph operator {op_name} is not
supported"
+ )
+
+ old_in_call_once_init = self.conversion_state["in_call_once_init"]
+ self.conversion_state["in_call_once_init"] = True
+ try:
+ # The supported init ops below only update importer state and
return None.
+ # If future CALL_ONCE ops emit Relax bindings, revisit sharing the
parent builder.
+ subgraph_converter = type(self)(
+ self.model,
+ init_subgraph,
+ ExprTable(),
+ self.bb,
+ self.conversion_state,
+ )
+ subgraph_converter.check_unsupported_ops()
+ subgraph_converter.convert_op_to_relax()
+ finally:
+ self.conversion_state["in_call_once_init"] = old_in_call_once_init
+
def _convert_stablehlo_convert(self, op):
"""Convert STABLEHLO_CONVERT to Relax (astype).
diff --git a/tests/python/relax/test_frontend_tflite.py
b/tests/python/relax/test_frontend_tflite.py
index 263943ad6a..e9ccea7ad1 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -3943,6 +3943,78 @@ def _build_call_once_options(builder,
init_subgraph_index):
return _tfl_call_once_options.CallOnceOptionsEnd(builder)
+def _get_builtin_options_type(options_name):
+ if not hasattr(_tfl_builtin_options, options_name):
+ pytest.skip(f"TFLite schema does not provide
BuiltinOptions.{options_name}")
+ return getattr(_tfl_builtin_options, options_name)
+
+
+def _get_resource_tensor_type():
+ if not hasattr(_tfl_tensor_type, "RESOURCE"):
+ pytest.skip("TFLite schema does not provide TensorType.RESOURCE")
+ return getattr(_tfl_tensor_type, "RESOURCE")
+
+
+def _get_string_tensor_type():
+ if not hasattr(_tfl_tensor_type, "STRING"):
+ pytest.skip("TFLite schema does not provide TensorType.STRING")
+ return getattr(_tfl_tensor_type, "STRING")
+
+
+def _build_tflite_string_buffer(values):
+ encoded = [value.encode("utf-8") for value in values]
+ offsets = []
+ cursor = 4 * (len(encoded) + 2)
+ for value in encoded:
+ offsets.append(cursor)
+ cursor += len(value)
+ offsets.append(cursor)
+ header = np.array([len(encoded), *offsets], dtype=np.int32).tobytes()
+ return header + b"".join(encoded)
+
+
+def _build_var_handle_options(builder, shared_name="resource_var",
container=""):
+ try:
+ var_handle_options = _get_tflite_schema_module("VarHandleOptions")
+ except ModuleNotFoundError:
+ pytest.skip("TFLite schema does not provide VarHandleOptions")
+ container_offset = builder.CreateString(container)
+ shared_name_offset = builder.CreateString(shared_name)
+ var_handle_options.VarHandleOptionsStart(builder)
+ var_handle_options.VarHandleOptionsAddContainer(builder, container_offset)
+ var_handle_options.VarHandleOptionsAddSharedName(builder,
shared_name_offset)
+ return var_handle_options.VarHandleOptionsEnd(builder)
+
+
+def _build_empty_builtin_options(builder, options_name):
+ try:
+ options_module = _get_tflite_schema_module(options_name)
+ except ModuleNotFoundError:
+ pytest.skip(f"TFLite schema does not provide {options_name}")
+ getattr(options_module, f"{options_name}Start")(builder)
+ return getattr(options_module, f"{options_name}End")(builder)
+
+
+def _build_hashtable_options(
+ builder,
+ table_id=0,
+ key_dtype=None,
+ value_dtype=None,
+):
+ try:
+ hashtable_options = _get_tflite_schema_module("HashtableOptions")
+ except ModuleNotFoundError:
+ pytest.skip("TFLite schema does not provide HashtableOptions")
+
+ key_dtype = _tfl_tensor_type.INT64 if key_dtype is None else key_dtype
+ value_dtype = _get_string_tensor_type() if value_dtype is None else
value_dtype
+ hashtable_options.HashtableOptionsStart(builder)
+ hashtable_options.HashtableOptionsAddTableId(builder, table_id)
+ hashtable_options.HashtableOptionsAddKeyDtype(builder, key_dtype)
+ hashtable_options.HashtableOptionsAddValueDtype(builder, value_dtype)
+ return hashtable_options.HashtableOptionsEnd(builder)
+
+
def _load_model_from_buffer(model_bytes):
if hasattr(tflite.Model, "Model"):
tflite_model = tflite.Model.Model.GetRootAsModel(model_bytes, 0)
@@ -5275,6 +5347,545 @@ def test_call_once_invalid_index_unsupported():
_load_model_from_buffer(_build_tflite_call_once_model(init_subgraph_index=2))
+def _build_tflite_resource_variable_model():
+ """Build a model that initializes a resource variable in CALL_ONCE and
reads it."""
+ builder = flatbuffers.Builder(1024)
+ resource_type = _get_resource_tensor_type()
+ initial_value = np.array([1.0, 2.0], dtype=np.float32)
+
+ call_once_options = _build_call_once_options(builder, 1)
+ main_var_handle_options = _build_var_handle_options(builder)
+ main_read_options = _build_empty_builtin_options(builder,
"ReadVariableOptions")
+ init_var_handle_options = _build_var_handle_options(builder)
+ init_assign_options = _build_empty_builtin_options(builder,
"AssignVariableOptions")
+
+ resource_tensor = _build_tensor(builder, 0, [], tensor_type=resource_type)
+ main_output_tensor = _build_tensor(builder, 0, [2])
+ main_call_once = _build_operator(
+ builder,
+ 0,
+ [],
+ [],
+ builtin_options_type=_get_builtin_options_type("CallOnceOptions"),
+ builtin_options=call_once_options,
+ )
+ main_var_handle = _build_operator(
+ builder,
+ 1,
+ [],
+ [0],
+ builtin_options_type=_get_builtin_options_type("VarHandleOptions"),
+ builtin_options=main_var_handle_options,
+ )
+ main_read = _build_operator(
+ builder,
+ 2,
+ [0],
+ [1],
+ builtin_options_type=_get_builtin_options_type("ReadVariableOptions"),
+ builtin_options=main_read_options,
+ )
+ main_subgraph = _build_subgraph(
+ builder,
+ tensors=[resource_tensor, main_output_tensor],
+ operators=[main_call_once, main_var_handle, main_read],
+ inputs=[],
+ outputs=[1],
+ )
+
+ init_resource_tensor = _build_tensor(builder, 0, [],
tensor_type=resource_type)
+ init_value_tensor = _build_tensor(builder, 1, [2])
+ init_var_handle = _build_operator(
+ builder,
+ 1,
+ [],
+ [0],
+ builtin_options_type=_get_builtin_options_type("VarHandleOptions"),
+ builtin_options=init_var_handle_options,
+ )
+ init_assign = _build_operator(
+ builder,
+ 3,
+ [0, 1],
+ [],
+
builtin_options_type=_get_builtin_options_type("AssignVariableOptions"),
+ builtin_options=init_assign_options,
+ )
+ init_subgraph = _build_subgraph(
+ builder,
+ tensors=[init_resource_tensor, init_value_tensor],
+ operators=[init_var_handle, init_assign],
+ inputs=[],
+ outputs=[],
+ )
+
+ operator_codes = [
+ _build_operator_code(builder, _get_builtin_operator("CALL_ONCE")),
+ _build_operator_code(builder, _get_builtin_operator("VAR_HANDLE")),
+ _build_operator_code(builder, _get_builtin_operator("READ_VARIABLE")),
+ _build_operator_code(builder,
_get_builtin_operator("ASSIGN_VARIABLE")),
+ ]
+ buffers = [_build_buffer(builder), _build_buffer(builder,
initial_value.tobytes())]
+ return _finish_tflite_model(
+ builder,
+ subgraph=main_subgraph,
+ extra_subgraphs=[init_subgraph],
+ operator_codes=operator_codes,
+ buffers=buffers,
+ )
+
+
+def _build_tflite_resource_assign_in_main_model():
+ """Build a model that attempts to assign a resource variable in the main
subgraph."""
+ builder = flatbuffers.Builder(1024)
+ resource_type = _get_resource_tensor_type()
+ value = np.array([1.0, 2.0], dtype=np.float32)
+
+ var_handle_options = _build_var_handle_options(builder)
+ assign_options = _build_empty_builtin_options(builder,
"AssignVariableOptions")
+ resource_tensor = _build_tensor(builder, 0, [], tensor_type=resource_type)
+ value_tensor = _build_tensor(builder, 1, [2])
+ var_handle = _build_operator(
+ builder,
+ 0,
+ [],
+ [0],
+ builtin_options_type=_get_builtin_options_type("VarHandleOptions"),
+ builtin_options=var_handle_options,
+ )
+ assign = _build_operator(
+ builder,
+ 1,
+ [0, 1],
+ [],
+
builtin_options_type=_get_builtin_options_type("AssignVariableOptions"),
+ builtin_options=assign_options,
+ )
+ main_subgraph = _build_subgraph(
+ builder,
+ tensors=[resource_tensor, value_tensor],
+ operators=[var_handle, assign],
+ inputs=[],
+ outputs=[1],
+ )
+ operator_codes = [
+ _build_operator_code(builder, _get_builtin_operator("VAR_HANDLE")),
+ _build_operator_code(builder,
_get_builtin_operator("ASSIGN_VARIABLE")),
+ ]
+ buffers = [_build_buffer(builder), _build_buffer(builder, value.tobytes())]
+ return _finish_tflite_model(
+ builder,
+ subgraph=main_subgraph,
+ operator_codes=operator_codes,
+ buffers=buffers,
+ )
+
+
+def _build_tflite_resource_read_uninitialized_model():
+ """Build a model that reads a resource variable without CALL_ONCE
initialization."""
+ builder = flatbuffers.Builder(1024)
+ resource_type = _get_resource_tensor_type()
+
+ var_handle_options = _build_var_handle_options(builder)
+ read_options = _build_empty_builtin_options(builder, "ReadVariableOptions")
+ resource_tensor = _build_tensor(builder, 0, [], tensor_type=resource_type)
+ output_tensor = _build_tensor(builder, 0, [2])
+ var_handle = _build_operator(
+ builder,
+ 0,
+ [],
+ [0],
+ builtin_options_type=_get_builtin_options_type("VarHandleOptions"),
+ builtin_options=var_handle_options,
+ )
+ read = _build_operator(
+ builder,
+ 1,
+ [0],
+ [1],
+ builtin_options_type=_get_builtin_options_type("ReadVariableOptions"),
+ builtin_options=read_options,
+ )
+ main_subgraph = _build_subgraph(
+ builder,
+ tensors=[resource_tensor, output_tensor],
+ operators=[var_handle, read],
+ inputs=[],
+ outputs=[1],
+ )
+ operator_codes = [
+ _build_operator_code(builder, _get_builtin_operator("VAR_HANDLE")),
+ _build_operator_code(builder, _get_builtin_operator("READ_VARIABLE")),
+ ]
+ return _finish_tflite_model(
+ builder,
+ subgraph=main_subgraph,
+ operator_codes=operator_codes,
+ buffers=[_build_buffer(builder)],
+ )
+
+
+def _build_tflite_hashtable_find_model():
+ """Build a model that imports a static hashtable and finds runtime query
keys."""
+ builder = flatbuffers.Builder(1024)
+ resource_type = _get_resource_tensor_type()
+ string_type = _get_string_tensor_type()
+ table_keys = np.array([10, 20], dtype=np.int64)
+ table_values = _build_tflite_string_buffer(["one hundred", "two hundred"])
+ default_value = _build_tflite_string_buffer(["missing"])
+
+ call_once_options = _build_call_once_options(builder, 1)
+ main_table_options = _build_hashtable_options(builder, table_id=0)
+ find_options = _build_empty_builtin_options(builder,
"HashtableFindOptions")
+ init_table_options = _build_hashtable_options(builder, table_id=0)
+ import_options = _build_empty_builtin_options(builder,
"HashtableImportOptions")
+
+ query_tensor = _build_tensor(builder, 0, [3],
tensor_type=_tfl_tensor_type.INT64)
+ table_tensor = _build_tensor(builder, 0, [1], tensor_type=resource_type)
+ default_tensor = _build_tensor(builder, 1, [], tensor_type=string_type)
+ output_tensor = _build_tensor(builder, 0, [3], tensor_type=string_type)
+ main_call_once = _build_operator(
+ builder,
+ 0,
+ [],
+ [],
+ builtin_options_type=_get_builtin_options_type("CallOnceOptions"),
+ builtin_options=call_once_options,
+ )
+ main_hashtable = _build_operator(
+ builder,
+ 1,
+ [],
+ [1],
+ builtin_options_type=_get_builtin_options_type("HashtableOptions"),
+ builtin_options=main_table_options,
+ )
+ main_find = _build_operator(
+ builder,
+ 2,
+ [1, 0, 2],
+ [3],
+ builtin_options_type=_get_builtin_options_type("HashtableFindOptions"),
+ builtin_options=find_options,
+ )
+ main_subgraph = _build_subgraph(
+ builder,
+ tensors=[query_tensor, table_tensor, default_tensor, output_tensor],
+ operators=[main_call_once, main_hashtable, main_find],
+ inputs=[0],
+ outputs=[3],
+ )
+
+ init_table_tensor = _build_tensor(builder, 0, [1],
tensor_type=resource_type)
+ init_keys_tensor = _build_tensor(builder, 2, [2],
tensor_type=_tfl_tensor_type.INT64)
+ init_values_tensor = _build_tensor(
+ builder,
+ 3,
+ [2],
+ tensor_type=string_type,
+ )
+ init_hashtable = _build_operator(
+ builder,
+ 1,
+ [],
+ [0],
+ builtin_options_type=_get_builtin_options_type("HashtableOptions"),
+ builtin_options=init_table_options,
+ )
+ init_import = _build_operator(
+ builder,
+ 3,
+ [0, 1, 2],
+ [],
+
builtin_options_type=_get_builtin_options_type("HashtableImportOptions"),
+ builtin_options=import_options,
+ )
+ init_subgraph = _build_subgraph(
+ builder,
+ tensors=[init_table_tensor, init_keys_tensor, init_values_tensor],
+ operators=[init_hashtable, init_import],
+ inputs=[],
+ outputs=[],
+ )
+
+ operator_codes = [
+ _build_operator_code(builder, _get_builtin_operator("CALL_ONCE")),
+ _build_operator_code(builder, _get_builtin_operator("HASHTABLE")),
+ _build_operator_code(builder, _get_builtin_operator("HASHTABLE_FIND")),
+ _build_operator_code(builder,
_get_builtin_operator("HASHTABLE_IMPORT")),
+ ]
+ buffers = [
+ _build_buffer(builder),
+ _build_buffer(builder, default_value),
+ _build_buffer(builder, table_keys.tobytes()),
+ _build_buffer(builder, table_values),
+ ]
+ return _finish_tflite_model(
+ builder,
+ subgraph=main_subgraph,
+ extra_subgraphs=[init_subgraph],
+ operator_codes=operator_codes,
+ buffers=buffers,
+ )
+
+
+def _build_tflite_hashtable_size_model():
+ """Build a model that imports a static hashtable and returns its size."""
+ builder = flatbuffers.Builder(1024)
+ resource_type = _get_resource_tensor_type()
+ string_type = _get_string_tensor_type()
+ table_keys = np.array([10, 20], dtype=np.int64)
+ table_values = _build_tflite_string_buffer(["one hundred", "two hundred"])
+
+ call_once_options = _build_call_once_options(builder, 1)
+ main_table_options = _build_hashtable_options(builder, table_id=0)
+ size_options = _build_empty_builtin_options(builder,
"HashtableSizeOptions")
+ init_table_options = _build_hashtable_options(builder, table_id=0)
+ import_options = _build_empty_builtin_options(builder,
"HashtableImportOptions")
+
+ table_tensor = _build_tensor(builder, 0, [1], tensor_type=resource_type)
+ size_tensor = _build_tensor(builder, 0, [1],
tensor_type=_tfl_tensor_type.INT64)
+ main_call_once = _build_operator(
+ builder,
+ 0,
+ [],
+ [],
+ builtin_options_type=_get_builtin_options_type("CallOnceOptions"),
+ builtin_options=call_once_options,
+ )
+ main_hashtable = _build_operator(
+ builder,
+ 1,
+ [],
+ [0],
+ builtin_options_type=_get_builtin_options_type("HashtableOptions"),
+ builtin_options=main_table_options,
+ )
+ main_size = _build_operator(
+ builder,
+ 2,
+ [0],
+ [1],
+ builtin_options_type=_get_builtin_options_type("HashtableSizeOptions"),
+ builtin_options=size_options,
+ )
+ main_subgraph = _build_subgraph(
+ builder,
+ tensors=[table_tensor, size_tensor],
+ operators=[main_call_once, main_hashtable, main_size],
+ inputs=[],
+ outputs=[1],
+ )
+
+ init_table_tensor = _build_tensor(builder, 0, [1],
tensor_type=resource_type)
+ init_keys_tensor = _build_tensor(builder, 1, [2],
tensor_type=_tfl_tensor_type.INT64)
+ init_values_tensor = _build_tensor(builder, 2, [2],
tensor_type=string_type)
+ init_hashtable = _build_operator(
+ builder,
+ 1,
+ [],
+ [0],
+ builtin_options_type=_get_builtin_options_type("HashtableOptions"),
+ builtin_options=init_table_options,
+ )
+ init_import = _build_operator(
+ builder,
+ 3,
+ [0, 1, 2],
+ [],
+
builtin_options_type=_get_builtin_options_type("HashtableImportOptions"),
+ builtin_options=import_options,
+ )
+ init_subgraph = _build_subgraph(
+ builder,
+ tensors=[init_table_tensor, init_keys_tensor, init_values_tensor],
+ operators=[init_hashtable, init_import],
+ inputs=[],
+ outputs=[],
+ )
+
+ operator_codes = [
+ _build_operator_code(builder, _get_builtin_operator("CALL_ONCE")),
+ _build_operator_code(builder, _get_builtin_operator("HASHTABLE")),
+ _build_operator_code(builder, _get_builtin_operator("HASHTABLE_SIZE")),
+ _build_operator_code(builder,
_get_builtin_operator("HASHTABLE_IMPORT")),
+ ]
+ buffers = [
+ _build_buffer(builder),
+ _build_buffer(builder, table_keys.tobytes()),
+ _build_buffer(builder, table_values),
+ ]
+ return _finish_tflite_model(
+ builder,
+ subgraph=main_subgraph,
+ extra_subgraphs=[init_subgraph],
+ operator_codes=operator_codes,
+ buffers=buffers,
+ )
+
+
+def _build_tflite_hashtable_import_in_main_model():
+ """Build a model that attempts to import hashtable values in the main
subgraph."""
+ builder = flatbuffers.Builder(1024)
+ resource_type = _get_resource_tensor_type()
+ string_type = _get_string_tensor_type()
+ table_keys = np.array([10, 20], dtype=np.int64)
+ table_values = _build_tflite_string_buffer(["one hundred", "two hundred"])
+
+ table_options = _build_hashtable_options(builder, table_id=0)
+ import_options = _build_empty_builtin_options(builder,
"HashtableImportOptions")
+
+ table_tensor = _build_tensor(builder, 0, [1], tensor_type=resource_type)
+ keys_tensor = _build_tensor(builder, 1, [2],
tensor_type=_tfl_tensor_type.INT64)
+ values_tensor = _build_tensor(builder, 2, [2], tensor_type=string_type)
+ hashtable = _build_operator(
+ builder,
+ 0,
+ [],
+ [0],
+ builtin_options_type=_get_builtin_options_type("HashtableOptions"),
+ builtin_options=table_options,
+ )
+ hashtable_import = _build_operator(
+ builder,
+ 1,
+ [0, 1, 2],
+ [],
+
builtin_options_type=_get_builtin_options_type("HashtableImportOptions"),
+ builtin_options=import_options,
+ )
+ main_subgraph = _build_subgraph(
+ builder,
+ tensors=[table_tensor, keys_tensor, values_tensor],
+ operators=[hashtable, hashtable_import],
+ inputs=[],
+ outputs=[2],
+ )
+ operator_codes = [
+ _build_operator_code(builder, _get_builtin_operator("HASHTABLE")),
+ _build_operator_code(builder,
_get_builtin_operator("HASHTABLE_IMPORT")),
+ ]
+ buffers = [
+ _build_buffer(builder),
+ _build_buffer(builder, table_keys.tobytes()),
+ _build_buffer(builder, table_values),
+ ]
+ return _finish_tflite_model(
+ builder,
+ subgraph=main_subgraph,
+ operator_codes=operator_codes,
+ buffers=buffers,
+ )
+
+
+def _build_tflite_hashtable_size_uninitialized_model():
+ """Build a model that queries the size of a hashtable without importing
values."""
+ builder = flatbuffers.Builder(1024)
+ resource_type = _get_resource_tensor_type()
+
+ table_options = _build_hashtable_options(builder, table_id=0)
+ size_options = _build_empty_builtin_options(builder,
"HashtableSizeOptions")
+ table_tensor = _build_tensor(builder, 0, [1], tensor_type=resource_type)
+ size_tensor = _build_tensor(builder, 0, [1],
tensor_type=_tfl_tensor_type.INT64)
+ hashtable = _build_operator(
+ builder,
+ 0,
+ [],
+ [0],
+ builtin_options_type=_get_builtin_options_type("HashtableOptions"),
+ builtin_options=table_options,
+ )
+ hashtable_size = _build_operator(
+ builder,
+ 1,
+ [0],
+ [1],
+ builtin_options_type=_get_builtin_options_type("HashtableSizeOptions"),
+ builtin_options=size_options,
+ )
+ main_subgraph = _build_subgraph(
+ builder,
+ tensors=[table_tensor, size_tensor],
+ operators=[hashtable, hashtable_size],
+ inputs=[],
+ outputs=[1],
+ )
+ operator_codes = [
+ _build_operator_code(builder, _get_builtin_operator("HASHTABLE")),
+ _build_operator_code(builder, _get_builtin_operator("HASHTABLE_SIZE")),
+ ]
+ return _finish_tflite_model(
+ builder,
+ subgraph=main_subgraph,
+ operator_codes=operator_codes,
+ buffers=[_build_buffer(builder)],
+ )
+
+
+def test_resource_variable_call_once_init_read():
+ """Test reading a resource variable initialized by a supported CALL_ONCE
subgraph."""
+ mod = _load_model_from_buffer(_build_tflite_resource_variable_model())
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main() -> R.Tensor((2,), dtype="float32"):
+ R.func_attr({"num_input": 0})
+ with R.dataflow():
+ gv: R.Tensor((2,), dtype="float32") = R.const([1.0, 2.0],
"float32")
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_assign_variable_main_subgraph_unsupported():
+ """Test ASSIGN_VARIABLE remains unsupported outside CALL_ONCE
initialization."""
+ with pytest.raises(tvm.error.OpNotImplemented, match="ASSIGN_VARIABLE
outside CALL_ONCE"):
+ _load_model_from_buffer(_build_tflite_resource_assign_in_main_model())
+
+
+def test_read_variable_uninitialized_unsupported():
+ """Test READ_VARIABLE rejects resource handles without supported
initialization."""
+ with pytest.raises(tvm.error.OpNotImplemented, match="READ_VARIABLE
requires a resource"):
+
_load_model_from_buffer(_build_tflite_resource_read_uninitialized_model())
+
+
+def test_hashtable_call_once_import_find_unsupported():
+ """Test HASHTABLE_FIND remains unsupported until TFLite string tensors are
supported."""
+ with pytest.raises(tvm.error.OpNotImplemented, match="TensorType.STRING"):
+ _load_model_from_buffer(_build_tflite_hashtable_find_model())
+
+
+def test_hashtable_call_once_import_size():
+ """Test HASHTABLE_SIZE for a table initialized by a supported CALL_ONCE
subgraph."""
+ mod = _load_model_from_buffer(_build_tflite_hashtable_size_model())
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main() -> R.Tensor((1,), dtype="int64"):
+ R.func_attr({"num_input": 0})
+ with R.dataflow():
+ gv: R.Tensor((1,), dtype="int64") = R.const([2], "int64")
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_hashtable_import_main_subgraph_unsupported():
+ """Test HASHTABLE_IMPORT remains unsupported outside CALL_ONCE
initialization."""
+ with pytest.raises(tvm.error.OpNotImplemented, match="HASHTABLE_IMPORT
outside CALL_ONCE"):
+ _load_model_from_buffer(_build_tflite_hashtable_import_in_main_model())
+
+
+def test_hashtable_size_uninitialized_unsupported():
+ """Test HASHTABLE_SIZE rejects tables without supported initialization."""
+ with pytest.raises(tvm.error.OpNotImplemented, match="HASHTABLE_SIZE
requires a table"):
+
_load_model_from_buffer(_build_tflite_hashtable_size_uninitialized_model())
+
+
def _get_stablehlo_builtin_operator(builtin_name):
if not hasattr(_tfl_builtin_operator, builtin_name):
pytest.skip(f"TFLite schema does not provide
BuiltinOperator.{builtin_name}")