This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e89570fa83 [Relax][Frontend][TFLite] Add RNN converter (#19632)
e89570fa83 is described below

commit e89570fa8321ebf9951e6d8f512eda764216b256
Author: YinHanke <[email protected]>
AuthorDate: Fri May 29 14:17:15 2026 +0800

    [Relax][Frontend][TFLite] Add RNN converter (#19632)
    
    ## Summary
    
    Add Relax TFLite frontend support for `RNN` (BuiltinOperator 23),
    claimed in [#19519](https://github.com/apache/tvm/issues/19519) Group A.
    
    Single-step RNN cell:
    ```
    h = fused_activation(x @ W.T + h @ Wr.T + b)
    ```
    
    ## Changes
    
    - **Handler**: `convert_rnn` registered in `convert_map` (alphabetical,
    after `RANGE`)
    - **Inputs** (5): `input [batch, input_size]`, `input_weights
    [num_units, input_size]`, `recurrent_weights [num_units, num_units]`,
    `bias [num_units]`, `hidden_state [batch, num_units]` (variable,
    zero-initialised)
    - **Output**: `[batch, num_units]`
    - **Activations**: all fused activations via
    `convert_fused_activation_function`
    - **Quantized**: raises `OpNotImplemented`
    
    ## Testing
    
    Two tests added to `tests/python/relax/test_frontend_tflite.py`:
    
    - `test_rnn_none_activation` — `tvm.ir.assert_structural_equal` with
    identity weights, NONE activation
    - `test_rnn_relu_activation` — shape check, random weights, RELU
    activation
    
    ```bash
    python -m pytest tests/python/relax/test_frontend_tflite.py -k rnn -v
    ```
    
    ## References
    
    - Issue [#19519](https://github.com/apache/tvm/issues/19519) Group A:
    Sequence / recurrent model operators
---
 .../tvm/relax/frontend/tflite/tflite_frontend.py   |  79 ++++++
 tests/python/relax/test_frontend_tflite.py         | 290 +++++++++++++++++++++
 2 files changed, 369 insertions(+)

diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py 
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 65c0faadc2..87f0f12b1b 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -273,6 +273,7 @@ class OperatorConverter:
             "POW": functools.partial(self._convert_elemwise, 
relax_op=_op.power),
             "PRELU": self.convert_prelu,
             "RANGE": self.convert_range,
+            "RNN": self.convert_rnn,
             "QUANTIZE": self.convert_quantize,
             "RANDOM_STANDARD_NORMAL": self.convert_random_standard_normal,
             "RANDOM_UNIFORM": self.convert_random_uniform,
@@ -5044,6 +5045,84 @@ class OperatorConverter:
 
         return squeezed
 
+    def convert_rnn(self, op):
+        """Convert TFLite RNN.
+
+        Single-step RNN cell.
+
+        Inputs (5 tensors):
+          [0] input             [batch, input_size]
+          [1] input_weights     [num_units, input_size]
+          [2] recurrent_weights [num_units, num_units]
+          [3] bias              [num_units]
+          [4] hidden_state      [batch, num_units]  (variable, 
zero-initialised)
+
+        Output:
+          [0] output  [batch, num_units]
+
+        Cell equation:
+          h = fused_activation(x @ W.T + h @ Wr.T + b)
+        """
+        from tflite.BuiltinOptions import BuiltinOptions
+        from tflite.RNNOptions import RNNOptions
+
+        if self.is_quantized(op):
+            raise tvm.error.OpNotImplemented("TFLite quantized RNN is not 
supported yet.")
+
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 5, "input tensors length should be 5"
+
+        input_tensor = input_tensors[0]
+        weights_tensor = input_tensors[1]
+        recurrent_tensor = input_tensors[2]
+        bias_tensor = input_tensors[3]
+        hidden_state_tensor = input_tensors[4]
+
+        output_tensors = self.get_output_tensors(op)
+        assert len(output_tensors) >= 1, "output tensors length should be at 
least 1"
+
+        assert op.BuiltinOptionsType() == BuiltinOptions.RNNOptions
+        op_options = op.BuiltinOptions()
+        rnn_options = RNNOptions()
+        rnn_options.Init(op_options.Bytes, op_options.Pos)
+        fused_activation_fn = rnn_options.FusedActivationFunction()
+
+        # Constant weight/bias expressions.
+        weights_expr = self.get_tensor_expr(weights_tensor)  # [num_units, 
input_size]
+        recurrent_expr = self.get_tensor_expr(recurrent_tensor)  # [num_units, 
num_units]
+        bias_expr = self.get_tensor_expr(bias_tensor)  # [num_units]
+
+        # Transpose to [input_size, num_units] and [num_units, num_units] for 
x @ W.T.
+        w_t = relax.op.permute_dims(weights_expr)
+        wr_t = relax.op.permute_dims(recurrent_expr)
+
+        # Resolve the input expression.
+        in_expr = self.get_tensor_expr(input_tensor)
+
+        # Initial hidden state: use the model's tensor value when available 
(non-zero init or
+        # graph input), otherwise fall back to zeros for the common 
variable-tensor case.
+        h_dtype = self.get_tensor_type_str(hidden_state_tensor.tensor.Type())
+        if self.has_expr(hidden_state_tensor.tensor_idx) or (
+            hidden_state_tensor.buffer is not None and 
hidden_state_tensor.buffer.DataLength() > 0
+        ):
+            h = self.get_tensor_expr(hidden_state_tensor)
+        else:
+            h_shape = 
tuple(to_int_list(self.get_tensor_shape(hidden_state_tensor)))
+            h = relax.op.zeros(h_shape, dtype=h_dtype)
+
+        gates = relax.op.add(
+            relax.op.add(relax.op.matmul(in_expr, w_t), relax.op.matmul(h, 
wr_t)),
+            bias_expr,
+        )
+        h = self.convert_fused_activation_function(gates, fused_activation_fn)
+
+        self.exp_tab.set_expr(
+            get_tensor_name(self.subgraph, hidden_state_tensor.tensor_idx),
+            h,
+            force_override=True,
+        )
+        return h
+
     def convert_unidirectional_sequence_rnn(self, op):
         """Convert TFLite UNIDIRECTIONAL_SEQUENCE_RNN.
 
diff --git a/tests/python/relax/test_frontend_tflite.py 
b/tests/python/relax/test_frontend_tflite.py
index 9e91c09c2d..7c5951d631 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -3722,6 +3722,7 @@ _tfl_reduce_window_function = 
_get_tflite_schema_enum("ReduceWindowFunction")
 _tfl_sparse_index_vector = _get_tflite_schema_enum("SparseIndexVector")
 _tfl_tensor_type = _get_tflite_schema_enum("TensorType")
 
+_tfl_rnn_options = _get_tflite_schema_module("RNNOptions")
 _tfl_sequence_rnn_options = _get_tflite_schema_module("SequenceRNNOptions")
 
 _DENSIFY_TEST_VALUES = np.array([1.0, 2.0], dtype=np.float32)
@@ -10127,6 +10128,295 @@ def test_dilate_dynamic_dilations():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+# ── RNN 
────────────────────────────────────────────────────────────────────────
+
+
+def _build_rnn_model(batch, input_size, num_units, weights, recurrent_weights, 
bias, activation):
+    """Build a minimal TFLite flatbuffer model containing one RNN op.
+
+    Tensor layout (indices 0-5):
+      0 - input             [batch, input_size]
+      1 - input_weights     [num_units, input_size]    (constant)
+      2 - recurrent_weights [num_units, num_units]     (constant)
+      3 - bias              [num_units]                (constant)
+      4 - hidden_state      [batch, num_units]         (variable, 
zero-initialised)
+      5 - output            [batch, num_units]
+    """
+    builder = flatbuffers.Builder(4096)
+
+    _tfl_rnn_options.RNNOptionsStart(builder)
+    _tfl_rnn_options.RNNOptionsAddFusedActivationFunction(builder, activation)
+    rnn_opts = _tfl_rnn_options.RNNOptionsEnd(builder)
+
+    rnn_op_code = _build_operator_code(builder, _tfl_builtin_operator.RNN)
+
+    def _t(buf_idx, shape, is_variable=False):
+        shape_vec = _tflite_shape(builder, shape)
+        _tfl_tensor.TensorStart(builder)
+        _tfl_tensor.TensorAddBuffer(builder, buf_idx)
+        _tfl_tensor.TensorAddHasRank(builder, True)
+        _tfl_tensor.TensorAddIsVariable(builder, is_variable)
+        _tfl_tensor.TensorAddShape(builder, shape_vec)
+        _tfl_tensor.TensorAddType(builder, _tfl_tensor_type.FLOAT32)
+        return _tfl_tensor.TensorEnd(builder)
+
+    tensors = [
+        _t(0, [batch, input_size]),
+        _t(1, [num_units, input_size]),
+        _t(2, [num_units, num_units]),
+        _t(3, [num_units]),
+        _t(4, [batch, num_units], is_variable=True),
+        _t(5, [batch, num_units]),
+    ]
+
+    rnn_op = _build_operator(
+        builder,
+        0,
+        [0, 1, 2, 3, 4],
+        [5],
+        builtin_options_type=_tfl_builtin_options.RNNOptions,
+        builtin_options=rnn_opts,
+    )
+
+    subgraph = _build_subgraph(
+        builder,
+        tensors=tensors,
+        operators=[rnn_op],
+        inputs=[0],
+        outputs=[5],
+    )
+
+    buffers = [
+        _build_buffer(builder),
+        _build_buffer(builder, weights.tobytes()),
+        _build_buffer(builder, recurrent_weights.tobytes()),
+        _build_buffer(builder, bias.tobytes()),
+        _build_buffer(builder),
+        _build_buffer(builder),
+    ]
+
+    return _finish_tflite_model(
+        builder,
+        subgraph=subgraph,
+        operator_codes=[rnn_op_code],
+        buffers=buffers,
+    )
+
+
+def _build_two_step_shared_state_rnn_model(
+    batch, input_size, num_units, weights, recurrent_weights, bias, activation
+):
+    """Build a TFLite model with two RNN ops sharing the same hidden-state 
tensor."""
+    builder = flatbuffers.Builder(4096)
+
+    _tfl_rnn_options.RNNOptionsStart(builder)
+    _tfl_rnn_options.RNNOptionsAddFusedActivationFunction(builder, activation)
+    rnn_opts = _tfl_rnn_options.RNNOptionsEnd(builder)
+
+    rnn_op_code = _build_operator_code(builder, _tfl_builtin_operator.RNN)
+
+    def _t(buf_idx, shape, is_variable=False):
+        shape_vec = _tflite_shape(builder, shape)
+        _tfl_tensor.TensorStart(builder)
+        _tfl_tensor.TensorAddBuffer(builder, buf_idx)
+        _tfl_tensor.TensorAddHasRank(builder, True)
+        _tfl_tensor.TensorAddIsVariable(builder, is_variable)
+        _tfl_tensor.TensorAddShape(builder, shape_vec)
+        _tfl_tensor.TensorAddType(builder, _tfl_tensor_type.FLOAT32)
+        return _tfl_tensor.TensorEnd(builder)
+
+    tensors = [
+        _t(0, [batch, input_size]),
+        _t(1, [num_units, input_size]),
+        _t(2, [num_units, num_units]),
+        _t(3, [num_units]),
+        _t(4, [batch, num_units], is_variable=True),
+        _t(0, [batch, input_size]),
+        _t(0, [batch, num_units]),
+        _t(0, [batch, num_units]),
+    ]
+
+    first_rnn_op = _build_operator(
+        builder,
+        0,
+        [0, 1, 2, 3, 4],
+        [6],
+        builtin_options_type=_tfl_builtin_options.RNNOptions,
+        builtin_options=rnn_opts,
+    )
+    second_rnn_op = _build_operator(
+        builder,
+        0,
+        [5, 1, 2, 3, 4],
+        [7],
+        builtin_options_type=_tfl_builtin_options.RNNOptions,
+        builtin_options=rnn_opts,
+    )
+
+    subgraph = _build_subgraph(
+        builder,
+        tensors=tensors,
+        operators=[first_rnn_op, second_rnn_op],
+        inputs=[0, 5],
+        outputs=[7],
+    )
+
+    buffers = [
+        _build_buffer(builder),
+        _build_buffer(builder, weights.tobytes()),
+        _build_buffer(builder, recurrent_weights.tobytes()),
+        _build_buffer(builder, bias.tobytes()),
+        _build_buffer(builder),
+    ]
+
+    return _finish_tflite_model(
+        builder,
+        subgraph=subgraph,
+        operator_codes=[rnn_op_code],
+        buffers=buffers,
+    )
+
+
+def test_rnn_none_activation():
+    """RNN with NONE activation lowers to matmul/add.
+
+    Cell equation: h = x @ W.T + h @ Wr.T + b  (no activation for NONE)
+    """
+    from tflite.ActivationFunctionType import ActivationFunctionType
+
+    batch, input_size, num_units = 2, 2, 2
+    weights = np.eye(num_units, input_size, dtype=np.float32)
+    recurrent_weights = np.eye(num_units, dtype=np.float32)
+    bias = np.zeros(num_units, dtype=np.float32)
+
+    mod = _load_model_from_buffer(
+        _build_rnn_model(
+            batch,
+            input_size,
+            num_units,
+            weights,
+            recurrent_weights,
+            bias,
+            ActivationFunctionType.NONE,
+        )
+    )
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((2, 2), dtype="float32")) -> R.Tensor((2, 2), 
dtype="float32"):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                lv: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
+                    R.const(np.eye(2, dtype=np.float32)), axes=None
+                )
+                lv1: R.Tensor((2, 2), dtype="float32") = R.matmul(x, lv, 
out_dtype="void")
+                lv2: R.Tensor((2, 2), dtype="float32") = R.zeros(R.shape([2, 
2]), dtype="float32")
+                lv3: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
+                    R.const(np.eye(2, dtype=np.float32)), axes=None
+                )
+                lv4: R.Tensor((2, 2), dtype="float32") = R.matmul(lv2, lv3, 
out_dtype="void")
+                lv5: R.Tensor((2, 2), dtype="float32") = R.add(lv1, lv4)
+                gv: R.Tensor((2, 2), dtype="float32") = R.add(
+                    lv5, R.const(np.zeros(2, dtype=np.float32))
+                )
+                R.output(gv)
+            return gv
+
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_rnn_relu_activation():
+    """RNN with RELU activation and random weights."""
+    from tflite.ActivationFunctionType import ActivationFunctionType
+
+    batch, input_size, num_units = 2, 4, 8
+    np.random.seed(42)
+    weights = np.random.randn(num_units, input_size).astype(np.float32)
+    recurrent_weights = np.random.randn(num_units, 
num_units).astype(np.float32)
+    bias = np.random.randn(num_units).astype(np.float32)
+
+    mod = _load_model_from_buffer(
+        _build_rnn_model(
+            batch,
+            input_size,
+            num_units,
+            weights,
+            recurrent_weights,
+            bias,
+            ActivationFunctionType.RELU,
+        )
+    )
+
+    fn = mod["main"]
+    assert len(fn.params) == 1, "only the input should be a graph input"
+    in_shape = fn.params[0].struct_info.shape
+    assert tuple(int(d) for d in in_shape) == (batch, input_size)
+    out_shape = fn.ret_struct_info.shape
+    assert tuple(int(d) for d in out_shape) == (batch, num_units)
+
+
+def test_rnn_shared_hidden_state_updates_exp_tab():
+    """Two consecutive RNN ops sharing hidden_state should use the updated 
state."""
+    from tflite.ActivationFunctionType import ActivationFunctionType
+
+    batch, input_size, num_units = 2, 2, 2
+    weights = np.eye(num_units, input_size, dtype=np.float32)
+    recurrent_weights = np.eye(num_units, dtype=np.float32)
+    bias = np.zeros(num_units, dtype=np.float32)
+
+    mod = _load_model_from_buffer(
+        _build_two_step_shared_state_rnn_model(
+            batch,
+            input_size,
+            num_units,
+            weights,
+            recurrent_weights,
+            bias,
+            ActivationFunctionType.NONE,
+        )
+    )
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x0: R.Tensor((2, 2), dtype="float32"),
+            x1: R.Tensor((2, 2), dtype="float32"),
+        ) -> R.Tensor((2, 2), dtype="float32"):
+            R.func_attr({"num_input": 2})
+            with R.dataflow():
+                lv: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
+                    R.const(np.eye(2, dtype=np.float32)), axes=None
+                )
+                lv1: R.Tensor((2, 2), dtype="float32") = R.matmul(x0, lv, 
out_dtype="void")
+                lv2: R.Tensor((2, 2), dtype="float32") = R.zeros(R.shape([2, 
2]), dtype="float32")
+                lv3: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
+                    R.const(np.eye(2, dtype=np.float32)), axes=None
+                )
+                lv4: R.Tensor((2, 2), dtype="float32") = R.matmul(lv2, lv3, 
out_dtype="void")
+                lv5: R.Tensor((2, 2), dtype="float32") = R.add(lv1, lv4)
+                lv6: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
+                    R.const(np.eye(2, dtype=np.float32)), axes=None
+                )
+                lv7: R.Tensor((2, 2), dtype="float32") = R.matmul(x1, lv6, 
out_dtype="void")
+                lv8: R.Tensor((2, 2), dtype="float32") = R.add(
+                    lv5, R.const(np.zeros(2, dtype=np.float32))
+                )
+                lv9: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
+                    R.const(np.eye(2, dtype=np.float32)), axes=None
+                )
+                lv10: R.Tensor((2, 2), dtype="float32") = R.matmul(lv8, lv9, 
out_dtype="void")
+                lv11: R.Tensor((2, 2), dtype="float32") = R.add(lv7, lv10)
+                gv: R.Tensor((2, 2), dtype="float32") = R.add(
+                    lv11, R.const(np.zeros(2, dtype=np.float32))
+                )
+                R.output(gv)
+            return gv
+
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 # ── UNIDIRECTIONAL_SEQUENCE_RNN 
───────────────────────────────────────────────
 
 

Reply via email to