This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4b0a0397e2 [Relax][TFLite] Add remaining operator tests and
reverse_sequence op (#19814)
4b0a0397e2 is described below
commit 4b0a0397e2805fb1bea8dabaa0412463887859d3
Author: Hongyi Wu <[email protected]>
AuthorDate: Wed Jun 24 03:22:52 2026 +0800
[Relax][TFLite] Add remaining operator tests and reverse_sequence op
(#19814)
## Summary
This PR adds focused Relax TFLite frontend coverage for the remaining
non-quantized builtin operators tracked by #18971:
- `SQUEEZE`
- `REVERSE_SEQUENCE`
- `UNPACK`
- `ZEROS_LIKE`
The tests manually build minimal TFLite flatbuffers and compare the
imported
Relax IR with `tvm.ir.assert_structural_equal`. This keeps the coverage
on the
frontend importer itself, without depending on TensorFlow converter
rewrites or
constant folding.
The PR also adds first-class Relax support for `reverse_sequence`.
TFLite
`REVERSE_SEQUENCE` was previously routed through:
```text
R.call_dps_packed("topi.reverse_sequence", ...)
```
That is not executable as a runtime packed call because
`topi.reverse_sequence`
is a TE compute and expects TE tensors during lowering. The frontend now
emits
`R.reverse_sequence`, and `LegalizeOps` lowers it through TOPI to TIR:
```text
TFLite REVERSE_SEQUENCE
-> R.reverse_sequence
-> LegalizeOps
-> topi.reverse_sequence
-> R.call_tir
```
## Design
### TFLite Operator Tests
The new TFLite tests use hand-built flatbuffers for small importer
fixtures:
- `SQUEEZE` checks axis handling and direct Relax `squeeze` lowering.
- `REVERSE_SEQUENCE` checks import to `R.reverse_sequence`, rejects the
old
`R.call_dps_packed("topi.reverse_sequence", ...)` path, compiles the
module,
and runs it with the VM.
- `UNPACK` checks multi-output lowering through Relax tuple output
handling.
- `ZEROS_LIKE` checks direct Relax zero-like tensor creation.
### Relax reverse_sequence Operator
The PR adds a public Relax operator:
```python
relax.op.reverse_sequence(data, seq_lengths, seq_axis=1, batch_axis=0)
```
The operator uses `ReverseSequenceAttrs` with `seq_axis` and
`batch_axis`.
Type inference preserves the input tensor's shape, dtype, and vdevice,
and
validates the statically known constraints:
- `data` must be a tensor.
- `seq_lengths` must be a 1-D tensor.
- `seq_lengths` dtype must be `int32` or `int64`.
- `seq_axis` and `batch_axis` must be in `[-ndim, ndim)` when the input
rank is
known.
- `seq_lengths.shape[0]` must match the batch-axis extent when both
shapes are
statically available.
The op is exported through Python as `relax.op.reverse_sequence` and
through
the script builder as `R.reverse_sequence`.
### Legalization
`relax.reverse_sequence` is registered in `LegalizeOps` and lowered with
`bb.call_te`:
```python
bb.call_te(
topi.reverse_sequence,
data,
seq_lengths,
seq_axis,
batch_axis,
primfunc_name_hint="reverse_sequence",
)
```
This produces `R.call_tir` in the legalized Relax module, keeping
runtime
execution on the normal TOPI/TIR path.
### TOPI Packed Registration
The Python TOPI wrapper already accepts `batch_axis`:
```python
topi.reverse_sequence(a, seq_lengths, seq_axis=1, batch_axis=0)
```
The C++ packed registration only forwarded the first three arguments, so
Python
calls that provided `batch_axis` would drop it before reaching the TOPI
compute.
The registration now forwards the fourth argument and keeps the old
three-argument call form compatible by defaulting `batch_axis=0`.
## Operator Support
| Operator | TFLite options | Relax lowering | Supported subset |
|---|---|---|---|
| `SQUEEZE` | `SqueezeOptions.SqueezeDims()` | `R.squeeze` | static
squeeze axes from TFLite options |
| `REVERSE_SEQUENCE` | `ReverseSequenceOptions.SeqDim()`, `BatchDim()` |
`R.reverse_sequence` legalized to TOPI/TIR | tensor input, 1-D
int32/int64 `seq_lengths`, valid `seq_axis` and `batch_axis` |
| `UNPACK` | `UnpackOptions.Axis()`, `Num()` | Relax tuple output |
static axis and output count from TFLite options |
| `ZEROS_LIKE` | none | `R.zeros_like` | tensor input |
## Not Included
- Quantized TFLite `REVERSE_SEQUENCE` support.
- A runtime DPS packed implementation for `topi.reverse_sequence`.
- Changes to TOPI compute semantics.
- ONNX `ReverseSequence` importer support.
## Tests
The tests cover both the TFLite frontend fixtures and the new Relax op:
| Test | Coverage |
|---|---|
| `test_squeeze` | imports TFLite `SQUEEZE` to Relax `squeeze` |
| `test_reverse_sequence` | imports TFLite `REVERSE_SEQUENCE` to
`R.reverse_sequence`, avoids the old TOPI DPS packed call, compiles, and
runs through VM |
| `test_unpack` | imports TFLite `UNPACK` as multi-output Relax tuple
handling |
| `test_zeros_like` | imports TFLite `ZEROS_LIKE` to Relax `zeros_like`
|
| `test_op_correctness` | `relax.op.reverse_sequence(...).op` resolves
to `relax.reverse_sequence` |
| `test_reverse_sequence_infer_ty` | static shape, unknown dtype,
unknown ndim, symbolic shape, and vdevice propagation |
| `test_reverse_sequence_infer_ty_wrong_inputs` | non-tensor
`seq_lengths`, wrong rank, wrong dtype, invalid axes, and static batch
mismatch |
| `test_reverse_sequence` in `test_transform_legalize_ops_manipulate.py`
| `LegalizeOps` emits `R.call_tir` and exercises `seq_axis=0,
batch_axis=1` |
Local validation:
```bash
python -m pytest tests/python/relax/test_op_manipulate.py \
-k reverse_sequence -q
python -m pytest
tests/python/relax/test_transform_legalize_ops_manipulate.py \
-k reverse_sequence -q
python -m pytest --noconftest tests/python/relax/test_frontend_tflite.py \
-k "reverse_sequence or squeeze or unpack or zeros_like" -q
```
Result:
```text
cmake build: passed
py_compile: passed
ruff format --check: 9 files already formatted
ruff check: All checks passed
clang-format --dry-run --Werror: passed
pre-commit run --files: passed
test_op_manipulate.py -k reverse_sequence: 3 passed
test_transform_legalize_ops_manipulate.py -k reverse_sequence: 1 passed
test_frontend_tflite.py -k "reverse_sequence or squeeze or unpack or
zeros_like": 4 passed
```
## References
- Issue #18971: TFLite non-quantized operator unit-test coverage
- TFLite `REVERSE_SEQUENCE` builtin semantics
---
include/tvm/relax/attrs/manipulate.h | 17 ++
.../tvm/relax/frontend/tflite/tflite_frontend.py | 8 +-
python/tvm/relax/op/__init__.py | 1 +
python/tvm/relax/op/manipulate.py | 25 ++
python/tvm/relax/op/op_attrs.py | 5 +
python/tvm/relax/script/builder/ir.py | 2 +
.../tvm/relax/transform/legalize_ops/manipulate.py | 12 +
src/relax/op/tensor/manipulate.cc | 91 +++++++
src/relax/op/tensor/manipulate.h | 10 +
src/topi/transform.cc | 3 +-
tests/python/relax/test_frontend_tflite.py | 272 +++++++++++++++++++++
tests/python/relax/test_op_manipulate.py | 72 ++++++
.../test_transform_legalize_ops_manipulate.py | 56 +++++
13 files changed, 569 insertions(+), 5 deletions(-)
diff --git a/include/tvm/relax/attrs/manipulate.h
b/include/tvm/relax/attrs/manipulate.h
index 7897b860e1..cb538ea867 100644
--- a/include/tvm/relax/attrs/manipulate.h
+++ b/include/tvm/relax/attrs/manipulate.h
@@ -195,6 +195,23 @@ struct FlipAttrs : public AttrsNode {
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.FlipAttrs", FlipAttrs,
AttrsNode);
}; // struct FlipAttrs
+/*! \brief Attributes used in reverse_sequence operators */
+struct ReverseSequenceAttrs : public AttrsNode {
+ int64_t seq_axis;
+ int64_t batch_axis;
+
+ static void RegisterReflection() {
+ namespace refl = tvm::ffi::reflection;
+ refl::ObjectDef<ReverseSequenceAttrs>()
+ .def_ro("seq_axis", &ReverseSequenceAttrs::seq_axis,
+ "The axis along which to reverse variable length slices.")
+ .def_ro("batch_axis", &ReverseSequenceAttrs::batch_axis,
+ "The axis that indexes the batch.");
+ }
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ReverseSequenceAttrs",
ReverseSequenceAttrs,
+ AttrsNode);
+}; // struct ReverseSequenceAttrs
+
/*! \brief Attributes used in gather_elements operators */
struct GatherElementsAttrs : public AttrsNode {
int64_t axis;
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 44e9773973..4bf74fe340 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -5614,10 +5614,10 @@ class OperatorConverter:
else:
splitted = relax.op.split(in_expr,
indices_or_sections=num_unpacks, axis=unpack_axis)
squeezed = relax.Tuple(
- relax.Tuple(
- [_op.squeeze(split_item, axis=squeeze_axis) for split_item
in splitted]
- ),
- len(splitted),
+ [
+ _op.squeeze(relax.TupleGetItem(splitted, i),
axis=squeeze_axis)
+ for i in range(num_unpacks)
+ ]
)
return squeezed
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 473e50ed30..c116a0d996 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -108,6 +108,7 @@ from .manipulate import (
permute_dims,
repeat,
reshape,
+ reverse_sequence,
scatter_elements,
scatter_nd,
slice_scatter,
diff --git a/python/tvm/relax/op/manipulate.py
b/python/tvm/relax/op/manipulate.py
index e4814bc62a..4b787c265b 100644
--- a/python/tvm/relax/op/manipulate.py
+++ b/python/tvm/relax/op/manipulate.py
@@ -460,6 +460,31 @@ def flip(data, axis):
return _ffi_api.flip(data, axis) # type: ignore
+def reverse_sequence(data: Expr, seq_lengths: Expr, seq_axis: int = 1,
batch_axis: int = 0) -> Expr:
+ """Reverses variable length slices.
+
+ Parameters
+ ----------
+ data : relax.Expr
+ The input tensor.
+
+ seq_lengths : relax.Expr
+ A 1-D tensor containing sequence lengths for each batch.
+
+ seq_axis : int
+ The axis along which to reverse variable length slices.
+
+ batch_axis : int
+ The axis that indexes the batch.
+
+ Returns
+ -------
+ ret : relax.Expr
+ The computed result.
+ """
+ return _ffi_api.reverse_sequence(data, seq_lengths, seq_axis, batch_axis)
# type: ignore
+
+
def gather_elements(data: Expr, indices: Expr, axis: int = 0) -> Expr:
"""Gather elements from data according to indices along the specified axis.
diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py
index b4c3260bb4..b879a19b46 100644
--- a/python/tvm/relax/op/op_attrs.py
+++ b/python/tvm/relax/op/op_attrs.py
@@ -206,6 +206,11 @@ class FlipAttrs(Attrs):
"""Attributes for flip operator"""
+@tvm_ffi.register_object("relax.attrs.ReverseSequenceAttrs")
+class ReverseSequenceAttrs(Attrs):
+ """Attributes for reverse_sequence operator"""
+
+
@tvm_ffi.register_object("relax.attrs.PadAttrs")
class PadAttrs(Attrs):
"""Attributes used in pad operator"""
diff --git a/python/tvm/relax/script/builder/ir.py
b/python/tvm/relax/script/builder/ir.py
index 4a30963eea..8c4f0191db 100644
--- a/python/tvm/relax/script/builder/ir.py
+++ b/python/tvm/relax/script/builder/ir.py
@@ -153,6 +153,7 @@ from tvm.relax.op import (
quantize,
repeat,
reshape,
+ reverse_sequence,
right_shift,
round,
rsqrt,
@@ -941,6 +942,7 @@ __all__ = [
"quantize",
"repeat",
"reshape",
+ "reverse_sequence",
"rewriter",
"right_shift",
"rocm",
diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py
b/python/tvm/relax/transform/legalize_ops/manipulate.py
index 1f3abaaf6e..f0cc8977d4 100644
--- a/python/tvm/relax/transform/legalize_ops/manipulate.py
+++ b/python/tvm/relax/transform/legalize_ops/manipulate.py
@@ -170,6 +170,18 @@ def _flip(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(topi.flip, call.args[0], int(call.attrs.axis))
+@register_legalize("relax.reverse_sequence")
+def _reverse_sequence(bb: BlockBuilder, call: Call) -> Expr:
+ return bb.call_te(
+ topi.reverse_sequence,
+ call.args[0],
+ call.args[1],
+ int(call.attrs.seq_axis),
+ int(call.attrs.batch_axis),
+ primfunc_name_hint="reverse_sequence",
+ )
+
+
@register_legalize("relax.gather_elements")
def _gather_elements(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(topi.gather, call.args[0], int(call.attrs.axis),
call.args[1])
diff --git a/src/relax/op/tensor/manipulate.cc
b/src/relax/op/tensor/manipulate.cc
index d27a56bd92..caa7300913 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -51,6 +51,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
RepeatAttrs::RegisterReflection();
TileAttrs::RegisterReflection();
FlipAttrs::RegisterReflection();
+ ReverseSequenceAttrs::RegisterReflection();
GatherElementsAttrs::RegisterReflection();
GatherNDAttrs::RegisterReflection();
IndexPutAttrs::RegisterReflection();
@@ -2071,6 +2072,96 @@ TVM_REGISTER_OP("relax.flip")
.set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutFlip)
.set_attr<bool>("FPurity", true);
+/* relax.reverse_sequence */
+
+Expr reverse_sequence(Expr data, Expr seq_lengths, int64_t seq_axis, int64_t
batch_axis) {
+ auto attrs = ffi::make_object<ReverseSequenceAttrs>();
+ attrs->seq_axis = seq_axis;
+ attrs->batch_axis = batch_axis;
+ static const Op& op = Op::Get("relax.reverse_sequence");
+ return Call(op, {std::move(data), std::move(seq_lengths)}, Attrs{attrs}, {});
+}
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+ namespace refl = tvm::ffi::reflection;
+ refl::GlobalDef().def("relax.op.reverse_sequence", reverse_sequence);
+}
+
+Type InferTypeReverseSequence(const Call& call, const BlockBuilder& ctx) {
+ if (call->args.size() != 2) {
+ TVM_FFI_VISIT_THROW(ValueError, call) << "ReverseSequence op should take 2
arguments";
+ }
+ TensorType data_ty = GetInputTensorType(call, 0, ctx);
+ TensorType seq_lengths_ty = GetInputTensorType(call, 1, ctx);
+
+ if (!seq_lengths_ty->IsUnknownNdim() && seq_lengths_ty->ndim != 1) {
+ TVM_FFI_VISIT_THROW(ValueError, call)
+ << "ReverseSequence requires seq_lengths to be 1-D. However,
seq_lengths has ndim "
+ << seq_lengths_ty->ndim;
+ }
+ if (!seq_lengths_ty->dtype.is_void() && !seq_lengths_ty->dtype.is_int()) {
+ TVM_FFI_VISIT_THROW(ValueError, call)
+ << "ReverseSequence requires seq_lengths to have dtype int32 or int64.
However, "
+ "seq_lengths has dtype "
+ << seq_lengths_ty->dtype;
+ }
+ if (seq_lengths_ty->dtype.is_int() && seq_lengths_ty->dtype.bits() != 32 &&
+ seq_lengths_ty->dtype.bits() != 64) {
+ TVM_FFI_VISIT_THROW(ValueError, call)
+ << "ReverseSequence requires seq_lengths to have dtype int32 or int64.
However, "
+ "seq_lengths has dtype "
+ << seq_lengths_ty->dtype;
+ }
+
+ const auto* attrs = call->attrs.as<ReverseSequenceAttrs>();
+ int64_t seq_axis = attrs->seq_axis;
+ int64_t batch_axis = attrs->batch_axis;
+ if (!data_ty->IsUnknownNdim()) {
+ int ndim = data_ty->ndim;
+ auto check_axis = [&](int64_t axis, ffi::String axis_name) {
+ if (axis < -ndim || axis >= ndim) {
+ TVM_FFI_VISIT_THROW(ValueError, call)
+ << "ReverseSequence requires " << axis_name
+ << " to belong to range [-ndim, ndim). However, the axis is " <<
axis
+ << ", while ndim is " << ndim;
+ }
+ };
+ check_axis(seq_axis, "seq_axis");
+ check_axis(batch_axis, "batch_axis");
+
+ if (batch_axis < 0) {
+ batch_axis += ndim;
+ }
+
+ if (data_ty->shape.defined() && seq_lengths_ty->shape.defined()) {
+ const auto* data_shape_ty =
GetTypeAs<ShapeTypeNode>(data_ty->shape.value());
+ const auto* seq_lengths_shape_ty =
GetTypeAs<ShapeTypeNode>(seq_lengths_ty->shape.value());
+ if (data_shape_ty != nullptr && seq_lengths_shape_ty != nullptr &&
+ data_shape_ty->values.defined() &&
seq_lengths_shape_ty->values.defined()) {
+ PrimExpr batch_extent = data_shape_ty->values.value()[batch_axis];
+ PrimExpr seq_lengths_extent = seq_lengths_shape_ty->values.value()[0];
+ if (ctx->GetAnalyzer()->CanProve(seq_lengths_extent != batch_extent)) {
+ TVM_FFI_VISIT_THROW(ValueError, call)
+ << "ReverseSequence requires seq_lengths.shape[0] to equal the
batch axis extent. "
+ "However, seq_lengths.shape[0] is "
+ << seq_lengths_extent << ", while data.shape[" << batch_axis <<
"] is "
+ << batch_extent;
+ }
+ }
+ }
+ }
+
+ return data_ty;
+}
+
+TVM_REGISTER_OP("relax.reverse_sequence")
+ .set_attrs_type<ReverseSequenceAttrs>()
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("seq_lengths", "Tensor", "The sequence length tensor.")
+ .set_attr<FInferType>("FInferType", InferTypeReverseSequence)
+ .set_attr<bool>("FPurity", true);
+
/* relax.gather_elements */
Expr gather_elements(Expr data, Expr indices, int axis) {
diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h
index 147e622f4d..343b0f6651 100644
--- a/src/relax/op/tensor/manipulate.h
+++ b/src/relax/op/tensor/manipulate.h
@@ -181,6 +181,16 @@ Expr tile(Expr data, ffi::Array<int64_t> repeats);
*/
Expr flip(Expr data, int64_t axis);
+/*!
+ * \brief Reverses variable length slices along seq_axis.
+ * \param data The input tensor.
+ * \param seq_lengths A 1-D tensor containing sequence lengths for each batch.
+ * \param seq_axis The axis along which to reverse.
+ * \param batch_axis The axis that indexes the batch.
+ * \return The computed result.
+ */
+Expr reverse_sequence(Expr data, Expr seq_lengths, int64_t seq_axis, int64_t
batch_axis);
+
/*!
* \brief Gather elements from a tensor using indices.
* \param data The input tensor.
diff --git a/src/topi/transform.cc b/src/topi/transform.cc
index 5e81e95c60..f0d9225fb5 100644
--- a/src/topi/transform.cc
+++ b/src/topi/transform.cc
@@ -58,8 +58,9 @@ TVM_FFI_STATIC_INIT_BLOCK() {
})
.def_packed("topi.reverse_sequence",
[](ffi::PackedArgs args, ffi::Any* rv) {
+ int batch_axis = args.size() >= 4 ? args[3].cast<int>() :
0;
*rv = reverse_sequence(args[0].cast<te::Tensor>(),
args[1].cast<te::Tensor>(),
- args[2].cast<int>());
+ args[2].cast<int>(), batch_axis);
})
.def_packed("topi.reshape",
[](ffi::PackedArgs args, ffi::Any* rv) {
diff --git a/tests/python/relax/test_frontend_tflite.py
b/tests/python/relax/test_frontend_tflite.py
index 590cc4ac45..9962732e08 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -1617,6 +1617,42 @@ def test_reverse_v2():
verify(ReverseV2, Expected)
+def test_reverse_sequence():
+ mod = _load_model_from_buffer(_build_tflite_reverse_sequence_model())
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ tvmgen_tensor_0: R.Tensor((2, 4, 3), dtype="float32"),
+ tvmgen_tensor_1: R.Tensor((2,), dtype="int32"),
+ ) -> R.Tensor((2, 4, 3), dtype="float32"):
+ R.func_attr({"num_input": 2})
+ with R.dataflow():
+ gv: R.Tensor((2, 4, 3), dtype="float32") = R.reverse_sequence(
+ tvmgen_tensor_0, tvmgen_tensor_1, seq_axis=1, batch_axis=0
+ )
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+ ir = mod.script()
+ assert "R.reverse_sequence" in ir
+ assert 'R.call_dps_packed("topi.reverse_sequence"' not in ir
+
+ data = np.arange(24, dtype="float32").reshape((2, 4, 3))
+ seq_lengths = np.array([1, 3], dtype="int32")
+ expected = data.copy()
+ expected[1, :3, :] = expected[1, :3, :][::-1]
+
+ ex = tvm.compile(mod, tvm.target.Target("c"))
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ vm.set_input("main", data, seq_lengths)
+ vm.invoke_stateful("main")
+ output = vm.get_outputs("main")
+ np.testing.assert_allclose(output.numpy(), expected, rtol=1e-5, atol=1e-5)
+
+
def test_gather():
class Gather(tf.Module):
@tf.function(
@@ -1674,6 +1710,73 @@ def test_gather_nd():
verify(GatherND, Expected)
+def test_squeeze():
+ mod = _load_model_from_buffer(_build_tflite_squeeze_model())
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(tvmgen_tensor_0: R.Tensor((1, 2, 1, 3), dtype="float32")) ->
R.Tensor(
+ (2, 3), dtype="float32"
+ ):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ gv: R.Tensor((2, 3), dtype="float32") =
R.squeeze(tvmgen_tensor_0, axis=[0, 2])
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_unpack():
+ mod = _load_model_from_buffer(_build_tflite_unpack_model())
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(tvmgen_tensor_0: R.Tensor((2, 3, 4), dtype="float32")) ->
R.Tuple(
+ R.Tensor((2, 4), dtype="float32"),
+ R.Tensor((2, 4), dtype="float32"),
+ R.Tensor((2, 4), dtype="float32"),
+ ):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ lv: R.Tuple(
+ R.Tensor((2, 1, 4), dtype="float32"),
+ R.Tensor((2, 1, 4), dtype="float32"),
+ R.Tensor((2, 1, 4), dtype="float32"),
+ ) = R.split(tvmgen_tensor_0, indices_or_sections=3, axis=1)
+ lv1: R.Tensor((2, 1, 4), dtype="float32") = lv[0]
+ lv2: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv1,
axis=[1])
+ lv3: R.Tensor((2, 1, 4), dtype="float32") = lv[1]
+ lv4: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv3,
axis=[1])
+ lv5: R.Tensor((2, 1, 4), dtype="float32") = lv[2]
+ lv6: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv5,
axis=[1])
+ gv = (lv2, lv4, lv6)
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_zeros_like():
+ mod = _load_model_from_buffer(_build_tflite_zeros_like_model())
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(tvmgen_tensor_0: R.Tensor((2, 3), dtype="float32")) ->
R.Tensor(
+ (2, 3), dtype="float32"
+ ):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ gv: R.Tensor((2, 3), dtype="float32") =
R.zeros_like(tvmgen_tensor_0)
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
def _make_conv2d_module(data_shape, kernel_shape, data_format, strides,
padding):
class Conv2DModule(tf.Module):
@tf.function(
@@ -4029,7 +4132,11 @@ _tfl_quantization_parameters =
_get_tflite_schema_module("QuantizationParameters
_tfl_sparsity_parameters = _get_tflite_schema_module("SparsityParameters")
_tfl_subgraph = _get_tflite_schema_module("SubGraph")
_tfl_tensor = _get_tflite_schema_module("Tensor")
+_tfl_reverse_sequence_options =
_get_tflite_schema_module("ReverseSequenceOptions")
+_tfl_squeeze_options = _get_tflite_schema_module("SqueezeOptions")
+_tfl_unpack_options = _get_tflite_schema_module("UnpackOptions")
_tfl_while_options = _get_tflite_schema_module("WhileOptions")
+_tfl_zeros_like_options = _get_tflite_schema_module("ZerosLikeOptions")
_tfl_builtin_operator = _get_tflite_schema_enum("BuiltinOperator")
_tfl_builtin_options = _get_tflite_schema_enum("BuiltinOptions")
@@ -4285,6 +4392,31 @@ def _build_call_once_options(builder,
init_subgraph_index):
return _tfl_call_once_options.CallOnceOptionsEnd(builder)
+def _build_squeeze_options(builder, squeeze_dims):
+ squeeze_dims_vec = _tflite_int32_vector(
+ builder,
+ _tfl_squeeze_options.SqueezeOptionsStartSqueezeDimsVector,
+ squeeze_dims,
+ )
+ _tfl_squeeze_options.SqueezeOptionsStart(builder)
+ _tfl_squeeze_options.SqueezeOptionsAddSqueezeDims(builder,
squeeze_dims_vec)
+ return _tfl_squeeze_options.SqueezeOptionsEnd(builder)
+
+
+def _build_reverse_sequence_options(builder, seq_dim, batch_dim):
+ _tfl_reverse_sequence_options.ReverseSequenceOptionsStart(builder)
+ _tfl_reverse_sequence_options.ReverseSequenceOptionsAddSeqDim(builder,
seq_dim)
+ _tfl_reverse_sequence_options.ReverseSequenceOptionsAddBatchDim(builder,
batch_dim)
+ return _tfl_reverse_sequence_options.ReverseSequenceOptionsEnd(builder)
+
+
+def _build_unpack_options(builder, num, axis):
+ _tfl_unpack_options.UnpackOptionsStart(builder)
+ _tfl_unpack_options.UnpackOptionsAddNum(builder, num)
+ _tfl_unpack_options.UnpackOptionsAddAxis(builder, axis)
+ return _tfl_unpack_options.UnpackOptionsEnd(builder)
+
+
def _get_builtin_options_type(options_name):
if not hasattr(_tfl_builtin_options, options_name):
pytest.skip(f"TFLite schema does not provide
BuiltinOptions.{options_name}")
@@ -4410,6 +4542,146 @@ def test_operator_marker_unsupported(builtin_name):
_load_model_from_buffer(_build_tflite_operator_marker_model(builtin_name))
+def _build_tflite_squeeze_model():
+ builder = flatbuffers.Builder(1024)
+
+ squeeze_opts = _build_squeeze_options(builder, [0, 2])
+ squeeze_op_code = _build_operator_code(builder,
_tfl_builtin_operator.SQUEEZE)
+
+ tensors = [
+ _build_tensor(builder, 0, [1, 2, 1, 3]),
+ _build_tensor(builder, 0, [2, 3]),
+ ]
+ squeeze_op = _build_operator(
+ builder,
+ 0,
+ [0],
+ [1],
+ builtin_options_type=_tfl_builtin_options.SqueezeOptions,
+ builtin_options=squeeze_opts,
+ )
+ subgraph = _build_subgraph(
+ builder,
+ tensors=tensors,
+ operators=[squeeze_op],
+ inputs=[0],
+ outputs=[1],
+ )
+ buffers = [_build_buffer(builder)]
+ return _finish_tflite_model(
+ builder,
+ subgraph=subgraph,
+ operator_codes=[squeeze_op_code],
+ buffers=buffers,
+ )
+
+
+def _build_tflite_reverse_sequence_model():
+ builder = flatbuffers.Builder(1024)
+
+ reverse_sequence_opts = _build_reverse_sequence_options(builder,
seq_dim=1, batch_dim=0)
+ reverse_sequence_op_code = _build_operator_code(builder,
_tfl_builtin_operator.REVERSE_SEQUENCE)
+
+ tensors = [
+ _build_tensor(builder, 0, [2, 4, 3]),
+ _build_tensor(builder, 0, [2], tensor_type=_tfl_tensor_type.INT32),
+ _build_tensor(builder, 0, [2, 4, 3]),
+ ]
+ reverse_sequence_op = _build_operator(
+ builder,
+ 0,
+ [0, 1],
+ [2],
+ builtin_options_type=_tfl_builtin_options.ReverseSequenceOptions,
+ builtin_options=reverse_sequence_opts,
+ )
+ subgraph = _build_subgraph(
+ builder,
+ tensors=tensors,
+ operators=[reverse_sequence_op],
+ inputs=[0, 1],
+ outputs=[2],
+ )
+ buffers = [_build_buffer(builder)]
+ return _finish_tflite_model(
+ builder,
+ subgraph=subgraph,
+ operator_codes=[reverse_sequence_op_code],
+ buffers=buffers,
+ )
+
+
+def _build_tflite_unpack_model():
+ builder = flatbuffers.Builder(1024)
+
+ unpack_opts = _build_unpack_options(builder, num=3, axis=1)
+ unpack_op_code = _build_operator_code(builder,
_tfl_builtin_operator.UNPACK)
+
+ tensors = [
+ _build_tensor(builder, 0, [2, 3, 4]),
+ _build_tensor(builder, 0, [2, 4]),
+ _build_tensor(builder, 0, [2, 4]),
+ _build_tensor(builder, 0, [2, 4]),
+ ]
+ unpack_op = _build_operator(
+ builder,
+ 0,
+ [0],
+ [1, 2, 3],
+ builtin_options_type=_tfl_builtin_options.UnpackOptions,
+ builtin_options=unpack_opts,
+ )
+ subgraph = _build_subgraph(
+ builder,
+ tensors=tensors,
+ operators=[unpack_op],
+ inputs=[0],
+ outputs=[1, 2, 3],
+ )
+ buffers = [_build_buffer(builder)]
+ return _finish_tflite_model(
+ builder,
+ subgraph=subgraph,
+ operator_codes=[unpack_op_code],
+ buffers=buffers,
+ )
+
+
+def _build_tflite_zeros_like_model():
+ builder = flatbuffers.Builder(1024)
+
+ _tfl_zeros_like_options.ZerosLikeOptionsStart(builder)
+ zeros_like_opts = _tfl_zeros_like_options.ZerosLikeOptionsEnd(builder)
+ zeros_like_op_code = _build_operator_code(builder,
_tfl_builtin_operator.ZEROS_LIKE)
+
+ tensors = [
+ _build_tensor(builder, 0, [2, 3]),
+ _build_tensor(builder, 0, [2, 3]),
+ ]
+ zeros_like_op = _build_operator(
+ builder,
+ 0,
+ [0],
+ [1],
+ builtin_options_type=_tfl_builtin_options.ZerosLikeOptions,
+ builtin_options=zeros_like_opts,
+ )
+ subgraph = _build_subgraph(
+ builder,
+ tensors=tensors,
+ operators=[zeros_like_op],
+ inputs=[0],
+ outputs=[1],
+ )
+ buffers = [_build_buffer(builder)]
+ return _finish_tflite_model(
+ builder,
+ subgraph=subgraph,
+ operator_codes=[zeros_like_op_code],
+ buffers=buffers,
+ )
+
+
def _run_module(mod, *inputs):
tgt = tvm.target.Target("c")
ex = tvm.compile(mod, tgt)
diff --git a/tests/python/relax/test_op_manipulate.py
b/tests/python/relax/test_op_manipulate.py
index 9a938b647d..3c308167ba 100644
--- a/tests/python/relax/test_op_manipulate.py
+++ b/tests/python/relax/test_op_manipulate.py
@@ -46,6 +46,8 @@ def test_op_correctness():
assert relax.op.cumsum(x, axis=1, dtype="int32").op ==
Op.get("relax.cumsum")
assert relax.op.einsum(x, subscripts="ii").op == Op.get("relax.einsum")
assert relax.op.flip(x, axis=1).op == Op.get("relax.flip")
+ seq_lengths = relax.Var("seq_lengths", R.Tensor((3,), "int32"))
+ assert relax.op.reverse_sequence(x, seq_lengths).op ==
Op.get("relax.reverse_sequence")
assert relax.op.scatter_elements(x, x, x).op ==
Op.get("relax.scatter_elements")
assert relax.op.scatter_nd(x, x, x).op == Op.get("relax.scatter_nd")
@@ -3019,6 +3021,76 @@ def test_flip_infer_ty_wrong_inputs():
bb.normalize(relax.op.flip(x0, axis=3))
+def test_reverse_sequence_infer_ty():
+ bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
+ x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32"))
+ x1 = relax.Var("x", R.Tensor("float16", ndim=3))
+ x2 = relax.Var("x", R.Tensor("int32"))
+ x3 = relax.Var("x", R.Tensor((2, 10, 4)))
+ x4 = relax.Var("x", R.Tensor(ndim=3))
+ x5 = relax.Var("x", R.Tensor((2, 10, 4), "float32", vdev0))
+ s0 = relax.Var("s", R.Tensor((2,), "int32"))
+ s1 = relax.Var("s", R.Tensor("int64", ndim=1))
+
+ _check_inference(
+ bb,
+ relax.op.reverse_sequence(x0, s0, seq_axis=1, batch_axis=0),
+ relax.TensorType((2, 10, 4), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.reverse_sequence(x5, s0, seq_axis=1, batch_axis=0),
+ relax.TensorType((2, 10, 4), "float32", vdev0),
+ )
+ _check_inference(
+ bb, relax.op.reverse_sequence(x1, s1, seq_axis=0, batch_axis=1),
R.Tensor("float16", ndim=3)
+ )
+ _check_inference(bb, relax.op.reverse_sequence(x2, s1), R.Tensor("int32"))
+ _check_inference(bb, relax.op.reverse_sequence(x3, s0), R.Tensor((2, 10,
4)))
+ _check_inference(bb, relax.op.reverse_sequence(x4, s1), R.Tensor(ndim=3))
+
+
+def test_reverse_sequence_infer_ty_shape_symbolic():
+ bb = relax.BlockBuilder()
+ a = tirx.Var("a", "int64")
+ b = tirx.Var("b", "int64")
+ x = relax.Var("x", R.Tensor((a, b, 4), "float32"))
+ seq_lengths = relax.Var("seq_lengths", R.Tensor((b,), "int64"))
+
+ _check_inference(
+ bb,
+ relax.op.reverse_sequence(x, seq_lengths, seq_axis=0, batch_axis=1),
+ relax.TensorType((a, b, 4), "float32"),
+ )
+
+
+def test_reverse_sequence_infer_ty_wrong_inputs():
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor((2, 10, 4), "float32"))
+ seq_lengths = relax.Var("seq_lengths", R.Tensor((2,), "int32"))
+ seq_lengths_2d = relax.Var("seq_lengths", R.Tensor((2, 1), "int32"))
+ seq_lengths_float = relax.Var("seq_lengths", R.Tensor((2,), "float32"))
+ seq_lengths_int16 = relax.Var("seq_lengths", R.Tensor((2,), "int16"))
+ seq_lengths_mismatch = relax.Var("seq_lengths", R.Tensor((3,), "int32"))
+ not_tensor = relax.Var("seq_lengths", relax.ObjectType())
+
+ with pytest.raises(TypeError):
+ bb.normalize(relax.op.reverse_sequence(x, not_tensor))
+ with pytest.raises(ValueError):
+ bb.normalize(relax.op.reverse_sequence(x, seq_lengths_2d))
+ with pytest.raises(ValueError):
+ bb.normalize(relax.op.reverse_sequence(x, seq_lengths_float))
+ with pytest.raises(ValueError):
+ bb.normalize(relax.op.reverse_sequence(x, seq_lengths_int16))
+ with pytest.raises(ValueError):
+ bb.normalize(relax.op.reverse_sequence(x, seq_lengths, seq_axis=3))
+ with pytest.raises(ValueError):
+ bb.normalize(relax.op.reverse_sequence(x, seq_lengths, batch_axis=-4))
+ with pytest.raises(ValueError):
+ bb.normalize(relax.op.reverse_sequence(x, seq_lengths_mismatch))
+
+
def test_gather_elements_infer_ty():
bb = relax.BlockBuilder()
vdev0 = VDevice("llvm")
diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py
b/tests/python/relax/test_transform_legalize_ops_manipulate.py
index 92668ab60e..45036523ac 100644
--- a/tests/python/relax/test_transform_legalize_ops_manipulate.py
+++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py
@@ -1363,6 +1363,62 @@ def test_flip_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_reverse_sequence():
+ # fmt: off
+ @I.ir_module(s_tir=True)
+ class ReverseSequence:
+ @R.function
+ def main(x: R.Tensor((4, 2, 3), "float32"), seq_lengths:
R.Tensor((2,), "int64")):
+ gv = R.reverse_sequence(x, seq_lengths, seq_axis=0, batch_axis=1)
+ return gv
+
+ @I.ir_module(s_tir=True)
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((4, 2, 3), dtype="float32"),
+ seq_lengths: R.Tensor((2,), dtype="int64"),
+ ) -> R.Tensor((4, 2, 3), dtype="float32"):
+ cls = Expected
+ gv = R.call_tir(
+ cls.reverse_sequence,
+ (x, seq_lengths),
+ out_ty=R.Tensor((4, 2, 3), dtype="float32"),
+ )
+ return gv
+
+ @T.prim_func(private=True, s_tir=True)
+ def reverse_sequence(
+ rxplaceholder: T.Buffer((T.int64(4), T.int64(2), T.int64(3)),
"float32"),
+ seq_lengths: T.Buffer((T.int64(2),), "int64"),
+ T_reverse_sequence: T.Buffer((T.int64(4), T.int64(2), T.int64(3)),
"float32"),
+ ):
+ T.func_attr({"tirx.noalias": True})
+ for ax0, ax1, ax2 in T.grid(T.int64(4), T.int64(2), T.int64(3)):
+ with T.sblock("T_reverse_sequence"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(rxplaceholder[T.int64(0):T.int64(4), v_ax1,
v_ax2], seq_lengths[v_ax1])
+ T.writes(T_reverse_sequence[v_ax0, v_ax1, v_ax2])
+ T_reverse_sequence[v_ax0, v_ax1, v_ax2] = rxplaceholder[
+ T.if_then_else(
+ seq_lengths[v_ax1] <= T.int64(1) or
seq_lengths[v_ax1] <= v_ax0,
+ v_ax0,
+ T.if_then_else(
+ T.int64(4) < seq_lengths[v_ax1],
+ T.int64(3) - v_ax0,
+ seq_lengths[v_ax1] - v_ax0 - T.int64(1),
+ ),
+ ),
+ v_ax1,
+ v_ax2,
+ ]
+
+ # fmt: on
+
+ mod = LegalizeOps()(ReverseSequence)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
def test_scatter_elements():
# fmt: off
@I.ir_module(s_tir=True)