This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new bf6e101793 [Relax][PyTorch] Add Pad Op Support for Exported Program 
and FX graph (#17821)
bf6e101793 is described below

commit bf6e1017934f0207e79b0639202e6db999633459
Author: Deivanayaki S <[email protected]>
AuthorDate: Fri Apr 18 11:24:49 2025 +0530

    [Relax][PyTorch] Add Pad Op Support for Exported Program and FX graph 
(#17821)
    
    * add pad op support into frontend pipelines
    
    fixing end of files formatting issue
    
    fixing trailing space issues
    
    update the docstring for pad mode in nn file
    
    fixing lint issues
    
    remove trailing whitespaces
    
    fix lint format issues in test script
    
    fix lint issue in pad file import statement
    
    modify docstring of pad function
    
    fixing dtype error in unity check
    
    fixing lint issues in pad.py file
    
    resolve arg mismatch error
    
    resolved error while handling pad value attr
    
    fix dtype of pad value attribute
    
    add helper function for different pad mode
    
    test script enhanced to check different pad mode
    
    remove trailing whitespaces in test script
    
    add docstring for helper function
    
    update test script
    
    * fix pad op arg handling in fx graph
    
    * fix issue by updated the retrieval of value arg
    
    ---------
    
    Co-authored-by: deivanayakisankaralingam <deiva@Deivanayaki>
---
 .../frontend/torch/base_fx_graph_translator.py     |  18 +++
 .../frontend/torch/exported_program_translator.py  |   1 +
 python/tvm/relax/frontend/torch/fx_translator.py   |   1 +
 python/tvm/relax/op/nn/nn.py                       |  15 +-
 python/tvm/relax/transform/legalize_ops/nn.py      |  31 ++--
 python/tvm/topi/nn/pad.py                          | 174 ++++++++++++++++++++-
 .../relax/test_frontend_from_exported_program.py   |  89 +++++++++++
 tests/python/relax/test_frontend_from_fx.py        |  89 +++++++++++
 8 files changed, 401 insertions(+), 17 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 7a971c00cd..3ea70df9a1 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -901,6 +901,24 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
 
         return self._max_pool2d_impl(x, kernel_size, stride, padding, 
dilation, ceil_mode)
 
+    def _pad(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        pad = node.args[1]
+        mode = node.args[2] if len(node.args) > 2 else node.kwargs.get("mode", 
"constant")
+        value = node.args[3] if len(node.args) > 3 else 
node.kwargs.get("value", 0.0)
+        value = 0.0 if value is None else value
+
+        # Calculate symmetric padding width for each dimension
+        # and applying them in reverse order to match the input dimensions.
+        input_ndim = x.struct_info.ndim
+        pad_width = [0] * (input_ndim * 2)
+        pad_pairs = [pad[i : i + 2] for i in range(0, len(pad), 2)]
+        reversed_pairs = list(reversed(pad_pairs))
+        flattened = [value for pair in reversed_pairs for value in pair]
+        pad_width[-len(flattened) :] = flattened
+
+        return self.block_builder.emit(relax.op.nn.pad(x, pad_width, mode, 
value))
+
     def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var:
         transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 
3])
         query = transpose_S_H(self.env[node.args[0]])
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 4084e35de5..9064de37f0 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -299,6 +299,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "log1p.default": self._log1p,
             "log_softmax.int": self._log_softmax,
             "neg.default": self._unary_op(relax.op.negative),
+            "pad.default": self._pad,
             "prelu.default": self._prelu,
             "reciprocal.default": self._reciprocal,
             "relu.default": self._unary_op(relax.op.nn.relu),
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 4ef0b05aca..e6b1fdd223 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -649,6 +649,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             "logical_not": self._unary_op(relax.op.logical_not),
             "log_softmax": self._log_softmax,
             "neg": self._unary_op(relax.op.negative),
+            "pad": self._pad,
             "prelu": self._prelu,
             "reciprocal": self._reciprocal,
             "relu": self._unary_op(relax.op.nn.relu),
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index 9d9eb3ef48..e201b596f9 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -515,9 +515,9 @@ def conv2d_transpose(
 
 def pad(
     data: Expr,
-    pad_width: Tuple[Tuple[int, int], ...],
+    pad_width: Union[List[int], Tuple[int, ...]],
     pad_mode: Optional[str] = "constant",
-    pad_value: Optional[Union[float, Expr]] = 0.0,
+    pad_value: Optional[float] = 0.0,
 ):
     r"""Padding
 
@@ -528,14 +528,15 @@ def pad(
     ----------
     data: relax.Expr
         The input data to the operator
-    pad_width: Tuple[Tuple[int, int], ...], required
+    pad_width: Union[List[int], Tuple[int, ...]], required
         Number of values padded to the edges of each axis, in the format
         of ((before_1, after_1), ..., (before_N, after_N))
     pad_mode: Optional[str]
-        'constant', 'edge', or 'reflect'
-        'constant' pads with constant_value pad_value
-        'edge' pads using the edge values of the input array
-        'reflect' pads by reflecting values with respect to the edge
+        'constant', 'reflect', 'replicate', 'circular'
+        'constant' pads with constant value pad_value
+        'reflect' pads by mirroring values excluding the edge
+        'replicate' pads by repeating the edge values.
+        'circular' pads by looping values from the other side
         Default is 'constant'
     pad_value: Optional[Union[float, Expr]]
         The value used for padding. Default is 0.
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py 
b/python/tvm/relax/transform/legalize_ops/nn.py
index 5d942e5f64..6a6f0ed6cb 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -222,18 +222,31 @@ def _nn_conv2d_transpose(bb: BlockBuilder, call: Call) -> 
Expr:
 
 @register_legalize("relax.nn.pad")
 def _nn_pad(bb: BlockBuilder, call: Call) -> Expr:
-    # Unpack pad_width into two separate lists for topi.
+    pad_mode = call.attrs.pad_mode
     pad_widths = call.attrs.pad_width
     pad_before = pad_widths[::2]
     pad_after = pad_widths[1::2]
-    return bb.call_te(
-        topi.nn.pad,
-        call.args[0],
-        pad_before=pad_before,
-        pad_after=pad_after,
-        pad_value=call.attrs.pad_value,
-        primfunc_name_hint="pad",
-    )
+    if pad_mode == "reflect":
+        return bb.call_te(
+            topi.nn.reflect_pad, call.args[0], pad_before=pad_before, 
pad_after=pad_after
+        )
+    elif pad_mode == "replicate":
+        return bb.call_te(
+            topi.nn.replicate_pad, call.args[0], pad_before=pad_before, 
pad_after=pad_after
+        )
+    elif pad_mode == "circular":
+        return bb.call_te(
+            topi.nn.circular_pad, call.args[0], pad_before=pad_before, 
pad_after=pad_after
+        )
+    else:
+        return bb.call_te(
+            topi.nn.pad,
+            call.args[0],
+            pad_before=pad_before,
+            pad_after=pad_after,
+            pad_value=call.attrs.pad_value,
+            primfunc_name_hint="pad",
+        )
 
 
 @register_legalize("relax.nn.max_pool1d")
diff --git a/python/tvm/topi/nn/pad.py b/python/tvm/topi/nn/pad.py
index 8833ef38d6..a3a7379d8d 100644
--- a/python/tvm/topi/nn/pad.py
+++ b/python/tvm/topi/nn/pad.py
@@ -19,14 +19,46 @@ from __future__ import absolute_import as _abs
 
 import tvm
 from tvm import te
+from tvm.tir import if_then_else
 
 from .. import tag
 from ..utils import equal_const_int
 
 
+def get_padded_shape(data, pad_before, pad_after=None):
+    """
+    Calculates the output shape of a tensor after applying padding.
+
+    Args:
+        data (tvm.te.Tensor): The input tensor to which padding is applied.
+        pad_before : list / tuple of n ints
+            Pad width on each dimension to pad the before the axis begin.
+        pad_after : list / tuple of n ints, optional
+            Pad width each dimension to pad the after the axis end.
+
+    Raises:
+        ValueError: If `pad_before` or `pad_after` lengths mismatch with 
`data` dimensions.
+
+    Returns:
+        tuple: A tuple representing the padded shape of the tensor.
+    """
+    n = data.ndim
+    pad_after = pad_after if pad_after else pad_before
+
+    if len(pad_before) != n:
+        raise ValueError(f"pad_before length {len(pad_before)} != input dims 
{n}")
+    if len(pad_after) != n:
+        raise ValueError(f"pad_after length {len(pad_after)} != input dims 
{n}")
+
+    ana = tvm.arith.Analyzer()
+    out_shape = tuple(ana.simplify(data.shape[i] + pad_before[i] + 
pad_after[i]) for i in range(n))
+
+    return out_shape
+
+
 @tvm.te.tag_scope(tag=tag.INJECTIVE + ",pad")
 def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput", 
attrs=None):
-    """Pad Input with zeros.
+    """Pad Input with using pad values.
 
     Parameters
     ----------
@@ -145,3 +177,143 @@ def mirror_pad(data, pad_before, pad_after=None, 
mode="SYMMETRIC", name="MirrorP
         return data(*mapped_tuple)
 
     return te.compute(out_shape, _pad, name=name)
+
+
[email protected]_scope(tag=tag.INJECTIVE + ",pad")
+def reflect_pad(data, pad_before, pad_after=None, name="ReflectPadInput"):
+    """
+    Apply reflect padding to the input tensor.
+
+    Parameters
+    ----------
+    data : tvm.te.Tensor
+        Input tensor.
+
+    pad_before : List[int]
+        Amount to pad before each dimension.
+
+    pad_after : List[int], optional
+        Amount to pad after each dimension. If None, defaults to pad_before.
+
+    name : str
+        Name of the resulting tensor.
+
+    Returns
+    -------
+    out : tvm.te.Tensor
+        Reflect-padded tensor.
+    """
+    out_shape = get_padded_shape(data, pad_before, pad_after)
+
+    def _pad(*indices):
+        index_tuple = []
+        for i in range(data.ndim):
+            idx = indices[i]
+            size = data.shape[i]
+            before = pad_before[i]
+
+            orig_idx = idx - before
+
+            reflected_idx = if_then_else(
+                orig_idx < 0,
+                -orig_idx,  # reflect from start (no repeat)
+                if_then_else(
+                    orig_idx >= size,
+                    (2 * size - 2) - orig_idx,  # reflect from end
+                    orig_idx,
+                ),
+            )
+            index_tuple.append(reflected_idx)
+        return data(*index_tuple)
+
+    return te.compute(out_shape, _pad, name=name)
+
+
[email protected]_scope(tag=tag.INJECTIVE + ",pad")
+def replicate_pad(data, pad_before, pad_after=None, name="ReplicatePadInput"):
+    """
+    Apply replicate padding (edge padding) to the input tensor.
+
+    Parameters
+    ----------
+    data : tvm.te.Tensor
+        Input tensor.
+
+    pad_before : List[int]
+        Amount to pad before each dimension.
+
+    pad_after : List[int], optional
+        Amount to pad after each dimension. If None, defaults to pad_before.
+
+    name : str
+        Name of the resulting tensor.
+
+    Returns
+    -------
+    out : tvm.te.Tensor
+        Replicate-padded tensor.
+    """
+    out_shape = get_padded_shape(data, pad_before, pad_after)
+
+    def _pad(*indices):
+        index_tuple = []
+        for i in range(data.ndim):
+            idx = indices[i]
+            size = data.shape[i]
+            before = pad_before[i]
+
+            orig_idx = idx - before
+            clamped_idx = if_then_else(
+                orig_idx < 0,
+                tvm.tir.const(0, "int32"),  # replicate first element
+                if_then_else(
+                    orig_idx >= size,
+                    size - 1,  # replicate last element
+                    orig_idx,
+                ),
+            )
+            index_tuple.append(clamped_idx)
+        return data(*index_tuple)
+
+    return te.compute(out_shape, _pad, name=name)
+
+
[email protected]_scope(tag=tag.INJECTIVE + ",pad")
+def circular_pad(data, pad_before, pad_after=None, name="CircularPadInput"):
+    """
+    Apply circular padding (wrap around) to the input tensor.
+
+    Parameters
+    ----------
+    data : tvm.te.Tensor
+        Input tensor.
+
+    pad_before : List[int]
+        Amount to pad before each dimension.
+
+    pad_after : List[int], optional
+        Amount to pad after each dimension. If None, defaults to pad_before.
+
+    name : str
+        Name of the resulting tensor.
+
+    Returns
+    -------
+    out : tvm.te.Tensor
+        Circular-padded tensor.
+    """
+    out_shape = get_padded_shape(data, pad_before, pad_after)
+
+    def _pad(*indices):
+        index_tuple = []
+        for i in range(data.ndim):
+            idx = indices[i]
+            size = data.shape[i]
+            before = pad_before[i]
+
+            orig_idx = idx - before
+            wrapped_idx = tvm.tir.indexmod(orig_idx + size, size)
+            index_tuple.append(wrapped_idx)
+        return data(*index_tuple)
+
+    return te.compute(out_shape, _pad, name=name)
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 8db9684999..4c60fcd651 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -1850,6 +1850,95 @@ def test_conv3d():
     verify_model(model, example_args, binding, expected2)
 
 
+def test_pad():
+    class PadModel(torch.nn.Module):
+        def __init__(self, pad, mode="constant", value=0.0):
+            super().__init__()
+            self.pad = pad
+            self.mode = mode
+            self.value = value
+
+        def forward(self, x):
+            if self.mode == "constant":
+                return torch.nn.functional.pad(x, self.pad, mode=self.mode, 
value=self.value)
+            else:
+                return torch.nn.functional.pad(x, self.pad, mode=self.mode)
+
+    @tvm.script.ir_module
+    class expected_constant:
+        @R.function
+        def main(
+            x: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad(
+                    x,
+                    pad_width=[0, 0, 0, 0, 2, 2, 1, 1],
+                    pad_mode="constant",
+                    pad_value=0.0,
+                )
+                gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class expected_reflect:
+        @R.function
+        def main(
+            x: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad(
+                    x,
+                    pad_width=[0, 0, 0, 0, 2, 2, 1, 1],
+                    pad_mode="reflect",
+                    pad_value=0.0,
+                )
+                gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class expected_replicate:
+        @R.function
+        def main(
+            x: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad(
+                    x,
+                    pad_width=[0, 0, 0, 0, 2, 2, 1, 1],
+                    pad_mode="replicate",
+                    pad_value=0.0,
+                )
+                gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class expected_circular:
+        @R.function
+        def main(
+            x: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad(
+                    x,
+                    pad_width=[0, 0, 0, 0, 2, 2, 1, 1],
+                    pad_mode="circular",
+                    pad_value=0.0,
+                )
+                gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+    verify_model(PadModel(pad=[1, 1, 2, 2]), example_args, {}, 
expected_constant)
+    verify_model(PadModel(pad=[1, 1, 2, 2], mode="reflect"), example_args, {}, 
expected_reflect)
+    verify_model(PadModel(pad=[1, 1, 2, 2], mode="replicate"), example_args, 
{}, expected_replicate)
+    verify_model(PadModel(pad=[1, 1, 2, 2], mode="circular"), example_args, 
{}, expected_circular)
+
+
 def test_einsum():
     class Einsum1(Module):
         def __init__(self):
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 4a2ca336e1..53c925e14e 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -503,6 +503,95 @@ def test_conv3d():
     verify_model(model, input_info, binding, expected2)
 
 
+def test_pad():
+    class PadModel(torch.nn.Module):
+        def __init__(self, pad, mode="constant", value=0.0):
+            super().__init__()
+            self.pad = pad
+            self.mode = mode
+            self.value = value
+
+        def forward(self, x):
+            if self.mode == "constant":
+                return torch.nn.functional.pad(x, self.pad, mode=self.mode, 
value=self.value)
+            else:
+                return torch.nn.functional.pad(x, self.pad, mode=self.mode)
+
+    @tvm.script.ir_module
+    class expected_constant:
+        @R.function
+        def main(
+            x: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 14, 12), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad(
+                    x,
+                    pad_width=[0, 0, 0, 0, 2, 2, 1, 1],
+                    pad_mode="constant",
+                    pad_value=0.0,
+                )
+                gv: R.Tensor((1, 3, 14, 12), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class expected_reflect:
+        @R.function
+        def main(
+            x: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 14, 12), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad(
+                    x,
+                    pad_width=[0, 0, 0, 0, 2, 2, 1, 1],
+                    pad_mode="reflect",
+                    pad_value=0.0,
+                )
+                gv: R.Tensor((1, 3, 14, 12), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class expected_replicate:
+        @R.function
+        def main(
+            x: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 14, 12), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad(
+                    x,
+                    pad_width=[0, 0, 0, 0, 2, 2, 1, 1],
+                    pad_mode="replicate",
+                    pad_value=0.0,
+                )
+                gv: R.Tensor((1, 3, 14, 12), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class expected_circular:
+        @R.function
+        def main(
+            x: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 14, 12), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad(
+                    x,
+                    pad_width=[0, 0, 0, 0, 2, 2, 1, 1],
+                    pad_mode="circular",
+                    pad_value=0.0,
+                )
+                gv: R.Tensor((1, 3, 14, 12), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    input_infos = [([1, 3, 10, 10], "float32")]
+    verify_model(PadModel(pad=[1, 1, 2, 2]), input_infos, {}, 
expected_constant)
+    verify_model(PadModel(pad=[1, 1, 2, 2], mode="reflect"), input_infos, {}, 
expected_reflect)
+    verify_model(PadModel(pad=[1, 1, 2, 2], mode="replicate"), input_infos, 
{}, expected_replicate)
+    verify_model(PadModel(pad=[1, 1, 2, 2], mode="circular"), input_infos, {}, 
expected_circular)
+
+
 def test_linear():
     # nn.Linear
     class Dense1(Module):

Reply via email to