Writing out some of my thoughts, to see if there's a way to express the 
constraints while only using existing TIR features.  The main goals would be as 
follows.

1. Allow simplification of expressions based on the values present in the 
padding.
2. Allow local simplifications to take advantage of non-local constraints, 
without requiring a full end-to-end analysis.
3. Specify the non-local constraints in some deducible manner that doesn't 
impose a runtime performance penalty.
   
Next, working through various options for how the constraints could be stored. 
In the examples below, sketching out how these would apply to the element-wise 
operation which starts as below.

```python
@T.prim_func
def func(A: T.Buffer[(14), "int32"], B: T.Buffer[14, "int32"]):
    for i in T.serial(14):
        B[i] = 2 * A[i]
```

1. Apply layout transforms on local caches.  Here, the full lifetime of a 
buffer is known.  All TIR optimization are done prior to hoisting the cache and 
layout transformation into the graph level.
   
   - For read caches, pad value is whatever gets conditionally written to the 
padding while generating it.  In example below, `AC` could be recognized as 
being padded.
     
     ```python
     @T.prim_func
     def func(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]):
         AC = T.alloc_buffer([4, 4], "int32")
         for io, ii in T.grid(4, 4):
             if 4 * io + ii < 14:
                 AC[io, ii] = A[4 * io + ii]
             else:
                 AC[io, ii] = 0
     
         for i in T.serial(14):
             B[i] = 2 * AC[i // 4, i % 4]
     ```
     
   - For write caches, pad value is whatever is in the padding after the last 
write to the cache.  In example below, `BC` could be recognized as being padded.

     ```python
     @T.prim_func
     def func(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]):
         BC = T.alloc_buffer([4, 4], "int32")
         for io, ii in T.grid(4, 4):
             if 4 * io + ii < 14:
                 BC[io, ii] = 2 * A[4*io + ii]
             else:
                 BC[io, ii] = 0
     
         for io, ii in T.grid(4, 4):
             if 4 * io + ii < 14:
                 B[i] = BC[io, ii]
     ```

   - Downside, either of the `else` statements could be eliminated as a no-op, 
since they don't contribute to the output `B` value. After that elimination, 
there wouldn't be any way to reconstruct the pad value.
     
2. When hoisting an allocation+transformation, write the pad value to the 
buffer at the start of function from which it was hoisted. This way, the pad 
value can still be used in local reasoning.
   
   - No change needed in producers, since they would already write the pad 
value to the buffer.
   
   - For consumers, would be represented as writing `pad_value` into the 
padding at the start of the function.
   
     ```python
     @T.prim_func
     def func(AC: T.Buffer[(4, 4), "int32"], B: T.Buffer[14, "int32"]):
         for io, ii in T.grid(4, 4):
             if 4 * io + ii >= 14:
                 AC[io, ii] = 0
     
         for io, ii in T.grid(4, 4):
             if 4 * io + ii < 14:
                 B[4 * io + ii] = 2 * AC[io, ii]
     ```
     
   - Downside, repeated unnecessary effort at the beginning of each consumer.  
Avoiding it with this representation would require knowing that the producer 
had written `pad_value` already, which is exactly the information we're trying 
to avoid.
     
3. When hoisting an allocation+transformation, write the pad value to the 
buffer at the start of function from which it was hoisted, and write 
`T.undef()` at the end.  This way, the pad value can still be used in local 
reasoning, and no-op removal can remove the repeated writing when lowering.
   
   - No change needed in producers, since they would already write the pad 
value to the buffer.
     
   - For consumers, would be like option 2, but with an additional write of 
`T.undef()` at the end of the function.  When lowering, the write of 
`T.undef()` would allow the first write to be removed as a no-op because it is 
overwritten.  The `T.undef()` can then be removed as described in the RFC.
   
     ```python
     @T.prim_func
     def func(AC: T.Buffer[(4, 4), "int32"], B: T.Buffer[14, "int32"]):
         for io, ii in T.grid(4, 4):
             if 4 * io + ii >= 14:
                 AC[io, ii] = 0
     
         for io, ii in T.grid(4, 4):
             if 4 * io + ii < 14:
                 B[4 * io + ii] = 2 * AC[io, ii]
     
         for io, ii in T.grid(4, 4):
             if 4 * io + ii >= 14:
                 AC[io, ii] = T.undef()
     ```
     
   - Downside, no way to distinguish between "can assume the pad value is zero" 
