This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 23a0ea8d8b [Relax][Frontend][TFLite] Support
STABLEHLO_RNG_BIT_GENERATOR (#19651)
23a0ea8d8b is described below
commit 23a0ea8d8bd8c1d2538408f37dcdd13f55940684
Author: HoYi <[email protected]>
AuthorDate: Tue Jun 2 02:50:40 2026 +0800
[Relax][Frontend][TFLite] Support STABLEHLO_RNG_BIT_GENERATOR (#19651)
## Summary
This PR adds Relax TFLite frontend support for the TFLite builtin
`STABLEHLO_RNG_BIT_GENERATOR` operator.
Unlike most StableHLO builtins, the TFLite runtime
(`tensorflow/lite/kernels/rng_bit_generator.cc`) implements this op as a
real,
deterministic counter-based PRNG, so the importer must reproduce it
bit-exactly
rather than map it to an existing op:
- one uint64 1-D `initial_state` input, two outputs — uint64
`output_state` and
the random-bit `output` (int32 / int64 / uint32 / uint64);
- `algorithm` in `{DEFAULT, PHILOX, THREEFRY}`, where `DEFAULT` resolves
to
`PHILOX`;
- Random123 Threefry2x32 (20 rounds) and Philox4x32 (10 rounds) with the
fixed
constants from `rng_util.cc`;
- state-length constraints: `THREEFRY` requires `u64[2]`,
`PHILOX`/`DEFAULT`
require `u64[2]` or `u64[3]`.
## Design
TVM/Relax has no matching RNG primitive, so the converter generates a
TIR kernel
that mirrors the runtime and emits it through `relax.call_tir` with two
outputs.
The kernel:
- reinterprets the uint64 state as uint32 words and advances a 64-bit
block
counter (`final counter = initial_state[1] + num_blocks`);
- runs the selected algorithm per block with all round state
materialized into
local buffers, which keeps the generated IR linear instead of an
exponentially
nested expression tree;
- packs the produced uint32 words back into the output dtype, and writes
the
updated state (key unchanged, counter advanced, Philox `u64[3]` tail
passed
through) — the only state behaviour the runtime relies on.
The kernel is an `s_tir` PrimFunc wrapped in a single opaque structured
block so
it remains a well-formed block-structured function for the Relax
pipeline
(e.g. `HasReshapePattern`). `get_tensor_type_str` and the input
`_decode_type`
map are extended with uint32/uint64 so the uint64 state imports
correctly.
Unsupported inputs raise a precise `OpNotImplemented` (non-uint64 /
non-1-D
state, mismatched output-state shape, unsupported output dtype, unknown
algorithm, per-algorithm state-length violations).
## Operator Support
| Operator | TFLite options | Relax lowering | Supported subset |
|---|---|---|---|
| `STABLEHLO_RNG_BIT_GENERATOR` |
`StablehloRngBitGeneratorOptions.Algorithm()` from `BuiltinOptions2` |
`call_tir` to a generated bit-exact TIR kernel | THREEFRY (`u64[2]`) and
PHILOX/DEFAULT (`u64[2]`/`u64[3]`); int32/int64/uint32/uint64 output |
## Tests
Tests build minimal RNG flatbuffers, compile, and execute them,
comparing the
output and updated state against the verbatim expected vectors from the
TFLite
runtime kernel test (`rng_bit_generator_test.cc`).
| Test | Coverage |
|---|---|
| `test_stablehlo_rng_bit_generator_threefry` | THREEFRY bit-exact, all
4 output dtypes |
| `test_stablehlo_rng_bit_generator_philox` | PHILOX bit-exact, all 4
output dtypes |
| `test_stablehlo_rng_bit_generator_default_matches_philox` | DEFAULT
resolves to PHILOX |
| `test_stablehlo_rng_bit_generator_deterministic` | run-to-run
bit-identical output |
| `test_stablehlo_rng_bit_generator_unsupported_output_dtype` | output
dtype guard |
| `test_stablehlo_rng_bit_generator_threefry_invalid_state_unsupported`
| THREEFRY `u64[2]` state guard |
| `test_stablehlo_rng_bit_generator_non_uint64_state_unsupported` |
uint64 state guard |
Local validation:
```bash
python -m ruff check \
python/tvm/relax/frontend/tflite/tflite_frontend.py \
tests/python/relax/test_frontend_tflite.py
python -m pytest \
tests/python/relax/test_frontend_tflite.py \
-k rng_bit_generator -q
python -m pytest \
tests/python/relax/test_frontend_tflite.py \
-k stablehlo -q
```
Result:
```text
ruff check: All checks passed
rng_bit_generator tests: 13 passed
stablehlo tests: 96 passed
```
## References
- Issue #19519 item I: remaining StableHLO operators in TFLite
- `tensorflow/lite/kernels/rng_bit_generator.cc`, `rng_util.cc`,
`rng_bit_generator_test.cc`
---
.../tvm/relax/frontend/tflite/tflite_frontend.py | 231 ++++++++++++++++++++
tests/python/relax/test_frontend_tflite.py | 236 +++++++++++++++++++++
2 files changed, 467 insertions(+)
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index fc3d61713d..bf90895cfc 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -376,6 +376,7 @@ class OperatorConverter:
"STABLEHLO_REDUCE": self._convert_stablehlo_reduce,
"STABLEHLO_REDUCE_WINDOW": self._convert_stablehlo_reduce_window,
"STABLEHLO_REMAINDER": self._convert_stablehlo_remainder,
+ "STABLEHLO_RNG_BIT_GENERATOR":
self._convert_stablehlo_rng_bit_generator,
"STABLEHLO_RSQRT":
functools.partial(self._convert_stablehlo_unary, relax_op=_op.rsqrt),
"STABLEHLO_SCATTER": self._convert_stablehlo_scatter,
"STABLEHLO_SELECT": functools.partial(
@@ -1001,6 +1002,8 @@ class OperatorConverter:
TensorType.FLOAT32: np.float32,
TensorType.INT32: np.int32,
TensorType.INT64: np.int64,
+ TensorType.UINT32: np.uint32,
+ TensorType.UINT64: np.uint64,
TensorType.BOOL: np.bool_,
}[tensor_wrapper.tensor.Type()]
@@ -1041,6 +1044,10 @@ class OperatorConverter:
return "int32"
if tensor_type == TensorType.INT64:
return "int64"
+ if tensor_type == TensorType.UINT32:
+ return "uint32"
+ if tensor_type == TensorType.UINT64:
+ return "uint64"
if tensor_type == TensorType.BOOL:
return "bool"
raise NotImplementedError(f"Tensor type {tensor_type!s} is currently
not supported")
@@ -2289,6 +2296,72 @@ class OperatorConverter:
target = call_target_name or "<empty>"
raise tvm.error.OpNotImplemented(f"STABLEHLO_CUSTOM_CALL target
{target} is not supported")
+ def _convert_stablehlo_rng_bit_generator(self, op):
+ """Convert STABLEHLO_RNG_BIT_GENERATOR to a bit-exact call_tir
kernel."""
+ from tflite.RngAlgorithm import RngAlgorithm
+ from tflite.StablehloRngBitGeneratorOptions import
StablehloRngBitGeneratorOptions
+
+ op_name = "STABLEHLO_RNG_BIT_GENERATOR"
+ input_tensors = self.get_input_tensors(op)
+ output_tensors = self.get_output_tensors(op)
+ if len(input_tensors) != 1 or len(output_tensors) != 2:
+ raise tvm.error.OpNotImplemented(f"{op_name} expects one input and
two outputs")
+
+ opts = self._get_stablehlo_options(op, StablehloRngBitGeneratorOptions)
+ algorithm_enum = opts.Algorithm()
+ # DEFAULT resolves to PHILOX in the TFLite runtime kernel.
+ if algorithm_enum == RngAlgorithm.THREEFRY:
+ algorithm = "threefry"
+ elif algorithm_enum in (RngAlgorithm.PHILOX, RngAlgorithm.DEFAULT):
+ algorithm = "philox"
+ else:
+ raise tvm.error.OpNotImplemented(
+ f"{op_name} algorithm {algorithm_enum} is not supported"
+ )
+
+ state_tensor = input_tensors[0]
+ if self.get_tensor_type_str(state_tensor.tensor.Type()) != "uint64":
+ raise tvm.error.OpNotImplemented(f"{op_name} requires a uint64
initial state")
+ state_shape = self._get_static_tensor_shape(state_tensor, op_name)
+ if len(state_shape) != 1:
+ raise tvm.error.OpNotImplemented(f"{op_name} requires a 1-D
initial state")
+ state_len = int(state_shape[0])
+ # State-length constraints mirror the TFLite runtime kernel.
+ if algorithm == "threefry" and state_len != 2:
+ raise tvm.error.OpNotImplemented(f"{op_name} THREEFRY requires a
u64[2] state")
+ if algorithm == "philox" and state_len not in (2, 3):
+ raise tvm.error.OpNotImplemented(f"{op_name} PHILOX requires a
u64[2] or u64[3] state")
+
+ out_state_tensor, out_tensor = output_tensors
+ if self.get_tensor_type_str(out_state_tensor.tensor.Type()) !=
"uint64":
+ raise tvm.error.OpNotImplemented(f"{op_name} output state must be
uint64")
+ out_state_shape = self._get_static_tensor_shape(out_state_tensor,
op_name)
+ if list(out_state_shape) != list(state_shape):
+ raise tvm.error.OpNotImplemented(
+ f"{op_name} output state shape must match the initial state"
+ )
+ out_dtype = self.get_tensor_type_str(out_tensor.tensor.Type())
+ if out_dtype not in ("int32", "int64", "uint32", "uint64"):
+ raise tvm.error.OpNotImplemented(f"{op_name} output dtype
{out_dtype} is not supported")
+ out_shape = tuple(self._get_static_tensor_shape(out_tensor, op_name))
+
+ prim_func = _build_stablehlo_rng_bit_generator_primfunc(
+ algorithm, state_len, out_dtype, out_shape
+ )
+ module_builder = self.conversion_state["module_builder"]
+ func_name =
f"tflite_stablehlo_rng_{algorithm}_{out_state_tensor.tensor_idx}"
+ gv = module_builder.add_func(prim_func, func_name)
+ state_expr = self.get_tensor_expr(state_tensor)
+ call = relax.call_tir(
+ gv,
+ [state_expr],
+ [
+ relax.TensorStructInfo(tuple(state_shape), "uint64"),
+ relax.TensorStructInfo(out_shape, out_dtype),
+ ],
+ )
+ return self.bb.normalize(call)
+
def _convert_stablehlo_while(self, op):
"""Convert STABLEHLO_WHILE to a recursive Relax private function."""
from tflite.StablehloWhileOptions import StablehloWhileOptions
@@ -7430,6 +7503,162 @@ class OperatorConverter:
)
+# Constants for the Random123 counter-based PRNGs used by
STABLEHLO_RNG_BIT_GENERATOR,
+# matching tensorflow/lite/kernels/rng_util.cc.
+_STABLEHLO_RNG_THREEFRY_PARITY = 0x1BD11BDA
+_STABLEHLO_RNG_PHILOX_MUL_A = 0xD2511F53
+_STABLEHLO_RNG_PHILOX_MUL_B = 0xCD9E8D57
+_STABLEHLO_RNG_PHILOX_WEYL_A = 0x9E3779B9
+_STABLEHLO_RNG_PHILOX_WEYL_B = 0xBB67AE85
+
+
+def _build_stablehlo_rng_bit_generator_primfunc(algorithm, state_len,
out_dtype, out_shape):
+ """Build a bit-exact TIR kernel for STABLEHLO_RNG_BIT_GENERATOR.
+
+ Mirrors the TFLite runtime kernel
(tensorflow/lite/kernels/rng_bit_generator.cc),
+ implementing the Random123 Threefry2x32 (20 rounds) and Philox4x32 (10
rounds)
+ counter-based PRNGs. The kernel reinterprets the uint64 state as uint32
words,
+ advances a 64-bit block counter, and packs the generated words into the
output
+ tensor. The updated state keeps the key unchanged and only advances the
counter,
+ which is the only behaviour the runtime relies on.
+ """
+ from tvm.script.parser import tirx as T
+
+ total = 1
+ for dim in out_shape:
+ total *= int(dim)
+ is_64bit = out_dtype in ("int64", "uint64")
+ block_words = 2 if algorithm == "threefry" else 4
+ out_word_count = total * (2 if is_64bit else 1)
+ num_blocks = (out_word_count + block_words - 1) // block_words
+ writes_per_block = block_words // (2 if is_64bit else 1)
+ parity = _STABLEHLO_RNG_THREEFRY_PARITY
+ mul_a, mul_b = _STABLEHLO_RNG_PHILOX_MUL_A, _STABLEHLO_RNG_PHILOX_MUL_B
+ weyl_a, weyl_b = _STABLEHLO_RNG_PHILOX_WEYL_A, _STABLEHLO_RNG_PHILOX_WEYL_B
+
+ def _u32(value):
+ return T.Cast("uint32", value)
+
+ def _u64(value):
+ return T.Cast("uint64", value)
+
+ def _store_value(words, write_index):
+ # Pack the generated uint32 words into one output element,
reinterpreting
+ # the bit pattern into the (possibly signed) output dtype.
+ if is_64bit:
+ low = _u64(words[2 * write_index])
+ high = _u64(words[2 * write_index + 1])
+ return T.reinterpret(out_dtype, low | (high << T.uint64(32)))
+ return T.reinterpret(out_dtype, words[write_index])
+
+ if algorithm == "threefry":
+
+ @T.prim_func(private=True, s_tir=True)
+ def kernel(
+ initial_state: T.Buffer((state_len,), "uint64"),
+ output_state: T.Buffer((state_len,), "uint64"),
+ output: T.Buffer(out_shape, out_dtype),
+ ):
+ # A single opaque structured block keeps the imperative kernel as a
+ # well-formed block-structured PrimFunc, as required by the Relax
+ # pipeline (e.g. HasReshapePattern).
+ with T.sblock("rng_bit_generator"):
+ state_key = initial_state[0]
+ state_counter = initial_state[1]
+ key_0 = _u32(state_key & T.uint64(0xFFFFFFFF))
+ key_1 = _u32(state_key >> T.uint64(32))
+ output_state[0] = state_key
+ output_state[1] = state_counter + T.uint64(num_blocks)
+ out_flat = T.decl_buffer((total,), out_dtype, data=output.data)
+ keys = T.decl_buffer((3,), "uint32", scope="local")
+ rotations = T.decl_buffer((8,), "uint32", scope="local")
+ ctr = T.decl_buffer((2,), "uint32", scope="local")
+ keys[0] = key_0
+ keys[1] = key_1
+ keys[2] = key_0 ^ key_1 ^ T.uint32(parity)
+ rotations[0] = T.uint32(13)
+ rotations[1] = T.uint32(15)
+ rotations[2] = T.uint32(26)
+ rotations[3] = T.uint32(6)
+ rotations[4] = T.uint32(17)
+ rotations[5] = T.uint32(29)
+ rotations[6] = T.uint32(16)
+ rotations[7] = T.uint32(24)
+ for block in T.serial(num_blocks):
+ counter = state_counter + _u64(block)
+ ctr[0] = _u32(counter & T.uint64(0xFFFFFFFF)) + key_0
+ ctr[1] = _u32(counter >> T.uint64(32)) + key_1
+ for group in T.serial(5):
+ for step in T.serial(4):
+ rot = rotations[(group * 4 + step) % 8]
+ ctr[0] = ctr[0] + ctr[1]
+ ctr[1] = (ctr[1] << rot) | (ctr[1] >>
(T.uint32(32) - rot))
+ ctr[1] = ctr[1] ^ ctr[0]
+ ctr[0] = ctr[0] + keys[(group + 1) % 3]
+ ctr[1] = ctr[1] + keys[(group + 2) % 3] + _u32(group +
1)
+ for write_index in T.serial(writes_per_block):
+ element = block * writes_per_block + write_index
+ if element < total:
+ out_flat[element] = _store_value(ctr, write_index)
+
+ return kernel
+
+ @T.prim_func(private=True, s_tir=True)
+ def kernel(
+ initial_state: T.Buffer((state_len,), "uint64"),
+ output_state: T.Buffer((state_len,), "uint64"),
+ output: T.Buffer(out_shape, out_dtype),
+ ):
+ with T.sblock("rng_bit_generator"):
+ state_key = initial_state[0]
+ state_counter = initial_state[1]
+ key_0 = _u32(state_key & T.uint64(0xFFFFFFFF))
+ key_1 = _u32(state_key >> T.uint64(32))
+ output_state[0] = state_key
+ output_state[1] = state_counter + T.uint64(num_blocks)
+ out_flat = T.decl_buffer((total,), out_dtype, data=output.data)
+ ctr = T.decl_buffer((4,), "uint32", scope="local")
+ keys = T.decl_buffer((2,), "uint32", scope="local")
+ high_ctr = T.decl_buffer((2,), "uint32", scope="local")
+ if state_len == 3:
+ # PHILOX u64[3]: the third state word feeds the high counter
and
+ # is passed through to the output state unchanged.
+ high_state = initial_state[2]
+ output_state[2] = high_state
+ high_ctr[0] = _u32(high_state & T.uint64(0xFFFFFFFF))
+ high_ctr[1] = _u32(high_state >> T.uint64(32))
+ else:
+ high_ctr[0] = key_0
+ high_ctr[1] = key_1
+ for block in T.serial(num_blocks):
+ counter = state_counter + _u64(block)
+ ctr[0] = _u32(counter & T.uint64(0xFFFFFFFF))
+ ctr[1] = _u32(counter >> T.uint64(32))
+ ctr[2] = high_ctr[0]
+ ctr[3] = high_ctr[1]
+ keys[0] = key_0
+ keys[1] = key_1
+ for _round in T.serial(10):
+ prod_0 = T.uint64(mul_a) * _u64(ctr[0])
+ prod_1 = T.uint64(mul_b) * _u64(ctr[2])
+ new_0 = _u32(prod_1 >> T.uint64(32)) ^ ctr[1] ^ keys[0]
+ new_1 = _u32(prod_1 & T.uint64(0xFFFFFFFF))
+ new_2 = _u32(prod_0 >> T.uint64(32)) ^ ctr[3] ^ keys[1]
+ new_3 = _u32(prod_0 & T.uint64(0xFFFFFFFF))
+ ctr[0] = new_0
+ ctr[1] = new_1
+ ctr[2] = new_2
+ ctr[3] = new_3
+ keys[0] = keys[0] + T.uint32(weyl_a)
+ keys[1] = keys[1] + T.uint32(weyl_b)
+ for write_index in T.serial(writes_per_block):
+ element = block * writes_per_block + write_index
+ if element < total:
+ out_flat[element] = _store_value(ctr, write_index)
+
+ return kernel
+
+
# pylint: disable=no-else-return
def prepare_dense_matrix_from_sparse(sparse_tensor, sparse_tensor_value,
sparse_tensor_type):
"""Prepare sparse indices and dense matrix from TFLite sparse
parameters."""
@@ -7676,6 +7905,8 @@ def _decode_type(n):
7: "int16",
8: "complex64",
9: "int8",
+ 12: "uint64",
+ 15: "uint32",
}
return _tflite_m[n]
diff --git a/tests/python/relax/test_frontend_tflite.py
b/tests/python/relax/test_frontend_tflite.py
index c34da605de..e4866d7096 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -3697,6 +3697,7 @@ _tfl_stablehlo_reduce_window_opts =
_get_tflite_schema_module("StablehloReduceWi
_tfl_stablehlo_scatter_opts =
_get_tflite_schema_module("StablehloScatterOptions")
_tfl_stablehlo_sort_opts = _get_tflite_schema_module("StablehloSortOptions")
_tfl_stablehlo_while_opts = _get_tflite_schema_module("StablehloWhileOptions")
+_tfl_stablehlo_rng_opts =
_get_tflite_schema_module("StablehloRngBitGeneratorOptions")
_tfl_call_options = _get_tflite_schema_module("CallOptions")
_tfl_call_once_options = _get_tflite_schema_module("CallOnceOptions")
_tfl_dimension_metadata = _get_tflite_schema_module("DimensionMetadata")
@@ -3721,6 +3722,7 @@ _tfl_fc_weights_format =
_get_tflite_schema_enum("FullyConnectedOptionsWeightsFo
_tfl_padding = _get_tflite_schema_enum("Padding")
_tfl_sparse_index_vector = _get_tflite_schema_enum("SparseIndexVector")
_tfl_tensor_type = _get_tflite_schema_enum("TensorType")
+_tfl_rng_algorithm = _get_tflite_schema_enum("RngAlgorithm")
_tfl_lstm_options = _get_tflite_schema_module("LSTMOptions")
_tfl_sequence_rnn_options = _get_tflite_schema_module("SequenceRNNOptions")
@@ -7015,6 +7017,240 @@ def
test_stablehlo_options_missing_payload_unsupported():
_load_model_from_buffer(buf)
+def _build_stablehlo_rng_model(algorithm, state_len, out_shape,
out_tensor_type, const_state=None):
+ """Build a STABLEHLO_RNG_BIT_GENERATOR model.
+
+ When ``const_state`` is provided, the uint64 initial state is embedded as a
+ constant tensor (no graph input); otherwise it is a graph input.
+ """
+ builder = flatbuffers.Builder(1024)
+
+ _tfl_stablehlo_rng_opts.StablehloRngBitGeneratorOptionsStart(builder)
+
_tfl_stablehlo_rng_opts.StablehloRngBitGeneratorOptionsAddAlgorithm(builder,
algorithm)
+ rng_opts =
_tfl_stablehlo_rng_opts.StablehloRngBitGeneratorOptionsEnd(builder)
+
+ rng_builtin =
_get_stablehlo_builtin_operator("STABLEHLO_RNG_BIT_GENERATOR")
+ rng_code = _build_operator_code(builder, rng_builtin)
+
+ main_tensors = [
+ _build_tensor(builder, 0, [state_len],
tensor_type=_tfl_tensor_type.UINT64),
+ _build_tensor(builder, 1, [state_len],
tensor_type=_tfl_tensor_type.UINT64),
+ _build_tensor(builder, 2, list(out_shape),
tensor_type=out_tensor_type),
+ ]
+ rng_op = _build_operator(
+ builder,
+ 0,
+ [0],
+ [1, 2],
+
builtin_options2_type=_tfl_builtin_options2.StablehloRngBitGeneratorOptions,
+ builtin_options2=rng_opts,
+ )
+ main_subgraph = _build_subgraph(
+ builder,
+ tensors=main_tensors,
+ operators=[rng_op],
+ inputs=[] if const_state is not None else [0],
+ outputs=[1, 2],
+ )
+
+ state_data = None
+ if const_state is not None:
+ state_data = np.array(const_state, dtype="uint64").tobytes()
+ buffers = [
+ _build_buffer(builder, data=state_data),
+ _build_buffer(builder),
+ _build_buffer(builder),
+ ]
+ return _finish_tflite_model(
+ builder,
+ subgraph=main_subgraph,
+ operator_codes=[rng_code],
+ buffers=buffers,
+ )
+
+
+def _run_stablehlo_rng_model(algorithm, state_len, out_shape, out_tensor_type,
init_state):
+ """Import, compile, and execute an RNG model, returning (output_state,
output)."""
+ buf = _build_stablehlo_rng_model(algorithm, state_len, out_shape,
out_tensor_type)
+ mod = _load_model_from_buffer(buf)
+ ex = tvm.compile(mod, tvm.target.Target("llvm"))
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ result = vm["main"](tvm.runtime.tensor(np.array(init_state,
dtype="uint64")))
+ return result[0].numpy(), result[1].numpy()
+
+
+# Expected vectors are taken verbatim from the TFLite runtime kernel test
+# (tensorflow/lite/kernels/rng_bit_generator_test.cc), guaranteeing bit-exact
parity.
+_RNG_THREEFRY_EXPECTED = {
+ "int32": [43444564, -2144348869, -315321645, -549236733, 1672743891,
-54463903],
+ "uint32": [43444564, 2150618427, 3979645651, 3745730563, 1672743891,
4240503393],
+ "int64": [
+ -9209908263526143660,
+ -2358953802017238317,
+ -233920680524772397,
+ 2658481902456610144,
+ -2022031683723149139,
+ -2324041912354448873,
+ ],
+ "uint64": [
+ 9236835810183407956,
+ 16087790271692313299,
+ 18212823393184779219,
+ 2658481902456610144,
+ 16424712389986402477,
+ 16122702161355102743,
+ ],
+}
+_RNG_THREEFRY_STATE = {"int32": [1, 5], "uint32": [1, 5], "int64": [1, 8],
"uint64": [1, 8]}
+_RNG_PHILOX_EXPECTED = {
+ "int32": [-263854262, 1366700262, 495645701, -1243243882, 89414891,
1917262711],
+ "uint32": [4031113034, 1366700262, 495645701, 3051723414, 89414891,
1917262711],
+ "int64": [
+ 5869932932755744586,
+ -5339691813646437371,
+ 8234580641674714347,
+ 2641225993340350124,
+ 1962472297844690804,
+ -3580856229565614135,
+ ],
+ "uint64": [
+ 5869932932755744586,
+ 13107052260063114245,
+ 8234580641674714347,
+ 2641225993340350124,
+ 1962472297844690804,
+ 14865887844143937481,
+ ],
+}
+_RNG_PHILOX_STATE = {
+ "int32": [1, 4, 3],
+ "uint32": [1, 4, 3],
+ "int64": [1, 5, 3],
+ "uint64": [1, 5, 3],
+}
+
+
[email protected](
+ "out_dtype,out_tensor_type",
+ [
+ ("int32", _tfl_tensor_type.INT32),
+ ("uint32", _tfl_tensor_type.UINT32),
+ ("int64", _tfl_tensor_type.INT64),
+ ("uint64", _tfl_tensor_type.UINT64),
+ ],
+)
+def test_stablehlo_rng_bit_generator_threefry(out_dtype, out_tensor_type):
+ """TFLite STABLEHLO_RNG_BIT_GENERATOR THREEFRY matches the runtime kernel
bit-exactly."""
+ state, output = _run_stablehlo_rng_model(
+ _tfl_rng_algorithm.THREEFRY, 2, [2, 3], out_tensor_type, [1, 2]
+ )
+ assert output.flatten().tolist() == _RNG_THREEFRY_EXPECTED[out_dtype]
+ assert state.tolist() == _RNG_THREEFRY_STATE[out_dtype]
+
+
[email protected](
+ "out_dtype,out_tensor_type",
+ [
+ ("int32", _tfl_tensor_type.INT32),
+ ("uint32", _tfl_tensor_type.UINT32),
+ ("int64", _tfl_tensor_type.INT64),
+ ("uint64", _tfl_tensor_type.UINT64),
+ ],
+)
+def test_stablehlo_rng_bit_generator_philox(out_dtype, out_tensor_type):
+ """TFLite STABLEHLO_RNG_BIT_GENERATOR PHILOX matches the runtime kernel
bit-exactly."""
+ state, output = _run_stablehlo_rng_model(
+ _tfl_rng_algorithm.PHILOX, 3, [2, 3], out_tensor_type, [1, 2, 3]
+ )
+ assert output.flatten().tolist() == _RNG_PHILOX_EXPECTED[out_dtype]
+ assert state.tolist() == _RNG_PHILOX_STATE[out_dtype]
+
+
+def test_stablehlo_rng_bit_generator_default_matches_philox():
+ """TFLite STABLEHLO_RNG_BIT_GENERATOR DEFAULT resolves to the PHILOX
algorithm."""
+ state, output = _run_stablehlo_rng_model(
+ _tfl_rng_algorithm.DEFAULT, 3, [2, 3], _tfl_tensor_type.INT32, [1, 2,
3]
+ )
+ assert output.flatten().tolist() == _RNG_PHILOX_EXPECTED["int32"]
+ assert state.tolist() == _RNG_PHILOX_STATE["int32"]
+
+
+def test_stablehlo_rng_bit_generator_deterministic():
+ """Re-running the imported RNG kernel yields identical bit-exact output."""
+ buf = _build_stablehlo_rng_model(_tfl_rng_algorithm.PHILOX, 3, [3, 3],
_tfl_tensor_type.INT32)
+ mod = _load_model_from_buffer(buf)
+ ex = tvm.compile(mod, tvm.target.Target("llvm"))
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ init = tvm.runtime.tensor(np.array([7, 8, 9], dtype="uint64"))
+ first = vm["main"](init)
+ second = vm["main"](init)
+ np.testing.assert_equal(first[1].numpy(), second[1].numpy())
+ np.testing.assert_equal(first[0].numpy(), second[0].numpy())
+
+
+def test_stablehlo_rng_bit_generator_constant_state():
+ """A constant uint64 initial state imports and stays bit-exact (no graph
input)."""
+ buf = _build_stablehlo_rng_model(
+ _tfl_rng_algorithm.THREEFRY, 2, [2, 3], _tfl_tensor_type.INT32,
const_state=[1, 2]
+ )
+ mod = _load_model_from_buffer(buf)
+ assert len(mod["main"].params) == 0
+ ex = tvm.compile(mod, tvm.target.Target("llvm"))
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ result = vm["main"]()
+ assert result[1].numpy().flatten().tolist() ==
_RNG_THREEFRY_EXPECTED["int32"]
+ assert result[0].numpy().tolist() == _RNG_THREEFRY_STATE["int32"]
+
+
+def test_stablehlo_rng_bit_generator_unsupported_output_dtype():
+ """TFLite STABLEHLO_RNG_BIT_GENERATOR rejects non-integer output dtypes."""
+ buf = _build_stablehlo_rng_model(_tfl_rng_algorithm.PHILOX, 3, [2, 3],
_tfl_tensor_type.FLOAT32)
+ with pytest.raises(tvm.error.OpNotImplemented, match="output dtype float32
is not supported"):
+ _load_model_from_buffer(buf)
+
+
+def test_stablehlo_rng_bit_generator_threefry_invalid_state_unsupported():
+ """TFLite STABLEHLO_RNG_BIT_GENERATOR rejects a u64[3] state for
THREEFRY."""
+ buf = _build_stablehlo_rng_model(_tfl_rng_algorithm.THREEFRY, 3, [2, 3],
_tfl_tensor_type.INT32)
+ with pytest.raises(tvm.error.OpNotImplemented, match="THREEFRY requires a
u64.2. state"):
+ _load_model_from_buffer(buf)
+
+
+def test_stablehlo_rng_bit_generator_non_uint64_state_unsupported():
+ """TFLite STABLEHLO_RNG_BIT_GENERATOR rejects a non-uint64 initial
state."""
+ builder = flatbuffers.Builder(1024)
+ _tfl_stablehlo_rng_opts.StablehloRngBitGeneratorOptionsStart(builder)
+ _tfl_stablehlo_rng_opts.StablehloRngBitGeneratorOptionsAddAlgorithm(
+ builder, _tfl_rng_algorithm.PHILOX
+ )
+ rng_opts =
_tfl_stablehlo_rng_opts.StablehloRngBitGeneratorOptionsEnd(builder)
+ rng_code = _build_operator_code(
+ builder, _get_stablehlo_builtin_operator("STABLEHLO_RNG_BIT_GENERATOR")
+ )
+ tensors = [
+ _build_tensor(builder, 0, [2], tensor_type=_tfl_tensor_type.INT64),
+ _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.INT64),
+ _build_tensor(builder, 2, [2, 3], tensor_type=_tfl_tensor_type.INT32),
+ ]
+ rng_op = _build_operator(
+ builder,
+ 0,
+ [0],
+ [1, 2],
+
builtin_options2_type=_tfl_builtin_options2.StablehloRngBitGeneratorOptions,
+ builtin_options2=rng_opts,
+ )
+ subgraph = _build_subgraph(
+ builder, tensors=tensors, operators=[rng_op], inputs=[0], outputs=[1,
2]
+ )
+ buffers = [_build_buffer(builder) for _ in range(3)]
+ buf = _finish_tflite_model(
+ builder, subgraph=subgraph, operator_codes=[rng_code], buffers=buffers
+ )
+ with pytest.raises(tvm.error.OpNotImplemented, match="requires a uint64
initial state"):
+ _load_model_from_buffer(buf)
+
+
def test_stablehlo_while():
"""TFLite STABLEHLO_WHILE lowers to a recursive Relax private function."""
mod = _load_model_from_buffer(_build_stablehlo_while_model())