Added some examples to build on top of @Lunderberg 's example

## Transformation

The main difference between annotation and special handling are:

- annotation is not necessarily to for correctness of the program, but it may 
provide hints towards future optimizations
- Without annotation, the program still runs correctly, but certain 
optimizations may not trigger

### Step 0: Produce temp stages with annotation

The transformation produces temporary buffers (AC and BC), where the relation 
between those data and the A, B are recorded in two blocks(preproc and post 
proc).

Note that these additional annotations are hint for compilers to perform future 
optimizations(e.g. to lift them out our cancel. Our eventual goal could be 
directly reason those properties from the code, but annontations provides a 
first short cut.

```python
@T.prim_func
def grow(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]):
    AC = T.alloc_buffer([4, 4], "int32")
    BC = T.alloc_buffer([4, 4], "int32")

                for io, ii in T.grid(4, 4):
                                with T.block():
                                        T.block_attr("preproc", "pad")
                                AC[io, ii] = if_then_else(4 * io + ii < 14, A[4 
* io + ii], 0)

    for i, j in T.grid(4, 4):
        BC[i, j] = 2 * AC[i, j]

    for io, ii in T.grid(14):
                                with T.block():
                                                # hint that this is a cropping 
operation, 
                                                # where we know that the 
remaining part in B is 0
                                                # Additionally, the remaining 
uncovered values 
                                                # are assumed to be 0, if not 
provided then no assumptions are made
                                                T.block_attr("postproc", 
["crop", 0])
                                                B[io, ii] = BC[4 * io + ii]

@T.prim_func
def addone(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]):
    for i in T.grid(14):
        B[i] = A[i] + 1

@R.func
def main(A: T.Tensor[14, "int32"]):
        lv0 = call_tir(grow, [A], (14))
        # an intermdiate stage to show non-local reflowing
        lv1 = call_tir(addone, [lv0], (14))
        lv2 = call_tir(grow, [lv1], (14))
        ...

```

Not the special crop annotation comes with an `assumed_value`, which is 
provided as part of transformation (and actually we can prove that it is safe 
if our layout transformation starts from B and go backwards.

### Step 1: Reconstruct constraint at TIR-Graph level

By looking at the primfunc, we know that there is a desire to split out the 
preproc stage and postpost stage to the graph. Although it is totally fine for 
the compiler to choose not to do so and it is still a valid program. But let us 
say we choose to lift them out

```python
@T.prim_func
def grow_packed(AC: T.Buffer[[4,4], "int32"], BC: T.Buffer[[4,4], "int32"]):
    for i, j in T.grid(4, 4):
        BC[i, j] = 2 * AC[i, j]

@T.prim_func
def pad(A: T.Buffer[14, "int32"], AC: T.Buffer[[14, 14], "int32"]):
                for io, ii in T.grid(4, 4):
                                with T.block():
                                                T.block_attr("preproc", "pad")
                                                AC[io, ii] = if_then_else(4 * 
io + ii < 14, A[4 * io + ii], 0)

@T.prim_func
def crop_with_pad_assume(BC: T.Buffer[[4,4], "int32"], B: T.Buffer[14, 
"int32"]):
                # Note that this crop carries a pad assertion(of other values 
of BC)
    for io, ii in T.grid(14):
                                with T.block():
                                                T.block_attr("postproc", 
["crop", 0])
                                                B[io, ii] = BC[4 * io + ii]

@R.func
def main(A: T.Tensor[14, "int32"]):
        lv0 = call_tir(pad, (4, 4), A)
        lv1 = call_tir(grow, [lv0], (4, 4))
        # These are two things that we want to use for global format reflowing  
        lv2 = call_tir(crop_with_pad_assume, [lv1], (14))
        lv3 = call_tir(addone, [lv2], (14)
        lv4 = call_tir(pad, [lv2], (4, 4))
        lv4 = call_tir(grow, [lv3], (4, 4))
        lv5 = call_tir(crop_with_pad_assume, [(14))
```

### Step 2: Global Reflowing of layouts

Now as a last step, let us say we will do global reflowing.

- Start from reverse topo DAG order,
- Whenever we encounter a pad, we reconstruct a in-memory data 
structure(something like BufferConstraint, e.g. PadMapping(constraint, 
pad_value=0))
- We try to “backrop” the PadMapping through out the graph
- Each function needs to have its own TIR analysis of how it flows things back, 
for example, in the case of `addone`, we can safely flow PadMapping back, 
changing `addone` to `addone_packed` by analyzing the TIR. If the `addone` is 
elemwise exp however, we need to insert a select operator(because `exp(0)=1` ) 
the message to input becomes `PadMapping(constraint, pad_value=undef)`.
- When `PadMapping` meets `crop_with_pad_assert`, we can attempt to simplify 
and cancel out
- When there are branches, transpositions in the graph level or other more 
complicated issues, we might choose to materialize

### Discussion

There are a few key properties that is really desirable here:

- transformation of PrimFunc do not change the PrimFunc interface: this is 
really important so we can transform a PrimFunc without worrying about how the 
graph interacts with it(as the interface remains the same, we can lift out the 
blocks earlier)
- There are implicit assumption generated(`crop_with_pad_assume` ) to enable 
some simplification(otherwise a select is necessary, which is also not as bad). 
Note that assumption are generated under a global context (when we do transform 
padding we actually know that the overflowing field are 0). But extra amount of 
care is needed when we attempt to move `crop_with_pad_assume` , as it really 
depends on the value property of its input. A high-level gist is we should not 
do that, and instead the global reflowing of layout will reflow the 
`PadMapping` to `crop_with_pad_assume` then cancel it out.

Talking about “constraints”, it is also useful to talk about categories of 
them, roughly we can divide them into three categories.

- static_assert: We want to assert some invariance of the code, it is also 
necessary to “proof” that it is the case during compile time, otherwise 
compilation error needs to be raised.
- (runtime) assert: We want to assert some invariance of the code, it is not 
necessary to “proof” that this is the case, but we need to do runtime checking 
if it cannot be proved.
- assume (from __builtin_assume): We want to assert some invariance of the 
code, it is not necessary to “prove” that it is the case during compilation.

All three types of constraints can be helpful. In our particular case, `assume` 
is being generated in `crop_with_pad_assume`.

-- 
Reply to this email directly or view it on GitHub:
https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1164440693
You are receiving this because you are subscribed to this thread.

Message ID: <apache/tvm-rfcs/pull/77/c1164440...@github.com>

Reply via email to