This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new fca86a6872 [Fix] Stabilize layer_norm variance computation with 
two-pass reduction (#19643)
fca86a6872 is described below

commit fca86a6872dacea0978f1d49eb4e22935c2b0820
Author: ConvolutedDog <[email protected]>
AuthorDate: Sun May 31 13:45:46 2026 +0800

    [Fix] Stabilize layer_norm variance computation with two-pass reduction 
(#19643)
    
    This PR will fix https://github.com/apache/tvm/issues/19592.
    
    LayerNorm could produce NaN on large-value, small-variance inputs due to
    catastrophic cancellation in var = E[x^2] - E[x]^2.
    
    Switch to a numerically stable two-pass formulation:
    
      - pass1 computes mean via sum(x) / N
      - pass2 computes variance via sum((x - mean)^2) / N
---
 include/tvm/topi/nn/layer_norm.h                   |  73 ++++---
 tests/python/relax/test_frontend_onnx.py           |  38 ++++
 .../python/relax/test_transform_legalize_ops_nn.py | 224 ++++++++++++---------
 3 files changed, 211 insertions(+), 124 deletions(-)

diff --git a/include/tvm/topi/nn/layer_norm.h b/include/tvm/topi/nn/layer_norm.h
index 873a5fd1b2..d74bbce23f 100644
--- a/include/tvm/topi/nn/layer_norm.h
+++ b/include/tvm/topi/nn/layer_norm.h
@@ -25,6 +25,7 @@
 #define TVM_TOPI_NN_LAYER_NORM_H_
 
 #include <tvm/te/operation.h>
+#include <tvm/topi/reduction.h>
 #include <tvm/topi/tags.h>
 
 #include <string>
@@ -59,17 +60,18 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& 
gamma, const Tensor&
   TVM_FFI_ICHECK(data_type == DataType::Float(32) || data_type == 
DataType::Float(16))
       << "layer_norm: only support float32 and float16 for now";
   bool is_float16 = data_type == DataType::Float(16);
-  // sum x and x^2
+  // Two-pass algorithm for improved numerical stability:
+  //   pass1: mean = E[x]
+  //   pass2: var = E[(x - mean)^2]
   auto ndim = data->shape.size();
   TVM_FFI_ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor";
   auto real_axis = GetRealAxis(static_cast<int>(ndim), axis);
   auto reduce_axes = MakeReduceAxes(real_axis, data);
   auto target_shape =
       MakeReduceTargetShape(real_axis, data, /*keepdims=*/false, 
/*atleast1d=*/false);
-  auto func = MakeTupleSumReducer();
 
-  auto compute = [ndim, is_float16, &real_axis, &reduce_axes, &func,
-                  &data](const ffi::Array<Var>& indices) {
+  auto make_eval_range = [&real_axis, &reduce_axes,
+                          ndim](const ffi::Array<Var>& non_reduce_indices) {
     ffi::Array<PrimExpr> eval_range;
     int arg_counter = 0;
     int red_counter = 0;
@@ -80,34 +82,51 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& 
gamma, const Tensor&
         eval_range.push_back(reduce_axes[red_counter]);
         red_counter++;
       } else {
-        eval_range.push_back(indices[arg_counter]);
+        eval_range.push_back(non_reduce_indices[arg_counter]);
         arg_counter++;
       }
     }
-    auto square = [is_float16](const PrimExpr& x) {
-      if (is_float16) {
-        return Cast(DataType::Float(32), x) * Cast(DataType::Float(32), x);
-      }
-      return x * x;
-    };
-    if (is_float16) {
-      return func({Cast(DataType::Float(32), data(eval_range)), 
square(data(eval_range))},
-                  reduce_axes, nullptr);
-    } else {
-      return func({data(eval_range), square(data(eval_range))}, reduce_axes, 
nullptr);
-    }
+    return eval_range;
   };
 
-  auto temp_x_x2 =
-      tvm::te::compute(target_shape, compute, data->op->name + "_red_temp", 
kCommReduce);
+  Tensor temp_sum = te::compute(
+      target_shape,
+      [is_float16, &data, &reduce_axes, &make_eval_range](const 
ffi::Array<Var>& indices) {
+        auto eval_range = make_eval_range(indices);
+        PrimExpr x = data(eval_range);
+        if (is_float16) {
+          x = Cast(DataType::Float(32), x);
+        }
+        return sum(x, reduce_axes);
+      },
+      data->op->name + "_sum", kCommReduce);
 
-  auto temp_x = temp_x_x2[0];
-  auto temp_x2 = temp_x_x2[1];
-
-  auto reduce_extent = make_const(data->dtype, 1);
+  DataType reduce_dtype = is_float16 ? DataType::Float(32) : data->dtype;
+  PrimExpr reduce_extent = make_const(reduce_dtype, 1);
   for (int i : real_axis) {
     reduce_extent *= data->shape[i];
   }
+  Tensor temp_mean = te::compute(
+      target_shape,
+      [&temp_sum, &reduce_extent](const ffi::Array<Var>& indices) {
+        return temp_sum(indices) / reduce_extent;
+      },
+      data->op->name + "_mean", kInjective);
+
+  Tensor temp_var_sum = te::compute(
+      target_shape,
+      [is_float16, &data, &reduce_axes, &make_eval_range,
+       &temp_mean](const ffi::Array<Var>& indices) {
+        auto eval_range = make_eval_range(indices);
+        PrimExpr x = data(eval_range);
+        if (is_float16) {
+          x = Cast(DataType::Float(32), x);
+        }
+        PrimExpr diff = x - temp_mean(indices);
+        return sum(diff * diff, reduce_axes);
+      },
+      data->op->name + "_var_sum", kCommReduce);
+
   auto layer_norm_func = [&](const ffi::Array<Var>& indices) {
     ffi::Array<Var> reduce_indices, non_reduce_indices;
     for (int i = 0, n = static_cast<int>(indices.size()); i < n; ++i) {
@@ -117,9 +136,9 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& 
gamma, const Tensor&
         non_reduce_indices.push_back(indices[i]);
       }
     }
-    auto mean = temp_x(non_reduce_indices) / reduce_extent;
-    auto var = temp_x2(non_reduce_indices) / reduce_extent - mean * mean;
-    auto layer_norm = (data(indices) - mean) * tvm::rsqrt(var + 
make_const(var->dtype, epsilon));
+    auto mean = temp_mean(non_reduce_indices);
+    auto var = temp_var_sum(non_reduce_indices) / reduce_extent;
+    auto layer_norm = (data(indices) - mean) * rsqrt(var + 
make_const(var->dtype, epsilon));
     if (is_float16) {
       layer_norm = Cast(DataType::Float(16), layer_norm);
     }
@@ -129,7 +148,7 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& 
gamma, const Tensor&
     }
     return layer_norm;
   };