and "can overwrite the pad value at will".  The writing of `T.undef()` would 
allow any writes to the padding to be inserted as a no-op.
     
   - Downside, wouldn't actually simplify out in cases where the pad value is 
used.  The first in a pair of repeated writes to the same location can only be 
removed if there are no reads between the writes.  After using the pad value to 
eliminate `if 4 * io + ii < 14` from the compute, the dummy loop that writes 
the padding could no longer be removed.
     
4. Use `AssertStmt` in a loop to declare known information about the buffers.

   - No change needed in producers, since the pad value is already written out.
     
   - For consumers, would have an initial loop that asserts the pad value is 
correct.

     ```python
     @T.prim_func
     def func(AC: T.Buffer[(4, 4), "int32"], B: T.Buffer[14, "int32"]):
         for io, ii in T.grid(4, 4):
             if 4 * io + ii >= 14:
                 assert AC[io, ii] == 0, "padding"
     
         for io, ii in T.grid(4, 4):
             if 4 * io + ii < 14:
                 B[4 * io + ii] = 2 * AC[io, ii]
     ```
     
   - Downside, assert statements have target-dependent handling.  In 
`CodeGenLLVM` and `CodeGenSPIRV`, they are treated as no-ops.  In `CodeGenCPU` 
and `CodeGenC`, they generate asserts.  In `CodeGenCUDA`, they aren't handled 
at all and would error out.
     
     Could work around this with a lowering pass, but identifying these 
conditions would require having a special string in the message, and packing 
structured data into strings makes me wary.
     
5. Use `AssertStmt` with implicitly-defined variables to declare known 
information about the buffers.
   
   ```python
   @T.prim_func
   def func(AC: T.Buffer[(4, 4), "int32"], B: T.Buffer[14, "int32"]):
       a = T.var("int32")
       b = T.var("int32")
       assert (
           AC[a, b] == 0 or (4 * a + b < 14) or (a < 0) or (a >= 4) or (b < 0) 
or (b >= 4)
       ), "padding"
   
       for io, ii in T.grid(4, 4):
           if 4 * io + ii < 14:
               B[4 * io + ii] = 2 * AC[io, ii]
   ```
   
   - Can apply to clamped texture memory, since the variables in the assertion 
isn't restricted to the bounds.
     
   - Would need to recognize specific pattern of `BufferLoad` being used to 
define variables used in constraint.
     
   - The implicitly-defined variables can be written in current TIR, but  
variables would ensure that this isn't something that ever makes it into 
generated code at runtime.
   
   - Downside, implicitly-defined variables are something of a red flag.

6. Store constraints in the function attributes, either as a dictionary or as a 
structured object.
   
   ```python
   @T.prim_func
   def func(AC: T.Buffer[(4, 4), "int32"], B: T.Buffer[14, "int32"]):
       T.func_attr(
           "buffer_constraints",
           [
               {
                   "buffer": AC,
                   "predicate": lambda io, ii: 4 * io + ii < 14,
                   "pad_value": lambda io, ii: 0,
               },
           ],
       )
   
       for io, ii in T.grid(4, 4):
           if 4 * io + ii < 14:
               B[4 * io + ii] = 2 * AC[io, ii]
   ```
   
   - Downside, requires transformations that change a buffer to be aware that 
other structures will also need to be replaced.
     
   - Downside, requires simplifications to either be passed the entire 
`PrimFunc`, or to be explicitly passed the `"buffer_constraints"` list.
     
   - Downside, would break expectations of `IRMutatorWithAnalyzer`. The current 
entry point of any `Stmt` or `Expr` would need to have additional information 
of the `"buffer_constraints"`.
     

7. Store constraints in the `Buffer` object, either as a dictionary or as a 
structured object.
   
   ```python
   @T.prim_func
   def func(ac: T.handle, B: T.Buffer[14, "int32"]):
       AC = T.match_buffer(
           shape=(4, 4),
           dtype="int32",
           constraints=[T.BufferConstraints(predicate=lambda io, ii: 4 * io + 
ii < 14, pad_value=0)],
       )
   
       for io, ii in T.grid(4, 4):
           if 4 * io + ii < 14:
               B[4 * io + ii] = 2 * AC[io, ii]
   ```
   
   - Downside, introduces additional data structure in TIR.

-- 
Reply to this email directly or view it on GitHub:
https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1163620046
You are receiving this because you are subscribed to this thread.

Message ID: <apache/tvm-rfcs/pull/77/c1163620...@github.com>

Reply via email to