This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 347c1fd4ca Preserve ONNX BatchNormalization inference mode (#19818)
347c1fd4ca is described below

commit 347c1fd4cabc26a180d4f2bb49fe924722905c6d
Author: Yin Li <[email protected]>
AuthorDate: Wed Jun 17 23:19:13 2026 +0100

    Preserve ONNX BatchNormalization inference mode (#19818)
    
    ### Description
    
    Fixes #19574.
    
    ONNX `BatchNormalization` defaults `training_mode` to integer `0`, while
    Relax `nn.batch_norm` expects a boolean `training` attribute. Passing
    the integer through lets the imported op carry a non-bool training attr
    and can route default ONNX BatchNormalization through training-mode
    semantics instead of inference/running-stat semantics.
    
    This converts the ONNX attribute to `bool` before constructing
    `relax.op.nn.batch_norm`, and adds an importer-level regression test
    that verifies default ONNX BatchNormalization is imported with `training
    is False`.
    
    ### Tests
    
    - `python3 -m py_compile python/tvm/relax/frontend/onnx/onnx_frontend.py
    tests/python/relax/test_frontend_onnx.py`
    - Static regression check confirming the bool conversion and new
    frontend test
    - `git diff --check`
    
    The full TVM Python test cannot run in this local checkout because
    `tvm_ffi` is not built/installed.
    
    Signed-off-by: Kevin-Li-2025 <[email protected]>
    Co-authored-by: Kevin-Li-2025 <[email protected]>
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py |  2 +-
 tests/python/relax/test_frontend_onnx.py        | 32 +++++++++++++++++++++++++
 2 files changed, 33 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 11485659fb..a8cb216e26 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -3524,7 +3524,7 @@ class BatchNormalization(OnnxOpConverter):
             axis=1,
             epsilon=epsilon,
             momentum=momentum,
-            training=training_mode,
+            training=bool(training_mode),
         )
 
 
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 414c3d5bbf..5aff95da5a 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -4187,6 +4187,38 @@ def test_batch_norm():
     check_correctness(model, opset=15)
 
 
+def test_batch_norm_defaults_to_inference_mode():
+    batch_norm_node = helper.make_node(
+        "BatchNormalization", ["x", "s", "bias", "mean", "var"], ["y"], 
epsilon=1e-2
+    )
+    graph = helper.make_graph(
+        [batch_norm_node],
+        "batch_norm_inference_attr_test",
+        inputs=[
+            helper.make_tensor_value_info("x", TensorProto.FLOAT, [2, 3, 4, 
5]),
+            helper.make_tensor_value_info("s", TensorProto.FLOAT, [3]),
+            helper.make_tensor_value_info("bias", TensorProto.FLOAT, [3]),
+            helper.make_tensor_value_info("mean", TensorProto.FLOAT, [3]),
+            helper.make_tensor_value_info("var", TensorProto.FLOAT, [3]),
+        ],
+        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [2, 3, 
4, 5])],
+    )
+    model = helper.make_model(graph, 
producer_name="batch_norm_inference_attr_test")
+    model.opset_import[0].version = 15
+
+    tvm_model = from_onnx(model, opset=15, keep_params_in_input=True)
+    batch_norm_attrs = []
+
+    def visit(expr):
+        if isinstance(expr, relax.Call) and expr.op == 
tvm.ir.Op.get("relax.nn.batch_norm"):
+            batch_norm_attrs.append(expr.attrs)
+
+    relax.analysis.post_order_visit(tvm_model["main"], visit)
+
+    assert len(batch_norm_attrs) == 1
+    assert batch_norm_attrs[0].training is False
+
+
 @pytest.mark.parametrize("pool_name", ["MaxPool", "AveragePool", "LpPool"])
 @pytest.mark.parametrize(
     "shape, auto_pad, kernel_shape, strides, pads",

Reply via email to