Writing out some of my thoughts, to see if there's a way to express the constraints while only using existing TIR features. The main goals would be as follows.
1. Allow simplification of expressions based on the values present in the padding. 2. Allow local simplifications to take advantage of non-local constraints, without requiring a full end-to-end analysis. 3. Specify the non-local constraints in some deducible manner that doesn't impose a runtime performance penalty. Next, working through various options for how the constraints could be stored. In the examples below, sketching out how these would apply to the element-wise operation which starts as below. ```python @T.prim_func def func(A: T.Buffer[(14), "int32"], B: T.Buffer[14, "int32"]): for i in T.serial(14): B[i] = 2 * A[i] ``` 1. Apply layout transforms on local caches. Here, the full lifetime of a buffer is known. All TIR optimization are done prior to hoisting the cache and layout transformation into the graph level. - For read caches, pad value is whatever gets conditionally written to the padding while generating it. In example below, `AC` could be recognized as being padded. ```python @T.prim_func def func(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]): AC = T.alloc_buffer([4, 4], "int32") for io, ii in T.grid(4, 4): if 4 * io + ii < 14: AC[io, ii] = A[4 * io + ii] else: AC[io, ii] = 0 for i in T.serial(14): B[i] = 2 * AC[i // 4, i % 4] ``` - For write caches, pad value is whatever is in the padding after the last write to the cache. In example below, `BC` could be recognized as being padded. ```python @T.prim_func def func(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]): BC = T.alloc_buffer([4, 4], "int32") for io, ii in T.grid(4, 4): if 4 * io + ii < 14: BC[io, ii] = 2 * A[4*io + ii] else: BC[io, ii] = 0 for io, ii in T.grid(4, 4): if 4 * io + ii < 14: B[i] = BC[io, ii] ``` - Downside, either of the `else` statements could be eliminated as a no-op, since they don't contribute to the output `B` value. After that elimination, there wouldn't be any way to reconstruct the pad value. 2. When hoisting an allocation+transformation, write the pad value to the buffer at the start of function from which it was hoisted. This way, the pad value can still be used in local reasoning. - No change needed in producers, since they would already write the pad value to the buffer. - For consumers, would be represented as writing `pad_value` into the padding at the start of the function. ```python @T.prim_func def func(AC: T.Buffer[(4, 4), "int32"], B: T.Buffer[14, "int32"]): for io, ii in T.grid(4, 4): if 4 * io + ii >= 14: AC[io, ii] = 0 for io, ii in T.grid(4, 4): if 4 * io + ii < 14: B[4 * io + ii] = 2 * AC[io, ii] ``` - Downside, repeated unnecessary effort at the beginning of each consumer. Avoiding it with this representation would require knowing that the producer had written `pad_value` already, which is exactly the information we're trying to avoid. 3. When hoisting an allocation+transformation, write the pad value to the buffer at the start of function from which it was hoisted, and write `T.undef()` at the end. This way, the pad value can still be used in local reasoning, and no-op removal can remove the repeated writing when lowering. - No change needed in producers, since they would already write the pad value to the buffer. - For consumers, would be like option 2, but with an additional write of `T.undef()` at the end of the function. When lowering, the write of `T.undef()` would allow the first write to be removed as a no-op because it is overwritten. The `T.undef()` can then be removed as described in the RFC. ```python @T.prim_func def func(AC: T.Buffer[(4, 4), "int32"], B: T.Buffer[14, "int32"]): for io, ii in T.grid(4, 4): if 4 * io + ii >= 14: AC[io, ii] = 0 for io, ii in T.grid(4, 4): if 4 * io + ii < 14: B[4 * io + ii] = 2 * AC[io, ii] for io, ii in T.grid(4, 4): if 4 * io + ii >= 14: AC[io, ii] = T.undef() ``` - Downside, no way to distinguish between "can assume the pad value is zero" and "can overwrite the pad value at will". The writing of `T.undef()` would allow any writes to the padding to be inserted as a no-op. - Downside, wouldn't actually simplify out in cases where the pad value is used. The first in a pair of repeated writes to the same location can only be removed if there are no reads between the writes. After using the pad value to eliminate `if 4 * io + ii < 14` from the compute, the dummy loop that writes the padding could no longer be removed. 4. Use `AssertStmt` in a loop to declare known information about the buffers. - No change needed in producers, since the pad value is already written out. - For consumers, would have an initial loop that asserts the pad value is correct. ```python @T.prim_func def func(AC: T.Buffer[(4, 4), "int32"], B: T.Buffer[14, "int32"]): for io, ii in T.grid(4, 4): if 4 * io + ii >= 14: assert AC[io, ii] == 0, "padding" for io, ii in T.grid(4, 4): if 4 * io + ii < 14: B[4 * io + ii] = 2 * AC[io, ii] ``` - Downside, assert statements have target-dependent handling. In `CodeGenLLVM` and `CodeGenSPIRV`, they are treated as no-ops. In `CodeGenCPU` and `CodeGenC`, they generate asserts. In `CodeGenCUDA`, they aren't handled at all and would error out. Could work around this with a lowering pass, but identifying these conditions would require having a special string in the message, and packing structured data into strings makes me wary. 5. Use `AssertStmt` with implicitly-defined variables to declare known information about the buffers. ```python @T.prim_func def func(AC: T.Buffer[(4, 4), "int32"], B: T.Buffer[14, "int32"]): a = T.var("int32") b = T.var("int32") assert ( AC[a, b] == 0 or (4 * a + b < 14) or (a < 0) or (a >= 4) or (b < 0) or (b >= 4) ), "padding" for io, ii in T.grid(4, 4): if 4 * io + ii < 14: B[4 * io + ii] = 2 * AC[io, ii] ``` - Can apply to clamped texture memory, since the variables in the assertion isn't restricted to the bounds. - Would need to recognize specific pattern of `BufferLoad` being used to define variables used in constraint. - The implicitly-defined variables can be written in current TIR, but variables would ensure that this isn't something that ever makes it into generated code at runtime. - Downside, implicitly-defined variables are something of a red flag. 6. Store constraints in the function attributes, either as a dictionary or as a structured object. ```python @T.prim_func def func(AC: T.Buffer[(4, 4), "int32"], B: T.Buffer[14, "int32"]): T.func_attr( "buffer_constraints", [ { "buffer": AC, "predicate": lambda io, ii: 4 * io + ii < 14, "pad_value": lambda io, ii: 0, }, ], ) for io, ii in T.grid(4, 4): if 4 * io + ii < 14: B[4 * io + ii] = 2 * AC[io, ii] ``` - Downside, requires transformations that change a buffer to be aware that other structures will also need to be replaced. - Downside, requires simplifications to either be passed the entire `PrimFunc`, or to be explicitly passed the `"buffer_constraints"` list. - Downside, would break expectations of `IRMutatorWithAnalyzer`. The current entry point of any `Stmt` or `Expr` would need to have additional information of the `"buffer_constraints"`. 7. Store constraints in the `Buffer` object, either as a dictionary or as a structured object. ```python @T.prim_func def func(ac: T.handle, B: T.Buffer[14, "int32"]): AC = T.match_buffer( shape=(4, 4), dtype="int32", constraints=[T.BufferConstraints(predicate=lambda io, ii: 4 * io + ii < 14, pad_value=0)], ) for io, ii in T.grid(4, 4): if 4 * io + ii < 14: B[4 * io + ii] = 2 * AC[io, ii] ``` - Downside, introduces additional data structure in TIR. -- Reply to this email directly or view it on GitHub: https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1163620046 You are receiving this because you are subscribed to this thread. Message ID: <apache/tvm-rfcs/pull/77/c1163620...@github.com>