This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 71c634f8b1 [Relax][Frontend][TFLite] Add `RANDOM_UNIFORM`, 
`RANDOM_STANDARD_NORMAL`, and `MULTINOMIAL` (#19473)
71c634f8b1 is described below

commit 71c634f8b1340abec6b2bd69bf2bea5f41cbb521
Author: HoYi <[email protected]>
AuthorDate: Fri May 1 18:50:07 2026 +0800

    [Relax][Frontend][TFLite] Add `RANDOM_UNIFORM`, `RANDOM_STANDARD_NORMAL`, 
and `MULTINOMIAL` (#19473)
    
    ## Summary
    
    This PR adds support for the TFLite `RANDOM_UNIFORM`,
    `RANDOM_STANDARD_NORMAL`, and `MULTINOMIAL` operators in the Relax
    TFLite frontend, covering the items I claimed in #19412.
    
    `RANDOM_UNIFORM` and `RANDOM_STANDARD_NORMAL` are lowered to seeded
    `tvm.contrib.random` calls with dynamic shape support. `MULTINOMIAL` is
    lowered by composing existing Relax ops around
    `relax.op.multinomial_from_uniform`.
    
      ## Changes
    
      ### Frontend
    1. Add converter registrations for `RANDOM_UNIFORM`,
    `RANDOM_STANDARD_NORMAL`, and `MULTINOMIAL`.
      2. Add shared helpers to:
         - parse TFLite `RandomOptions` seeds
    - convert shape tensors to Relax shape expressions for dynamic-shape
    random ops
      3. Lower `RANDOM_UNIFORM` to `tvm.contrib.random.uniform`.
      4. Lower `RANDOM_STANDARD_NORMAL` to `tvm.contrib.random.normal`.
      5. Lower `MULTINOMIAL` by:
         - applying `R.nn.softmax` to logits
         - generating seeded uniform samples
         - calling `R.multinomial_from_uniform`
         - reshaping the result to `[batch_size, num_samples]`
    
      ### Runtime
    1. Extend `src/runtime/contrib/random/random.cc` to accept seeded calls
    for `tvm.contrib.random.uniform` and `tvm.contrib.random.normal`.
    2. Preserve compatibility with the existing unseeded calling convention.
    
      ## Testing
    All tests pass:
      ```bash
    pytest
    
tests/python/relax/test_frontend_tflite.py::test_random_uniform_dynamic_shape
    \
    
    
tests/python/relax/test_frontend_tflite.py::test_random_standard_normal_dynamic_shape
    \
    
    
tests/python/relax/test_frontend_tflite.py::test_multinomial_dynamic_num_samples
    -v
    ```
      ## References
    
      - #19412
      - Claimed items: MULTINOMIAL, RANDOM_STANDARD_NORMAL, RANDOM_UNIFORM
---
 .../tvm/relax/frontend/tflite/tflite_frontend.py   | 153 +++++++++++++++++++--
 src/runtime/contrib/random/random.cc               |  93 +++++++++++--
 tests/python/relax/test_frontend_tflite.py         | 109 +++++++++++++++
 3 files changed, 331 insertions(+), 24 deletions(-)

diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py 
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 9d0fdaf587..05bda6816b 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -188,6 +188,7 @@ class OperatorConverter:
             "MINIMUM": functools.partial(self._convert_elemwise, 
relax_op=_op.minimum),
             "MIRROR_PAD": self.convert_mirror_pad,
             "MUL": functools.partial(self._convert_elemwise, 
relax_op=_op.multiply),
+            "MULTINOMIAL": self.convert_multinomial,
             "NEG": functools.partial(self._convert_unary_elemwise, 
relax_op=_op.negative),
             "NOT_EQUAL": functools.partial(
                 self._convert_elemwise, relax_op=_op.not_equal, 
comparison_op=True
@@ -200,6 +201,8 @@ class OperatorConverter:
             "PRELU": self.convert_prelu,
             "RANGE": self.convert_range,
             "QUANTIZE": self.convert_quantize,
+            "RANDOM_STANDARD_NORMAL": self.convert_random_standard_normal,
+            "RANDOM_UNIFORM": self.convert_random_uniform,
             "REDUCE_ALL": functools.partial(self._convert_reduce_bool, 
relax_op=_op.min),
             "REDUCE_ANY": functools.partial(self._convert_reduce_bool, 
relax_op=_op.max),
             "REDUCE_MAX": functools.partial(self._convert_reduce, 
relax_op=_op.max),
@@ -542,6 +545,22 @@ class OperatorConverter:
             return "bool"
         raise NotImplementedError(f"Tensor type {tensor_type!s} is currently 
not supported")
 
+    def _get_shape_expr_from_tensor(self, shape_tensor, prefix):
+        """Convert a TFLite shape tensor to a Relax shape expression."""
+        if self.has_expr(shape_tensor.tensor_idx):
+            dims_expr = self.get_expr(shape_tensor.tensor_idx)
+            dims_ndim = int(self.get_tensor_shape(shape_tensor)[0])
+            dims_dtype = self.get_tensor_type_str(shape_tensor.tensor.Type())
+            dims_expr = self.bb.match_cast(dims_expr, 
relax.TensorStructInfo([dims_ndim], dims_dtype))
+            dims_expr = self.bb.normalize(relax.op.astype(dims_expr, "int64"))
+            shape_dataflow_var = 
self.bb.emit(relax.op.tensor_to_shape(dims_expr))
+            shape_vars = [tirx.Var(f"{prefix}_{i}", "int64") for i in 
range(dims_ndim)]
+            self.bb.match_cast(shape_dataflow_var, 
relax.ShapeStructInfo(shape_vars))
+            return relax.ShapeExpr(shape_vars), shape_vars
+
+        dims = to_int_list(self.get_tensor_value(shape_tensor))
+        return dims, dims
+
     def flatten_to_nd(self, x, nd=3):
         """Flatten input tensor to nd rank"""
         shape = x.struct_info.shape
@@ -1783,23 +1802,129 @@ class OperatorConverter:
         dims_tensor = input_tensors[0]
         in_value_expr = self.get_expr(input_tensors[1].tensor_idx)
 
-        if self.has_expr(dims_tensor.tensor_idx):
-            dims_expr = self.get_expr(dims_tensor.tensor_idx)
-            dims_ndim = int(self.get_tensor_shape(dims_tensor)[0])
+        out_shape, _ = self._get_shape_expr_from_tensor(dims_tensor, 
"fill_dim")
+        out = relax.op.full(out_shape, in_value_expr)
 
-            # Bind runtime dims to fresh symbolic shape vars so the imported
-            # module remains well formed before LegalizeOps runs.
-            dims_expr = self.bb.match_cast(dims_expr, 
relax.TensorStructInfo([dims_ndim], "int32"))
-            dims_expr = self.bb.normalize(relax.op.astype(dims_expr, "int64"))
-            shape_dataflow_var = 
self.bb.emit(relax.op.tensor_to_shape(dims_expr))
-            shape_vars = [tirx.Var(f"fill_dim_{i}", "int64") for i in 
range(dims_ndim)]
-            self.bb.match_cast(shape_dataflow_var, 
relax.ShapeStructInfo(shape_vars))
-            out = relax.op.full(relax.ShapeExpr(shape_vars), in_value_expr)
+        return out
+
+    def _get_random_options(self, op):
+        """Return the seed pair for random TFLite operators.
+
+        The runtime imports seeded TFLite random ops with stateless semantics, 
so identical
+        non-zero seed pairs produce identical results on every invocation. The 
seed pair
+        (0, 0) is forwarded as the TF non-deterministic case.
+        """
+        from tflite.BuiltinOptions import BuiltinOptions
+        from tflite.RandomOptions import RandomOptions
+
+        if op.BuiltinOptionsType():
+            assert op.BuiltinOptionsType() == BuiltinOptions.RandomOptions
+            random_options = RandomOptions()
+            op_options = op.BuiltinOptions()
+            random_options.Init(op_options.Bytes, op_options.Pos)
+            return int(random_options.Seed()), int(random_options.Seed2())
+        return 0, 0
+
+    def _check_random_output_dtype(self, op_name, output_dtype, 
supported_dtypes):
+        if output_dtype not in supported_dtypes:
+            supported = ", ".join(supported_dtypes)
+            raise tvm.error.OpNotImplemented(
+                f"The TFLite {op_name} converter currently supports output 
dtype(s) "
+                f"{supported} only, but got {output_dtype}."
+            )
+
+    def convert_random_uniform(self, op):
+        """Convert TFLite RANDOM_UNIFORM using stateless seeded RNG 
semantics."""
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 1, "input tensors length should be 1"
+
+        output_tensors = self.get_output_tensors(op)
+        assert len(output_tensors) == 1, "output tensors length should be 1"
+        output_tensor = output_tensors[0]
+        output_dtype = self.get_tensor_type_str(output_tensor.tensor.Type())
+        self._check_random_output_dtype("RANDOM_UNIFORM", output_dtype, 
["float32"])
+
+        out_shape, _ = self._get_shape_expr_from_tensor(input_tensors[0], 
"random_uniform_dim")
+        seed, seed2 = self._get_random_options(op)
+        return relax.op.call_dps_packed(
+            "tvm.contrib.random.uniform",
+            (seed, seed2, 0.0, 1.0),
+            out_sinfo=relax.TensorStructInfo(out_shape, output_dtype),
+        )
+
+    def convert_random_standard_normal(self, op):
+        """Convert TFLite RANDOM_STANDARD_NORMAL using stateless seeded RNG 
semantics."""
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 1, "input tensors length should be 1"
+
+        output_tensors = self.get_output_tensors(op)
+        assert len(output_tensors) == 1, "output tensors length should be 1"
+        output_tensor = output_tensors[0]
+        output_dtype = self.get_tensor_type_str(output_tensor.tensor.Type())
+        self._check_random_output_dtype("RANDOM_STANDARD_NORMAL", 
output_dtype, ["float32"])
+
+        out_shape, _ = self._get_shape_expr_from_tensor(
+            input_tensors[0], "random_standard_normal_dim"
+        )
+        seed, seed2 = self._get_random_options(op)
+        return relax.op.call_dps_packed(
+            "tvm.contrib.random.normal",
+            (seed, seed2, 0.0, 1.0),
+            out_sinfo=relax.TensorStructInfo(out_shape, output_dtype),
+        )
+
+    def convert_multinomial(self, op):
+        """Convert TFLite MULTINOMIAL using stateless seeded RNG semantics."""
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 2, "input tensors length should be 2"
+
+        logits_tensor, num_samples_tensor = input_tensors
+        logits_expr = self.get_tensor_expr(logits_tensor)
+        batch_size = self.get_tensor_shape(logits_tensor)[0]
+        if self.has_expr(num_samples_tensor.tensor_idx):
+            scalar_expr = self.get_expr(num_samples_tensor.tensor_idx)
+            scalar_dtype = 
self.get_tensor_type_str(num_samples_tensor.tensor.Type())
+            scalar_expr = self.bb.match_cast(scalar_expr, 
relax.TensorStructInfo([], scalar_dtype))
+            scalar_expr = self.bb.normalize(relax.op.astype(scalar_expr, 
"int64"))
+            scalar_expr = self.bb.normalize(relax.op.reshape(scalar_expr, [1]))
+            shape_dataflow_var = 
self.bb.emit(relax.op.tensor_to_shape(scalar_expr))
+            num_samples = tirx.Var("multinomial_num_samples", "int64")
+            self.bb.match_cast(shape_dataflow_var, 
relax.ShapeStructInfo([num_samples]))
         else:
-            in_dims = list(self.get_tensor_value(dims_tensor))
-            out = relax.op.full(in_dims, in_value_expr)
+            value = self.get_tensor_value(num_samples_tensor)
+            assert value.size == 1, (
+                "TFLite MULTINOMIAL num_samples must be a scalar tensor, "
+                f"but got {value.size} values"
+            )
+            num_samples = int(value.item())
+        output_batch = batch_size * num_samples
 
-        return out
+        output_tensors = self.get_output_tensors(op)
+        assert len(output_tensors) == 1, "output tensors length should be 1"
+        output_tensor = output_tensors[0]
+        output_dtype = self.get_tensor_type_str(output_tensor.tensor.Type())
+        self._check_random_output_dtype("MULTINOMIAL", output_dtype, ["int32", 
"int64"])
+
+        seed, seed2 = self._get_random_options(op)
+        uniform_sample = relax.op.call_dps_packed(
+            "tvm.contrib.random.uniform",
+            (seed, seed2, 0.0, 1.0),
+            out_sinfo=relax.TensorStructInfo([output_batch, 1], "float32"),
+        )
+        sample_indices = relax.op.reshape(
+            relax.op.broadcast_to(
+                relax.op.expand_dims(relax.op.arange(batch_size, 
dtype="int64"), axis=[1]),
+                relax.ShapeExpr([batch_size, num_samples]),
+            ),
+            relax.ShapeExpr([output_batch, 1]),
+        )
+        sampled = relax.op.multinomial_from_uniform(
+            relax.op.nn.softmax(logits_expr, axis=-1),
+            uniform_sample,
+            sample_indices,
+            dtype=output_dtype,
+        )
+        return relax.op.reshape(sampled, relax.ShapeExpr([batch_size, 
num_samples]))
 
     def _convert_reduce(self, relax_op, op):
         """Generic method to Convert TFLite REDUCE operators"""
diff --git a/src/runtime/contrib/random/random.cc 
b/src/runtime/contrib/random/random.cc
index af94f97ef1..5f6be9a0d6 100644
--- a/src/runtime/contrib/random/random.cc
+++ b/src/runtime/contrib/random/random.cc
@@ -26,6 +26,7 @@
 #include <tvm/runtime/data_type.h>
 
 #include <algorithm>
+#include <cstdint>
 
 #include "mt_random_engine.cc"
 
@@ -79,6 +80,44 @@ RandomThreadLocalEntry* 
RandomThreadLocalEntry::ThreadLocal() {
   return &inst;
 }
 
+namespace {
+
+unsigned CombineSeeds(int64_t seed, int64_t seed2) {
+  auto mix = [](uint64_t value) {
+    value += 0x9e3779b97f4a7c15ULL;
+    value = (value ^ (value >> 30)) * 0xbf58476d1ce4e5b9ULL;
+    value = (value ^ (value >> 27)) * 0x94d049bb133111ebULL;
+    return value ^ (value >> 31);
+  };
+
+  uint64_t seed_bits = static_cast<uint64_t>(seed);
+  uint64_t seed2_bits = static_cast<uint64_t>(seed2);
+  uint64_t combined = mix(seed_bits) ^ (mix(seed2_bits) + 
0x9e3779b97f4a7c15ULL + (seed_bits << 6) +
+                                        (seed_bits >> 2));
+  return static_cast<unsigned>((combined >> 32) ^ combined);
+}
+
+RandomEngine* GetRandomEngineForArgs(const ffi::PackedArgs& args, int 
seed_idx, int seed2_idx) {
+  if (args.size() <= seed2_idx) {
+    return &RandomThreadLocalEntry::ThreadLocal()->random_engine;
+  }
+
+  int64_t seed = args[seed_idx].cast<int64_t>();
+  int64_t seed2 = args[seed2_idx].cast<int64_t>();
+  if (seed == 0 && seed2 == 0) {
+    // TF treats seed=0 and seed2=0 as non-deterministic, so use the global 
engine.
+    return &RandomThreadLocalEntry::ThreadLocal()->random_engine;
+  }
+
+  // Seeded TFLite random ops use stateless semantics: identical seed pairs 
re-seed the engine
+  // and produce identical outputs on every invocation.
+  static thread_local RandomEngine seeded_engine;
+  seeded_engine.Seed(CombineSeeds(seed, seed2));
+  return &seeded_engine;
+}
+
+}  // namespace
+
 TVM_FFI_STATIC_INIT_BLOCK() {
   namespace refl = tvm::ffi::reflection;
   refl::GlobalDef()
@@ -118,19 +157,53 @@ TVM_FFI_STATIC_INIT_BLOCK() {
                   })
       .def_packed("tvm.contrib.random.uniform",
                   [](ffi::PackedArgs args, ffi::Any* ret) {
-                    RandomThreadLocalEntry* entry = 
RandomThreadLocalEntry::ThreadLocal();
-                    double low = args[0].cast<double>();
-                    double high = args[1].cast<double>();
-                    auto out = args[2].cast<DLTensor*>();
-                    entry->random_engine.SampleUniform(out, low, high);
+                    RandomEngine* engine = nullptr;
+                    double low = 0.0;
+                    double high = 0.0;
+                    DLTensor* out = nullptr;
+
+                    if (args.size() == 3) {
+                      engine = 
&RandomThreadLocalEntry::ThreadLocal()->random_engine;
+                      low = args[0].cast<double>();
+                      high = args[1].cast<double>();
+                      out = args[2].cast<DLTensor*>();
+                    } else if (args.size() == 5) {
+                      engine = GetRandomEngineForArgs(args, 0, 1);
+                      low = args[2].cast<double>();
+                      high = args[3].cast<double>();
+                      out = args[4].cast<DLTensor*>();
+                    } else {
+                      TVM_FFI_THROW(InternalError)
+                          << "tvm.contrib.random.uniform expects either 3 or 5 
arguments, but got "
+                          << args.size();
+                    }
+
+                    engine->SampleUniform(out, low, high);
                   })
       .def_packed("tvm.contrib.random.normal",
                   [](ffi::PackedArgs args, ffi::Any* ret) {
-                    RandomThreadLocalEntry* entry = 
RandomThreadLocalEntry::ThreadLocal();
-                    double loc = args[0].cast<double>();
-                    double scale = args[1].cast<double>();
-                    auto out = args[2].cast<DLTensor*>();
-                    entry->random_engine.SampleNormal(out, loc, scale);
+                    RandomEngine* engine = nullptr;
+                    double loc = 0.0;
+                    double scale = 0.0;
+                    DLTensor* out = nullptr;
+
+                    if (args.size() == 3) {
+                      engine = 
&RandomThreadLocalEntry::ThreadLocal()->random_engine;
+                      loc = args[0].cast<double>();
+                      scale = args[1].cast<double>();
+                      out = args[2].cast<DLTensor*>();
+                    } else if (args.size() == 5) {
+                      engine = GetRandomEngineForArgs(args, 0, 1);
+                      loc = args[2].cast<double>();
+                      scale = args[3].cast<double>();
+                      out = args[4].cast<DLTensor*>();
+                    } else {
+                      TVM_FFI_THROW(InternalError)
+                          << "tvm.contrib.random.normal expects either 3 or 5 
arguments, but got "
+                          << args.size();
+                    }
+
+                    engine->SampleNormal(out, loc, scale);
                   })
       .def_packed("tvm.contrib.random.random_fill",
                   [](ffi::PackedArgs args, ffi::Any* ret) {
diff --git a/tests/python/relax/test_frontend_tflite.py 
b/tests/python/relax/test_frontend_tflite.py
index f4a0612705..418a7665a7 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -94,6 +94,42 @@ def verify(TestClass, expected=None):
         np.testing.assert_allclose(tf_output.numpy(), tvm_output.numpy(), 
rtol=1e-5, atol=1e-5)
 
 
+def _verify_random_with_inputs(cfunc, inputs):
+    """E2E verify random ops by shape/dtype and TVM seeded self-consistency."""
+    if "CI_ENV_NIGHTLY" not in os.environ:
+        return
+
+    mod = _get_mod_from_cfunc(cfunc)
+    tvm_inputs = [np.asarray(data) for data in inputs]
+    tf_inputs = [tf.constant(data) for data in tvm_inputs]
+
+    tf_output = cfunc(*tf_inputs)
+
+    tgt = tvm.target.Target("llvm")
+    ex = tvm.compile(mod, tgt)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+
+    def run_tvm():
+        vm.set_input("main", *tvm_inputs)
+        vm.invoke_stateful("main")
+        return vm.get_outputs("main")
+
+    tvm_output = run_tvm()
+    tvm_output_again = run_tvm()
+
+    if not isinstance(tf_output, tuple):
+        tf_output = (tf_output,)
+        tvm_output = (tvm_output,)
+        tvm_output_again = (tvm_output_again,)
+
+    for tf_out, tvm_out, tvm_out_again in zip(tf_output, tvm_output, 
tvm_output_again):
+        tf_np = tf_out.numpy()
+        tvm_np = tvm_out.numpy()
+        assert tvm_np.shape == tf_np.shape
+        assert tvm_np.dtype == tf_np.dtype
+        np.testing.assert_equal(tvm_np, tvm_out_again.numpy())
+
+
 def test_add_one_2d():
     class AddOne2D(tf.Module):
         @tf.function(input_signature=[tf.TensorSpec(shape=(2, 2), 
dtype=tf.float32)])
@@ -760,6 +796,79 @@ def test_fill_dynamic_dims():
     verify(cf)
 
 
+def test_random_uniform_dynamic_shape():
+    """RANDOM_UNIFORM imports dynamic shape and validates random output 
metadata."""
+
+    class TfRandomUniform(tf.Module):
+        @tf.function(input_signature=[tf.TensorSpec(shape=(2,), 
dtype=tf.int32)])
+        def func(self, shape):
+            return tf.raw_ops.RandomUniform(shape=shape, dtype=tf.float32, 
seed=7, seed2=11)
+
+    cf = TfRandomUniform().func.get_concrete_function()
+    mod = _get_mod_from_cfunc(cf)
+    ir = mod.script()
+    assert "R.tensor_to_shape" in ir
+    assert 'R.call_dps_packed("tvm.contrib.random.uniform"' in ir
+
+    _verify_random_with_inputs(cf, [np.array([2, 3], dtype="int32")])
+
+
+def test_random_standard_normal_dynamic_shape():
+    """RANDOM_STANDARD_NORMAL imports dynamic shape and validates random 
output metadata."""
+
+    class TfRandomStandardNormal(tf.Module):
+        @tf.function(input_signature=[tf.TensorSpec(shape=(2,), 
dtype=tf.int32)])
+        def func(self, shape):
+            return tf.raw_ops.RandomStandardNormal(
+                shape=shape, dtype=tf.float32, seed=3, seed2=5
+            )
+
+    cf = TfRandomStandardNormal().func.get_concrete_function()
+    mod = _get_mod_from_cfunc(cf)
+    ir = mod.script()
+    assert "R.tensor_to_shape" in ir
+    assert 'R.call_dps_packed("tvm.contrib.random.normal"' in ir
+
+    _verify_random_with_inputs(cf, [np.array([2, 4], dtype="int32")])
+
+
+def test_multinomial_dynamic_num_samples():
+    """MULTINOMIAL lowers through seeded uniform sampling with dynamic 
num_samples."""
+
+    class TfMultinomial(tf.Module):
+        @tf.function(
+            input_signature=[
+                tf.TensorSpec(shape=(2, 3), dtype=tf.float32),
+                tf.TensorSpec(shape=(), dtype=tf.int32),
+            ]
+        )
+        def func(self, logits, num_samples):
+            return tf.raw_ops.Multinomial(
+                logits=logits,
+                num_samples=num_samples,
+                output_dtype=tf.int64,
+                seed=13,
+                seed2=17,
+            )
+
+    cf = TfMultinomial().func.get_concrete_function()
+    mod = _get_mod_from_cfunc(cf)
+    ir = mod.script()
+    assert "R.nn.softmax" in ir
+    assert "R.multinomial_from_uniform" in ir
+    assert "R.tensor_to_shape" in ir
+    assert "multinomial_num_samples" in ir
+    assert 'R.call_dps_packed("tvm.contrib.random.uniform"' in ir
+
+    _verify_random_with_inputs(
+        cf,
+        [
+            np.array([[2.0, 1.0, 0.5], [0.1, 0.2, 3.0]], dtype="float32"),
+            np.array(4, dtype="int32"),
+        ],
+    )
+
+
 @pytest.mark.parametrize(
     "tf_op, relax_op",
     [

Reply via email to