Thanks for the all great discussions! It is so excited that we will have a more 
powerful ability to handle all things like paddings and imperfect tiles.

Since our team rely on the code path of s-tir, we are extremely interested in 
the story on s-tir. I would be very appreciated if we have some details on 
s-tir padding. I would like to use a [127, 127, 127] matmul to depict my 
questions :)

```python
@T.prim_func
def matmul(A: T.Buffer[(127, 127), "float32"], B: T.Buffer[(127, 127), 
"float32"], C: T.Buffer[(127, 127), "float32"]):
    for i, j, k in T.grid(127, 127, 127):
        with T.block("compute"):
            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
            with T.init():
                C[vi, vj] = 0.0
            C[vi, vj] += A[vi, vk] * B[vk, vj]
```

In current s-tir state, we can construct padded loop and buffer using existing 
primitives by "split and then fuse" trick:
```python
s = tvm.tir.Schedule(matmul)
blk = s.get_block("compute")
i, j, k = s.get_loops(blk)
s.fuse(*s.split(i, factors=[4, 32]))
s.fuse(*s.split(j, factors=[4, 32]))
s.fuse(*s.split(k, factors=[4, 32]))
s.transform_layout(blk, "A", lambda i,k: ((i // 32) * 32 + i % 32, (k // 32) * 
32 + k % 32))
s.transform_layout(blk, "B", lambda k,j: ((k // 32) * 32 + k % 32, (j // 32) * 
32 + j % 32))
s.transform_layout(blk, "C", lambda i,j: ((i // 32) * 32 + i % 32, (j // 32) * 
32 + j % 32))
```
We will get (if simplified)
```python
@T.prim_func
def func(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), 
"float32"], C: T.Buffer[(128, 128), "float32"]):
    for i_0_i_1_fused, j_0_j_1_fused, k_0_k_1_fused in T.grid(128, 128, 128):
        with T.block("compute"):
            vi = T.axis.spatial(127, i_0_i_1_fused)
            vj = T.axis.spatial(127, j_0_j_1_fused)
            vk = T.axis.reduce(127, k_0_k_1_fused)
            T.where(i_0_i_1_fused < 127 and j_0_j_1_fused < 127 and 
k_0_k_1_fused < 127)
            T.reads(A[vi, vk], B[vk, vj])
            T.writes(C[vi, vj])
            with T.init():
                C[vi, vj] = T.float32(0)
            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
```
Then the only thing left is the condition for padding: `T.where(i_0_i_1_fused < 
127 and j_0_j_1_fused < 127 and k_0_k_1_fused < 127)`. I believe we now get to 
the point on current RFC about over-computation and branch tradeoff. And below 
are some my questions ~

1. What happened when change to `s.transform_layout(...,  pad_value=0)`? (if we 
want over-computations)
   - (possible behavior 1) Insert padding filling code as a producer block of 
`compute`.  
     - since the effect is immediate, maybe we do not need `BufferConstraint` 
annotations afterwards?
   - (possible behavior 2) Annotate buffers and let lowering passes to handle.
     - we may require `BufferConstraint` to direct lowering passes, 
   - (possible behavior 3) Pass `BufferConstraint` upwards into graph level
     -  thus assume the param buffer match the constraint, do not write edge 
values.
   
2.  For (1.2)(1.3), it seems encode the `BufferConstraint` into the buffer 
object is not the only choice.
    - For s-tir,  fix me, at least for common cases the constraint could be 
treat to be local wrt the transformed block. What if we encode the constraint 
just into the block, as its memory access properties.
      We found previously, block memory annotations `T.reads`, `T.writes` 
(`BufferRegion`) have some limitations that they loss conditional access 
informations. Maybe we can also combine `BufferConstraint` with `BufferRegion`?

    - For graph level annotations, IIUC,  it uses "Tensor" typed value instead 
of "Buffer" conceptually. Maybe we still need another construction instead of 
`Buffer` with `BufferConstraint` field? 
      We could also consider instantiate graph level transformation explicitly. 
This is our solution currently: 
https://discuss.tvm.apache.org/t/introducing-ty-nnp-backend-with-end2end-tensorir-integration/11807/4.
 

    - Nevertheless, if finally we decide extent the buffer node structure, hope 
we can have an explicit lifetime for the `BufferConstraint` in the TIR 
lowering. Thus storage related passes afterwards do not bother, especially for 
customized passes developed by vendors.

3. For the reduce axis padding, mentioned in 
https://github.com/apache/tvm-rfcs/pull/77#discussion_r894899301
    - In TIR level, since the schedule primitive should preserve the semantic 
correctness, how we prove the `k` dimension padding should only be zero? 
Especially when we do not know it is a "matmul" op generally. I think it is 
important if we want to use padded `transform_layout` in auto-schedule fashion 
applications.

cc @Lunderberg @tqchen @vinx13 @Hzfengsy 

-- 
Reply to this email directly or view it on GitHub:
https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1152928725
You are receiving this because you are subscribed to this thread.

Message ID: <apache/tvm-rfcs/pull/77/c1152928...@github.com>

Reply via email to