Thanks for the all great discussions! It is so excited that we will have a more powerful ability to handle all things like paddings and imperfect tiles.
Since our team rely on the code path of s-tir, we are extremely interested in the story on s-tir. I would be very appreciated if we have some details on s-tir padding. I would like to use a [127, 127, 127] matmul to depict my questions :) ```python @T.prim_func def matmul(A: T.Buffer[(127, 127), "float32"], B: T.Buffer[(127, 127), "float32"], C: T.Buffer[(127, 127), "float32"]): for i, j, k in T.grid(127, 127, 127): with T.block("compute"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 C[vi, vj] += A[vi, vk] * B[vk, vj] ``` In current s-tir state, we can construct padded loop and buffer using existing primitives by "split and then fuse" trick: ```python s = tvm.tir.Schedule(matmul) blk = s.get_block("compute") i, j, k = s.get_loops(blk) s.fuse(*s.split(i, factors=[4, 32])) s.fuse(*s.split(j, factors=[4, 32])) s.fuse(*s.split(k, factors=[4, 32])) s.transform_layout(blk, "A", lambda i,k: ((i // 32) * 32 + i % 32, (k // 32) * 32 + k % 32)) s.transform_layout(blk, "B", lambda k,j: ((k // 32) * 32 + k % 32, (j // 32) * 32 + j % 32)) s.transform_layout(blk, "C", lambda i,j: ((i // 32) * 32 + i % 32, (j // 32) * 32 + j % 32)) ``` We will get (if simplified) ```python @T.prim_func def func(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]): for i_0_i_1_fused, j_0_j_1_fused, k_0_k_1_fused in T.grid(128, 128, 128): with T.block("compute"): vi = T.axis.spatial(127, i_0_i_1_fused) vj = T.axis.spatial(127, j_0_j_1_fused) vk = T.axis.reduce(127, k_0_k_1_fused) T.where(i_0_i_1_fused < 127 and j_0_j_1_fused < 127 and k_0_k_1_fused < 127) T.reads(A[vi, vk], B[vk, vj]) T.writes(C[vi, vj]) with T.init(): C[vi, vj] = T.float32(0) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] ``` Then the only thing left is the condition for padding: `T.where(i_0_i_1_fused < 127 and j_0_j_1_fused < 127 and k_0_k_1_fused < 127)`. I believe we now get to the point on current RFC about over-computation and branch tradeoff. And below are some my questions ~ 1. What happened when change to `s.transform_layout(..., pad_value=0)`? (if we want over-computations) - (possible behavior 1) Insert padding filling code as a producer block of `compute`. - since the effect is immediate, maybe we do not need `BufferConstraint` annotations afterwards? - (possible behavior 2) Annotate buffers and let lowering passes to handle. - we may require `BufferConstraint` to direct lowering passes, - (possible behavior 3) Pass `BufferConstraint` upwards into graph level - thus assume the param buffer match the constraint, do not write edge values. 2. For (1.2)(1.3), it seems encode the `BufferConstraint` into the buffer object is not the only choice. - For s-tir, fix me, at least for common cases the constraint could be treat to be local wrt the transformed block. What if we encode the constraint just into the block, as its memory access properties. We found previously, block memory annotations `T.reads`, `T.writes` (`BufferRegion`) have some limitations that they loss conditional access informations. Maybe we can also combine `BufferConstraint` with `BufferRegion`? - For graph level annotations, IIUC, it uses "Tensor" typed value instead of "Buffer" conceptually. Maybe we still need another construction instead of `Buffer` with `BufferConstraint` field? We could also consider instantiate graph level transformation explicitly. This is our solution currently: https://discuss.tvm.apache.org/t/introducing-ty-nnp-backend-with-end2end-tensorir-integration/11807/4. - Nevertheless, if finally we decide extent the buffer node structure, hope we can have an explicit lifetime for the `BufferConstraint` in the TIR lowering. Thus storage related passes afterwards do not bother, especially for customized passes developed by vendors. 3. For the reduce axis padding, mentioned in https://github.com/apache/tvm-rfcs/pull/77#discussion_r894899301 - In TIR level, since the schedule primitive should preserve the semantic correctness, how we prove the `k` dimension padding should only be zero? Especially when we do not know it is a "matmul" op generally. I think it is important if we want to use padded `transform_layout` in auto-schedule fashion applications. cc @Lunderberg @tqchen @vinx13 @Hzfengsy -- Reply to this email directly or view it on GitHub: https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1152928725 You are receiving this because you are subscribed to this thread. Message ID: <apache/tvm-rfcs/pull/77/c1152928...@github.com>