This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8b012ed369 [Relax] Legalize nn.dropout as inference no-op (#19841)
8b012ed369 is described below

commit 8b012ed3696bbb47a922c8653ca01eaee3bd5c87
Author: Guan-Ming Chiu <[email protected]>
AuthorDate: Tue Jun 23 07:17:21 2026 +0800

    [Relax] Legalize nn.dropout as inference no-op (#19841)
    
    ## Related Issue
    
    Closed #19695
    
    ## Why
    
    Building any module containing relax.nn.dropout crashed in VM codegen
    because the op had no real legalization, and the ONNX frontend could not
    import it
    
    ## How
    
    - Legalize nn.dropout to pass the input through with an all-ones mask,
    matching its (output, mask) tuple result.
    - Add and register a Dropout converter in the ONNX frontend.
    - Add legalize structural and ONNX onnxruntime-parity tests.
    
    Signed-off-by: Guan-Ming (Wesley) Chiu 
<[email protected]>
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py    | 20 ++++++++++++
 python/tvm/relax/transform/legalize_ops/nn.py      |  8 +++--
 tests/python/relax/test_frontend_onnx.py           | 26 +++++++++++++++
 .../python/relax/test_transform_legalize_ops_nn.py | 38 ++++++++++++++++++++++
 4 files changed, 90 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 61f95e7130..8562cb60a2 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -3070,6 +3070,25 @@ class Identity(OnnxOpConverter):
         return inputs[0]
 
 
+class Dropout(OnnxOpConverter):
+    """Converts an onnx Dropout node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
+        ratio = float(attr.get("ratio", 0.5))
+        return relax.op.nn.dropout(inputs[0], ratio)
+
+    @classmethod
+    def _impl_v12(cls, bb, inputs, attr, params):
+        # Since opset 12 ratio is the optional second input rather than an 
attribute.
+        ratio = 0.5
+        if len(inputs) >= 2 and inputs[1] is not None:
+            const = get_constant(inputs[1], params)
+            if isinstance(const, relax.Constant):
+                ratio = float(const.data.numpy())
+        return relax.op.nn.dropout(inputs[0], ratio)
+
+
 def _onnx_resize_spatial_roi_vector(roi_full: relax.Expr, rank: int) -> 
relax.Expr:
     """Map ONNX ROI [starts..., ends...] to TOPI spatial ROI (drop N/C 
axes)."""
     return relax.op.concat(
@@ -5284,6 +5303,7 @@ def _get_convert_map():
         "ConvTranspose": ConvTranspose,
         "Flatten": Flatten,
         "Identity": Identity,
+        "Dropout": Dropout,
         "Resize": Resize,
         "Einsum": Einsum,
         "Range": Range,
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py 
b/python/tvm/relax/transform/legalize_ops/nn.py
index f87c16aa0a..d68426f02a 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -697,8 +697,12 @@ def _nn_rms_norm(bb: BlockBuilder, call: Call) -> Expr:
 
 @register_legalize("relax.nn.dropout")
 def _nn_dropout(bb: BlockBuilder, call: Call) -> Expr:
-    logging.info("Dropout is handled by frontend translator at this moment and 
is not legalized.")
-    return call
+    # Dropout is a no-op at inference: pass the input through and return an 
all-ones mask.
+    return bb.call_te(
+        lambda x: [topi.identity(x), topi.full_like(x, 1.0)],
+        call.args[0],
+        primfunc_name_hint="dropout",
+    )
 
 
 def _te_attention(
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 82e3f997d2..721e26e792 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -4254,6 +4254,32 @@ def test_maxunpool(kernel_shape, pads, strides):
     check_correctness(model, inputs={"I": indices})
 
 
+def test_dropout():
+    verify_unary("Dropout", [1, 3, 32, 32])
+    verify_unary("Dropout", [1, 3, 32, 32], opset=11, attrs={"ratio": 0.5})
+
+    # Opset 12+ passes ratio as an optional input; check it is captured into 
the relax op.
+    node = helper.make_node("Dropout", ["x", "ratio"], ["y"])
+    graph = helper.make_graph(
+        [node],
+        "dropout_ratio_input",
+        inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 
4, 4])],
+        initializer=[helper.make_tensor("ratio", TensorProto.FLOAT, [], 
[0.3])],
+        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 
4, 4])],
+    )
+    model = helper.make_model(graph, producer_name="dropout_ratio_input")
+    model.opset_import[0].version = 13
+    mod = from_onnx(model, opset=13)
+    rates = [
+        float(b.value.attrs.rate)
+        for f in mod.functions.values()
+        for block in getattr(f.body, "blocks", [])
+        for b in block.bindings
+        if getattr(getattr(b.value, "op", None), "name", "") == 
"relax.nn.dropout"
+    ]
+    assert rates == pytest.approx([0.3])
+
+
 def test_flatten():
     verify_unary("Flatten", [1, 3, 32, 32], attrs={"axis": 0})
     verify_unary("Flatten", [1, 3, 32, 32], attrs={"axis": -1})
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py 
b/tests/python/relax/test_transform_legalize_ops_nn.py
index 81648e91b4..601985f7be 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -4116,5 +4116,43 @@ def test_batch_flatten_undefined_shape():
     tvm.ir.assert_structural_equal(mod, BatchFlattenUndefinedShape)
 
 
+def test_dropout():
+    # fmt: off
+    @tvm.script.ir_module
+    class Dropout:
+        @R.function
+        def main(x: R.Tensor((2, 3), "float32")) -> R.Tuple(R.Tensor((2, 3), 
"float32"), R.Tensor((2, 3), "float32")):
+            gv: R.Tuple(R.Tensor((2, 3), "float32"), R.Tensor((2, 3), 
"float32")) = R.nn.dropout(x, rate=0.5)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func(private=True, s_tir=True)
+        def dropout(x: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: 
T.Buffer((T.int64(2), T.int64(3)), "float32"), T_full_like: 
T.Buffer((T.int64(2), T.int64(3)), "float32")):
+            T.func_attr({"tirx.noalias": True})
+            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
+                with T.sblock("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(x[v_i0, v_i1])
+                    T.writes(compute[v_i0, v_i1])
+                    compute[v_i0, v_i1] = x[v_i0, v_i1]
+            for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
+                with T.sblock("T_full_like"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads()
+                    T.writes(T_full_like[v_ax0, v_ax1])
+                    T_full_like[v_ax0, v_ax1] = T.float32(1.0)
+
+        @R.function
+        def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tuple(R.Tensor((2, 
3), dtype="float32"), R.Tensor((2, 3), dtype="float32")):
+            cls = Expected
+            gv = R.call_tir(cls.dropout, (x,), out_sinfo=[R.Tensor((2, 3), 
dtype="float32"), R.Tensor((2, 3), dtype="float32")])
+            return gv
+    # fmt: on
+
+    mod = LegalizeOps()(Dropout)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to