This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 772857d34c [Relax][Frontend][TFLite] Add ATAN2 op and TFLite mapping 
(#19485)
772857d34c is described below

commit 772857d34c1de433cd3ca18a1f18d46df5652d41
Author: as4230 <[email protected]>
AuthorDate: Fri May 1 00:04:51 2026 -0400

    [Relax][Frontend][TFLite] Add ATAN2 op and TFLite mapping (#19485)
    
    This PR adds the ATAN2 operator to the Relax TFLite frontend.
    
    Introduces relax.op.atan2 as a new binary elementwise primitive (TOPI
    broadcast op, Relax registration, legalization to topi.atan2, script
    parser support) and registers ATAN2 in the TFLite convert_map. It reuses
    the TIR primitive tvm::atan2 so this PR is the higher-layer plumbing.
    
    Validation:
        python -m pytest tests/python/relax/test_op_binary.py
    python -m pytest tests/python/relax/test_frontend_tflite.py -k binary
    
    Addresses the ATAN2 item under #19412.
---
 include/tvm/topi/broadcast.h                       |  13 +++
 .../tvm/relax/frontend/tflite/tflite_frontend.py   |   1 +
 python/tvm/relax/op/__init__.py                    |   1 +
 python/tvm/relax/op/binary.py                      |  18 ++++
 python/tvm/relax/script/builder/ir.py              |   2 +
 python/tvm/relax/transform/legalize_ops/binary.py  |   1 +
 python/tvm/topi/broadcast.py                       |  19 ++++
 src/relax/op/tensor/binary.cc                      |   1 +
 src/relax/op/tensor/binary.h                       |   3 +
 src/topi/broadcast.cc                              |   1 +
 tests/python/relax/test_frontend_tflite.py         |   1 +
 tests/python/relax/test_op_binary.py               |   2 +
 .../relax/test_transform_legalize_ops_binary.py    | 117 +++++++++++++++++++++
 13 files changed, 180 insertions(+)

diff --git a/include/tvm/topi/broadcast.h b/include/tvm/topi/broadcast.h
index f8ef2edc39..b0c6ac8f67 100644
--- a/include/tvm/topi/broadcast.h
+++ b/include/tvm/topi/broadcast.h
@@ -384,6 +384,19 @@ TOPI_DEFINE_BCAST_OP(minimum, { return tvm::min(a, b); });
  */
 TOPI_DEFINE_BCAST_OP(power, { return tvm::pow(a, b); });
 
+/*!
+ * \fn atan2
+ * \brief Compute atan2(y, x) with auto-broadcasting.
+ *
+ * \param A The first tensor, or Expr (y-coordinates).
+ * \param B The second tensor, or Expr (x-coordinates).
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return The result.
+ */
+TOPI_DEFINE_BCAST_OP(atan2, { return tvm::atan2(a, b); });
+
 /*!
  * \fn left_shift
  * \brief Compute A << B with auto-broadcasting.
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py 
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index b7a7e42c48..155b6301f9 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -121,6 +121,7 @@ class OperatorConverter:
             "ADD_N": self.convert_add_n,
             "ARG_MAX": functools.partial(self._convert_arg_min_max, 
relax_op=_op.argmax),
             "ARG_MIN": functools.partial(self._convert_arg_min_max, 
relax_op=_op.argmin),
+            "ATAN2": functools.partial(self._convert_elemwise, 
relax_op=_op.atan2),
             "AVERAGE_POOL_2D": functools.partial(self.convert_pool2d, 
pool_type="average"),
             "BATCH_TO_SPACE_ND": self.convert_batch_to_space_nd,
             "BATCH_MATMUL": self.convert_batch_matmul,
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 6f985ef36c..473e50ed30 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -47,6 +47,7 @@ from .base import (
 )
 from .binary import (
     add,
+    atan2,
     bitwise_and,
     bitwise_or,
     bitwise_xor,
diff --git a/python/tvm/relax/op/binary.py b/python/tvm/relax/op/binary.py
index 939ba69275..9480612e6f 100644
--- a/python/tvm/relax/op/binary.py
+++ b/python/tvm/relax/op/binary.py
@@ -141,6 +141,24 @@ def power(x1: Expr, x2: Expr):
     return _ffi_api.power(x1, x2)  # type: ignore
 
 
+def atan2(x1: Expr, x2: Expr) -> Expr:
+    """Atan2 with numpy-style broadcasting.
+
+    Parameters
+    ----------
+    x1 : relax.Expr
+        The first input tensor (y-coordinates).
+    x2 : relax.Expr
+        The second input tensor (x-coordinates).
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    return _ffi_api.atan2(x1, x2)  # type: ignore
+
+
 def subtract(x1: Expr, x2: Expr) -> Expr:
     """Subtraction with numpy-style broadcasting.
 
diff --git a/python/tvm/relax/script/builder/ir.py 
b/python/tvm/relax/script/builder/ir.py
index f62164dbd7..84ad485a33 100644
--- a/python/tvm/relax/script/builder/ir.py
+++ b/python/tvm/relax/script/builder/ir.py
@@ -54,6 +54,7 @@ from tvm.relax.op import (
     assert_op,
     astype,
     atan,
+    atan2,
     atanh,
     bitwise_and,
     bitwise_not,
@@ -813,6 +814,7 @@ __all__ = [
     "assert_op",
     "astype",
     "atan",
+    "atan2",
     "atanh",
     "bitwise_and",
     "bitwise_not",
diff --git a/python/tvm/relax/transform/legalize_ops/binary.py 
b/python/tvm/relax/transform/legalize_ops/binary.py
index 85e3f06440..355fed86b9 100644
--- a/python/tvm/relax/transform/legalize_ops/binary.py
+++ b/python/tvm/relax/transform/legalize_ops/binary.py
@@ -49,6 +49,7 @@ register_legalize("relax.floor_divide", 
_binary(topi.floor_divide))
 register_legalize("relax.log_add_exp", _binary(topi.log_add_exp))
 register_legalize("relax.multiply", _binary(topi.multiply))
 register_legalize("relax.power", _binary(topi.power))
+register_legalize("relax.atan2", _binary(topi.atan2))
 register_legalize("relax.subtract", _binary(topi.subtract))
 register_legalize("relax.equal", _binary(topi.equal))
 register_legalize("relax.mod", _binary(topi.mod))
diff --git a/python/tvm/topi/broadcast.py b/python/tvm/topi/broadcast.py
index c00495b032..e97730bc88 100644
--- a/python/tvm/topi/broadcast.py
+++ b/python/tvm/topi/broadcast.py
@@ -249,6 +249,25 @@ def power(lhs, rhs):
     return _cpp.power(lhs, rhs)
 
 
+def atan2(lhs, rhs):
+    """Atan2 with auto-broadcasting.
+
+    Parameters
+    ----------
+    lhs : tvm.te.Tensor or Expr
+        The left operand (y-coordinates).
+    rhs : tvm.te.Tensor or Expr
+        The right operand (x-coordinates).
+
+    Returns
+    -------
+    ret : tvm.te.Tensor or Expr
+        Returns Expr if both operands are Expr.
+        Otherwise returns Tensor.
+    """
+    return _cpp.atan2(lhs, rhs)
+
+
 def left_shift(lhs, rhs):
     """Left shift with auto-broadcasting
 
diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc
index 71c00e09e4..07c3364a9f 100644
--- a/src/relax/op/tensor/binary.cc
+++ b/src/relax/op/tensor/binary.cc
@@ -203,6 +203,7 @@ RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(floor_divide);
 RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(log_add_exp);
 RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(multiply);
 RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(power);
+RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(atan2);
 RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(subtract);
 RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(mod);
 RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(floor_mod);
diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h
index b5650fad27..a0dfbd66e6 100644
--- a/src/relax/op/tensor/binary.h
+++ b/src/relax/op/tensor/binary.h
@@ -81,6 +81,9 @@ Expr multiply(Expr x1, Expr x2);
 /*! \brief Power with numpy-style broadcasting. */
 Expr power(Expr x1, Expr x2);
 
+/*! \brief Atan2 with numpy-style broadcasting. */
+Expr atan2(Expr x1, Expr x2);
+
 /*! \brief Subtraction with numpy-style broadcasting. */
 Expr subtract(Expr x1, Expr x2);
 
diff --git a/src/topi/broadcast.cc b/src/topi/broadcast.cc
index c90b208771..cba8e29afa 100644
--- a/src/topi/broadcast.cc
+++ b/src/topi/broadcast.cc
@@ -66,6 +66,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
       .TOPI_DEF_BCAST_OP("topi.maximum", topi::maximum)
       .TOPI_DEF_BCAST_OP("topi.minimum", topi::minimum)
       .TOPI_DEF_BCAST_OP("topi.power", topi::power)
+      .TOPI_DEF_BCAST_OP("topi.atan2", topi::atan2)
       .TOPI_DEF_BCAST_OP("topi.left_shift", topi::left_shift)
       .TOPI_DEF_BCAST_OP("topi.logical_and", topi::logical_and)
       .TOPI_DEF_BCAST_OP("topi.logical_or", topi::logical_or)
diff --git a/tests/python/relax/test_frontend_tflite.py 
b/tests/python/relax/test_frontend_tflite.py
index 69aab2d43b..37211d337a 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -769,6 +769,7 @@ def test_fill_dynamic_dims():
         (tf.divide, R.divide),
         (tf.math.floormod, R.floor_mod),
         (tf.math.floordiv, R.floor_divide),
+        (tf.math.atan2, R.atan2),
     ],
 )
 def test_binary(tf_op, relax_op):
