This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9808108e48 [Relax] Legalize dilated conv_transpose (#19842)
9808108e48 is described below
commit 9808108e48af413a03ec35e512939a522132176b
Author: Guan-Ming Chiu <[email protected]>
AuthorDate: Tue Jun 23 12:28:20 2026 +0800
[Relax] Legalize dilated conv_transpose (#19842)
## Why
relax.nn.conv{1,2,3}d_transpose with dilation > 1 silently bailed in
legalize and then crashed in VM codegen with an opaque error.
## How
- Lower dilation > 1 by zero-filling (dilating) the kernel, then reusing
the existing TOPI transposed-conv compute (1D/2D/3D).
- Unsupported non-NCHW layouts and out_layout != data_layout keep their
existing passthrough (left for downstream/BYOC codegen such as CLML),
unchanged.
- Add a 2D-dilation structural test.
Signed-off-by: Guan-Ming (Wesley) Chiu
<[email protected]>
---
python/tvm/relax/transform/legalize_ops/nn.py | 89 +++++++++++-----------
.../python/relax/test_transform_legalize_ops_nn.py | 52 +++++++++++++
2 files changed, 96 insertions(+), 45 deletions(-)
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py
b/python/tvm/relax/transform/legalize_ops/nn.py
index d68426f02a..6116a41e76 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -164,24 +164,23 @@ def _nn_conv1d_transpose(bb: BlockBuilder, call: Call) ->
Expr:
"and kernel layout other than IOW, so cannot be legalized by TOPI"
)
return call
- dilation = call.attrs.dilation
- if len(dilation) != 1 or dilation[0] != 1:
- logging.info(
- "TOPI conv1d_transpose does not support dilations other than 1, "
- "and thus cannot be legalized by TOPI"
+ strides = [int(s) for s in call.attrs.strides]
+ padding = [int(p) for p in call.attrs.padding]
+ output_padding = [int(o) for o in call.attrs.output_padding]
+ groups = int(call.attrs.groups)
+ out_dtype = call.ty.dtype
+ dilation = [int(d) for d in call.attrs.dilation]
+
+ def te_conv1d_transpose(data, kernel):
+ # Dilated transposed conv == transposed conv with a spatially dilated
(zero-filled) kernel.
+ if any(d != 1 for d in dilation):
+ kernel = topi.nn.dilate(kernel, [1, 1, dilation[0]],
name="kernel_dilate")
+ return topi.nn.group_conv1d_transpose_ncw(
+ data, kernel, strides, padding, out_dtype, output_padding, groups
)
- return call
return bb.call_te(
- topi.nn.group_conv1d_transpose_ncw,
- call.args[0],
- call.args[1],
- stride=call.attrs.strides,
- padding=call.attrs.padding,
- out_dtype=call.ty.dtype,
- output_padding=call.attrs.output_padding,
- groups=call.attrs.groups,
- primfunc_name_hint="conv1d_transpose",
+ te_conv1d_transpose, call.args[0], call.args[1],
primfunc_name_hint="conv1d_transpose"
)
@@ -199,24 +198,23 @@ def _nn_conv2d_transpose(bb: BlockBuilder, call: Call) ->
Expr:
"and kernel layout other than IOHW, so cannot be legalized by TOPI"
)
return call
- dilation = call.attrs.dilation
- if len(dilation) != 2 or any(d != 1 for d in dilation):
- logging.info(
- "TOPI conv2d_transpose does not support dilations other than 1, "
- "and thus cannot be legalized by TOPI"
+ strides = [int(s) for s in call.attrs.strides]
+ padding = [int(p) for p in call.attrs.padding]
+ output_padding = [int(o) for o in call.attrs.output_padding]
+ groups = int(call.attrs.groups)
+ out_dtype = call.ty.dtype
+ dilation = [int(d) for d in call.attrs.dilation]
+
+ def te_conv2d_transpose(data, kernel):
+ # Dilated transposed conv == transposed conv with a spatially dilated
(zero-filled) kernel.
+ if any(d != 1 for d in dilation):
+ kernel = topi.nn.dilate(kernel, [1, 1, dilation[0], dilation[1]],
name="kernel_dilate")
+ return topi.nn.group_conv2d_transpose_nchw(
+ data, kernel, strides, padding, out_dtype, output_padding, groups
)
- return call
return bb.call_te(
- topi.nn.group_conv2d_transpose_nchw,
- call.args[0],
- call.args[1],
- stride=call.attrs.strides,
- padding=call.attrs.padding,
- out_dtype=call.ty.dtype,
- output_padding=call.attrs.output_padding,
- groups=call.attrs.groups,
- primfunc_name_hint="conv2d_transpose",
+ te_conv2d_transpose, call.args[0], call.args[1],
primfunc_name_hint="conv2d_transpose"
)
@@ -236,24 +234,25 @@ def _nn_conv3d_transpose(bb: BlockBuilder, call: Call) ->
Expr:
"and kernel layout other than IODHW, so cannot be legalized by
TOPI"
)
return call
- dilation = call.attrs.dilation
- if len(dilation) != 3 or any(d != 1 for d in dilation):
- logging.info(
- "TOPI conv3d_transpose does not support dilations other than 1, "
- "and thus cannot be legalized by TOPI"
+ strides = [int(s) for s in call.attrs.strides]
+ padding = [int(p) for p in call.attrs.padding]
+ output_padding = [int(o) for o in call.attrs.output_padding]
+ groups = int(call.attrs.groups)
+ out_dtype = call.ty.dtype
+ dilation = [int(d) for d in call.attrs.dilation]
+
+ def te_conv3d_transpose(data, kernel):
+ # Dilated transposed conv == transposed conv with a spatially dilated
(zero-filled) kernel.
+ if any(d != 1 for d in dilation):
+ kernel = topi.nn.dilate(
+ kernel, [1, 1, dilation[0], dilation[1], dilation[2]],
name="kernel_dilate"
+ )
+ return topi.nn.group_conv3d_transpose_ncdhw(
+ data, kernel, strides, padding, out_dtype, output_padding, groups
)
- return call
return bb.call_te(
- topi.nn.group_conv3d_transpose_ncdhw,
- call.args[0],
- call.args[1],
- strides=call.attrs.strides,
- padding=call.attrs.padding,
- out_dtype=call.ty.dtype,
- output_padding=call.attrs.output_padding,
- groups=call.attrs.groups,
- primfunc_name_hint="conv3d_transpose",
+ te_conv3d_transpose, call.args[0], call.args[1],
primfunc_name_hint="conv3d_transpose"
)
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py
b/tests/python/relax/test_transform_legalize_ops_nn.py
index 601985f7be..88621b9067 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -725,6 +725,58 @@ def test_conv2d_transpose_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_conv2d_transpose_dilation():
+ # fmt: off
+ @tvm.script.ir_module
+ class Conv2dTranspose:
+ @R.function
+ def main(x: R.Tensor((1, 1, 3, 3), "float32"), w: R.Tensor((1, 1, 2,
2), "float32")):
+ gv = R.nn.conv2d_transpose(x, w, dilation=(2, 2))
+ return gv
+
+ @I.ir_module(s_tir=True)
+ class Expected:
+ @T.prim_func(private=True, s_tir=True)
+ def conv2d_transpose(x: T.Buffer((T.int64(1), T.int64(1), T.int64(3),
T.int64(3)), "float32"), w: T.Buffer((T.int64(1), T.int64(1), T.int64(2),
T.int64(2)), "float32"), compute: T.Buffer((T.int64(1), T.int64(1), T.int64(5),
T.int64(5)), "float32")):
+ T.func_attr({"tirx.noalias": True})
+ data_dilate = T.sblock_alloc_buffer((T.int64(1), T.int64(1),
T.int64(3), T.int64(3)))
+ data_pad = T.sblock_alloc_buffer((T.int64(1), T.int64(1),
T.int64(7), T.int64(7)))
+ kernel_dilate = T.sblock_alloc_buffer((T.int64(1), T.int64(1),
T.int64(3), T.int64(3)))
+ kernel_transform = T.sblock_alloc_buffer((T.int64(1), T.int64(1),
T.int64(3), T.int64(3)))
+ for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(3),
T.int64(3)):
+ with T.sblock("data_dilate"):
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2,
i3])
+ data_dilate[v_i0, v_i1, v_i2, v_i3] = x[v_i0, v_i1, v_i2,
v_i3]
+ for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(7),
T.int64(7)):
+ with T.sblock("data_pad"):
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2,
i3])
+ data_pad[v_i0, v_i1, v_i2, v_i3] =
T.if_then_else(T.int64(2) <= v_i2 and v_i2 < T.int64(5) and T.int64(2) <= v_i3
and v_i3 < T.int64(5), data_dilate[v_i0, v_i1, v_i2 - T.int64(2), v_i3 -
T.int64(2)], T.float32(0.0))
+ for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(3),
T.int64(3)):
+ with T.sblock("kernel_dilate"):
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2,
i3])
+ kernel_dilate[v_i0, v_i1, v_i2, v_i3] =
T.if_then_else(v_i2 % T.int64(2) == T.int64(0) and v_i3 % T.int64(2) ==
T.int64(0), w[v_i0, v_i1, v_i2 // T.int64(2), v_i3 // T.int64(2)],
T.float32(0.0))
+ for o, i, h, w_1 in T.grid(T.int64(1), T.int64(1), T.int64(3),
T.int64(3)):
+ with T.sblock("kernel_transform"):
+ v_o, v_i, v_h, v_w = T.axis.remap("SSSS", [o, i, h, w_1])
+ kernel_transform[v_o, v_i, v_h, v_w] = kernel_dilate[v_i,
v_o, T.int64(2) - v_h, T.int64(2) - v_w]
+ for b, c, h, w_1, dc, dh, dw in T.grid(T.int64(1), T.int64(1),
T.int64(5), T.int64(5), T.int64(1), T.int64(3), T.int64(3)):
+ with T.sblock("compute"):
+ v_b, v_c, v_h, v_w, v_dc, v_dh, v_dw =
T.axis.remap("SSSSRRR", [b, c, h, w_1, dc, dh, dw])
+ with T.init():
+ compute[v_b, v_c, v_h, v_w] = T.float32(0.0)
+ compute[v_b, v_c, v_h, v_w] = compute[v_b, v_c, v_h, v_w]
+ data_pad[v_b, v_dc, v_h + v_dh, v_w + v_dw] * kernel_transform[v_c, v_dc,
v_dh, v_dw]
+
+ @R.function
+ def main(x: R.Tensor((1, 1, 3, 3), dtype="float32"), w: R.Tensor((1,
1, 2, 2), dtype="float32")) -> R.Tensor((1, 1, 5, 5), dtype="float32"):
+ cls = Expected
+ gv = R.call_tir(cls.conv2d_transpose, (x, w), out_ty=R.Tensor((1,
1, 5, 5), dtype="float32"))
+ return gv
+ # fmt: on
+
+ mod = LegalizeOps()(Conv2dTranspose)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
def test_max_pool2d():
# fmt: off
@tvm.script.ir_module