This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch revert-19842-fix/conv-transpose-dilation-legalize in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 660a337c6b4a49a4b7ab8d1388a3b3efb9ed4741 Author: Tianqi Chen <[email protected]> AuthorDate: Tue Jun 23 08:49:51 2026 -0400 Revert "[Relax] Legalize dilated conv_transpose (#19842)" This reverts commit 9808108e48af413a03ec35e512939a522132176b. --- python/tvm/relax/transform/legalize_ops/nn.py | 89 +++++++++++----------- .../python/relax/test_transform_legalize_ops_nn.py | 52 ------------- 2 files changed, 45 insertions(+), 96 deletions(-) diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 6116a41e76..d68426f02a 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -164,23 +164,24 @@ def _nn_conv1d_transpose(bb: BlockBuilder, call: Call) -> Expr: "and kernel layout other than IOW, so cannot be legalized by TOPI" ) return call - strides = [int(s) for s in call.attrs.strides] - padding = [int(p) for p in call.attrs.padding] - output_padding = [int(o) for o in call.attrs.output_padding] - groups = int(call.attrs.groups) - out_dtype = call.ty.dtype - dilation = [int(d) for d in call.attrs.dilation] - - def te_conv1d_transpose(data, kernel): - # Dilated transposed conv == transposed conv with a spatially dilated (zero-filled) kernel. - if any(d != 1 for d in dilation): - kernel = topi.nn.dilate(kernel, [1, 1, dilation[0]], name="kernel_dilate") - return topi.nn.group_conv1d_transpose_ncw( - data, kernel, strides, padding, out_dtype, output_padding, groups + dilation = call.attrs.dilation + if len(dilation) != 1 or dilation[0] != 1: + logging.info( + "TOPI conv1d_transpose does not support dilations other than 1, " + "and thus cannot be legalized by TOPI" ) + return call return bb.call_te( - te_conv1d_transpose, call.args[0], call.args[1], primfunc_name_hint="conv1d_transpose" + topi.nn.group_conv1d_transpose_ncw, + call.args[0], + call.args[1], + stride=call.attrs.strides, + padding=call.attrs.padding, + out_dtype=call.ty.dtype, + output_padding=call.attrs.output_padding, + groups=call.attrs.groups, + primfunc_name_hint="conv1d_transpose", ) @@ -198,23 +199,24 @@ def _nn_conv2d_transpose(bb: BlockBuilder, call: Call) -> Expr: "and kernel layout other than IOHW, so cannot be legalized by TOPI" ) return call - strides = [int(s) for s in call.attrs.strides] - padding = [int(p) for p in call.attrs.padding] - output_padding = [int(o) for o in call.attrs.output_padding] - groups = int(call.attrs.groups) - out_dtype = call.ty.dtype - dilation = [int(d) for d in call.attrs.dilation] - - def te_conv2d_transpose(data, kernel): - # Dilated transposed conv == transposed conv with a spatially dilated (zero-filled) kernel. - if any(d != 1 for d in dilation): - kernel = topi.nn.dilate(kernel, [1, 1, dilation[0], dilation[1]], name="kernel_dilate") - return topi.nn.group_conv2d_transpose_nchw( - data, kernel, strides, padding, out_dtype, output_padding, groups + dilation = call.attrs.dilation + if len(dilation) != 2 or any(d != 1 for d in dilation): + logging.info( + "TOPI conv2d_transpose does not support dilations other than 1, " + "and thus cannot be legalized by TOPI" ) + return call return bb.call_te( - te_conv2d_transpose, call.args[0], call.args[1], primfunc_name_hint="conv2d_transpose" + topi.nn.group_conv2d_transpose_nchw, + call.args[0], + call.args[1], + stride=call.attrs.strides, + padding=call.attrs.padding, + out_dtype=call.ty.dtype, + output_padding=call.attrs.output_padding, + groups=call.attrs.groups, + primfunc_name_hint="conv2d_transpose", ) @@ -234,25 +236,24 @@ def _nn_conv3d_transpose(bb: BlockBuilder, call: Call) -> Expr: "and kernel layout other than IODHW, so cannot be legalized by TOPI" ) return call - strides = [int(s) for s in call.attrs.strides] - padding = [int(p) for p in call.attrs.padding] - output_padding = [int(o) for o in call.attrs.output_padding] - groups = int(call.attrs.groups) - out_dtype = call.ty.dtype - dilation = [int(d) for d in call.attrs.dilation] - - def te_conv3d_transpose(data, kernel): - # Dilated transposed conv == transposed conv with a spatially dilated (zero-filled) kernel. - if any(d != 1 for d in dilation): - kernel = topi.nn.dilate( - kernel, [1, 1, dilation[0], dilation[1], dilation[2]], name="kernel_dilate" - ) - return topi.nn.group_conv3d_transpose_ncdhw( - data, kernel, strides, padding, out_dtype, output_padding, groups + dilation = call.attrs.dilation + if len(dilation) != 3 or any(d != 1 for d in dilation): + logging.info( + "TOPI conv3d_transpose does not support dilations other than 1, " + "and thus cannot be legalized by TOPI" ) + return call return bb.call_te( - te_conv3d_transpose, call.args[0], call.args[1], primfunc_name_hint="conv3d_transpose" + topi.nn.group_conv3d_transpose_ncdhw, + call.args[0], + call.args[1], + strides=call.attrs.strides, + padding=call.attrs.padding, + out_dtype=call.ty.dtype, + output_padding=call.attrs.output_padding, + groups=call.attrs.groups, + primfunc_name_hint="conv3d_transpose", ) diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index 88621b9067..601985f7be 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -725,58 +725,6 @@ def test_conv2d_transpose_symbolic(): tvm.ir.assert_structural_equal(mod, Expected) -def test_conv2d_transpose_dilation(): - # fmt: off - @tvm.script.ir_module - class Conv2dTranspose: - @R.function - def main(x: R.Tensor((1, 1, 3, 3), "float32"), w: R.Tensor((1, 1, 2, 2), "float32")): - gv = R.nn.conv2d_transpose(x, w, dilation=(2, 2)) - return gv - - @I.ir_module(s_tir=True) - class Expected: - @T.prim_func(private=True, s_tir=True) - def conv2d_transpose(x: T.Buffer((T.int64(1), T.int64(1), T.int64(3), T.int64(3)), "float32"), w: T.Buffer((T.int64(1), T.int64(1), T.int64(2), T.int64(2)), "float32"), compute: T.Buffer((T.int64(1), T.int64(1), T.int64(5), T.int64(5)), "float32")): - T.func_attr({"tirx.noalias": True}) - data_dilate = T.sblock_alloc_buffer((T.int64(1), T.int64(1), T.int64(3), T.int64(3))) - data_pad = T.sblock_alloc_buffer((T.int64(1), T.int64(1), T.int64(7), T.int64(7))) - kernel_dilate = T.sblock_alloc_buffer((T.int64(1), T.int64(1), T.int64(3), T.int64(3))) - kernel_transform = T.sblock_alloc_buffer((T.int64(1), T.int64(1), T.int64(3), T.int64(3))) - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(3), T.int64(3)): - with T.sblock("data_dilate"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - data_dilate[v_i0, v_i1, v_i2, v_i3] = x[v_i0, v_i1, v_i2, v_i3] - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(7), T.int64(7)): - with T.sblock("data_pad"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - data_pad[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(T.int64(2) <= v_i2 and v_i2 < T.int64(5) and T.int64(2) <= v_i3 and v_i3 < T.int64(5), data_dilate[v_i0, v_i1, v_i2 - T.int64(2), v_i3 - T.int64(2)], T.float32(0.0)) - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(3), T.int64(3)): - with T.sblock("kernel_dilate"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - kernel_dilate[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(v_i2 % T.int64(2) == T.int64(0) and v_i3 % T.int64(2) == T.int64(0), w[v_i0, v_i1, v_i2 // T.int64(2), v_i3 // T.int64(2)], T.float32(0.0)) - for o, i, h, w_1 in T.grid(T.int64(1), T.int64(1), T.int64(3), T.int64(3)): - with T.sblock("kernel_transform"): - v_o, v_i, v_h, v_w = T.axis.remap("SSSS", [o, i, h, w_1]) - kernel_transform[v_o, v_i, v_h, v_w] = kernel_dilate[v_i, v_o, T.int64(2) - v_h, T.int64(2) - v_w] - for b, c, h, w_1, dc, dh, dw in T.grid(T.int64(1), T.int64(1), T.int64(5), T.int64(5), T.int64(1), T.int64(3), T.int64(3)): - with T.sblock("compute"): - v_b, v_c, v_h, v_w, v_dc, v_dh, v_dw = T.axis.remap("SSSSRRR", [b, c, h, w_1, dc, dh, dw]) - with T.init(): - compute[v_b, v_c, v_h, v_w] = T.float32(0.0) - compute[v_b, v_c, v_h, v_w] = compute[v_b, v_c, v_h, v_w] + data_pad[v_b, v_dc, v_h + v_dh, v_w + v_dw] * kernel_transform[v_c, v_dc, v_dh, v_dw] - - @R.function - def main(x: R.Tensor((1, 1, 3, 3), dtype="float32"), w: R.Tensor((1, 1, 2, 2), dtype="float32")) -> R.Tensor((1, 1, 5, 5), dtype="float32"): - cls = Expected - gv = R.call_tir(cls.conv2d_transpose, (x, w), out_ty=R.Tensor((1, 1, 5, 5), dtype="float32")) - return gv - # fmt: on - - mod = LegalizeOps()(Conv2dTranspose) - tvm.ir.assert_structural_equal(mod, Expected) - - def test_max_pool2d(): # fmt: off @tvm.script.ir_module
