This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 772857d34c [Relax][Frontend][TFLite] Add ATAN2 op and TFLite mapping
(#19485)
772857d34c is described below
commit 772857d34c1de433cd3ca18a1f18d46df5652d41
Author: as4230 <[email protected]>
AuthorDate: Fri May 1 00:04:51 2026 -0400
[Relax][Frontend][TFLite] Add ATAN2 op and TFLite mapping (#19485)
This PR adds the ATAN2 operator to the Relax TFLite frontend.
Introduces relax.op.atan2 as a new binary elementwise primitive (TOPI
broadcast op, Relax registration, legalization to topi.atan2, script
parser support) and registers ATAN2 in the TFLite convert_map. It reuses
the TIR primitive tvm::atan2 so this PR is the higher-layer plumbing.
Validation:
python -m pytest tests/python/relax/test_op_binary.py
python -m pytest tests/python/relax/test_frontend_tflite.py -k binary
Addresses the ATAN2 item under #19412.
---
include/tvm/topi/broadcast.h | 13 +++
.../tvm/relax/frontend/tflite/tflite_frontend.py | 1 +
python/tvm/relax/op/__init__.py | 1 +
python/tvm/relax/op/binary.py | 18 ++++
python/tvm/relax/script/builder/ir.py | 2 +
python/tvm/relax/transform/legalize_ops/binary.py | 1 +
python/tvm/topi/broadcast.py | 19 ++++
src/relax/op/tensor/binary.cc | 1 +
src/relax/op/tensor/binary.h | 3 +
src/topi/broadcast.cc | 1 +
tests/python/relax/test_frontend_tflite.py | 1 +
tests/python/relax/test_op_binary.py | 2 +
.../relax/test_transform_legalize_ops_binary.py | 117 +++++++++++++++++++++
13 files changed, 180 insertions(+)
diff --git a/include/tvm/topi/broadcast.h b/include/tvm/topi/broadcast.h
index f8ef2edc39..b0c6ac8f67 100644
--- a/include/tvm/topi/broadcast.h
+++ b/include/tvm/topi/broadcast.h
@@ -384,6 +384,19 @@ TOPI_DEFINE_BCAST_OP(minimum, { return tvm::min(a, b); });
*/
TOPI_DEFINE_BCAST_OP(power, { return tvm::pow(a, b); });
+/*!
+ * \fn atan2
+ * \brief Compute atan2(y, x) with auto-broadcasting.
+ *
+ * \param A The first tensor, or Expr (y-coordinates).
+ * \param B The second tensor, or Expr (x-coordinates).
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return The result.
+ */
+TOPI_DEFINE_BCAST_OP(atan2, { return tvm::atan2(a, b); });
+
/*!
* \fn left_shift
* \brief Compute A << B with auto-broadcasting.
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index b7a7e42c48..155b6301f9 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -121,6 +121,7 @@ class OperatorConverter:
"ADD_N": self.convert_add_n,
"ARG_MAX": functools.partial(self._convert_arg_min_max,
relax_op=_op.argmax),
"ARG_MIN": functools.partial(self._convert_arg_min_max,
relax_op=_op.argmin),
+ "ATAN2": functools.partial(self._convert_elemwise,
relax_op=_op.atan2),
"AVERAGE_POOL_2D": functools.partial(self.convert_pool2d,
pool_type="average"),
"BATCH_TO_SPACE_ND": self.convert_batch_to_space_nd,
"BATCH_MATMUL": self.convert_batch_matmul,
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 6f985ef36c..473e50ed30 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -47,6 +47,7 @@ from .base import (
)
from .binary import (
add,
+ atan2,
bitwise_and,
bitwise_or,
bitwise_xor,
diff --git a/python/tvm/relax/op/binary.py b/python/tvm/relax/op/binary.py
index 939ba69275..9480612e6f 100644
--- a/python/tvm/relax/op/binary.py
+++ b/python/tvm/relax/op/binary.py
@@ -141,6 +141,24 @@ def power(x1: Expr, x2: Expr):
return _ffi_api.power(x1, x2) # type: ignore
+def atan2(x1: Expr, x2: Expr) -> Expr:
+ """Atan2 with numpy-style broadcasting.
+
+ Parameters
+ ----------
+ x1 : relax.Expr
+ The first input tensor (y-coordinates).
+ x2 : relax.Expr
+ The second input tensor (x-coordinates).
+
+ Returns
+ -------
+ result : relax.Expr
+ The computed result.
+ """
+ return _ffi_api.atan2(x1, x2) # type: ignore
+
+
def subtract(x1: Expr, x2: Expr) -> Expr:
"""Subtraction with numpy-style broadcasting.
diff --git a/python/tvm/relax/script/builder/ir.py
b/python/tvm/relax/script/builder/ir.py
index f62164dbd7..84ad485a33 100644
--- a/python/tvm/relax/script/builder/ir.py
+++ b/python/tvm/relax/script/builder/ir.py
@@ -54,6 +54,7 @@ from tvm.relax.op import (
assert_op,
astype,
atan,
+ atan2,
atanh,
bitwise_and,
bitwise_not,
@@ -813,6 +814,7 @@ __all__ = [
"assert_op",
"astype",
"atan",
+ "atan2",
"atanh",
"bitwise_and",
"bitwise_not",
diff --git a/python/tvm/relax/transform/legalize_ops/binary.py
b/python/tvm/relax/transform/legalize_ops/binary.py
index 85e3f06440..355fed86b9 100644
--- a/python/tvm/relax/transform/legalize_ops/binary.py
+++ b/python/tvm/relax/transform/legalize_ops/binary.py
@@ -49,6 +49,7 @@ register_legalize("relax.floor_divide",
_binary(topi.floor_divide))
register_legalize("relax.log_add_exp", _binary(topi.log_add_exp))
register_legalize("relax.multiply", _binary(topi.multiply))
register_legalize("relax.power", _binary(topi.power))
+register_legalize("relax.atan2", _binary(topi.atan2))
register_legalize("relax.subtract", _binary(topi.subtract))
register_legalize("relax.equal", _binary(topi.equal))
register_legalize("relax.mod", _binary(topi.mod))
diff --git a/python/tvm/topi/broadcast.py b/python/tvm/topi/broadcast.py
index c00495b032..e97730bc88 100644
--- a/python/tvm/topi/broadcast.py
+++ b/python/tvm/topi/broadcast.py
@@ -249,6 +249,25 @@ def power(lhs, rhs):
return _cpp.power(lhs, rhs)
+def atan2(lhs, rhs):
+ """Atan2 with auto-broadcasting.
+
+ Parameters
+ ----------
+ lhs : tvm.te.Tensor or Expr
+ The left operand (y-coordinates).
+ rhs : tvm.te.Tensor or Expr
+ The right operand (x-coordinates).
+
+ Returns
+ -------
+ ret : tvm.te.Tensor or Expr
+ Returns Expr if both operands are Expr.
+ Otherwise returns Tensor.
+ """
+ return _cpp.atan2(lhs, rhs)
+
+
def left_shift(lhs, rhs):
"""Left shift with auto-broadcasting
diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc
index 71c00e09e4..07c3364a9f 100644
--- a/src/relax/op/tensor/binary.cc
+++ b/src/relax/op/tensor/binary.cc
@@ -203,6 +203,7 @@ RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(floor_divide);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(log_add_exp);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(multiply);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(power);
+RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(atan2);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(subtract);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(mod);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(floor_mod);
diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h
index b5650fad27..a0dfbd66e6 100644
--- a/src/relax/op/tensor/binary.h
+++ b/src/relax/op/tensor/binary.h
@@ -81,6 +81,9 @@ Expr multiply(Expr x1, Expr x2);
/*! \brief Power with numpy-style broadcasting. */
Expr power(Expr x1, Expr x2);
+/*! \brief Atan2 with numpy-style broadcasting. */
+Expr atan2(Expr x1, Expr x2);
+
/*! \brief Subtraction with numpy-style broadcasting. */
Expr subtract(Expr x1, Expr x2);
diff --git a/src/topi/broadcast.cc b/src/topi/broadcast.cc
index c90b208771..cba8e29afa 100644
--- a/src/topi/broadcast.cc
+++ b/src/topi/broadcast.cc
@@ -66,6 +66,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
.TOPI_DEF_BCAST_OP("topi.maximum", topi::maximum)
.TOPI_DEF_BCAST_OP("topi.minimum", topi::minimum)
.TOPI_DEF_BCAST_OP("topi.power", topi::power)
+ .TOPI_DEF_BCAST_OP("topi.atan2", topi::atan2)
.TOPI_DEF_BCAST_OP("topi.left_shift", topi::left_shift)
.TOPI_DEF_BCAST_OP("topi.logical_and", topi::logical_and)
.TOPI_DEF_BCAST_OP("topi.logical_or", topi::logical_or)
diff --git a/tests/python/relax/test_frontend_tflite.py
b/tests/python/relax/test_frontend_tflite.py
index 69aab2d43b..37211d337a 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -769,6 +769,7 @@ def test_fill_dynamic_dims():
(tf.divide, R.divide),
(tf.math.floormod, R.floor_mod),
(tf.math.floordiv, R.floor_divide),
+ (tf.math.atan2, R.atan2),
],
)
def test_binary(tf_op, relax_op):
diff --git a/tests/python/relax/test_op_binary.py
b/tests/python/relax/test_op_binary.py
index 7049e6aaef..0ac8cf1e9f 100644
--- a/tests/python/relax/test_op_binary.py
+++ b/tests/python/relax/test_op_binary.py
@@ -34,6 +34,7 @@ def test_op_correctness():
assert relax.op.floor_divide(x, y).op == Op.get("relax.floor_divide")
assert relax.op.multiply(x, y).op == Op.get("relax.multiply")
assert relax.op.power(x, y).op == Op.get("relax.power")
+ assert relax.op.atan2(x, y).op == Op.get("relax.atan2")
assert relax.op.subtract(x, y).op == Op.get("relax.subtract")
assert relax.op.mod(x, y).op == Op.get("relax.mod")
assert relax.op.floor_mod(x, y).op == Op.get("relax.floor_mod")
@@ -71,6 +72,7 @@ def _check_inference(bb: relax.BlockBuilder, call:
relax.Call, expected_sinfo: r
(relax.op.floor_divide, tirx.FloorDiv),
(relax.op.multiply, tirx.Mul),
(relax.op.power, tirx.pow),
+ (relax.op.atan2, tirx.atan2),
(relax.op.subtract, tirx.Sub),
(relax.op.maximum, tirx.Max),
(relax.op.minimum, tirx.Min),
diff --git a/tests/python/relax/test_transform_legalize_ops_binary.py
b/tests/python/relax/test_transform_legalize_ops_binary.py
index f9b2074eab..42355ba757 100644
--- a/tests/python/relax/test_transform_legalize_ops_binary.py
+++ b/tests/python/relax/test_transform_legalize_ops_binary.py
@@ -791,6 +791,123 @@ def test_power_primvalue():
tvm.ir.assert_structural_equal(Expected, After)
+def test_atan2():
+ # fmt: off
+ @tvm.script.ir_module
+ class Atan2:
+ @R.function
+ def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1),
"float32")) -> R.Tensor((4, 3, 2, 3), "float32"):
+ gv: R.Tensor((4, 3, 2, 3), "float32") = R.atan2(x, y)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func(private=True)
+ def atan2(rxplaceholder: T.Buffer((T.int64(1), T.int64(2),
T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3),
T.int64(2), T.int64(1)), "float32"), T_atan2: T.Buffer((T.int64(4), T.int64(3),
T.int64(2), T.int64(3)), "float32")):
+ T.func_attr({"tirx.noalias": True})
+ # with T.sblock("root"):
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(3),
T.int64(2), T.int64(3)):
+ with T.sblock("T_atan2"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(rxplaceholder[T.int64(0), v_ax2, v_ax3],
rxplaceholder_1[v_ax0, v_ax1, v_ax2, T.int64(0)])
+ T.writes(T_atan2[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_atan2[v_ax0, v_ax1, v_ax2, v_ax3] =
T.atan2(rxplaceholder[T.int64(0), v_ax2, v_ax3], rxplaceholder_1[v_ax0, v_ax1,
v_ax2, T.int64(0)])
+
+ @R.function
+ def main(x: R.Tensor((1, 2, 3), dtype="float32"), y: R.Tensor((4, 3,
2, 1), dtype="float32")) -> R.Tensor((4, 3, 2, 3), dtype="float32"):
+ gv = R.call_tir(Expected.atan2, (x, y), out_sinfo=R.Tensor((4, 3,
2, 3), dtype="float32"))
+ return gv
+
+ # fmt: on
+
+ mod = LegalizeOps()(Atan2)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_atan2_symbolic():
+ # fmt: off
+ @tvm.script.ir_module
+ class Atan2:
+ @R.function
+ def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b",
"c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"):
+ a = T.int64()
+ b = T.int64()
+ c = T.int64()
+ d = T.int64()
+ gv: R.Tensor((a, b, c, d), "float32") = R.atan2(x, y)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func(private=True)
+ def atan2(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle,
var_T_atan2: T.handle):
+ T.func_attr({"tirx.noalias": True})
+ c = T.int64()
+ d = T.int64()
+ rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), c,
d))
+ a = T.int64()
+ b = T.int64()
+ rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (a, b, c,
T.int64(1)))
+ T_atan2 = T.match_buffer(var_T_atan2, (a, b, c, d))
+ for ax0, ax1, ax2, ax3 in T.grid(a, b, c, d):
+ with T.sblock("T_atan2"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(rxplaceholder[T.int64(0), v_ax2, v_ax3],
rxplaceholder_1[v_ax0, v_ax1, v_ax2, T.int64(0)])
+ T.writes(T_atan2[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_atan2[v_ax0, v_ax1, v_ax2, v_ax3] =
T.atan2(rxplaceholder[T.int64(0), v_ax2, v_ax3], rxplaceholder_1[v_ax0, v_ax1,
v_ax2, T.int64(0)])
+
+ @R.function
+ def main(x: R.Tensor((1, "c", "d"), dtype="float32"), y:
R.Tensor(("a", "b", "c", 1), dtype="float32")) -> R.Tensor(("a", "b", "c",
"d"), dtype="float32"):
+ a = T.int64()
+ b = T.int64()
+ c = T.int64()
+ d = T.int64()
+ gv = R.call_tir(Expected.atan2, (x, y), out_sinfo=R.Tensor((a, b,
c, d), dtype="float32"))
+ return gv
+ # fmt: on
+
+ mod = LegalizeOps()(Expected)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_atan2_primvalue():
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor([64, 32, 16], "float32"),
+ y: R.Prim("float32"),
+ ):
+ gv = R.atan2(x, y)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor([64, 32, 16], "float32"),
+ y: R.Prim("float32"),
+ ):
+ cls = Expected
+ gv = R.call_tir(cls.atan2, (x, y), R.Tensor([64, 32, 16],
dtype="float32"))
+ return gv
+
+ @T.prim_func(private=True)
+ def atan2(
+ lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"),
+ rhs: T.float32,
+ output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)],
"float32"),
+ ):
+ T.func_attr({"tirx.noalias": True})
+ for i, j, k in T.grid(*lhs.shape):
+ with T.sblock("T_atan2"):
+ vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+ output[vi, vj, vk] = T.atan2(lhs[vi, vj, vk], rhs)
+
+ After = LegalizeOps()(Before)
+ tvm.ir.assert_structural_equal(Expected, After)
+
+
def test_subtract():
# fmt: off
@tvm.script.ir_module