diff --git a/tests/python/relax/test_op_binary.py 
b/tests/python/relax/test_op_binary.py
index 7049e6aaef..0ac8cf1e9f 100644
--- a/tests/python/relax/test_op_binary.py
+++ b/tests/python/relax/test_op_binary.py
@@ -34,6 +34,7 @@ def test_op_correctness():
     assert relax.op.floor_divide(x, y).op == Op.get("relax.floor_divide")
     assert relax.op.multiply(x, y).op == Op.get("relax.multiply")
     assert relax.op.power(x, y).op == Op.get("relax.power")
+    assert relax.op.atan2(x, y).op == Op.get("relax.atan2")
     assert relax.op.subtract(x, y).op == Op.get("relax.subtract")
     assert relax.op.mod(x, y).op == Op.get("relax.mod")
     assert relax.op.floor_mod(x, y).op == Op.get("relax.floor_mod")
@@ -71,6 +72,7 @@ def _check_inference(bb: relax.BlockBuilder, call: 
relax.Call, expected_sinfo: r
     (relax.op.floor_divide, tirx.FloorDiv),
     (relax.op.multiply, tirx.Mul),
     (relax.op.power, tirx.pow),
+    (relax.op.atan2, tirx.atan2),
     (relax.op.subtract, tirx.Sub),
     (relax.op.maximum, tirx.Max),
     (relax.op.minimum, tirx.Min),
diff --git a/tests/python/relax/test_transform_legalize_ops_binary.py 
b/tests/python/relax/test_transform_legalize_ops_binary.py
index f9b2074eab..42355ba757 100644
--- a/tests/python/relax/test_transform_legalize_ops_binary.py
+++ b/tests/python/relax/test_transform_legalize_ops_binary.py
@@ -791,6 +791,123 @@ def test_power_primvalue():
     tvm.ir.assert_structural_equal(Expected, After)
 
 
+def test_atan2():
+    # fmt: off
+    @tvm.script.ir_module
+    class Atan2:
+        @R.function
+        def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), 
"float32")) -> R.Tensor((4, 3, 2, 3), "float32"):
+            gv: R.Tensor((4, 3, 2, 3), "float32") = R.atan2(x, y)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func(private=True)
+        def atan2(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), 
T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), 
T.int64(2), T.int64(1)), "float32"), T_atan2: T.Buffer((T.int64(4), T.int64(3), 
T.int64(2), T.int64(3)), "float32")):
+            T.func_attr({"tirx.noalias": True})
+            # with T.sblock("root"):
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(3), 
T.int64(2), T.int64(3)):
+                with T.sblock("T_atan2"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(rxplaceholder[T.int64(0), v_ax2, v_ax3], 
rxplaceholder_1[v_ax0, v_ax1, v_ax2, T.int64(0)])
+                    T.writes(T_atan2[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_atan2[v_ax0, v_ax1, v_ax2, v_ax3] = 
T.atan2(rxplaceholder[T.int64(0), v_ax2, v_ax3], rxplaceholder_1[v_ax0, v_ax1, 
v_ax2, T.int64(0)])
+
+        @R.function
+        def main(x: R.Tensor((1, 2, 3), dtype="float32"), y: R.Tensor((4, 3, 
2, 1), dtype="float32")) -> R.Tensor((4, 3, 2, 3), dtype="float32"):
+            gv = R.call_tir(Expected.atan2, (x, y), out_sinfo=R.Tensor((4, 3, 
2, 3), dtype="float32"))
+            return gv
+
+    # fmt: on
+
+    mod = LegalizeOps()(Atan2)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_atan2_symbolic():
+    # fmt: off
+    @tvm.script.ir_module
+    class Atan2:
+        @R.function
+        def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", 
"c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"):
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
+            gv: R.Tensor((a, b, c, d), "float32") = R.atan2(x, y)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func(private=True)
+        def atan2(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, 
var_T_atan2: T.handle):
+            T.func_attr({"tirx.noalias": True})
+            c = T.int64()
+            d = T.int64()
+            rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), c, 
d))
+            a = T.int64()
+            b = T.int64()
+            rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (a, b, c, 
T.int64(1)))
+            T_atan2 = T.match_buffer(var_T_atan2, (a, b, c, d))
+            for ax0, ax1, ax2, ax3 in T.grid(a, b, c, d):
+                with T.sblock("T_atan2"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(rxplaceholder[T.int64(0), v_ax2, v_ax3], 
rxplaceholder_1[v_ax0, v_ax1, v_ax2, T.int64(0)])
+                    T.writes(T_atan2[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_atan2[v_ax0, v_ax1, v_ax2, v_ax3] = 
T.atan2(rxplaceholder[T.int64(0), v_ax2, v_ax3], rxplaceholder_1[v_ax0, v_ax1, 
v_ax2, T.int64(0)])
+
+        @R.function
+        def main(x: R.Tensor((1, "c", "d"), dtype="float32"), y: 
R.Tensor(("a", "b", "c", 1), dtype="float32")) -> R.Tensor(("a", "b", "c", 
"d"), dtype="float32"):
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
+            gv = R.call_tir(Expected.atan2, (x, y), out_sinfo=R.Tensor((a, b, 
c, d), dtype="float32"))
+            return gv
+    # fmt: on
+
+    mod = LegalizeOps()(Expected)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_atan2_primvalue():
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor([64, 32, 16], "float32"),
+            y: R.Prim("float32"),
+        ):
+            gv = R.atan2(x, y)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor([64, 32, 16], "float32"),
+            y: R.Prim("float32"),
+        ):
+            cls = Expected
+            gv = R.call_tir(cls.atan2, (x, y), R.Tensor([64, 32, 16], 
dtype="float32"))
+            return gv
+
+        @T.prim_func(private=True)
+        def atan2(
+            lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"),
+            rhs: T.float32,
+            output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], 
"float32"),
+        ):
+            T.func_attr({"tirx.noalias": True})
+            for i, j, k in T.grid(*lhs.shape):
+                with T.sblock("T_atan2"):
+                    vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+                    output[vi, vj, vk] = T.atan2(lhs[vi, vj, vk], rhs)
+
+    After = LegalizeOps()(Before)
+    tvm.ir.assert_structural_equal(Expected, After)
+
+
 def test_subtract():
     # fmt: off
     @tvm.script.ir_module

Reply via email to