This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 57912395a8 [Relax][PyTorch] Decompose integer pow into repeated
multiplication (#19660)
57912395a8 is described below
commit 57912395a8f99c4b12e28190b3d99c12f2638e63
Author: Javier De Jesus <[email protected]>
AuthorDate: Wed Jun 3 23:15:47 2026 +0200
[Relax][PyTorch] Decompose integer pow into repeated multiplication (#19660)
`torch.pow` on an integer tensor returns an integer result, but the
PyTorch frontend lowered it to `relax.op.power`, which fails
`LegalizeOps` with `power only applies to float` (TOPI `power` /
`tvm::pow` requires a floating-point input).
This decomposes an integer base raised to a constant non-negative
integer exponent into repeated multiplication, so the result stays
integral and matches PyTorch. Float bases and non-constant or tensor
exponents keep using `relax.op.power` unchanged. The ONNX frontend
already uses the same decomposition (`x**3 = x*x*x`).
Added structural tests covering both the FX and ExportedProgram import
paths.
Fixes #19550
---
.../frontend/torch/base_fx_graph_translator.py | 22 ++++++++++++++++++++++
.../frontend/torch/exported_program_translator.py | 2 +-
python/tvm/relax/frontend/torch/fx_translator.py | 2 +-
.../relax/test_frontend_from_exported_program.py | 22 ++++++++++++++++++++++
tests/python/relax/test_frontend_from_fx.py | 22 ++++++++++++++++++++++
5 files changed, 68 insertions(+), 2 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index a2ebed0480..581475ebd8 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -22,6 +22,7 @@
import abc
import math
+import operator
from collections.abc import Callable
from functools import reduce
@@ -523,6 +524,27 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
return convert
+ def _pow(self, node: fx.Node) -> relax.Var:
+ lhs, rhs = self.retrieve_args(node)
+ # torch integer pow returns an integer tensor, but relax.op.power
legalizes to
+ # TOPI power which requires floating-point inputs. Decompose an
integer base with
+ # a constant non-negative integer exponent into repeated
multiplication instead.
+ if (
+ isinstance(lhs, relax.Expr)
+ and isinstance(lhs.struct_info, relax.TensorStructInfo)
+ and "int" in lhs.struct_info.dtype
+ and isinstance(rhs, int)
+ and not isinstance(rhs, bool)
+ and rhs >= 0
+ ):
+ if rhs == 0:
+ return self.block_builder.emit(relax.op.ones_like(lhs))
+ result = lhs
+ for _ in range(rhs - 1):
+ result = self.block_builder.emit(relax.op.multiply(result,
lhs))
+ return result
+ return self._binary_op(relax.op.power, operator.pow)(node)
+
def _div(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
inp_1 = args[0]
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 26f5a5918c..976c9d45b6 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -1645,7 +1645,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
relax.op.outer(self.env[node.args[0]], self.env[node.args[1]])
),
"pow.Scalar": self._binary_op(relax.op.power, operator.pow),
- "pow.Tensor_Scalar": self._binary_op(relax.op.power, operator.pow),
+ "pow.Tensor_Scalar": self._pow,
"pow.Tensor_Tensor": self._binary_op(relax.op.power, operator.pow),
"sub.Tensor": self._binary_op(relax.op.subtract, operator.sub),
"sub.Scalar": self._binary_op(relax.op.subtract, operator.sub),
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 9d27f62b42..867407193a 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -929,7 +929,7 @@ class TorchFXImporter(BaseFXGraphImporter):
"outer": lambda node: self.block_builder.emit(
relax.op.outer(self.env[node.args[0]], self.env[node.args[1]])
),
- "pow": self._binary_op(relax.op.power, operator.pow),
+ "pow": self._pow,
"or_": self._binary_op(relax.op.bitwise_or, operator.or_),
"rshift": self._binary_op(relax.op.right_shift, operator.rshift),
"rsub": self._rsub,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index d1bdad7578..86471d8924 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -1085,6 +1085,28 @@ def test_logical_not():
verify_model(LogicalNot(), example_args, {}, expected)
+def test_pow_integer():
+ class Pow(Module):
+ def forward(self, input):
+ return input.pow(4)
+
+ @tvm.script.ir_module
+ class expected:
+ @R.function
+ def main(input: R.Tensor((4,), dtype="int64")) ->
R.Tuple(R.Tensor((4,), dtype="int64")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((4,), dtype="int64") = R.multiply(input, input)
+ lv1: R.Tensor((4,), dtype="int64") = R.multiply(lv, input)
+ lv2: R.Tensor((4,), dtype="int64") = R.multiply(lv1, input)
+ gv: R.Tuple(R.Tensor((4,), dtype="int64")) = (lv2,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.tensor([-1, 1, 2, 3], dtype=torch.int64),)
+ verify_model(Pow(), example_args, {}, expected)
+
+
def test_logsoftmax():
class LogSoftmax(Module):
def __init__(self):
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 1bf71fb6eb..abfb18cf41 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -3527,6 +3527,28 @@ def test_extended_unary_ops():
verify_model(Trunc(), input_info, {}, expected_trunc)
+def test_pow_integer():
+ input_info = [([4], "int64")]
+
+ class Pow(Module):
+ def forward(self, input):
+ return input.pow(4)
+
+ @tvm.script.ir_module
+ class expected:
+ @R.function
+ def main(inp_0: R.Tensor((4,), dtype="int64")) -> R.Tensor((4,),
dtype="int64"):
+ with R.dataflow():
+ lv: R.Tensor((4,), dtype="int64") = R.multiply(inp_0, inp_0)
+ lv1: R.Tensor((4,), dtype="int64") = R.multiply(lv, inp_0)
+ lv2: R.Tensor((4,), dtype="int64") = R.multiply(lv1, inp_0)
+ gv: R.Tensor((4,), dtype="int64") = lv2
+ R.output(gv)
+ return gv
+
+ verify_model(Pow(), input_info, {}, expected)
+
+
def test_interpolate():
input_info = [([1, 3, 10, 10], "float32")]