-  return tvm::te::compute(data->shape, layer_norm_func, name, tag);
+  return te::compute(data->shape, layer_norm_func, name, tag);
 }
 
 }  // namespace nn
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 4278812436..7ee10993a4 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -2309,6 +2309,44 @@ def test_layer_norm_with_nd_gamma_beta():
     check_correctness(model)
 
 
+def test_layer_norm_numerical_stability():
+    """Numerical stability test for 
https://github.com/apache/tvm/issues/19592.""";
+    layer_norm_node = helper.make_node(
+        "LayerNormalization", ["input", "scale", "bias"], ["Y"], axis=-1, 
epsilon=1e-5
+    )
+    graph = helper.make_graph(
+        [layer_norm_node],
+        "layer_norm_numerical_stability",
+        inputs=[
+            helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 4]),
+            helper.make_tensor_value_info("scale", TensorProto.FLOAT, [4]),
+            helper.make_tensor_value_info("bias", TensorProto.FLOAT, [4]),
+        ],
+        outputs=[
+            helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4]),
+        ],
+    )
+    model = helper.make_model(graph, 
producer_name="layer_norm_numerical_stability")
+
+    input_array = np.array([[80000.0, 80001.0, 80002.0, 80003.0]], 
dtype=np.float32)
+    scale_array = np.ones(4, dtype=np.float32)
+    bias_array = np.zeros(4, dtype=np.float32)
+    inputs = {"input": input_array, "scale": scale_array, "bias": bias_array}
+
+    # ONNXRuntime also returns NaN for Large-value, small-variance inputs, so 
we here
+    # compare against a two-pass reference instead of ORT.
+    mean = input_array.mean(axis=-1, keepdims=True)
+    var = ((input_array - mean) ** 2).mean(axis=-1, keepdims=True)
+    expected = ((input_array - mean) / np.sqrt(var + 1e-5) * scale_array + 
bias_array).astype(
+        np.float32
+    )
+
+    tvm_output = run_in_tvm(model, inputs=inputs, ir_version=9, opset=17)
+
+    assert np.isfinite(tvm_output.numpy()).all()
+    tvm.testing.assert_allclose(tvm_output.numpy(), expected)
+
+
 def test_rms_norm():
     # Basic test: default axis=-1
     rms_norm_node = helper.make_node("RMSNormalization", ["input", "scale"], 
["Y"], epsilon=1e-05)
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py 
b/tests/python/relax/test_transform_legalize_ops_nn.py
index 6badc7fc33..4a708b5da1 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -2734,28 +2734,40 @@ def test_layer_norm():
             return gv
 
         @T.prim_func(private=True, s_tir=True)
