This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e89570fa83 [Relax][Frontend][TFLite] Add RNN converter (#19632)
e89570fa83 is described below
commit e89570fa8321ebf9951e6d8f512eda764216b256
Author: YinHanke <[email protected]>
AuthorDate: Fri May 29 14:17:15 2026 +0800
[Relax][Frontend][TFLite] Add RNN converter (#19632)
## Summary
Add Relax TFLite frontend support for `RNN` (BuiltinOperator 23),
claimed in [#19519](https://github.com/apache/tvm/issues/19519) Group A.
Single-step RNN cell:
```
h = fused_activation(x @ W.T + h @ Wr.T + b)
```
## Changes
- **Handler**: `convert_rnn` registered in `convert_map` (alphabetical,
after `RANGE`)
- **Inputs** (5): `input [batch, input_size]`, `input_weights
[num_units, input_size]`, `recurrent_weights [num_units, num_units]`,
`bias [num_units]`, `hidden_state [batch, num_units]` (variable,
zero-initialised)
- **Output**: `[batch, num_units]`
- **Activations**: all fused activations via
`convert_fused_activation_function`
- **Quantized**: raises `OpNotImplemented`
## Testing
Two tests added to `tests/python/relax/test_frontend_tflite.py`:
- `test_rnn_none_activation` — `tvm.ir.assert_structural_equal` with
identity weights, NONE activation
- `test_rnn_relu_activation` — shape check, random weights, RELU
activation
```bash
python -m pytest tests/python/relax/test_frontend_tflite.py -k rnn -v
```
## References
- Issue [#19519](https://github.com/apache/tvm/issues/19519) Group A:
Sequence / recurrent model operators
---
.../tvm/relax/frontend/tflite/tflite_frontend.py | 79 ++++++
tests/python/relax/test_frontend_tflite.py | 290 +++++++++++++++++++++
2 files changed, 369 insertions(+)
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 65c0faadc2..87f0f12b1b 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -273,6 +273,7 @@ class OperatorConverter:
"POW": functools.partial(self._convert_elemwise,
relax_op=_op.power),
"PRELU": self.convert_prelu,
"RANGE": self.convert_range,
+ "RNN": self.convert_rnn,
"QUANTIZE": self.convert_quantize,
"RANDOM_STANDARD_NORMAL": self.convert_random_standard_normal,
"RANDOM_UNIFORM": self.convert_random_uniform,
@@ -5044,6 +5045,84 @@ class OperatorConverter:
return squeezed
+ def convert_rnn(self, op):
+ """Convert TFLite RNN.
+
+ Single-step RNN cell.
+
+ Inputs (5 tensors):
+ [0] input [batch, input_size]
+ [1] input_weights [num_units, input_size]
+ [2] recurrent_weights [num_units, num_units]
+ [3] bias [num_units]
+ [4] hidden_state [batch, num_units] (variable,
zero-initialised)
+
+ Output:
+ [0] output [batch, num_units]
+
+ Cell equation:
+ h = fused_activation(x @ W.T + h @ Wr.T + b)
+ """
+ from tflite.BuiltinOptions import BuiltinOptions
+ from tflite.RNNOptions import RNNOptions
+
+ if self.is_quantized(op):
+ raise tvm.error.OpNotImplemented("TFLite quantized RNN is not
supported yet.")
+
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 5, "input tensors length should be 5"
+
+ input_tensor = input_tensors[0]
+ weights_tensor = input_tensors[1]
+ recurrent_tensor = input_tensors[2]
+ bias_tensor = input_tensors[3]
+ hidden_state_tensor = input_tensors[4]
+
+ output_tensors = self.get_output_tensors(op)
+ assert len(output_tensors) >= 1, "output tensors length should be at
least 1"
+
+ assert op.BuiltinOptionsType() == BuiltinOptions.RNNOptions
+ op_options = op.BuiltinOptions()
+ rnn_options = RNNOptions()
+ rnn_options.Init(op_options.Bytes, op_options.Pos)
+ fused_activation_fn = rnn_options.FusedActivationFunction()
+
+ # Constant weight/bias expressions.
+ weights_expr = self.get_tensor_expr(weights_tensor) # [num_units,
input_size]
+ recurrent_expr = self.get_tensor_expr(recurrent_tensor) # [num_units,
num_units]
+ bias_expr = self.get_tensor_expr(bias_tensor) # [num_units]
+
+ # Transpose to [input_size, num_units] and [num_units, num_units] for
x @ W.T.
+ w_t = relax.op.permute_dims(weights_expr)
+ wr_t = relax.op.permute_dims(recurrent_expr)
+
+ # Resolve the input expression.
+ in_expr = self.get_tensor_expr(input_tensor)
+
+ # Initial hidden state: use the model's tensor value when available
(non-zero init or
+ # graph input), otherwise fall back to zeros for the common
variable-tensor case.
+ h_dtype = self.get_tensor_type_str(hidden_state_tensor.tensor.Type())
+ if self.has_expr(hidden_state_tensor.tensor_idx) or (
+ hidden_state_tensor.buffer is not None and
hidden_state_tensor.buffer.DataLength() > 0
+ ):
+ h = self.get_tensor_expr(hidden_state_tensor)
+ else:
+ h_shape =
tuple(to_int_list(self.get_tensor_shape(hidden_state_tensor)))
+ h = relax.op.zeros(h_shape, dtype=h_dtype)
+
+ gates = relax.op.add(
+ relax.op.add(relax.op.matmul(in_expr, w_t), relax.op.matmul(h,
wr_t)),
+ bias_expr,
+ )
+ h = self.convert_fused_activation_function(gates, fused_activation_fn)
+
+ self.exp_tab.set_expr(
+ get_tensor_name(self.subgraph, hidden_state_tensor.tensor_idx),
+ h,
+ force_override=True,
+ )
+ return h
+
def convert_unidirectional_sequence_rnn(self, op):
"""Convert TFLite UNIDIRECTIONAL_SEQUENCE_RNN.
diff --git a/tests/python/relax/test_frontend_tflite.py
b/tests/python/relax/test_frontend_tflite.py
index 9e91c09c2d..7c5951d631 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -3722,6 +3722,7 @@ _tfl_reduce_window_function =
_get_tflite_schema_enum("ReduceWindowFunction")
_tfl_sparse_index_vector = _get_tflite_schema_enum("SparseIndexVector")
_tfl_tensor_type = _get_tflite_schema_enum("TensorType")
+_tfl_rnn_options = _get_tflite_schema_module("RNNOptions")
_tfl_sequence_rnn_options = _get_tflite_schema_module("SequenceRNNOptions")
_DENSIFY_TEST_VALUES = np.array([1.0, 2.0], dtype=np.float32)
@@ -10127,6 +10128,295 @@ def test_dilate_dynamic_dilations():
tvm.ir.assert_structural_equal(mod, Expected)
+# ── RNN
────────────────────────────────────────────────────────────────────────
+
+
+def _build_rnn_model(batch, input_size, num_units, weights, recurrent_weights,
bias, activation):
+ """Build a minimal TFLite flatbuffer model containing one RNN op.
+
+ Tensor layout (indices 0-5):
+ 0 - input [batch, input_size]
+ 1 - input_weights [num_units, input_size] (constant)
+ 2 - recurrent_weights [num_units, num_units] (constant)
+ 3 - bias [num_units] (constant)
+ 4 - hidden_state [batch, num_units] (variable,
zero-initialised)
+ 5 - output [batch, num_units]
+ """
+ builder = flatbuffers.Builder(4096)
+
+ _tfl_rnn_options.RNNOptionsStart(builder)
+ _tfl_rnn_options.RNNOptionsAddFusedActivationFunction(builder, activation)
+ rnn_opts = _tfl_rnn_options.RNNOptionsEnd(builder)
+
+ rnn_op_code = _build_operator_code(builder, _tfl_builtin_operator.RNN)
+
+ def _t(buf_idx, shape, is_variable=False):
+ shape_vec = _tflite_shape(builder, shape)
+ _tfl_tensor.TensorStart(builder)
+ _tfl_tensor.TensorAddBuffer(builder, buf_idx)
+ _tfl_tensor.TensorAddHasRank(builder, True)
+ _tfl_tensor.TensorAddIsVariable(builder, is_variable)
+ _tfl_tensor.TensorAddShape(builder, shape_vec)
+ _tfl_tensor.TensorAddType(builder, _tfl_tensor_type.FLOAT32)
+ return _tfl_tensor.TensorEnd(builder)
+
+ tensors = [
+ _t(0, [batch, input_size]),
+ _t(1, [num_units, input_size]),
+ _t(2, [num_units, num_units]),
+ _t(3, [num_units]),
+ _t(4, [batch, num_units], is_variable=True),
+ _t(5, [batch, num_units]),
+ ]
+
+ rnn_op = _build_operator(
+ builder,
+ 0,
+ [0, 1, 2, 3, 4],
+ [5],
+ builtin_options_type=_tfl_builtin_options.RNNOptions,
+ builtin_options=rnn_opts,
+ )
+
+ subgraph = _build_subgraph(
+ builder,
+ tensors=tensors,
+ operators=[rnn_op],
+ inputs=[0],
+ outputs=[5],
+ )
+
+ buffers = [
+ _build_buffer(builder),
+ _build_buffer(builder, weights.tobytes()),
+ _build_buffer(builder, recurrent_weights.tobytes()),
+ _build_buffer(builder, bias.tobytes()),
+ _build_buffer(builder),
+ _build_buffer(builder),
+ ]
+
+ return _finish_tflite_model(
+ builder,
+ subgraph=subgraph,
+ operator_codes=[rnn_op_code],
+ buffers=buffers,
+ )
+
+
+def _build_two_step_shared_state_rnn_model(
+ batch, input_size, num_units, weights, recurrent_weights, bias, activation
+):
+ """Build a TFLite model with two RNN ops sharing the same hidden-state
tensor."""
+ builder = flatbuffers.Builder(4096)
+
+ _tfl_rnn_options.RNNOptionsStart(builder)
+ _tfl_rnn_options.RNNOptionsAddFusedActivationFunction(builder, activation)
+ rnn_opts = _tfl_rnn_options.RNNOptionsEnd(builder)
+
+ rnn_op_code = _build_operator_code(builder, _tfl_builtin_operator.RNN)
+
+ def _t(buf_idx, shape, is_variable=False):
+ shape_vec = _tflite_shape(builder, shape)
+ _tfl_tensor.TensorStart(builder)
+ _tfl_tensor.TensorAddBuffer(builder, buf_idx)
+ _tfl_tensor.TensorAddHasRank(builder, True)
+ _tfl_tensor.TensorAddIsVariable(builder, is_variable)
+ _tfl_tensor.TensorAddShape(builder, shape_vec)
+ _tfl_tensor.TensorAddType(builder, _tfl_tensor_type.FLOAT32)
+ return _tfl_tensor.TensorEnd(builder)
+
+ tensors = [
+ _t(0, [batch, input_size]),
+ _t(1, [num_units, input_size]),
+ _t(2, [num_units, num_units]),
+ _t(3, [num_units]),
+ _t(4, [batch, num_units], is_variable=True),
+ _t(0, [batch, input_size]),
+ _t(0, [batch, num_units]),
+ _t(0, [batch, num_units]),
+ ]
+
+ first_rnn_op = _build_operator(
+ builder,
+ 0,
+ [0, 1, 2, 3, 4],
+ [6],
+ builtin_options_type=_tfl_builtin_options.RNNOptions,
+ builtin_options=rnn_opts,
+ )
+ second_rnn_op = _build_operator(
+ builder,
+ 0,
+ [5, 1, 2, 3, 4],
+ [7],
+ builtin_options_type=_tfl_builtin_options.RNNOptions,
+ builtin_options=rnn_opts,
+ )
+
+ subgraph = _build_subgraph(
+ builder,
+ tensors=tensors,
+ operators=[first_rnn_op, second_rnn_op],
+ inputs=[0, 5],
+ outputs=[7],
+ )
+
+ buffers = [
+ _build_buffer(builder),
+ _build_buffer(builder, weights.tobytes()),
+ _build_buffer(builder, recurrent_weights.tobytes()),
+ _build_buffer(builder, bias.tobytes()),
+ _build_buffer(builder),
+ ]
+
+ return _finish_tflite_model(
+ builder,
+ subgraph=subgraph,
+ operator_codes=[rnn_op_code],
+ buffers=buffers,
+ )
+
+
+def test_rnn_none_activation():
+ """RNN with NONE activation lowers to matmul/add.
+
+ Cell equation: h = x @ W.T + h @ Wr.T + b (no activation for NONE)
+ """
+ from tflite.ActivationFunctionType import ActivationFunctionType
+
+ batch, input_size, num_units = 2, 2, 2
+ weights = np.eye(num_units, input_size, dtype=np.float32)
+ recurrent_weights = np.eye(num_units, dtype=np.float32)
+ bias = np.zeros(num_units, dtype=np.float32)
+
+ mod = _load_model_from_buffer(
+ _build_rnn_model(
+ batch,
+ input_size,
+ num_units,
+ weights,
+ recurrent_weights,
+ bias,
+ ActivationFunctionType.NONE,
+ )
+ )
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((2, 2), dtype="float32")) -> R.Tensor((2, 2),
dtype="float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ lv: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
+ R.const(np.eye(2, dtype=np.float32)), axes=None
+ )
+ lv1: R.Tensor((2, 2), dtype="float32") = R.matmul(x, lv,
out_dtype="void")
+ lv2: R.Tensor((2, 2), dtype="float32") = R.zeros(R.shape([2,
2]), dtype="float32")
+ lv3: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
+ R.const(np.eye(2, dtype=np.float32)), axes=None
+ )
+ lv4: R.Tensor((2, 2), dtype="float32") = R.matmul(lv2, lv3,
out_dtype="void")
+ lv5: R.Tensor((2, 2), dtype="float32") = R.add(lv1, lv4)
+ gv: R.Tensor((2, 2), dtype="float32") = R.add(
+ lv5, R.const(np.zeros(2, dtype=np.float32))
+ )
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_rnn_relu_activation():
+ """RNN with RELU activation and random weights."""
+ from tflite.ActivationFunctionType import ActivationFunctionType
+
+ batch, input_size, num_units = 2, 4, 8
+ np.random.seed(42)
+ weights = np.random.randn(num_units, input_size).astype(np.float32)
+ recurrent_weights = np.random.randn(num_units,
num_units).astype(np.float32)
+ bias = np.random.randn(num_units).astype(np.float32)
+
+ mod = _load_model_from_buffer(
+ _build_rnn_model(
+ batch,
+ input_size,
+ num_units,
+ weights,
+ recurrent_weights,
+ bias,
+ ActivationFunctionType.RELU,
+ )
+ )
+
+ fn = mod["main"]
+ assert len(fn.params) == 1, "only the input should be a graph input"
+ in_shape = fn.params[0].struct_info.shape
+ assert tuple(int(d) for d in in_shape) == (batch, input_size)
+ out_shape = fn.ret_struct_info.shape
+ assert tuple(int(d) for d in out_shape) == (batch, num_units)
+
+
+def test_rnn_shared_hidden_state_updates_exp_tab():
+ """Two consecutive RNN ops sharing hidden_state should use the updated
state."""
+ from tflite.ActivationFunctionType import ActivationFunctionType
+
+ batch, input_size, num_units = 2, 2, 2
+ weights = np.eye(num_units, input_size, dtype=np.float32)
+ recurrent_weights = np.eye(num_units, dtype=np.float32)
+ bias = np.zeros(num_units, dtype=np.float32)
+
+ mod = _load_model_from_buffer(
+ _build_two_step_shared_state_rnn_model(
+ batch,
+ input_size,
+ num_units,
+ weights,
+ recurrent_weights,
+ bias,
+ ActivationFunctionType.NONE,
+ )
+ )
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x0: R.Tensor((2, 2), dtype="float32"),
+ x1: R.Tensor((2, 2), dtype="float32"),
+ ) -> R.Tensor((2, 2), dtype="float32"):
+ R.func_attr({"num_input": 2})
+ with R.dataflow():
+ lv: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
+ R.const(np.eye(2, dtype=np.float32)), axes=None
+ )
+ lv1: R.Tensor((2, 2), dtype="float32") = R.matmul(x0, lv,
out_dtype="void")
+ lv2: R.Tensor((2, 2), dtype="float32") = R.zeros(R.shape([2,
2]), dtype="float32")
+ lv3: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
+ R.const(np.eye(2, dtype=np.float32)), axes=None
+ )
+ lv4: R.Tensor((2, 2), dtype="float32") = R.matmul(lv2, lv3,
out_dtype="void")
+ lv5: R.Tensor((2, 2), dtype="float32") = R.add(lv1, lv4)
+ lv6: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
+ R.const(np.eye(2, dtype=np.float32)), axes=None
+ )
+ lv7: R.Tensor((2, 2), dtype="float32") = R.matmul(x1, lv6,
out_dtype="void")
+ lv8: R.Tensor((2, 2), dtype="float32") = R.add(
+ lv5, R.const(np.zeros(2, dtype=np.float32))
+ )
+ lv9: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
+ R.const(np.eye(2, dtype=np.float32)), axes=None
+ )
+ lv10: R.Tensor((2, 2), dtype="float32") = R.matmul(lv8, lv9,
out_dtype="void")
+ lv11: R.Tensor((2, 2), dtype="float32") = R.add(lv7, lv10)
+ gv: R.Tensor((2, 2), dtype="float32") = R.add(
+ lv11, R.const(np.zeros(2, dtype=np.float32))
+ )
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
# ── UNIDIRECTIONAL_SEQUENCE_RNN
───────────────────────────────────────────────