This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e3933804ee [Relax][Frontend][TFLite] Support sequence LSTM and RNN 
operators (#19634)
e3933804ee is described below

commit e3933804ee15239d38527a18c69c507d76af8ffe
Author: YinHanke <[email protected]>
AuthorDate: Sun May 31 00:49:53 2026 +0800

    [Relax][Frontend][TFLite] Support sequence LSTM and RNN operators (#19634)
    
    ## Summary
    
    Add three TFLite sequence recurrent operators to the Relax frontend, all
    with
    coupled input-forget gate (FULL kernel) and float32-only support.
    
    - UNIDIRECTIONAL_SEQUENCE_LSTM
    - BIDIRECTIONAL_SEQUENCE_RNN
    - BIDIRECTIONAL_SEQUENCE_LSTM
    
    From #19519.
    
    ## Changes
    
    - **UNIDIRECTIONAL_SEQUENCE_LSTM**: same layout as single-step LSTM,
    unrolls over
    time and stacks per-step hidden states. Supports time_major, cell_clip,
    proj_clip,
      and fused activation.
    - **BIDIRECTIONAL_SEQUENCE_RNN**: separate fw/bw RNN cells, backward
    scans in
    reverse. Supports merge_outputs (concat fw + bw) and split outputs via
    Tuple.
    - **BIDIRECTIONAL_SEQUENCE_LSTM**: 48-input operator with fw/bw LSTM
    cells sharing
      the same input tensor. States at indices 35-38.
    - All converters propagate final states to exp_tab for multi-step
    correctness.
    - Peephole, projection, layer norm, and aux input are not supported
    (raise
      OpNotImplemented).
    
    ## Testing
    
    - `test_unidirectional_sequence_lstm_none_activation` — output shape
    [batch, time, num_units]
    - `test_bidirectional_sequence_rnn_none_activation` —
    merge_outputs=True, shape [batch, time, 2*num_units]
    - `test_bidirectional_sequence_lstm_none_activation` —
    merge_outputs=True, shape [batch, time, 2*num_units]
    
    ```bash
    python -m pytest tests/python/relax/test_frontend_tflite.py -k 
"sequence_lstm or sequence_rnn" -v
    ```
---
 .../tvm/relax/frontend/tflite/tflite_frontend.py   | 670 ++++++++++++----
 tests/python/relax/test_frontend_tflite.py         | 892 +++++++++++++++++++++
 2 files changed, 1425 insertions(+), 137 deletions(-)

diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py 
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index c479ec83c1..7046e43bbe 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -200,6 +200,8 @@ class OperatorConverter:
             "AVERAGE_POOL_2D": functools.partial(self.convert_pool2d, 
pool_type="average"),
             "BATCH_TO_SPACE_ND": self.convert_batch_to_space_nd,
             "BATCH_MATMUL": self.convert_batch_matmul,
+            "BIDIRECTIONAL_SEQUENCE_LSTM": 
self.convert_bidirectional_sequence_lstm,
+            "BIDIRECTIONAL_SEQUENCE_RNN": 
self.convert_bidirectional_sequence_rnn,
             "BITCAST": self.convert_bitcast,
             "BROADCAST_TO": self.convert_broadcast_to,
             "BROADCAST_ARGS": self.convert_broadcast_args,
@@ -404,7 +406,7 @@ class OperatorConverter:
             "UNSORTED_SEGMENT_PROD": functools.partial(
                 self._convert_segment_op, op_name="UNSORTED_SEGMENT_PROD", 
reduction="mul"
             ),
-            # "UNIDIRECTIONAL_SEQUENCE_LSTM": 
self.convert_unidirectional_sequence_lstm,
+            "UNIDIRECTIONAL_SEQUENCE_LSTM": 
self.convert_unidirectional_sequence_lstm,
             "VAR_HANDLE": self.convert_var_handle,
             "WHERE": self.convert_select,
             "WHILE": self.convert_while,
@@ -5510,153 +5512,547 @@ class OperatorConverter:
         # Stack timestep outputs: [batch, time, num_units].
         return relax.op.stack(outputs, axis=1)
 
-    """
     def convert_unidirectional_sequence_lstm(self, op):
-        ### Long Short Term Memory for TFLite implementation. ###
+        """Convert TFLite UNIDIRECTIONAL_SEQUENCE_LSTM.
+
+        Inputs (24 tensors, same layout as single-step LSTM):
+          [0]  input                       [batch, time, input_size]
+          [1]  input_to_input_weights      [num_units, input_size]   (optional)
+          [2]  input_to_forget_weights     [num_units, input_size]
+          [3]  input_to_cell_weights       [num_units, input_size]
+          [4]  input_to_output_weights     [num_units, input_size]
+          [5]  recurrent_to_input_weights  [num_units, num_units]   (optional)
+          [6]  recurrent_to_forget_weights [num_units, num_units]
+          [7]  recurrent_to_cell_weights   [num_units, num_units]
+          [8]  recurrent_to_output_weights [num_units, num_units]
+          [9]  cell_to_input_weights       [num_units]              (optional)
+          [10] cell_to_forget_weights      [num_units]              (optional)
+          [11] cell_to_output_weights      [num_units]              (optional)
+          [12] input_gate_bias             [num_units]              (optional)
+          [13] forget_gate_bias            [num_units]
+          [14] cell_gate_bias              [num_units]
+          [15] output_gate_bias            [num_units]
+          [16] projection_weights          [num_units, num_units]   (optional)
+          [17] projection_bias             [num_units]              (optional)
+          [18] output_state                [batch, num_units]       (variable)
+          [19] cell_state                  [batch, num_units]       (variable)
+          [20-23] optional layer norm weights
+
+        Output:
+          [0] output  [batch, time, num_units]
+
+        Uses coupled input-forget gate (i = 1 - f) for the FULL kernel.
+        """
+        from tflite.BuiltinOptions import BuiltinOptions
+        from tflite.UnidirectionalSequenceLSTMOptions import 
UnidirectionalSequenceLSTMOptions
+
         if self.is_quantized(op):
             raise tvm.error.OpNotImplemented(
-                "TFlite quantized UNIDIRECTIONALSEQUENCELSTM operator is not 
supported yet."
+                "TFLite quantized UNIDIRECTIONAL_SEQUENCE_LSTM is not 
supported yet."
             )
 
         input_tensors = self.get_input_tensors(op)
-        assert len(input_tensors) == 24, "input tensors length should be == 24"
+        assert len(input_tensors) == 24, (
+            f"input tensors length should be 24, got {len(input_tensors)}"
+        )
 
-        # Extract input tensor from saved model
-        input_tensor = input_tensors[0]
+        output_tensors = self.get_output_tensors(op)
+        assert len(output_tensors) >= 1, "output tensors length should be at 
least 1"
+
+        assert op.BuiltinOptionsType() == 
BuiltinOptions.UnidirectionalSequenceLSTMOptions
+        op_options = op.BuiltinOptions()
+        lstm_opts = UnidirectionalSequenceLSTMOptions()
+        lstm_opts.Init(op_options.Bytes, op_options.Pos)
+        time_major = lstm_opts.TimeMajor()
+        fused_activation_fn = lstm_opts.FusedActivationFunction()
+        cell_clip = lstm_opts.CellClip()
+        proj_clip = lstm_opts.ProjClip()
+
+        # Only coupled input-forget gate is supported.
+        if input_tensors[1].tensor_idx != -1 or input_tensors[5].tensor_idx != 
-1:
+            raise tvm.error.OpNotImplemented("Only coupled input-forget LSTM 
is supported.")
+        if any(input_tensors[idx].tensor_idx != -1 for idx in [9, 10, 11]):
+            raise tvm.error.OpNotImplemented("TFLite peephole LSTM is not 
supported yet.")
+        if any(input_tensors[idx].tensor_idx != -1 for idx in [16, 17]):
+            raise tvm.error.OpNotImplemented("TFLite projection LSTM is not 
supported yet.")
+        if any(input_tensors[idx].tensor_idx != -1 for idx in [20, 21, 22, 
23]):
+            raise tvm.error.OpNotImplemented("TFLite layer-norm LSTM is not 
supported yet.")
+
+        # Weights (transposed once outside the loop).
+        w_f_t = relax.op.permute_dims(self.get_tensor_expr(input_tensors[2]))
+        w_c_t = relax.op.permute_dims(self.get_tensor_expr(input_tensors[3]))
+        w_o_t = relax.op.permute_dims(self.get_tensor_expr(input_tensors[4]))
+        r_f_t = relax.op.permute_dims(self.get_tensor_expr(input_tensors[6]))
+        r_c_t = relax.op.permute_dims(self.get_tensor_expr(input_tensors[7]))
+        r_o_t = relax.op.permute_dims(self.get_tensor_expr(input_tensors[8]))
+
+        # Biases.
+        b_f = self.get_tensor_expr(input_tensors[13])
+        b_c = self.get_tensor_expr(input_tensors[14])
+        b_o = self.get_tensor_expr(input_tensors[15])
+
+        # Initial states.
+        h = self.get_tensor_expr(input_tensors[18])
+        c = self.get_tensor_expr(input_tensors[19])
+
+        # Resolve the input expression; normalise to batch-major [batch, time, 
input_size].
+        in_expr = self.get_tensor_expr(input_tensors[0])
+        in_shape = self.get_tensor_shape(input_tensors[0])
+        if time_major:
+            in_expr = relax.op.permute_dims(in_expr, [1, 0, 2])
+            num_steps = int(in_shape[0])
+        else:
+            num_steps = int(in_shape[1])
+
+        # Unroll over the time axis.
+        if num_steps == 1:
+            steps = [relax.op.squeeze(in_expr, axis=[1])]
+        else:
+            splits = relax.op.split(in_expr, num_steps, axis=1)
+            steps = [relax.op.squeeze(splits[i], axis=[1]) for i in 
range(num_steps)]
+
+        one = relax.const(1.0, "float32")
+        outputs = []
+        for x_t in steps:
+            f = relax.op.sigmoid(
+                relax.op.add(
+                    relax.op.add(
+                        relax.op.matmul(x_t, w_f_t),
+                        relax.op.matmul(h, r_f_t),
+                    ),
+                    b_f,
+                )
+            )
+            i = relax.op.subtract(one, f)
+            g = self.convert_fused_activation_function(
+                relax.op.add(
+                    relax.op.add(relax.op.matmul(x_t, w_c_t), 
relax.op.matmul(h, r_c_t)),
+                    b_c,
+                ),
+                fused_activation_fn,
+            )
+            o = relax.op.sigmoid(
+                relax.op.add(
+                    relax.op.add(
+                        relax.op.matmul(x_t, w_o_t),
+                        relax.op.matmul(h, r_o_t),
+                    ),
+                    b_o,
+                )
+            )
+
+            c_new = relax.op.add(relax.op.multiply(f, c), relax.op.multiply(i, 
g))
+            if cell_clip > 0.0:
+                c_new = relax.op.clip(c_new, -cell_clip, cell_clip)
+
+            h_new = relax.op.multiply(
+                o, self.convert_fused_activation_function(c_new, 
fused_activation_fn)
+            )
+            if proj_clip > 0.0:
+                h_new = relax.op.clip(h_new, -proj_clip, proj_clip)
+            outputs.append(h_new)
+            h, c = h_new, c_new
+
+        h_out = relax.op.stack(outputs, axis=1)
+        if time_major:
+            h_out = relax.op.permute_dims(h_out, [1, 0, 2])
+
+        # Update state tensors in the expression table for subsequent ops.
+        self.exp_tab.set_expr(
+            get_tensor_name(self.subgraph, input_tensors[18].tensor_idx),
+            h,
+            force_override=True,
+        )
+        self.exp_tab.set_expr(
+            get_tensor_name(self.subgraph, input_tensors[19].tensor_idx),
+            c,
+            force_override=True,
+        )
+
+        return h_out
+
+    def convert_bidirectional_sequence_rnn(self, op):
+        """Convert TFLite BIDIRECTIONAL_SEQUENCE_RNN.
+
+        Inputs (9 tensors, aux_input not supported):
+          [0] input                [batch, time, input_size]
+          [1] fw_weights           [num_units, input_size]
+          [2] fw_recurrent_weights [num_units, num_units]
+          [3] fw_bias              [num_units]
+          [4] fw_hidden_state      [batch, num_units]         (variable)
+          [5] bw_weights           [num_units, input_size]
+          [6] bw_recurrent_weights [num_units, num_units]
+          [7] bw_bias              [num_units]
+          [8] bw_hidden_state      [batch, num_units]         (variable)
+
+        Output (merge_outputs=True):
+          [0] output  [batch, time, 2 * num_units]  (fw and bw concatenated)
+
+        Output (merge_outputs=False):
+          [0] fw_output  [batch, time, num_units]
+          [1] bw_output  [batch, time, num_units]
+        """
+        from tflite.BidirectionalSequenceRNNOptions import 
BidirectionalSequenceRNNOptions
+        from tflite.BuiltinOptions import BuiltinOptions
+
+        if self.is_quantized(op):
+            raise tvm.error.OpNotImplemented(
+                "TFLite quantized BIDIRECTIONAL_SEQUENCE_RNN is not supported 
yet."
+            )
+
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 12, (
+            f"input tensors length should be 12, got {len(input_tensors)}"
+        )
 
-        # Extract tensors from input tensors from saved model
-        # Input weights
-        input_input_weights = input_tensors[1]
-        input_forget_weights = input_tensors[2]
-        input_cell_weights = input_tensors[3]
-        input_output_weights = input_tensors[4]
-        # Recurrent weights
-        recurrent_input_weights = input_tensors[5]
-        recurrent_forget_weights = input_tensors[6]
-        recurrent_cell_weights = input_tensors[7]
-        recurrent_output_weights = input_tensors[8]
-        # inputs 9, 10, 11, 16, 17, 20, 21, 22, 23 are not occupied
-        # there locations are -1 in the flatbuffer
-        # Bias weights
-        input_gate_bias = input_tensors[12]
-        forget_gate_bias = input_tensors[13]
-        cell_gate_bias = input_tensors[14]
-        output_gate_bias = input_tensors[15]
-
-        # State input
-        output_state_in = input_tensors[18]
-        cell_state_in = input_tensors[19]
-
-        # Extract output tensor from saved model
         output_tensors = self.get_output_tensors(op)
-        assert len(output_tensors) == 1, "output tensors length should be 1"
-        X_steps = self.unbind(input_tensor, axis=1)
-        weights_dict = {}
-
-        # hidden_state_weights is equivalent to output_state_in in tflite model
-        out_state_in_shape = tuple(self.get_tensor_shape(output_state_in))
-        out_state_in_dtype = 
self.get_tensor_type_str(output_state_in.tensor.Type())
-        out_state_in_expr = relax.op.zeros(out_state_in_shape, 
dtype=out_state_in_dtype)
-        weights_dict["hidden_state"] = relax.op.split(out_state_in_expr, 1)[0]
-
-        # cell_state_weights is equivalent to output_state_in tflite model
-        cell_state_in_shape = tuple(self.get_tensor_shape(cell_state_in))
-        cell_state_in_dtype = 
self.get_tensor_type_str(cell_state_in.tensor.Type())
-        cell_state_in_expr = relax.op.zeros(cell_state_in_shape, 
dtype=cell_state_in_dtype)
-        weights_dict["cell_state"] = relax.op.split(cell_state_in_expr, 1)[0]
-
-        # Process weight matrix of input: w_inp
-        # Concatenate of [input_input_weight, input_forget_weights,
-        # input_cell_weights, input_output_weights]
-        input_input_weights_default_values = 
self.get_tensor_value(input_input_weights)
-        input_input_weights_op = relax.op.split(
-            relax.op.const(input_input_weights_default_values.tolist()), 1
-        )
-        input_output_weights_default_values = 
self.get_tensor_value(input_output_weights)
-        input_output_weights_op = relax.op.split(
-            relax.op.const(input_output_weights_default_values.tolist()), 1
-        )
-        input_forget_weights_default_values = 
self.get_tensor_value(input_forget_weights)
-        input_forget_weights_op = relax.op.split(
-            relax.op.const(input_forget_weights_default_values.tolist()), 1
-        )
-        input_cell_weights_default_values = 
self.get_tensor_value(input_cell_weights)
-        input_cell_weights_op = relax.op.split(
-            _op.const(input_cell_weights_default_values.tolist()), 1
-        )
-        weights_dict["w_inp"] = relax.op.concat(
-            [
-                relax.op.squeeze(input_input_weights_op[0]),
-                relax.op.squeeze(input_forget_weights_op[0]),
-                relax.op.squeeze(input_cell_weights_op[0]),
-                relax.op.squeeze(input_output_weights_op[0]),
-            ],
-            axis=0,
-        )
-
-        # Process weight matrix of hidden state:
-        # w_hid to support lstm_cell function. Not used in tflite
-        recurrent_input_weights_values = 
self.get_tensor_value(recurrent_input_weights)
-        recurrent_input_weights_op = relax.op.split(
-            relax.op.const(recurrent_input_weights_values.tolist()), 1
-        )
-        recurrent_output_weights_values = 
self.get_tensor_value(recurrent_output_weights)
-        recurrent_output_weights_op = relax.op.split(
-            relax.op.const(recurrent_output_weights_values.tolist()), 1
-        )
-        recurrent_forget_weights_values = 
self.get_tensor_value(recurrent_forget_weights)
-        recurrent_forget_weights_op = relax.op.split(
-            relax.op.const(recurrent_forget_weights_values.tolist()), 1
-        )
-        recurrent_cell_weights_values = 
self.get_tensor_value(recurrent_cell_weights)
-        recurrent_cell_weights_op = relax.op.split(
-            _op.const(recurrent_cell_weights_values.tolist()), 1
-        )
-        weights_dict["w_hid"] = relax.op.concat(
-            [
-                recurrent_input_weights_op[0],
-                recurrent_forget_weights_op[0],
-                recurrent_cell_weights_op[0],
-                recurrent_output_weights_op[0],
-            ],
-            axis=0,
-        )
-
-        # Process weight matrix of bias: b_inp
-        input_gate_bias_values = self.get_tensor_value(input_gate_bias)
-        input_gate_bias_op = 
relax.op.split(_op.const(input_gate_bias_values.tolist()), 1)
-        output_gate_bias_values = self.get_tensor_value(output_gate_bias)
-        output_gate_bias_op = 
relax.op.split(_op.const(output_gate_bias_values.tolist()), 1)
-        forget_gate_bias_values = self.get_tensor_value(forget_gate_bias)
-        forget_gate_bias_op = 
relax.op.split(_op.const(forget_gate_bias_values.tolist()), 1)
-        cell_gate_bias_values = self.get_tensor_value(cell_gate_bias)
-        cell_gate_bias_op = 
relax.op.split(_op.const(cell_gate_bias_values.tolist()), 1)
-        weights_dict["b_inp"] = relax.op.concat(
-            [
-                input_gate_bias_op[0],
-                forget_gate_bias_op[0],
-                cell_gate_bias_op[0],
-                output_gate_bias_op[0],
-            ],
-            axis=0,
-        )
-
-        # Process weight matrix of hidden bias:
-        # b_hid (with the same shape as b_inp)
-        gate_bias_dtype = 
self.get_tensor_type_str(input_gate_bias.tensor.Type())
-        weights_dict["b_hid"] = relax.op.split(
-            relax.op.const(
-                np.zeros(self._infer_shape(weights_dict["b_inp"]), 
dtype=gate_bias_dtype),
-                dtype=gate_bias_dtype,
-            ),
-            1,
-        )[0]
+        assert len(output_tensors) >= 1, "output tensors length should be at 
least 1"
+
+        assert op.BuiltinOptionsType() == 
BuiltinOptions.BidirectionalSequenceRNNOptions
+        op_options = op.BuiltinOptions()
+        rnn_opts = BidirectionalSequenceRNNOptions()
+        rnn_opts.Init(op_options.Bytes, op_options.Pos)
+        time_major = rnn_opts.TimeMajor()
+        fused_activation_fn = rnn_opts.FusedActivationFunction()
+        merge_outputs = rnn_opts.MergeOutputs()
+        if any(input_tensors[idx].tensor_idx != -1 for idx in [9, 10, 11]):
+            raise tvm.error.OpNotImplemented(
+                "TFLite BIDIRECTIONAL_SEQUENCE_RNN aux input is not supported 
yet."
+            )
 
-        outputs, _, _ = lstm_cell(input_seqs=X_steps, **weights_dict)
+        # Forward weights and biases.
+        fw_weights_expr = self.get_tensor_expr(input_tensors[1])
+        fw_recurrent_expr = self.get_tensor_expr(input_tensors[2])
+        fw_bias_expr = self.get_tensor_expr(input_tensors[3])
+        fw_w_t = relax.op.permute_dims(fw_weights_expr)
+        fw_wr_t = relax.op.permute_dims(fw_recurrent_expr)
 
-        output = relax.op.stack(outputs, axis=1)
-        return output
-    """
+        # Backward weights and biases.
+        bw_weights_expr = self.get_tensor_expr(input_tensors[5])
+        bw_recurrent_expr = self.get_tensor_expr(input_tensors[6])
+        bw_bias_expr = self.get_tensor_expr(input_tensors[7])
+        bw_w_t = relax.op.permute_dims(bw_weights_expr)
+        bw_wr_t = relax.op.permute_dims(bw_recurrent_expr)
+
+        # Resolve the input expression; normalise to batch-major [batch, time, 
input_size].
+        in_expr = self.get_tensor_expr(input_tensors[0])
+        in_shape = self.get_tensor_shape(input_tensors[0])
+        if time_major:
+            in_expr = relax.op.permute_dims(in_expr, [1, 0, 2])
+            num_steps = int(in_shape[0])
+        else:
+            num_steps = int(in_shape[1])
+
+        # Initial hidden states.
+        def _get_hidden_state(tensor):
+            if self.has_expr(tensor.tensor_idx) or (
+                tensor.buffer is not None and tensor.buffer.DataLength() > 0
+            ):
+                return self.get_tensor_expr(tensor)
+            dtype = self.get_tensor_type_str(tensor.tensor.Type())
+            h_shape = tuple(to_int_list(self.get_tensor_shape(tensor)))
+            return relax.op.zeros(h_shape, dtype=dtype)
+
+        fw_h = _get_hidden_state(input_tensors[4])
+        bw_h = _get_hidden_state(input_tensors[8])
+
+        # Unroll over the time axis.
+        if num_steps == 1:
+            steps = [relax.op.squeeze(in_expr, axis=[1])]
+        else:
+            splits = relax.op.split(in_expr, num_steps, axis=1)
+            steps = [relax.op.squeeze(splits[i], axis=[1]) for i in 
range(num_steps)]
+
+        # Forward pass.
+        fw_outputs = []
+        for x_t in steps:
+            gates = relax.op.add(
+                relax.op.add(relax.op.matmul(x_t, fw_w_t), 
relax.op.matmul(fw_h, fw_wr_t)),
+                fw_bias_expr,
+            )
+            fw_h = self.convert_fused_activation_function(gates, 
fused_activation_fn)
+            fw_outputs.append(fw_h)
+
+        # Backward pass (process steps in reverse).
+        bw_outputs = []
+        for x_t in reversed(steps):
+            gates = relax.op.add(
+                relax.op.add(relax.op.matmul(x_t, bw_w_t), 
relax.op.matmul(bw_h, bw_wr_t)),
+                bw_bias_expr,
+            )
+            bw_h = self.convert_fused_activation_function(gates, 
fused_activation_fn)
+            bw_outputs.append(bw_h)
+        bw_outputs.reverse()
+
+        fw_stacked = relax.op.stack(fw_outputs, axis=1)  # [batch, time, 
num_units]
+        bw_stacked = relax.op.stack(bw_outputs, axis=1)  # [batch, time, 
num_units]
+        if time_major:
+            fw_stacked = relax.op.permute_dims(fw_stacked, [1, 0, 2])
+            bw_stacked = relax.op.permute_dims(bw_stacked, [1, 0, 2])
+
+        # Update state tensors in the expression table for subsequent ops.
+        self.exp_tab.set_expr(
+            get_tensor_name(self.subgraph, input_tensors[4].tensor_idx),
+            fw_h,
+            force_override=True,
+        )
+        self.exp_tab.set_expr(
+            get_tensor_name(self.subgraph, input_tensors[8].tensor_idx),
+            bw_h,
+            force_override=True,
+        )
+
+        if merge_outputs:
+            return relax.op.concat([fw_stacked, bw_stacked], axis=-1)
+        else:
+            return relax.Tuple([fw_stacked, bw_stacked])
+
+    def convert_bidirectional_sequence_lstm(self, op):
+        """Convert TFLite BIDIRECTIONAL_SEQUENCE_LSTM.
+
+        Inputs (48 tensors, indices 0-17 forward LSTM, 18-34 backward LSTM, 
35-38 states,
+        39-47 optional aux inputs, which are not supported):
+
+        Forward LSTM cell (indices 0-17, same layout as single-step LSTM):
+          [0]  input (shared)              [batch, time, input_size]
+          [1]  fw_input_to_input_weights   (optional)
+          [2]  fw_input_to_forget_weights
+          [3]  fw_input_to_cell_weights
+          [4]  fw_input_to_output_weights
+          [5]  fw_recurrent_to_input_wts   (optional)
+          [6]  fw_recurrent_to_forget_wts
+          [7]  fw_recurrent_to_cell_wts
+          [8]  fw_recurrent_to_output_wts
+          [9-11] fw cell_to_*_weights      (optional, not supported)
+          [12] fw_input_gate_bias          (optional)
+          [13] fw_forget_gate_bias
+          [14] fw_cell_gate_bias
+          [15] fw_output_gate_bias
+          [16] fw_projection_weights       (optional, not supported)
+          [17] fw_projection_bias          (optional, not supported)
+
+        Backward LSTM cell (indices 18-34, same layout as fw):
+          [19] bw_input_to_forget_weights
+          [20] bw_input_to_cell_weights
+          [21] bw_input_to_output_weights
+          [23] bw_recurrent_to_forget_wts
+          [24] bw_recurrent_to_cell_wts
+          [25] bw_recurrent_to_output_wts
+          [30] bw_forget_gate_bias
+          [31] bw_cell_gate_bias
+          [32] bw_output_gate_bias
+
+        State tensors:
+          [35] fw_activation_state  [batch, num_units]
+          [36] fw_cell_state        [batch, num_units]
+          [37] bw_activation_state  [batch, num_units]
+          [38] bw_cell_state        [batch, num_units]
+
+        Output (merge_outputs=True):
+          [0] output  [batch, time, 2 * num_units]
+
+        Output (merge_outputs=False):
+          [0] fw_output  [batch, time, num_units]
+          [1] bw_output  [batch, time, num_units]
+        """
+        from tflite.BidirectionalSequenceLSTMOptions import 
BidirectionalSequenceLSTMOptions
+        from tflite.BuiltinOptions import BuiltinOptions
+
+        if self.is_quantized(op):
+            raise tvm.error.OpNotImplemented(
+                "TFLite quantized BIDIRECTIONAL_SEQUENCE_LSTM is not supported 
yet."
+            )
+
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 48, (
+            f"input tensors length should be 48, got {len(input_tensors)}"
+        )
+
+        output_tensors = self.get_output_tensors(op)
+        assert len(output_tensors) >= 1, "output tensors length should be at 
least 1"
+
+        assert op.BuiltinOptionsType() == 
BuiltinOptions.BidirectionalSequenceLSTMOptions
+        op_options = op.BuiltinOptions()
+        lstm_opts = BidirectionalSequenceLSTMOptions()
+        lstm_opts.Init(op_options.Bytes, op_options.Pos)
+        time_major = lstm_opts.TimeMajor()
+        fused_activation_fn = lstm_opts.FusedActivationFunction()
+        merge_outputs = lstm_opts.MergeOutputs()
+        cell_clip = lstm_opts.CellClip()
+        proj_clip = lstm_opts.ProjClip()
+
+        # ── Forward LSTM weights (transposed once outside the loop) ──
+        if input_tensors[1].tensor_idx != -1 or input_tensors[5].tensor_idx != 
-1:
+            raise tvm.error.OpNotImplemented("Only coupled input-forget LSTM 
is supported.")
+        if any(input_tensors[idx].tensor_idx != -1 for idx in [9, 10, 11]):
+            raise tvm.error.OpNotImplemented("TFLite peephole LSTM is not 
supported yet.")
+        if any(input_tensors[idx].tensor_idx != -1 for idx in [16, 17]):
+            raise tvm.error.OpNotImplemented("TFLite projection LSTM is not 
supported yet.")
+
+        fw_w_f_t = 
relax.op.permute_dims(self.get_tensor_expr(input_tensors[2]))
+        fw_w_c_t = 
relax.op.permute_dims(self.get_tensor_expr(input_tensors[3]))
+        fw_w_o_t = 
relax.op.permute_dims(self.get_tensor_expr(input_tensors[4]))
+        fw_r_f_t = 
relax.op.permute_dims(self.get_tensor_expr(input_tensors[6]))
+        fw_r_c_t = 
relax.op.permute_dims(self.get_tensor_expr(input_tensors[7]))
+        fw_r_o_t = 
relax.op.permute_dims(self.get_tensor_expr(input_tensors[8]))
+        fw_b_f = self.get_tensor_expr(input_tensors[13])
+        fw_b_c = self.get_tensor_expr(input_tensors[14])
+        fw_b_o = self.get_tensor_expr(input_tensors[15])
+
+        # ── Backward LSTM weights (transposed once outside the loop) ──
+        if input_tensors[18].tensor_idx != -1 or input_tensors[22].tensor_idx 
!= -1:
+            raise tvm.error.OpNotImplemented("Only coupled input-forget LSTM 
is supported.")
+        if any(input_tensors[idx].tensor_idx != -1 for idx in [26, 27, 28]):
+            raise tvm.error.OpNotImplemented("TFLite peephole LSTM is not 
supported yet.")
+        if any(input_tensors[idx].tensor_idx != -1 for idx in [33, 34]):
+            raise tvm.error.OpNotImplemented("TFLite projection LSTM is not 
supported yet.")
+        if any(input_tensors[idx].tensor_idx != -1 for idx in range(39, 48)):
+            raise tvm.error.OpNotImplemented(
+                "TFLite BIDIRECTIONAL_SEQUENCE_LSTM aux input is not supported 
yet."
+            )
+
+        bw_w_f_t = 
relax.op.permute_dims(self.get_tensor_expr(input_tensors[19]))
+        bw_w_c_t = 
relax.op.permute_dims(self.get_tensor_expr(input_tensors[20]))
+        bw_w_o_t = 
relax.op.permute_dims(self.get_tensor_expr(input_tensors[21]))
+        bw_r_f_t = 
relax.op.permute_dims(self.get_tensor_expr(input_tensors[23]))
+        bw_r_c_t = 
relax.op.permute_dims(self.get_tensor_expr(input_tensors[24]))
+        bw_r_o_t = 
relax.op.permute_dims(self.get_tensor_expr(input_tensors[25]))
+        bw_b_f = self.get_tensor_expr(input_tensors[30])
+        bw_b_c = self.get_tensor_expr(input_tensors[31])
+        bw_b_o = self.get_tensor_expr(input_tensors[32])
+
+        # ── Initial states ──
+        fw_h = self.get_tensor_expr(input_tensors[35])
+        fw_c = self.get_tensor_expr(input_tensors[36])
+        bw_h = self.get_tensor_expr(input_tensors[37])
+        bw_c = self.get_tensor_expr(input_tensors[38])
+
+        # ── Unroll input ──
+        in_expr = self.get_tensor_expr(input_tensors[0])
+        in_shape = self.get_tensor_shape(input_tensors[0])
+        if time_major:
+            in_expr = relax.op.permute_dims(in_expr, [1, 0, 2])
+            num_steps = int(in_shape[0])
+        else:
+            num_steps = int(in_shape[1])
+
+        if num_steps == 1:
+            steps = [relax.op.squeeze(in_expr, axis=[1])]
+        else:
+            splits = relax.op.split(in_expr, num_steps, axis=1)
+            steps = [relax.op.squeeze(splits[i], axis=[1]) for i in 
range(num_steps)]
+
+        one = relax.const(1.0, "float32")
+
+        def _lstm_step(x_t, h, c, w_f_t, w_c_t, w_o_t, r_f_t, r_c_t, r_o_t, 
b_f, b_c, b_o):
+            """Single LSTM step with coupled input-forget gate."""
+            f = relax.op.sigmoid(
+                relax.op.add(
+                    relax.op.add(
+                        relax.op.matmul(x_t, w_f_t),
+                        relax.op.matmul(h, r_f_t),
+                    ),
+                    b_f,
+                )
+            )
+            i = relax.op.subtract(one, f)
+            g = self.convert_fused_activation_function(
+                relax.op.add(
+                    relax.op.add(relax.op.matmul(x_t, w_c_t), 
relax.op.matmul(h, r_c_t)),
+                    b_c,
+                ),
+                fused_activation_fn,
+            )
+            o = relax.op.sigmoid(
+                relax.op.add(
+                    relax.op.add(
+                        relax.op.matmul(x_t, w_o_t),
+                        relax.op.matmul(h, r_o_t),
+                    ),
+                    b_o,
+                )
+            )
+            c_new = relax.op.add(relax.op.multiply(f, c), relax.op.multiply(i, 
g))
+            if cell_clip > 0.0:
+                c_new = relax.op.clip(c_new, -cell_clip, cell_clip)
+            h_new = relax.op.multiply(
+                o, self.convert_fused_activation_function(c_new, 
fused_activation_fn)
+            )
+            if proj_clip > 0.0:
+                h_new = relax.op.clip(h_new, -proj_clip, proj_clip)
+            return h_new, c_new
+
+        # ── Forward pass ──
+        fw_outputs = []
+        for x_t in steps:
+            fw_h, fw_c = _lstm_step(
+                x_t,
+                fw_h,
+                fw_c,
+                fw_w_f_t,
+                fw_w_c_t,
+                fw_w_o_t,
+                fw_r_f_t,
+                fw_r_c_t,
+                fw_r_o_t,
+                fw_b_f,
+                fw_b_c,
+                fw_b_o,
+            )
+            fw_outputs.append(fw_h)
+
+        # ── Backward pass ──
+        bw_outputs = []
+        for x_t in reversed(steps):
+            bw_h, bw_c = _lstm_step(
+                x_t,
+                bw_h,
+                bw_c,
+                bw_w_f_t,
+                bw_w_c_t,
+                bw_w_o_t,
+                bw_r_f_t,
+                bw_r_c_t,
+                bw_r_o_t,
+                bw_b_f,
+                bw_b_c,
+                bw_b_o,
+            )
+            bw_outputs.append(bw_h)
+        bw_outputs.reverse()
+
+        fw_stacked = relax.op.stack(fw_outputs, axis=1)
+        bw_stacked = relax.op.stack(bw_outputs, axis=1)
+        if time_major:
+            fw_stacked = relax.op.permute_dims(fw_stacked, [1, 0, 2])
+            bw_stacked = relax.op.permute_dims(bw_stacked, [1, 0, 2])
+
+        # Update state tensors in the expression table for subsequent ops.
+        self.exp_tab.set_expr(
+            get_tensor_name(self.subgraph, input_tensors[35].tensor_idx),
+            fw_h,
+            force_override=True,
+        )
+        self.exp_tab.set_expr(
+            get_tensor_name(self.subgraph, input_tensors[36].tensor_idx),
+            fw_c,
+            force_override=True,
+        )
+        self.exp_tab.set_expr(
+            get_tensor_name(self.subgraph, input_tensors[37].tensor_idx),
+            bw_h,
+            force_override=True,
+        )
+        self.exp_tab.set_expr(
+            get_tensor_name(self.subgraph, input_tensors[38].tensor_idx),
+            bw_c,
+            force_override=True,
+        )
+
+        if merge_outputs:
+            return relax.op.concat([fw_stacked, bw_stacked], axis=-1)
+        else:
+            return relax.Tuple([fw_stacked, bw_stacked])
 
     def convert_batch_to_space_nd(self, op):
         """batch_to_space_nd implementation."""
diff --git a/tests/python/relax/test_frontend_tflite.py 
b/tests/python/relax/test_frontend_tflite.py
index e9ccea7ad1..05a6c1e5e5 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -3723,6 +3723,15 @@ _tfl_tensor_type = _get_tflite_schema_enum("TensorType")
 _tfl_lstm_options = _get_tflite_schema_module("LSTMOptions")
 _tfl_sequence_rnn_options = _get_tflite_schema_module("SequenceRNNOptions")
 _tfl_svdf_options = _get_tflite_schema_module("SVDFOptions")
+_tfl_unidirectional_sequence_lstm_options = _get_tflite_schema_module(
+    "UnidirectionalSequenceLSTMOptions"
+)
+_tfl_bidirectional_sequence_rnn_options = _get_tflite_schema_module(
+    "BidirectionalSequenceRNNOptions"
+)
+_tfl_bidirectional_sequence_lstm_options = _get_tflite_schema_module(
+    "BidirectionalSequenceLSTMOptions"
+)
 
 _DENSIFY_TEST_VALUES = np.array([1.0, 2.0], dtype=np.float32)
 _DENSIFY_TEST_DENSE = np.array([[1.0, 0.0], [0.0, 2.0]], dtype=np.float32)
@@ -11052,6 +11061,889 @@ def test_svdf_shared_state_updates_exp_tab():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+# ── UNIDIRECTIONAL_SEQUENCE_LSTM ─────────────────────────────────────────────
+
+
+def _build_unidirectional_sequence_lstm_model(
+    batch,
+    time,
+    input_size,
+    num_units,
+    input_to_forget_weights,
+    input_to_cell_weights,
+    input_to_output_weights,
+    recurrent_to_forget_weights,
+    recurrent_to_cell_weights,
+    recurrent_to_output_weights,
+    forget_gate_bias,
+    cell_bias,
+    output_gate_bias,
+    activation,
+    *,
+    time_major=False,
+    cell_clip=0.0,
+    proj_clip=0.0,
+    projection_weights=None,
+):
+    """Build a TFLite flatbuffer model with one UNIDIRECTIONAL_SEQUENCE_LSTM 
op.
+
+    Tensor indices (same layout as single-step LSTM, but input is 3D):
+      0  - input                       [batch, time, input_size]
+      1  - input_to_forget_weights     [num_units, input_size]
+      2  - input_to_cell_weights       [num_units, input_size]
+      3  - input_to_output_weights     [num_units, input_size]
+      4  - recurrent_to_forget_weights [num_units, num_units]
+      5  - recurrent_to_cell_weights   [num_units, num_units]
+      6  - recurrent_to_output_weights [num_units, num_units]
+      7  - forget_gate_bias            [num_units]
+      8  - cell_bias                   [num_units]
+      9  - output_gate_bias            [num_units]
+      10 - output_state                [batch, num_units]   (model input)
+      11 - cell_state                  [batch, num_units]   (model input)
+      12 - output                      [batch, time, num_units] or [time, 
batch, num_units]
+    """
+    builder = flatbuffers.Builder(4096)
+
+    
_tfl_unidirectional_sequence_lstm_options.UnidirectionalSequenceLSTMOptionsStart(builder)
+    
_tfl_unidirectional_sequence_lstm_options.UnidirectionalSequenceLSTMOptionsAddFusedActivationFunction(
+        builder, activation
+    )
+    
_tfl_unidirectional_sequence_lstm_options.UnidirectionalSequenceLSTMOptionsAddTimeMajor(
+        builder, time_major
+    )
+    
_tfl_unidirectional_sequence_lstm_options.UnidirectionalSequenceLSTMOptionsAddCellClip(
+        builder, cell_clip
+    )
+    
_tfl_unidirectional_sequence_lstm_options.UnidirectionalSequenceLSTMOptionsAddProjClip(
+        builder, proj_clip
+    )
+    lstm_opts = 
_tfl_unidirectional_sequence_lstm_options.UnidirectionalSequenceLSTMOptionsEnd(
+        builder
+    )
+
+    lstm_op_code = _build_operator_code(builder, 
_tfl_builtin_operator.UNIDIRECTIONAL_SEQUENCE_LSTM)
+
+    def _t(buf_idx, shape):
+        shape_vec = _tflite_shape(builder, shape)
+        _tfl_tensor.TensorStart(builder)
+        _tfl_tensor.TensorAddBuffer(builder, buf_idx)
+        _tfl_tensor.TensorAddHasRank(builder, True)
+        _tfl_tensor.TensorAddIsVariable(builder, False)
+        _tfl_tensor.TensorAddShape(builder, shape_vec)
+        _tfl_tensor.TensorAddType(builder, _tfl_tensor_type.FLOAT32)
+        return _tfl_tensor.TensorEnd(builder)
+
+    input_shape = [time, batch, input_size] if time_major else [batch, time, 
input_size]
+    output_shape = [time, batch, num_units] if time_major else [batch, time, 
num_units]
+    tensors = [
+        _t(0, input_shape),  # 0: input
+        _t(1, [num_units, input_size]),  # 1: input_to_forget_weights
+        _t(2, [num_units, input_size]),  # 2: input_to_cell_weights
+        _t(3, [num_units, input_size]),  # 3: input_to_output_weights
+        _t(4, [num_units, num_units]),  # 4: recurrent_to_forget_weights
+        _t(5, [num_units, num_units]),  # 5: recurrent_to_cell_weights
+        _t(6, [num_units, num_units]),  # 6: recurrent_to_output_weights
+        _t(7, [num_units]),  # 7: forget_gate_bias
+        _t(8, [num_units]),  # 8: cell_bias
+        _t(9, [num_units]),  # 9: output_gate_bias
+        _t(0, [batch, num_units]),  # 10: output_state (model input)
+        _t(0, [batch, num_units]),  # 11: cell_state (model input)
+        _t(0, output_shape),  # 12: output
+    ]
+
+    # 24 operator inputs, -1 for absent.
+    lstm_inputs = [
+        0,
+        -1,
+        1,
+        2,
+        3,
+        -1,
+        4,
+        5,
+        6,
+        -1,
+        -1,
+        -1,
+        -1,
+        7,
+        8,
+        9,
+        -1,
+        -1,
+        10,
+        11,
+        -1,
+        -1,
+        -1,
+        -1,
+    ]
+    buffers = [
+        _build_buffer(builder),  # 0: empty
+        _build_buffer(builder, input_to_forget_weights.tobytes()),  # 1
+        _build_buffer(builder, input_to_cell_weights.tobytes()),  # 2
+        _build_buffer(builder, input_to_output_weights.tobytes()),  # 3
+        _build_buffer(builder, recurrent_to_forget_weights.tobytes()),  # 4
+        _build_buffer(builder, recurrent_to_cell_weights.tobytes()),  # 5
+        _build_buffer(builder, recurrent_to_output_weights.tobytes()),  # 6
+        _build_buffer(builder, forget_gate_bias.tobytes()),  # 7
+        _build_buffer(builder, cell_bias.tobytes()),  # 8
+        _build_buffer(builder, output_gate_bias.tobytes()),  # 9
+    ]
+    if projection_weights is not None:
+        tensors.append(_t(len(buffers), [num_units, num_units]))
+        lstm_inputs[16] = len(tensors) - 1
+        buffers.append(_build_buffer(builder, projection_weights.tobytes()))
+
+    lstm_op = _build_operator(
+        builder,
+        0,
+        lstm_inputs,
+        [12],
+        
builtin_options_type=_tfl_builtin_options.UnidirectionalSequenceLSTMOptions,
+        builtin_options=lstm_opts,
+    )
+
+    subgraph = _build_subgraph(
+        builder,
+        tensors=tensors,
+        operators=[lstm_op],
+        inputs=[0, 10, 11],
+        outputs=[12],
+    )
+
+    return _finish_tflite_model(
+        builder,
+        subgraph=subgraph,
+        operator_codes=[lstm_op_code],
+        buffers=buffers,
+    )
+
+
+def test_unidirectional_sequence_lstm_none_activation():
+    """UNIDIRECTIONAL_SEQUENCE_LSTM with NONE activation keeps cell activation 
linear."""
+    from tflite.ActivationFunctionType import ActivationFunctionType
+
+    batch, time, input_size, num_units = 2, 1, 2, 2
+    w_f = np.eye(num_units, input_size, dtype=np.float32)
+    w_c = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
+    w_o = np.array([[0.5, -0.25], [0.75, 0.5]], dtype=np.float32)
+    r_f = np.eye(num_units, dtype=np.float32)
+    r_c = np.array([[0.5, 0.0], [0.0, 0.25]], dtype=np.float32)
+    r_o = np.array([[0.1, 0.0], [0.0, 0.2]], dtype=np.float32)
+    b_f = np.zeros(num_units, dtype=np.float32)
+    b_c = np.zeros(num_units, dtype=np.float32)
+    b_o = np.zeros(num_units, dtype=np.float32)
+
+    mod = _load_model_from_buffer(
+        _build_unidirectional_sequence_lstm_model(
+            batch,
+            time,
+            input_size,
+            num_units,
+            w_f,
+            w_c,
+            w_o,
+            r_f,
+            r_c,
+            r_o,
+            b_f,
+            b_c,
+            b_o,
+            ActivationFunctionType.NONE,
+        )
+    )
+
+    script = mod.script(show_meta=True)
+    assert script.count("R.sigmoid") == 2
+    assert "R.tanh" not in script
+    assert "R.multiply" in script
+
+
+def test_unidirectional_sequence_lstm_tanh_activation():
+    """UNIDIRECTIONAL_SEQUENCE_LSTM with TANH activation applies it inside the 
cell."""
+    from tflite.ActivationFunctionType import ActivationFunctionType
+
+    batch, time, input_size, num_units = 2, 1, 2, 2
+    w_f = np.eye(num_units, input_size, dtype=np.float32)
+    w_c = np.array([[1.0, -1.0], [0.25, 0.5]], dtype=np.float32)
+    w_o = np.array([[0.5, 0.5], [-0.5, 1.0]], dtype=np.float32)
+    r_f = np.eye(num_units, dtype=np.float32)
+    r_c = np.array([[0.0, 0.1], [0.2, 0.0]], dtype=np.float32)
+    r_o = np.array([[0.3, 0.0], [0.0, 0.4]], dtype=np.float32)
+    b_f = np.zeros(num_units, dtype=np.float32)
+    b_c = np.zeros(num_units, dtype=np.float32)
+    b_o = np.zeros(num_units, dtype=np.float32)
+
+    mod = _load_model_from_buffer(
+        _build_unidirectional_sequence_lstm_model(
+            batch,
+            time,
+            input_size,
+            num_units,
+            w_f,
+            w_c,
+            w_o,
+            r_f,
+            r_c,
+            r_o,
+            b_f,
+            b_c,
+            b_o,
+            ActivationFunctionType.TANH,
+        )
+    )
+
+    script = mod.script(show_meta=True)
+    assert script.count("R.sigmoid") == 2
+    assert script.count("R.tanh") == 2
+    assert "R.multiply" in script
+
+
+def test_unidirectional_sequence_lstm_time_major():
+    """UNIDIRECTIONAL_SEQUENCE_LSTM preserves time-major output layout."""
+    from tflite.ActivationFunctionType import ActivationFunctionType
+
+    batch, time, input_size, num_units = 2, 3, 2, 2
+    weights = np.eye(num_units, input_size, dtype=np.float32)
+    recurrent = np.eye(num_units, dtype=np.float32)
+    bias = np.zeros(num_units, dtype=np.float32)
+
+    mod = _load_model_from_buffer(
+        _build_unidirectional_sequence_lstm_model(
+            batch,
+            time,
+            input_size,
+            num_units,
+            weights,
+            weights,
+            weights,
+            recurrent,
+            recurrent,
+            recurrent,
+            bias,
+            bias,
+            bias,
+            ActivationFunctionType.NONE,
+            time_major=True,
+        )
+    )
+
+    fn = mod["main"]
+    assert tuple(int(d) for d in fn.params[0].struct_info.shape) == (time, 
batch, input_size)
+    assert tuple(int(d) for d in fn.ret_struct_info.shape) == (time, batch, 
num_units)
+
+
+def test_unidirectional_sequence_lstm_rejects_projection():
+    """UNIDIRECTIONAL_SEQUENCE_LSTM rejects unsupported projection inputs."""
+    from tflite.ActivationFunctionType import ActivationFunctionType
+
+    batch, time, input_size, num_units = 2, 2, 2, 2
+    weights = np.eye(num_units, input_size, dtype=np.float32)
+    recurrent = np.eye(num_units, dtype=np.float32)
+    bias = np.zeros(num_units, dtype=np.float32)
+
+    with pytest.raises(tvm.error.OpNotImplemented, match="projection LSTM"):
+        _load_model_from_buffer(
+            _build_unidirectional_sequence_lstm_model(
+                batch,
+                time,
+                input_size,
+                num_units,
+                weights,
+                weights,
+                weights,
+                recurrent,
+                recurrent,
+                recurrent,
+                bias,
+                bias,
+                bias,
+                ActivationFunctionType.NONE,
+                projection_weights=np.eye(num_units, dtype=np.float32),
+            )
+        )
+
+
+# ── BIDIRECTIONAL_SEQUENCE_RNN ───────────────────────────────────────────────
+
+
+def _build_bidirectional_sequence_rnn_model(
+    batch,
+    time,
+    input_size,
+    num_units,
+    fw_weights,
+    fw_recurrent_weights,
+    fw_bias,
+    bw_weights,
+    bw_recurrent_weights,
+    bw_bias,
+    activation,
+    *,
+    time_major=False,
+    merge_outputs=True,
+    with_aux_input=False,
+):
+    """Build a TFLite flatbuffer model with one BIDIRECTIONAL_SEQUENCE_RNN op.
+
+    Tensor indices:
+      0  - input               [batch, time, input_size]
+      1  - fw_weights          [num_units, input_size]
+      2  - fw_recurrent_weights [num_units, num_units]
+      3  - fw_bias             [num_units]
+      4  - fw_hidden_state     [batch, num_units]   (model input)
+      5  - bw_weights          [num_units, input_size]
+      6  - bw_recurrent_weights [num_units, num_units]
+      7  - bw_bias             [num_units]
+      8  - bw_hidden_state     [batch, num_units]   (model input)
+      9  - aux_input           (optional)
+      10 - fw_aux_weights      (optional)
+      11 - bw_aux_weights      (optional)
+      12 - output (or fw_output if merge_outputs=False)
+      13 - bw_output (only if merge_outputs=False)
+    """
+    builder = flatbuffers.Builder(4096)
+
+    
_tfl_bidirectional_sequence_rnn_options.BidirectionalSequenceRNNOptionsStart(builder)
+    
_tfl_bidirectional_sequence_rnn_options.BidirectionalSequenceRNNOptionsAddTimeMajor(
+        builder, time_major
+    )
+    
_tfl_bidirectional_sequence_rnn_options.BidirectionalSequenceRNNOptionsAddFusedActivationFunction(
+        builder, activation
+    )
+    
_tfl_bidirectional_sequence_rnn_options.BidirectionalSequenceRNNOptionsAddMergeOutputs(
+        builder, merge_outputs
+    )
+    rnn_opts = 
_tfl_bidirectional_sequence_rnn_options.BidirectionalSequenceRNNOptionsEnd(builder)
+
+    rnn_op_code = _build_operator_code(builder, 
_tfl_builtin_operator.BIDIRECTIONAL_SEQUENCE_RNN)
+
+    def _t(buf_idx, shape):
+        shape_vec = _tflite_shape(builder, shape)
+        _tfl_tensor.TensorStart(builder)
+        _tfl_tensor.TensorAddBuffer(builder, buf_idx)
+        _tfl_tensor.TensorAddHasRank(builder, True)
+        _tfl_tensor.TensorAddIsVariable(builder, False)
+        _tfl_tensor.TensorAddShape(builder, shape_vec)
+        _tfl_tensor.TensorAddType(builder, _tfl_tensor_type.FLOAT32)
+        return _tfl_tensor.TensorEnd(builder)
+
+    input_shape = [time, batch, input_size] if time_major else [batch, time, 
input_size]
+    output_prefix = [time, batch] if time_major else [batch, time]
+    output_shape = output_prefix + ([num_units * 2] if merge_outputs else 
[num_units])
+
+    tensors = [
+        _t(0, input_shape),  # 0: input
+        _t(1, [num_units, input_size]),  # 1: fw_weights
+        _t(2, [num_units, num_units]),  # 2: fw_recurrent_weights
+        _t(3, [num_units]),  # 3: fw_bias
+        _t(0, [batch, num_units]),  # 4: fw_hidden_state (model input)
+        _t(4, [num_units, input_size]),  # 5: bw_weights
+        _t(5, [num_units, num_units]),  # 6: bw_recurrent_weights
+        _t(6, [num_units]),  # 7: bw_bias
+        _t(0, [batch, num_units]),  # 8: bw_hidden_state (model input)
+    ]
+    buffers = [
+        _build_buffer(builder),  # 0: empty
+        _build_buffer(builder, fw_weights.tobytes()),  # 1
+        _build_buffer(builder, fw_recurrent_weights.tobytes()),  # 2
+        _build_buffer(builder, fw_bias.tobytes()),  # 3
+        _build_buffer(builder, bw_weights.tobytes()),  # 4
+        _build_buffer(builder, bw_recurrent_weights.tobytes()),  # 5
+        _build_buffer(builder, bw_bias.tobytes()),  # 6
+    ]
+    rnn_inputs = [*list(range(9)), -1, -1, -1]
+    if with_aux_input:
+        tensors.extend(
+            [
+                _t(len(buffers), input_shape),
+                _t(len(buffers) + 1, [num_units, input_size]),
+                _t(len(buffers) + 2, [num_units, input_size]),
+            ]
+        )
+        rnn_inputs[9:12] = [len(tensors) - 3, len(tensors) - 2, len(tensors) - 
1]
+        buffers.extend(
+            [
+                _build_buffer(builder, np.zeros(input_shape, 
dtype=np.float32).tobytes()),
+                _build_buffer(
+                    builder, np.zeros((num_units, input_size), 
dtype=np.float32).tobytes()
+                ),
+                _build_buffer(
+                    builder, np.zeros((num_units, input_size), 
dtype=np.float32).tobytes()
+                ),
+            ]
+        )
+
+    if merge_outputs:
+        tensors.append(_t(0, output_shape))
+        outputs = [len(tensors) - 1]
+    else:
+        tensors.extend([_t(0, output_shape), _t(0, output_shape)])
+        outputs = [len(tensors) - 2, len(tensors) - 1]
+
+    rnn_op = _build_operator(
+        builder,
+        0,
+        rnn_inputs,
+        outputs,
+        
builtin_options_type=_tfl_builtin_options.BidirectionalSequenceRNNOptions,
+        builtin_options=rnn_opts,
+    )
+
+    subgraph = _build_subgraph(
+        builder,
+        tensors=tensors,
+        operators=[rnn_op],
+        inputs=[0, 4, 8],
+        outputs=outputs,
+    )
+
+    return _finish_tflite_model(
+        builder,
+        subgraph=subgraph,
+        operator_codes=[rnn_op_code],
+        buffers=buffers,
+    )
+
+
+def test_bidirectional_sequence_rnn_none_activation():
+    """BIDIRECTIONAL_SEQUENCE_RNN with NONE activation lowers the expected 
equations."""
+    from tflite.ActivationFunctionType import ActivationFunctionType
+
+    batch, time, input_size, num_units = 2, 1, 2, 2
+    fw_w = np.array([[1.0, 0.0], [0.5, -1.0]], dtype=np.float32)
+    fw_r = np.array([[0.25, 0.0], [0.0, 0.5]], dtype=np.float32)
+    fw_b = np.zeros(num_units, dtype=np.float32)
+    bw_w = np.array([[0.0, 1.0], [-0.5, 0.75]], dtype=np.float32)
+    bw_r = np.array([[0.1, 0.0], [0.0, 0.2]], dtype=np.float32)
+    bw_b = np.zeros(num_units, dtype=np.float32)
+
+    mod = _load_model_from_buffer(
+        _build_bidirectional_sequence_rnn_model(
+            batch,
+            time,
+            input_size,
+            num_units,
+            fw_w,
+            fw_r,
+            fw_b,
+            bw_w,
+            bw_r,
+            bw_b,
+            ActivationFunctionType.NONE,
+        )
+    )
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 1, 2), dtype="float32"),
+            fw_h: R.Tensor((2, 2), dtype="float32"),
+            bw_h: R.Tensor((2, 2), dtype="float32"),
+        ) -> R.Tensor((2, 1, 4), dtype="float32"):
+            R.func_attr({"num_input": 3})
+            with R.dataflow():
+                x_t: R.Tensor((2, 2), dtype="float32") = R.squeeze(x, axis=[1])
+                fw_w_t: R.Tensor((2, 2), dtype="float32") = 
R.permute_dims(R.const(fw_w), axes=None)
+                fw_x: R.Tensor((2, 2), dtype="float32") = R.matmul(x_t, 
fw_w_t, out_dtype="void")
+                fw_r_t: R.Tensor((2, 2), dtype="float32") = 
R.permute_dims(R.const(fw_r), axes=None)
+                fw_h_proj: R.Tensor((2, 2), dtype="float32") = R.matmul(
+                    fw_h, fw_r_t, out_dtype="void"
+                )
+                fw_out: R.Tensor((2, 2), dtype="float32") = R.add(
+                    R.add(fw_x, fw_h_proj), R.const(fw_b)
+                )
+                fw_stacked: R.Tensor((2, 1, 2), dtype="float32") = 
R.stack((fw_out,), axis=1)
+                bw_w_t: R.Tensor((2, 2), dtype="float32") = 
R.permute_dims(R.const(bw_w), axes=None)
+                bw_x: R.Tensor((2, 2), dtype="float32") = R.matmul(x_t, 
bw_w_t, out_dtype="void")
+                bw_r_t: R.Tensor((2, 2), dtype="float32") = 
R.permute_dims(R.const(bw_r), axes=None)
+                bw_h_proj: R.Tensor((2, 2), dtype="float32") = R.matmul(
+                    bw_h, bw_r_t, out_dtype="void"
+                )
+                bw_out: R.Tensor((2, 2), dtype="float32") = R.add(
+                    R.add(bw_x, bw_h_proj), R.const(bw_b)
+                )
+                bw_stacked: R.Tensor((2, 1, 2), dtype="float32") = 
R.stack((bw_out,), axis=1)
+                gv: R.Tensor((2, 1, 4), dtype="float32") = R.concat(
+                    (fw_stacked, bw_stacked), axis=-1
+                )
+                R.output(gv)
+            return gv
+
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_bidirectional_sequence_rnn_time_major():
+    """BIDIRECTIONAL_SEQUENCE_RNN preserves time-major output layout."""
+    from tflite.ActivationFunctionType import ActivationFunctionType
+
+    batch, time, input_size, num_units = 2, 3, 2, 2
+    weights = np.eye(num_units, input_size, dtype=np.float32)
+    recurrent = np.eye(num_units, dtype=np.float32)
+    bias = np.zeros(num_units, dtype=np.float32)
+
+    mod = _load_model_from_buffer(
+        _build_bidirectional_sequence_rnn_model(
+            batch,
+            time,
+            input_size,
+            num_units,
+            weights,
+            recurrent,
+            bias,
+            weights,
+            recurrent,
+            bias,
+            ActivationFunctionType.NONE,
+            time_major=True,
+        )
+    )
+
+    fn = mod["main"]
+    assert tuple(int(d) for d in fn.params[0].struct_info.shape) == (time, 
batch, input_size)
+    assert tuple(int(d) for d in fn.ret_struct_info.shape) == (time, batch, 
num_units * 2)
+
+
+def test_bidirectional_sequence_rnn_rejects_aux_input():
+    """BIDIRECTIONAL_SEQUENCE_RNN rejects unsupported auxiliary input 
tensors."""
+    from tflite.ActivationFunctionType import ActivationFunctionType
+
+    batch, time, input_size, num_units = 2, 2, 2, 2
+    weights = np.eye(num_units, input_size, dtype=np.float32)
+    recurrent = np.eye(num_units, dtype=np.float32)
+    bias = np.zeros(num_units, dtype=np.float32)
+
+    with pytest.raises(tvm.error.OpNotImplemented, match="aux input"):
+        _load_model_from_buffer(
+            _build_bidirectional_sequence_rnn_model(
+                batch,
+                time,
+                input_size,
+                num_units,
+                weights,
+                recurrent,
+                bias,
+                weights,
+                recurrent,
+                bias,
+                ActivationFunctionType.NONE,
+                with_aux_input=True,
+            )
+        )
+
+
+# ── BIDIRECTIONAL_SEQUENCE_LSTM ──────────────────────────────────────────────
+
+
+def _build_bidirectional_sequence_lstm_model(
+    batch,
+    time,
+    input_size,
+    num_units,
+    fw_w_f,
+    fw_w_c,
+    fw_w_o,
+    fw_r_f,
+    fw_r_c,
+    fw_r_o,
+    fw_b_f,
+    fw_b_c,
+    fw_b_o,
+    bw_w_f,
+    bw_w_c,
+    bw_w_o,
+    bw_r_f,
+    bw_r_c,
+    bw_r_o,
+    bw_b_f,
+    bw_b_c,
+    bw_b_o,
+    activation,
+    *,
+    time_major=False,
+    merge_outputs=True,
+    cell_clip=0.0,
+    proj_clip=0.0,
+    with_aux_input=False,
+):
+    """Build a TFLite flatbuffer model with one BIDIRECTIONAL_SEQUENCE_LSTM op.
+
+    48 operator inputs. Forward LSTM: indices 0-17, Backward LSTM: indices 
18-34,
+    States: indices 35-38.
+    """
+    builder = flatbuffers.Builder(8192)
+
+    
_tfl_bidirectional_sequence_lstm_options.BidirectionalSequenceLSTMOptionsStart(builder)
+    
_tfl_bidirectional_sequence_lstm_options.BidirectionalSequenceLSTMOptionsAddFusedActivationFunction(
+        builder, activation
+    )
+    
_tfl_bidirectional_sequence_lstm_options.BidirectionalSequenceLSTMOptionsAddTimeMajor(
+        builder, time_major
+    )
+    
_tfl_bidirectional_sequence_lstm_options.BidirectionalSequenceLSTMOptionsAddMergeOutputs(
+        builder, merge_outputs
+    )
+    
_tfl_bidirectional_sequence_lstm_options.BidirectionalSequenceLSTMOptionsAddCellClip(
+        builder, cell_clip
+    )
+    
_tfl_bidirectional_sequence_lstm_options.BidirectionalSequenceLSTMOptionsAddProjClip(
+        builder, proj_clip
+    )
+    lstm_opts = 
_tfl_bidirectional_sequence_lstm_options.BidirectionalSequenceLSTMOptionsEnd(
+        builder
+    )
+
+    lstm_op_code = _build_operator_code(builder, 
_tfl_builtin_operator.BIDIRECTIONAL_SEQUENCE_LSTM)
+
+    def _t(buf_idx, shape, is_variable=False):
+        shape_vec = _tflite_shape(builder, shape)
+        _tfl_tensor.TensorStart(builder)
+        _tfl_tensor.TensorAddBuffer(builder, buf_idx)
+        _tfl_tensor.TensorAddHasRank(builder, True)
+        _tfl_tensor.TensorAddIsVariable(builder, is_variable)
+        _tfl_tensor.TensorAddShape(builder, shape_vec)
+        _tfl_tensor.TensorAddType(builder, _tfl_tensor_type.FLOAT32)
+        return _tfl_tensor.TensorEnd(builder)
+
+    input_shape = [time, batch, input_size] if time_major else [batch, time, 
input_size]
+    output_size = num_units * 2 if merge_outputs else num_units
+    output_shape = ([time, batch] if time_major else [batch, time]) + 
[output_size]
+
+    tensors = [
+        _t(0, input_shape),  # 0: input
+        _t(1, [num_units, input_size]),  # 1: fw_w_f
+        _t(2, [num_units, input_size]),  # 2: fw_w_c
+        _t(3, [num_units, input_size]),  # 3: fw_w_o
+        _t(4, [num_units, num_units]),  # 4: fw_r_f
+        _t(5, [num_units, num_units]),  # 5: fw_r_c
+        _t(6, [num_units, num_units]),  # 6: fw_r_o
+        _t(7, [num_units]),  # 7: fw_b_f
+        _t(8, [num_units]),  # 8: fw_b_c
+        _t(9, [num_units]),  # 9: fw_b_o
+        _t(10, [num_units, input_size]),  # 10: bw_w_f
+        _t(11, [num_units, input_size]),  # 11: bw_w_c
+        _t(12, [num_units, input_size]),  # 12: bw_w_o
+        _t(13, [num_units, num_units]),  # 13: bw_r_f
+        _t(14, [num_units, num_units]),  # 14: bw_r_c
+        _t(15, [num_units, num_units]),  # 15: bw_r_o
+        _t(16, [num_units]),  # 16: bw_b_f
+        _t(17, [num_units]),  # 17: bw_b_c
+        _t(18, [num_units]),  # 18: bw_b_o
+        _t(0, [batch, num_units]),  # 19: fw_activation_state (model input)
+        _t(0, [batch, num_units]),  # 20: fw_cell_state (model input)
+        _t(0, [batch, num_units]),  # 21: bw_activation_state (model input)
+        _t(0, [batch, num_units]),  # 22: bw_cell_state (model input)
+        _t(0, output_shape),  # 23: output
+    ]
+
+    # Build operator inputs: 48 total, with unsupported optional inputs set to 
-1.
+    fw_inputs = [0, -1, 1, 2, 3, -1, 4, 5, 6, -1, -1, -1, -1, 7, 8, 9, -1, -1]
+    bw_inputs = [-1, 10, 11, 12, -1, 13, 14, 15, -1, -1, -1, -1, 16, 17, 18, 
-1, -1]
+    states = [19, 20, 21, 22]
+    aux_inputs = [-1] * 9
+    if with_aux_input:
+        tensors.append(_t(0, input_shape))
+        aux_inputs[0] = len(tensors) - 1
+    lstm_inputs = fw_inputs + bw_inputs + states + aux_inputs
+
+    lstm_op = _build_operator(
+        builder,
+        0,
+        lstm_inputs,
+        [23],
+        
builtin_options_type=_tfl_builtin_options.BidirectionalSequenceLSTMOptions,
+        builtin_options=lstm_opts,
+    )
+
+    subgraph = _build_subgraph(
+        builder,
+        tensors=tensors,
+        operators=[lstm_op],
+        inputs=[0, 19, 20, 21, 22],
+        outputs=[23],
+    )
+
+    buffers = [
+        _build_buffer(builder),  # 0: empty
+        _build_buffer(builder, fw_w_f.tobytes()),  # 1
+        _build_buffer(builder, fw_w_c.tobytes()),  # 2
+        _build_buffer(builder, fw_w_o.tobytes()),  # 3
+        _build_buffer(builder, fw_r_f.tobytes()),  # 4
+        _build_buffer(builder, fw_r_c.tobytes()),  # 5
+        _build_buffer(builder, fw_r_o.tobytes()),  # 6
+        _build_buffer(builder, fw_b_f.tobytes()),  # 7
+        _build_buffer(builder, fw_b_c.tobytes()),  # 8
+        _build_buffer(builder, fw_b_o.tobytes()),  # 9
+        _build_buffer(builder, bw_w_f.tobytes()),  # 10
+        _build_buffer(builder, bw_w_c.tobytes()),  # 11
+        _build_buffer(builder, bw_w_o.tobytes()),  # 12
+        _build_buffer(builder, bw_r_f.tobytes()),  # 13
+        _build_buffer(builder, bw_r_c.tobytes()),  # 14
+        _build_buffer(builder, bw_r_o.tobytes()),  # 15
+        _build_buffer(builder, bw_b_f.tobytes()),  # 16
+        _build_buffer(builder, bw_b_c.tobytes()),  # 17
+        _build_buffer(builder, bw_b_o.tobytes()),  # 18
+    ]
+
+    return _finish_tflite_model(
+        builder,
+        subgraph=subgraph,
+        operator_codes=[lstm_op_code],
+        buffers=buffers,
+    )
+
+
+def test_bidirectional_sequence_lstm_none_activation():
+    """BIDIRECTIONAL_SEQUENCE_LSTM with NONE activation keeps both cell 
activations linear."""
+    from tflite.ActivationFunctionType import ActivationFunctionType
+
+    batch, time, input_size, num_units = 2, 1, 2, 2
+
+    def _eye_or_randn(m, n):
+        if m == n:
+            return np.eye(m, dtype=np.float32)
+        return np.arange(m * n, dtype=np.float32).reshape(m, n) / 10.0
+
+    fw_w_f = _eye_or_randn(num_units, input_size)
+    fw_w_c = np.array([[1.0, -0.5], [0.25, 0.75]], dtype=np.float32)
+    fw_w_o = np.array([[0.5, 0.25], [-0.25, 1.0]], dtype=np.float32)
+    fw_r_f = _eye_or_randn(num_units, num_units)
+    fw_r_c = np.array([[0.2, 0.0], [0.0, 0.3]], dtype=np.float32)
+    fw_r_o = np.array([[0.1, 0.0], [0.0, 0.2]], dtype=np.float32)
+    fw_b_f = np.zeros(num_units, dtype=np.float32)
+    fw_b_c = np.zeros(num_units, dtype=np.float32)
+    fw_b_o = np.zeros(num_units, dtype=np.float32)
+
+    bw_w_f = np.array([[1.0, 0.0], [0.0, 1.0]], dtype=np.float32)
+    bw_w_c = np.array([[0.5, 0.5], [-0.5, 1.0]], dtype=np.float32)
+    bw_w_o = np.array([[0.25, -0.25], [0.75, 0.5]], dtype=np.float32)
+    bw_r_f = np.array([[0.4, 0.0], [0.0, 0.6]], dtype=np.float32)
+    bw_r_c = np.array([[0.3, 0.0], [0.0, 0.2]], dtype=np.float32)
+    bw_r_o = np.array([[0.2, 0.0], [0.0, 0.1]], dtype=np.float32)
+    bw_b_f = np.zeros(num_units, dtype=np.float32)
+    bw_b_c = np.zeros(num_units, dtype=np.float32)
+    bw_b_o = np.zeros(num_units, dtype=np.float32)
+
+    mod = _load_model_from_buffer(
+        _build_bidirectional_sequence_lstm_model(
+            batch,
+            time,
+            input_size,
+            num_units,
+            fw_w_f,
+            fw_w_c,
+            fw_w_o,
+            fw_r_f,
+            fw_r_c,
+            fw_r_o,
+            fw_b_f,
+            fw_b_c,
+            fw_b_o,
+            bw_w_f,
+            bw_w_c,
+            bw_w_o,
+            bw_r_f,
+            bw_r_c,
+            bw_r_o,
+            bw_b_f,
+            bw_b_c,
+            bw_b_o,
+            ActivationFunctionType.NONE,
+        )
+    )
+
+    script = mod.script(show_meta=True)
+    assert script.count("R.sigmoid") == 4
+    assert "R.tanh" not in script
+    assert script.count("R.stack") == 2
+    assert "R.concat" in script
+
+
+def test_bidirectional_sequence_lstm_time_major():
+    """BIDIRECTIONAL_SEQUENCE_LSTM preserves time-major output layout."""
+    from tflite.ActivationFunctionType import ActivationFunctionType
+
+    batch, time, input_size, num_units = 2, 3, 2, 2
+    weights = np.eye(num_units, input_size, dtype=np.float32)
+    recurrent = np.eye(num_units, dtype=np.float32)
+    bias = np.zeros(num_units, dtype=np.float32)
+
+    mod = _load_model_from_buffer(
+        _build_bidirectional_sequence_lstm_model(
+            batch,
+            time,
+            input_size,
+            num_units,
+            weights,
+            weights,
+            weights,
+            recurrent,
+            recurrent,
+            recurrent,
+            bias,
+            bias,
+            bias,
+            weights,
+            weights,
+            weights,
+            recurrent,
+            recurrent,
+            recurrent,
+            bias,
+            bias,
+            bias,
+            ActivationFunctionType.NONE,
+            time_major=True,
+        )
+    )
+
+    fn = mod["main"]
+    assert tuple(int(d) for d in fn.params[0].struct_info.shape) == (time, 
batch, input_size)
+    assert tuple(int(d) for d in fn.ret_struct_info.shape) == (time, batch, 
num_units * 2)
+
+
+def test_bidirectional_sequence_lstm_rejects_aux_input():
+    """BIDIRECTIONAL_SEQUENCE_LSTM rejects unsupported auxiliary inputs."""
+    from tflite.ActivationFunctionType import ActivationFunctionType
+
+    batch, time, input_size, num_units = 2, 2, 2, 2
+    weights = np.eye(num_units, input_size, dtype=np.float32)
+    recurrent = np.eye(num_units, dtype=np.float32)
+    bias = np.zeros(num_units, dtype=np.float32)
+
+    with pytest.raises(tvm.error.OpNotImplemented, match="aux input"):
+        _load_model_from_buffer(
+            _build_bidirectional_sequence_lstm_model(
+                batch,
+                time,
+                input_size,
+                num_units,
+                weights,
+                weights,
+                weights,
+                recurrent,
+                recurrent,
+                recurrent,
+                bias,
+                bias,
+                bias,
+                weights,
+                weights,
+                weights,
+                recurrent,
+                recurrent,
+                recurrent,
+                bias,
+                bias,
+                bias,
+                ActivationFunctionType.NONE,
+                with_aux_input=True,
+            )
+        )
+
+
 # ── UNIDIRECTIONAL_SEQUENCE_RNN 
───────────────────────────────────────────────
 
 

Reply via email to