This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new da52d7d0f6 fix: Support 5D volumetric inputs in ONNX GridSample
frontend converter (#19816)
da52d7d0f6 is described below
commit da52d7d0f62e2831212763189870caf447c85bb0
Author: Matt Van Horn <[email protected]>
AuthorDate: Thu Jun 18 12:24:23 2026 -0700
fix: Support 5D volumetric inputs in ONNX GridSample frontend converter
(#19816)
## Summary
The Relax ONNX frontend's GridSample._impl_v16 converter unconditionally
permutes the grid from ONNX [N,H,W,2] to TVM [N,2,H,W] and calls
image.grid_sample with layout="NCHW". For 5D volumetric inputs
([N,C,D,H,W] with grid [N,D,H,W,3]) this crashes at permute_dims with an
InternalError ('PermuteDims expects the number of input axes to equal
the ndim of the input tensor.
## Changes
In GridSample._impl_v16, read data.struct_info.ndim and dispatch on
rank. For ndim==4, keep the existing permute_dims(grid,[0,3,1,2]) +
grid_sample(layout="NCHW").
Fixes #19688
---------
Co-authored-by: Matt Van Horn <[email protected]>
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 28 ++++-
src/relax/op/image/resize.cc | 35 +++++--
tests/python/relax/test_frontend_onnx.py | 134 +++++++++++++++++++++++-
3 files changed, 181 insertions(+), 16 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index a8cb216e26..3cfe7c892c 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -5084,15 +5084,35 @@ class GridSample(OnnxOpConverter):
align_corners = bool(attr.get("align_corners", 0))
- # ONNX grid shape: [N, H_out, W_out, 2]
- # TVM grid shape: [N, 2, H_out, W_out]
- grid = relax.op.permute_dims(grid, [0, 3, 1, 2])
+ if hasattr(data.struct_info, "ndim"):
+ ndim = data.struct_info.ndim
+ else:
+ ndim = len(data.struct_info.shape)
+
+ if ndim == 5 and method == "bicubic":
+ raise NotImplementedError(
+ "5D (volumetric) GridSample with mode='cubic' is not supported
"
+ "(TOPI 3D grid_sample supports only bilinear and nearest)."
+ )
+
+ if ndim == 4:
+ # ONNX grid shape: [N, H_out, W_out, 2]
+ # TVM grid shape: [N, 2, H_out, W_out]
+ grid = relax.op.permute_dims(grid, [0, 3, 1, 2])
+ layout = "NCHW"
+ elif ndim == 5:
+ # ONNX grid shape: [N, D_out, H_out, W_out, 3]
+ # TVM grid shape: [N, 3, D_out, H_out, W_out]
+ grid = relax.op.permute_dims(grid, [0, 4, 1, 2, 3])
+ layout = "NCDHW"
+ else:
+ raise NotImplementedError(f"GridSample only supports 4D or 5D
input, got {ndim}D.")
return relax.op.image.grid_sample(
data,
grid,
method=method,
- layout="NCHW",
+ layout=layout,
padding_mode=padding_mode,
align_corners=align_corners,
)
diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc
index 1b84f3dfc8..653ea04c63 100644
--- a/src/relax/op/image/resize.cc
+++ b/src/relax/op/image/resize.cc
@@ -305,14 +305,19 @@ StructInfo InferStructInfoGridSample(const Call& call,
const BlockBuilder& ctx)
}
const auto* attrs = call->attrs.as<GridSampleAttrs>();
- auto [data_layout, data2NCHW] = CheckTensorLayout(call, ctx, attrs->layout,
- /*tgt_layout=*/"NCHW",
- /*tensor_name=*/"data");
+
+ // grid_sample supports both 2D (NCHW) and 3D (NCDHW) sampling. The frontend
+ // sets attrs->layout to "NCDHW" for the volumetric case; everything else is
+ // treated as the 2D NCHW path so existing behavior is preserved.
+ const bool is_ncdhw = (attrs->layout == "NCDHW");
+
+ auto [data_layout, data2tgt] =
+ CheckTensorLayout(call, ctx, attrs->layout,
+ /*tgt_layout=*/is_ncdhw ? "NCDHW" : "NCHW",
+ /*tensor_name=*/"data");
DataType out_dtype = data_sinfo->dtype;
- // Output shape: [N, C, grid_H, grid_W]
- // grid shape for NCHW layout input is [N, H_out, W_out, 2]
ffi::Optional<ShapeExpr> data_shape = CheckNdimPerLayoutAndGetShape(
call, ctx, ffi::GetRef<TensorStructInfo>(data_sinfo), data_layout);
const auto* grid_shape = grid_sinfo->shape.as<ShapeExprNode>();
@@ -321,13 +326,21 @@ StructInfo InferStructInfoGridSample(const Call& call,
const BlockBuilder& ctx)
return TensorStructInfo(out_dtype, data_layout.ndim(),
data_sinfo->vdevice);
}
- ffi::Array<PrimExpr> data_NCHW_shape =
data2NCHW.ForwardShape(data_shape.value()->values);
- // grid is [N, H_out, W_out, 2], output is [N, C, H_out, W_out]
- ffi::Array<PrimExpr> out_NCHW_shape(data_NCHW_shape);
- out_NCHW_shape.Set(2, grid_shape->values[1]); // H_out
- out_NCHW_shape.Set(3, grid_shape->values[2]); // W_out
+ ffi::Array<PrimExpr> data_tgt_shape =
data2tgt.ForwardShape(data_shape.value()->values);
+ ffi::Array<PrimExpr> out_tgt_shape(data_tgt_shape);
+ if (is_ncdhw) {
+ // grid (TVM layout) is [N, 3, D_out, H_out, W_out], output is
+ // [N, C, D_out, H_out, W_out]; the spatial extents are grid->values[2:].
+ out_tgt_shape.Set(2, grid_shape->values[2]); // D_out
+ out_tgt_shape.Set(3, grid_shape->values[3]); // H_out
+ out_tgt_shape.Set(4, grid_shape->values[4]); // W_out
+ } else {
+ // grid (TVM layout) is [N, 2, H_out, W_out], output is [N, C, H_out,
W_out]
+ out_tgt_shape.Set(2, grid_shape->values[2]); // H_out
+ out_tgt_shape.Set(3, grid_shape->values[3]); // W_out
+ }
- ffi::Array<PrimExpr> out_shape = data2NCHW.BackwardShape(out_NCHW_shape);
+ ffi::Array<PrimExpr> out_shape = data2tgt.BackwardShape(out_tgt_shape);
return TensorStructInfo(ShapeExpr(out_shape), out_dtype,
data_sinfo->vdevice);
}
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 5aff95da5a..57f780868c 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -5629,7 +5629,6 @@ def test_affine_grid():
@pytest.mark.parametrize("padding_mode", ["zeros", "border", "reflection"])
@pytest.mark.parametrize("align_corners", [0, 1])
def test_grid_sample(mode, padding_mode, align_corners):
- # Only testing 2D (NCHW) as that's what TVM currently supports
x_shape = [1, 3, 4, 4]
grid_shape = [1, 2, 2, 2]
out_shape = [x_shape[0], x_shape[1], grid_shape[1], grid_shape[2]]
@@ -5668,6 +5667,139 @@ def test_grid_sample(mode, padding_mode, align_corners):
)
[email protected]("mode", ["bilinear", "nearest"])
[email protected]("padding_mode", ["zeros", "border", "reflection"])
[email protected]("align_corners", [0, 1])
+def test_grid_sample_5d(mode, padding_mode, align_corners):
+ x_shape = [1, 1, 4, 4, 4]
+ grid_shape = [1, 4, 4, 4, 3]
+ out_shape = [x_shape[0], x_shape[1], grid_shape[1], grid_shape[2],
grid_shape[3]]
+
+ node = helper.make_node(
+ "GridSample",
+ inputs=["X", "grid"],
+ outputs=["Y"],
+ mode=mode,
+ padding_mode=padding_mode,
+ align_corners=align_corners,
+ )
+
+ graph = helper.make_graph(
+ [node],
+ "grid_sample_5d_test",
+ inputs=[
+ helper.make_tensor_value_info("X", TensorProto.FLOAT, x_shape),
+ helper.make_tensor_value_info("grid", TensorProto.FLOAT,
grid_shape),
+ ],
+ outputs=[
+ helper.make_tensor_value_info("Y", TensorProto.FLOAT, out_shape),
+ ],
+ )
+
+ rng = np.random.default_rng(0)
+ grid_data = rng.uniform(-1.25, 1.25, grid_shape).astype("float32")
+ x_data = rng.uniform(-1, 1, x_shape).astype("float32")
+
+ model = helper.make_model(graph, producer_name="grid_sample_5d_test")
+ check_correctness(
+ model,
+ inputs={"grid": grid_data, "X": x_data},
+ opset=16,
+ rtol=1e-5,
+ atol=1e-5,
+ )
+
+
+def test_grid_sample_5d_cubic_unsupported():
+ x_shape = [1, 1, 4, 4, 4]
+ grid_shape = [1, 2, 3, 5, 3]
+ out_shape = [x_shape[0], x_shape[1], grid_shape[1], grid_shape[2],
grid_shape[3]]
+
+ node = helper.make_node(
+ "GridSample",
+ inputs=["X", "grid"],
+ outputs=["Y"],
+ mode="cubic",
+ )
+
+ graph = helper.make_graph(
+ [node],
+ "grid_sample_5d_cubic_unsupported_test",
+ inputs=[
+ helper.make_tensor_value_info("X", TensorProto.FLOAT, x_shape),
+ helper.make_tensor_value_info("grid", TensorProto.FLOAT,
grid_shape),
+ ],
+ outputs=[
+ helper.make_tensor_value_info("Y", TensorProto.FLOAT, out_shape),
+ ],
+ )
+
+ model = helper.make_model(graph,
producer_name="grid_sample_5d_cubic_unsupported_test")
+ with pytest.raises(
+ NotImplementedError,
+ match="5D .*GridSample with mode='cubic' is not supported",
+ ):
+ from_onnx(model, opset=16, keep_params_in_input=True)
+
+
+def test_grid_sample_4d_non_square_output_shape():
+ x_shape = [1, 3, 4, 4]
+ grid_shape = [1, 3, 5, 2]
+ out_shape = [x_shape[0], x_shape[1], grid_shape[1], grid_shape[2]]
+
+ node = helper.make_node(
+ "GridSample",
+ inputs=["X", "grid"],
+ outputs=["Y"],
+ mode="bilinear",
+ )
+
+ graph = helper.make_graph(
+ [node],
+ "grid_sample_4d_non_square_output_shape_test",
+ inputs=[
+ helper.make_tensor_value_info("X", TensorProto.FLOAT, x_shape),
+ helper.make_tensor_value_info("grid", TensorProto.FLOAT,
grid_shape),
+ ],
+ outputs=[
+ helper.make_tensor_value_info("Y", TensorProto.FLOAT, out_shape),
+ ],
+ )
+
+ model = helper.make_model(graph,
producer_name="grid_sample_4d_non_square_output_shape_test")
+ tvm_model = from_onnx(model, opset=16, keep_params_in_input=True)
+ inferred_shape = tuple(dim.value for dim in
tvm_model["main"].ret_struct_info.shape.values)
+ assert inferred_shape == tuple(out_shape)
+
+
+def test_grid_sample_unsupported_rank():
+ x_shape = [1, 3, 4]
+ grid_shape = [1, 4, 2]
+
+ node = helper.make_node(
+ "GridSample",
+ inputs=["X", "grid"],
+ outputs=["Y"],
+ mode="bilinear",
+ )
+
+ graph = helper.make_graph(
+ [node],
+ "grid_sample_unsupported_rank_test",
+ inputs=[
+ helper.make_tensor_value_info("X", TensorProto.FLOAT, x_shape),
+ helper.make_tensor_value_info("grid", TensorProto.FLOAT,
grid_shape),
+ ],
+ outputs=[
+ helper.make_tensor_value_info("Y", TensorProto.FLOAT, x_shape),
+ ],
+ )
+
+ model = helper.make_model(graph,
producer_name="grid_sample_unsupported_rank_test")
+ with pytest.raises(NotImplementedError, match="GridSample only supports 4D
or 5D input"):
+ from_onnx(model, opset=16, keep_params_in_input=True)
+
+
def test_grid_sample_linear_mode_translation():
"""Test that ONNX mode='linear' is correctly translated to 'bilinear'.