This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0e3250795b [Relax][PyTorch] Cast non-bool inputs to bool in 
logical_and converter (#19679)
0e3250795b is described below

commit 0e3250795bf450438a2224b6e2d76d5568d4873e
Author: Javier De Jesus <[email protected]>
AuthorDate: Mon Jun 8 22:33:19 2026 +0200

    [Relax][PyTorch] Cast non-bool inputs to bool in logical_and converter 
(#19679)
    
    ### Motivation
    
    `torch.logical_and` accepts input tensors of any dtype (treating any
    nonzero
    element as `True`) and always returns a `bool` tensor.
    
    The PyTorch frontend did not produce that `bool` result. The
    ExportedProgram
    frontend lowered `logical_and.default` with
    `self._binary_op(relax.op.logical_and, operator.and_)`, which kept the
    operand
    dtype and emitted `relax.op.logical_and` on non-bool inputs (for example
    `float32`). `relax.op.logical_and` requires boolean inputs and otherwise
    fails
    `LegalizeOps` in the TOPI `logical_and`. The FX frontend did not
    register
    `logical_and` at all, so the op was unsupported there.
    
    ### Changes
    
    - Add a shared `_logical_and` converter in `BaseFXGraphImporter` that
    casts
    non-bool operands to `bool` before applying `relax.op.logical_and`. Bool
      operands are passed through unchanged (no redundant cast).
    - Point the `logical_and.default` (ExportedProgram) registration at the
    new
    converter, and add a `logical_and` (FX) registration that was previously
      missing, matching the existing `logical_not` converter.
    - Add a standalone `test_logical_and` to both the FX and ExportedProgram
    test
    suites asserting the corrected IR (`astype` to bool on each operand,
    then
      `logical_and`, producing a `bool` output).
    
    ### Notes
    
    The cast to `bool` lowers to an elementwise nonzero test, so it matches
    PyTorch's "nonzero is True" semantics for float, integer, and NaN
    inputs.
---
 .../frontend/torch/base_fx_graph_translator.py     | 11 +++++++++
 .../frontend/torch/exported_program_translator.py  |  2 +-
 python/tvm/relax/frontend/torch/fx_translator.py   |  1 +
 .../relax/test_frontend_from_exported_program.py   | 28 ++++++++++++++++++++++
 tests/python/relax/test_frontend_from_fx.py        | 25 +++++++++++++++++++
 5 files changed, 66 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 91b6a3a171..4c0edcb0f2 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -391,6 +391,17 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 
-1)
         return self.block_builder.emit(relax.op.nn.log_softmax(x, dim))
 
+    def _logical_and(self, node: fx.Node) -> relax.Var:
+        lhs = self.env[node.args[0]]
+        rhs = self.env[node.args[1]]
+        # torch.logical_and accepts any dtype (treating nonzero as True) and 
returns bool, but
+        # relax.op.logical_and requires boolean inputs, so cast non-bool 
inputs to bool first.
+        if lhs.struct_info.dtype != "bool":
+            lhs = self.block_builder.emit(relax.op.astype(lhs, "bool"))
+        if rhs.struct_info.dtype != "bool":
+            rhs = self.block_builder.emit(relax.op.astype(rhs, "bool"))
+        return self.block_builder.emit(relax.op.logical_and(lhs, rhs))
+
     def _logical_not(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         # torch.logical_not accepts any dtype (treating nonzero as True) and 
returns bool, but
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 976c9d45b6..7924a2305c 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -1552,7 +1552,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "log10.default": self._log10,
             "log1p.default": self._log1p,
             "logical_not.default": self._logical_not,
-            "logical_and.default": self._binary_op(relax.op.logical_and, 
operator.and_),
+            "logical_and.default": self._logical_and,
             "log_softmax.int": self._log_softmax,
             "_log_softmax.default": self._log_softmax,
             "neg.default": self._unary_op(relax.op.negative),
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 867407193a..4af86068d7 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -875,6 +875,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             "log2": self._log2,
             "log10": self._log10,
             "log1p": self._log1p,
+            "logical_and": self._logical_and,
             "logical_not": self._logical_not,
             "log_softmax": self._log_softmax,
             "neg": self._unary_op(relax.op.negative),
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 86471d8924..fa2d793f29 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -1062,6 +1062,34 @@ def test_logaddexp():
     verify_model(LogAddExp(), example_args, {}, expected)
 
 
+def test_logical_and():
+    class LogicalAnd(Module):
+        def forward(self, lhs, rhs):
+            return torch.logical_and(lhs, rhs)
+
+    @tvm.script.ir_module
+    class expected:
+        @R.function
+        def main(
+            lhs: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            rhs: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="bool")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.astype(lhs, 
dtype="bool")
+                lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.astype(rhs, 
dtype="bool")
+                lv2: R.Tensor((1, 3, 10, 10), dtype="bool") = 
R.logical_and(lv, lv1)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="bool")) = (lv2,)
+                R.output(gv)
+            return gv
+
+    example_args = (
+        torch.randn(1, 3, 10, 10, dtype=torch.float32),
+        torch.randn(1, 3, 10, 10, dtype=torch.float32),
+    )
+    verify_model(LogicalAnd(), example_args, {}, expected)
+
+
 def test_logical_not():
     class LogicalNot(Module):
         def forward(self, input):
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index abfb18cf41..94cdf43773 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -3527,6 +3527,31 @@ def test_extended_unary_ops():
     verify_model(Trunc(), input_info, {}, expected_trunc)
 
 
+def test_logical_and():
+    input_info = [([1, 3, 10, 10], "float32"), ([1, 3, 10, 10], "float32")]
+
+    class LogicalAnd(Module):
+        def forward(self, lhs, rhs):
+            return torch.logical_and(lhs, rhs)
+
+    @tvm.script.ir_module
+    class expected:
+        @R.function
+        def main(
+            lhs: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            rhs: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.astype(lhs, 
dtype="bool")
+                lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.astype(rhs, 
dtype="bool")
+                lv2: R.Tensor((1, 3, 10, 10), dtype="bool") = 
R.logical_and(lv, lv1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv2
+                R.output(gv)
+            return gv
+
+    verify_model(LogicalAnd(), input_info, {}, expected)
+
+
 def test_pow_integer():
     input_info = [([4], "int64")]
 

Reply via email to