-        def layer_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), 
T.int64(4), T.int64(5)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), 
T.int64(5)), "float32"), rxplaceholder_2: T.Buffer((T.int64(4), T.int64(5)), 
"float32"), T_layer_norm: T.Buffer((T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)), "float32")):
+        def layer_norm(x: T.Buffer((T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)), "float32"), gamma: T.Buffer((T.int64(4), T.int64(5)), "float32"), 
beta: T.Buffer((T.int64(4), T.int64(5)), "float32"), T_layer_norm: 
T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32")):
             T.func_attr({"tirx.noalias": True})
-            rxplaceholder_red_temp_v0 = T.sblock_alloc_buffer([T.int64(2), 
T.int64(3)], dtype="float32")
-            rxplaceholder_red_temp_v1 = T.sblock_alloc_buffer([T.int64(2), 
T.int64(3)], dtype="float32")
-            for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)):
-                with T.sblock("rxplaceholder_red_temp"):
-                    ax0, ax1, k2, k3 = T.axis.remap("SSRR", [i0, i1, i2, i3])
-                    T.reads(rxplaceholder[ax0, ax1, k2, k3])
-                    T.writes(rxplaceholder_red_temp_v0[ax0, ax1], 
rxplaceholder_red_temp_v1[ax0, ax1])
+            # with T.sblock("root"):
+            x_sum = T.sblock_alloc_buffer((T.int64(2), T.int64(3)))
+            x_mean = T.sblock_alloc_buffer((T.int64(2), T.int64(3)))
+            x_var_sum = T.sblock_alloc_buffer((T.int64(2), T.int64(3)))
+            for ax0, ax1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)):
+                with T.sblock("x_sum"):
+                    v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1, 
k2, k3])
+                    T.reads(x[v_ax0, v_ax1, v_k2, v_k3])
+                    T.writes(x_sum[v_ax0, v_ax1])
+                    with T.init():
+                        x_sum[v_ax0, v_ax1] = T.float32(0.0)
+                    x_sum[v_ax0, v_ax1] = x_sum[v_ax0, v_ax1] + x[v_ax0, 
v_ax1, v_k2, v_k3]
+            for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
+                with T.sblock("x_mean"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(x_sum[v_ax0, v_ax1])
+                    T.writes(x_mean[v_ax0, v_ax1])
+                    x_mean[v_ax0, v_ax1] = x_sum[v_ax0, v_ax1] / 
T.float32(20.0)
+            for ax0, ax1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)):
+                with T.sblock("x_var_sum"):
+                    v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1, 
k2, k3])
+                    T.reads(x[v_ax0, v_ax1, v_k2, v_k3], x_mean[v_ax0, v_ax1])
+                    T.writes(x_var_sum[v_ax0, v_ax1])
                     with T.init():
