This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 347c1fd4ca Preserve ONNX BatchNormalization inference mode (#19818)
347c1fd4ca is described below
commit 347c1fd4cabc26a180d4f2bb49fe924722905c6d
Author: Yin Li <[email protected]>
AuthorDate: Wed Jun 17 23:19:13 2026 +0100
Preserve ONNX BatchNormalization inference mode (#19818)
### Description
Fixes #19574.
ONNX `BatchNormalization` defaults `training_mode` to integer `0`, while
Relax `nn.batch_norm` expects a boolean `training` attribute. Passing
the integer through lets the imported op carry a non-bool training attr
and can route default ONNX BatchNormalization through training-mode
semantics instead of inference/running-stat semantics.
This converts the ONNX attribute to `bool` before constructing
`relax.op.nn.batch_norm`, and adds an importer-level regression test
that verifies default ONNX BatchNormalization is imported with `training
is False`.
### Tests
- `python3 -m py_compile python/tvm/relax/frontend/onnx/onnx_frontend.py
tests/python/relax/test_frontend_onnx.py`
- Static regression check confirming the bool conversion and new
frontend test
- `git diff --check`
The full TVM Python test cannot run in this local checkout because
`tvm_ffi` is not built/installed.
Signed-off-by: Kevin-Li-2025 <[email protected]>
Co-authored-by: Kevin-Li-2025 <[email protected]>
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 2 +-
tests/python/relax/test_frontend_onnx.py | 32 +++++++++++++++++++++++++
2 files changed, 33 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 11485659fb..a8cb216e26 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -3524,7 +3524,7 @@ class BatchNormalization(OnnxOpConverter):
axis=1,
epsilon=epsilon,
momentum=momentum,
- training=training_mode,
+ training=bool(training_mode),
)
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 414c3d5bbf..5aff95da5a 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -4187,6 +4187,38 @@ def test_batch_norm():
check_correctness(model, opset=15)
+def test_batch_norm_defaults_to_inference_mode():
+ batch_norm_node = helper.make_node(
+ "BatchNormalization", ["x", "s", "bias", "mean", "var"], ["y"],
epsilon=1e-2
+ )
+ graph = helper.make_graph(
+ [batch_norm_node],
+ "batch_norm_inference_attr_test",
+ inputs=[
+ helper.make_tensor_value_info("x", TensorProto.FLOAT, [2, 3, 4,
5]),
+ helper.make_tensor_value_info("s", TensorProto.FLOAT, [3]),
+ helper.make_tensor_value_info("bias", TensorProto.FLOAT, [3]),
+ helper.make_tensor_value_info("mean", TensorProto.FLOAT, [3]),
+ helper.make_tensor_value_info("var", TensorProto.FLOAT, [3]),
+ ],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [2, 3,
4, 5])],
+ )
+ model = helper.make_model(graph,
producer_name="batch_norm_inference_attr_test")
+ model.opset_import[0].version = 15
+
+ tvm_model = from_onnx(model, opset=15, keep_params_in_input=True)
+ batch_norm_attrs = []
+
+ def visit(expr):
+ if isinstance(expr, relax.Call) and expr.op ==
tvm.ir.Op.get("relax.nn.batch_norm"):
+ batch_norm_attrs.append(expr.attrs)
+
+ relax.analysis.post_order_visit(tvm_model["main"], visit)
+
+ assert len(batch_norm_attrs) == 1
+ assert batch_norm_attrs[0].training is False
+
+
@pytest.mark.parametrize("pool_name", ["MaxPool", "AveragePool", "LpPool"])
@pytest.mark.parametrize(
"shape, auto_pad, kernel_shape, strides, pads",