This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new bf6e101793 [Relax][PyTorch] Add Pad Op Support for Exported Program
and FX graph (#17821)
bf6e101793 is described below
commit bf6e1017934f0207e79b0639202e6db999633459
Author: Deivanayaki S <[email protected]>
AuthorDate: Fri Apr 18 11:24:49 2025 +0530
[Relax][PyTorch] Add Pad Op Support for Exported Program and FX graph
(#17821)
* add pad op support into frontend pipelines
fixing end of files formatting issue
fixing trailing space issues
update the docstring for pad mode in nn file
fixing lint issues
remove trailing whitespaces
fix lint format issues in test script
fix lint issue in pad file import statement
modify docstring of pad function
fixing dtype error in unity check
fixing lint issues in pad.py file
resolve arg mismatch error
resolved error while handling pad value attr
fix dtype of pad value attribute
add helper function for different pad mode
test script enhanced to check different pad mode
remove trailing whitespaces in test script
add docstring for helper function
update test script
* fix pad op arg handling in fx graph
* fix issue by updated the retrieval of value arg
---------
Co-authored-by: deivanayakisankaralingam <deiva@Deivanayaki>
---
.../frontend/torch/base_fx_graph_translator.py | 18 +++
.../frontend/torch/exported_program_translator.py | 1 +
python/tvm/relax/frontend/torch/fx_translator.py | 1 +
python/tvm/relax/op/nn/nn.py | 15 +-
python/tvm/relax/transform/legalize_ops/nn.py | 31 ++--
python/tvm/topi/nn/pad.py | 174 ++++++++++++++++++++-
.../relax/test_frontend_from_exported_program.py | 89 +++++++++++
tests/python/relax/test_frontend_from_fx.py | 89 +++++++++++
8 files changed, 401 insertions(+), 17 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 7a971c00cd..3ea70df9a1 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -901,6 +901,24 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
return self._max_pool2d_impl(x, kernel_size, stride, padding,
dilation, ceil_mode)
+ def _pad(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ pad = node.args[1]
+ mode = node.args[2] if len(node.args) > 2 else node.kwargs.get("mode",
"constant")
+ value = node.args[3] if len(node.args) > 3 else
node.kwargs.get("value", 0.0)
+ value = 0.0 if value is None else value
+
+ # Calculate symmetric padding width for each dimension
+ # and applying them in reverse order to match the input dimensions.
+ input_ndim = x.struct_info.ndim
+ pad_width = [0] * (input_ndim * 2)
+ pad_pairs = [pad[i : i + 2] for i in range(0, len(pad), 2)]
+ reversed_pairs = list(reversed(pad_pairs))
+ flattened = [value for pair in reversed_pairs for value in pair]
+ pad_width[-len(flattened) :] = flattened
+
+ return self.block_builder.emit(relax.op.nn.pad(x, pad_width, mode,
value))
+
def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var:
transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1,
3])
query = transpose_S_H(self.env[node.args[0]])
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 4084e35de5..9064de37f0 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -299,6 +299,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"log1p.default": self._log1p,
"log_softmax.int": self._log_softmax,
"neg.default": self._unary_op(relax.op.negative),
+ "pad.default": self._pad,
"prelu.default": self._prelu,
"reciprocal.default": self._reciprocal,
"relu.default": self._unary_op(relax.op.nn.relu),
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 4ef0b05aca..e6b1fdd223 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -649,6 +649,7 @@ class TorchFXImporter(BaseFXGraphImporter):
"logical_not": self._unary_op(relax.op.logical_not),
"log_softmax": self._log_softmax,
"neg": self._unary_op(relax.op.negative),
+ "pad": self._pad,
"prelu": self._prelu,
"reciprocal": self._reciprocal,
"relu": self._unary_op(relax.op.nn.relu),
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index 9d9eb3ef48..e201b596f9 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -515,9 +515,9 @@ def conv2d_transpose(
def pad(
data: Expr,
- pad_width: Tuple[Tuple[int, int], ...],
+ pad_width: Union[List[int], Tuple[int, ...]],
pad_mode: Optional[str] = "constant",
- pad_value: Optional[Union[float, Expr]] = 0.0,
+ pad_value: Optional[float] = 0.0,
):
r"""Padding
@@ -528,14 +528,15 @@ def pad(
----------
data: relax.Expr
The input data to the operator
- pad_width: Tuple[Tuple[int, int], ...], required
+ pad_width: Union[List[int], Tuple[int, ...]], required
Number of values padded to the edges of each axis, in the format
of ((before_1, after_1), ..., (before_N, after_N))
pad_mode: Optional[str]
- 'constant', 'edge', or 'reflect'
- 'constant' pads with constant_value pad_value
- 'edge' pads using the edge values of the input array
- 'reflect' pads by reflecting values with respect to the edge
+ 'constant', 'reflect', 'replicate', 'circular'
+ 'constant' pads with constant value pad_value
+ 'reflect' pads by mirroring values excluding the edge
+ 'replicate' pads by repeating the edge values.
+ 'circular' pads by looping values from the other side
Default is 'constant'
pad_value: Optional[Union[float, Expr]]
The value used for padding. Default is 0.
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py
b/python/tvm/relax/transform/legalize_ops/nn.py
index 5d942e5f64..6a6f0ed6cb 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -222,18 +222,31 @@ def _nn_conv2d_transpose(bb: BlockBuilder, call: Call) ->
Expr:
@register_legalize("relax.nn.pad")
def _nn_pad(bb: BlockBuilder, call: Call) -> Expr:
- # Unpack pad_width into two separate lists for topi.
+ pad_mode = call.attrs.pad_mode
pad_widths = call.attrs.pad_width
pad_before = pad_widths[::2]
pad_after = pad_widths[1::2]
- return bb.call_te(
- topi.nn.pad,
- call.args[0],
- pad_before=pad_before,
- pad_after=pad_after,
- pad_value=call.attrs.pad_value,
- primfunc_name_hint="pad",
- )
+ if pad_mode == "reflect":
+ return bb.call_te(
+ topi.nn.reflect_pad, call.args[0], pad_before=pad_before,
pad_after=pad_after
+ )
+ elif pad_mode == "replicate":
+ return bb.call_te(
+ topi.nn.replicate_pad, call.args[0], pad_before=pad_before,
pad_after=pad_after
+ )
+ elif pad_mode == "circular":
+ return bb.call_te(
+ topi.nn.circular_pad, call.args[0], pad_before=pad_before,
pad_after=pad_after
+ )
+ else:
+ return bb.call_te(
+ topi.nn.pad,
+ call.args[0],
+ pad_before=pad_before,
+ pad_after=pad_after,
+ pad_value=call.attrs.pad_value,
+ primfunc_name_hint="pad",
+ )
@register_legalize("relax.nn.max_pool1d")
diff --git a/python/tvm/topi/nn/pad.py b/python/tvm/topi/nn/pad.py
index 8833ef38d6..a3a7379d8d 100644
--- a/python/tvm/topi/nn/pad.py
+++ b/python/tvm/topi/nn/pad.py
@@ -19,14 +19,46 @@ from __future__ import absolute_import as _abs
import tvm
from tvm import te
+from tvm.tir import if_then_else
from .. import tag
from ..utils import equal_const_int
+def get_padded_shape(data, pad_before, pad_after=None):
+ """
+ Calculates the output shape of a tensor after applying padding.
+
+ Args:
+ data (tvm.te.Tensor): The input tensor to which padding is applied.
+ pad_before : list / tuple of n ints
+ Pad width on each dimension to pad the before the axis begin.
+ pad_after : list / tuple of n ints, optional
+ Pad width each dimension to pad the after the axis end.
+
+ Raises:
+ ValueError: If `pad_before` or `pad_after` lengths mismatch with
`data` dimensions.
+
+ Returns:
+ tuple: A tuple representing the padded shape of the tensor.
+ """
+ n = data.ndim
+ pad_after = pad_after if pad_after else pad_before
+
+ if len(pad_before) != n:
+ raise ValueError(f"pad_before length {len(pad_before)} != input dims
{n}")
+ if len(pad_after) != n:
+ raise ValueError(f"pad_after length {len(pad_after)} != input dims
{n}")
+
+ ana = tvm.arith.Analyzer()
+ out_shape = tuple(ana.simplify(data.shape[i] + pad_before[i] +
pad_after[i]) for i in range(n))
+
+ return out_shape
+
+
@tvm.te.tag_scope(tag=tag.INJECTIVE + ",pad")
def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput",
attrs=None):
- """Pad Input with zeros.
+ """Pad Input with using pad values.
Parameters
----------
@@ -145,3 +177,143 @@ def mirror_pad(data, pad_before, pad_after=None,
mode="SYMMETRIC", name="MirrorP
return data(*mapped_tuple)
return te.compute(out_shape, _pad, name=name)
+
+
[email protected]_scope(tag=tag.INJECTIVE + ",pad")
+def reflect_pad(data, pad_before, pad_after=None, name="ReflectPadInput"):
+ """
+ Apply reflect padding to the input tensor.
+
+ Parameters
+ ----------
+ data : tvm.te.Tensor
+ Input tensor.
+
+ pad_before : List[int]
+ Amount to pad before each dimension.
+
+ pad_after : List[int], optional
+ Amount to pad after each dimension. If None, defaults to pad_before.
+
+ name : str
+ Name of the resulting tensor.
+
+ Returns
+ -------
+ out : tvm.te.Tensor
+ Reflect-padded tensor.
+ """
+ out_shape = get_padded_shape(data, pad_before, pad_after)
+
+ def _pad(*indices):
+ index_tuple = []
+ for i in range(data.ndim):
+ idx = indices[i]
+ size = data.shape[i]
+ before = pad_before[i]
+
+ orig_idx = idx - before
+
+ reflected_idx = if_then_else(
+ orig_idx < 0,
+ -orig_idx, # reflect from start (no repeat)
+ if_then_else(
+ orig_idx >= size,
+ (2 * size - 2) - orig_idx, # reflect from end
+ orig_idx,
+ ),
+ )
+ index_tuple.append(reflected_idx)
+ return data(*index_tuple)
+
+ return te.compute(out_shape, _pad, name=name)
+
+
[email protected]_scope(tag=tag.INJECTIVE + ",pad")
+def replicate_pad(data, pad_before, pad_after=None, name="ReplicatePadInput"):
+ """
+ Apply replicate padding (edge padding) to the input tensor.
+
+ Parameters
+ ----------
+ data : tvm.te.Tensor
+ Input tensor.
+
+ pad_before : List[int]
+ Amount to pad before each dimension.
+
+ pad_after : List[int], optional
+ Amount to pad after each dimension. If None, defaults to pad_before.
+
+ name : str
+ Name of the resulting tensor.
+
+ Returns
+ -------
+ out : tvm.te.Tensor
+ Replicate-padded tensor.
+ """
+ out_shape = get_padded_shape(data, pad_before, pad_after)
+
+ def _pad(*indices):
+ index_tuple = []
+ for i in range(data.ndim):
+ idx = indices[i]
+ size = data.shape[i]
+ before = pad_before[i]
+
+ orig_idx = idx - before
+ clamped_idx = if_then_else(
+ orig_idx < 0,
+ tvm.tir.const(0, "int32"), # replicate first element
+ if_then_else(
+ orig_idx >= size,
+ size - 1, # replicate last element
+ orig_idx,
+ ),
+ )
+ index_tuple.append(clamped_idx)
+ return data(*index_tuple)
+
+ return te.compute(out_shape, _pad, name=name)
+
+
[email protected]_scope(tag=tag.INJECTIVE + ",pad")
+def circular_pad(data, pad_before, pad_after=None, name="CircularPadInput"):
+ """
+ Apply circular padding (wrap around) to the input tensor.
+
+ Parameters
+ ----------
+ data : tvm.te.Tensor
+ Input tensor.
+
+ pad_before : List[int]
+ Amount to pad before each dimension.
+
+ pad_after : List[int], optional
+ Amount to pad after each dimension. If None, defaults to pad_before.
+
+ name : str
+ Name of the resulting tensor.
+
+ Returns
+ -------
+ out : tvm.te.Tensor
+ Circular-padded tensor.
+ """
+ out_shape = get_padded_shape(data, pad_before, pad_after)
+
+ def _pad(*indices):
+ index_tuple = []
+ for i in range(data.ndim):
+ idx = indices[i]
+ size = data.shape[i]
+ before = pad_before[i]
+
+ orig_idx = idx - before
+ wrapped_idx = tvm.tir.indexmod(orig_idx + size, size)
+ index_tuple.append(wrapped_idx)
+ return data(*index_tuple)
+
+ return te.compute(out_shape, _pad, name=name)
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 8db9684999..4c60fcd651 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -1850,6 +1850,95 @@ def test_conv3d():
verify_model(model, example_args, binding, expected2)
+def test_pad():
+ class PadModel(torch.nn.Module):
+ def __init__(self, pad, mode="constant", value=0.0):
+ super().__init__()
+ self.pad = pad
+ self.mode = mode
+ self.value = value
+
+ def forward(self, x):
+ if self.mode == "constant":
+ return torch.nn.functional.pad(x, self.pad, mode=self.mode,
value=self.value)
+ else:
+ return torch.nn.functional.pad(x, self.pad, mode=self.mode)
+
+ @tvm.script.ir_module
+ class expected_constant:
+ @R.function
+ def main(
+ x: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad(
+ x,
+ pad_width=[0, 0, 0, 0, 2, 2, 1, 1],
+ pad_mode="constant",
+ pad_value=0.0,
+ )
+ gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class expected_reflect:
+ @R.function
+ def main(
+ x: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad(
+ x,
+ pad_width=[0, 0, 0, 0, 2, 2, 1, 1],
+ pad_mode="reflect",
+ pad_value=0.0,
+ )
+ gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class expected_replicate:
+ @R.function
+ def main(
+ x: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad(
+ x,
+ pad_width=[0, 0, 0, 0, 2, 2, 1, 1],
+ pad_mode="replicate",
+ pad_value=0.0,
+ )
+ gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class expected_circular:
+ @R.function
+ def main(
+ x: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad(
+ x,
+ pad_width=[0, 0, 0, 0, 2, 2, 1, 1],
+ pad_mode="circular",
+ pad_value=0.0,
+ )
+ gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+ verify_model(PadModel(pad=[1, 1, 2, 2]), example_args, {},
expected_constant)
+ verify_model(PadModel(pad=[1, 1, 2, 2], mode="reflect"), example_args, {},
expected_reflect)
+ verify_model(PadModel(pad=[1, 1, 2, 2], mode="replicate"), example_args,
{}, expected_replicate)
+ verify_model(PadModel(pad=[1, 1, 2, 2], mode="circular"), example_args,
{}, expected_circular)
+
+
def test_einsum():
class Einsum1(Module):
def __init__(self):
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 4a2ca336e1..53c925e14e 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -503,6 +503,95 @@ def test_conv3d():
verify_model(model, input_info, binding, expected2)
+def test_pad():
+ class PadModel(torch.nn.Module):
+ def __init__(self, pad, mode="constant", value=0.0):
+ super().__init__()
+ self.pad = pad
+ self.mode = mode
+ self.value = value
+
+ def forward(self, x):
+ if self.mode == "constant":
+ return torch.nn.functional.pad(x, self.pad, mode=self.mode,
value=self.value)
+ else:
+ return torch.nn.functional.pad(x, self.pad, mode=self.mode)
+
+ @tvm.script.ir_module
+ class expected_constant:
+ @R.function
+ def main(
+ x: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 14, 12), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad(
+ x,
+ pad_width=[0, 0, 0, 0, 2, 2, 1, 1],
+ pad_mode="constant",
+ pad_value=0.0,
+ )
+ gv: R.Tensor((1, 3, 14, 12), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class expected_reflect:
+ @R.function
+ def main(
+ x: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 14, 12), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad(
+ x,
+ pad_width=[0, 0, 0, 0, 2, 2, 1, 1],
+ pad_mode="reflect",
+ pad_value=0.0,
+ )
+ gv: R.Tensor((1, 3, 14, 12), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class expected_replicate:
+ @R.function
+ def main(
+ x: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 14, 12), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad(
+ x,
+ pad_width=[0, 0, 0, 0, 2, 2, 1, 1],
+ pad_mode="replicate",
+ pad_value=0.0,
+ )
+ gv: R.Tensor((1, 3, 14, 12), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class expected_circular:
+ @R.function
+ def main(
+ x: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 14, 12), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad(
+ x,
+ pad_width=[0, 0, 0, 0, 2, 2, 1, 1],
+ pad_mode="circular",
+ pad_value=0.0,
+ )
+ gv: R.Tensor((1, 3, 14, 12), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ input_infos = [([1, 3, 10, 10], "float32")]
+ verify_model(PadModel(pad=[1, 1, 2, 2]), input_infos, {},
expected_constant)
+ verify_model(PadModel(pad=[1, 1, 2, 2], mode="reflect"), input_infos, {},
expected_reflect)
+ verify_model(PadModel(pad=[1, 1, 2, 2], mode="replicate"), input_infos,
{}, expected_replicate)
+ verify_model(PadModel(pad=[1, 1, 2, 2], mode="circular"), input_infos, {},
expected_circular)
+
+
def test_linear():
# nn.Linear
class Dense1(Module):