-                        rxplaceholder_red_temp_v0[ax0, ax1] = T.float32(0)
-                        rxplaceholder_red_temp_v1[ax0, ax1] = T.float32(0)
-                    v_rxplaceholder_red_temp_v0: T.let[T.float32] = 
rxplaceholder_red_temp_v0[ax0, ax1] + rxplaceholder[ax0, ax1, k2, k3]
-                    v_rxplaceholder_red_temp_v1: T.let[T.float32] = 
rxplaceholder_red_temp_v1[ax0, ax1] + rxplaceholder[ax0, ax1, k2, k3] * 
rxplaceholder[ax0, ax1, k2, k3]
-                    rxplaceholder_red_temp_v0[ax0, ax1] = 
v_rxplaceholder_red_temp_v0
-                    rxplaceholder_red_temp_v1[ax0, ax1] = 
v_rxplaceholder_red_temp_v1
-            for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)):
+                        x_var_sum[v_ax0, v_ax1] = T.float32(0.0)
+                    x_var_sum[v_ax0, v_ax1] = x_var_sum[v_ax0, v_ax1] + 
(x[v_ax0, v_ax1, v_k2, v_k3] - x_mean[v_ax0, v_ax1]) * (x[v_ax0, v_ax1, v_k2, 
v_k3] - x_mean[v_ax0, v_ax1])
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), 
T.int64(4), T.int64(5)):
                 with T.sblock("T_layer_norm"):
-                    ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
-                    T.reads(rxplaceholder[ax0, ax1, ax2, ax3], 
rxplaceholder_red_temp_v0[ax0, ax1], rxplaceholder_red_temp_v1[ax0, ax1], 
rxplaceholder_1[ax2, ax3], rxplaceholder_2[ax2, ax3])
-                    T.writes(T_layer_norm[ax0, ax1, ax2, ax3])
-                    T_layer_norm[ax0, ax1, ax2, ax3] = (rxplaceholder[ax0, 
ax1, ax2, ax3] - rxplaceholder_red_temp_v0[ax0, ax1] / T.float32(20)) * 
T.rsqrt(rxplaceholder_red_temp_v1[ax0, ax1] / T.float32(20) - 
rxplaceholder_red_temp_v0[ax0, ax1] / T.float32(20) * 
(rxplaceholder_red_temp_v0[ax0, ax1] / T.float32(20)) + T.float32(1e-05), 
dtype="float32") * rxplaceholder_1[ax2, ax3] + rxplaceholder_2[ax2, ax3]
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], x_mean[v_ax0, 
v_ax1], x_var_sum[v_ax0, v_ax1], gamma[v_ax2, v_ax3], beta[v_ax2, v_ax3])
+                    T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3] = (x[v_ax0, 
v_ax1, v_ax2, v_ax3] - x_mean[v_ax0, v_ax1]) * T.rsqrt(x_var_sum[v_ax0, v_ax1] 
/ T.float32(20.0) + T.float32(1.0000000000000001e-05)) * gamma[v_ax2, v_ax3] + 
beta[v_ax2, v_ax3]
     # fmt: on
     mod = LegalizeOps()(LayerNorm)
     tvm.ir.assert_structural_equal(mod, Expected)
@@ -2780,26 +2792,36 @@ def test_layer_norm_1d():
         def layer_norm(x: T.Buffer((T.int64(3),), "float32"), 
layer_norm_weight: T.Buffer((T.int64(3),), "float32"), layer_norm_bias: 
T.Buffer((T.int64(3),), "float32"), T_layer_norm: T.Buffer((T.int64(3),), 
"float32")):
             T.func_attr({"tirx.noalias": True})
             # with T.sblock("root"):
-            x_red_temp_v0 = T.sblock_alloc_buffer(())
-            x_red_temp_v1 = T.sblock_alloc_buffer(())
+            x_sum = T.sblock_alloc_buffer(())
+            x_mean = T.sblock_alloc_buffer(())
+            x_var_sum = T.sblock_alloc_buffer(())
             for k0 in range(T.int64(3)):
-                with T.sblock("x_red_temp"):
+                with T.sblock("x_sum"):
                     v_k0 = T.axis.reduce(T.int64(3), k0)
                     T.reads(x[v_k0])
-                    T.writes(x_red_temp_v0[()], x_red_temp_v1[()])
+                    T.writes(x_sum[()])
+                    with T.init():
+                        x_sum[()] = T.float32(0.0)
+                    x_sum[()] = x_sum[()] + x[v_k0]
+            with T.sblock("x_mean"):
+                vi = T.axis.spatial(1, T.int64(0))
+                T.reads(x_sum[()])
+                T.writes(x_mean[()])
+                x_mean[()] = x_sum[()] / T.float32(3.0)
+            for k0 in range(T.int64(3)):
+                with T.sblock("x_var_sum"):
+                    v_k0 = T.axis.reduce(T.int64(3), k0)
+                    T.reads(x[v_k0], x_mean[()])
+                    T.writes(x_var_sum[()])
                     with T.init():
