This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 39535048b6 [Relax][Frontend][TFLite] Add RFFT2D op and supporting TIR
kernels (#19812)
39535048b6 is described below
commit 39535048b62b5a23c941a0e98976bd53da050f1d
Author: Hongyi Wu <[email protected]>
AuthorDate: Thu Jun 18 02:25:05 2026 +0800
[Relax][Frontend][TFLite] Add RFFT2D op and supporting TIR kernels (#19812)
## Summary
This PR adds Relax TFLite frontend support for the TFLite builtin
`RFFT2D`
operator (issue #19519 item C — FFT / complex operators). It is the
follow-up to upstream PR #19763, which already merged the `REAL` /
`IMAG` /
`COMPLEX_ABS` subset of item C; this PR completes the subset with
`RFFT2D`
itself.
`RFFT2D` computes a 2D real FFT over the last two input axes and returns
a
real/imag pair tensor of shape `[..., H, W // 2 + 1, 2]`. Relax does not
have
a native complex64 dtype, so the pair output is represented as a
`float32`
tensor with a trailing axis of size 2, matching the convention PR #19763
established for `REAL` / `IMAG` / `COMPLEX_ABS`. `RFFT2D` accepts a real
`float32` input and emits a `float32` pair output of the same dtype; the
frontend does not need to materialize any COMPLEX64 in-memory
representation
of its own.
The supported subset is the static no-padding / no-truncation path:
- the input's last two dimensions must match `fft_length`
- `fft_length` must be a length-2 integer constant (int32 or int64)
- the output shape is `input_shape[:-2] + (H, W // 2 + 1, 2)`
- sparse inputs are rejected
## Design
### Dispatch Between Two TIR Kernels
`convert_rfft2d` selects one of two TIR primfuncs at lowering time,
based on
whether the spatial axes are powers of two:
```python
if _is_power_of_2(height) and _is_power_of_2(width):
prim_func = _build_tflite_rfft2d_fft_primfunc(input_shape,
relax_output_shape)
else:
prim_func = _build_tflite_rfft2d_primfunc(input_shape,
relax_output_shape)
```
Both kernels share the same `call_tir` contract, so downstream code is
kernel-agnostic.
#### DFT Reference Kernel (`_build_tflite_rfft2d_primfunc`)
A naive O(B · H · W · H · W) DFT over the last two input axes. The outer
(batch, out_y, out_x) iteration is structured as S-TIR spatial axes so a
downstream `tir.schedule` pass can parallelize it. Trig and accumulation
are
in float32; the result agrees with `np.fft.rfft2` to about `1e-5`
absolute
tolerance for typical input sizes. This kernel is the fallback for odd
or
non-power-of-2 spatial sizes.
#### Cooley-Tukey FFT Kernel (`_build_tflite_rfft2d_fft_primfunc`)
An O(B · H · W · (log2(H) + log2(W))) radix-2 Cooley-Tukey FFT,
dispatched
when both `H` and `W` are positive powers of 2. The primfunc source is
generated as a TIR string at construction time and registered in
`linecache`
so `tirx.parser` can resolve it. The bit-reversal permutation and
butterfly
stages are precomputed in Python and inlined as direct scratch-buffer
assignments, so all loop bounds and twiddle factors are compile-time
literals:
1. Copy the real input into `scratch_real`; initialize `scratch_imag` to
0.
2. For each batch and each row, run an in-place 1D FFT of length `W`
along
the width axis using scratch buffers.
3. For each batch and each column, run an in-place 1D FFT of length `H`
along the height axis with stride `W`.
4. Write the first `W // 2 + 1` complex bins per row to the output pair
representation.
The bit-reversal swap pairs and butterfly stage bodies are precomputed
at
primfunc-construction time because `tirx.parser` does not currently
accept
runtime `T.serial` bounds, and twiddle factors must be `T.float32`
literals
to avoid `Undefined variable` errors. The fake linecache filename
`<tflite_rfft2d_fft_primfunc H=8 W=8 outW=5>` is dimension-aware, so
generated-source stack traces are readable.
### COMPLEX64 Pair Representation Helpers
The frontend represents TFLite COMPLEX64 tensors as float32 real/imag
pairs
with a trailing axis of size 2 (since Relax has no native complex64
dtype).
Four small helpers centralize the rule so future complex ops can plug in
without re-implementing the pair-axis layout:
- `_is_tflite_complex64_type` — checks whether a TFLite tensor type is
`COMPLEX64`.
- `_unwrap_tflite_tensor` — unwraps a `TensorWrapper` to the raw
`tflite.Tensor`.
- `_get_relax_tensor_dtype` — returns the Relax dtype used to represent
a
TFLite tensor (`"float32"` for COMPLEX64, otherwise the standard
mapping).
- `_get_relax_tensor_shape` — returns the Relax shape (TFLite shape with
a
trailing `(2,)` axis appended for COMPLEX64).
The 3 callers that construct Relax parameters from TFLite metadata
(`_get_static_tensor_shape_dtype` / `_set_subgraph_input_params` /
`_get_tensor_param`) now go through these helpers. The pair-axis
invariant
is documented on `get_tensor_value` and `get_tensor_shape`, which both
return the *raw TFLite* shape (no pair axis) for downstream callers that
need to compare against the model's declared output shape.
### Boundary Validation
`convert_rfft2d` validates rank, dtype, fft_length shape, integer-ness,
positivity, fft_length == input spatial shape, output shape agreement,
and
the absence of sparse inputs before emitting the `call_tir`. Edge cases
(sparse, non-integer fft_length, zero/negative fft_length, mismatched
fft_length, 1×1 spatial size) each produce a targeted `OpNotImplemented`
diagnostic.
## Operator Support
| Operator | TFLite options | Relax lowering | Supported subset |
|---|---|---|---|
| `RFFT2D` | input `float32`, constant length-2 integer `fft_length`,
`COMPLEX64` output | `call_tir` to a generated TIR kernel | static
no-padding/no-truncation; `H`, `W` arbitrary; Cooley-Tukey dispatched
when both are powers of 2 |
## Safety Checks
- Non-float32 input raises `OpNotImplemented("RFFT2D input must be
float32")`.
- Non-COMPLEX64 output raises `OpNotImplemented("RFFT2D output must be
COMPLEX64")`.
- Sparse inputs raise `OpNotImplemented("RFFT2D does not support sparse
inputs")`.
- Non-constant `fft_length` raises `OpNotImplemented("RFFT2D requires a
constant fft_length")`.
- Non-integer `fft_length` raises `OpNotImplemented("RFFT2D fft_length
must be an integer tensor")`.
- Wrong-shape `fft_length` (not length 2) raises
`OpNotImplemented("RFFT2D fft_length must be a length-2 tensor")`.
- Non-positive `fft_length` raises `OpNotImplemented("RFFT2D fft_length
must be positive")`.
- `fft_length` not matching the input's last two dims raises
`OpNotImplemented("RFFT2D currently supports fft_length matching the
input spatial shape")`.
- Mismatched output shape raises `OpNotImplemented("RFFT2D output shape
does not match fft_length")`.
- Input rank < 2 raises `OpNotImplemented("RFFT2D input rank must be at
least 2")`.
## Not Included
- `RFFT2D` with `fft_length` not matching the input's last two
dimensions
(padding / truncation path).
- Other complex-data operators: `REAL` / `IMAG` / `COMPLEX_ABS` are
already
handled by upstream PR #19763 and are out of scope for this PR.
- A frontend guard that rejects COMPLEX64 inputs flowing into
non-complex
ops. With PR #19763 providing the only other complex ops, models that
contain only `RFFT2D` (with float32 input / COMPLEX64 output) and the
three upstream ops are fully supported. A generic COMPLEX64 guard was
intentionally not added here to keep this PR scoped to `RFFT2D`.
- User-override validation for `shape_dict` / `dtype_dict` on COMPLEX64
inputs. After PR #19763, the frontend no longer auto-appends a
trailing pair axis to user overrides; a user passing the natural TFLite
shape without the pair axis will now fall through to the standard
metadata mismatch path.
- Higher-precision FFT kernel (e.g. SIMD). The float32 twiddle / float32
accumulation paths match `np.fft.rfft2` to `~1e-4` on the Cooley-Tukey
path; large spatial dimensions may need a future optimized lowering or
backend-specific implementation.
## Tests
The tests manually build minimal TFLite flatbuffers, run the frontend,
and
compare against `np.fft.rfft2`. Edge cases raise `OpNotImplemented`. The
DFT-path tests use `atol=1e-5`; the FFT-path tests use `atol=1e-4`
because
twiddle factors are float32 literals.
| Test | Path | Shape | Coverage |
|---|---|---|---|
| `test_rfft2d_static_pair_output` | DFT | `[2, 4]` | Baseline 2D, even
width; also asserts Relax script contains the `tflite_rfft2d` kernel
name and pair-output struct-info |
| `test_rfft2d_static_pair_output_with_batch` | DFT | `[2, 2, 4]` |
Leading batch dims preserved |
| `test_rfft2d_odd_width_pair_output` | DFT | `[3, 5]` | Odd width → `W
// 2 + 1` output bins |
| `test_rfft2d_int64_fft_length` | DFT | `[2, 4]` | INT64 fft_length
constant (TFLite schema allows either int32 or int64) |
| `test_rfft2d_4d_input_pair_output` | DFT | `[2, 3, 4, 5]` | 4D input
with batch and odd width |
| `test_rfft2d_minimal_1x1_pair_output` | DFT | `[1, 1]` | Edge case:
trivial 1×1 FFT |
| `test_rfft2d_mismatched_fft_length_unsupported` | — | `[2, 4]`
(fft=`[4, 4]`) | fft_length != input spatial shape guard |
| `test_rfft2d_dynamic_fft_length_unsupported` | — | `[2, 4]` |
Dynamic/non-constant fft_length guard |
| `test_rfft2d_fft_path_4x4` | **FFT** | `[4, 4]` | Smallest
power-of-two (4×4) where both row and column FFTs do real work |
| `test_rfft2d_fft_path_8x8` | **FFT** | `[8, 8]` | Square 8×8
power-of-two |
| `test_rfft2d_fft_path_16x16` | **FFT** | `[16, 16]` | Larger FFT,
kernel scaling check |
| `test_rfft2d_fft_path_2x2x4x8` | **FFT** | `[2, 2, 4, 8]` | 4D
power-of-two with batch |
Local validation:
```bash
python -m py_compile \
python/tvm/relax/frontend/tflite/tflite_frontend.py \
tests/python/relax/test_frontend_tflite.py
python -m ruff check \
python/tvm/relax/frontend/tflite/tflite_frontend.py \
tests/python/relax/test_frontend_tflite.py
python -m pytest \
tests/python/relax/test_frontend_tflite.py \
-k rfft2d -v
```
Result:
```text
py_compile: passed
ruff check: All checks passed
pre-commit run --files: passed
rfft2d tests: 12 passed
```
## References
- Issue [#19519](https://github.com/apache/tvm/issues/19519) item C: FFT
/
complex operators (`RFFT2D`, `REAL`, `IMAG`, `COMPLEX_ABS`).
- Upstream PR #19763: `[Relax][Frontend][TFLite] Add support for
FFT/complex operators: REAL, IMAG, COMPLEX_ABS` (merge commit
`9d6e1cf0`).
This PR's RFFT2D output pair layout matches the COMPLEX64 pair
representation PR #19763 established for `REAL` / `IMAG` /
`COMPLEX_ABS`,
so downstream ops from PR #19763 can consume RFFT2D output directly.
- Tracking issue [#19764](https://github.com/apache/tvm/issues/19764)
for
the longer-term native `relax.op.signal.rfft2d` / registered TOPI
backend
path. This PR is a frontend-local lowering for TFLite `RFFT2D`, not the
native Relax signal op tracked there.
---
.../tvm/relax/frontend/tflite/tflite_frontend.py | 439 ++++++++++++++++++++-
tests/python/relax/test_frontend_tflite.py | 279 +++++++++++++
2 files changed, 706 insertions(+), 12 deletions(-)
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 3bd87d0af4..0edfc00ce9 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -309,6 +309,7 @@ class OperatorConverter:
"RESHAPE": self.convert_reshape,
"RESIZE_BILINEAR": self.convert_resize_bilinear,
"RESIZE_NEAREST_NEIGHBOR": self.convert_resize_nearest_neighbor,
+ "RFFT2D": self.convert_rfft2d,
"ROUND": functools.partial(self._convert_unary_elemwise,
relax_op=_op.round),
"RSQRT": functools.partial(self._convert_unary_elemwise,
relax_op=_op.rsqrt),
"REVERSE_SEQUENCE": self.convert_reverse_sequence,
@@ -1014,7 +1015,15 @@ class OperatorConverter:
# pylint: disable=no-else-return
def get_tensor_value(self, tensor_wrapper, is_sparse=False):
- """Get tensor buffer value from given tensor wrapper"""
+ """Get tensor buffer value from given tensor wrapper.
+
+ Returns the raw TFLite buffer reinterpreted via
``get_tensor_type_as_numpy``: for
+ COMPLEX64 this is a ``np.complex64`` ndarray with the TFLite shape (no
pair axis).
+ This is distinct from ``get_tensor_expr``, which returns the *Relax*
representation
+ (float32 real/imag pair with a trailing (2,) axis for COMPLEX64). Use
+ ``get_tensor_expr`` for Relax IR construction; use
``get_tensor_value`` only when
+ you need the literal TFLite buffer (e.g. fft_length constant parsing).
+ """
assert isinstance(tensor_wrapper, TensorWrapper)
dtype = self.get_tensor_type_as_numpy(tensor_wrapper)
@@ -1059,6 +1068,39 @@ class OperatorConverter:
return "complex64"
raise NotImplementedError(f"Tensor type {tensor_type!s} is currently
not supported")
+ def _is_tflite_complex64_type(self, tensor_type):
+ """Return whether a TFLite tensor type is COMPLEX64."""
+ from tflite.TensorType import TensorType
+
+ return tensor_type == TensorType.COMPLEX64
+
+ def _unwrap_tflite_tensor(self, tensor):
+ """Return the underlying tflite.Tensor, unwrapping TensorWrapper if
needed."""
+ if isinstance(tensor, TensorWrapper):
+ return tensor.tensor
+ return tensor
+
+ def _get_relax_tensor_dtype(self, tensor):
+ """Return the Relax dtype used to represent a TFLite tensor."""
+ tensor = self._unwrap_tflite_tensor(tensor)
+ tensor_type = tensor.Type() if hasattr(tensor, "Type") else tensor
+ if self._is_tflite_complex64_type(tensor_type):
+ return "float32"
+ return self.get_tensor_type_str(tensor_type)
+
+ def _get_relax_tensor_shape(self, tensor):
+ """Return the Relax shape used to represent a TFLite tensor.
+
+ For COMPLEX64 tensors, the trailing (2,) axis encodes the real/imag
pair.
+ Returns an empty tuple () for rank-0 tensors. Shape elements are
Python ints
+ (not numpy scalars) so the result is safe to feed into TIRX
``T.Buffer(shape, ...)``.
+ """
+ tensor = self._unwrap_tflite_tensor(tensor)
+ shape = to_int_list(tensor.ShapeAsNumpy()) if tensor.ShapeLength() > 0
else ()
+ if self._is_tflite_complex64_type(tensor.Type()):
+ return tuple(shape) + (2,)
+ return tuple(shape)
+
def _get_shape_expr_from_tensor(self, shape_tensor, prefix):
"""Convert a TFLite shape tensor to a Relax shape expression."""
if self.has_expr(shape_tensor.tensor_idx):
@@ -2467,8 +2509,8 @@ class OperatorConverter:
"""Return static shape and dtype metadata for a TFLite tensor."""
if isinstance(tensor, TensorWrapper):
tensor = tensor.tensor
- shape = tuple(tensor.ShapeAsNumpy()) if tensor.ShapeLength() > 0 else
()
- dtype = self.get_tensor_type_str(tensor.Type())
+ shape = self._get_relax_tensor_shape(tensor)
+ dtype = self._get_relax_tensor_dtype(tensor)
return shape, dtype
def _check_tensor_metadata_match(self, actual, expected, op_name,
tensor_role):
@@ -2505,8 +2547,8 @@ class OperatorConverter:
for input_index in self._indices_or_empty(subgraph.InputsAsNumpy()):
tensor = subgraph.Tensors(int(input_index))
input_name = get_tensor_name(subgraph, int(input_index))
- shape = tuple(tensor.ShapeAsNumpy()) if tensor.ShapeLength() > 0
else []
- dtype = self.get_tensor_type_str(tensor.Type())
+ shape = self._get_relax_tensor_shape(tensor)
+ dtype = self._get_relax_tensor_dtype(tensor)
param = relax.Var(input_name, relax.TensorStructInfo(shape=shape,
dtype=dtype))
exp_tab.set_expr(input_name, param)
params.append(param)
@@ -2515,12 +2557,8 @@ class OperatorConverter:
def _get_tensor_param(self, tensor_wrapper):
"""Create a Relax parameter from TFLite tensor metadata."""
name = get_tensor_name(self.subgraph, tensor_wrapper.tensor_idx)
- shape = (
- tuple(tensor_wrapper.tensor.ShapeAsNumpy())
- if tensor_wrapper.tensor.ShapeLength() > 0
- else []
- )
- dtype = self.get_tensor_type_str(tensor_wrapper.tensor.Type())
+ shape = self._get_relax_tensor_shape(tensor_wrapper)
+ dtype = self._get_relax_tensor_dtype(tensor_wrapper)
return relax.Var(name, relax.TensorStructInfo(shape=shape,
dtype=dtype))
def _lower_subgraph_to_function(self, subgraph_index, function_name_hint,
op_name="CALL"):
@@ -5033,6 +5071,82 @@ class OperatorConverter:
return relax.op.memory.view(in_expr, shape=output_shape,
dtype=output_dtype)
+ def convert_rfft2d(self, op):
+ """Convert TFLite RFFT2D for static no-padding/no-truncation shapes."""
+ from tflite.TensorType import TensorType
+
+ input_tensors = self.get_input_tensors(op)
+ output_tensors = self.get_output_tensors(op)
+ if len(input_tensors) != 2 or len(output_tensors) != 1:
+ raise tvm.error.OpNotImplemented("RFFT2D expects two inputs and
one output")
+
+ data_tensor, fft_length_tensor = input_tensors
+ output_tensor = output_tensors[0]
+ if data_tensor.tensor.Type() != TensorType.FLOAT32:
+ raise tvm.error.OpNotImplemented("RFFT2D input must be float32")
+ if not self._is_tflite_complex64_type(output_tensor.tensor.Type()):
+ raise tvm.error.OpNotImplemented("RFFT2D output must be COMPLEX64")
+ if (
+ data_tensor.tensor.Sparsity() is not None
+ or fft_length_tensor.tensor.Sparsity() is not None
+ ):
+ raise tvm.error.OpNotImplemented("RFFT2D does not support sparse
inputs")
+
+ input_shape = tuple(to_int_list(self.get_tensor_shape(data_tensor)))
+ tflite_output_shape =
tuple(to_int_list(self.get_tensor_shape(output_tensor)))
+ if len(input_shape) < 2:
+ raise tvm.error.OpNotImplemented("RFFT2D input rank must be at
least 2")
+
+ try:
+ fft_length_value =
self.get_tensor_value_or_prefetched(fft_length_tensor)
+ except (ValueError, TypeError):
+ raise tvm.error.OpNotImplemented("RFFT2D requires a constant
fft_length") from None
+ if fft_length_value is None:
+ raise tvm.error.OpNotImplemented("RFFT2D requires a constant
fft_length")
+ # Reject non-integer fft_length tensors before astype("int64") can
silently
+ # truncate (e.g. float32 2.7 -> int64 2).
+ if not np.issubdtype(fft_length_value.dtype, np.integer):
+ raise tvm.error.OpNotImplemented(
+ f"RFFT2D fft_length must be an integer tensor, got dtype
{fft_length_value.dtype!r}"
+ )
+ fft_length = fft_length_value.astype("int64")
+ if tuple(fft_length.shape) != (2,):
+ raise tvm.error.OpNotImplemented("RFFT2D fft_length must be a
length-2 tensor")
+
+ height = int(fft_length[0])
+ width = int(fft_length[1])
+ if height <= 0 or width <= 0:
+ raise tvm.error.OpNotImplemented(
+ f"RFFT2D fft_length must be positive, got ({height}, {width})"
+ )
+ if height != int(input_shape[-2]) or width != int(input_shape[-1]):
+ raise tvm.error.OpNotImplemented(
+ "RFFT2D currently supports fft_length matching the input
spatial shape"
+ )
+ expected_tflite_output_shape = input_shape[:-2] + (height, width // 2
+ 1)
+ if tflite_output_shape != expected_tflite_output_shape:
+ raise tvm.error.OpNotImplemented("RFFT2D output shape does not
match fft_length")
+
+ relax_output_shape = self._get_relax_tensor_shape(output_tensor)
+ # Dispatch: power-of-2 sizes use the O(N^2 log N) Cooley-Tukey FFT
kernel;
+ # the remaining (odd / non-power-of-2) shapes fall back to the O(N^4)
DFT
+ # reference. Both kernels share the same call_tir contract, so the
+ # downstream code is kernel-agnostic.
+ if _is_power_of_2(height) and _is_power_of_2(width):
+ prim_func = _build_tflite_rfft2d_fft_primfunc(input_shape,
relax_output_shape)
+ else:
+ prim_func = _build_tflite_rfft2d_primfunc(input_shape,
relax_output_shape)
+ module_builder = self.conversion_state["module_builder"]
+ func_name = f"tflite_rfft2d_{output_tensor.tensor_idx}"
+ gv = module_builder.add_func(prim_func, func_name)
+ data_expr = self.get_tensor_expr(data_tensor)
+ call = relax.call_tir(
+ gv,
+ [data_expr],
+ relax.TensorStructInfo(relax_output_shape, "float32"),
+ )
+ return self.bb.normalize(call)
+
def convert_broadcast_args(self, op):
"""Convert TFLite BROADCAST_ARGS"""
input_tensors = self.get_input_tensors(op)
@@ -7670,7 +7784,15 @@ class OperatorConverter:
return self.exp_tab.new_const(value, dtype=type_str,
source_name=tensor.tensor.Name())
def get_tensor_shape(self, tensor_wrapper):
- """Returns tensor shape. Infers shape if the shape is empty."""
+ """Returns the TFLite tensor shape, inferring it if the TFLite shape
is empty.
+
+ This returns the *raw TFLite* shape (no pair axis), even for COMPLEX64
tensors.
+ It is distinct from ``_get_relax_tensor_shape``, which returns the
*Relax*
+ representation (TFLite shape with a trailing (2,) axis appended for
COMPLEX64).
+ Operators that build Relax IR for COMPLEX64 inputs should use
+ ``_get_relax_tensor_shape``; operators that need the TFLite shape for
validation
+ (e.g. comparing against the model's declared output shape) should use
this method.
+ """
assert isinstance(tensor_wrapper, TensorWrapper), "Expecting
TensorWrapper here"
return (
tensor_wrapper.tensor.ShapeAsNumpy()
@@ -7679,6 +7801,299 @@ class OperatorConverter:
)
+def _is_power_of_2(n):
+ """Return True iff ``n`` is a positive power of 2."""
+ return n > 0 and (n & (n - 1)) == 0
+
+
+def _bit_reversal_swap_pairs(n):
+ """Return the (i, j) index pairs (i < j) for the bit-reversal permutation
of length n.
+
+ For a Cooley-Tukey radix-2 FFT, the input must be permuted by bit-reversing
+ each index in log2(n) bits before the butterfly stages. Precomputing the
+ swap pairs as constants is much cheaper in TIR than computing the
+ bit-reverse on the fly.
+ """
+ assert _is_power_of_2(n), f"bit-reversal requires power of 2, got {n}"
+ length = n.bit_length() - 1 # log2(n)
+ swaps = []
+ for i in range(1, n):
+ j = 0
+ for k in range(length):
+ if i & (1 << k):
+ j |= 1 << (length - 1 - k)
+ if i < j:
+ swaps.append((i, j))
+ return swaps
+
+
+def _build_tflite_rfft2d_primfunc(input_shape, output_pair_shape):
+ """Build a reference TIR kernel for TFLite RFFT2D.
+
+ The TFLite frontend represents complex tensors as float32 real/imag pairs
+ with a trailing dimension of size 2 because TVM does not have a native
+ complex64 dtype. This kernel computes the unnormalized 2-D real FFT over
+ the last two input dimensions and writes that pair representation.
+
+ All trig and accumulation are in float32, so the result agrees with
+ ``np.fft.rfft2`` to about ``1e-5`` absolute tolerance for typical input
+ sizes. Higher-precision backends should override this kernel.
+
+ Notes
+ -----
+ This is a **naive O(B * H * W * H * W) DFT**, not an FFT. For an input of
+ spatial shape (H, W) the inner sum runs H*W times per output position, and
+ there are H*W' output positions per batch (W' = W // 2 + 1). This is
+ intentionally simple for correctness validation against
+ ``np.fft.rfft2``; production use cases with large spatial dimensions should
+ override the kernel with an FFT-based implementation. The outer
+ (batch, out_y, out_x) iteration is structured as S-TIR spatial axes so a
+ downstream ``tvm.tir.schedule`` pass can parallelize it.
+ """
+ from tvm.script.parser import tirx as T
+
+ batch = 1
+ for dim in input_shape[:-2]:
+ batch *= int(dim)
+ height = int(input_shape[-2])
+ width = int(input_shape[-1])
+ out_width = int(output_pair_shape[-2])
+ input_total = batch * height * width
+ output_complex_total = batch * height * out_width
+ neg_two_pi = np.float32(-2.0 * math.pi)
+
+ @T.prim_func(private=True, s_tir=True, check_well_formed=False)
+ def kernel(
+ data: T.Buffer(input_shape, "float32"), output:
T.Buffer(output_pair_shape, "float32")
+ ):
+ # Flat 1D aliases of the multi-dim buffers. The kernel is rank-agnostic
+ # over the leading batch dimensions, so collapsing the index space
+ # avoids special-casing 2D / 3D / 4D input shapes.
+ data_flat = T.decl_buffer((input_total,), "float32", data=data.data)
+ output_flat = T.decl_buffer((output_complex_total * 2,), "float32",
data=output.data)
+ neg_two_pi_const = T.float32(neg_two_pi)
+
+ for b_idx, out_y, out_x in T.grid(batch, height, out_width):
+ with T.sblock("rfft2d"):
+ v_b, v_oy, v_ox = T.axis.remap("SSS", [b_idx, out_y, out_x])
+ real_sum = T.float32(0)
+ imag_sum = T.float32(0)
+ input_base = v_b * height * width
+ for in_y, in_x in T.grid(height, width):
+ phase_y = T.Cast("float32", v_oy) * T.Cast("float32",
in_y) / T.float32(height)
+ phase_x = T.Cast("float32", v_ox) * T.Cast("float32",
in_x) / T.float32(width)
+ angle = neg_two_pi_const * (phase_y + phase_x)
+ value = data_flat[input_base + in_y * width + in_x]
+ real_sum = real_sum + value * T.cos(angle)
+ imag_sum = imag_sum + value * T.sin(angle)
+ flat_out_idx = ((v_b * height + v_oy) * out_width + v_ox) * 2
+ output_flat[flat_out_idx] = real_sum
+ output_flat[flat_out_idx + 1] = imag_sum
+
+ return kernel
+
+
+def _build_tflite_rfft2d_fft_primfunc(input_shape, output_pair_shape):
+ """Build a 2D Cooley-Tukey FFT TIR kernel for TFLite RFFT2D.
+
+ Precondition: both ``input_shape[-2]`` (height) and ``input_shape[-1]``
+ (width) must be positive powers of 2. The frontend dispatches to this
+ kernel via ``_is_power_of_2`` checks; the DFT reference kernel handles
+ the remaining cases (odd / non-power-of-2 sizes).
+
+ Algorithm
+ ---------
+ 1. Copy the real input into a scratch complex buffer (imag = 0) of shape
+ ``(B * H * W,)``.
+ 2. For each batch and each row, run an in-place radix-2 1D FFT of length
+ ``W`` along the width axis.
+ 3. For each batch and each column, run an in-place radix-2 1D FFT of
+ length ``H`` along the height axis (with stride ``W``).
+ 4. Write the first ``W // 2 + 1`` complex bins per row to the output
+ pair representation.
+
+ The bit-reversal permutation required by iterative Cooley-Tukey is done
+ by emitting the (i, j) swap pairs directly in the TIR source (one
+ inlined swap per pair), avoiding the need for runtime index tables.
+
+ Complexity is ``O(B * H * W * (log2(H) + log2(W)))``, vs the DFT
+ reference kernel's ``O(B * H * W * H * W)``.
+ """
+ from tvm.script.parser import tirx as T
+
+ batch = 1
+ for dim in input_shape[:-2]:
+ batch *= int(dim)
+ height = int(input_shape[-2])
+ width = int(input_shape[-1])
+ out_width = int(output_pair_shape[-2])
+ input_total = batch * height * width
+ output_complex_total = batch * height * out_width
+ # Cast to Python float so the f-string-interpolated repr is a plain number
+ # (np.float32's repr is "np.float32(...)", which the TIR parser can't
resolve).
+ neg_two_pi = float(np.float32(-2.0 * math.pi))
+ log2_w = int(math.log2(width))
+ log2_h = int(math.log2(height))
+
+ if not (_is_power_of_2(height) and _is_power_of_2(width)):
+ raise ValueError(
+ f"_build_tflite_rfft2d_fft_primfunc requires power-of-2 height and
width, "
+ f"got H={height}, W={width}"
+ )
+
+ # Precompute the bit-reversal swap pairs at Python level. These are
+ # constant for a given FFT length and will be inlined in the TIR source.
+ # Each emitted line is indented 16 spaces (4 levels: top → b_idx loop →
+ # sblock → row/col loop body) so it lands inside the for loop when
+ # concatenated into the primfunc source.
+ row_swap_stmts = []
+ for i, j in _bit_reversal_swap_pairs(width):
+ row_swap_stmts.append(
+ f" i_idx = row_base + {i}\n"
+ f" j_idx = row_base + {j}\n"
+ f" tmp_r = scratch_real[i_idx]\n"
+ f" scratch_real[i_idx] = scratch_real[j_idx]\n"
+ f" scratch_real[j_idx] = tmp_r\n"
+ f" tmp_i = scratch_imag[i_idx]\n"
+ f" scratch_imag[i_idx] = scratch_imag[j_idx]\n"
+ f" scratch_imag[j_idx] = tmp_i\n"
+ )
+ row_swaps_code = "".join(row_swap_stmts) if row_swap_stmts else "
pass\n"
+
+ col_swap_stmts = []
+ for i, j in _bit_reversal_swap_pairs(height):
+ col_swap_stmts.append(
+ f" i_idx = col_base + {i * width}\n"
+ f" j_idx = col_base + {j * width}\n"
+ f" tmp_r = scratch_real[i_idx]\n"
+ f" scratch_real[i_idx] = scratch_real[j_idx]\n"
+ f" scratch_real[j_idx] = tmp_r\n"
+ f" tmp_i = scratch_imag[i_idx]\n"
+ f" scratch_imag[i_idx] = scratch_imag[j_idx]\n"
+ f" scratch_imag[j_idx] = tmp_i\n"
+ )
+ col_swaps_code = "".join(col_swap_stmts) if col_swap_stmts else "
pass\n"
+
+ # Build the per-stage butterfly code with the stage loop fully unrolled
+ # at primfunc-construction time. After unrolling, all loop bounds
+ # (block_start, k) are compile-time integers, so the TIR parser doesn't
+ # need to reason about runtime loop extents and the scheduler can
+ # see static twiddle factors instead of runtime trig calls.
+ def _stage_stmts(stage_count, length, indent, stride=1,
base_expr="row_base"):
+ """Generate fully-unrolled Cooley-Tukey butterfly stage bodies.
+
+ ``base_expr`` is the TIR expression holding the base offset of the
+ FFT being transformed (e.g. ``"row_base"`` for rows or
+ ``"col_base"`` for columns). ``stride`` is the integer distance
+ between adjacent butterfly taps: 1 for the row-FFT (contiguous
+ elements) and ``width`` for the column-FFT (strided access).
+ """
+ out = []
+ for stage in range(1, stage_count + 1):
+ m_val = 1 << stage
+ half_val = m_val >> 1
+ for block_start in range(0, length, m_val):
+ for k in range(half_val):
+ if stride == 1:
+ a_idx = f"{base_expr} + {block_start} + {k}"
+ b_idx_expr = f"{base_expr} + {block_start} + {k} +
{half_val}"
+ else:
+ a_idx = f"{base_expr} + ({block_start} + {k}) *
{stride}"
+ b_idx_expr = f"{base_expr} + ({block_start} + {k} +
{half_val}) * {stride}"
+ angle_val = float(np.float32(neg_two_pi * k / m_val))
+ w_real_val = float(np.float32(math.cos(angle_val)))
+ w_imag_val = float(np.float32(math.sin(angle_val)))
+ out.extend(
+ [
+ f"{indent}a_idx = {a_idx}\n",
+ f"{indent}b_idx_local = {b_idx_expr}\n",
+ f"{indent}t_real = scratch_real[b_idx_local] *
T.float32({w_real_val!r}) - scratch_imag[b_idx_local] *
T.float32({w_imag_val!r})\n",
+ f"{indent}t_imag = scratch_real[b_idx_local] *
T.float32({w_imag_val!r}) + scratch_imag[b_idx_local] *
T.float32({w_real_val!r})\n",
+ f"{indent}u_real = scratch_real[a_idx]\n",
+ f"{indent}u_imag = scratch_imag[a_idx]\n",
+ f"{indent}scratch_real[a_idx] = u_real + t_real\n",
+ f"{indent}scratch_imag[a_idx] = u_imag + t_imag\n",
+ f"{indent}scratch_real[b_idx_local] = u_real -
t_real\n",
+ f"{indent}scratch_imag[b_idx_local] = u_imag -
t_imag\n",
+ ]
+ )
+ return "".join(out)
+
+ row_stages_code = _stage_stmts(
+ log2_w, width, " ", stride=1, base_expr="row_base"
+ )
+ col_stages_code = _stage_stmts(
+ log2_h, height, " ", stride=width, base_expr="col_base"
+ )
+
+ # Build the primfunc source. The bit-reversal swaps are inlined (one
+ # unconditional block per (i, j) pair) and the butterfly stages are
+ # fully unrolled, so the TIR parser sees ordinary statements rather than
+ # runtime table lookups or runtime-magnitude loops. The body is wrapped
+ # in a single S-TIR block over the batch dimension so the Relax
+ # pipeline (which expects an SBlockRealize at the primfunc body) accepts
+ # this kernel.
+ primfunc_source = (
+ "from tvm.script.parser import tirx as T\n"
+ "@T.prim_func(private=True, s_tir=True, check_well_formed=False)\n"
+ "def kernel(\n"
+ f" data: T.Buffer({tuple(int(x) for x in input_shape)},
'float32'),\n"
+ f" output: T.Buffer({tuple(int(x) for x in output_pair_shape)},
'float32'),\n"
+ "):\n"
+ f" data_flat = T.decl_buffer(({input_total},), 'float32',
data=data.data)\n"
+ f" output_flat = T.decl_buffer(({output_complex_total * 2},),
'float32', data=output.data)\n"
+ f" scratch_real = T.decl_buffer(({input_total},), 'float32')\n"
+ f" scratch_imag = T.decl_buffer(({input_total},), 'float32')\n"
+ f" for b_idx in T.serial({batch}):\n"
+ f" with T.sblock('rfft2d_fft'):\n"
+ f" v_b = T.axis.remap('S', [b_idx])\n"
+ f" # Initialize scratch from real input; imag = 0.\n"
+ f" for i in T.serial({height * width}):\n"
+ f" src = v_b * {height * width} + i\n"
+ f" scratch_real[src] = data_flat[src]\n"
+ f" scratch_imag[src] = T.float32(0)\n"
+ f" # FFT along width axis (one 1D FFT per row).\n"
+ f" for row in T.serial({height}):\n"
+ f" row_base = v_b * {height * width} + row * {width}\n"
+ f" # Bit-reversal permutation (inlined swaps).\n"
+ f"{row_swaps_code}"
+ f" # Cooley-Tukey butterfly stages (fully unrolled).\n"
+ f"{row_stages_code}"
+ f" # FFT along height axis (one 1D FFT per column, stride =
width).\n"
+ f" for col in T.serial({width}):\n"
+ f" col_base = v_b * {height * width} + col\n"
+ f" # Bit-reversal permutation (strided, inlined
swaps).\n"
+ f"{col_swaps_code}"
+ f" # Cooley-Tukey butterfly stages (strided, fully
unrolled).\n"
+ f"{col_stages_code}"
+ f" # Write the first out_width complex bins per row to
output.\n"
+ f" for row in T.serial({height}):\n"
+ f" for out_x in T.serial({out_width}):\n"
+ f" src = v_b * {height * width} + row * {width} +
out_x\n"
+ f" dst = ((v_b * {height} + row) * {out_width} +
out_x) * 2\n"
+ f" output_flat[dst] = scratch_real[src]\n"
+ f" output_flat[dst + 1] = scratch_imag[src]\n"
+ )
+
+ namespace = {"T": T, "tirx": T}
+ # Register the generated source in linecache so the TIR parser (which calls
+ # inspect.getsourcelines) can find it. The fake filename is stable across
+ # calls — we include input/output shapes so the cache is
self-disambiguating
+ # for any callers who want to introspect the generated source.
+ import linecache as _linecache
+
+ fake_file = f"<tflite_rfft2d_fft_primfunc H={height} W={width}
outW={out_width}>"
+ _linecache.cache[fake_file] = (
+ len(primfunc_source.splitlines()),
+ None,
+ [line + "\n" for line in primfunc_source.splitlines()],
+ fake_file,
+ )
+ code = compile(primfunc_source, fake_file, "exec")
+ exec(code, namespace)
+ return namespace["kernel"]
+
+
# Constants for the Random123 counter-based PRNGs used by
STABLEHLO_RNG_BIT_GENERATOR,
# matching tensorflow/lite/kernels/rng_util.cc.
_STABLEHLO_RNG_THREEFRY_PARITY = 0x1BD11BDA
diff --git a/tests/python/relax/test_frontend_tflite.py
b/tests/python/relax/test_frontend_tflite.py
index e9a842b8df..01beaadd7b 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -4082,6 +4082,285 @@ def _run_no_input_module(mod):
return _run_module(mod)
+def _complex64_to_pair(value):
+ value = np.asarray(value, dtype=np.complex64)
+ return np.stack([value.real, value.imag], axis=-1).astype("float32")
+
+
+def _build_tflite_rfft2d_model(*, input_shape, fft_length, output_shape):
+ """Build a minimal TFLite RFFT2D model."""
+ builder = flatbuffers.Builder(1024)
+ builtin_op = _get_builtin_operator("RFFT2D")
+ op_code = _build_operator_code(builder, builtin_op)
+ tensors = [
+ _build_tensor(builder, 0, input_shape,
tensor_type=_tfl_tensor_type.FLOAT32),
+ _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.INT32),
+ _build_tensor(builder, 2, output_shape,
tensor_type=_tfl_tensor_type.COMPLEX64),
+ ]
+ op = _build_operator(builder, 0, [0, 1], [2])
+ subgraph = _build_subgraph(builder, tensors=tensors, operators=[op],
inputs=[0], outputs=[2])
+ buffers = [
+ _build_buffer(builder),
+ _build_buffer(builder, np.array(fft_length, dtype=np.int32).tobytes()),
+ _build_buffer(builder),
+ ]
+ return _finish_tflite_model(
+ builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers
+ )
+
+
+def test_rfft2d_static_pair_output():
+ """TFLite RFFT2D emits a call_tir kernel with float32 real/imag pair
output."""
+ mod = _load_model_from_buffer(
+ _build_tflite_rfft2d_model(
+ input_shape=[2, 4],
+ fft_length=[2, 4],
+ output_shape=[2, 3],
+ )
+ )
+
+ mod_script = mod.script()
+ assert "tflite_rfft2d" in mod_script
+ assert "R.call_tir" in mod_script
+ assert 'R.Tensor((2, 3, 2), dtype="float32")' in mod_script
+
+ data = np.array([[1.0, -2.0, 3.0, 4.0], [5.0, 6.0, -7.0, 8.0]],
dtype="float32")
+ expected = np.fft.rfft2(data).astype(np.complex64)
+ # atol accommodates the float32 reference kernel: numpy's rfft2 internally
uses
+ # float64, while the reference TIR kernel accumulates in float32 (see
+ # _build_tflite_rfft2d_primfunc docstring).
+ np.testing.assert_allclose(
+ _run_module(mod, data), _complex64_to_pair(expected), rtol=1e-5,
atol=1e-5
+ )
+
+
+def test_rfft2d_static_pair_output_with_batch():
+ """RFFT2D computes over the last two axes and preserves leading batch
dimensions."""
+ mod = _load_model_from_buffer(
+ _build_tflite_rfft2d_model(
+ input_shape=[2, 2, 4],
+ fft_length=[2, 4],
+ output_shape=[2, 2, 3],
+ )
+ )
+
+ data = np.array(
+ [
+ [[1.0, -2.0, 3.0, 4.0], [5.0, 6.0, -7.0, 8.0]],
+ [[-1.0, 2.0, 0.5, -4.0], [3.5, -6.0, 7.0, 1.0]],
+ ],
+ dtype="float32",
+ )
+ expected = np.fft.rfft2(data).astype(np.complex64)
+ np.testing.assert_allclose(
+ _run_module(mod, data), _complex64_to_pair(expected), rtol=1e-5,
atol=1e-5
+ )
+
+
+def test_rfft2d_odd_width_pair_output():
+ """RFFT2D handles odd width: output has width//2 + 1 bins (TFLite
convention)."""
+ mod = _load_model_from_buffer(
+ _build_tflite_rfft2d_model(
+ input_shape=[3, 5],
+ fft_length=[3, 5],
+ output_shape=[3, 3], # 5 // 2 + 1 = 3
+ )
+ )
+
+ data = np.array(
+ [[1.0, -2.0, 3.0, 4.0, -5.0], [0.5, 6.0, -7.0, 8.0, 2.5], [-1.5, 4.0,
0.0, -3.0, 1.0]],
+ dtype="float32",
+ )
+ expected = np.fft.rfft2(data).astype(np.complex64)
+ # atol accommodates the float32 reference kernel (see
+ # _build_tflite_rfft2d_primfunc docstring).
+ np.testing.assert_allclose(
+ _run_module(mod, data), _complex64_to_pair(expected), rtol=1e-5,
atol=1e-5
+ )
+
+
+def test_rfft2d_int64_fft_length():
+ """RFFT2D accepts INT64 fft_length constant (TFLite schema allows either
int32 or int64)."""
+ builder = flatbuffers.Builder(1024)
+ rfft_op_code = _build_operator_code(builder,
_get_builtin_operator("RFFT2D"))
+ tensors = [
+ _build_tensor(builder, 0, [2, 4],
tensor_type=_tfl_tensor_type.FLOAT32),
+ _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.INT64),
+ _build_tensor(builder, 2, [2, 3],
tensor_type=_tfl_tensor_type.COMPLEX64),
+ ]
+ op = _build_operator(builder, 0, [0, 1], [2])
+ subgraph = _build_subgraph(builder, tensors=tensors, operators=[op],
inputs=[0], outputs=[2])
+ buffers = [
+ _build_buffer(builder),
+ _build_buffer(builder, np.array([2, 4], dtype=np.int64).tobytes()),
+ _build_buffer(builder),
+ ]
+ buf = _finish_tflite_model(
+ builder, subgraph=subgraph, operator_codes=[rfft_op_code],
buffers=buffers
+ )
+ mod = _load_model_from_buffer(buf)
+
+ data = np.array([[1.0, -2.0, 3.0, 4.0], [5.0, 6.0, -7.0, 8.0]],
dtype="float32")
+ expected = np.fft.rfft2(data).astype(np.complex64)
+ np.testing.assert_allclose(
+ _run_module(mod, data), _complex64_to_pair(expected), rtol=1e-5,
atol=1e-5
+ )
+
+
+def test_rfft2d_4d_input_pair_output():
+ """RFFT2D accepts 4D input and preserves leading batch dimensions."""
+ mod = _load_model_from_buffer(
+ _build_tflite_rfft2d_model(
+ input_shape=[2, 3, 4, 5], # batch=6, H=4, W=5
+ fft_length=[4, 5],
+ output_shape=[2, 3, 4, 3], # 5 // 2 + 1 = 3
+ )
+ )
+
+ rng = np.random.RandomState(0)
+ data = (rng.randn(2, 3, 4, 5) * 0.5).astype("float32")
+ expected = np.fft.rfft2(data).astype(np.complex64)
+ # 4D test accumulates 20 inner terms per output; use a slightly larger atol
+ # than the 2D case (which accumulates 4-8 terms).
+ np.testing.assert_allclose(
+ _run_module(mod, data), _complex64_to_pair(expected), rtol=1e-5,
atol=1e-4
+ )
+
+
+def test_rfft2d_minimal_1x1_pair_output():
+ """RFFT2D on a [1, 1] input: the only output is the DC component (sum of
inputs)."""
+ mod = _load_model_from_buffer(
+ _build_tflite_rfft2d_model(
+ input_shape=[1, 1],
+ fft_length=[1, 1],
+ output_shape=[1, 1],
+ )
+ )
+
+ data = np.array([[3.5]], dtype="float32")
+ expected = np.fft.rfft2(data).astype(np.complex64)
+ np.testing.assert_allclose(
+ _run_module(mod, data), _complex64_to_pair(expected), rtol=1e-5,
atol=1e-5
+ )
+
+
+def test_rfft2d_fft_path_8x8():
+ """RFFT2D on a square 8x8 input exercises the Cooley-Tukey FFT dispatch
path."""
+ mod = _load_model_from_buffer(
+ _build_tflite_rfft2d_model(
+ input_shape=[8, 8],
+ fft_length=[8, 8],
+ output_shape=[8, 5],
+ )
+ )
+
+ np.random.seed(0xCAFE)
+ data = np.random.randn(8, 8).astype("float32")
+ expected = np.fft.rfft2(data).astype(np.complex64)
+ # The FFT path uses float32 twiddles (cos/sin) and float32 butterfly
+ # accumulation, so the error vs. numpy's float64 reference is in the
+ # 1e-4 range on these random inputs.
+ np.testing.assert_allclose(
+ _run_module(mod, data), _complex64_to_pair(expected), rtol=1e-4,
atol=1e-4
+ )
+
+
+def test_rfft2d_fft_path_4x4():
+ """RFFT2D on a 4x4 input: smallest case where both row and column FFTs do
real work."""
+ mod = _load_model_from_buffer(
+ _build_tflite_rfft2d_model(
+ input_shape=[4, 4],
+ fft_length=[4, 4],
+ output_shape=[4, 3],
+ )
+ )
+
+ np.random.seed(0xFEED)
+ data = np.random.randn(4, 4).astype("float32")
+ expected = np.fft.rfft2(data).astype(np.complex64)
+ np.testing.assert_allclose(
+ _run_module(mod, data), _complex64_to_pair(expected), rtol=1e-4,
atol=1e-4
+ )
+
+
+def test_rfft2d_fft_path_2x2x4x8():
+ """RFFT2D on a 4D input with power-of-2 height/width exercises the FFT
path with batch."""
+ mod = _load_model_from_buffer(
+ _build_tflite_rfft2d_model(
+ input_shape=[2, 2, 4, 8],
+ fft_length=[4, 8],
+ output_shape=[2, 2, 4, 5],
+ )
+ )
+
+ np.random.seed(0xBEEF)
+ data = np.random.randn(2, 2, 4, 8).astype("float32")
+ expected = np.fft.rfft2(data, axes=(-2, -1)).astype(np.complex64)
+ np.testing.assert_allclose(
+ _run_module(mod, data), _complex64_to_pair(expected), rtol=1e-4,
atol=1e-4
+ )
+
+
+def test_rfft2d_fft_path_16x16():
+ """RFFT2D on a 16x16 input: a larger FFT to check that the unrolled kernel
scales."""
+ mod = _load_model_from_buffer(
+ _build_tflite_rfft2d_model(
+ input_shape=[16, 16],
+ fft_length=[16, 16],
+ output_shape=[16, 9],
+ )
+ )
+
+ np.random.seed(0xDEAD)
+ data = np.random.randn(16, 16).astype("float32")
+ expected = np.fft.rfft2(data).astype(np.complex64)
+ np.testing.assert_allclose(
+ _run_module(mod, data), _complex64_to_pair(expected), rtol=1e-4,
atol=1e-4
+ )
+
+
+def test_rfft2d_mismatched_fft_length_unsupported():
+ """RFFT2D padding/truncation cases are guarded until explicitly
implemented."""
+ buf = _build_tflite_rfft2d_model(
+ input_shape=[2, 4],
+ fft_length=[4, 4],
+ output_shape=[4, 3],
+ )
+ if hasattr(tflite.Model, "Model"):
+ tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+ else:
+ tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+
+ with pytest.raises(tvm.error.OpNotImplemented, match="fft_length"):
+ from_tflite(tflite_model)
+
+
+def test_rfft2d_dynamic_fft_length_unsupported():
+ """RFFT2D requires fft_length to be a constant tensor."""
+ builder = flatbuffers.Builder(1024)
+ rfft_op_code = _build_operator_code(builder,
_get_builtin_operator("RFFT2D"))
+ tensors = [
+ _build_tensor(builder, 0, [2, 4],
tensor_type=_tfl_tensor_type.FLOAT32),
+ _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.INT32),
+ _build_tensor(builder, 2, [2, 3],
tensor_type=_tfl_tensor_type.COMPLEX64),
+ ]
+ op = _build_operator(builder, 0, [0, 1], [2])
+ subgraph = _build_subgraph(builder, tensors=tensors, operators=[op],
inputs=[0, 1], outputs=[2])
+ buf = _finish_tflite_model(
+ builder,
+ subgraph=subgraph,
+ operator_codes=[rfft_op_code],
+ buffers=[_build_buffer(builder), _build_buffer(builder),
_build_buffer(builder)],
+ )
+ if hasattr(tflite.Model, "Model"):
+ tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+ else:
+ tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+
+ with pytest.raises(tvm.error.OpNotImplemented, match="requires a constant
fft_length"):
+ from_tflite(tflite_model)
+
+
def _build_tflite_call_model(
call_subgraph_index=1,
callee_inputs=None,