Hi @tkonolige,
Thanks a lot for your help. Unfortunately, your fix didn t solve the problem. I am a bit confused because my implementation is very closed to the one for conv2d_NCHWc_int8 ``` def _schedule_conv2d_NCHWc_int8(cfg, s, output): conv = output.op.input_tensors[0] packed_data, packed_kernel = conv.op.input_tensors if isinstance(packed_data.op, tvm.te.ComputeOp) and "pad" in packed_data.op.tag: pad_data = packed_data packed_data = pad_data.op.input_tensors[0] else: pad_data = packed_data if autotvm.GLOBAL_SCOPE.in_tuning: # skip this part during tuning to make recrods accurate # this part will be pre-computed during NNVM's pre-compute optimization pass s[packed_data].pragma(s[packed_data].op.axis[0], "debug_skip_region") s[packed_kernel].pragma(s[packed_kernel].op.axis[0], "debug_skip_region") else: if isinstance(packed_kernel.op, tvm.te.ComputeOp) and packed_kernel.name == "packed_kernel": # data and kernel are not pre-computed, schedule layout transform here schedule_injective_from_existing(s, packed_data) schedule_injective_from_existing(s, packed_kernel) if pad_data != packed_data: s[pad_data].compute_inline() # create cache stage AA = s.cache_read(pad_data, "shared", [conv]) WW = s.cache_read(packed_kernel, "shared", [conv]) s[conv].set_scope("local") # handle bias if output.op not in s.outputs: s[output].compute_inline() output = s.outputs[0].output(0) # tile and bind spatial axes if len(s[output].op.axis) == 5: n, f, y, x, c = s[output].op.axis else: # For task extraction of auto-tuning, the expected output is 4D. Since auto-tuning tasks # are created from scratch, therefore the real auto-tuning will still happen on 5D output. n, f, y, x = s[output].op.axis cfg.define_split("tile_n", cfg.axis(n), num_outputs=4) cfg.define_split("tile_f", cfg.axis(f), num_outputs=4) cfg.define_split("tile_y", cfg.axis(y), num_outputs=4) cfg.define_split("tile_x", cfg.axis(x), num_outputs=4) # this is the scope to attach global config inside this kernel kernel_scope, n = s[output].split(n, nparts=1) bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n) bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f) by, vy, ty, yi = cfg["tile_y"].apply(s, output, y) bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) s[output].reorder(bn, bf, by, bx, vn, vf, vy, vx, tn, tf, ty, tx, ni, fi, yi, xi) s[output].bind(bn, te.thread_axis("blockIdx.z")) s[output].bind(bf, te.thread_axis("blockIdx.y")) s[output].bind(s[output].fuse(by, bx), te.thread_axis("blockIdx.x")) s[output].bind(vn, te.thread_axis("vthread")) s[output].bind(vf, te.thread_axis("vthread")) s[output].bind(vy, te.thread_axis("vthread")) s[output].bind(vx, te.thread_axis("vthread")) cfg.define_knob("fuse_yx", [0, 1]) # fuse ty,tx or tn,tf if cfg["fuse_yx"].val: s[output].bind(tn, te.thread_axis("threadIdx.z")) s[output].bind(tf, te.thread_axis("threadIdx.y")) tyx = s[output].fuse(ty, tx) s[output].bind(tyx, te.thread_axis("threadIdx.x")) s[conv].compute_at(s[output], tyx) # number of threads n_tz = cfg["tile_n"].size[2] n_ty = cfg["tile_f"].size[2] n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2] else: s[output].bind(s[output].fuse(tn, tf), te.thread_axis("threadIdx.z")) s[output].bind(ty, te.thread_axis("threadIdx.y")) s[output].bind(tx, te.thread_axis("threadIdx.x")) s[conv].compute_at(s[output], tx) # number of threads n_tz = cfg["tile_n"].size[2] * cfg["tile_f"].size[2] n_ty = cfg["tile_y"].size[2] n_tx = cfg["tile_x"].size[2] # tile and bind reduction axes n, f, y, x, c = s[conv].op.axis rc, ry, rx, rc_block = s[conv].op.reduce_axis cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=2) cfg.define_split("tile_ry", cfg.axis(ry), num_outputs=2) cfg.define_split("tile_rx", cfg.axis(rx), num_outputs=2) rco, rci = cfg["tile_rc"].apply(s, conv, rc) ryo, ryi = cfg["tile_ry"].apply(s, conv, ry) rxo, rxi = cfg["tile_rx"].apply(s, conv, rx) s[conv].reorder(rco, ryo, rxo, rci, ryi, rxi, n, f, y, x, c, rc_block) cfg.define_reorder("reorder_inner", [rco, ryo, rxo], policy="all") cfg["reorder_inner"].apply(s, conv, [rco, ryo, rxo]) cfg["reorder_inner"].apply(s, conv, [rci, ryi, rxi]) _, rc_block = s[conv].split(rc_block, factor=4) s[conv].tensorize(rc_block, _dp4a) cache_loc = [rco, ryo, rxo][cfg["reorder_inner"].perm[-1]] s[AA].compute_at(s[conv], cache_loc) s[WW].compute_at(s[conv], cache_loc) # cooperative fetching for load in [AA, WW]: print(s[load].op.axis) c = s[load].op.axis[-1] c_outer, c = s[load].split(c, factor=4) s[load].vectorize(c) fused = s[load].op.axis[:-1] + [c_outer] fused = s[load].fuse(*fused) fused, tx = s[load].split(fused, factor=n_tx) fused, ty = s[load].split(fused, factor=n_ty) fused, tz = s[load].split(fused, factor=n_tz) s[load].bind(tz, te.thread_axis("threadIdx.z")) s[load].bind(ty, te.thread_axis("threadIdx.y")) s[load].bind(tx, te.thread_axis("threadIdx.x")) # double buffer cfg.define_knob("AA_double_buffer", [0, 1]) cfg.define_knob("WW_double_buffer", [0, 1]) if cfg["AA_double_buffer"].val: s[AA].double_buffer() if cfg["WW_double_buffer"].val: s[WW].double_buffer() # unroll cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) s[output].pragma(kernel_scope, "unroll_explicit", False) return s ``` --- [Visit Topic](https://discuss.tvm.apache.org/t/quantization-and-3d-convolution/8338/4) to respond. You are receiving this because you enabled mailing list mode. To unsubscribe from these emails, [click here](https://discuss.tvm.apache.org/email/unsubscribe/bf29a569a4afba6486a2bead14ad489381a9fe7db60897ebb71c12655054a65a).