-                        x_red_temp_v0[()] = T.float32(0.0)
-                        x_red_temp_v1[()] = T.float32(0.0)
-                    v_x_red_temp_v0: T.let[T.float32] = x_red_temp_v0[()] + 
x[v_k0]
-                    v_x_red_temp_v1: T.let[T.float32] = x_red_temp_v1[()] + 
x[v_k0] * x[v_k0]
-                    x_red_temp_v0[()] = v_x_red_temp_v0
-                    x_red_temp_v1[()] = v_x_red_temp_v1
+                        x_var_sum[()] = T.float32(0.0)
+                    x_var_sum[()] = x_var_sum[()] + (x[v_k0] - x_mean[()]) * 
(x[v_k0] - x_mean[()])
             for ax0 in range(T.int64(3)):
                 with T.sblock("T_layer_norm"):
                     v_ax0 = T.axis.spatial(T.int64(3), ax0)
-                    T.reads(x[v_ax0], x_red_temp_v0[()], x_red_temp_v1[()], 
layer_norm_weight[v_ax0], layer_norm_bias[v_ax0])
+                    T.reads(x[v_ax0], x_mean[()], x_var_sum[()], 
layer_norm_weight[v_ax0], layer_norm_bias[v_ax0])
                     T.writes(T_layer_norm[v_ax0])
-                    T_layer_norm[v_ax0] = (x[v_ax0] - x_red_temp_v0[()] / 
T.float32(3)) * T.rsqrt(x_red_temp_v1[()] / T.float32(3) - x_red_temp_v0[()] / 
T.float32(3) * (x_red_temp_v0[()] / T.float32(3)) + 
T.float32(1.0000000000000001e-05)) * layer_norm_weight[v_ax0] + 
layer_norm_bias[v_ax0]
+                    T_layer_norm[v_ax0] = (x[v_ax0] - x_mean[()]) * 
T.rsqrt(x_var_sum[()] / T.float32(3.0) + T.float32(1.0000000000000001e-05)) * 
layer_norm_weight[v_ax0] + layer_norm_bias[v_ax0]
 
         @R.function
         def forward(x: R.Tensor((3,), dtype="float32"), layer_norm_weight: 
R.Tensor((3,), dtype="float32"), layer_norm_bias: R.Tensor((3,), 
dtype="float32")) -> R.Tensor((3,), dtype="float32"):
@@ -2827,47 +2849,45 @@ def test_layer_norm_fp16():
     @I.ir_module(s_tir=True)
     class Expected:
         @T.prim_func(private=True, s_tir=True)
-        def layer_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: 
T.handle, var_rxplaceholder_2: T.handle, var_T_layer_norm: T.handle):
+        def layer_norm(
+            x: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), 
"float16"),
+            gamma: T.Buffer((T.int64(4), T.int64(5)), "float16"),
+            beta: T.Buffer((T.int64(4), T.int64(5)), "float16"),
+            T_layer_norm: T.Buffer((T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)), "float16"),
+        ):
             T.func_attr({"tirx.noalias": True})
