This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 093d785e22 [Relax][PyTorch] Add atan2 converter (#19850)
093d785e22 is described below
commit 093d785e2297269cf71ea115aec1b62095ba11eb
Author: Javier De Jesus <[email protected]>
AuthorDate: Sun Jun 21 05:26:28 2026 +0200
[Relax][PyTorch] Add atan2 converter (#19850)
### Motivation
`torch.atan2` was not registered in either the ExportedProgram or FX
frontend,
so importing a model that uses it failed with an "Unsupported function
types"
error. The `relax.op.atan2` operator already exists and legalizes to
`topi.atan2`, so the frontends only needed to route the op to it.
### Changes
- Register `atan2` in the FX frontend and `atan2.default` in the
ExportedProgram
frontend, reusing the shared `_binary_op` helper (the same pattern as
the
existing `maximum`/`minimum`/`logaddexp` converters).
- Add a structural test in `test_frontend_from_fx.py` and
`test_frontend_from_exported_program.py`.
---
.../frontend/torch/exported_program_translator.py | 1 +
python/tvm/relax/frontend/torch/fx_translator.py | 2 ++
.../relax/test_frontend_from_exported_program.py | 26 ++++++++++++++++++++++
tests/python/relax/test_frontend_from_fx.py | 21 +++++++++++++++++
4 files changed, 50 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 6c9e3e3f5e..b96316adee 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -1603,6 +1603,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"add.Tensor": self._binary_op(relax.op.add, operator.add),
"add.Scalar": self._binary_op(relax.op.add, operator.add),
"add_.Tensor": self._binary_op(relax.op.add, operator.add),
+ "atan2.default": self._binary_op(relax.op.atan2, torch.atan2),
"bitwise_and.Tensor": self._binary_op(relax.op.bitwise_and,
operator.and_),
"bitwise_and.Scalar": self._binary_op(relax.op.bitwise_and,
operator.and_),
"bitwise_or_.Scalar": self._binary_op(relax.op.bitwise_or,
operator.or_),
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 66d17a5828..4932871bad 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -791,6 +791,7 @@ class TorchFXImporter(BaseFXGraphImporter):
) -> dict[torch.nn.Module | str, Callable[[fx.Node], relax.Var]]:
import operator
+ import torch # type: ignore
from torch import nn
return {
@@ -909,6 +910,7 @@ class TorchFXImporter(BaseFXGraphImporter):
# binary
"add": self._binary_op(relax.op.add, operator.add),
"and_": self._binary_op(relax.op.bitwise_and, operator.and_),
+ "atan2": self._binary_op(relax.op.atan2, torch.atan2),
"bitwise_or_": self._binary_op_inplace(relax.op.bitwise_or,
operator.or_),
"bitwise_or": self._binary_op(relax.op.bitwise_or, operator.or_),
"div": self._div,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index ee2f4a8f8d..dac0bd1e2a 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -1016,6 +1016,32 @@ def test_logaddexp():
verify_model(LogAddExp(), example_args, {}, expected)
+def test_atan2():
+ class Atan2(Module):
+ def forward(self, lhs, rhs):
+ return torch.atan2(lhs, rhs)
+
+ @tvm.script.ir_module
+ class expected:
+ @R.function
+ def main(
+ lhs: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ rhs: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.atan2(lhs,
rhs)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (
+ torch.randn(1, 3, 10, 10, dtype=torch.float32),
+ torch.randn(1, 3, 10, 10, dtype=torch.float32),
+ )
+ verify_model(Atan2(), example_args, {}, expected)
+
+
def test_logical_and():
class LogicalAnd(Module):
def forward(self, lhs, rhs):
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 34da69d5f0..bcb9252b89 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -5335,6 +5335,27 @@ def test_min():
verify_model(Min(), [([256, 256], "float32"), ([256, 256], "float32")],
{}, Expected1)
+def test_atan2():
+ class Atan2(Module):
+ def forward(self, x, y):
+ return torch.atan2(x, y)
+
+ @I.ir_module
+ class Expected1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((256, 256), dtype="float32"),
+ inp_1: R.Tensor((256, 256), dtype="float32"),
+ ) -> R.Tensor((256, 256), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((256, 256), dtype="float32") = R.atan2(inp_0,
inp_1)
+ gv: R.Tensor((256, 256), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Atan2(), [([256, 256], "float32"), ([256, 256], "float32")],
{}, Expected1)
+
+
def test_attention():
@I.ir_module
class Expected1: