This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 576e60e974 [Relax][Frontend][TFLite] Add LSTM and SVDF converter
(#19633)
576e60e974 is described below
commit 576e60e9744a1baea921ce054829925e192d3817
Author: YinHanke <[email protected]>
AuthorDate: Sat May 30 02:30:19 2026 +0800
[Relax][Frontend][TFLite] Add LSTM and SVDF converter (#19633)
## Summary
Add LSTM (coupled input-forget) and SVDF single-step converters to the
TFLite frontend. Both are float32-only; quantized variants are not
supported yet.
From #19519.
## Changes
- **LSTM**: FULL kernel type, coupled input-forget gate only. Peephole,
projection, and layer norm are not supported
- **SVDF**: Standard SVDF with feature projection + time filtering +
bias + fused activation
- Both converters validate unsupported modes (quantized, non-coupled
LSTM) with clear error messages
## Testing
- `test_lstm_none_activation` — verifies LSTM converter produces correct
IR shapes (batch, input_size) → (batch, num_units) with 3 params (input,
h_state, c_state)
- `test_svdf_none_activation` — verifies SVDF converter produces correct
IR shapes (batch, input_size) → (batch, num_filters) with 2 params
(input, state)
```bash
python -m pytest tests/python/relax/test_frontend_tflite.py -k "lstm or
svdf" -v
```
## References
- TFLite LSTM spec:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/lstm.cc
- TFLite SVDF spec:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/svdf.cc
---
.../tvm/relax/frontend/tflite/tflite_frontend.py | 451 ++++----
tests/python/relax/test_frontend_tflite.py | 1170 ++++++++++----------
2 files changed, 830 insertions(+), 791 deletions(-)
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 87f0f12b1b..87697dc6ad 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -253,6 +253,7 @@ class OperatorConverter:
"LOGICAL_NOT": self.convert_logical_not,
"LOGICAL_OR": functools.partial(self._convert_logical_binary,
relax_op=_op.logical_or),
"LOGISTIC": self.convert_logistic,
+ "LSTM": self.convert_lstm,
"MATRIX_DIAG": self.convert_matrix_diag,
"MATRIX_SET_DIAG": self.convert_matrix_set_diag,
"MAX_POOL_2D": functools.partial(self.convert_pool2d,
pool_type="max"),
@@ -273,7 +274,6 @@ class OperatorConverter:
"POW": functools.partial(self._convert_elemwise,
relax_op=_op.power),
"PRELU": self.convert_prelu,
"RANGE": self.convert_range,
- "RNN": self.convert_rnn,
"QUANTIZE": self.convert_quantize,
"RANDOM_STANDARD_NORMAL": self.convert_random_standard_normal,
"RANDOM_UNIFORM": self.convert_random_uniform,
@@ -282,7 +282,6 @@ class OperatorConverter:
"REDUCE_MAX": functools.partial(self._convert_reduce,
relax_op=_op.max),
"REDUCE_MIN": functools.partial(self._convert_reduce,
relax_op=_op.min),
"REDUCE_PROD": functools.partial(self._convert_reduce,
relax_op=_op.prod),
- "REDUCE_WINDOW": self.convert_reduce_window,
"RELU": self.convert_relu,
"RELU6": self.convert_relu6,
"RELU_N1_TO_1": self.convert_relu_n1_to_1,
@@ -376,6 +375,7 @@ class OperatorConverter:
"STRIDED_SLICE": self.convert_strided_slice,
"SUB": functools.partial(self._convert_elemwise,
relax_op=_op.subtract),
"SUM": functools.partial(self._convert_reduce, relax_op=_op.sum),
+ "SVDF": self.convert_svdf,
"TAN": functools.partial(self._convert_unary_elemwise,
relax_op=_op.tan),
"TANH": self.convert_tanh,
"TILE": self.convert_tile,
@@ -3447,171 +3447,6 @@ class OperatorConverter:
return out
- def convert_reduce_window(self, op):
- """Convert TFLite REDUCE_WINDOW."""
-
- from tflite.BuiltinOptions2 import BuiltinOptions2
- from tflite.ReduceWindowFunction import ReduceWindowFunction
- from tflite.ReduceWindowOptions import ReduceWindowOptions
-
- input_tensors = self.get_input_tensors(op)
- output_tensors = self.get_output_tensors(op)
- if len(input_tensors) != 5:
- raise tvm.error.OpAttributeUnImplemented(
- "TFLite REDUCE_WINDOW requires 5 input tensors."
- )
- if len(output_tensors) != 1:
- raise tvm.error.OpAttributeUnImplemented(
- "TFLite REDUCE_WINDOW requires 1 output tensor."
- )
-
- if op.BuiltinOptions2Type() != BuiltinOptions2.ReduceWindowOptions:
- raise tvm.error.OpAttributeUnImplemented(
- "TFLite REDUCE_WINDOW requires ReduceWindowOptions."
- )
-
- (
- input_tensor,
- init_tensor,
- window_shape_tensor,
- window_strides_tensor,
- window_dilations_tensor,
- ) = input_tensors
- output_tensor = output_tensors[0]
-
- if any(
- self.has_expr(tensor.tensor_idx)
- for tensor in [window_shape_tensor, window_strides_tensor,
window_dilations_tensor]
- ):
- raise tvm.error.OpNotImplemented(
- "TFLite REDUCE_WINDOW requires constant window_shape, "
- "window_strides, and window_dilations."
- )
-
- input_shape = to_int_list(self.get_tensor_shape(input_tensor))
- output_shape = to_int_list(self.get_tensor_shape(output_tensor))
- input_dtype = self.get_tensor_type_str(input_tensor.tensor.Type())
- output_dtype = self.get_tensor_type_str(output_tensor.tensor.Type())
-
- if input_tensor.qnn_params or output_tensor.qnn_params:
- raise tvm.error.OpNotImplemented(
- "Quantized TFLite REDUCE_WINDOW is not yet supported in the
Relax frontend."
- )
-
- if input_dtype != output_dtype:
- raise tvm.error.OpAttributeUnImplemented(
- "TFLite REDUCE_WINDOW requires input and output dtypes to
match."
- )
-
- init_shape = to_int_list(self.get_tensor_shape(init_tensor))
- if math.prod(init_shape) != 1:
- raise tvm.error.OpNotImplemented(
- "TFLite REDUCE_WINDOW requires init_value to contain exactly
one element."
- )
-
- options = ReduceWindowOptions()
- op_options = op.BuiltinOptions2()
- options.Init(op_options.Bytes, op_options.Pos)
- reduce_function = options.ReduceFunction()
-
- if reduce_function == ReduceWindowFunction.UNSUPPORTED:
- raise tvm.error.OpNotImplemented(
- "TFLite REDUCE_WINDOW with UNSUPPORTED reduce_function is not
supported."
- )
-
- window_shape = to_int_list(self.get_tensor_value(window_shape_tensor))
- window_strides =
to_int_list(self.get_tensor_value(window_strides_tensor))
- window_dilations =
to_int_list(self.get_tensor_value(window_dilations_tensor))
- rank = len(input_shape)
-
- if not (len(window_shape) == len(window_strides) ==
len(window_dilations) == rank):
- raise tvm.error.OpAttributeUnImplemented(
- "TFLite REDUCE_WINDOW window_shape, window_strides, and
window_dilations "
- "must match input rank."
- )
-
- if any(value <= 0 for value in window_shape + window_strides +
window_dilations):
- raise tvm.error.OpAttributeUnImplemented(
- "TFLite REDUCE_WINDOW window dimensions, strides, and
dilations must be positive."
- )
-
- dilated_window_shape = [
- (window_dim - 1) * dilation + 1
- for window_dim, dilation in zip(window_shape, window_dilations)
- ]
- expected_output_shape = [
- 0 if input_dim < dilated_dim else (input_dim - dilated_dim) //
stride + 1
- for input_dim, dilated_dim, stride in zip(
- input_shape, dilated_window_shape, window_strides
- )
- ]
-
- numeric_reduce_functions = {
- ReduceWindowFunction.ADD: (relax.op.sum, relax.op.add),
- ReduceWindowFunction.MUL: (relax.op.prod, relax.op.multiply),
- ReduceWindowFunction.MINIMUM: (relax.op.min, relax.op.minimum),
- ReduceWindowFunction.MAXIMUM: (relax.op.max, relax.op.maximum),
- }
- bool_reduce_functions = {
- ReduceWindowFunction.ALL: (relax.op.min, relax.op.logical_and),
- ReduceWindowFunction.ANY: (relax.op.max, relax.op.logical_or),
- }
-
- if reduce_function in numeric_reduce_functions and input_dtype ==
"bool":
- raise tvm.error.OpAttributeUnImplemented(
- "TFLite REDUCE_WINDOW numeric reductions expect numeric input."
- )
- if reduce_function in bool_reduce_functions and input_dtype != "bool":
- raise tvm.error.OpAttributeUnImplemented(
- "TFLite REDUCE_WINDOW boolean reductions expect bool input."
- )
-
- if output_shape != expected_output_shape:
- raise tvm.error.OpAttributeUnImplemented(
- "TFLite REDUCE_WINDOW output shape does not match input/window
parameters."
- )
-
- if any(output_dim == 0 for output_dim in output_shape):
- return relax.op.zeros(output_shape, output_dtype)
-
- data = self.get_tensor_expr(input_tensor)
- init_value = self.get_tensor_expr(init_tensor)
- if len(init_shape) != 0:
- init_value = relax.op.reshape(init_value, [])
-
- windowed = relax.op.call_dps_packed(
- "topi.sliding_window",
- (
- data,
- 0,
- relax.ShapeExpr(dilated_window_shape),
- relax.ShapeExpr(window_strides),
- ),
- out_sinfo=relax.TensorStructInfo(output_shape +
dilated_window_shape, input_dtype),
- )
-
- if any(dilation != 1 for dilation in window_dilations):
- windowed = relax.op.strided_slice(
- windowed,
- axes=list(range(rank, 2 * rank)),
- begin=[0] * rank,
- end=dilated_window_shape,
- strides=window_dilations,
- )
-
- reduce_axes = list(range(rank, 2 * rank))
- if reduce_function in numeric_reduce_functions:
- reduce_op, combine_op = numeric_reduce_functions[reduce_function]
- return combine_op(reduce_op(windowed, axis=reduce_axes),
init_value)
- if reduce_function in bool_reduce_functions:
- reduce_op, combine_op = bool_reduce_functions[reduce_function]
- reduced = reduce_op(relax.op.astype(windowed, "int8"),
axis=reduce_axes)
- return combine_op(relax.op.astype(reduced, "bool"), init_value)
-
- raise tvm.error.OpNotImplemented(
- f"TFLite REDUCE_WINDOW reduce_function {reduce_function} is not
supported."
- )
-
def _convert_reduce_bool(self, relax_op, op):
"""Convert TFLite REDUCE_ANY / REDUCE_ALL (bool-only ops).
@@ -5045,83 +4880,263 @@ class OperatorConverter:
return squeezed
- def convert_rnn(self, op):
- """Convert TFLite RNN.
-
- Single-step RNN cell.
-
- Inputs (5 tensors):
- [0] input [batch, input_size]
- [1] input_weights [num_units, input_size]
- [2] recurrent_weights [num_units, num_units]
- [3] bias [num_units]
- [4] hidden_state [batch, num_units] (variable,
zero-initialised)
+ def convert_lstm(self, op):
+ """Convert TFLite LSTM (single-step).
+
+ Standard LSTM cell with FULL kernel and coupled input-forget gate.
+ Peephole, projection, and layer norm are not supported.
+
+ Inputs (24 tensors, many optional):
+ [0] input [batch, input_size]
+ [1] input_to_input_weights (optional, -1 => coupled)
+ [2] input_to_forget_weights [num_units, input_size]
+ [3] input_to_cell_weights [num_units, input_size]
+ [4] input_to_output_weights [num_units, input_size]
+ [5] recurrent_to_input_weights (optional)
+ [6] recurrent_to_forget_weights [num_units, num_units]
+ [7] recurrent_to_cell_weights [num_units, num_units]
+ [8] recurrent_to_output_weights [num_units, num_units]
+ [9-11] cell_to_*_weights (optional, not supported)
+ [12] input_gate_bias (optional)
+ [13] forget_gate_bias [num_units]
+ [14] cell_bias [num_units]
+ [15] output_gate_bias [num_units]
+ [16-17] projection_weights/bias (optional, not supported)
+ [18] output_state [batch, num_units]
+ [19] cell_state [batch, num_units]
+ [20-23] layer_norm (optional, not supported)
Output:
[0] output [batch, num_units]
- Cell equation:
- h = fused_activation(x @ W.T + h @ Wr.T + b)
+ Cell (coupled input-forget):
+ f = sigmoid(x @ W_f.T + h @ R_f.T + b_f)
+ i = 1 - f
+ g = tanh(x @ W_c.T + h @ R_c.T + b_c)
+ o = sigmoid(x @ W_o.T + h @ R_o.T + b_o)
+ c_new = f * c_prev + i * g
+ h_new = fused_activation(o * tanh(c_new))
"""
from tflite.BuiltinOptions import BuiltinOptions
- from tflite.RNNOptions import RNNOptions
+ from tflite.LSTMOptions import LSTMOptions
if self.is_quantized(op):
- raise tvm.error.OpNotImplemented("TFLite quantized RNN is not
supported yet.")
+ raise tvm.error.OpNotImplemented("TFLite quantized LSTM is not
supported yet.")
input_tensors = self.get_input_tensors(op)
- assert len(input_tensors) == 5, "input tensors length should be 5"
-
- input_tensor = input_tensors[0]
- weights_tensor = input_tensors[1]
- recurrent_tensor = input_tensors[2]
- bias_tensor = input_tensors[3]
- hidden_state_tensor = input_tensors[4]
+ assert len(input_tensors) == 24, (
+ f"input tensors length should be 24, got {len(input_tensors)}"
+ )
output_tensors = self.get_output_tensors(op)
assert len(output_tensors) >= 1, "output tensors length should be at
least 1"
- assert op.BuiltinOptionsType() == BuiltinOptions.RNNOptions
+ assert op.BuiltinOptionsType() == BuiltinOptions.LSTMOptions
op_options = op.BuiltinOptions()
- rnn_options = RNNOptions()
- rnn_options.Init(op_options.Bytes, op_options.Pos)
- fused_activation_fn = rnn_options.FusedActivationFunction()
+ lstm_opts = LSTMOptions()
+ lstm_opts.Init(op_options.Bytes, op_options.Pos)
- # Constant weight/bias expressions.
- weights_expr = self.get_tensor_expr(weights_tensor) # [num_units,
input_size]
- recurrent_expr = self.get_tensor_expr(recurrent_tensor) # [num_units,
num_units]
- bias_expr = self.get_tensor_expr(bias_tensor) # [num_units]
+ fused_activation_fn = lstm_opts.FusedActivationFunction()
+ cell_clip = lstm_opts.CellClip()
+ proj_clip = lstm_opts.ProjClip()
- # Transpose to [input_size, num_units] and [num_units, num_units] for
x @ W.T.
- w_t = relax.op.permute_dims(weights_expr)
- wr_t = relax.op.permute_dims(recurrent_expr)
+ in_expr = self.get_tensor_expr(input_tensors[0])
- # Resolve the input expression.
- in_expr = self.get_tensor_expr(input_tensor)
+ # Only coupled input-forget gate is supported.
+ if input_tensors[1].tensor_idx != -1 or input_tensors[5].tensor_idx !=
-1:
+ raise tvm.error.OpNotImplemented("Only coupled input-forget LSTM
is supported.")
- # Initial hidden state: use the model's tensor value when available
(non-zero init or
- # graph input), otherwise fall back to zeros for the common
variable-tensor case.
- h_dtype = self.get_tensor_type_str(hidden_state_tensor.tensor.Type())
- if self.has_expr(hidden_state_tensor.tensor_idx) or (
- hidden_state_tensor.buffer is not None and
hidden_state_tensor.buffer.DataLength() > 0
+ # Peephole, projection, and layer norm are not modeled yet.
+ if (
+ any(t.tensor_idx != -1 for t in input_tensors[9:12])
+ or any(t.tensor_idx != -1 for t in input_tensors[16:18])
+ or any(t.tensor_idx != -1 for t in input_tensors[20:24])
):
- h = self.get_tensor_expr(hidden_state_tensor)
- else:
- h_shape =
tuple(to_int_list(self.get_tensor_shape(hidden_state_tensor)))
- h = relax.op.zeros(h_shape, dtype=h_dtype)
+ raise tvm.error.OpNotImplemented(
+ "Peephole, projection, and layer norm LSTM are not supported
yet."
+ )
+
+ # Weights.
+ w_f = self.get_tensor_expr(input_tensors[2])
+ w_c = self.get_tensor_expr(input_tensors[3])
+ w_o = self.get_tensor_expr(input_tensors[4])
+
+ r_f = self.get_tensor_expr(input_tensors[6])
+ r_c = self.get_tensor_expr(input_tensors[7])
+ r_o = self.get_tensor_expr(input_tensors[8])
+
+ # Biases.
+ b_f = self.get_tensor_expr(input_tensors[13])
+ b_c = self.get_tensor_expr(input_tensors[14])
+ b_o = self.get_tensor_expr(input_tensors[15])
+
+ # State inputs.
+ h_prev = self.get_tensor_expr(input_tensors[18])
+ c_prev = self.get_tensor_expr(input_tensors[19])
- gates = relax.op.add(
- relax.op.add(relax.op.matmul(in_expr, w_t), relax.op.matmul(h,
wr_t)),
- bias_expr,
+ # Coupled input-forget gate.
+ f = relax.op.sigmoid(
+ relax.op.add(
+ relax.op.add(
+ relax.op.matmul(in_expr, relax.op.permute_dims(w_f)),
+ relax.op.matmul(h_prev, relax.op.permute_dims(r_f)),
+ ),
+ b_f,
+ )
+ )
+ i = relax.op.subtract(
+ relax.const(1.0, "float32"),
+ f,
+ )
+
+ # Cell candidate.
+ g = relax.op.tanh(
+ relax.op.add(
+ relax.op.add(
+ relax.op.matmul(in_expr, relax.op.permute_dims(w_c)),
+ relax.op.matmul(h_prev, relax.op.permute_dims(r_c)),
+ ),
+ b_c,
+ )
)
- h = self.convert_fused_activation_function(gates, fused_activation_fn)
+ # Output gate.
+ o = relax.op.sigmoid(
+ relax.op.add(
+ relax.op.add(
+ relax.op.matmul(in_expr, relax.op.permute_dims(w_o)),
+ relax.op.matmul(h_prev, relax.op.permute_dims(r_o)),
+ ),
+ b_o,
+ )
+ )
+
+ # Cell state update with optional clipping.
+ c_new = relax.op.add(
+ relax.op.multiply(f, c_prev),
+ relax.op.multiply(i, g),
+ )
+ if cell_clip > 0:
+ c_new = relax.op.clip(c_new, -cell_clip, cell_clip)
+
+ # Hidden state.
+ # TFLite applies the fused activation to the cell state before the
+ # output gate multiply.
+ h_new = relax.op.multiply(
+ o, self.convert_fused_activation_function(c_new,
fused_activation_fn)
+ )
+ if proj_clip > 0:
+ h_new = relax.op.clip(h_new, -proj_clip, proj_clip)
+
+ # Update state tensors in the expression table for subsequent ops.
+ self.exp_tab.set_expr(
+ get_tensor_name(self.subgraph, input_tensors[18].tensor_idx),
+ h_new,
+ force_override=True,
+ )
+ self.exp_tab.set_expr(
+ get_tensor_name(self.subgraph, input_tensors[19].tensor_idx),
+ c_new,
+ force_override=True,
+ )
+
+ return h_new
+
+ def convert_svdf(self, op):
+ """Convert TFLite SVDF (single-step).
+
+ Structured-Vectorized Bidirectional Filter for keyword spotting.
+
+ Inputs (5 tensors):
+ [0] input [batch, input_size]
+ [1] feature_weights [num_filters, input_size]
+ [2] time_weights [num_filters, memory_size]
+ [3] bias [num_filters] (optional)
+ [4] state [batch, num_filters * memory_size] (variable)
+
+ Output:
+ [0] output [batch, num_units]
+
+ Computation:
+ feat = x @ W_feat.T # feature projection
+ state_r = reshape(state, [B, F, memory_size]) # ring buffer
+ time = sum(state_r * time_weights, axis=-1) # time filtering
+ out = activation(sum(reshape(time, [B, U, rank]), axis=-1) + bias)
+ """
+ from tflite.BuiltinOptions import BuiltinOptions
+ from tflite.SVDFOptions import SVDFOptions
+
+ if self.is_quantized(op):
+ raise tvm.error.OpNotImplemented("TFLite quantized SVDF is not
supported yet.")
+
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 5, (
+ f"input tensors length should be 5, got {len(input_tensors)}"
+ )
+
+ output_tensors = self.get_output_tensors(op)
+ assert len(output_tensors) >= 1, "output tensors length should be at
least 1"
+
+ assert op.BuiltinOptionsType() == BuiltinOptions.SVDFOptions
+ op_options = op.BuiltinOptions()
+ svdf_opts = SVDFOptions()
+ svdf_opts.Init(op_options.Bytes, op_options.Pos)
+
+ rank = svdf_opts.Rank()
+ fused_activation_fn = svdf_opts.FusedActivationFunction()
+
+ in_expr = self.get_tensor_expr(input_tensors[0])
+ feat_weights = self.get_tensor_expr(input_tensors[1])
+ time_weights = self.get_tensor_expr(input_tensors[2])
+
+ batch_size = self.get_tensor_shape(input_tensors[0])[0]
+ if isinstance(batch_size, np.integer | int):
+ batch_size = int(batch_size)
+ num_filters = to_int_list(self.get_tensor_shape(input_tensors[1]))[0]
+ if num_filters % rank != 0:
+ raise tvm.error.OpNotImplemented("SVDF num_filters must be
divisible by rank.")
+ num_units = num_filters // rank
+ memory_size = to_int_list(self.get_tensor_shape(input_tensors[2]))[1]
+
+ # Feature projection: [batch, input_size] @ [input_size, num_filters]
+ feat = relax.op.matmul(in_expr, relax.op.permute_dims(feat_weights))
+
+ # Time filtering: reshape state -> weight -> reduce.
+ state_expr = self.get_tensor_expr(input_tensors[4])
+ state_3d = relax.op.reshape(state_expr, (batch_size, num_filters,
memory_size))
+
+ # time_weights: [num_filters, memory_size], broadcast to [1,
num_filters, memory_size]
+ tw_3d = relax.op.reshape(time_weights, (1, num_filters, memory_size))
+ time_weighted = relax.op.multiply(state_3d, tw_3d)
+ time_output = relax.op.sum(time_weighted, axis=-1, keepdims=False)
+ reduced = relax.op.reshape(time_output, (batch_size, num_units, rank))
+ result = relax.op.sum(reduced, axis=-1, keepdims=False)
+
+ # Add bias if present
+ if input_tensors[3].tensor_idx != -1:
+ bias_expr = self.get_tensor_expr(input_tensors[3])
+ result = relax.op.add(result, bias_expr)
+
+ result = self.convert_fused_activation_function(result,
fused_activation_fn)
+
+ # Update state tensor in the expression table for subsequent steps.
+ # SVDF state is a FIFO ring-buffer: shift left by 1, append new feat.
+ feat_3d = relax.op.expand_dims(feat, axis=-1)
+ if memory_size > 1:
+ shifted_state = relax.op.strided_slice(
+ state_3d, axes=[2], begin=[1], end=[int(memory_size)]
+ )
+ new_state_3d = relax.op.concat([shifted_state, feat_3d], axis=2)
+ else:
+ new_state_3d = feat_3d
+ new_state = relax.op.reshape(new_state_3d, (batch_size, num_filters *
memory_size))
self.exp_tab.set_expr(
- get_tensor_name(self.subgraph, hidden_state_tensor.tensor_idx),
- h,
+ get_tensor_name(self.subgraph, input_tensors[4].tensor_idx),
+ new_state,
force_override=True,
)
- return h
+
+ return result
def convert_unidirectional_sequence_rnn(self, op):
"""Convert TFLite UNIDIRECTIONAL_SEQUENCE_RNN.
diff --git a/tests/python/relax/test_frontend_tflite.py
b/tests/python/relax/test_frontend_tflite.py
index 7c5951d631..263943ad6a 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -85,7 +85,7 @@ def verify(TestClass, expected=None):
tf_output = cf(*tf_inputs)
# TVM Run
- tgt = tvm.target.Target("llvm")
+ tgt = tvm.target.Target("c")
ex = tvm.compile(mod, tgt)
vm = relax.VirtualMachine(ex, tvm.cpu())
vm.set_input("main", *tvm_inputs)
@@ -110,7 +110,7 @@ def _verify_random_with_inputs(cfunc, inputs):
tf_output = cfunc(*tf_inputs)
- tgt = tvm.target.Target("llvm")
+ tgt = tvm.target.Target("c")
ex = tvm.compile(mod, tgt)
vm = relax.VirtualMachine(ex, tvm.cpu())
@@ -3705,7 +3705,6 @@ _tfl_model = _get_tflite_schema_module("Model")
_tfl_operator = _get_tflite_schema_module("Operator")
_tfl_operator_code = _get_tflite_schema_module("OperatorCode")
_tfl_quantization_parameters =
_get_tflite_schema_module("QuantizationParameters")
-_tfl_reduce_window_options = _get_tflite_schema_module("ReduceWindowOptions")
_tfl_sparsity_parameters = _get_tflite_schema_module("SparsityParameters")
_tfl_subgraph = _get_tflite_schema_module("SubGraph")
_tfl_tensor = _get_tflite_schema_module("Tensor")
@@ -3718,12 +3717,12 @@ _tfl_activation_fn =
_get_tflite_schema_enum("ActivationFunctionType")
_tfl_dimension_type = _get_tflite_schema_enum("DimensionType")
_tfl_fc_weights_format =
_get_tflite_schema_enum("FullyConnectedOptionsWeightsFormat")
_tfl_padding = _get_tflite_schema_enum("Padding")
-_tfl_reduce_window_function = _get_tflite_schema_enum("ReduceWindowFunction")
_tfl_sparse_index_vector = _get_tflite_schema_enum("SparseIndexVector")
_tfl_tensor_type = _get_tflite_schema_enum("TensorType")
-_tfl_rnn_options = _get_tflite_schema_module("RNNOptions")
+_tfl_lstm_options = _get_tflite_schema_module("LSTMOptions")
_tfl_sequence_rnn_options = _get_tflite_schema_module("SequenceRNNOptions")
+_tfl_svdf_options = _get_tflite_schema_module("SVDFOptions")
_DENSIFY_TEST_VALUES = np.array([1.0, 2.0], dtype=np.float32)
_DENSIFY_TEST_DENSE = np.array([[1.0, 0.0], [0.0, 2.0]], dtype=np.float32)
@@ -3954,410 +3953,6 @@ def _load_model_from_buffer(model_bytes):
return mod
-def _build_reduce_window_options(builder, reduce_function):
- _tfl_reduce_window_options.ReduceWindowOptionsStart(builder)
- _tfl_reduce_window_options.ReduceWindowOptionsAddReduceFunction(builder,
reduce_function)
- return _tfl_reduce_window_options.ReduceWindowOptionsEnd(builder)
-
-
-def _reduce_window_output_shape(input_shape, window_shape, window_strides,
window_dilations):
- output_shape = []
- for input_dim, window_dim, stride, dilation in zip(
- input_shape, window_shape, window_strides, window_dilations
- ):
- dilated_window = (window_dim - 1) * dilation + 1
- if stride <= 0:
- output_shape.append(0)
- elif input_dim < dilated_window:
- output_shape.append(0)
- else:
- output_shape.append((input_dim - dilated_window) // stride + 1)
- return tuple(output_shape)
-
-
-def _build_reduce_window_model(
- *,
- input_shape,
- init_value,
- init_shape=(),
- window_shape,
- window_strides,
- window_dilations,
- output_shape=None,
- reduce_function,
- tensor_type=None,
- value_dtype=np.float32,
-):
- builder = flatbuffers.Builder(1024)
- if tensor_type is None:
- tensor_type = _tfl_tensor_type.FLOAT32
-
- input_tensor_idx = 0
- init_tensor_idx = 1
- window_shape_tensor_idx = 2
- window_strides_tensor_idx = 3
- window_dilations_tensor_idx = 4
- output_tensor_idx = 5
-
- if output_shape is None:
- output_shape = _reduce_window_output_shape(
- input_shape, window_shape, window_strides, window_dilations
- )
-
- input_tensor = _build_tensor(builder, 1, input_shape,
tensor_type=tensor_type)
- init_tensor = _build_tensor(builder, 2, init_shape,
tensor_type=tensor_type)
- window_shape_tensor = _build_tensor(
- builder, 3, [len(window_shape)], tensor_type=_tfl_tensor_type.INT64
- )
- window_strides_tensor = _build_tensor(
- builder, 4, [len(window_strides)], tensor_type=_tfl_tensor_type.INT64
- )
- window_dilations_tensor = _build_tensor(
- builder, 5, [len(window_dilations)], tensor_type=_tfl_tensor_type.INT64
- )
- output_tensor = _build_tensor(builder, 6, output_shape,
tensor_type=tensor_type)
-
- reduce_window_opts = _build_reduce_window_options(builder, reduce_function)
- reduce_window_op = _build_operator(
- builder,
- 0,
- [
- input_tensor_idx,
- init_tensor_idx,
- window_shape_tensor_idx,
- window_strides_tensor_idx,
- window_dilations_tensor_idx,
- ],
- [output_tensor_idx],
- builtin_options2_type=_tfl_builtin_options2.ReduceWindowOptions,
- builtin_options2=reduce_window_opts,
- )
-
- subgraph = _build_subgraph(
- builder,
- tensors=[
- input_tensor,
- init_tensor,
- window_shape_tensor,
- window_strides_tensor,
- window_dilations_tensor,
- output_tensor,
- ],
- operators=[reduce_window_op],
- inputs=[input_tensor_idx],
- outputs=[output_tensor_idx],
- )
- operator_codes = [_build_operator_code(builder,
_tfl_builtin_operator.REDUCE_WINDOW)]
-
- buffers = [
- _build_buffer(builder),
- _build_buffer(builder),
- _build_buffer(builder, np.asarray([init_value],
dtype=value_dtype).tobytes()),
- _build_buffer(builder, np.asarray(window_shape,
dtype=np.int64).tobytes()),
- _build_buffer(builder, np.asarray(window_strides,
dtype=np.int64).tobytes()),
- _build_buffer(builder, np.asarray(window_dilations,
dtype=np.int64).tobytes()),
- _build_buffer(builder),
- ]
-
- return _finish_tflite_model(
- builder, subgraph=subgraph, operator_codes=operator_codes,
buffers=buffers
- )
-
-
-def _from_reduce_window_model(**kwargs):
- return _load_model_from_buffer(_build_reduce_window_model(**kwargs))
-
-
-def _reduce_window_dilated_shape(window_shape, window_dilations):
- return [
- (window_dim - 1) * dilation + 1
- for window_dim, dilation in zip(window_shape, window_dilations)
- ]
-
-
-def _make_reduce_window_numeric_expected(
- *,
- input_shape,
- init_value,
- init_shape=(),
- window_shape,
- window_strides,
- window_dilations,
- reduce_op,
- combine_op,
- dtype="float32",
-):
- output_shape = _reduce_window_output_shape(
- input_shape, window_shape, window_strides, window_dilations
- )
- dilated_window_shape = _reduce_window_dilated_shape(window_shape,
window_dilations)
- rank = len(input_shape)
-
- bb = relax.BlockBuilder()
- x = relax.Var("tvmgen_tensor_0", relax.TensorStructInfo(input_shape,
dtype))
- with bb.function("main", [x]):
- with bb.dataflow():
- windowed = bb.emit(
- relax.op.call_dps_packed(
- "topi.sliding_window",
- (
- x,
- 0,
- relax.ShapeExpr(dilated_window_shape),
- relax.ShapeExpr(window_strides),
- ),
- out_sinfo=relax.TensorStructInfo(
- output_shape + tuple(dilated_window_shape), dtype
- ),
- )
- )
- if any(dilation != 1 for dilation in window_dilations):
- windowed = bb.emit(
- relax.op.strided_slice(
- windowed,
- axes=list(range(rank, 2 * rank)),
- begin=[0] * rank,
- end=dilated_window_shape,
- strides=window_dilations,
- )
- )
- reduced = bb.emit(reduce_op(windowed, axis=list(range(rank, 2 *
rank))))
- init = relax.const(np.asarray([init_value],
dtype=dtype).reshape(init_shape), dtype)
- if len(init_shape) != 0:
- init = relax.op.reshape(init, [])
- gv = bb.emit_output(combine_op(reduced, init))
- bb.emit_func_output(gv)
-
- mod = bb.get()
- mod["main"] = mod["main"].with_attr("num_input", 1)
- return mod
-
-
-def _make_reduce_window_bool_expected(
- *,
- input_shape,
- init_value,
- window_shape,
- window_strides,
- window_dilations,
- reduce_op,
- combine_op,
-):
- output_shape = _reduce_window_output_shape(
- input_shape, window_shape, window_strides, window_dilations
- )
- dilated_window_shape = _reduce_window_dilated_shape(window_shape,
window_dilations)
- rank = len(input_shape)
-
- bb = relax.BlockBuilder()
- x = relax.Var("tvmgen_tensor_0", relax.TensorStructInfo(input_shape,
"bool"))
- with bb.function("main", [x]):
- with bb.dataflow():
- windowed = bb.emit(
- relax.op.call_dps_packed(
- "topi.sliding_window",
- (
- x,
- 0,
- relax.ShapeExpr(dilated_window_shape),
- relax.ShapeExpr(window_strides),
- ),
- out_sinfo=relax.TensorStructInfo(
- output_shape + tuple(dilated_window_shape), "bool"
- ),
- )
- )
- cast_windowed = bb.emit(relax.op.astype(windowed, "int8"))
- reduced = bb.emit(reduce_op(cast_windowed, axis=list(range(rank, 2
* rank))))
- reduced_bool = bb.emit(relax.op.astype(reduced, "bool"))
- gv = bb.emit_output(combine_op(reduced_bool,
relax.const(init_value, "bool")))
- bb.emit_func_output(gv)
-
- mod = bb.get()
- mod["main"] = mod["main"].with_attr("num_input", 1)
- return mod
-
-
-def _make_reduce_window_empty_expected(*, input_shape, output_shape,
dtype="float32"):
- bb = relax.BlockBuilder()
- x = relax.Var("tvmgen_tensor_0", relax.TensorStructInfo(input_shape,
dtype))
- with bb.function("main", [x]):
- with bb.dataflow():
- gv = bb.emit_output(relax.op.zeros(output_shape, dtype))
- bb.emit_func_output(gv)
-
- mod = bb.get()
- mod["main"] = mod["main"].with_attr("num_input", 1)
- return mod
-
-
-def test_reduce_window_unsupported_function():
- with pytest.raises(tvm.error.OpNotImplemented, match="UNSUPPORTED
reduce_function"):
- _from_reduce_window_model(
- input_shape=(4,),
- init_value=0.0,
- window_shape=[2],
- window_strides=[1],
- window_dilations=[1],
- reduce_function=_tfl_reduce_window_function.UNSUPPORTED,
- )
-
-
[email protected](
- "reduce_function, reduce_op, combine_op",
- [
- (_tfl_reduce_window_function.ADD, relax.op.sum, relax.op.add),
- (_tfl_reduce_window_function.MUL, relax.op.prod, relax.op.multiply),
- (_tfl_reduce_window_function.MINIMUM, relax.op.min, relax.op.minimum),
- (_tfl_reduce_window_function.MAXIMUM, relax.op.max, relax.op.maximum),
- ],
-)
-def test_reduce_window_numeric_modes(reduce_function, reduce_op, combine_op):
- input_shape = (4, 5)
- init_value = 1.0
- window_shape = [2, 2]
- window_strides = [1, 2]
- window_dilations = [2, 1]
- mod = _from_reduce_window_model(
- input_shape=input_shape,
- init_value=init_value,
- window_shape=window_shape,
- window_strides=window_strides,
- window_dilations=window_dilations,
- reduce_function=reduce_function,
- )
- expected = _make_reduce_window_numeric_expected(
- input_shape=input_shape,
- init_value=init_value,
- window_shape=window_shape,
- window_strides=window_strides,
- window_dilations=window_dilations,
- reduce_op=reduce_op,
- combine_op=combine_op,
- )
- tvm.ir.assert_structural_equal(mod, expected)
-
-
-def test_reduce_window_one_element_init_tensor():
- input_shape = (4,)
- init_value = 1.0
- init_shape = (1,)
- window_shape = [2]
- window_strides = [1]
- window_dilations = [1]
- mod = _from_reduce_window_model(
- input_shape=input_shape,
- init_value=init_value,
- init_shape=init_shape,
- window_shape=window_shape,
- window_strides=window_strides,
- window_dilations=window_dilations,
- reduce_function=_tfl_reduce_window_function.ADD,
- )
- expected = _make_reduce_window_numeric_expected(
- input_shape=input_shape,
- init_value=init_value,
- init_shape=init_shape,
- window_shape=window_shape,
- window_strides=window_strides,
- window_dilations=window_dilations,
- reduce_op=relax.op.sum,
- combine_op=relax.op.add,
- )
- tvm.ir.assert_structural_equal(mod, expected)
-
-
[email protected](
- "reduce_function, reduce_op, combine_op, init_value",
- [
- (_tfl_reduce_window_function.ALL, relax.op.min, relax.op.logical_and,
True),
- (_tfl_reduce_window_function.ANY, relax.op.max, relax.op.logical_or,
False),
- ],
-)
-def test_reduce_window_bool_modes(reduce_function, reduce_op, combine_op,
init_value):
- input_shape = (5,)
- window_shape = [3]
- window_strides = [2]
- window_dilations = [1]
- mod = _from_reduce_window_model(
- input_shape=input_shape,
- init_value=init_value,
- window_shape=window_shape,
- window_strides=window_strides,
- window_dilations=window_dilations,
- reduce_function=reduce_function,
- tensor_type=_tfl_tensor_type.BOOL,
- value_dtype=np.bool_,
- )
- expected = _make_reduce_window_bool_expected(
- input_shape=input_shape,
- init_value=init_value,
- window_shape=window_shape,
- window_strides=window_strides,
- window_dilations=window_dilations,
- reduce_op=reduce_op,
- combine_op=combine_op,
- )
- tvm.ir.assert_structural_equal(mod, expected)
-
-
-def test_reduce_window_empty_output_dimension():
- input_shape = (2,)
- window_shape = [3]
- window_strides = [1]
- window_dilations = [1]
- mod = _from_reduce_window_model(
- input_shape=input_shape,
- init_value=0.0,
- window_shape=window_shape,
- window_strides=window_strides,
- window_dilations=window_dilations,
- reduce_function=_tfl_reduce_window_function.ADD,
- )
- expected = _make_reduce_window_empty_expected(
- input_shape=input_shape,
- output_shape=(0,),
- )
- tvm.ir.assert_structural_equal(mod, expected)
-
-
-def test_reduce_window_mismatched_window_rank():
- with pytest.raises(tvm.error.OpAttributeUnImplemented, match="must match
input rank"):
- _from_reduce_window_model(
- input_shape=(4, 5),
- init_value=0.0,
- window_shape=[2],
- window_strides=[1],
- window_dilations=[1],
- reduce_function=_tfl_reduce_window_function.ADD,
- )
-
-
-def test_reduce_window_non_positive_stride():
- with pytest.raises(tvm.error.OpAttributeUnImplemented, match="must be
positive"):
- _from_reduce_window_model(
- input_shape=(4,),
- init_value=0.0,
- window_shape=[2],
- window_strides=[0],
- window_dilations=[1],
- reduce_function=_tfl_reduce_window_function.ADD,
- )
-
-
-def test_reduce_window_inconsistent_output_shape():
- with pytest.raises(tvm.error.OpAttributeUnImplemented, match="output
shape"):
- _from_reduce_window_model(
- input_shape=(5,),
- init_value=0.0,
- window_shape=[2],
- window_strides=[1],
- window_dilations=[1],
- output_shape=(3,),
- reduce_function=_tfl_reduce_window_function.ADD,
- )
-
-
def _get_builtin_operator(builtin_name):
if not hasattr(_tfl_builtin_operator, builtin_name):
pytest.skip(f"TFLite schema does not provide
BuiltinOperator.{builtin_name}")
@@ -10128,251 +9723,661 @@ def test_dilate_dynamic_dilations():
tvm.ir.assert_structural_equal(mod, Expected)
-# ── RNN
────────────────────────────────────────────────────────────────────────
+# ── LSTM
──────────────────────────────────────────────────────────────────────
-def _build_rnn_model(batch, input_size, num_units, weights, recurrent_weights,
bias, activation):
- """Build a minimal TFLite flatbuffer model containing one RNN op.
-
- Tensor layout (indices 0-5):
- 0 - input [batch, input_size]
- 1 - input_weights [num_units, input_size] (constant)
- 2 - recurrent_weights [num_units, num_units] (constant)
- 3 - bias [num_units] (constant)
- 4 - hidden_state [batch, num_units] (variable,
zero-initialised)
- 5 - output [batch, num_units]
+def _build_lstm_model(
+ batch,
+ input_size,
+ num_units,
+ input_to_forget_weights,
+ input_to_cell_weights,
+ input_to_output_weights,
+ recurrent_to_forget_weights,
+ recurrent_to_cell_weights,
+ recurrent_to_output_weights,
+ forget_gate_bias,
+ cell_bias,
+ output_gate_bias,
+ activation,
+ *,
+ cell_clip=0.0,
+ proj_clip=0.0,
+ include_unsupported=False,
+):
+ """Build a minimal TFLite flatbuffer model with one LSTM op (coupled
input-forget).
+
+ Tensor indices:
+ 0 - input [batch, input_size]
+ 1 - input_to_forget_weights [num_units, input_size] (constant)
+ 2 - input_to_cell_weights [num_units, input_size] (constant)
+ 3 - input_to_output_weights [num_units, input_size] (constant)
+ 4 - recurrent_to_forget_weights [num_units, num_units] (constant)
+ 5 - recurrent_to_cell_weights [num_units, num_units] (constant)
+ 6 - recurrent_to_output_weights [num_units, num_units] (constant)
+ 7 - forget_gate_bias [num_units] (constant)
+ 8 - cell_bias [num_units] (constant)
+ 9 - output_gate_bias [num_units] (constant)
+ 10 - output_state [batch, num_units] (input)
+ 11 - cell_state [batch, num_units] (input)
+ 12 - output [batch, num_units]
+
+ Operator input indices (24 entries, -1 for absent):
+ [0, -1, 1, 2, 3, -1, 4, 5, 6, -1, -1, -1, -1, 7, 8, 9, -1, -1, 10, 11,
-1, -1, -1, -1]
"""
builder = flatbuffers.Builder(4096)
- _tfl_rnn_options.RNNOptionsStart(builder)
- _tfl_rnn_options.RNNOptionsAddFusedActivationFunction(builder, activation)
- rnn_opts = _tfl_rnn_options.RNNOptionsEnd(builder)
+ _tfl_lstm_options.LSTMOptionsStart(builder)
+ _tfl_lstm_options.LSTMOptionsAddFusedActivationFunction(builder,
activation)
+ _tfl_lstm_options.LSTMOptionsAddCellClip(builder, cell_clip)
+ _tfl_lstm_options.LSTMOptionsAddProjClip(builder, proj_clip)
+ lstm_opts = _tfl_lstm_options.LSTMOptionsEnd(builder)
- rnn_op_code = _build_operator_code(builder, _tfl_builtin_operator.RNN)
+ lstm_op_code = _build_operator_code(builder, _tfl_builtin_operator.LSTM)
- def _t(buf_idx, shape, is_variable=False):
+ def _t(buf_idx, shape):
shape_vec = _tflite_shape(builder, shape)
_tfl_tensor.TensorStart(builder)
_tfl_tensor.TensorAddBuffer(builder, buf_idx)
_tfl_tensor.TensorAddHasRank(builder, True)
- _tfl_tensor.TensorAddIsVariable(builder, is_variable)
+ _tfl_tensor.TensorAddIsVariable(builder, False)
_tfl_tensor.TensorAddShape(builder, shape_vec)
_tfl_tensor.TensorAddType(builder, _tfl_tensor_type.FLOAT32)
return _tfl_tensor.TensorEnd(builder)
tensors = [
+ # 0: input
_t(0, [batch, input_size]),
+ # 1: input_to_forget_weights (coupled)
_t(1, [num_units, input_size]),
- _t(2, [num_units, num_units]),
- _t(3, [num_units]),
- _t(4, [batch, num_units], is_variable=True),
- _t(5, [batch, num_units]),
+ # 2: input_to_cell_weights
+ _t(2, [num_units, input_size]),
+ # 3: input_to_output_weights
+ _t(3, [num_units, input_size]),
+ # 4: recurrent_to_forget_weights (coupled)
+ _t(4, [num_units, num_units]),
+ # 5: recurrent_to_cell_weights
+ _t(5, [num_units, num_units]),
+ # 6: recurrent_to_output_weights
+ _t(6, [num_units, num_units]),
+ # 7: forget_gate_bias (coupled)
+ _t(7, [num_units]),
+ # 8: cell_bias
+ _t(8, [num_units]),
+ # 9: output_gate_bias
+ _t(9, [num_units]),
+ # 10: output_state (input)
+ _t(0, [batch, num_units]),
+ # 11: cell_state (input)
+ _t(0, [batch, num_units]),
+ # 12: output
+ _t(0, [batch, num_units]),
]
- rnn_op = _build_operator(
+ if include_unsupported:
+ tensors.extend(
+ [
+ _t(0, [num_units]),
+ _t(0, [num_units]),
+ _t(0, [num_units]),
+ _t(0, [num_units, num_units]),
+ _t(0, [num_units]),
+ _t(0, [num_units]),
+ _t(0, [num_units]),
+ _t(0, [num_units]),
+ _t(0, [num_units]),
+ ]
+ )
+
+ # Operator input indices: -1 for absent optional inputs
+ lstm_inputs = [
+ 0,
+ -1,
+ 1,
+ 2,
+ 3,
+ -1,
+ 4,
+ 5,
+ 6,
+ 13 if include_unsupported else -1,
+ 14 if include_unsupported else -1,
+ 15 if include_unsupported else -1,
+ -1,
+ 7,
+ 8,
+ 9,
+ 16 if include_unsupported else -1,
+ 17 if include_unsupported else -1,
+ 10,
+ 11,
+ 18 if include_unsupported else -1,
+ 19 if include_unsupported else -1,
+ 20 if include_unsupported else -1,
+ 21 if include_unsupported else -1,
+ ]
+
+ lstm_op = _build_operator(
builder,
0,
- [0, 1, 2, 3, 4],
- [5],
- builtin_options_type=_tfl_builtin_options.RNNOptions,
- builtin_options=rnn_opts,
+ lstm_inputs,
+ [12],
+ builtin_options_type=_tfl_builtin_options.LSTMOptions,
+ builtin_options=lstm_opts,
)
subgraph = _build_subgraph(
builder,
tensors=tensors,
- operators=[rnn_op],
- inputs=[0],
- outputs=[5],
+ operators=[lstm_op],
+ inputs=[0, 10, 11],
+ outputs=[12],
)
buffers = [
- _build_buffer(builder),
- _build_buffer(builder, weights.tobytes()),
- _build_buffer(builder, recurrent_weights.tobytes()),
- _build_buffer(builder, bias.tobytes()),
- _build_buffer(builder),
- _build_buffer(builder),
+ _build_buffer(builder), # 0: empty
+ _build_buffer(builder, input_to_forget_weights.tobytes()), # 1
+ _build_buffer(builder, input_to_cell_weights.tobytes()), # 2
+ _build_buffer(builder, input_to_output_weights.tobytes()), # 3
+ _build_buffer(builder, recurrent_to_forget_weights.tobytes()), # 4
+ _build_buffer(builder, recurrent_to_cell_weights.tobytes()), # 5
+ _build_buffer(builder, recurrent_to_output_weights.tobytes()), # 6
+ _build_buffer(builder, forget_gate_bias.tobytes()), # 7
+ _build_buffer(builder, cell_bias.tobytes()), # 8
+ _build_buffer(builder, output_gate_bias.tobytes()), # 9
]
+ if include_unsupported:
+ buffers.extend([_build_buffer(builder) for _ in range(9)])
+
return _finish_tflite_model(
builder,
subgraph=subgraph,
- operator_codes=[rnn_op_code],
+ operator_codes=[lstm_op_code],
buffers=buffers,
)
-def _build_two_step_shared_state_rnn_model(
- batch, input_size, num_units, weights, recurrent_weights, bias, activation
+def test_lstm_none_activation():
+ """LSTM with NONE activation uses the cell state before the output gate
multiply."""
+ from tflite.ActivationFunctionType import ActivationFunctionType
+
+ batch, input_size, num_units = 2, 2, 2
+ w_f = np.eye(num_units, input_size, dtype=np.float32)
+ w_c = np.eye(num_units, input_size, dtype=np.float32)
+ w_o = np.eye(num_units, input_size, dtype=np.float32)
+ r_f = np.eye(num_units, dtype=np.float32)
+ r_c = np.eye(num_units, dtype=np.float32)
+ r_o = np.eye(num_units, dtype=np.float32)
+ b_f = np.zeros(num_units, dtype=np.float32)
+ b_c = np.zeros(num_units, dtype=np.float32)
+ b_o = np.zeros(num_units, dtype=np.float32)
+
+ mod = _load_model_from_buffer(
+ _build_lstm_model(
+ batch,
+ input_size,
+ num_units,
+ w_f,
+ w_c,
+ w_o,
+ r_f,
+ r_c,
+ r_o,
+ b_f,
+ b_c,
+ b_o,
+ ActivationFunctionType.NONE,
+ )
+ )
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"),
+ tvmgen_tensor_10: R.Tensor((2, 2), dtype="float32"),
+ tvmgen_tensor_11: R.Tensor((2, 2), dtype="float32"),
+ ) -> R.Tensor((2, 2), dtype="float32"):
+ R.func_attr({"num_input": 3})
+ with R.dataflow():
+ lv: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
+ R.const(np.eye(2, dtype=np.float32)), axes=None
+ )
+ lv1: R.Tensor((2, 2), dtype="float32") = R.matmul(
+ tvmgen_tensor_0, lv, out_dtype="void"
+ )
+ lv2: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
+ R.const(np.eye(2, dtype=np.float32)), axes=None
+ )
+ lv3: R.Tensor((2, 2), dtype="float32") = R.matmul(
+ tvmgen_tensor_10, lv2, out_dtype="void"
+ )
+ lv4: R.Tensor((2, 2), dtype="float32") = R.add(lv1, lv3)
+ lv5: R.Tensor((2, 2), dtype="float32") = R.add(
+ lv4, R.const(np.zeros(2, dtype=np.float32))
+ )
+ lv6: R.Tensor((2, 2), dtype="float32") = R.sigmoid(lv5)
+ lv7: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
+ R.const(np.eye(2, dtype=np.float32)), axes=None
+ )
+ lv8: R.Tensor((2, 2), dtype="float32") = R.matmul(
+ tvmgen_tensor_0, lv7, out_dtype="void"
+ )
+ lv9: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
+ R.const(np.eye(2, dtype=np.float32)), axes=None
+ )
+ lv10: R.Tensor((2, 2), dtype="float32") = R.matmul(
+ tvmgen_tensor_10, lv9, out_dtype="void"
+ )
+ lv11: R.Tensor((2, 2), dtype="float32") = R.add(lv8, lv10)
+ lv12: R.Tensor((2, 2), dtype="float32") = R.add(
+ lv11, R.const(np.zeros(2, dtype=np.float32))
+ )
+ lv13: R.Tensor((2, 2), dtype="float32") = R.sigmoid(lv12)
+ lv14: R.Tensor((2, 2), dtype="float32") = R.multiply(lv13,
tvmgen_tensor_11)
+ lv15: R.Tensor((2, 2), dtype="float32") =
R.subtract(R.const(1.0, "float32"), lv13)
+ lv16: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
+ R.const(np.eye(2, dtype=np.float32)), axes=None
+ )
+ lv17: R.Tensor((2, 2), dtype="float32") = R.matmul(
+ tvmgen_tensor_0, lv16, out_dtype="void"
+ )
+ lv18: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
+ R.const(np.eye(2, dtype=np.float32)), axes=None
+ )
+ lv19: R.Tensor((2, 2), dtype="float32") = R.matmul(
+ tvmgen_tensor_10, lv18, out_dtype="void"
+ )
+ lv20: R.Tensor((2, 2), dtype="float32") = R.add(lv17, lv19)
+ lv21: R.Tensor((2, 2), dtype="float32") = R.add(
+ lv20, R.const(np.zeros(2, dtype=np.float32))
+ )
+ lv22: R.Tensor((2, 2), dtype="float32") = R.tanh(lv21)
+ lv23: R.Tensor((2, 2), dtype="float32") = R.multiply(lv15,
lv22)
+ lv24: R.Tensor((2, 2), dtype="float32") = R.add(lv14, lv23)
+ gv: R.Tensor((2, 2), dtype="float32") = R.multiply(lv6, lv24)
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_lstm_tanh_activation():
+ """LSTM with TANH activation applies tanh before the output gate
multiply."""
+ from tflite.ActivationFunctionType import ActivationFunctionType
+
+ batch, input_size, num_units = 2, 2, 2
+ w_f = np.eye(num_units, input_size, dtype=np.float32)
+ w_c = np.eye(num_units, input_size, dtype=np.float32)
+ w_o = np.eye(num_units, input_size, dtype=np.float32)
+ r_f = np.eye(num_units, dtype=np.float32)
+ r_c = np.eye(num_units, dtype=np.float32)
+ r_o = np.eye(num_units, dtype=np.float32)
+ b_f = np.zeros(num_units, dtype=np.float32)
+ b_c = np.zeros(num_units, dtype=np.float32)
+ b_o = np.zeros(num_units, dtype=np.float32)
+
+ mod = _load_model_from_buffer(
+ _build_lstm_model(
+ batch,
+ input_size,
+ num_units,
+ w_f,
+ w_c,
+ w_o,
+ r_f,
+ r_c,
+ r_o,
+ b_f,
+ b_c,
+ b_o,
+ ActivationFunctionType.TANH,
+ )
+ )
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"),
+ tvmgen_tensor_10: R.Tensor((2, 2), dtype="float32"),
+ tvmgen_tensor_11: R.Tensor((2, 2), dtype="float32"),
+ ) -> R.Tensor((2, 2), dtype="float32"):
+ R.func_attr({"num_input": 3})
+ with R.dataflow():
+ lv: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
+ R.const(np.eye(2, dtype=np.float32)), axes=None
+ )
+ lv1: R.Tensor((2, 2), dtype="float32") = R.matmul(
+ tvmgen_tensor_0, lv, out_dtype="void"
+ )
+ lv2: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
+ R.const(np.eye(2, dtype=np.float32)), axes=None
+ )
+ lv3: R.Tensor((2, 2), dtype="float32") = R.matmul(
+ tvmgen_tensor_10, lv2, out_dtype="void"
+ )
+ lv4: R.Tensor((2, 2), dtype="float32") = R.add(lv1, lv3)
+ lv5: R.Tensor((2, 2), dtype="float32") = R.add(
+ lv4, R.const(np.zeros(2, dtype=np.float32))
+ )
+ lv6: R.Tensor((2, 2), dtype="float32") = R.sigmoid(lv5)
+ lv7: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
+ R.const(np.eye(2, dtype=np.float32)), axes=None
+ )
+ lv8: R.Tensor((2, 2), dtype="float32") = R.matmul(
+ tvmgen_tensor_0, lv7, out_dtype="void"
+ )
+ lv9: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
+ R.const(np.eye(2, dtype=np.float32)), axes=None
+ )
+ lv10: R.Tensor((2, 2), dtype="float32") = R.matmul(
+ tvmgen_tensor_10, lv9, out_dtype="void"
+ )
+ lv11: R.Tensor((2, 2), dtype="float32") = R.add(lv8, lv10)
+ lv12: R.Tensor((2, 2), dtype="float32") = R.add(
+ lv11, R.const(np.zeros(2, dtype=np.float32))
+ )
+ lv13: R.Tensor((2, 2), dtype="float32") = R.sigmoid(lv12)
+ lv14: R.Tensor((2, 2), dtype="float32") = R.multiply(lv13,
tvmgen_tensor_11)
+ lv15: R.Tensor((2, 2), dtype="float32") =
R.subtract(R.const(1.0, "float32"), lv13)
+ lv16: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
+ R.const(np.eye(2, dtype=np.float32)), axes=None
+ )
+ lv17: R.Tensor((2, 2), dtype="float32") = R.matmul(
+ tvmgen_tensor_0, lv16, out_dtype="void"
+ )
+ lv18: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
+ R.const(np.eye(2, dtype=np.float32)), axes=None
+ )
+ lv19: R.Tensor((2, 2), dtype="float32") = R.matmul(
+ tvmgen_tensor_10, lv18, out_dtype="void"
+ )
+ lv20: R.Tensor((2, 2), dtype="float32") = R.add(lv17, lv19)
+ lv21: R.Tensor((2, 2), dtype="float32") = R.add(
+ lv20, R.const(np.zeros(2, dtype=np.float32))
+ )
+ lv22: R.Tensor((2, 2), dtype="float32") = R.tanh(lv21)
+ lv23: R.Tensor((2, 2), dtype="float32") = R.multiply(lv15,
lv22)
+ lv24: R.Tensor((2, 2), dtype="float32") = R.add(lv14, lv23)
+ lv25: R.Tensor((2, 2), dtype="float32") = R.tanh(lv24)
+ gv: R.Tensor((2, 2), dtype="float32") = R.multiply(lv6, lv25)
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_lstm_rejects_unsupported_features():
+ """LSTM with peephole/projection/layer norm tensors should be rejected."""
+ from tflite.ActivationFunctionType import ActivationFunctionType
+
+ batch, input_size, num_units = 2, 2, 2
+ zeros_w = np.zeros((num_units, input_size), dtype=np.float32)
+ zeros_r = np.zeros((num_units, num_units), dtype=np.float32)
+ zeros_b = np.zeros(num_units, dtype=np.float32)
+
+ with pytest.raises(tvm.error.OpNotImplemented, match="not supported yet"):
+ _load_model_from_buffer(
+ _build_lstm_model(
+ batch,
+ input_size,
+ num_units,
+ zeros_w,
+ zeros_w,
+ zeros_w,
+ zeros_r,
+ zeros_r,
+ zeros_r,
+ zeros_b,
+ zeros_b,
+ zeros_b,
+ ActivationFunctionType.NONE,
+ include_unsupported=True,
+ )
+ )
+
+
+# ── SVDF
──────────────────────────────────────────────────────────────────────
+
+
+def _build_svdf_model(
+ batch,
+ input_size,
+ num_units,
+ rank,
+ memory_size,
+ num_filters,
+ feat_weights,
+ time_weights,
+ bias,
+ activation,
):
- """Build a TFLite model with two RNN ops sharing the same hidden-state
tensor."""
+ """Build a minimal TFLite flatbuffer model containing one SVDF op.
+
+ Tensor indices:
+ 0 - input [batch, input_size] (model input)
+ 1 - feature_weights [num_filters, input_size] (constant)
+ 2 - time_weights [num_filters, memory_size] (constant)
+ 3 - bias [num_units] (constant)
+ 4 - state [batch, num_filters * memory_size] (variable, model
input)
+ 5 - output [batch, num_units]
+ """
builder = flatbuffers.Builder(4096)
- _tfl_rnn_options.RNNOptionsStart(builder)
- _tfl_rnn_options.RNNOptionsAddFusedActivationFunction(builder, activation)
- rnn_opts = _tfl_rnn_options.RNNOptionsEnd(builder)
+ _tfl_svdf_options.SVDFOptionsStart(builder)
+ _tfl_svdf_options.SVDFOptionsAddRank(builder, rank)
+ _tfl_svdf_options.SVDFOptionsAddFusedActivationFunction(builder,
activation)
+ svdf_opts = _tfl_svdf_options.SVDFOptionsEnd(builder)
- rnn_op_code = _build_operator_code(builder, _tfl_builtin_operator.RNN)
+ svdf_op_code = _build_operator_code(builder, _tfl_builtin_operator.SVDF)
- def _t(buf_idx, shape, is_variable=False):
+ def _t(buf_idx, shape):
shape_vec = _tflite_shape(builder, shape)
_tfl_tensor.TensorStart(builder)
_tfl_tensor.TensorAddBuffer(builder, buf_idx)
_tfl_tensor.TensorAddHasRank(builder, True)
- _tfl_tensor.TensorAddIsVariable(builder, is_variable)
+ _tfl_tensor.TensorAddIsVariable(builder, False)
_tfl_tensor.TensorAddShape(builder, shape_vec)
_tfl_tensor.TensorAddType(builder, _tfl_tensor_type.FLOAT32)
return _tfl_tensor.TensorEnd(builder)
tensors = [
- _t(0, [batch, input_size]),
- _t(1, [num_units, input_size]),
- _t(2, [num_units, num_units]),
- _t(3, [num_units]),
- _t(4, [batch, num_units], is_variable=True),
- _t(0, [batch, input_size]),
- _t(0, [batch, num_units]),
- _t(0, [batch, num_units]),
+ _t(0, [batch, input_size]), # 0: input
+ _t(1, [num_filters, input_size]), # 1: feature_weights
+ _t(2, [num_filters, memory_size]), # 2: time_weights
+ _t(3, [num_units]), # 3: bias
+ _t(0, [batch, num_filters * memory_size]), # 4: state (variable,
zero-filled)
+ _t(0, [batch, num_units]), # 5: output
]
- first_rnn_op = _build_operator(
+ svdf_op = _build_operator(
builder,
0,
[0, 1, 2, 3, 4],
- [6],
- builtin_options_type=_tfl_builtin_options.RNNOptions,
- builtin_options=rnn_opts,
- )
- second_rnn_op = _build_operator(
- builder,
- 0,
- [5, 1, 2, 3, 4],
- [7],
- builtin_options_type=_tfl_builtin_options.RNNOptions,
- builtin_options=rnn_opts,
+ [5],
+ builtin_options_type=_tfl_builtin_options.SVDFOptions,
+ builtin_options=svdf_opts,
)
subgraph = _build_subgraph(
builder,
tensors=tensors,
- operators=[first_rnn_op, second_rnn_op],
- inputs=[0, 5],
- outputs=[7],
+ operators=[svdf_op],
+ inputs=[0, 4],
+ outputs=[5],
)
buffers = [
- _build_buffer(builder),
- _build_buffer(builder, weights.tobytes()),
- _build_buffer(builder, recurrent_weights.tobytes()),
- _build_buffer(builder, bias.tobytes()),
- _build_buffer(builder),
+ _build_buffer(builder), # 0: empty
+ _build_buffer(builder, feat_weights.tobytes()), # 1
+ _build_buffer(builder, time_weights.tobytes()), # 2
+ _build_buffer(builder, bias.tobytes()), # 3
]
return _finish_tflite_model(
builder,
subgraph=subgraph,
- operator_codes=[rnn_op_code],
+ operator_codes=[svdf_op_code],
buffers=buffers,
)
-def test_rnn_none_activation():
- """RNN with NONE activation lowers to matmul/add.
-
- Cell equation: h = x @ W.T + h @ Wr.T + b (no activation for NONE)
- """
+def test_svdf_none_activation():
+ """SVDF with NONE activation, verifying output shape and params."""
from tflite.ActivationFunctionType import ActivationFunctionType
- batch, input_size, num_units = 2, 2, 2
- weights = np.eye(num_units, input_size, dtype=np.float32)
- recurrent_weights = np.eye(num_units, dtype=np.float32)
+ batch, input_size, num_units, rank, memory_size = 2, 3, 2, 2, 3
+ num_filters = num_units * rank
+ np.random.seed(42)
+ feat_weights = np.random.randn(num_filters, input_size).astype(np.float32)
+ time_weights = np.random.randn(num_filters, memory_size).astype(np.float32)
bias = np.zeros(num_units, dtype=np.float32)
mod = _load_model_from_buffer(
- _build_rnn_model(
+ _build_svdf_model(
batch,
input_size,
num_units,
- weights,
- recurrent_weights,
+ rank,
+ memory_size,
+ num_filters,
+ feat_weights,
+ time_weights,
bias,
ActivationFunctionType.NONE,
)
)
- @I.ir_module
- class Expected:
- @R.function
- def main(x: R.Tensor((2, 2), dtype="float32")) -> R.Tensor((2, 2),
dtype="float32"):
- R.func_attr({"num_input": 1})
- with R.dataflow():
- lv: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
- R.const(np.eye(2, dtype=np.float32)), axes=None
- )
- lv1: R.Tensor((2, 2), dtype="float32") = R.matmul(x, lv,
out_dtype="void")
- lv2: R.Tensor((2, 2), dtype="float32") = R.zeros(R.shape([2,
2]), dtype="float32")
- lv3: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
- R.const(np.eye(2, dtype=np.float32)), axes=None
- )
- lv4: R.Tensor((2, 2), dtype="float32") = R.matmul(lv2, lv3,
out_dtype="void")
- lv5: R.Tensor((2, 2), dtype="float32") = R.add(lv1, lv4)
- gv: R.Tensor((2, 2), dtype="float32") = R.add(
- lv5, R.const(np.zeros(2, dtype=np.float32))
- )
- R.output(gv)
- return gv
+ fn = mod["main"]
+ assert len(fn.params) == 2, f"expected 2 params (input, state), got
{len(fn.params)}"
+ in_shape = fn.params[0].struct_info.shape
+ assert tuple(int(d) for d in in_shape) == (batch, input_size)
+ state_shape = fn.params[1].struct_info.shape
+ assert tuple(int(d) for d in state_shape) == (batch, num_filters *
memory_size)
+ out_shape = fn.ret_struct_info.shape
+ assert tuple(int(d) for d in out_shape) == (batch, num_units)
- tvm.ir.assert_structural_equal(mod, Expected)
+def _build_two_step_shared_state_svdf_model(
+ batch,
+ input_size,
+ num_units,
+ rank,
+ memory_size,
+ feat_weights_0,
+ time_weights_0,
+ bias_0,
+ feat_weights_1,
+ time_weights_1,
+ bias_1,
+ activation,
+):
+ """Build two consecutive SVDF ops sharing a single state tensor."""
+ builder = flatbuffers.Builder(4096)
+ num_filters = num_units * rank
+
+ _tfl_svdf_options.SVDFOptionsStart(builder)
+ _tfl_svdf_options.SVDFOptionsAddRank(builder, rank)
+ _tfl_svdf_options.SVDFOptionsAddFusedActivationFunction(builder,
activation)
+ svdf_opts = _tfl_svdf_options.SVDFOptionsEnd(builder)
-def test_rnn_relu_activation():
- """RNN with RELU activation and random weights."""
- from tflite.ActivationFunctionType import ActivationFunctionType
+ svdf_op_code = _build_operator_code(builder, _tfl_builtin_operator.SVDF)
- batch, input_size, num_units = 2, 4, 8
- np.random.seed(42)
- weights = np.random.randn(num_units, input_size).astype(np.float32)
- recurrent_weights = np.random.randn(num_units,
num_units).astype(np.float32)
- bias = np.random.randn(num_units).astype(np.float32)
+ def _t(buf_idx, shape):
+ shape_vec = _tflite_shape(builder, shape)
+ _tfl_tensor.TensorStart(builder)
+ _tfl_tensor.TensorAddBuffer(builder, buf_idx)
+ _tfl_tensor.TensorAddHasRank(builder, True)
+ _tfl_tensor.TensorAddIsVariable(builder, False)
+ _tfl_tensor.TensorAddShape(builder, shape_vec)
+ _tfl_tensor.TensorAddType(builder, _tfl_tensor_type.FLOAT32)
+ return _tfl_tensor.TensorEnd(builder)
- mod = _load_model_from_buffer(
- _build_rnn_model(
- batch,
- input_size,
- num_units,
- weights,
- recurrent_weights,
- bias,
- ActivationFunctionType.RELU,
- )
+ tensors = [
+ _t(0, [batch, input_size]), # 0 input_0
+ _t(1, [num_filters, input_size]), # 1 feat_weights_0
+ _t(2, [num_filters, memory_size]), # 2 time_weights_0
+ _t(3, [num_units]), # 3 bias_0
+ _t(0, [batch, num_filters * memory_size]), # 4 shared state
+ _t(0, [batch, num_units]), # 5 output_0
+ _t(0, [batch, input_size]), # 6 input_1
+ _t(4, [num_filters, input_size]), # 7 feat_weights_1
+ _t(5, [num_filters, memory_size]), # 8 time_weights_1
+ _t(6, [num_units]), # 9 bias_1
+ _t(0, [batch, num_units]), # 10 output_1
+ ]
+
+ svdf_op_0 = _build_operator(
+ builder,
+ 0,
+ [0, 1, 2, 3, 4],
+ [5],
+ builtin_options_type=_tfl_builtin_options.SVDFOptions,
+ builtin_options=svdf_opts,
+ )
+ svdf_op_1 = _build_operator(
+ builder,
+ 0,
+ [6, 7, 8, 9, 4],
+ [10],
+ builtin_options_type=_tfl_builtin_options.SVDFOptions,
+ builtin_options=svdf_opts,
)
- fn = mod["main"]
- assert len(fn.params) == 1, "only the input should be a graph input"
- in_shape = fn.params[0].struct_info.shape
- assert tuple(int(d) for d in in_shape) == (batch, input_size)
- out_shape = fn.ret_struct_info.shape
- assert tuple(int(d) for d in out_shape) == (batch, num_units)
+ subgraph = _build_subgraph(
+ builder,
+ tensors=tensors,
+ operators=[svdf_op_0, svdf_op_1],
+ inputs=[0, 6, 4],
+ outputs=[10],
+ )
+
+ buffers = [
+ _build_buffer(builder),
+ _build_buffer(builder, feat_weights_0.tobytes()),
+ _build_buffer(builder, time_weights_0.tobytes()),
+ _build_buffer(builder, bias_0.tobytes()),
+ _build_buffer(builder, feat_weights_1.tobytes()),
+ _build_buffer(builder, time_weights_1.tobytes()),
+ _build_buffer(builder, bias_1.tobytes()),
+ ]
+
+ return _finish_tflite_model(
+ builder,
+ subgraph=subgraph,
+ operator_codes=[svdf_op_code],
+ buffers=buffers,
+ )
-def test_rnn_shared_hidden_state_updates_exp_tab():
- """Two consecutive RNN ops sharing hidden_state should use the updated
state."""
+def test_svdf_shared_state_updates_exp_tab():
+ """Two SVDF ops sharing state should use the updated FIFO state in the
second step."""
from tflite.ActivationFunctionType import ActivationFunctionType
- batch, input_size, num_units = 2, 2, 2
- weights = np.eye(num_units, input_size, dtype=np.float32)
- recurrent_weights = np.eye(num_units, dtype=np.float32)
- bias = np.zeros(num_units, dtype=np.float32)
+ batch, input_size, num_units, rank, memory_size = 1, 1, 1, 2, 3
+ feat_weights_0 = np.array([[1.0], [2.0]], dtype=np.float32)
+ time_weights_0 = np.array([[1.0, 3.0, 5.0], [2.0, 4.0, 6.0]],
dtype=np.float32)
+ bias_0 = np.zeros(num_units, dtype=np.float32)
+
+ feat_weights_1 = np.array([[7.0], [11.0]], dtype=np.float32)
+ time_weights_1 = np.array([[13.0, 17.0, 19.0], [23.0, 29.0, 31.0]],
dtype=np.float32)
+ bias_1 = np.zeros(num_units, dtype=np.float32)
mod = _load_model_from_buffer(
- _build_two_step_shared_state_rnn_model(
+ _build_two_step_shared_state_svdf_model(
batch,
input_size,
num_units,
- weights,
- recurrent_weights,
- bias,
+ rank,
+ memory_size,
+ feat_weights_0,
+ time_weights_0,
+ bias_0,
+ feat_weights_1,
+ time_weights_1,
+ bias_1,
ActivationFunctionType.NONE,
)
)
@@ -10381,35 +10386,54 @@ def test_rnn_shared_hidden_state_updates_exp_tab():
class Expected:
@R.function
def main(
- x0: R.Tensor((2, 2), dtype="float32"),
- x1: R.Tensor((2, 2), dtype="float32"),
- ) -> R.Tensor((2, 2), dtype="float32"):
- R.func_attr({"num_input": 2})
+ tvmgen_tensor_0: R.Tensor((1, 1), dtype="float32"),
+ tvmgen_tensor_6: R.Tensor((1, 1), dtype="float32"),
+ tvmgen_tensor_4: R.Tensor((1, 6), dtype="float32"),
+ ) -> R.Tensor((1, 1), dtype="float32"):
+ R.func_attr({"num_input": 3})
with R.dataflow():
- lv: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
- R.const(np.eye(2, dtype=np.float32)), axes=None
+ lv: R.Tensor((1, 2, 3), dtype="float32") = R.reshape(
+ tvmgen_tensor_4, R.shape([1, 2, 3])
)
- lv1: R.Tensor((2, 2), dtype="float32") = R.matmul(x0, lv,
out_dtype="void")
- lv2: R.Tensor((2, 2), dtype="float32") = R.zeros(R.shape([2,
2]), dtype="float32")
- lv3: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
- R.const(np.eye(2, dtype=np.float32)), axes=None
+ lv1: R.Tensor((1, 2, 3), dtype="float32") = R.reshape(
+ R.const(np.array([[1.0, 3.0, 5.0], [2.0, 4.0, 6.0]],
dtype=np.float32)),
+ R.shape([1, 2, 3]),
)
- lv4: R.Tensor((2, 2), dtype="float32") = R.matmul(lv2, lv3,
out_dtype="void")
- lv5: R.Tensor((2, 2), dtype="float32") = R.add(lv1, lv4)
- lv6: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
- R.const(np.eye(2, dtype=np.float32)), axes=None
+ lv2: R.Tensor((1, 2, 3), dtype="float32") = R.multiply(lv, lv1)
+ lv3: R.Tensor((1, 2), dtype="float32") = R.sum(lv2, axis=[-1],
keepdims=False)
+ lv4: R.Tensor((1, 1, 2), dtype="float32") = R.reshape(lv3,
R.shape([1, 1, 2]))
+ lv5: R.Tensor((1, 1), dtype="float32") = R.sum( # noqa: F841
+ lv4, axis=[-1], keepdims=False
)
- lv7: R.Tensor((2, 2), dtype="float32") = R.matmul(x1, lv6,
out_dtype="void")
- lv8: R.Tensor((2, 2), dtype="float32") = R.add(
- lv5, R.const(np.zeros(2, dtype=np.float32))
+ lv6: R.Tensor((1, 2, 2), dtype="float32") = R.strided_slice(
+ lv,
+ (R.prim_value(2),),
+ (R.prim_value(1),),
+ (R.prim_value(3),),
+ assume_inbound=False,
)
- lv9: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
- R.const(np.eye(2, dtype=np.float32)), axes=None
+ lv7: R.Tensor((1, 2), dtype="float32") = R.permute_dims(
+ R.const(np.array([[1.0], [2.0]], dtype=np.float32)),
axes=None
)
- lv10: R.Tensor((2, 2), dtype="float32") = R.matmul(lv8, lv9,
out_dtype="void")
- lv11: R.Tensor((2, 2), dtype="float32") = R.add(lv7, lv10)
- gv: R.Tensor((2, 2), dtype="float32") = R.add(
- lv11, R.const(np.zeros(2, dtype=np.float32))
+ lv8: R.Tensor((1, 2), dtype="float32") = R.matmul(
+ tvmgen_tensor_0,
+ lv7,
+ out_dtype="void",
+ )
+ lv9: R.Tensor((1, 2, 1), dtype="float32") = R.expand_dims(lv8,
axis=[-1])
+ lv10: R.Tensor((1, 2, 3), dtype="float32") = R.concat((lv6,
lv9), axis=2)
+ lv11: R.Tensor((1, 6), dtype="float32") = R.reshape(lv10,
R.shape([1, 6]))
+ lv12: R.Tensor((1, 2, 3), dtype="float32") = R.reshape(lv11,
R.shape([1, 2, 3]))
+ lv13: R.Tensor((1, 2, 3), dtype="float32") = R.reshape(
+ R.const(np.array([[13.0, 17.0, 19.0], [23.0, 29.0, 31.0]],
dtype=np.float32)),
+ R.shape([1, 2, 3]),
+ )
+ lv14: R.Tensor((1, 2, 3), dtype="float32") = R.multiply(lv12,
lv13)
+ lv15: R.Tensor((1, 2), dtype="float32") = R.sum(lv14,
axis=[-1], keepdims=False)
+ lv16: R.Tensor((1, 1, 2), dtype="float32") = R.reshape(lv15,
R.shape([1, 1, 2]))
+ lv17: R.Tensor((1, 1), dtype="float32") = R.sum(lv16,
axis=[-1], keepdims=False)
+ gv: R.Tensor((1, 1), dtype="float32") = R.add(
+ lv17, R.const(np.zeros(1, dtype=np.float32))
)
R.output(gv)
return gv