-            rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(2), 
T.int64(3), T.int64(4), T.int64(5)), "float16")
-            rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (T.int64(4), 
T.int64(5)), "float16")
-            rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, (T.int64(4), 
T.int64(5)), "float16")
-            T_layer_norm = T.match_buffer(var_T_layer_norm, (T.int64(2), 
T.int64(3), T.int64(4), T.int64(5)), "float16")
-            with T.sblock("root"):
-                T.reads()
-                T.writes()
-                rxplaceholder_red_temp_v0 = T.sblock_alloc_buffer((T.int64(2), 
T.int64(3)))
-                rxplaceholder_red_temp_v1 = T.sblock_alloc_buffer((T.int64(2), 
T.int64(3)))
-                for ax0 in range(T.int64(2)):
-                    for ax1 in range(T.int64(3)):
-                        for k2 in range(T.int64(4)):
-                            for k3 in range(T.int64(5)):
-                                with T.sblock("rxplaceholder_red_temp"):
-                                    v_ax0 = T.axis.spatial(T.int64(2), ax0)
-                                    v_ax1 = T.axis.spatial(T.int64(3), ax1)
-                                    v_k2 = T.axis.reduce(T.int64(4), k2)
-                                    v_k3 = T.axis.reduce(T.int64(5), k3)
-                                    T.reads(rxplaceholder[v_ax0, v_ax1, v_k2, 
v_k3])
-                                    T.writes(rxplaceholder_red_temp_v0[v_ax0, 
v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1])
-                                    with T.init():
-                                        rxplaceholder_red_temp_v0[v_ax0, 
v_ax1] = T.float32(0)
-                                        rxplaceholder_red_temp_v1[v_ax0, 
v_ax1] = T.float32(0)
-                                    v_rxplaceholder_red_temp_v0: 
T.let[T.float32] = rxplaceholder_red_temp_v0[v_ax0, v_ax1] + T.Cast("float32", 
rxplaceholder[v_ax0, v_ax1, v_k2, v_k3])
-                                    v_rxplaceholder_red_temp_v1: 
T.let[T.float32] = rxplaceholder_red_temp_v1[v_ax0, v_ax1] + T.Cast("float32", 
rxplaceholder[v_ax0, v_ax1, v_k2, v_k3]) * T.Cast("float32", 
rxplaceholder[v_ax0, v_ax1, v_k2, v_k3])
-                                    rxplaceholder_red_temp_v0[v_ax0, v_ax1] = 
v_rxplaceholder_red_temp_v0
-                                    rxplaceholder_red_temp_v1[v_ax0, v_ax1] = 
v_rxplaceholder_red_temp_v1
-                for ax0 in range(T.int64(2)):
-                    for ax1 in range(T.int64(3)):
-                        for ax2 in range(T.int64(4)):
-                            for ax3 in range(T.int64(5)):
-                                with T.sblock("T_layer_norm"):
-                                    v_ax0 = T.axis.spatial(T.int64(2), ax0)
-                                    v_ax1 = T.axis.spatial(T.int64(3), ax1)
-                                    v_ax2 = T.axis.spatial(T.int64(4), ax2)
-                                    v_ax3 = T.axis.spatial(T.int64(5), ax3)
-                                    T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, 
v_ax3], rxplaceholder_red_temp_v0[v_ax0, v_ax1], 
rxplaceholder_red_temp_v1[v_ax0, v_ax1], rxplaceholder_1[v_ax2, v_ax3], 
rxplaceholder_2[v_ax2, v_ax3])
-                                    T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2, 
v_ax3])
-                                    T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3] = 
T.Cast("float16", (T.Cast("float32", rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3]) 
- rxplaceholder_red_temp_v0[v_ax0, v_ax1] / T.Cast("float32", T.float16(4) * 
T.float16(5))) * T.rsqrt(rxplaceholder_red_temp_v1[v_ax0, v_ax1] / 
T.Cast("float32", T.float16(4) * T.float16(5)) - 
rxplaceholder_red_temp_v0[v_ax0, v_ax1] / T.Cast("float32", T.float16(4) * 
T.float16(5)) * (rxplaceholder_red_temp_v0[v_ax0, v_ax1] / T.Cast [...]
+            # with T.sblock("root"):
+            x_sum = T.sblock_alloc_buffer((T.int64(2), T.int64(3)))
+            x_mean = T.sblock_alloc_buffer((T.int64(2), T.int64(3)))
+            x_var_sum = T.sblock_alloc_buffer((T.int64(2), T.int64(3)))
+            for ax0, ax1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)):
+                with T.sblock("x_sum"):
+                    v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1, 
k2, k3])
+                    T.reads(x[v_ax0, v_ax1, v_k2, v_k3])
+                    T.writes(x_sum[v_ax0, v_ax1])
+                    with T.init():
+                        x_sum[v_ax0, v_ax1] = T.float32(0.0)
+                    x_sum[v_ax0, v_ax1] = x_sum[v_ax0, v_ax1] + 
T.Cast("float32", x[v_ax0, v_ax1, v_k2, v_k3])
+            for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
+                with T.sblock("x_mean"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(x_sum[v_ax0, v_ax1])
+                    T.writes(x_mean[v_ax0, v_ax1])
+                    x_mean[v_ax0, v_ax1] = x_sum[v_ax0, v_ax1] / 
T.float32(20.0)
+            for ax0, ax1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)):
+                with T.sblock("x_var_sum"):
+                    v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1, 
k2, k3])
+                    T.reads(x[v_ax0, v_ax1, v_k2, v_k3], x_mean[v_ax0, v_ax1])
+                    T.writes(x_var_sum[v_ax0, v_ax1])
+                    with T.init():
+                        x_var_sum[v_ax0, v_ax1] = T.float32(0.0)
+                    x_var_sum[v_ax0, v_ax1] = x_var_sum[v_ax0, v_ax1] + 
(T.Cast("float32", x[v_ax0, v_ax1, v_k2, v_k3]) - x_mean[v_ax0, v_ax1]) * 
(T.Cast("float32", x[v_ax0, v_ax1, v_k2, v_k3]) - x_mean[v_ax0, v_ax1])
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), 
T.int64(4), T.int64(5)):
+                with T.sblock("T_layer_norm"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], x_mean[v_ax0, 
v_ax1], x_var_sum[v_ax0, v_ax1], gamma[v_ax2, v_ax3], beta[v_ax2, v_ax3])
+                    T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3] = 
T.Cast("float16", (T.Cast("float32", x[v_ax0, v_ax1, v_ax2, v_ax3]) - 
x_mean[v_ax0, v_ax1]) * T.rsqrt(x_var_sum[v_ax0, v_ax1] / T.float32(20.0) + 
T.float32(1.0000000000000001e-05))) * gamma[v_ax2, v_ax3] + beta[v_ax2, v_ax3]
 
         @R.function
         def main(x: R.Tensor((2, 3, 4, 5), dtype="float16"), gamma: 
R.Tensor((4, 5), dtype="float16"), beta: R.Tensor((4, 5), dtype="float16")) -> 
R.Tensor((2, 3, 4, 5), dtype="float16"):
@@ -2901,35 +2921,45 @@ def test_layer_norm_symbolic():
             return gv
 
         @T.prim_func(private=True, s_tir=True)
