This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 71c634f8b1 [Relax][Frontend][TFLite] Add `RANDOM_UNIFORM`,
`RANDOM_STANDARD_NORMAL`, and `MULTINOMIAL` (#19473)
71c634f8b1 is described below
commit 71c634f8b1340abec6b2bd69bf2bea5f41cbb521
Author: HoYi <[email protected]>
AuthorDate: Fri May 1 18:50:07 2026 +0800
[Relax][Frontend][TFLite] Add `RANDOM_UNIFORM`, `RANDOM_STANDARD_NORMAL`,
and `MULTINOMIAL` (#19473)
## Summary
This PR adds support for the TFLite `RANDOM_UNIFORM`,
`RANDOM_STANDARD_NORMAL`, and `MULTINOMIAL` operators in the Relax
TFLite frontend, covering the items I claimed in #19412.
`RANDOM_UNIFORM` and `RANDOM_STANDARD_NORMAL` are lowered to seeded
`tvm.contrib.random` calls with dynamic shape support. `MULTINOMIAL` is
lowered by composing existing Relax ops around
`relax.op.multinomial_from_uniform`.
## Changes
### Frontend
1. Add converter registrations for `RANDOM_UNIFORM`,
`RANDOM_STANDARD_NORMAL`, and `MULTINOMIAL`.
2. Add shared helpers to:
- parse TFLite `RandomOptions` seeds
- convert shape tensors to Relax shape expressions for dynamic-shape
random ops
3. Lower `RANDOM_UNIFORM` to `tvm.contrib.random.uniform`.
4. Lower `RANDOM_STANDARD_NORMAL` to `tvm.contrib.random.normal`.
5. Lower `MULTINOMIAL` by:
- applying `R.nn.softmax` to logits
- generating seeded uniform samples
- calling `R.multinomial_from_uniform`
- reshaping the result to `[batch_size, num_samples]`
### Runtime
1. Extend `src/runtime/contrib/random/random.cc` to accept seeded calls
for `tvm.contrib.random.uniform` and `tvm.contrib.random.normal`.
2. Preserve compatibility with the existing unseeded calling convention.
## Testing
All tests pass:
```bash
pytest
tests/python/relax/test_frontend_tflite.py::test_random_uniform_dynamic_shape
\
tests/python/relax/test_frontend_tflite.py::test_random_standard_normal_dynamic_shape
\
tests/python/relax/test_frontend_tflite.py::test_multinomial_dynamic_num_samples
-v
```
## References
- #19412
- Claimed items: MULTINOMIAL, RANDOM_STANDARD_NORMAL, RANDOM_UNIFORM
---
.../tvm/relax/frontend/tflite/tflite_frontend.py | 153 +++++++++++++++++++--
src/runtime/contrib/random/random.cc | 93 +++++++++++--
tests/python/relax/test_frontend_tflite.py | 109 +++++++++++++++
3 files changed, 331 insertions(+), 24 deletions(-)
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 9d0fdaf587..05bda6816b 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -188,6 +188,7 @@ class OperatorConverter:
"MINIMUM": functools.partial(self._convert_elemwise,
relax_op=_op.minimum),
"MIRROR_PAD": self.convert_mirror_pad,
"MUL": functools.partial(self._convert_elemwise,
relax_op=_op.multiply),
+ "MULTINOMIAL": self.convert_multinomial,
"NEG": functools.partial(self._convert_unary_elemwise,
relax_op=_op.negative),
"NOT_EQUAL": functools.partial(
self._convert_elemwise, relax_op=_op.not_equal,
comparison_op=True
@@ -200,6 +201,8 @@ class OperatorConverter:
"PRELU": self.convert_prelu,
"RANGE": self.convert_range,
"QUANTIZE": self.convert_quantize,
+ "RANDOM_STANDARD_NORMAL": self.convert_random_standard_normal,
+ "RANDOM_UNIFORM": self.convert_random_uniform,
"REDUCE_ALL": functools.partial(self._convert_reduce_bool,
relax_op=_op.min),
"REDUCE_ANY": functools.partial(self._convert_reduce_bool,
relax_op=_op.max),
"REDUCE_MAX": functools.partial(self._convert_reduce,
relax_op=_op.max),
@@ -542,6 +545,22 @@ class OperatorConverter:
return "bool"
raise NotImplementedError(f"Tensor type {tensor_type!s} is currently
not supported")
+ def _get_shape_expr_from_tensor(self, shape_tensor, prefix):
+ """Convert a TFLite shape tensor to a Relax shape expression."""
+ if self.has_expr(shape_tensor.tensor_idx):
+ dims_expr = self.get_expr(shape_tensor.tensor_idx)
+ dims_ndim = int(self.get_tensor_shape(shape_tensor)[0])
+ dims_dtype = self.get_tensor_type_str(shape_tensor.tensor.Type())
+ dims_expr = self.bb.match_cast(dims_expr,
relax.TensorStructInfo([dims_ndim], dims_dtype))
+ dims_expr = self.bb.normalize(relax.op.astype(dims_expr, "int64"))
+ shape_dataflow_var =
self.bb.emit(relax.op.tensor_to_shape(dims_expr))
+ shape_vars = [tirx.Var(f"{prefix}_{i}", "int64") for i in
range(dims_ndim)]
+ self.bb.match_cast(shape_dataflow_var,
relax.ShapeStructInfo(shape_vars))
+ return relax.ShapeExpr(shape_vars), shape_vars
+
+ dims = to_int_list(self.get_tensor_value(shape_tensor))
+ return dims, dims
+
def flatten_to_nd(self, x, nd=3):
"""Flatten input tensor to nd rank"""
shape = x.struct_info.shape
@@ -1783,23 +1802,129 @@ class OperatorConverter:
dims_tensor = input_tensors[0]
in_value_expr = self.get_expr(input_tensors[1].tensor_idx)
- if self.has_expr(dims_tensor.tensor_idx):
- dims_expr = self.get_expr(dims_tensor.tensor_idx)
- dims_ndim = int(self.get_tensor_shape(dims_tensor)[0])
+ out_shape, _ = self._get_shape_expr_from_tensor(dims_tensor,
"fill_dim")
+ out = relax.op.full(out_shape, in_value_expr)
- # Bind runtime dims to fresh symbolic shape vars so the imported
- # module remains well formed before LegalizeOps runs.
- dims_expr = self.bb.match_cast(dims_expr,
relax.TensorStructInfo([dims_ndim], "int32"))
- dims_expr = self.bb.normalize(relax.op.astype(dims_expr, "int64"))
- shape_dataflow_var =
self.bb.emit(relax.op.tensor_to_shape(dims_expr))
- shape_vars = [tirx.Var(f"fill_dim_{i}", "int64") for i in
range(dims_ndim)]
- self.bb.match_cast(shape_dataflow_var,
relax.ShapeStructInfo(shape_vars))
- out = relax.op.full(relax.ShapeExpr(shape_vars), in_value_expr)
+ return out
+
+ def _get_random_options(self, op):
+ """Return the seed pair for random TFLite operators.
+
+ The runtime imports seeded TFLite random ops with stateless semantics,
so identical
+ non-zero seed pairs produce identical results on every invocation. The
seed pair
+ (0, 0) is forwarded as the TF non-deterministic case.
+ """
+ from tflite.BuiltinOptions import BuiltinOptions
+ from tflite.RandomOptions import RandomOptions
+
+ if op.BuiltinOptionsType():
+ assert op.BuiltinOptionsType() == BuiltinOptions.RandomOptions
+ random_options = RandomOptions()
+ op_options = op.BuiltinOptions()
+ random_options.Init(op_options.Bytes, op_options.Pos)
+ return int(random_options.Seed()), int(random_options.Seed2())
+ return 0, 0
+
+ def _check_random_output_dtype(self, op_name, output_dtype,
supported_dtypes):
+ if output_dtype not in supported_dtypes:
+ supported = ", ".join(supported_dtypes)
+ raise tvm.error.OpNotImplemented(
+ f"The TFLite {op_name} converter currently supports output
dtype(s) "
+ f"{supported} only, but got {output_dtype}."
+ )
+
+ def convert_random_uniform(self, op):
+ """Convert TFLite RANDOM_UNIFORM using stateless seeded RNG
semantics."""
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 1, "input tensors length should be 1"
+
+ output_tensors = self.get_output_tensors(op)
+ assert len(output_tensors) == 1, "output tensors length should be 1"
+ output_tensor = output_tensors[0]
+ output_dtype = self.get_tensor_type_str(output_tensor.tensor.Type())
+ self._check_random_output_dtype("RANDOM_UNIFORM", output_dtype,
["float32"])
+
+ out_shape, _ = self._get_shape_expr_from_tensor(input_tensors[0],
"random_uniform_dim")
+ seed, seed2 = self._get_random_options(op)
+ return relax.op.call_dps_packed(
+ "tvm.contrib.random.uniform",
+ (seed, seed2, 0.0, 1.0),
+ out_sinfo=relax.TensorStructInfo(out_shape, output_dtype),
+ )
+
+ def convert_random_standard_normal(self, op):
+ """Convert TFLite RANDOM_STANDARD_NORMAL using stateless seeded RNG
semantics."""
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 1, "input tensors length should be 1"
+
+ output_tensors = self.get_output_tensors(op)
+ assert len(output_tensors) == 1, "output tensors length should be 1"
+ output_tensor = output_tensors[0]
+ output_dtype = self.get_tensor_type_str(output_tensor.tensor.Type())
+ self._check_random_output_dtype("RANDOM_STANDARD_NORMAL",
output_dtype, ["float32"])
+
+ out_shape, _ = self._get_shape_expr_from_tensor(
+ input_tensors[0], "random_standard_normal_dim"
+ )
+ seed, seed2 = self._get_random_options(op)
+ return relax.op.call_dps_packed(
+ "tvm.contrib.random.normal",
+ (seed, seed2, 0.0, 1.0),
+ out_sinfo=relax.TensorStructInfo(out_shape, output_dtype),
+ )
+
+ def convert_multinomial(self, op):
+ """Convert TFLite MULTINOMIAL using stateless seeded RNG semantics."""
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 2, "input tensors length should be 2"
+
+ logits_tensor, num_samples_tensor = input_tensors
+ logits_expr = self.get_tensor_expr(logits_tensor)
+ batch_size = self.get_tensor_shape(logits_tensor)[0]
+ if self.has_expr(num_samples_tensor.tensor_idx):
+ scalar_expr = self.get_expr(num_samples_tensor.tensor_idx)
+ scalar_dtype =
self.get_tensor_type_str(num_samples_tensor.tensor.Type())
+ scalar_expr = self.bb.match_cast(scalar_expr,
relax.TensorStructInfo([], scalar_dtype))
+ scalar_expr = self.bb.normalize(relax.op.astype(scalar_expr,
"int64"))
+ scalar_expr = self.bb.normalize(relax.op.reshape(scalar_expr, [1]))
+ shape_dataflow_var =
self.bb.emit(relax.op.tensor_to_shape(scalar_expr))
+ num_samples = tirx.Var("multinomial_num_samples", "int64")
+ self.bb.match_cast(shape_dataflow_var,
relax.ShapeStructInfo([num_samples]))
else:
- in_dims = list(self.get_tensor_value(dims_tensor))
- out = relax.op.full(in_dims, in_value_expr)
+ value = self.get_tensor_value(num_samples_tensor)
+ assert value.size == 1, (
+ "TFLite MULTINOMIAL num_samples must be a scalar tensor, "
+ f"but got {value.size} values"
+ )
+ num_samples = int(value.item())
+ output_batch = batch_size * num_samples
- return out
+ output_tensors = self.get_output_tensors(op)
+ assert len(output_tensors) == 1, "output tensors length should be 1"
+ output_tensor = output_tensors[0]
+ output_dtype = self.get_tensor_type_str(output_tensor.tensor.Type())
+ self._check_random_output_dtype("MULTINOMIAL", output_dtype, ["int32",
"int64"])
+
+ seed, seed2 = self._get_random_options(op)
+ uniform_sample = relax.op.call_dps_packed(
+ "tvm.contrib.random.uniform",
+ (seed, seed2, 0.0, 1.0),
+ out_sinfo=relax.TensorStructInfo([output_batch, 1], "float32"),
+ )
+ sample_indices = relax.op.reshape(
+ relax.op.broadcast_to(
+ relax.op.expand_dims(relax.op.arange(batch_size,
dtype="int64"), axis=[1]),
+ relax.ShapeExpr([batch_size, num_samples]),
+ ),
+ relax.ShapeExpr([output_batch, 1]),
+ )
+ sampled = relax.op.multinomial_from_uniform(
+ relax.op.nn.softmax(logits_expr, axis=-1),
+ uniform_sample,
+ sample_indices,
+ dtype=output_dtype,
+ )
+ return relax.op.reshape(sampled, relax.ShapeExpr([batch_size,
num_samples]))
def _convert_reduce(self, relax_op, op):
"""Generic method to Convert TFLite REDUCE operators"""
diff --git a/src/runtime/contrib/random/random.cc
b/src/runtime/contrib/random/random.cc
index af94f97ef1..5f6be9a0d6 100644
--- a/src/runtime/contrib/random/random.cc
+++ b/src/runtime/contrib/random/random.cc
@@ -26,6 +26,7 @@
#include <tvm/runtime/data_type.h>
#include <algorithm>
+#include <cstdint>
#include "mt_random_engine.cc"
@@ -79,6 +80,44 @@ RandomThreadLocalEntry*
RandomThreadLocalEntry::ThreadLocal() {
return &inst;
}
+namespace {
+
+unsigned CombineSeeds(int64_t seed, int64_t seed2) {
+ auto mix = [](uint64_t value) {
+ value += 0x9e3779b97f4a7c15ULL;
+ value = (value ^ (value >> 30)) * 0xbf58476d1ce4e5b9ULL;
+ value = (value ^ (value >> 27)) * 0x94d049bb133111ebULL;
+ return value ^ (value >> 31);
+ };
+
+ uint64_t seed_bits = static_cast<uint64_t>(seed);
+ uint64_t seed2_bits = static_cast<uint64_t>(seed2);
+ uint64_t combined = mix(seed_bits) ^ (mix(seed2_bits) +
0x9e3779b97f4a7c15ULL + (seed_bits << 6) +
+ (seed_bits >> 2));
+ return static_cast<unsigned>((combined >> 32) ^ combined);
+}
+
+RandomEngine* GetRandomEngineForArgs(const ffi::PackedArgs& args, int
seed_idx, int seed2_idx) {
+ if (args.size() <= seed2_idx) {
+ return &RandomThreadLocalEntry::ThreadLocal()->random_engine;
+ }
+
+ int64_t seed = args[seed_idx].cast<int64_t>();
+ int64_t seed2 = args[seed2_idx].cast<int64_t>();
+ if (seed == 0 && seed2 == 0) {
+ // TF treats seed=0 and seed2=0 as non-deterministic, so use the global
engine.
+ return &RandomThreadLocalEntry::ThreadLocal()->random_engine;
+ }
+
+ // Seeded TFLite random ops use stateless semantics: identical seed pairs
re-seed the engine
+ // and produce identical outputs on every invocation.
+ static thread_local RandomEngine seeded_engine;
+ seeded_engine.Seed(CombineSeeds(seed, seed2));
+ return &seeded_engine;
+}
+
+} // namespace
+
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
@@ -118,19 +157,53 @@ TVM_FFI_STATIC_INIT_BLOCK() {
})
.def_packed("tvm.contrib.random.uniform",
[](ffi::PackedArgs args, ffi::Any* ret) {
- RandomThreadLocalEntry* entry =
RandomThreadLocalEntry::ThreadLocal();
- double low = args[0].cast<double>();
- double high = args[1].cast<double>();
- auto out = args[2].cast<DLTensor*>();
- entry->random_engine.SampleUniform(out, low, high);
+ RandomEngine* engine = nullptr;
+ double low = 0.0;
+ double high = 0.0;
+ DLTensor* out = nullptr;
+
+ if (args.size() == 3) {
+ engine =
&RandomThreadLocalEntry::ThreadLocal()->random_engine;
+ low = args[0].cast<double>();
+ high = args[1].cast<double>();
+ out = args[2].cast<DLTensor*>();
+ } else if (args.size() == 5) {
+ engine = GetRandomEngineForArgs(args, 0, 1);
+ low = args[2].cast<double>();
+ high = args[3].cast<double>();
+ out = args[4].cast<DLTensor*>();
+ } else {
+ TVM_FFI_THROW(InternalError)
+ << "tvm.contrib.random.uniform expects either 3 or 5
arguments, but got "
+ << args.size();
+ }
+
+ engine->SampleUniform(out, low, high);
})
.def_packed("tvm.contrib.random.normal",
[](ffi::PackedArgs args, ffi::Any* ret) {
- RandomThreadLocalEntry* entry =
RandomThreadLocalEntry::ThreadLocal();
- double loc = args[0].cast<double>();
- double scale = args[1].cast<double>();
- auto out = args[2].cast<DLTensor*>();
- entry->random_engine.SampleNormal(out, loc, scale);
+ RandomEngine* engine = nullptr;
+ double loc = 0.0;
+ double scale = 0.0;
+ DLTensor* out = nullptr;
+
+ if (args.size() == 3) {
+ engine =
&RandomThreadLocalEntry::ThreadLocal()->random_engine;
+ loc = args[0].cast<double>();
+ scale = args[1].cast<double>();
+ out = args[2].cast<DLTensor*>();
+ } else if (args.size() == 5) {
+ engine = GetRandomEngineForArgs(args, 0, 1);
+ loc = args[2].cast<double>();
+ scale = args[3].cast<double>();
+ out = args[4].cast<DLTensor*>();
+ } else {
+ TVM_FFI_THROW(InternalError)
+ << "tvm.contrib.random.normal expects either 3 or 5
arguments, but got "
+ << args.size();
+ }
+
+ engine->SampleNormal(out, loc, scale);
})
.def_packed("tvm.contrib.random.random_fill",
[](ffi::PackedArgs args, ffi::Any* ret) {
diff --git a/tests/python/relax/test_frontend_tflite.py
b/tests/python/relax/test_frontend_tflite.py
index f4a0612705..418a7665a7 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -94,6 +94,42 @@ def verify(TestClass, expected=None):
np.testing.assert_allclose(tf_output.numpy(), tvm_output.numpy(),
rtol=1e-5, atol=1e-5)
+def _verify_random_with_inputs(cfunc, inputs):
+ """E2E verify random ops by shape/dtype and TVM seeded self-consistency."""
+ if "CI_ENV_NIGHTLY" not in os.environ:
+ return
+
+ mod = _get_mod_from_cfunc(cfunc)
+ tvm_inputs = [np.asarray(data) for data in inputs]
+ tf_inputs = [tf.constant(data) for data in tvm_inputs]
+
+ tf_output = cfunc(*tf_inputs)
+
+ tgt = tvm.target.Target("llvm")
+ ex = tvm.compile(mod, tgt)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+
+ def run_tvm():
+ vm.set_input("main", *tvm_inputs)
+ vm.invoke_stateful("main")
+ return vm.get_outputs("main")
+
+ tvm_output = run_tvm()
+ tvm_output_again = run_tvm()
+
+ if not isinstance(tf_output, tuple):
+ tf_output = (tf_output,)
+ tvm_output = (tvm_output,)
+ tvm_output_again = (tvm_output_again,)
+
+ for tf_out, tvm_out, tvm_out_again in zip(tf_output, tvm_output,
tvm_output_again):
+ tf_np = tf_out.numpy()
+ tvm_np = tvm_out.numpy()
+ assert tvm_np.shape == tf_np.shape
+ assert tvm_np.dtype == tf_np.dtype
+ np.testing.assert_equal(tvm_np, tvm_out_again.numpy())
+
+
def test_add_one_2d():
class AddOne2D(tf.Module):
@tf.function(input_signature=[tf.TensorSpec(shape=(2, 2),
dtype=tf.float32)])
@@ -760,6 +796,79 @@ def test_fill_dynamic_dims():
verify(cf)
+def test_random_uniform_dynamic_shape():
+ """RANDOM_UNIFORM imports dynamic shape and validates random output
metadata."""
+
+ class TfRandomUniform(tf.Module):
+ @tf.function(input_signature=[tf.TensorSpec(shape=(2,),
dtype=tf.int32)])
+ def func(self, shape):
+ return tf.raw_ops.RandomUniform(shape=shape, dtype=tf.float32,
seed=7, seed2=11)
+
+ cf = TfRandomUniform().func.get_concrete_function()
+ mod = _get_mod_from_cfunc(cf)
+ ir = mod.script()
+ assert "R.tensor_to_shape" in ir
+ assert 'R.call_dps_packed("tvm.contrib.random.uniform"' in ir
+
+ _verify_random_with_inputs(cf, [np.array([2, 3], dtype="int32")])
+
+
+def test_random_standard_normal_dynamic_shape():
+ """RANDOM_STANDARD_NORMAL imports dynamic shape and validates random
output metadata."""
+
+ class TfRandomStandardNormal(tf.Module):
+ @tf.function(input_signature=[tf.TensorSpec(shape=(2,),
dtype=tf.int32)])
+ def func(self, shape):
+ return tf.raw_ops.RandomStandardNormal(
+ shape=shape, dtype=tf.float32, seed=3, seed2=5
+ )
+
+ cf = TfRandomStandardNormal().func.get_concrete_function()
+ mod = _get_mod_from_cfunc(cf)
+ ir = mod.script()
+ assert "R.tensor_to_shape" in ir
+ assert 'R.call_dps_packed("tvm.contrib.random.normal"' in ir
+
+ _verify_random_with_inputs(cf, [np.array([2, 4], dtype="int32")])
+
+
+def test_multinomial_dynamic_num_samples():
+ """MULTINOMIAL lowers through seeded uniform sampling with dynamic
num_samples."""
+
+ class TfMultinomial(tf.Module):
+ @tf.function(
+ input_signature=[
+ tf.TensorSpec(shape=(2, 3), dtype=tf.float32),
+ tf.TensorSpec(shape=(), dtype=tf.int32),
+ ]
+ )
+ def func(self, logits, num_samples):
+ return tf.raw_ops.Multinomial(
+ logits=logits,
+ num_samples=num_samples,
+ output_dtype=tf.int64,
+ seed=13,
+ seed2=17,
+ )
+
+ cf = TfMultinomial().func.get_concrete_function()
+ mod = _get_mod_from_cfunc(cf)
+ ir = mod.script()
+ assert "R.nn.softmax" in ir
+ assert "R.multinomial_from_uniform" in ir
+ assert "R.tensor_to_shape" in ir
+ assert "multinomial_num_samples" in ir
+ assert 'R.call_dps_packed("tvm.contrib.random.uniform"' in ir
+
+ _verify_random_with_inputs(
+ cf,
+ [
+ np.array([[2.0, 1.0, 0.5], [0.1, 0.2, 3.0]], dtype="float32"),
+ np.array(4, dtype="int32"),
+ ],
+ )
+
+
@pytest.mark.parametrize(
"tf_op, relax_op",
[