This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new dcbebe7bfd [Relax][Frontend][TFLite] Add UNIDIRECTIONAL_SEQUENCE_RNN
converter (#19601)
dcbebe7bfd is described below
commit dcbebe7bfd2fe8ad45f501c00e058b74485824da
Author: YinHanke <[email protected]>
AuthorDate: Wed May 27 12:01:22 2026 +0800
[Relax][Frontend][TFLite] Add UNIDIRECTIONAL_SEQUENCE_RNN converter (#19601)
## Summary
This PR adds Relax TFLite frontend support for
`UNIDIRECTIONAL_SEQUENCE_RNN` (BuiltinOperator 35), claimed in
[#19519](https://github.com/apache/tvm/issues/19519) Group A.
The op executes a simple RNN cell over a time sequence. The converter
unrolls the time steps at graph-construction time using Relax
primitives.
Cell equation:
```
h_t = fused_activation(x_t @ W.T + h_{t-1} @ Wr.T + b)
```
## Changes
- **Handler**: `convert_unidirectional_sequence_rnn` registered in
`convert_map` (alphabetical, U-region after `UNPACK`)
- **Inputs** (5): `input [batch, time, input_size]`, `input_weights
[num_units, input_size]`, `recurrent_weights [num_units, num_units]`,
`bias [num_units]`, `hidden_state [batch, num_units]` (variable,
zero-initialised)
- **Output**: `[batch, time, num_units]` (always batch-major)
- **time_major=True**: input is transposed to batch-major before
unrolling
- **Activations**: NONE, RELU, RELU6, TANH, SIGMOID (via
`convert_fused_activation_function`)
- **Quantized**: raises `OpNotImplemented` (not yet supported)
## Testing
Modern TF/Keras (2.x, Keras 3) no longer emits
`UNIDIRECTIONAL_SEQUENCE_RNN`; `SimpleRNN` with `unroll=False` lowers to
`WHILE`+TensorList ops, and `unroll=True` expands to elementwise ops.
Tests therefore follow the same flatbuffer-construction pattern used by
the StableHLO op PRs (#19536, #19587).
Three tests added to `tests/python/relax/test_frontend_tflite.py`:
- `test_unidirectional_sequence_rnn_none_activation` —
`tvm.ir.assert_structural_equal` with identity weights / zero bias, NONE
activation, time=1
- `test_unidirectional_sequence_rnn_relu_activation` — shape check,
random weights, RELU activation, time=3
- `test_unidirectional_sequence_rnn_time_major` — shape check,
`time_major=True` input layout
```bash
python -m pytest tests/python/relax/test_frontend_tflite.py -k
unidirectional_sequence_rnn -v
```
All 3 tests pass. pre-commit (ASF header, ruff check, ruff format) all
pass.
## References
- Issue [#19519](https://github.com/apache/tvm/issues/19519) Group A:
Sequence / recurrent model operators
Co-authored-by: Copilot <[email protected]>
---
.../tvm/relax/frontend/tflite/tflite_frontend.py | 101 ++++++++++
tests/python/relax/test_frontend_tflite.py | 207 +++++++++++++++++++++
2 files changed, 308 insertions(+)
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index f395c95b6d..8183f64f73 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -381,6 +381,7 @@ class OperatorConverter:
"TRANSPOSE_CONV": self.convert_transpose_conv,
"TRANSPOSE": self.convert_transpose,
"UNPACK": self.convert_unpack,
+ "UNIDIRECTIONAL_SEQUENCE_RNN":
self.convert_unidirectional_sequence_rnn,
"UNSORTED_SEGMENT_MIN": functools.partial(
self._convert_segment_op, op_name="UNSORTED_SEGMENT_MIN",
reduction="min"
),
@@ -4877,6 +4878,106 @@ class OperatorConverter:
return squeezed
+ def convert_unidirectional_sequence_rnn(self, op):
+ """Convert TFLite UNIDIRECTIONAL_SEQUENCE_RNN.
+
+ Inputs (5 tensors):
+ [0] input [batch, time, input_size] (or [time, batch,
input_size] if time_major)
+ [1] input_weights [num_units, input_size]
+ [2] recurrent_weights [num_units, num_units]
+ [3] bias [num_units]
+ [4] hidden_state [batch, num_units] (variable, zero-initialised)
+
+ Output:
+ [0] output [batch, time, num_units]
+
+ Cell equation:
+ h_t = fused_activation(x_t @ W.T + h_{t-1} @ Wr.T + b)
+ """
+ from tflite.BuiltinOptions import BuiltinOptions
+ from tflite.SequenceRNNOptions import SequenceRNNOptions
+
+ if self.is_quantized(op):
+ raise tvm.error.OpNotImplemented(
+ "TFLite quantized UNIDIRECTIONAL_SEQUENCE_RNN is not supported
yet."
+ )
+
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 5, "input tensors length should be 5"
+
+ input_tensor = input_tensors[0]
+ weights_tensor = input_tensors[1]
+ recurrent_tensor = input_tensors[2]
+ bias_tensor = input_tensors[3]
+ hidden_state_tensor = input_tensors[4]
+
+ output_tensors = self.get_output_tensors(op)
+ assert len(output_tensors) >= 1, "output tensors length should be at
least 1"
+
+ assert op.BuiltinOptionsType() == BuiltinOptions.SequenceRNNOptions
+ op_options = op.BuiltinOptions()
+ seq_rnn_options = SequenceRNNOptions()
+ seq_rnn_options.Init(op_options.Bytes, op_options.Pos)
+ time_major = seq_rnn_options.TimeMajor()
+ fused_activation_fn = seq_rnn_options.FusedActivationFunction()
+
+ # Constant weight/bias expressions.
+ weights_expr = self.get_tensor_expr(weights_tensor) # [num_units,
input_size]
+ recurrent_expr = self.get_tensor_expr(recurrent_tensor) # [num_units,
num_units]
+
+ # bias is optional (tensor_idx == -1 when absent); default to zeros.
+ if bias_tensor.tensor_idx != -1:
+ bias_expr = self.get_tensor_expr(bias_tensor) # [num_units]
+ else:
+ num_units = int(self.get_tensor_shape(weights_tensor)[0])
+ bias_dtype = self.get_tensor_type_str(weights_tensor.tensor.Type())
+ bias_expr = relax.op.zeros((num_units,), dtype=bias_dtype)
+
+ # Transpose to [input_size, num_units] and [num_units, num_units] for
x @ W.T.
+ w_t = relax.op.permute_dims(weights_expr)
+ wr_t = relax.op.permute_dims(recurrent_expr)
+
+ # Resolve the input expression; normalise to batch-major [batch, time,
input_size].
+ # Only the time dimension must be static (needed for unrolling); batch
may be dynamic.
+ in_expr = self.get_tensor_expr(input_tensor)
+ in_shape = self.get_tensor_shape(input_tensor)
+ if time_major:
+ in_expr = relax.op.permute_dims(in_expr, [1, 0, 2])
+ num_steps = int(in_shape[0])
+ else:
+ num_steps = int(in_shape[1])
+
+ # Initial hidden state: use the model's tensor value when available
(non-zero init or
+ # graph input), otherwise fall back to zeros for the common
variable-tensor case.
+ h_dtype = self.get_tensor_type_str(hidden_state_tensor.tensor.Type())
+ if self.has_expr(hidden_state_tensor.tensor_idx) or (
+ hidden_state_tensor.buffer is not None and
hidden_state_tensor.buffer.DataLength() > 0
+ ):
+ h = self.get_tensor_expr(hidden_state_tensor)
+ else:
+ h_shape =
tuple(to_int_list(self.get_tensor_shape(hidden_state_tensor)))
+ h = relax.op.zeros(h_shape, dtype=h_dtype)
+
+ # Unroll over the time axis.
+ # relax.op.split with 1 section returns the tensor directly; handle
uniformly.
+ if num_steps == 1:
+ steps = [relax.op.squeeze(in_expr, axis=[1])]
+ else:
+ splits = relax.op.split(in_expr, num_steps, axis=1)
+ steps = [relax.op.squeeze(splits[i], axis=[1]) for i in
range(num_steps)]
+
+ outputs = []
+ for x_t in steps: # x_t: [batch, input_size]
+ gates = relax.op.add(
+ relax.op.add(relax.op.matmul(x_t, w_t), relax.op.matmul(h,
wr_t)),
+ bias_expr,
+ )
+ h = self.convert_fused_activation_function(gates,
fused_activation_fn)
+ outputs.append(h)
+
+ # Stack timestep outputs: [batch, time, num_units].
+ return relax.op.stack(outputs, axis=1)
+
"""
def convert_unidirectional_sequence_lstm(self, op):
### Long Short Term Memory for TFLite implementation. ###
diff --git a/tests/python/relax/test_frontend_tflite.py
b/tests/python/relax/test_frontend_tflite.py
index be762d5cb4..f1abacec27 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -3720,6 +3720,8 @@ _tfl_padding = _get_tflite_schema_enum("Padding")
_tfl_sparse_index_vector = _get_tflite_schema_enum("SparseIndexVector")
_tfl_tensor_type = _get_tflite_schema_enum("TensorType")
+_tfl_sequence_rnn_options = _get_tflite_schema_module("SequenceRNNOptions")
+
_DENSIFY_TEST_VALUES = np.array([1.0, 2.0], dtype=np.float32)
_DENSIFY_TEST_DENSE = np.array([[1.0, 0.0], [0.0, 2.0]], dtype=np.float32)
_DENSIFY_ROW_PTRS = [0, 1, 2]
@@ -9719,5 +9721,210 @@ def test_dilate_dynamic_dilations():
tvm.ir.assert_structural_equal(mod, Expected)
+# ── UNIDIRECTIONAL_SEQUENCE_RNN
───────────────────────────────────────────────
+
+
+def _build_unidirectional_sequence_rnn_model(
+ batch,
+ time,
+ input_size,
+ num_units,
+ weights,
+ recurrent_weights,
+ bias,
+ activation,
+ *,
+ time_major=False,
+):
+ """Build a minimal TFLite flatbuffer model containing one
UNIDIRECTIONAL_SEQUENCE_RNN op.
+
+ Tensor layout (indices 0-5):
+ 0 - input [batch, time, input_size] (or [time, batch,
input_size] if time_major)
+ 1 - input_weights [num_units, input_size] (constant)
+ 2 - recurrent_wts [num_units, num_units] (constant)
+ 3 - bias [num_units] (constant)
+ 4 - hidden_state [batch, num_units] (variable,
zero-initialised)
+ 5 - output [batch, time, num_units]
+ """
+ builder = flatbuffers.Builder(4096)
+
+ _tfl_sequence_rnn_options.SequenceRNNOptionsStart(builder)
+ _tfl_sequence_rnn_options.SequenceRNNOptionsAddTimeMajor(builder,
time_major)
+
_tfl_sequence_rnn_options.SequenceRNNOptionsAddFusedActivationFunction(builder,
activation)
+ rnn_opts = _tfl_sequence_rnn_options.SequenceRNNOptionsEnd(builder)
+
+ rnn_op_code = _build_operator_code(builder,
_tfl_builtin_operator.UNIDIRECTIONAL_SEQUENCE_RNN)
+
+ input_shape = [time, batch, input_size] if time_major else [batch, time,
input_size]
+
+ def _t(buf_idx, shape, is_variable=False):
+ shape_vec = _tflite_shape(builder, shape)
+ _tfl_tensor.TensorStart(builder)
+ _tfl_tensor.TensorAddBuffer(builder, buf_idx)
+ _tfl_tensor.TensorAddHasRank(builder, True)
+ _tfl_tensor.TensorAddIsVariable(builder, is_variable)
+ _tfl_tensor.TensorAddShape(builder, shape_vec)
+ _tfl_tensor.TensorAddType(builder, _tfl_tensor_type.FLOAT32)
+ return _tfl_tensor.TensorEnd(builder)
+
+ tensors = [
+ _t(0, input_shape),
+ _t(1, [num_units, input_size]),
+ _t(2, [num_units, num_units]),
+ _t(3, [num_units]),
+ _t(4, [batch, num_units], is_variable=True),
+ _t(5, [batch, time, num_units]),
+ ]
+
+ rnn_op = _build_operator(
+ builder,
+ 0,
+ [0, 1, 2, 3, 4],
+ [5],
+ builtin_options_type=_tfl_builtin_options.SequenceRNNOptions,
+ builtin_options=rnn_opts,
+ )
+
+ subgraph = _build_subgraph(
+ builder,
+ tensors=tensors,
+ operators=[rnn_op],
+ inputs=[0],
+ outputs=[5],
+ )
+
+ buffers = [
+ _build_buffer(builder),
+ _build_buffer(builder, weights.tobytes()),
+ _build_buffer(builder, recurrent_weights.tobytes()),
+ _build_buffer(builder, bias.tobytes()),
+ _build_buffer(builder),
+ _build_buffer(builder),
+ ]
+
+ return _finish_tflite_model(
+ builder,
+ subgraph=subgraph,
+ operator_codes=[rnn_op_code],
+ buffers=buffers,
+ )
+
+
+def test_unidirectional_sequence_rnn_none_activation():
+ """UNIDIRECTIONAL_SEQUENCE_RNN with NONE activation, time=1, lowers to
matmul/add/stack.
+
+ Cell equation: h_t = x_t @ W.T + h_{t-1} @ Wr.T + b (no activation for
NONE)
+ """
+ from tflite.ActivationFunctionType import ActivationFunctionType
+
+ batch, time, input_size, num_units = 2, 1, 2, 2
+ weights = np.eye(num_units, input_size, dtype=np.float32)
+ recurrent_weights = np.eye(num_units, dtype=np.float32)
+ bias = np.zeros(num_units, dtype=np.float32)
+
+ mod = _load_model_from_buffer(
+ _build_unidirectional_sequence_rnn_model(
+ batch,
+ time,
+ input_size,
+ num_units,
+ weights,
+ recurrent_weights,
+ bias,
+ ActivationFunctionType.NONE,
+ )
+ )
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((2, 1, 2), dtype="float32")) -> R.Tensor((2, 1,
2), dtype="float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ lv: R.Tensor((2, 2), dtype="float32") = R.squeeze(x, axis=[1])
+ lv1: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
+ R.const(np.eye(2, dtype=np.float32)), axes=None
+ )
+ lv2: R.Tensor((2, 2), dtype="float32") = R.matmul(lv, lv1,
out_dtype="void")
+ lv3: R.Tensor((2, 2), dtype="float32") = R.zeros(R.shape([2,
2]), dtype="float32")
+ lv4: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
+ R.const(np.eye(2, dtype=np.float32)), axes=None
+ )
+ lv5: R.Tensor((2, 2), dtype="float32") = R.matmul(lv3, lv4,
out_dtype="void")
+ lv6: R.Tensor((2, 2), dtype="float32") = R.add(lv2, lv5)
+ lv7: R.Tensor((2, 2), dtype="float32") = R.add(
+ lv6, R.const(np.zeros(2, dtype=np.float32))
+ )
+ gv: R.Tensor((2, 1, 2), dtype="float32") = R.stack((lv7,),
axis=1)
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_unidirectional_sequence_rnn_relu_activation():
+ """UNIDIRECTIONAL_SEQUENCE_RNN with RELU activation and multiple time
steps."""
+ from tflite.ActivationFunctionType import ActivationFunctionType
+
+ batch, time, input_size, num_units = 2, 3, 4, 8
+ np.random.seed(42)
+ weights = np.random.randn(num_units, input_size).astype(np.float32)
+ recurrent_weights = np.random.randn(num_units,
num_units).astype(np.float32)
+ bias = np.random.randn(num_units).astype(np.float32)
+
+ mod = _load_model_from_buffer(
+ _build_unidirectional_sequence_rnn_model(
+ batch,
+ time,
+ input_size,
+ num_units,
+ weights,
+ recurrent_weights,
+ bias,
+ ActivationFunctionType.RELU,
+ )
+ )
+
+ fn = mod["main"]
+ assert len(fn.params) == 1, "only the sequence input should be a graph
input"
+ in_shape = fn.params[0].struct_info.shape
+ assert tuple(int(d) for d in in_shape) == (batch, time, input_size)
+ out_shape = fn.ret_struct_info.shape
+ assert tuple(int(d) for d in out_shape) == (batch, time, num_units)
+
+
+def test_unidirectional_sequence_rnn_time_major():
+ """UNIDIRECTIONAL_SEQUENCE_RNN with time_major=True transposes before
unrolling."""
+ from tflite.ActivationFunctionType import ActivationFunctionType
+
+ batch, time, input_size, num_units = 3, 4, 2, 5
+ np.random.seed(7)
+ weights = np.random.randn(num_units, input_size).astype(np.float32)
+ recurrent_weights = np.random.randn(num_units,
num_units).astype(np.float32)
+ bias = np.zeros(num_units, dtype=np.float32)
+
+ mod = _load_model_from_buffer(
+ _build_unidirectional_sequence_rnn_model(
+ batch,
+ time,
+ input_size,
+ num_units,
+ weights,
+ recurrent_weights,
+ bias,
+ ActivationFunctionType.NONE,
+ time_major=True,
+ )
+ )
+
+ fn = mod["main"]
+ # Input to the graph is the raw time-major tensor [time, batch,
input_size].
+ in_shape = fn.params[0].struct_info.shape
+ assert tuple(int(d) for d in in_shape) == (time, batch, input_size)
+ # Output is always batch-major [batch, time, num_units].
+ out_shape = fn.ret_struct_info.shape
+ assert tuple(int(d) for d in out_shape) == (batch, time, num_units)
+
+
if __name__ == "__main__":
pytest.main(["-s", __file__])