-        def layer_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: 
T.handle, var_rxplaceholder_2: T.handle, var_T_layer_norm: T.handle):
+        def layer_norm(var_x: T.handle, var_gamma: T.handle, var_beta: 
T.handle, var_T_layer_norm: T.handle):
             T.func_attr({"tirx.noalias": True})
-            f = T.int64()
-            n = T.int64()
-            s = T.int64()
-            rxplaceholder = T.match_buffer(var_rxplaceholder, [n, s, f], 
dtype="float32")
-            rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [s, f], 
dtype="float32")
-            rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, [s, f], 
dtype="float32")
-            T_layer_norm = T.match_buffer(var_T_layer_norm, [n, s, f], 
dtype="float32")
-            rxplaceholder_red_temp_v0 = T.sblock_alloc_buffer([n], 
dtype="float32")
-            rxplaceholder_red_temp_v1 = T.sblock_alloc_buffer([n], 
dtype="float32")
-            for i0, i1, i2 in T.grid(n, s, f):
-                with T.sblock("rxplaceholder_red_temp"):
-                    ax0, k1, k2 = T.axis.remap("SRR", [i0, i1, i2])
-                    T.reads(rxplaceholder[ax0, k1, k2])
-                    T.writes(rxplaceholder_red_temp_v0[ax0], 
rxplaceholder_red_temp_v1[ax0])
+            n, s, f = T.int64(), T.int64(), T.int64()
+            x = T.match_buffer(var_x, (n, s, f))
+            gamma = T.match_buffer(var_gamma, (s, f))
+            beta = T.match_buffer(var_beta, (s, f))
+            T_layer_norm = T.match_buffer(var_T_layer_norm, (n, s, f))
+            # with T.sblock("root"):
+            x_sum = T.sblock_alloc_buffer((n,))
+            x_mean = T.sblock_alloc_buffer((n,))
+            x_var_sum = T.sblock_alloc_buffer((n,))
+            for ax0, k1, k2 in T.grid(n, s, f):
+                with T.sblock("x_sum"):
+                    v_ax0, v_k1, v_k2 = T.axis.remap("SRR", [ax0, k1, k2])
+                    T.reads(x[v_ax0, v_k1, v_k2])
+                    T.writes(x_sum[v_ax0])
                     with T.init():
