This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 39535048b6 [Relax][Frontend][TFLite] Add RFFT2D op and supporting TIR 
kernels (#19812)
39535048b6 is described below

commit 39535048b62b5a23c941a0e98976bd53da050f1d
Author: Hongyi Wu <[email protected]>
AuthorDate: Thu Jun 18 02:25:05 2026 +0800

    [Relax][Frontend][TFLite] Add RFFT2D op and supporting TIR kernels (#19812)
    
    ## Summary
    
    This PR adds Relax TFLite frontend support for the TFLite builtin
    `RFFT2D`
    operator (issue #19519 item C — FFT / complex operators). It is the
    follow-up to upstream PR #19763, which already merged the `REAL` /
    `IMAG` /
    `COMPLEX_ABS` subset of item C; this PR completes the subset with
    `RFFT2D`
    itself.
    
    `RFFT2D` computes a 2D real FFT over the last two input axes and returns
    a
    real/imag pair tensor of shape `[..., H, W // 2 + 1, 2]`. Relax does not
    have
    a native complex64 dtype, so the pair output is represented as a
    `float32`
    tensor with a trailing axis of size 2, matching the convention PR #19763
    established for `REAL` / `IMAG` / `COMPLEX_ABS`. `RFFT2D` accepts a real
    `float32` input and emits a `float32` pair output of the same dtype; the
    frontend does not need to materialize any COMPLEX64 in-memory
    representation
    of its own.
    
    The supported subset is the static no-padding / no-truncation path:
    
    - the input's last two dimensions must match `fft_length`
    - `fft_length` must be a length-2 integer constant (int32 or int64)
    - the output shape is `input_shape[:-2] + (H, W // 2 + 1, 2)`
    - sparse inputs are rejected
    
    ## Design
    
    ### Dispatch Between Two TIR Kernels
    
    `convert_rfft2d` selects one of two TIR primfuncs at lowering time,
    based on
    whether the spatial axes are powers of two:
    
    ```python
    if _is_power_of_2(height) and _is_power_of_2(width):
        prim_func = _build_tflite_rfft2d_fft_primfunc(input_shape, 
relax_output_shape)
    else:
        prim_func = _build_tflite_rfft2d_primfunc(input_shape, 
relax_output_shape)
    ```
    
    Both kernels share the same `call_tir` contract, so downstream code is
    kernel-agnostic.
    
    #### DFT Reference Kernel (`_build_tflite_rfft2d_primfunc`)
    
    A naive O(B · H · W · H · W) DFT over the last two input axes. The outer
    (batch, out_y, out_x) iteration is structured as S-TIR spatial axes so a
    downstream `tir.schedule` pass can parallelize it. Trig and accumulation
    are
    in float32; the result agrees with `np.fft.rfft2` to about `1e-5`
    absolute
    tolerance for typical input sizes. This kernel is the fallback for odd
    or
    non-power-of-2 spatial sizes.
    
    #### Cooley-Tukey FFT Kernel (`_build_tflite_rfft2d_fft_primfunc`)
    
    An O(B · H · W · (log2(H) + log2(W))) radix-2 Cooley-Tukey FFT,
    dispatched
    when both `H` and `W` are positive powers of 2. The primfunc source is
    generated as a TIR string at construction time and registered in
    `linecache`
    so `tirx.parser` can resolve it. The bit-reversal permutation and
    butterfly
    stages are precomputed in Python and inlined as direct scratch-buffer
    assignments, so all loop bounds and twiddle factors are compile-time
    literals:
    
    1. Copy the real input into `scratch_real`; initialize `scratch_imag` to
    0.
    2. For each batch and each row, run an in-place 1D FFT of length `W`
    along
       the width axis using scratch buffers.
    3. For each batch and each column, run an in-place 1D FFT of length `H`
       along the height axis with stride `W`.
    4. Write the first `W // 2 + 1` complex bins per row to the output pair
       representation.
    
    The bit-reversal swap pairs and butterfly stage bodies are precomputed
    at
    primfunc-construction time because `tirx.parser` does not currently
    accept
    runtime `T.serial` bounds, and twiddle factors must be `T.float32`
    literals
    to avoid `Undefined variable` errors. The fake linecache filename
    `<tflite_rfft2d_fft_primfunc H=8 W=8 outW=5>` is dimension-aware, so
    generated-source stack traces are readable.
    
    ### COMPLEX64 Pair Representation Helpers
    
    The frontend represents TFLite COMPLEX64 tensors as float32 real/imag
    pairs
    with a trailing axis of size 2 (since Relax has no native complex64
    dtype).
    Four small helpers centralize the rule so future complex ops can plug in
    without re-implementing the pair-axis layout:
    
    - `_is_tflite_complex64_type` — checks whether a TFLite tensor type is
      `COMPLEX64`.
    - `_unwrap_tflite_tensor` — unwraps a `TensorWrapper` to the raw
      `tflite.Tensor`.
    - `_get_relax_tensor_dtype` — returns the Relax dtype used to represent
    a
    TFLite tensor (`"float32"` for COMPLEX64, otherwise the standard
    mapping).
    - `_get_relax_tensor_shape` — returns the Relax shape (TFLite shape with
    a
      trailing `(2,)` axis appended for COMPLEX64).
    
    The 3 callers that construct Relax parameters from TFLite metadata
    (`_get_static_tensor_shape_dtype` / `_set_subgraph_input_params` /
    `_get_tensor_param`) now go through these helpers. The pair-axis
    invariant
    is documented on `get_tensor_value` and `get_tensor_shape`, which both
    return the *raw TFLite* shape (no pair axis) for downstream callers that
    need to compare against the model's declared output shape.
    
    ### Boundary Validation
    
    `convert_rfft2d` validates rank, dtype, fft_length shape, integer-ness,
    positivity, fft_length == input spatial shape, output shape agreement,
    and
    the absence of sparse inputs before emitting the `call_tir`. Edge cases
    (sparse, non-integer fft_length, zero/negative fft_length, mismatched
    fft_length, 1×1 spatial size) each produce a targeted `OpNotImplemented`
    diagnostic.
    
    ## Operator Support
    
    | Operator | TFLite options | Relax lowering | Supported subset |
    |---|---|---|---|
    | `RFFT2D` | input `float32`, constant length-2 integer `fft_length`,
    `COMPLEX64` output | `call_tir` to a generated TIR kernel | static
    no-padding/no-truncation; `H`, `W` arbitrary; Cooley-Tukey dispatched
    when both are powers of 2 |
    
    ## Safety Checks
    
    - Non-float32 input raises `OpNotImplemented("RFFT2D input must be
    float32")`.
    - Non-COMPLEX64 output raises `OpNotImplemented("RFFT2D output must be
    COMPLEX64")`.
    - Sparse inputs raise `OpNotImplemented("RFFT2D does not support sparse
    inputs")`.
    - Non-constant `fft_length` raises `OpNotImplemented("RFFT2D requires a
    constant fft_length")`.
    - Non-integer `fft_length` raises `OpNotImplemented("RFFT2D fft_length
    must be an integer tensor")`.
    - Wrong-shape `fft_length` (not length 2) raises
    `OpNotImplemented("RFFT2D fft_length must be a length-2 tensor")`.
    - Non-positive `fft_length` raises `OpNotImplemented("RFFT2D fft_length
    must be positive")`.
    - `fft_length` not matching the input's last two dims raises
    `OpNotImplemented("RFFT2D currently supports fft_length matching the
    input spatial shape")`.
    - Mismatched output shape raises `OpNotImplemented("RFFT2D output shape
    does not match fft_length")`.
    - Input rank < 2 raises `OpNotImplemented("RFFT2D input rank must be at
    least 2")`.
    
    ## Not Included
    
    - `RFFT2D` with `fft_length` not matching the input's last two
    dimensions
      (padding / truncation path).
    - Other complex-data operators: `REAL` / `IMAG` / `COMPLEX_ABS` are
    already
      handled by upstream PR #19763 and are out of scope for this PR.
    - A frontend guard that rejects COMPLEX64 inputs flowing into
    non-complex
      ops. With PR #19763 providing the only other complex ops, models that
      contain only `RFFT2D` (with float32 input / COMPLEX64 output) and the
      three upstream ops are fully supported. A generic COMPLEX64 guard was
      intentionally not added here to keep this PR scoped to `RFFT2D`.
    - User-override validation for `shape_dict` / `dtype_dict` on COMPLEX64
      inputs. After PR #19763, the frontend no longer auto-appends a
    trailing pair axis to user overrides; a user passing the natural TFLite
      shape without the pair axis will now fall through to the standard
      metadata mismatch path.
    - Higher-precision FFT kernel (e.g. SIMD). The float32 twiddle / float32
      accumulation paths match `np.fft.rfft2` to `~1e-4` on the Cooley-Tukey
      path; large spatial dimensions may need a future optimized lowering or
      backend-specific implementation.
    
    ## Tests
    
    The tests manually build minimal TFLite flatbuffers, run the frontend,
    and
    compare against `np.fft.rfft2`. Edge cases raise `OpNotImplemented`. The
    DFT-path tests use `atol=1e-5`; the FFT-path tests use `atol=1e-4`
    because
    twiddle factors are float32 literals.
    
    | Test | Path | Shape | Coverage |
    |---|---|---|---|
    | `test_rfft2d_static_pair_output` | DFT | `[2, 4]` | Baseline 2D, even
    width; also asserts Relax script contains the `tflite_rfft2d` kernel
    name and pair-output struct-info |
    | `test_rfft2d_static_pair_output_with_batch` | DFT | `[2, 2, 4]` |
    Leading batch dims preserved |
    | `test_rfft2d_odd_width_pair_output` | DFT | `[3, 5]` | Odd width → `W
    // 2 + 1` output bins |
    | `test_rfft2d_int64_fft_length` | DFT | `[2, 4]` | INT64 fft_length
    constant (TFLite schema allows either int32 or int64) |
    | `test_rfft2d_4d_input_pair_output` | DFT | `[2, 3, 4, 5]` | 4D input
    with batch and odd width |
    | `test_rfft2d_minimal_1x1_pair_output` | DFT | `[1, 1]` | Edge case:
    trivial 1×1 FFT |
    | `test_rfft2d_mismatched_fft_length_unsupported` | — | `[2, 4]`
    (fft=`[4, 4]`) | fft_length != input spatial shape guard |
    | `test_rfft2d_dynamic_fft_length_unsupported` | — | `[2, 4]` |
    Dynamic/non-constant fft_length guard |
    | `test_rfft2d_fft_path_4x4` | **FFT** | `[4, 4]` | Smallest
    power-of-two (4×4) where both row and column FFTs do real work |
    | `test_rfft2d_fft_path_8x8` | **FFT** | `[8, 8]` | Square 8×8
    power-of-two |
    | `test_rfft2d_fft_path_16x16` | **FFT** | `[16, 16]` | Larger FFT,
    kernel scaling check |
    | `test_rfft2d_fft_path_2x2x4x8` | **FFT** | `[2, 2, 4, 8]` | 4D
    power-of-two with batch |
    
    Local validation:
    
    ```bash
    python -m py_compile \
      python/tvm/relax/frontend/tflite/tflite_frontend.py \
      tests/python/relax/test_frontend_tflite.py
    
    python -m ruff check \
      python/tvm/relax/frontend/tflite/tflite_frontend.py \
      tests/python/relax/test_frontend_tflite.py
    
    python -m pytest \
      tests/python/relax/test_frontend_tflite.py \
      -k rfft2d -v
    ```
    
    Result:
    
    ```text
    py_compile: passed
    ruff check: All checks passed
    pre-commit run --files: passed
    rfft2d tests: 12 passed
    ```
    
    ## References
    
    - Issue [#19519](https://github.com/apache/tvm/issues/19519) item C: FFT
    /
      complex operators (`RFFT2D`, `REAL`, `IMAG`, `COMPLEX_ABS`).
    - Upstream PR #19763: `[Relax][Frontend][TFLite] Add support for
    FFT/complex operators: REAL, IMAG, COMPLEX_ABS` (merge commit
    `9d6e1cf0`).
      This PR's RFFT2D output pair layout matches the COMPLEX64 pair
    representation PR #19763 established for `REAL` / `IMAG` /
    `COMPLEX_ABS`,
      so downstream ops from PR #19763 can consume RFFT2D output directly.
    - Tracking issue [#19764](https://github.com/apache/tvm/issues/19764)
    for
    the longer-term native `relax.op.signal.rfft2d` / registered TOPI
    backend
    path. This PR is a frontend-local lowering for TFLite `RFFT2D`, not the
      native Relax signal op tracked there.
---
 .../tvm/relax/frontend/tflite/tflite_frontend.py   | 439 ++++++++++++++++++++-
 tests/python/relax/test_frontend_tflite.py         | 279 +++++++++++++
 2 files changed, 706 insertions(+), 12 deletions(-)

diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py 
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 3bd87d0af4..0edfc00ce9 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -309,6 +309,7 @@ class OperatorConverter:
             "RESHAPE": self.convert_reshape,
             "RESIZE_BILINEAR": self.convert_resize_bilinear,
             "RESIZE_NEAREST_NEIGHBOR": self.convert_resize_nearest_neighbor,
+            "RFFT2D": self.convert_rfft2d,
             "ROUND": functools.partial(self._convert_unary_elemwise, 
relax_op=_op.round),
             "RSQRT": functools.partial(self._convert_unary_elemwise, 
relax_op=_op.rsqrt),
             "REVERSE_SEQUENCE": self.convert_reverse_sequence,
@@ -1014,7 +1015,15 @@ class OperatorConverter:
 
     # pylint: disable=no-else-return
     def get_tensor_value(self, tensor_wrapper, is_sparse=False):
-        """Get tensor buffer value from given tensor wrapper"""
+        """Get tensor buffer value from given tensor wrapper.
+
+        Returns the raw TFLite buffer reinterpreted via 
``get_tensor_type_as_numpy``: for
+        COMPLEX64 this is a ``np.complex64`` ndarray with the TFLite shape (no 
pair axis).
+        This is distinct from ``get_tensor_expr``, which returns the *Relax* 
representation
+        (float32 real/imag pair with a trailing (2,) axis for COMPLEX64). Use
+        ``get_tensor_expr`` for Relax IR construction; use 
``get_tensor_value`` only when
+        you need the literal TFLite buffer (e.g. fft_length constant parsing).
+        """
         assert isinstance(tensor_wrapper, TensorWrapper)
 
         dtype = self.get_tensor_type_as_numpy(tensor_wrapper)
@@ -1059,6 +1068,39 @@ class OperatorConverter:
             return "complex64"
         raise NotImplementedError(f"Tensor type {tensor_type!s} is currently 
not supported")
 
+    def _is_tflite_complex64_type(self, tensor_type):
+        """Return whether a TFLite tensor type is COMPLEX64."""
+        from tflite.TensorType import TensorType
+
+        return tensor_type == TensorType.COMPLEX64
+
+    def _unwrap_tflite_tensor(self, tensor):
+        """Return the underlying tflite.Tensor, unwrapping TensorWrapper if 
needed."""
+        if isinstance(tensor, TensorWrapper):
+            return tensor.tensor
+        return tensor
+
+    def _get_relax_tensor_dtype(self, tensor):
+        """Return the Relax dtype used to represent a TFLite tensor."""
+        tensor = self._unwrap_tflite_tensor(tensor)
+        tensor_type = tensor.Type() if hasattr(tensor, "Type") else tensor
+        if self._is_tflite_complex64_type(tensor_type):
+            return "float32"
+        return self.get_tensor_type_str(tensor_type)
+
+    def _get_relax_tensor_shape(self, tensor):
+        """Return the Relax shape used to represent a TFLite tensor.
+
+        For COMPLEX64 tensors, the trailing (2,) axis encodes the real/imag 
pair.
+        Returns an empty tuple () for rank-0 tensors. Shape elements are 
Python ints
+        (not numpy scalars) so the result is safe to feed into TIRX 
``T.Buffer(shape, ...)``.
+        """
+        tensor = self._unwrap_tflite_tensor(tensor)
+        shape = to_int_list(tensor.ShapeAsNumpy()) if tensor.ShapeLength() > 0 
else ()
+        if self._is_tflite_complex64_type(tensor.Type()):
+            return tuple(shape) + (2,)
+        return tuple(shape)
+
     def _get_shape_expr_from_tensor(self, shape_tensor, prefix):
         """Convert a TFLite shape tensor to a Relax shape expression."""
         if self.has_expr(shape_tensor.tensor_idx):
@@ -2467,8 +2509,8 @@ class OperatorConverter:
         """Return static shape and dtype metadata for a TFLite tensor."""
         if isinstance(tensor, TensorWrapper):
             tensor = tensor.tensor
-        shape = tuple(tensor.ShapeAsNumpy()) if tensor.ShapeLength() > 0 else 
()
-        dtype = self.get_tensor_type_str(tensor.Type())
+        shape = self._get_relax_tensor_shape(tensor)
+        dtype = self._get_relax_tensor_dtype(tensor)
         return shape, dtype
 
     def _check_tensor_metadata_match(self, actual, expected, op_name, 
tensor_role):
@@ -2505,8 +2547,8 @@ class OperatorConverter:
         for input_index in self._indices_or_empty(subgraph.InputsAsNumpy()):
             tensor = subgraph.Tensors(int(input_index))
             input_name = get_tensor_name(subgraph, int(input_index))
-            shape = tuple(tensor.ShapeAsNumpy()) if tensor.ShapeLength() > 0 
else []
-            dtype = self.get_tensor_type_str(tensor.Type())
+            shape = self._get_relax_tensor_shape(tensor)
+            dtype = self._get_relax_tensor_dtype(tensor)
             param = relax.Var(input_name, relax.TensorStructInfo(shape=shape, 
dtype=dtype))
             exp_tab.set_expr(input_name, param)
             params.append(param)
@@ -2515,12 +2557,8 @@ class OperatorConverter:
     def _get_tensor_param(self, tensor_wrapper):
         """Create a Relax parameter from TFLite tensor metadata."""
         name = get_tensor_name(self.subgraph, tensor_wrapper.tensor_idx)
-        shape = (
-            tuple(tensor_wrapper.tensor.ShapeAsNumpy())
-            if tensor_wrapper.tensor.ShapeLength() > 0
-            else []
-        )
-        dtype = self.get_tensor_type_str(tensor_wrapper.tensor.Type())
+        shape = self._get_relax_tensor_shape(tensor_wrapper)
+        dtype = self._get_relax_tensor_dtype(tensor_wrapper)
         return relax.Var(name, relax.TensorStructInfo(shape=shape, 
dtype=dtype))
 
     def _lower_subgraph_to_function(self, subgraph_index, function_name_hint, 
op_name="CALL"):
@@ -5033,6 +5071,82 @@ class OperatorConverter:
 
         return relax.op.memory.view(in_expr, shape=output_shape, 
dtype=output_dtype)
 
+    def convert_rfft2d(self, op):
+        """Convert TFLite RFFT2D for static no-padding/no-truncation shapes."""
+        from tflite.TensorType import TensorType
+
+        input_tensors = self.get_input_tensors(op)
+        output_tensors = self.get_output_tensors(op)
+        if len(input_tensors) != 2 or len(output_tensors) != 1:
+            raise tvm.error.OpNotImplemented("RFFT2D expects two inputs and 
one output")
+
+        data_tensor, fft_length_tensor = input_tensors
+        output_tensor = output_tensors[0]
+        if data_tensor.tensor.Type() != TensorType.FLOAT32:
+            raise tvm.error.OpNotImplemented("RFFT2D input must be float32")
+        if not self._is_tflite_complex64_type(output_tensor.tensor.Type()):
+            raise tvm.error.OpNotImplemented("RFFT2D output must be COMPLEX64")
+        if (
+            data_tensor.tensor.Sparsity() is not None
+            or fft_length_tensor.tensor.Sparsity() is not None
+        ):
+            raise tvm.error.OpNotImplemented("RFFT2D does not support sparse 
inputs")
+
+        input_shape = tuple(to_int_list(self.get_tensor_shape(data_tensor)))
+        tflite_output_shape = 
tuple(to_int_list(self.get_tensor_shape(output_tensor)))
+        if len(input_shape) < 2:
+            raise tvm.error.OpNotImplemented("RFFT2D input rank must be at 
least 2")
+
+        try:
+            fft_length_value = 
self.get_tensor_value_or_prefetched(fft_length_tensor)
+        except (ValueError, TypeError):
+            raise tvm.error.OpNotImplemented("RFFT2D requires a constant 
fft_length") from None
+        if fft_length_value is None:
+            raise tvm.error.OpNotImplemented("RFFT2D requires a constant 
fft_length")
+        # Reject non-integer fft_length tensors before astype("int64") can 
silently
+        # truncate (e.g. float32 2.7 -> int64 2).
+        if not np.issubdtype(fft_length_value.dtype, np.integer):
+            raise tvm.error.OpNotImplemented(
+                f"RFFT2D fft_length must be an integer tensor, got dtype 
{fft_length_value.dtype!r}"
+            )
+        fft_length = fft_length_value.astype("int64")
+        if tuple(fft_length.shape) != (2,):
+            raise tvm.error.OpNotImplemented("RFFT2D fft_length must be a 
length-2 tensor")
+
+        height = int(fft_length[0])
+        width = int(fft_length[1])
+        if height <= 0 or width <= 0:
+            raise tvm.error.OpNotImplemented(
+                f"RFFT2D fft_length must be positive, got ({height}, {width})"
+            )
+        if height != int(input_shape[-2]) or width != int(input_shape[-1]):
+            raise tvm.error.OpNotImplemented(
+                "RFFT2D currently supports fft_length matching the input 
spatial shape"
+            )
+        expected_tflite_output_shape = input_shape[:-2] + (height, width // 2 
+ 1)
+        if tflite_output_shape != expected_tflite_output_shape:
+            raise tvm.error.OpNotImplemented("RFFT2D output shape does not 
match fft_length")
+
+        relax_output_shape = self._get_relax_tensor_shape(output_tensor)
+        # Dispatch: power-of-2 sizes use the O(N^2 log N) Cooley-Tukey FFT 
kernel;
+        # the remaining (odd / non-power-of-2) shapes fall back to the O(N^4) 
DFT
+        # reference. Both kernels share the same call_tir contract, so the
+        # downstream code is kernel-agnostic.
+        if _is_power_of_2(height) and _is_power_of_2(width):
+            prim_func = _build_tflite_rfft2d_fft_primfunc(input_shape, 
relax_output_shape)
+        else:
+            prim_func = _build_tflite_rfft2d_primfunc(input_shape, 
relax_output_shape)
+        module_builder = self.conversion_state["module_builder"]
+        func_name = f"tflite_rfft2d_{output_tensor.tensor_idx}"
+        gv = module_builder.add_func(prim_func, func_name)
+        data_expr = self.get_tensor_expr(data_tensor)
+        call = relax.call_tir(
+            gv,
+            [data_expr],
+            relax.TensorStructInfo(relax_output_shape, "float32"),
+        )
+        return self.bb.normalize(call)
+
     def convert_broadcast_args(self, op):
         """Convert TFLite BROADCAST_ARGS"""
         input_tensors = self.get_input_tensors(op)
@@ -7670,7 +7784,15 @@ class OperatorConverter:
         return self.exp_tab.new_const(value, dtype=type_str, 
source_name=tensor.tensor.Name())
 
     def get_tensor_shape(self, tensor_wrapper):
-        """Returns tensor shape. Infers shape if the shape is empty."""
+        """Returns the TFLite tensor shape, inferring it if the TFLite shape 
is empty.
+
+        This returns the *raw TFLite* shape (no pair axis), even for COMPLEX64 
tensors.
+        It is distinct from ``_get_relax_tensor_shape``, which returns the 
*Relax*
+        representation (TFLite shape with a trailing (2,) axis appended for 
COMPLEX64).
+        Operators that build Relax IR for COMPLEX64 inputs should use
+        ``_get_relax_tensor_shape``; operators that need the TFLite shape for 
validation
+        (e.g. comparing against the model's declared output shape) should use 
this method.
+        """
         assert isinstance(tensor_wrapper, TensorWrapper), "Expecting 
TensorWrapper here"
         return (
             tensor_wrapper.tensor.ShapeAsNumpy()
@@ -7679,6 +7801,299 @@ class OperatorConverter:
         )
 
 
+def _is_power_of_2(n):
+    """Return True iff ``n`` is a positive power of 2."""
+    return n > 0 and (n & (n - 1)) == 0
+
+
+def _bit_reversal_swap_pairs(n):
+    """Return the (i, j) index pairs (i < j) for the bit-reversal permutation 
of length n.
+
+    For a Cooley-Tukey radix-2 FFT, the input must be permuted by bit-reversing
+    each index in log2(n) bits before the butterfly stages. Precomputing the
+    swap pairs as constants is much cheaper in TIR than computing the
+    bit-reverse on the fly.
+    """
+    assert _is_power_of_2(n), f"bit-reversal requires power of 2, got {n}"
+    length = n.bit_length() - 1  # log2(n)
+    swaps = []
+    for i in range(1, n):
+        j = 0
+        for k in range(length):
+            if i & (1 << k):
+                j |= 1 << (length - 1 - k)
+        if i < j:
+            swaps.append((i, j))
+    return swaps
+
+
+def _build_tflite_rfft2d_primfunc(input_shape, output_pair_shape):
+    """Build a reference TIR kernel for TFLite RFFT2D.
+
+    The TFLite frontend represents complex tensors as float32 real/imag pairs
+    with a trailing dimension of size 2 because TVM does not have a native
+    complex64 dtype. This kernel computes the unnormalized 2-D real FFT over
+    the last two input dimensions and writes that pair representation.
+
+    All trig and accumulation are in float32, so the result agrees with
+    ``np.fft.rfft2`` to about ``1e-5`` absolute tolerance for typical input
+    sizes. Higher-precision backends should override this kernel.
+
+    Notes
+    -----
+    This is a **naive O(B * H * W * H * W) DFT**, not an FFT. For an input of
+    spatial shape (H, W) the inner sum runs H*W times per output position, and
+    there are H*W' output positions per batch (W' = W // 2 + 1). This is
+    intentionally simple for correctness validation against
+    ``np.fft.rfft2``; production use cases with large spatial dimensions should
+    override the kernel with an FFT-based implementation. The outer
+    (batch, out_y, out_x) iteration is structured as S-TIR spatial axes so a
+    downstream ``tvm.tir.schedule`` pass can parallelize it.
+    """
+    from tvm.script.parser import tirx as T
+
+    batch = 1
+    for dim in input_shape[:-2]:
+        batch *= int(dim)
+    height = int(input_shape[-2])
+    width = int(input_shape[-1])
+    out_width = int(output_pair_shape[-2])
+    input_total = batch * height * width
+    output_complex_total = batch * height * out_width
+    neg_two_pi = np.float32(-2.0 * math.pi)
+
+    @T.prim_func(private=True, s_tir=True, check_well_formed=False)
+    def kernel(
+        data: T.Buffer(input_shape, "float32"), output: 
T.Buffer(output_pair_shape, "float32")
+    ):
+        # Flat 1D aliases of the multi-dim buffers. The kernel is rank-agnostic
+        # over the leading batch dimensions, so collapsing the index space
+        # avoids special-casing 2D / 3D / 4D input shapes.
+        data_flat = T.decl_buffer((input_total,), "float32", data=data.data)
+        output_flat = T.decl_buffer((output_complex_total * 2,), "float32", 
data=output.data)
+        neg_two_pi_const = T.float32(neg_two_pi)
+
+        for b_idx, out_y, out_x in T.grid(batch, height, out_width):
+            with T.sblock("rfft2d"):
+                v_b, v_oy, v_ox = T.axis.remap("SSS", [b_idx, out_y, out_x])
+                real_sum = T.float32(0)
+                imag_sum = T.float32(0)
+                input_base = v_b * height * width
+                for in_y, in_x in T.grid(height, width):
+                    phase_y = T.Cast("float32", v_oy) * T.Cast("float32", 
in_y) / T.float32(height)
+                    phase_x = T.Cast("float32", v_ox) * T.Cast("float32", 
in_x) / T.float32(width)
+                    angle = neg_two_pi_const * (phase_y + phase_x)
+                    value = data_flat[input_base + in_y * width + in_x]
+                    real_sum = real_sum + value * T.cos(angle)
+                    imag_sum = imag_sum + value * T.sin(angle)
+                flat_out_idx = ((v_b * height + v_oy) * out_width + v_ox) * 2
+                output_flat[flat_out_idx] = real_sum
+                output_flat[flat_out_idx + 1] = imag_sum
+
+    return kernel
+
+
+def _build_tflite_rfft2d_fft_primfunc(input_shape, output_pair_shape):
+    """Build a 2D Cooley-Tukey FFT TIR kernel for TFLite RFFT2D.
+
+    Precondition: both ``input_shape[-2]`` (height) and ``input_shape[-1]``
+    (width) must be positive powers of 2. The frontend dispatches to this
+    kernel via ``_is_power_of_2`` checks; the DFT reference kernel handles
+    the remaining cases (odd / non-power-of-2 sizes).
+
+    Algorithm
+    ---------
+    1. Copy the real input into a scratch complex buffer (imag = 0) of shape
+       ``(B * H * W,)``.
+    2. For each batch and each row, run an in-place radix-2 1D FFT of length
+       ``W`` along the width axis.
+    3. For each batch and each column, run an in-place radix-2 1D FFT of
+       length ``H`` along the height axis (with stride ``W``).
+    4. Write the first ``W // 2 + 1`` complex bins per row to the output
+       pair representation.
+
+    The bit-reversal permutation required by iterative Cooley-Tukey is done
+    by emitting the (i, j) swap pairs directly in the TIR source (one
+    inlined swap per pair), avoiding the need for runtime index tables.
+
+    Complexity is ``O(B * H * W * (log2(H) + log2(W)))``, vs the DFT
+    reference kernel's ``O(B * H * W * H * W)``.
+    """
+    from tvm.script.parser import tirx as T
+
+    batch = 1
+    for dim in input_shape[:-2]:
+        batch *= int(dim)
+    height = int(input_shape[-2])
+    width = int(input_shape[-1])
+    out_width = int(output_pair_shape[-2])
+    input_total = batch * height * width
+    output_complex_total = batch * height * out_width
+    # Cast to Python float so the f-string-interpolated repr is a plain number
+    # (np.float32's repr is "np.float32(...)", which the TIR parser can't 
resolve).
+    neg_two_pi = float(np.float32(-2.0 * math.pi))
+    log2_w = int(math.log2(width))
+    log2_h = int(math.log2(height))
+
+    if not (_is_power_of_2(height) and _is_power_of_2(width)):
+        raise ValueError(
+            f"_build_tflite_rfft2d_fft_primfunc requires power-of-2 height and 
width, "
+            f"got H={height}, W={width}"
+        )
+
+    # Precompute the bit-reversal swap pairs at Python level. These are
+    # constant for a given FFT length and will be inlined in the TIR source.
+    # Each emitted line is indented 16 spaces (4 levels: top → b_idx loop →
+    # sblock → row/col loop body) so it lands inside the for loop when
+    # concatenated into the primfunc source.
+    row_swap_stmts = []
+    for i, j in _bit_reversal_swap_pairs(width):
+        row_swap_stmts.append(
+            f"                i_idx = row_base + {i}\n"
+            f"                j_idx = row_base + {j}\n"
+            f"                tmp_r = scratch_real[i_idx]\n"
+            f"                scratch_real[i_idx] = scratch_real[j_idx]\n"
+            f"                scratch_real[j_idx] = tmp_r\n"
+            f"                tmp_i = scratch_imag[i_idx]\n"
+            f"                scratch_imag[i_idx] = scratch_imag[j_idx]\n"
+            f"                scratch_imag[j_idx] = tmp_i\n"
+        )
+    row_swaps_code = "".join(row_swap_stmts) if row_swap_stmts else "          
      pass\n"
+
+    col_swap_stmts = []
+    for i, j in _bit_reversal_swap_pairs(height):
+        col_swap_stmts.append(
+            f"                i_idx = col_base + {i * width}\n"
+            f"                j_idx = col_base + {j * width}\n"
+            f"                tmp_r = scratch_real[i_idx]\n"
+            f"                scratch_real[i_idx] = scratch_real[j_idx]\n"
+            f"                scratch_real[j_idx] = tmp_r\n"
+            f"                tmp_i = scratch_imag[i_idx]\n"
+            f"                scratch_imag[i_idx] = scratch_imag[j_idx]\n"
+            f"                scratch_imag[j_idx] = tmp_i\n"
+        )
+    col_swaps_code = "".join(col_swap_stmts) if col_swap_stmts else "          
      pass\n"
+
+    # Build the per-stage butterfly code with the stage loop fully unrolled
+    # at primfunc-construction time. After unrolling, all loop bounds
+    # (block_start, k) are compile-time integers, so the TIR parser doesn't
+    # need to reason about runtime loop extents and the scheduler can
+    # see static twiddle factors instead of runtime trig calls.
+    def _stage_stmts(stage_count, length, indent, stride=1, 
base_expr="row_base"):
+        """Generate fully-unrolled Cooley-Tukey butterfly stage bodies.
+
+        ``base_expr`` is the TIR expression holding the base offset of the
+        FFT being transformed (e.g. ``"row_base"`` for rows or
+        ``"col_base"`` for columns). ``stride`` is the integer distance
+        between adjacent butterfly taps: 1 for the row-FFT (contiguous
+        elements) and ``width`` for the column-FFT (strided access).
+        """
+        out = []
+        for stage in range(1, stage_count + 1):
+            m_val = 1 << stage
+            half_val = m_val >> 1
+            for block_start in range(0, length, m_val):
+                for k in range(half_val):
+                    if stride == 1:
+                        a_idx = f"{base_expr} + {block_start} + {k}"
+                        b_idx_expr = f"{base_expr} + {block_start} + {k} + 
{half_val}"
+                    else:
+                        a_idx = f"{base_expr} + ({block_start} + {k}) * 
{stride}"
+                        b_idx_expr = f"{base_expr} + ({block_start} + {k} + 
{half_val}) * {stride}"
+                    angle_val = float(np.float32(neg_two_pi * k / m_val))
+                    w_real_val = float(np.float32(math.cos(angle_val)))
+                    w_imag_val = float(np.float32(math.sin(angle_val)))
+                    out.extend(
+                        [
+                            f"{indent}a_idx = {a_idx}\n",
+                            f"{indent}b_idx_local = {b_idx_expr}\n",
+                            f"{indent}t_real = scratch_real[b_idx_local] * 
T.float32({w_real_val!r}) - scratch_imag[b_idx_local] * 
T.float32({w_imag_val!r})\n",
+                            f"{indent}t_imag = scratch_real[b_idx_local] * 
T.float32({w_imag_val!r}) + scratch_imag[b_idx_local] * 
T.float32({w_real_val!r})\n",
+                            f"{indent}u_real = scratch_real[a_idx]\n",
+                            f"{indent}u_imag = scratch_imag[a_idx]\n",
+                            f"{indent}scratch_real[a_idx] = u_real + t_real\n",
+                            f"{indent}scratch_imag[a_idx] = u_imag + t_imag\n",
+                            f"{indent}scratch_real[b_idx_local] = u_real - 
t_real\n",
+                            f"{indent}scratch_imag[b_idx_local] = u_imag - 
t_imag\n",
+                        ]
+                    )
+        return "".join(out)
+
+    row_stages_code = _stage_stmts(
+        log2_w, width, "                ", stride=1, base_expr="row_base"
+    )
+    col_stages_code = _stage_stmts(
+        log2_h, height, "                ", stride=width, base_expr="col_base"
+    )
+
+    # Build the primfunc source. The bit-reversal swaps are inlined (one
+    # unconditional block per (i, j) pair) and the butterfly stages are
+    # fully unrolled, so the TIR parser sees ordinary statements rather than
+    # runtime table lookups or runtime-magnitude loops. The body is wrapped
+    # in a single S-TIR block over the batch dimension so the Relax
+    # pipeline (which expects an SBlockRealize at the primfunc body) accepts
+    # this kernel.
+    primfunc_source = (
+        "from tvm.script.parser import tirx as T\n"
+        "@T.prim_func(private=True, s_tir=True, check_well_formed=False)\n"
+        "def kernel(\n"
+        f"    data: T.Buffer({tuple(int(x) for x in input_shape)}, 
'float32'),\n"
+        f"    output: T.Buffer({tuple(int(x) for x in output_pair_shape)}, 
'float32'),\n"
+        "):\n"
+        f"    data_flat = T.decl_buffer(({input_total},), 'float32', 
data=data.data)\n"
+        f"    output_flat = T.decl_buffer(({output_complex_total * 2},), 
'float32', data=output.data)\n"
+        f"    scratch_real = T.decl_buffer(({input_total},), 'float32')\n"
+        f"    scratch_imag = T.decl_buffer(({input_total},), 'float32')\n"
+        f"    for b_idx in T.serial({batch}):\n"
+        f"        with T.sblock('rfft2d_fft'):\n"
+        f"            v_b = T.axis.remap('S', [b_idx])\n"
+        f"            # Initialize scratch from real input; imag = 0.\n"
+        f"            for i in T.serial({height * width}):\n"
+        f"                src = v_b * {height * width} + i\n"
+        f"                scratch_real[src] = data_flat[src]\n"
+        f"                scratch_imag[src] = T.float32(0)\n"
+        f"            # FFT along width axis (one 1D FFT per row).\n"
+        f"            for row in T.serial({height}):\n"
+        f"                row_base = v_b * {height * width} + row * {width}\n"
+        f"                # Bit-reversal permutation (inlined swaps).\n"
+        f"{row_swaps_code}"
+        f"                # Cooley-Tukey butterfly stages (fully unrolled).\n"
+        f"{row_stages_code}"
+        f"            # FFT along height axis (one 1D FFT per column, stride = 
width).\n"
+        f"            for col in T.serial({width}):\n"
+        f"                col_base = v_b * {height * width} + col\n"
+        f"                # Bit-reversal permutation (strided, inlined 
swaps).\n"
+        f"{col_swaps_code}"
+        f"                # Cooley-Tukey butterfly stages (strided, fully 
unrolled).\n"
+        f"{col_stages_code}"
+        f"            # Write the first out_width complex bins per row to 
output.\n"
+        f"            for row in T.serial({height}):\n"
+        f"                for out_x in T.serial({out_width}):\n"
+        f"                    src = v_b * {height * width} + row * {width} + 
out_x\n"
+        f"                    dst = ((v_b * {height} + row) * {out_width} + 
out_x) * 2\n"
+        f"                    output_flat[dst] = scratch_real[src]\n"
+        f"                    output_flat[dst + 1] = scratch_imag[src]\n"
+    )
+
+    namespace = {"T": T, "tirx": T}
+    # Register the generated source in linecache so the TIR parser (which calls
+    # inspect.getsourcelines) can find it. The fake filename is stable across
+    # calls — we include input/output shapes so the cache is 
self-disambiguating
+    # for any callers who want to introspect the generated source.
+    import linecache as _linecache
+
+    fake_file = f"<tflite_rfft2d_fft_primfunc H={height} W={width} 
outW={out_width}>"
+    _linecache.cache[fake_file] = (
+        len(primfunc_source.splitlines()),
+        None,
+        [line + "\n" for line in primfunc_source.splitlines()],
+        fake_file,
+    )
+    code = compile(primfunc_source, fake_file, "exec")
+    exec(code, namespace)
+    return namespace["kernel"]
+
+
 # Constants for the Random123 counter-based PRNGs used by 
STABLEHLO_RNG_BIT_GENERATOR,
 # matching tensorflow/lite/kernels/rng_util.cc.
 _STABLEHLO_RNG_THREEFRY_PARITY = 0x1BD11BDA
diff --git a/tests/python/relax/test_frontend_tflite.py 
b/tests/python/relax/test_frontend_tflite.py
index e9a842b8df..01beaadd7b 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -4082,6 +4082,285 @@ def _run_no_input_module(mod):
     return _run_module(mod)
 
 
+def _complex64_to_pair(value):
+    value = np.asarray(value, dtype=np.complex64)
+    return np.stack([value.real, value.imag], axis=-1).astype("float32")
+
+
+def _build_tflite_rfft2d_model(*, input_shape, fft_length, output_shape):
+    """Build a minimal TFLite RFFT2D model."""
+    builder = flatbuffers.Builder(1024)
+    builtin_op = _get_builtin_operator("RFFT2D")
+    op_code = _build_operator_code(builder, builtin_op)
+    tensors = [
+        _build_tensor(builder, 0, input_shape, 
tensor_type=_tfl_tensor_type.FLOAT32),
+        _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.INT32),
+        _build_tensor(builder, 2, output_shape, 
tensor_type=_tfl_tensor_type.COMPLEX64),
+    ]
+    op = _build_operator(builder, 0, [0, 1], [2])
+    subgraph = _build_subgraph(builder, tensors=tensors, operators=[op], 
inputs=[0], outputs=[2])
+    buffers = [
+        _build_buffer(builder),
+        _build_buffer(builder, np.array(fft_length, dtype=np.int32).tobytes()),
+        _build_buffer(builder),
+    ]
+    return _finish_tflite_model(
+        builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers
+    )
+
+
+def test_rfft2d_static_pair_output():
+    """TFLite RFFT2D emits a call_tir kernel with float32 real/imag pair 
output."""
+    mod = _load_model_from_buffer(
+        _build_tflite_rfft2d_model(
+            input_shape=[2, 4],
+            fft_length=[2, 4],
+            output_shape=[2, 3],
+        )
+    )
+
+    mod_script = mod.script()
+    assert "tflite_rfft2d" in mod_script
+    assert "R.call_tir" in mod_script
+    assert 'R.Tensor((2, 3, 2), dtype="float32")' in mod_script
+
+    data = np.array([[1.0, -2.0, 3.0, 4.0], [5.0, 6.0, -7.0, 8.0]], 
dtype="float32")
+    expected = np.fft.rfft2(data).astype(np.complex64)
+    # atol accommodates the float32 reference kernel: numpy's rfft2 internally 
uses
+    # float64, while the reference TIR kernel accumulates in float32 (see
+    # _build_tflite_rfft2d_primfunc docstring).
+    np.testing.assert_allclose(
+        _run_module(mod, data), _complex64_to_pair(expected), rtol=1e-5, 
atol=1e-5
+    )
+
+
+def test_rfft2d_static_pair_output_with_batch():
+    """RFFT2D computes over the last two axes and preserves leading batch 
dimensions."""
+    mod = _load_model_from_buffer(
+        _build_tflite_rfft2d_model(
+            input_shape=[2, 2, 4],
+            fft_length=[2, 4],
+            output_shape=[2, 2, 3],
+        )
+    )
+
+    data = np.array(
+        [
+            [[1.0, -2.0, 3.0, 4.0], [5.0, 6.0, -7.0, 8.0]],
+            [[-1.0, 2.0, 0.5, -4.0], [3.5, -6.0, 7.0, 1.0]],
+        ],
+        dtype="float32",
+    )
+    expected = np.fft.rfft2(data).astype(np.complex64)
+    np.testing.assert_allclose(
+        _run_module(mod, data), _complex64_to_pair(expected), rtol=1e-5, 
atol=1e-5
+    )
+
+
+def test_rfft2d_odd_width_pair_output():
+    """RFFT2D handles odd width: output has width//2 + 1 bins (TFLite 
convention)."""
+    mod = _load_model_from_buffer(
+        _build_tflite_rfft2d_model(
+            input_shape=[3, 5],
+            fft_length=[3, 5],
+            output_shape=[3, 3],  # 5 // 2 + 1 = 3
+        )
+    )
+
+    data = np.array(
+        [[1.0, -2.0, 3.0, 4.0, -5.0], [0.5, 6.0, -7.0, 8.0, 2.5], [-1.5, 4.0, 
0.0, -3.0, 1.0]],
+        dtype="float32",
+    )
+    expected = np.fft.rfft2(data).astype(np.complex64)
+    # atol accommodates the float32 reference kernel (see
+    # _build_tflite_rfft2d_primfunc docstring).
+    np.testing.assert_allclose(
+        _run_module(mod, data), _complex64_to_pair(expected), rtol=1e-5, 
atol=1e-5
+    )
+
+
+def test_rfft2d_int64_fft_length():
+    """RFFT2D accepts INT64 fft_length constant (TFLite schema allows either 
int32 or int64)."""
+    builder = flatbuffers.Builder(1024)
+    rfft_op_code = _build_operator_code(builder, 
_get_builtin_operator("RFFT2D"))
+    tensors = [
+        _build_tensor(builder, 0, [2, 4], 
tensor_type=_tfl_tensor_type.FLOAT32),
+        _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.INT64),
+        _build_tensor(builder, 2, [2, 3], 
tensor_type=_tfl_tensor_type.COMPLEX64),
+    ]
+    op = _build_operator(builder, 0, [0, 1], [2])
+    subgraph = _build_subgraph(builder, tensors=tensors, operators=[op], 
inputs=[0], outputs=[2])
+    buffers = [
+        _build_buffer(builder),
+        _build_buffer(builder, np.array([2, 4], dtype=np.int64).tobytes()),
+        _build_buffer(builder),
+    ]
+    buf = _finish_tflite_model(
+        builder, subgraph=subgraph, operator_codes=[rfft_op_code], 
buffers=buffers
+    )
+    mod = _load_model_from_buffer(buf)
+
+    data = np.array([[1.0, -2.0, 3.0, 4.0], [5.0, 6.0, -7.0, 8.0]], 
dtype="float32")
+    expected = np.fft.rfft2(data).astype(np.complex64)
+    np.testing.assert_allclose(
+        _run_module(mod, data), _complex64_to_pair(expected), rtol=1e-5, 
atol=1e-5
+    )
+
+
+def test_rfft2d_4d_input_pair_output():
+    """RFFT2D accepts 4D input and preserves leading batch dimensions."""
+    mod = _load_model_from_buffer(
+        _build_tflite_rfft2d_model(
+            input_shape=[2, 3, 4, 5],  # batch=6, H=4, W=5
+            fft_length=[4, 5],
+            output_shape=[2, 3, 4, 3],  # 5 // 2 + 1 = 3
+        )
+    )
+
+    rng = np.random.RandomState(0)
+    data = (rng.randn(2, 3, 4, 5) * 0.5).astype("float32")
+    expected = np.fft.rfft2(data).astype(np.complex64)
+    # 4D test accumulates 20 inner terms per output; use a slightly larger atol
+    # than the 2D case (which accumulates 4-8 terms).
+    np.testing.assert_allclose(
+        _run_module(mod, data), _complex64_to_pair(expected), rtol=1e-5, 
atol=1e-4
+    )
+
+
+def test_rfft2d_minimal_1x1_pair_output():
+    """RFFT2D on a [1, 1] input: the only output is the DC component (sum of 
inputs)."""
+    mod = _load_model_from_buffer(
+        _build_tflite_rfft2d_model(
+            input_shape=[1, 1],
+            fft_length=[1, 1],
+            output_shape=[1, 1],
+        )
+    )
+
+    data = np.array([[3.5]], dtype="float32")
+    expected = np.fft.rfft2(data).astype(np.complex64)
+    np.testing.assert_allclose(
+        _run_module(mod, data), _complex64_to_pair(expected), rtol=1e-5, 
atol=1e-5
+    )
+
+
+def test_rfft2d_fft_path_8x8():
+    """RFFT2D on a square 8x8 input exercises the Cooley-Tukey FFT dispatch 
path."""
+    mod = _load_model_from_buffer(
+        _build_tflite_rfft2d_model(
+            input_shape=[8, 8],
+            fft_length=[8, 8],
+            output_shape=[8, 5],
+        )
+    )
+
+    np.random.seed(0xCAFE)
+    data = np.random.randn(8, 8).astype("float32")
+    expected = np.fft.rfft2(data).astype(np.complex64)
+    # The FFT path uses float32 twiddles (cos/sin) and float32 butterfly
+    # accumulation, so the error vs. numpy's float64 reference is in the
+    # 1e-4 range on these random inputs.
+    np.testing.assert_allclose(
+        _run_module(mod, data), _complex64_to_pair(expected), rtol=1e-4, 
atol=1e-4
+    )
+
+
+def test_rfft2d_fft_path_4x4():
+    """RFFT2D on a 4x4 input: smallest case where both row and column FFTs do 
real work."""
+    mod = _load_model_from_buffer(
+        _build_tflite_rfft2d_model(
+            input_shape=[4, 4],
+            fft_length=[4, 4],
+            output_shape=[4, 3],
+        )
+    )
+
+    np.random.seed(0xFEED)
+    data = np.random.randn(4, 4).astype("float32")
+    expected = np.fft.rfft2(data).astype(np.complex64)
+    np.testing.assert_allclose(
+        _run_module(mod, data), _complex64_to_pair(expected), rtol=1e-4, 
atol=1e-4
+    )
+
+
+def test_rfft2d_fft_path_2x2x4x8():
+    """RFFT2D on a 4D input with power-of-2 height/width exercises the FFT 
path with batch."""
+    mod = _load_model_from_buffer(
+        _build_tflite_rfft2d_model(
+            input_shape=[2, 2, 4, 8],
+            fft_length=[4, 8],
+            output_shape=[2, 2, 4, 5],
+        )
+    )
+
+    np.random.seed(0xBEEF)
+    data = np.random.randn(2, 2, 4, 8).astype("float32")
+    expected = np.fft.rfft2(data, axes=(-2, -1)).astype(np.complex64)
+    np.testing.assert_allclose(
+        _run_module(mod, data), _complex64_to_pair(expected), rtol=1e-4, 
atol=1e-4
+    )
+
+
+def test_rfft2d_fft_path_16x16():
+    """RFFT2D on a 16x16 input: a larger FFT to check that the unrolled kernel 
scales."""
+    mod = _load_model_from_buffer(
+        _build_tflite_rfft2d_model(
+            input_shape=[16, 16],
+            fft_length=[16, 16],
+            output_shape=[16, 9],
+        )
+    )
+
+    np.random.seed(0xDEAD)
+    data = np.random.randn(16, 16).astype("float32")
+    expected = np.fft.rfft2(data).astype(np.complex64)
+    np.testing.assert_allclose(
+        _run_module(mod, data), _complex64_to_pair(expected), rtol=1e-4, 
atol=1e-4
+    )
+
+
+def test_rfft2d_mismatched_fft_length_unsupported():
+    """RFFT2D padding/truncation cases are guarded until explicitly 
implemented."""
+    buf = _build_tflite_rfft2d_model(
+        input_shape=[2, 4],
+        fft_length=[4, 4],
+        output_shape=[4, 3],
+    )
+    if hasattr(tflite.Model, "Model"):
+        tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+    else:
+        tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+
+    with pytest.raises(tvm.error.OpNotImplemented, match="fft_length"):
+        from_tflite(tflite_model)
+
+
+def test_rfft2d_dynamic_fft_length_unsupported():
+    """RFFT2D requires fft_length to be a constant tensor."""
+    builder = flatbuffers.Builder(1024)
+    rfft_op_code = _build_operator_code(builder, 
_get_builtin_operator("RFFT2D"))
+    tensors = [
+        _build_tensor(builder, 0, [2, 4], 
tensor_type=_tfl_tensor_type.FLOAT32),
+        _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.INT32),
+        _build_tensor(builder, 2, [2, 3], 
tensor_type=_tfl_tensor_type.COMPLEX64),
+    ]
+    op = _build_operator(builder, 0, [0, 1], [2])
+    subgraph = _build_subgraph(builder, tensors=tensors, operators=[op], 
inputs=[0, 1], outputs=[2])
+    buf = _finish_tflite_model(
+        builder,
+        subgraph=subgraph,
+        operator_codes=[rfft_op_code],
+        buffers=[_build_buffer(builder), _build_buffer(builder), 
_build_buffer(builder)],
+    )
+    if hasattr(tflite.Model, "Model"):
+        tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+    else:
+        tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+
+    with pytest.raises(tvm.error.OpNotImplemented, match="requires a constant 
fft_length"):
+        from_tflite(tflite_model)
+
+
 def _build_tflite_call_model(
     call_subgraph_index=1,
     callee_inputs=None,

Reply via email to