This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4b0a0397e2 [Relax][TFLite] Add remaining operator tests and 
reverse_sequence op (#19814)
4b0a0397e2 is described below

commit 4b0a0397e2805fb1bea8dabaa0412463887859d3
Author: Hongyi Wu <[email protected]>
AuthorDate: Wed Jun 24 03:22:52 2026 +0800

    [Relax][TFLite] Add remaining operator tests and reverse_sequence op 
(#19814)
    
    ## Summary
    
    This PR adds focused Relax TFLite frontend coverage for the remaining
    non-quantized builtin operators tracked by #18971:
    
    - `SQUEEZE`
    - `REVERSE_SEQUENCE`
    - `UNPACK`
    - `ZEROS_LIKE`
    
    The tests manually build minimal TFLite flatbuffers and compare the
    imported
    Relax IR with `tvm.ir.assert_structural_equal`. This keeps the coverage
    on the
    frontend importer itself, without depending on TensorFlow converter
    rewrites or
    constant folding.
    
    The PR also adds first-class Relax support for `reverse_sequence`.
    TFLite
    `REVERSE_SEQUENCE` was previously routed through:
    
    ```text
    R.call_dps_packed("topi.reverse_sequence", ...)
    ```
    
    That is not executable as a runtime packed call because
    `topi.reverse_sequence`
    is a TE compute and expects TE tensors during lowering. The frontend now
    emits
    `R.reverse_sequence`, and `LegalizeOps` lowers it through TOPI to TIR:
    
    ```text
    TFLite REVERSE_SEQUENCE
      -> R.reverse_sequence
      -> LegalizeOps
      -> topi.reverse_sequence
      -> R.call_tir
    ```
    
    ## Design
    
    ### TFLite Operator Tests
    
    The new TFLite tests use hand-built flatbuffers for small importer
    fixtures:
    
    - `SQUEEZE` checks axis handling and direct Relax `squeeze` lowering.
    - `REVERSE_SEQUENCE` checks import to `R.reverse_sequence`, rejects the
    old
    `R.call_dps_packed("topi.reverse_sequence", ...)` path, compiles the
    module,
      and runs it with the VM.
    - `UNPACK` checks multi-output lowering through Relax tuple output
    handling.
    - `ZEROS_LIKE` checks direct Relax zero-like tensor creation.
    
    ### Relax reverse_sequence Operator
    
    The PR adds a public Relax operator:
    
    ```python
    relax.op.reverse_sequence(data, seq_lengths, seq_axis=1, batch_axis=0)
    ```
    
    The operator uses `ReverseSequenceAttrs` with `seq_axis` and
    `batch_axis`.
    Type inference preserves the input tensor's shape, dtype, and vdevice,
    and
    validates the statically known constraints:
    
    - `data` must be a tensor.
    - `seq_lengths` must be a 1-D tensor.
    - `seq_lengths` dtype must be `int32` or `int64`.
    - `seq_axis` and `batch_axis` must be in `[-ndim, ndim)` when the input
    rank is
      known.
    - `seq_lengths.shape[0]` must match the batch-axis extent when both
    shapes are
      statically available.
    
    The op is exported through Python as `relax.op.reverse_sequence` and
    through
    the script builder as `R.reverse_sequence`.
    
    ### Legalization
    
    `relax.reverse_sequence` is registered in `LegalizeOps` and lowered with
    `bb.call_te`:
    
    ```python
    bb.call_te(
        topi.reverse_sequence,
        data,
        seq_lengths,
        seq_axis,
        batch_axis,
        primfunc_name_hint="reverse_sequence",
    )
    ```
    
    This produces `R.call_tir` in the legalized Relax module, keeping
    runtime
    execution on the normal TOPI/TIR path.
    
    ### TOPI Packed Registration
    
    The Python TOPI wrapper already accepts `batch_axis`:
    
    ```python
    topi.reverse_sequence(a, seq_lengths, seq_axis=1, batch_axis=0)
    ```
    
    The C++ packed registration only forwarded the first three arguments, so
    Python
    calls that provided `batch_axis` would drop it before reaching the TOPI
    compute.
    The registration now forwards the fourth argument and keeps the old
    three-argument call form compatible by defaulting `batch_axis=0`.
    
    ## Operator Support
    
    | Operator | TFLite options | Relax lowering | Supported subset |
    |---|---|---|---|
    | `SQUEEZE` | `SqueezeOptions.SqueezeDims()` | `R.squeeze` | static
    squeeze axes from TFLite options |
    | `REVERSE_SEQUENCE` | `ReverseSequenceOptions.SeqDim()`, `BatchDim()` |
    `R.reverse_sequence` legalized to TOPI/TIR | tensor input, 1-D
    int32/int64 `seq_lengths`, valid `seq_axis` and `batch_axis` |
    | `UNPACK` | `UnpackOptions.Axis()`, `Num()` | Relax tuple output |
    static axis and output count from TFLite options |
    | `ZEROS_LIKE` | none | `R.zeros_like` | tensor input |
    
    ## Not Included
    
    - Quantized TFLite `REVERSE_SEQUENCE` support.
    - A runtime DPS packed implementation for `topi.reverse_sequence`.
    - Changes to TOPI compute semantics.
    - ONNX `ReverseSequence` importer support.
    
    ## Tests
    
    The tests cover both the TFLite frontend fixtures and the new Relax op:
    
    | Test | Coverage |
    |---|---|
    | `test_squeeze` | imports TFLite `SQUEEZE` to Relax `squeeze` |
    | `test_reverse_sequence` | imports TFLite `REVERSE_SEQUENCE` to
    `R.reverse_sequence`, avoids the old TOPI DPS packed call, compiles, and
    runs through VM |
    | `test_unpack` | imports TFLite `UNPACK` as multi-output Relax tuple
    handling |
    | `test_zeros_like` | imports TFLite `ZEROS_LIKE` to Relax `zeros_like`
    |
    | `test_op_correctness` | `relax.op.reverse_sequence(...).op` resolves
    to `relax.reverse_sequence` |
    | `test_reverse_sequence_infer_ty` | static shape, unknown dtype,
    unknown ndim, symbolic shape, and vdevice propagation |
    | `test_reverse_sequence_infer_ty_wrong_inputs` | non-tensor
    `seq_lengths`, wrong rank, wrong dtype, invalid axes, and static batch
    mismatch |
    | `test_reverse_sequence` in `test_transform_legalize_ops_manipulate.py`
    | `LegalizeOps` emits `R.call_tir` and exercises `seq_axis=0,
    batch_axis=1` |
    
    Local validation:
    
    ```bash
    
    python -m pytest tests/python/relax/test_op_manipulate.py \
      -k reverse_sequence -q
    
    python -m pytest 
tests/python/relax/test_transform_legalize_ops_manipulate.py \
      -k reverse_sequence -q
    
    python -m pytest --noconftest tests/python/relax/test_frontend_tflite.py \
      -k "reverse_sequence or squeeze or unpack or zeros_like" -q
    ```
    
    Result:
    
    ```text
    cmake build: passed
    py_compile: passed
    ruff format --check: 9 files already formatted
    ruff check: All checks passed
    clang-format --dry-run --Werror: passed
    pre-commit run --files: passed
    test_op_manipulate.py -k reverse_sequence: 3 passed
    test_transform_legalize_ops_manipulate.py -k reverse_sequence: 1 passed
    test_frontend_tflite.py -k "reverse_sequence or squeeze or unpack or 
zeros_like": 4 passed
    ```
    
    ## References
    
    - Issue #18971: TFLite non-quantized operator unit-test coverage
    - TFLite `REVERSE_SEQUENCE` builtin semantics
---
 include/tvm/relax/attrs/manipulate.h               |  17 ++
 .../tvm/relax/frontend/tflite/tflite_frontend.py   |   8 +-
 python/tvm/relax/op/__init__.py                    |   1 +
 python/tvm/relax/op/manipulate.py                  |  25 ++
 python/tvm/relax/op/op_attrs.py                    |   5 +
 python/tvm/relax/script/builder/ir.py              |   2 +
 .../tvm/relax/transform/legalize_ops/manipulate.py |  12 +
 src/relax/op/tensor/manipulate.cc                  |  91 +++++++
 src/relax/op/tensor/manipulate.h                   |  10 +
 src/topi/transform.cc                              |   3 +-
 tests/python/relax/test_frontend_tflite.py         | 272 +++++++++++++++++++++
 tests/python/relax/test_op_manipulate.py           |  72 ++++++
 .../test_transform_legalize_ops_manipulate.py      |  56 +++++
 13 files changed, 569 insertions(+), 5 deletions(-)

diff --git a/include/tvm/relax/attrs/manipulate.h 
b/include/tvm/relax/attrs/manipulate.h
index 7897b860e1..cb538ea867 100644
--- a/include/tvm/relax/attrs/manipulate.h
+++ b/include/tvm/relax/attrs/manipulate.h
@@ -195,6 +195,23 @@ struct FlipAttrs : public AttrsNode {
   TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.FlipAttrs", FlipAttrs, 
AttrsNode);
 };  // struct FlipAttrs
 
+/*! \brief Attributes used in reverse_sequence operators */
+struct ReverseSequenceAttrs : public AttrsNode {
+  int64_t seq_axis;
+  int64_t batch_axis;
+
+  static void RegisterReflection() {
+    namespace refl = tvm::ffi::reflection;
+    refl::ObjectDef<ReverseSequenceAttrs>()
+        .def_ro("seq_axis", &ReverseSequenceAttrs::seq_axis,
+                "The axis along which to reverse variable length slices.")
+        .def_ro("batch_axis", &ReverseSequenceAttrs::batch_axis,
+                "The axis that indexes the batch.");
+  }
+  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ReverseSequenceAttrs", 
ReverseSequenceAttrs,
+                                    AttrsNode);
+};  // struct ReverseSequenceAttrs
+
 /*! \brief Attributes used in gather_elements operators */
 struct GatherElementsAttrs : public AttrsNode {
   int64_t axis;
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py 
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 44e9773973..4bf74fe340 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -5614,10 +5614,10 @@ class OperatorConverter:
         else:
             splitted = relax.op.split(in_expr, 
indices_or_sections=num_unpacks, axis=unpack_axis)
             squeezed = relax.Tuple(
-                relax.Tuple(
-                    [_op.squeeze(split_item, axis=squeeze_axis) for split_item 
in splitted]
-                ),
-                len(splitted),
+                [
+                    _op.squeeze(relax.TupleGetItem(splitted, i), 
axis=squeeze_axis)
+                    for i in range(num_unpacks)
+                ]
             )
 
         return squeezed
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 473e50ed30..c116a0d996 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -108,6 +108,7 @@ from .manipulate import (
     permute_dims,
     repeat,
     reshape,
+    reverse_sequence,
     scatter_elements,
     scatter_nd,
     slice_scatter,
diff --git a/python/tvm/relax/op/manipulate.py 
b/python/tvm/relax/op/manipulate.py
index e4814bc62a..4b787c265b 100644
--- a/python/tvm/relax/op/manipulate.py
+++ b/python/tvm/relax/op/manipulate.py
@@ -460,6 +460,31 @@ def flip(data, axis):
     return _ffi_api.flip(data, axis)  # type: ignore
 
 
+def reverse_sequence(data: Expr, seq_lengths: Expr, seq_axis: int = 1, 
batch_axis: int = 0) -> Expr:
+    """Reverses variable length slices.
+
+    Parameters
+    ----------
+    data : relax.Expr
+        The input tensor.
+
+    seq_lengths : relax.Expr
+        A 1-D tensor containing sequence lengths for each batch.
+
+    seq_axis : int
+        The axis along which to reverse variable length slices.
+
+    batch_axis : int
+        The axis that indexes the batch.
+
+    Returns
+    -------
+    ret : relax.Expr
+        The computed result.
+    """
+    return _ffi_api.reverse_sequence(data, seq_lengths, seq_axis, batch_axis)  
# type: ignore
+
+
 def gather_elements(data: Expr, indices: Expr, axis: int = 0) -> Expr:
     """Gather elements from data according to indices along the specified axis.
 
diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py
index b4c3260bb4..b879a19b46 100644
--- a/python/tvm/relax/op/op_attrs.py
+++ b/python/tvm/relax/op/op_attrs.py
@@ -206,6 +206,11 @@ class FlipAttrs(Attrs):
     """Attributes for flip operator"""
 
 
+@tvm_ffi.register_object("relax.attrs.ReverseSequenceAttrs")
+class ReverseSequenceAttrs(Attrs):
+    """Attributes for reverse_sequence operator"""
+
+
 @tvm_ffi.register_object("relax.attrs.PadAttrs")
 class PadAttrs(Attrs):
     """Attributes used in pad operator"""
diff --git a/python/tvm/relax/script/builder/ir.py 
b/python/tvm/relax/script/builder/ir.py
index 4a30963eea..8c4f0191db 100644
--- a/python/tvm/relax/script/builder/ir.py
+++ b/python/tvm/relax/script/builder/ir.py
@@ -153,6 +153,7 @@ from tvm.relax.op import (
     quantize,
     repeat,
     reshape,
+    reverse_sequence,
     right_shift,
     round,
     rsqrt,
@@ -941,6 +942,7 @@ __all__ = [
     "quantize",
     "repeat",
     "reshape",
+    "reverse_sequence",
     "rewriter",
     "right_shift",
     "rocm",
diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py 
b/python/tvm/relax/transform/legalize_ops/manipulate.py
index 1f3abaaf6e..f0cc8977d4 100644
--- a/python/tvm/relax/transform/legalize_ops/manipulate.py
+++ b/python/tvm/relax/transform/legalize_ops/manipulate.py
@@ -170,6 +170,18 @@ def _flip(bb: BlockBuilder, call: Call) -> Expr:
     return bb.call_te(topi.flip, call.args[0], int(call.attrs.axis))
 
 
+@register_legalize("relax.reverse_sequence")
+def _reverse_sequence(bb: BlockBuilder, call: Call) -> Expr:
+    return bb.call_te(
+        topi.reverse_sequence,
+        call.args[0],
+        call.args[1],
+        int(call.attrs.seq_axis),
+        int(call.attrs.batch_axis),
+        primfunc_name_hint="reverse_sequence",
+    )
+
+
 @register_legalize("relax.gather_elements")
 def _gather_elements(bb: BlockBuilder, call: Call) -> Expr:
     return bb.call_te(topi.gather, call.args[0], int(call.attrs.axis), 
call.args[1])
diff --git a/src/relax/op/tensor/manipulate.cc 
b/src/relax/op/tensor/manipulate.cc
index d27a56bd92..caa7300913 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -51,6 +51,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
   RepeatAttrs::RegisterReflection();
   TileAttrs::RegisterReflection();
   FlipAttrs::RegisterReflection();
+  ReverseSequenceAttrs::RegisterReflection();
   GatherElementsAttrs::RegisterReflection();
   GatherNDAttrs::RegisterReflection();
   IndexPutAttrs::RegisterReflection();
@@ -2071,6 +2072,96 @@ TVM_REGISTER_OP("relax.flip")
     .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutFlip)
     .set_attr<bool>("FPurity", true);
 
+/* relax.reverse_sequence */
+
+Expr reverse_sequence(Expr data, Expr seq_lengths, int64_t seq_axis, int64_t 
batch_axis) {
+  auto attrs = ffi::make_object<ReverseSequenceAttrs>();
+  attrs->seq_axis = seq_axis;
+  attrs->batch_axis = batch_axis;
+  static const Op& op = Op::Get("relax.reverse_sequence");
+  return Call(op, {std::move(data), std::move(seq_lengths)}, Attrs{attrs}, {});
+}
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+  namespace refl = tvm::ffi::reflection;
+  refl::GlobalDef().def("relax.op.reverse_sequence", reverse_sequence);
+}
+
+Type InferTypeReverseSequence(const Call& call, const BlockBuilder& ctx) {
+  if (call->args.size() != 2) {
+    TVM_FFI_VISIT_THROW(ValueError, call) << "ReverseSequence op should take 2 
arguments";
+  }
+  TensorType data_ty = GetInputTensorType(call, 0, ctx);
+  TensorType seq_lengths_ty = GetInputTensorType(call, 1, ctx);
+
+  if (!seq_lengths_ty->IsUnknownNdim() && seq_lengths_ty->ndim != 1) {
+    TVM_FFI_VISIT_THROW(ValueError, call)
+        << "ReverseSequence requires seq_lengths to be 1-D. However, 
seq_lengths has ndim "
+        << seq_lengths_ty->ndim;
+  }
+  if (!seq_lengths_ty->dtype.is_void() && !seq_lengths_ty->dtype.is_int()) {
+    TVM_FFI_VISIT_THROW(ValueError, call)
+        << "ReverseSequence requires seq_lengths to have dtype int32 or int64. 
However, "
+           "seq_lengths has dtype "
+        << seq_lengths_ty->dtype;
+  }
+  if (seq_lengths_ty->dtype.is_int() && seq_lengths_ty->dtype.bits() != 32 &&
+      seq_lengths_ty->dtype.bits() != 64) {
+    TVM_FFI_VISIT_THROW(ValueError, call)
+        << "ReverseSequence requires seq_lengths to have dtype int32 or int64. 
However, "
+           "seq_lengths has dtype "
+        << seq_lengths_ty->dtype;
+  }
+
+  const auto* attrs = call->attrs.as<ReverseSequenceAttrs>();
+  int64_t seq_axis = attrs->seq_axis;
+  int64_t batch_axis = attrs->batch_axis;
+  if (!data_ty->IsUnknownNdim()) {
+    int ndim = data_ty->ndim;
+    auto check_axis = [&](int64_t axis, ffi::String axis_name) {
+      if (axis < -ndim || axis >= ndim) {
+        TVM_FFI_VISIT_THROW(ValueError, call)
+            << "ReverseSequence requires " << axis_name
+            << " to belong to range [-ndim, ndim). However, the axis is " << 
axis
+            << ", while ndim is " << ndim;
+      }
+    };
+    check_axis(seq_axis, "seq_axis");
+    check_axis(batch_axis, "batch_axis");
+
+    if (batch_axis < 0) {
+      batch_axis += ndim;
+    }
+
+    if (data_ty->shape.defined() && seq_lengths_ty->shape.defined()) {
+      const auto* data_shape_ty = 
GetTypeAs<ShapeTypeNode>(data_ty->shape.value());
+      const auto* seq_lengths_shape_ty = 
GetTypeAs<ShapeTypeNode>(seq_lengths_ty->shape.value());
+      if (data_shape_ty != nullptr && seq_lengths_shape_ty != nullptr &&
+          data_shape_ty->values.defined() && 
seq_lengths_shape_ty->values.defined()) {
+        PrimExpr batch_extent = data_shape_ty->values.value()[batch_axis];
+        PrimExpr seq_lengths_extent = seq_lengths_shape_ty->values.value()[0];
+        if (ctx->GetAnalyzer()->CanProve(seq_lengths_extent != batch_extent)) {
+          TVM_FFI_VISIT_THROW(ValueError, call)
+              << "ReverseSequence requires seq_lengths.shape[0] to equal the 
batch axis extent. "
+                 "However, seq_lengths.shape[0] is "
+              << seq_lengths_extent << ", while data.shape[" << batch_axis << 
"] is "
+              << batch_extent;
+        }
+      }
+    }
+  }
+
+  return data_ty;
+}
+
+TVM_REGISTER_OP("relax.reverse_sequence")
+    .set_attrs_type<ReverseSequenceAttrs>()
+    .set_num_inputs(2)
+    .add_argument("data", "Tensor", "The input tensor.")
+    .add_argument("seq_lengths", "Tensor", "The sequence length tensor.")
+    .set_attr<FInferType>("FInferType", InferTypeReverseSequence)
+    .set_attr<bool>("FPurity", true);
+
 /* relax.gather_elements */
 
 Expr gather_elements(Expr data, Expr indices, int axis) {
diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h
index 147e622f4d..343b0f6651 100644
--- a/src/relax/op/tensor/manipulate.h
+++ b/src/relax/op/tensor/manipulate.h
@@ -181,6 +181,16 @@ Expr tile(Expr data, ffi::Array<int64_t> repeats);
  */
 Expr flip(Expr data, int64_t axis);
 
+/*!
+ * \brief Reverses variable length slices along seq_axis.
+ * \param data The input tensor.
+ * \param seq_lengths A 1-D tensor containing sequence lengths for each batch.
+ * \param seq_axis The axis along which to reverse.
+ * \param batch_axis The axis that indexes the batch.
+ * \return The computed result.
+ */
+Expr reverse_sequence(Expr data, Expr seq_lengths, int64_t seq_axis, int64_t 
batch_axis);
+
 /*!
  * \brief Gather elements from a tensor using indices.
  * \param data The input tensor.
diff --git a/src/topi/transform.cc b/src/topi/transform.cc
index 5e81e95c60..f0d9225fb5 100644
--- a/src/topi/transform.cc
+++ b/src/topi/transform.cc
@@ -58,8 +58,9 @@ TVM_FFI_STATIC_INIT_BLOCK() {
                   })
       .def_packed("topi.reverse_sequence",
                   [](ffi::PackedArgs args, ffi::Any* rv) {
+                    int batch_axis = args.size() >= 4 ? args[3].cast<int>() : 
0;
                     *rv = reverse_sequence(args[0].cast<te::Tensor>(), 
args[1].cast<te::Tensor>(),
-                                           args[2].cast<int>());
+                                           args[2].cast<int>(), batch_axis);
                   })
       .def_packed("topi.reshape",
                   [](ffi::PackedArgs args, ffi::Any* rv) {
diff --git a/tests/python/relax/test_frontend_tflite.py 
b/tests/python/relax/test_frontend_tflite.py
index 590cc4ac45..9962732e08 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -1617,6 +1617,42 @@ def test_reverse_v2():
     verify(ReverseV2, Expected)
 
 
+def test_reverse_sequence():
+    mod = _load_model_from_buffer(_build_tflite_reverse_sequence_model())
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            tvmgen_tensor_0: R.Tensor((2, 4, 3), dtype="float32"),
+            tvmgen_tensor_1: R.Tensor((2,), dtype="int32"),
+        ) -> R.Tensor((2, 4, 3), dtype="float32"):
+            R.func_attr({"num_input": 2})
+            with R.dataflow():
+                gv: R.Tensor((2, 4, 3), dtype="float32") = R.reverse_sequence(
+                    tvmgen_tensor_0, tvmgen_tensor_1, seq_axis=1, batch_axis=0
+                )
+                R.output(gv)
+            return gv
+
+    tvm.ir.assert_structural_equal(mod, Expected)
+    ir = mod.script()
+    assert "R.reverse_sequence" in ir
+    assert 'R.call_dps_packed("topi.reverse_sequence"' not in ir
+
+    data = np.arange(24, dtype="float32").reshape((2, 4, 3))
+    seq_lengths = np.array([1, 3], dtype="int32")
+    expected = data.copy()
+    expected[1, :3, :] = expected[1, :3, :][::-1]
+
+    ex = tvm.compile(mod, tvm.target.Target("c"))
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    vm.set_input("main", data, seq_lengths)
+    vm.invoke_stateful("main")
+    output = vm.get_outputs("main")
+    np.testing.assert_allclose(output.numpy(), expected, rtol=1e-5, atol=1e-5)
+
+
 def test_gather():
     class Gather(tf.Module):
         @tf.function(
@@ -1674,6 +1710,73 @@ def test_gather_nd():
     verify(GatherND, Expected)
 
 
+def test_squeeze():
+    mod = _load_model_from_buffer(_build_tflite_squeeze_model())
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(tvmgen_tensor_0: R.Tensor((1, 2, 1, 3), dtype="float32")) -> 
R.Tensor(
+            (2, 3), dtype="float32"
+        ):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                gv: R.Tensor((2, 3), dtype="float32") = 
R.squeeze(tvmgen_tensor_0, axis=[0, 2])
+                R.output(gv)
+            return gv
+
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_unpack():
+    mod = _load_model_from_buffer(_build_tflite_unpack_model())
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(tvmgen_tensor_0: R.Tensor((2, 3, 4), dtype="float32")) -> 
R.Tuple(
+            R.Tensor((2, 4), dtype="float32"),
+            R.Tensor((2, 4), dtype="float32"),
+            R.Tensor((2, 4), dtype="float32"),
+        ):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                lv: R.Tuple(
+                    R.Tensor((2, 1, 4), dtype="float32"),
+                    R.Tensor((2, 1, 4), dtype="float32"),
+                    R.Tensor((2, 1, 4), dtype="float32"),
+                ) = R.split(tvmgen_tensor_0, indices_or_sections=3, axis=1)
+                lv1: R.Tensor((2, 1, 4), dtype="float32") = lv[0]
+                lv2: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv1, 
axis=[1])
+                lv3: R.Tensor((2, 1, 4), dtype="float32") = lv[1]
+                lv4: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv3, 
axis=[1])
+                lv5: R.Tensor((2, 1, 4), dtype="float32") = lv[2]
+                lv6: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv5, 
axis=[1])
+                gv = (lv2, lv4, lv6)
+                R.output(gv)
+            return gv
+
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_zeros_like():
+    mod = _load_model_from_buffer(_build_tflite_zeros_like_model())
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(tvmgen_tensor_0: R.Tensor((2, 3), dtype="float32")) -> 
R.Tensor(
+            (2, 3), dtype="float32"
+        ):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                gv: R.Tensor((2, 3), dtype="float32") = 
R.zeros_like(tvmgen_tensor_0)
+                R.output(gv)
+            return gv
+
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 def _make_conv2d_module(data_shape, kernel_shape, data_format, strides, 
padding):
     class Conv2DModule(tf.Module):
         @tf.function(
@@ -4029,7 +4132,11 @@ _tfl_quantization_parameters = 
_get_tflite_schema_module("QuantizationParameters
 _tfl_sparsity_parameters = _get_tflite_schema_module("SparsityParameters")
 _tfl_subgraph = _get_tflite_schema_module("SubGraph")
 _tfl_tensor = _get_tflite_schema_module("Tensor")
+_tfl_reverse_sequence_options = 
_get_tflite_schema_module("ReverseSequenceOptions")
+_tfl_squeeze_options = _get_tflite_schema_module("SqueezeOptions")
+_tfl_unpack_options = _get_tflite_schema_module("UnpackOptions")
 _tfl_while_options = _get_tflite_schema_module("WhileOptions")
+_tfl_zeros_like_options = _get_tflite_schema_module("ZerosLikeOptions")
 
 _tfl_builtin_operator = _get_tflite_schema_enum("BuiltinOperator")
 _tfl_builtin_options = _get_tflite_schema_enum("BuiltinOptions")
@@ -4285,6 +4392,31 @@ def _build_call_once_options(builder, 
init_subgraph_index):
     return _tfl_call_once_options.CallOnceOptionsEnd(builder)
 
 
+def _build_squeeze_options(builder, squeeze_dims):
+    squeeze_dims_vec = _tflite_int32_vector(
+        builder,
+        _tfl_squeeze_options.SqueezeOptionsStartSqueezeDimsVector,
+        squeeze_dims,
+    )
+    _tfl_squeeze_options.SqueezeOptionsStart(builder)
+    _tfl_squeeze_options.SqueezeOptionsAddSqueezeDims(builder, 
squeeze_dims_vec)
+    return _tfl_squeeze_options.SqueezeOptionsEnd(builder)
+
+
+def _build_reverse_sequence_options(builder, seq_dim, batch_dim):
+    _tfl_reverse_sequence_options.ReverseSequenceOptionsStart(builder)
+    _tfl_reverse_sequence_options.ReverseSequenceOptionsAddSeqDim(builder, 
seq_dim)
+    _tfl_reverse_sequence_options.ReverseSequenceOptionsAddBatchDim(builder, 
batch_dim)
+    return _tfl_reverse_sequence_options.ReverseSequenceOptionsEnd(builder)
+
+
+def _build_unpack_options(builder, num, axis):
+    _tfl_unpack_options.UnpackOptionsStart(builder)
+    _tfl_unpack_options.UnpackOptionsAddNum(builder, num)
+    _tfl_unpack_options.UnpackOptionsAddAxis(builder, axis)
+    return _tfl_unpack_options.UnpackOptionsEnd(builder)
+
+
 def _get_builtin_options_type(options_name):
     if not hasattr(_tfl_builtin_options, options_name):
         pytest.skip(f"TFLite schema does not provide 
BuiltinOptions.{options_name}")
@@ -4410,6 +4542,146 @@ def test_operator_marker_unsupported(builtin_name):
         
_load_model_from_buffer(_build_tflite_operator_marker_model(builtin_name))
 
 
+def _build_tflite_squeeze_model():
+    builder = flatbuffers.Builder(1024)
+
+    squeeze_opts = _build_squeeze_options(builder, [0, 2])
+    squeeze_op_code = _build_operator_code(builder, 
_tfl_builtin_operator.SQUEEZE)
+
+    tensors = [
+        _build_tensor(builder, 0, [1, 2, 1, 3]),
+        _build_tensor(builder, 0, [2, 3]),
+    ]
+    squeeze_op = _build_operator(
+        builder,
+        0,
+        [0],
+        [1],
+        builtin_options_type=_tfl_builtin_options.SqueezeOptions,
+        builtin_options=squeeze_opts,
+    )
+    subgraph = _build_subgraph(
+        builder,
+        tensors=tensors,
+        operators=[squeeze_op],
+        inputs=[0],
+        outputs=[1],
+    )
+    buffers = [_build_buffer(builder)]
+    return _finish_tflite_model(
+        builder,
+        subgraph=subgraph,
+        operator_codes=[squeeze_op_code],
+        buffers=buffers,
+    )
+
+
+def _build_tflite_reverse_sequence_model():
+    builder = flatbuffers.Builder(1024)
+
+    reverse_sequence_opts = _build_reverse_sequence_options(builder, 
seq_dim=1, batch_dim=0)
+    reverse_sequence_op_code = _build_operator_code(builder, 
_tfl_builtin_operator.REVERSE_SEQUENCE)
+
+    tensors = [
+        _build_tensor(builder, 0, [2, 4, 3]),
+        _build_tensor(builder, 0, [2], tensor_type=_tfl_tensor_type.INT32),
+        _build_tensor(builder, 0, [2, 4, 3]),
+    ]
+    reverse_sequence_op = _build_operator(
+        builder,
+        0,
+        [0, 1],
+        [2],
+        builtin_options_type=_tfl_builtin_options.ReverseSequenceOptions,
+        builtin_options=reverse_sequence_opts,
+    )
+    subgraph = _build_subgraph(
+        builder,
+        tensors=tensors,
+        operators=[reverse_sequence_op],
+        inputs=[0, 1],
+        outputs=[2],
+    )
+    buffers = [_build_buffer(builder)]
+    return _finish_tflite_model(
+        builder,
+        subgraph=subgraph,
+        operator_codes=[reverse_sequence_op_code],
+        buffers=buffers,
+    )
+
+
+def _build_tflite_unpack_model():
+    builder = flatbuffers.Builder(1024)
+
+    unpack_opts = _build_unpack_options(builder, num=3, axis=1)
+    unpack_op_code = _build_operator_code(builder, 
_tfl_builtin_operator.UNPACK)
+
+    tensors = [
+        _build_tensor(builder, 0, [2, 3, 4]),
+        _build_tensor(builder, 0, [2, 4]),
+        _build_tensor(builder, 0, [2, 4]),
+        _build_tensor(builder, 0, [2, 4]),
+    ]
+    unpack_op = _build_operator(
+        builder,
+        0,
+        [0],
+        [1, 2, 3],
+        builtin_options_type=_tfl_builtin_options.UnpackOptions,
+        builtin_options=unpack_opts,
+    )
+    subgraph = _build_subgraph(
+        builder,
+        tensors=tensors,
+        operators=[unpack_op],
+        inputs=[0],
+        outputs=[1, 2, 3],
+    )
+    buffers = [_build_buffer(builder)]
+    return _finish_tflite_model(
+        builder,
+        subgraph=subgraph,
+        operator_codes=[unpack_op_code],
+        buffers=buffers,
+    )
+
+
+def _build_tflite_zeros_like_model():
+    builder = flatbuffers.Builder(1024)
+
+    _tfl_zeros_like_options.ZerosLikeOptionsStart(builder)
+    zeros_like_opts = _tfl_zeros_like_options.ZerosLikeOptionsEnd(builder)
+    zeros_like_op_code = _build_operator_code(builder, 
_tfl_builtin_operator.ZEROS_LIKE)
+
+    tensors = [
+        _build_tensor(builder, 0, [2, 3]),
+        _build_tensor(builder, 0, [2, 3]),
+    ]
+    zeros_like_op = _build_operator(
+        builder,
+        0,
+        [0],
+        [1],
+        builtin_options_type=_tfl_builtin_options.ZerosLikeOptions,
+        builtin_options=zeros_like_opts,
+    )
+    subgraph = _build_subgraph(
+        builder,
+        tensors=tensors,
+        operators=[zeros_like_op],
+        inputs=[0],
+        outputs=[1],
+    )
+    buffers = [_build_buffer(builder)]
+    return _finish_tflite_model(
+        builder,
+        subgraph=subgraph,
+        operator_codes=[zeros_like_op_code],
+        buffers=buffers,
+    )
+
+
 def _run_module(mod, *inputs):
     tgt = tvm.target.Target("c")
     ex = tvm.compile(mod, tgt)
diff --git a/tests/python/relax/test_op_manipulate.py 
b/tests/python/relax/test_op_manipulate.py
index 9a938b647d..3c308167ba 100644
--- a/tests/python/relax/test_op_manipulate.py
+++ b/tests/python/relax/test_op_manipulate.py
@@ -46,6 +46,8 @@ def test_op_correctness():
     assert relax.op.cumsum(x, axis=1, dtype="int32").op == 
Op.get("relax.cumsum")
     assert relax.op.einsum(x, subscripts="ii").op == Op.get("relax.einsum")
     assert relax.op.flip(x, axis=1).op == Op.get("relax.flip")
+    seq_lengths = relax.Var("seq_lengths", R.Tensor((3,), "int32"))
+    assert relax.op.reverse_sequence(x, seq_lengths).op == 
Op.get("relax.reverse_sequence")
     assert relax.op.scatter_elements(x, x, x).op == 
Op.get("relax.scatter_elements")
     assert relax.op.scatter_nd(x, x, x).op == Op.get("relax.scatter_nd")
 
@@ -3019,6 +3021,76 @@ def test_flip_infer_ty_wrong_inputs():
         bb.normalize(relax.op.flip(x0, axis=3))
 
 
+def test_reverse_sequence_infer_ty():
+    bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
+    x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32"))
+    x1 = relax.Var("x", R.Tensor("float16", ndim=3))
+    x2 = relax.Var("x", R.Tensor("int32"))
+    x3 = relax.Var("x", R.Tensor((2, 10, 4)))
+    x4 = relax.Var("x", R.Tensor(ndim=3))
+    x5 = relax.Var("x", R.Tensor((2, 10, 4), "float32", vdev0))
+    s0 = relax.Var("s", R.Tensor((2,), "int32"))
+    s1 = relax.Var("s", R.Tensor("int64", ndim=1))
+
+    _check_inference(
+        bb,
+        relax.op.reverse_sequence(x0, s0, seq_axis=1, batch_axis=0),
+        relax.TensorType((2, 10, 4), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.reverse_sequence(x5, s0, seq_axis=1, batch_axis=0),
+        relax.TensorType((2, 10, 4), "float32", vdev0),
+    )
+    _check_inference(
+        bb, relax.op.reverse_sequence(x1, s1, seq_axis=0, batch_axis=1), 
R.Tensor("float16", ndim=3)
+    )
+    _check_inference(bb, relax.op.reverse_sequence(x2, s1), R.Tensor("int32"))
+    _check_inference(bb, relax.op.reverse_sequence(x3, s0), R.Tensor((2, 10, 
4)))
+    _check_inference(bb, relax.op.reverse_sequence(x4, s1), R.Tensor(ndim=3))
+
+
+def test_reverse_sequence_infer_ty_shape_symbolic():
+    bb = relax.BlockBuilder()
+    a = tirx.Var("a", "int64")
+    b = tirx.Var("b", "int64")
+    x = relax.Var("x", R.Tensor((a, b, 4), "float32"))
+    seq_lengths = relax.Var("seq_lengths", R.Tensor((b,), "int64"))
+
+    _check_inference(
+        bb,
+        relax.op.reverse_sequence(x, seq_lengths, seq_axis=0, batch_axis=1),
+        relax.TensorType((a, b, 4), "float32"),
+    )
+
+
+def test_reverse_sequence_infer_ty_wrong_inputs():
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((2, 10, 4), "float32"))
+    seq_lengths = relax.Var("seq_lengths", R.Tensor((2,), "int32"))
+    seq_lengths_2d = relax.Var("seq_lengths", R.Tensor((2, 1), "int32"))
+    seq_lengths_float = relax.Var("seq_lengths", R.Tensor((2,), "float32"))
+    seq_lengths_int16 = relax.Var("seq_lengths", R.Tensor((2,), "int16"))
+    seq_lengths_mismatch = relax.Var("seq_lengths", R.Tensor((3,), "int32"))
+    not_tensor = relax.Var("seq_lengths", relax.ObjectType())
+
+    with pytest.raises(TypeError):
+        bb.normalize(relax.op.reverse_sequence(x, not_tensor))
+    with pytest.raises(ValueError):
+        bb.normalize(relax.op.reverse_sequence(x, seq_lengths_2d))
+    with pytest.raises(ValueError):
+        bb.normalize(relax.op.reverse_sequence(x, seq_lengths_float))
+    with pytest.raises(ValueError):
+        bb.normalize(relax.op.reverse_sequence(x, seq_lengths_int16))
+    with pytest.raises(ValueError):
+        bb.normalize(relax.op.reverse_sequence(x, seq_lengths, seq_axis=3))
+    with pytest.raises(ValueError):
+        bb.normalize(relax.op.reverse_sequence(x, seq_lengths, batch_axis=-4))
+    with pytest.raises(ValueError):
+        bb.normalize(relax.op.reverse_sequence(x, seq_lengths_mismatch))
+
+
 def test_gather_elements_infer_ty():
     bb = relax.BlockBuilder()
     vdev0 = VDevice("llvm")
diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py 
b/tests/python/relax/test_transform_legalize_ops_manipulate.py
index 92668ab60e..45036523ac 100644
--- a/tests/python/relax/test_transform_legalize_ops_manipulate.py
+++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py
@@ -1363,6 +1363,62 @@ def test_flip_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_reverse_sequence():
+    # fmt: off
+    @I.ir_module(s_tir=True)
+    class ReverseSequence:
+        @R.function
+        def main(x: R.Tensor((4, 2, 3), "float32"), seq_lengths: 
R.Tensor((2,), "int64")):
+            gv = R.reverse_sequence(x, seq_lengths, seq_axis=0, batch_axis=1)
+            return gv
+
+    @I.ir_module(s_tir=True)
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((4, 2, 3), dtype="float32"),
+            seq_lengths: R.Tensor((2,), dtype="int64"),
+        ) -> R.Tensor((4, 2, 3), dtype="float32"):
+            cls = Expected
+            gv = R.call_tir(
+                cls.reverse_sequence,
+                (x, seq_lengths),
+                out_ty=R.Tensor((4, 2, 3), dtype="float32"),
+            )
+            return gv
+
+        @T.prim_func(private=True, s_tir=True)
+        def reverse_sequence(
+            rxplaceholder: T.Buffer((T.int64(4), T.int64(2), T.int64(3)), 
"float32"),
+            seq_lengths: T.Buffer((T.int64(2),), "int64"),
+            T_reverse_sequence: T.Buffer((T.int64(4), T.int64(2), T.int64(3)), 
"float32"),
+        ):
+            T.func_attr({"tirx.noalias": True})
+            for ax0, ax1, ax2 in T.grid(T.int64(4), T.int64(2), T.int64(3)):
+                with T.sblock("T_reverse_sequence"):
+                    v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                    T.reads(rxplaceholder[T.int64(0):T.int64(4), v_ax1, 
v_ax2], seq_lengths[v_ax1])
+                    T.writes(T_reverse_sequence[v_ax0, v_ax1, v_ax2])
+                    T_reverse_sequence[v_ax0, v_ax1, v_ax2] = rxplaceholder[
+                        T.if_then_else(
+                            seq_lengths[v_ax1] <= T.int64(1) or 
seq_lengths[v_ax1] <= v_ax0,
+                            v_ax0,
+                            T.if_then_else(
+                                T.int64(4) < seq_lengths[v_ax1],
+                                T.int64(3) - v_ax0,
+                                seq_lengths[v_ax1] - v_ax0 - T.int64(1),
+                            ),
+                        ),
+                        v_ax1,
+                        v_ax2,
+                    ]
+
+    # fmt: on
+
+    mod = LegalizeOps()(ReverseSequence)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 def test_scatter_elements():
     # fmt: off
     @I.ir_module(s_tir=True)


Reply via email to