-                        rxplaceholder_red_temp_v0[ax0] = T.float32(0)
-                        rxplaceholder_red_temp_v1[ax0] = T.float32(0)
-                    v_rxplaceholder_red_temp_v0: T.let[T.float32] = 
rxplaceholder_red_temp_v0[ax0] + rxplaceholder[ax0, k1, k2]
-                    v_rxplaceholder_red_temp_v1: T.let[T.float32] = 
rxplaceholder_red_temp_v1[ax0] + rxplaceholder[ax0, k1, k2] * 
rxplaceholder[ax0, k1, k2]
-                    rxplaceholder_red_temp_v0[ax0] = 
v_rxplaceholder_red_temp_v0
-                    rxplaceholder_red_temp_v1[ax0] = 
v_rxplaceholder_red_temp_v1
-            for i0, i1, i2 in T.grid(n, s, f):
+                        x_sum[v_ax0] = T.float32(0.0)
+                    x_sum[v_ax0] = x_sum[v_ax0] + x[v_ax0, v_k1, v_k2]
+            for ax0 in range(n):
+                with T.sblock("x_mean"):
+                    v_ax0 = T.axis.spatial(n, ax0)
+                    T.reads(x_sum[v_ax0])
+                    T.writes(x_mean[v_ax0])
+                    x_mean[v_ax0] = x_sum[v_ax0] / (T.Cast("float32", s) * 
T.Cast("float32", f))
+            for ax0, k1, k2 in T.grid(n, s, f):
+                with T.sblock("x_var_sum"):
+                    v_ax0, v_k1, v_k2 = T.axis.remap("SRR", [ax0, k1, k2])
+                    T.reads(x[v_ax0, v_k1, v_k2], x_mean[v_ax0])
+                    T.writes(x_var_sum[v_ax0])
+                    with T.init():
+                        x_var_sum[v_ax0] = T.float32(0.0)
+                    x_var_sum[v_ax0] = x_var_sum[v_ax0] + (x[v_ax0, v_k1, 
v_k2] - x_mean[v_ax0]) * (x[v_ax0, v_k1, v_k2] - x_mean[v_ax0])
+            for ax0, ax1, ax2 in T.grid(n, s, f):
                 with T.sblock("T_layer_norm"):
-                    ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
-                    T.reads(rxplaceholder[ax0, ax1, ax2], 
rxplaceholder_red_temp_v0[ax0], rxplaceholder_red_temp_v1[ax0], 
rxplaceholder_1[ax1, ax2], rxplaceholder_2[ax1, ax2])
-                    T.writes(T_layer_norm[ax0, ax1, ax2])
-                    T_layer_norm[ax0, ax1, ax2] = (rxplaceholder[ax0, ax1, 
ax2] - rxplaceholder_red_temp_v0[ax0] / (T.Cast("float32", s) * 
T.Cast("float32", f))) * T.rsqrt(rxplaceholder_red_temp_v1[ax0] / 
(T.Cast("float32", s) * T.Cast("float32", f)) - rxplaceholder_red_temp_v0[ax0] 
/ (T.Cast("float32", s) * T.Cast("float32", f)) * 
(rxplaceholder_red_temp_v0[ax0] / (T.Cast("float32", s) * T.Cast("float32", 
f))) + T.float32(1e-05), dtype="float32") * rxplaceholder_1[ax1, ax2] + 
rxplacehol [...]
+                    v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                    T.reads(x[v_ax0, v_ax1, v_ax2], x_mean[v_ax0], 
x_var_sum[v_ax0], gamma[v_ax1, v_ax2], beta[v_ax1, v_ax2])
+                    T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2])
+                    T_layer_norm[v_ax0, v_ax1, v_ax2] = (x[v_ax0, v_ax1, 
v_ax2] - x_mean[v_ax0]) * T.rsqrt(x_var_sum[v_ax0] / (T.Cast("float32", s) * 
T.Cast("float32", f)) + T.float32(1.0000000000000001e-05)) * gamma[v_ax1, 
v_ax2] + beta[v_ax1, v_ax2]
     # fmt: on
     mod = LegalizeOps()(LayerNorm)
     tvm.ir.assert_structural_equal(mod, Expected)

Reply via email to