Added some examples to build on top of @Lunderberg 's example ## Transformation
The main difference between annotation and special handling are: - annotation is not necessarily to for correctness of the program, but it may provide hints towards future optimizations - Without annotation, the program still runs correctly, but certain optimizations may not trigger ### Step 0: Produce temp stages with annotation The transformation produces temporary buffers (AC and BC), where the relation between those data and the A, B are recorded in two blocks(preproc and post proc). Note that these additional annotations are hint for compilers to perform future optimizations(e.g. to lift them out our cancel. Our eventual goal could be directly reason those properties from the code, but annontations provides a first short cut. ```python @T.prim_func def grow(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]): AC = T.alloc_buffer([4, 4], "int32") BC = T.alloc_buffer([4, 4], "int32") for io, ii in T.grid(4, 4): with T.block(): T.block_attr("preproc", "pad") AC[io, ii] = if_then_else(4 * io + ii < 14, A[4 * io + ii], 0) for i, j in T.grid(4, 4): BC[i, j] = 2 * AC[i, j] for io, ii in T.grid(14): with T.block(): # hint that this is a cropping operation, # where we know that the remaining part in B is 0 # Additionally, the remaining uncovered values # are assumed to be 0, if not provided then no assumptions are made T.block_attr("postproc", ["crop", 0]) B[io, ii] = BC[4 * io + ii] @T.prim_func def addone(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]): for i in T.grid(14): B[i] = A[i] + 1 @R.func def main(A: T.Tensor[14, "int32"]): lv0 = call_tir(grow, [A], (14)) # an intermdiate stage to show non-local reflowing lv1 = call_tir(addone, [lv0], (14)) lv2 = call_tir(grow, [lv1], (14)) ... ``` Not the special crop annotation comes with an `assumed_value`, which is provided as part of transformation (and actually we can prove that it is safe if our layout transformation starts from B and go backwards. ### Step 1: Reconstruct constraint at TIR-Graph level By looking at the primfunc, we know that there is a desire to split out the preproc stage and postpost stage to the graph. Although it is totally fine for the compiler to choose not to do so and it is still a valid program. But let us say we choose to lift them out ```python @T.prim_func def grow_packed(AC: T.Buffer[[4,4], "int32"], BC: T.Buffer[[4,4], "int32"]): for i, j in T.grid(4, 4): BC[i, j] = 2 * AC[i, j] @T.prim_func def pad(A: T.Buffer[14, "int32"], AC: T.Buffer[[14, 14], "int32"]): for io, ii in T.grid(4, 4): with T.block(): T.block_attr("preproc", "pad") AC[io, ii] = if_then_else(4 * io + ii < 14, A[4 * io + ii], 0) @T.prim_func def crop_with_pad_assume(BC: T.Buffer[[4,4], "int32"], B: T.Buffer[14, "int32"]): # Note that this crop carries a pad assertion(of other values of BC) for io, ii in T.grid(14): with T.block(): T.block_attr("postproc", ["crop", 0]) B[io, ii] = BC[4 * io + ii] @R.func def main(A: T.Tensor[14, "int32"]): lv0 = call_tir(pad, (4, 4), A) lv1 = call_tir(grow, [lv0], (4, 4)) # These are two things that we want to use for global format reflowing lv2 = call_tir(crop_with_pad_assume, [lv1], (14)) lv3 = call_tir(addone, [lv2], (14) lv4 = call_tir(pad, [lv2], (4, 4)) lv4 = call_tir(grow, [lv3], (4, 4)) lv5 = call_tir(crop_with_pad_assume, [(14)) ``` ### Step 2: Global Reflowing of layouts Now as a last step, let us say we will do global reflowing. - Start from reverse topo DAG order, - Whenever we encounter a pad, we reconstruct a in-memory data structure(something like BufferConstraint, e.g. PadMapping(constraint, pad_value=0)) - We try to “backrop” the PadMapping through out the graph - Each function needs to have its own TIR analysis of how it flows things back, for example, in the case of `addone`, we can safely flow PadMapping back, changing `addone` to `addone_packed` by analyzing the TIR. If the `addone` is elemwise exp however, we need to insert a select operator(because `exp(0)=1` ) the message to input becomes `PadMapping(constraint, pad_value=undef)`. - When `PadMapping` meets `crop_with_pad_assert`, we can attempt to simplify and cancel out - When there are branches, transpositions in the graph level or other more complicated issues, we might choose to materialize ### Discussion There are a few key properties that is really desirable here: - transformation of PrimFunc do not change the PrimFunc interface: this is really important so we can transform a PrimFunc without worrying about how the graph interacts with it(as the interface remains the same, we can lift out the blocks earlier) - There are implicit assumption generated(`crop_with_pad_assume` ) to enable some simplification(otherwise a select is necessary, which is also not as bad). Note that assumption are generated under a global context (when we do transform padding we actually know that the overflowing field are 0). But extra amount of care is needed when we attempt to move `crop_with_pad_assume` , as it really depends on the value property of its input. A high-level gist is we should not do that, and instead the global reflowing of layout will reflow the `PadMapping` to `crop_with_pad_assume` then cancel it out. Talking about “constraints”, it is also useful to talk about categories of them, roughly we can divide them into three categories. - static_assert: We want to assert some invariance of the code, it is also necessary to “proof” that it is the case during compile time, otherwise compilation error needs to be raised. - (runtime) assert: We want to assert some invariance of the code, it is not necessary to “proof” that this is the case, but we need to do runtime checking if it cannot be proved. - assume (from __builtin_assume): We want to assert some invariance of the code, it is not necessary to “prove” that it is the case during compilation. All three types of constraints can be helpful. In our particular case, `assume` is being generated in `crop_with_pad_assume`. -- Reply to this email directly or view it on GitHub: https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1164440693 You are receiving this because you are subscribed to this thread. Message ID: <apache/tvm-rfcs/pull/77/c1164440...@github.com>