Hello,I am trying to fuse one layer convolution computation and their relu 
result into next layer convolution computation. I tried two methods, one is to 
use te.sum expression as a parameter of another te.sum, and the other is to use 
s.compute_inline(), but both fail. I would like to know if it is possible to 
combine two reduce stages (te.sum) into one reduce stage in te, if not, can 
relay and tir complete the expression of this function.here is current tir 
without fusion:

        primfn(args: handle, arg_type_ids: handle, num_args: int32, 
out_ret_value: handle, out_ret_tcode: handle, resource_handle: handle) -> int32`
      attr = {"target": meta[Target][0], "tir.noalias": True, "global_symbol": 
"myfunc_fusion", "from_legacy_te_schedule": True, "tir.is_entry_func": True, 
"calling_conv": 1} {
      assert((num_args == 4), "myfunc_fusion: num_args should be 4")
      let arg0: handle = @tir.tvm_struct_get(args, 0, 12, dtype=handle)
      let arg0.code: int32 = (int32*)arg_type_ids[0]
      let arg1: handle = @tir.tvm_struct_get(args, 1, 12, dtype=handle)
      let arg1.code: int32 = (int32*)arg_type_ids[1]
      let arg2: handle = @tir.tvm_struct_get(args, 2, 12, dtype=handle)
      let arg2.code: int32 = (int32*)arg_type_ids[2]
      let arg3: handle = @tir.tvm_struct_get(args, 3, 12, dtype=handle)
      let arg3.code: int32 = (int32*)arg_type_ids[3]
      let A: Pointer(float32) = @tir.tvm_struct_get(arg0, 0, 1, dtype=handle)
      attr [A] "storage_alignment" = 128;
      let arg0.shape: handle = @tir.tvm_struct_get(arg0, 0, 2, dtype=handle)
      let arg0.strides: handle = @tir.tvm_struct_get(arg0, 0, 3, dtype=handle)
      let dev_id: int32 = @tir.tvm_struct_get(arg0, 0, 9, dtype=int32)
      let W: Pointer(float32) = @tir.tvm_struct_get(arg1, 0, 1, dtype=handle)
      attr [W] "storage_alignment" = 128;
      let arg1.shape: handle = @tir.tvm_struct_get(arg1, 0, 2, dtype=handle)
      let arg1.strides: handle = @tir.tvm_struct_get(arg1, 0, 3, dtype=handle)
      let W_2: Pointer(float32) = @tir.tvm_struct_get(arg2, 0, 1, dtype=handle)
      attr [W_2] "storage_alignment" = 128;
      let arg2.shape: handle = @tir.tvm_struct_get(arg2, 0, 2, dtype=handle)
      let arg2.strides: handle = @tir.tvm_struct_get(arg2, 0, 3, dtype=handle)
      let C: Pointer(float32) = @tir.tvm_struct_get(arg3, 0, 1, dtype=handle)
      attr [C] "storage_alignment" = 128;
      let arg3.shape: handle = @tir.tvm_struct_get(arg3, 0, 2, dtype=handle)
      let arg3.strides: handle = @tir.tvm_struct_get(arg3, 0, 3, dtype=handle)
      assert(((((arg0.code == 3) || (arg0.code == 13)) || (arg0.code == 7)) || 
(arg0.code == 4)), "myfunc_fusion: Expect arg[0] to be pointer")
      assert(((((arg1.code == 3) || (arg1.code == 13)) || (arg1.code == 7)) || 
(arg1.code == 4)), "myfunc_fusion: Expect arg[1] to be pointer")
      assert(((((arg2.code == 3) || (arg2.code == 13)) || (arg2.code == 7)) || 
(arg2.code == 4)), "myfunc_fusion: Expect arg[2] to be pointer")
      assert(((((arg3.code == 3) || (arg3.code == 13)) || (arg3.code == 7)) || 
(arg3.code == 4)), "myfunc_fusion: Expect arg[3] to be pointer")
      assert((4 == @tir.tvm_struct_get(arg0, 0, 4, dtype=int32)), "arg0.ndim is 
expected to equal 4")
      assert((4 == @tir.tvm_struct_get(arg0, 0, 4, dtype=int32)), "arg0.ndim is 
expected to equal 4")
      assert((((@tir.tvm_struct_get(arg0, 0, 5, dtype=uint8) == 2u8) && 
(@tir.tvm_struct_get(arg0, 0, 6, dtype=uint8) == 32u8)) && 
(@tir.tvm_struct_get(arg0, 0, 7, dtype=uint16) == 1u16)), "arg0.dtype is 
expected to be float32")
      assert((56 == cast(int32, (int64*)arg0.shape[0])), "Argument 
arg0.shape[0] has an unsatisfied constraint: (56 == int32(arg0.shape[0]))")
      assert((56 == cast(int32, (int64*)arg0.shape[1])), "Argument 
arg0.shape[1] has an unsatisfied constraint: (56 == int32(arg0.shape[1]))")
      assert((64 == cast(int32, (int64*)arg0.shape[2])), "Argument 
arg0.shape[2] has an unsatisfied constraint: (64 == int32(arg0.shape[2]))")
      assert((3 == cast(int32, (int64*)arg0.shape[3])), "Argument arg0.shape[3] 
has an unsatisfied constraint: (3 == int32(arg0.shape[3]))")
       {
        if !@tir.isnullptr(arg0.strides, dtype=bool) {
          assert(((((1 == cast(int32, (int64*)arg0.strides[3])) && (3 == 
cast(int32, (int64*)arg0.strides[2]))) && (192 == cast(int32, 
(int64*)arg0.strides[1]))) && (10752 == cast(int32, (int64*)arg0.strides[0]))), 
"arg0.strides: expected to be compact array")
          0
        }
        assert((0u64 == @tir.tvm_struct_get(arg0, 0, 8, dtype=uint64)), 
"Argument arg0.byte_offset has an unsatisfied constraint: ((uint64)0 == 
tir.tvm_struct_get(arg0, 0, 8))")
        assert((1 == @tir.tvm_struct_get(arg0, 0, 10, dtype=int32)), "Argument 
arg0.device_type has an unsatisfied constraint: (1 == tir.tvm_struct_get(arg0, 
0, 10))")
        assert((4 == @tir.tvm_struct_get(arg1, 0, 4, dtype=int32)), "arg1.ndim 
is expected to equal 4")
        assert((4 == @tir.tvm_struct_get(arg1, 0, 4, dtype=int32)), "arg1.ndim 
is expected to equal 4")
        assert((((@tir.tvm_struct_get(arg1, 0, 5, dtype=uint8) == 2u8) && 
(@tir.tvm_struct_get(arg1, 0, 6, dtype=uint8) == 32u8)) && 
(@tir.tvm_struct_get(arg1, 0, 7, dtype=uint16) == 1u16)), "arg1.dtype is 
expected to be float32")
        assert((3 == cast(int32, (int64*)arg1.shape[0])), "Argument 
arg1.shape[0] has an unsatisfied constraint: (3 == int32(arg1.shape[0]))")
        assert((3 == cast(int32, (int64*)arg1.shape[1])), "Argument 
arg1.shape[1] has an unsatisfied constraint: (3 == int32(arg1.shape[1]))")
        assert((64 == cast(int32, (int64*)arg1.shape[2])), "Argument 
arg1.shape[2] has an unsatisfied constraint: (64 == int32(arg1.shape[2]))")
        assert((64 == cast(int32, (int64*)arg1.shape[3])), "Argument 
arg1.shape[3] has an unsatisfied constraint: (64 == int32(arg1.shape[3]))")
         {
          if !@tir.isnullptr(arg1.strides, dtype=bool) {
            assert(((((1 == cast(int32, (int64*)arg1.strides[3])) && (64 == 
cast(int32, (int64*)arg1.strides[2]))) && (4096 == cast(int32, 
(int64*)arg1.strides[1]))) && (12288 == cast(int32, (int64*)arg1.strides[0]))), 
"arg1.strides: expected to be compact array")
            0
          }
          assert((0u64 == @tir.tvm_struct_get(arg1, 0, 8, dtype=uint64)), 
"Argument arg1.byte_offset has an unsatisfied constraint: ((uint64)0 == 
tir.tvm_struct_get(arg1, 0, 8))")
          assert((1 == @tir.tvm_struct_get(arg1, 0, 10, dtype=int32)), 
"Argument arg1.device_type has an unsatisfied constraint: (1 == 
tir.tvm_struct_get(arg1, 0, 10))")
          assert((dev_id == @tir.tvm_struct_get(arg1, 0, 9, dtype=int32)), 
"Argument arg1.device_id has an unsatisfied constraint: (dev_id == 
tir.tvm_struct_get(arg1, 0, 9))")
          assert((4 == @tir.tvm_struct_get(arg2, 0, 4, dtype=int32)), 
"arg2.ndim is expected to equal 4")
          assert((4 == @tir.tvm_struct_get(arg2, 0, 4, dtype=int32)), 
"arg2.ndim is expected to equal 4")
          assert((((@tir.tvm_struct_get(arg2, 0, 5, dtype=uint8) == 2u8) && 
(@tir.tvm_struct_get(arg2, 0, 6, dtype=uint8) == 32u8)) && 
(@tir.tvm_struct_get(arg2, 0, 7, dtype=uint16) == 1u16)), "arg2.dtype is 
expected to be float32")
          assert((3 == cast(int32, (int64*)arg2.shape[0])), "Argument 
arg2.shape[0] has an unsatisfied constraint: (3 == int32(arg2.shape[0]))")
          assert((3 == cast(int32, (int64*)arg2.shape[1])), "Argument 
arg2.shape[1] has an unsatisfied constraint: (3 == int32(arg2.shape[1]))")
          assert((64 == cast(int32, (int64*)arg2.shape[2])), "Argument 
arg2.shape[2] has an unsatisfied constraint: (64 == int32(arg2.shape[2]))")
          assert((64 == cast(int32, (int64*)arg2.shape[3])), "Argument 
arg2.shape[3] has an unsatisfied constraint: (64 == int32(arg2.shape[3]))")
           {
            if !@tir.isnullptr(arg2.strides, dtype=bool) {
              assert(((((1 == cast(int32, (int64*)arg2.strides[3])) && (64 == 
cast(int32, (int64*)arg2.strides[2]))) && (4096 == cast(int32, 
(int64*)arg2.strides[1]))) && (12288 == cast(int32, (int64*)arg2.strides[0]))), 
"arg2.strides: expected to be compact array")
              0
            }
            assert((0u64 == @tir.tvm_struct_get(arg2, 0, 8, dtype=uint64)), 
"Argument arg2.byte_offset has an unsatisfied constraint: ((uint64)0 == 
tir.tvm_struct_get(arg2, 0, 8))")
            assert((1 == @tir.tvm_struct_get(arg2, 0, 10, dtype=int32)), 
"Argument arg2.device_type has an unsatisfied constraint: (1 == 
tir.tvm_struct_get(arg2, 0, 10))")
            assert((dev_id == @tir.tvm_struct_get(arg2, 0, 9, dtype=int32)), 
"Argument arg2.device_id has an unsatisfied constraint: (dev_id == 
tir.tvm_struct_get(arg2, 0, 9))")
            assert((4 == @tir.tvm_struct_get(arg3, 0, 4, dtype=int32)), 
"arg3.ndim is expected to equal 4")
            assert((4 == @tir.tvm_struct_get(arg3, 0, 4, dtype=int32)), 
"arg3.ndim is expected to equal 4")
            assert((((@tir.tvm_struct_get(arg3, 0, 5, dtype=uint8) == 2u8) && 
(@tir.tvm_struct_get(arg3, 0, 6, dtype=uint8) == 32u8)) && 
(@tir.tvm_struct_get(arg3, 0, 7, dtype=uint16) == 1u16)), "arg3.dtype is 
expected to be float32")
            assert((54 == cast(int32, (int64*)arg3.shape[0])), "Argument 
arg3.shape[0] has an unsatisfied constraint: (54 == int32(arg3.shape[0]))")
            assert((54 == cast(int32, (int64*)arg3.shape[1])), "Argument 
arg3.shape[1] has an unsatisfied constraint: (54 == int32(arg3.shape[1]))")
            assert((64 == cast(int32, (int64*)arg3.shape[2])), "Argument 
arg3.shape[2] has an unsatisfied constraint: (64 == int32(arg3.shape[2]))")
            assert((3 == cast(int32, (int64*)arg3.shape[3])), "Argument 
arg3.shape[3] has an unsatisfied constraint: (3 == int32(arg3.shape[3]))")
             {
              if !@tir.isnullptr(arg3.strides, dtype=bool) {
                assert(((((1 == cast(int32, (int64*)arg3.strides[3])) && (3 == 
cast(int32, (int64*)arg3.strides[2]))) && (192 == cast(int32, 
(int64*)arg3.strides[1]))) && (10368 == cast(int32, (int64*)arg3.strides[0]))), 
"arg3.strides: expected to be compact array")
                0
              }
              assert((0u64 == @tir.tvm_struct_get(arg3, 0, 8, dtype=uint64)), 
"Argument arg3.byte_offset has an unsatisfied constraint: ((uint64)0 == 
tir.tvm_struct_get(arg3, 0, 8))")
              assert((1 == @tir.tvm_struct_get(arg3, 0, 10, dtype=int32)), 
"Argument arg3.device_type has an unsatisfied constraint: (1 == 
tir.tvm_struct_get(arg3, 0, 10))")
              assert((dev_id == @tir.tvm_struct_get(arg3, 0, 9, dtype=int32)), 
"Argument arg3.device_id has an unsatisfied constraint: (dev_id == 
tir.tvm_struct_get(arg3, 0, 9))")
              attr [0] "compute_scope" = "myfunc_fusion_compute_";
              attr [R: Pointer(global float32)] "storage_alignment" = 128 {
                let R = @tir.TVMBackendAllocWorkspace(1, dev_id, 2239488u64, 2, 
32, dtype=handle)
                 {
                  if @tir.isnullptr(R, dtype=bool) {
                    @tir.tvm_throw_last_error(, dtype=int32)
                  }
                  allocate(B: Pointer(global float32), float32, [1]), 
storage_scope = global {
                    for (yy: int32, 0, 54) {
                      for (xx: int32, 0, 54) {
                        for (cc: int32, 0, 64) {
                          for (batch: int32, 0, 3) {
                            B[0] = 0f32
                            for (ry: int32, 0, 3) {
                              for (rx: int32, 0, 3) {
                                for (rc: int32, 0, 64) {
                                  B[0] = @tir.call_llvm_pure_intrin(134u32, 
3u32, (float32*)A[((((((yy*10752) + (ry*10752)) + (xx*192)) + (rx*192)) + 
(rc*3)) + batch)], (float32*)W[((((ry*12288) + (rx*4096)) + (rc*64)) + cc)], 
(float32*)B[0], dtype=float32)
                                }
                              }
                            }
                            R[((((yy*10368) + (xx*192)) + (cc*3)) + batch)] = 
max(0f32, (float32*)B[0])
                          }
                        }
                      }
                    }
                    for (yy_1: int32, 0, 54) {
                      for (xx_1: int32, 0, 54) {
                        for (ff: int32, 0, 64) {
                          for (nn: int32, 0, 3) {
                            C[((((yy_1*10368) + (xx_1*192)) + (ff*3)) + nn)] = 
0f32
                            for (ry_2: int32, 0, 3) {
                              for (rx_2: int32, 0, 3) {
                                for (rc_2: int32, 0, 64) {
                                  C[((((yy_1*10368) + (xx_1*192)) + (ff*3)) + 
nn)] = @tir.call_llvm_pure_intrin(134u32, 3u32, (float32*)R[((((((yy_1*10368) + 
(ry_2*10368)) + (xx_1*192)) + (rx_2*192)) + (rc_2*3)) + nn)], 
(float32*)W_2[((((ry_2*12288) + (rx_2*4096)) + (rc_2*64)) + ff)], 
(float32*)C[((((yy_1*10368) + (xx_1*192)) + (ff*3)) + nn)], dtype=float32)
                                }
                              }
                            }
                          }
                        }
                      }
                    }
                  }
                }
                if (@tir.TVMBackendFreeWorkspace(1, dev_id, R, dtype=int32) != 
0) {
                  @tir.tvm_throw_last_error(, dtype=int32)
                }
              }
            }
          }
        }
      }
    }





---
[Visit 
Topic](https://discuss.tvm.apache.org/t/can-one-reduce-stage-fuse-into-another-reduce-stage/12367/1)
 to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.apache.org/email/unsubscribe/75e39fd21722d84c7a39212d0f263c73b6db84798a8e4a1929f107fc6ce53cb1).

Reply via email to