Thanks for your comments:)

[quote="areusch, post:3, topic:11807"]
could you say more here? is this a Relay-level thing or a TIR thing? presuming 
you’ve implemented this as a pass, how do you plan to ensure that the 
Relay-level pass makes the same scheduling decision as the TIR pass?
[/quote]

Perhaps I could take a fake example on Conv2d to describe it:

    fn (%arg0: Tensor[(1, 32, 224, 224), int8], %nn.conv2d_arg: Tensor[(32, 3, 
7, 7), int8]) {
      %conv_fn = fn (%data: Tensor[(1, 3, 224, 224), int8], %weight: 
Tensor[(32, 3, 7, 7), int8], Primitive=1) {
        nn.conv2d(%data, %weight, padding=[1, 1, 1, 1],  kernel_size=[7, 7], 
out_dtype="int32")
      };
      %conv_fn(%arg0, %nn.conv2d_arg)
    }

and the coresponding PrimFunc for primitive call `%conv_fn` would be like
```python
@T.prim_func
def main(x: T.Buffer[...], weight: T.Buffer[(32, 3, 7, 7), "int8"], y: 
T.Buffer[...]) -> None:
     # body
```
Assume to utilize the specific hardware, we want to arrange I/O channels into 
4*4 tiles. There are extra two notes:
- We get to know the "best" weight layout until a TIR schedule/tuning is done.
- The required layout is out of scope of common representations like "OIHW", 
"OHWI", etc.

The TIR schedule part would do following transformation on `weight`:

```python
o, i, h, w = s.get_read_buffer_axes(conv_block)
o_outer, o_inner = s.buffer_split(o, factor=4)  # [32, 3, 7, 7] -> [8, 4, 3, 7, 
7]
i_outer, i_inner = s.buffer_split(i, factor=4)  # [8, 4, 3, 7, 7] -> [8, 4, 1, 
4, 7, 7]
s.buffer_reorder(o_outer, o_inner, i_outer, i_inner, h, w)  #  [8, 4, 1, 4, 7, 
7] -> [8, 1, 4, 4, 7, 7]
```

Above we use a set of extended TensorIR primitives, but they can just be seen 
as sugars of ongoing schedule primitive `transform_layout`:
https://github.com/apache/tvm-rfcs/pull/39

The point is that they are not arbitary index remappings (compare to a general 
`transform_layout`). We ensure every such schedule step takes exact equivalent 
relay transformations.

In TIR schedule phase, we trace every buffer layout change on function param 
buffer (we can do that since they are what we implement), generate the 
transform (&& reverse transform) in relay on each step,  and finally compose 
them into single layout transform (&& reverse transform) functions in relay.

For the used example, it would be:

- `s.buffer_split(o, factor=4)` 
  - x ->  relay.reshape(x, [-1, 4, 3, 7, 7])
  - (reverse) x -> relay.reshape(x, [32, 3, 7, 7])

- `s.buffer_split(i, factor=4)`
  - x -> relay.reshape(relay.nn.pad(x, [..., (0, 1), ...]), [8, 4, -1, 4, 7, 7])
  - (reverse) x -> relay.strided_slice(relay.reshape(x, [8, 4, 4, 7, 7]), 
begin=..., end=...)

- `s.buffer_reorder(...)`
  - x -> relay.transpose(x, [...])
  - (reverse) x -> relay.transpose(x, [...])

Finally all transforms (&& reverse transforms) are composed into two 
`relay.Function` objects to rewrite relay-level layouts, which accepts original 
relay params, returns updated params tuple:

    fn (%p0: Tensor[..., int8], %p1: Tensor[(32, 3, 7, 7), int8]) {
      %0 = reshape(%p1, newshape=[...]);
      %1 = nn.pad(%0, pad_width=[...]);
      %2 = reshape(%1, newshape=[...]);
      %3 = transpose(%2, axes=[...]);
      (%p0, %3)
    }

and the reverse direction is:

    fn (%p0: Tensor[..., int8], %p1: Tensor[(8, 4, 1, 4, 7, 7), int8]) {
      %0 = transpose(%p1, axes=[...]);
      %1 = reshape(%0, newshape=[...]);
      %2 = strided_slice(%1, begin=[...], end=[...], strides=[...]);
      %3 = reshape(%2, newshape=[32, 3, 7, 7]);
      (%p0, %3)
    }
   

A relay pass now can perform "pre"-schedule for each primitive function, fetch 
the layout transform functions from schedule result, and perform relay-level 
layout updation. Finally, an extra `FoldConstants` could eliminate all extra 
transformations out of primitive calls typically.

     fn (%arg0: Tensor[(1, 32, 224, 224), int8], %nn.conv2d_arg: Tensor[(32, 3, 
7, 7), int8]) {
      %0 = reshape(%nn.conv2d_arg, newshape=[...]);
      %1 = nn.pad(%0, pad_width=[...]);
      %2 = reshape(%1, newshape=[...]);
      %3 = transpose(%2, axes=[...]);
      %conv_fn = fn (%data: Tensor[(1, 3, 224, 224), int8], %weight: Tensor[(8, 
4, 1, 4, 7, 7), int8], Primitive=1, DevicePrimFuncKey=873487) {
       %4 = transpose(%weight, axes=[...]);
       %5 = reshape(%4, newshape=[...]);
       %6 = strided_slice(%5, begin=[...], end=[...], strides=[...]);
       %7 = reshape(%6, newshape=[32, 3, 7, 7]); 
       nn.conv2d(%data, %7, padding=[1, 1, 1, 1], kernel_size=[7, 7], 
out_dtype="int32");
      };
      %conv_fn(%arg0, %3)
    }

The actual params are transformed before call into `%conv_fn` and the formal 
params are reversed within `%conv_fn`'s body. Why we need reverse transforms is 
that we currently can not represent a "lowered" function call in relay (correct 
me). It is a workaround for us to keep a valid primitive function body, that 
is, the relay module after pass can still be safely evaluated on a CPU.

All things described are only targeted to weights (free tensors) now. We check 
that a tensor produced/consumed by other relay calls should not get 
transformed. For input and output layouts, we find relay `ConvertLayout` can 
cover the currently demands. However, I think there is no essential difference 
between "appliable functions to transform layout" and a simple tag like "NCHW" 
on a input/output, it is possible to rewrite the input/output with the same 
machanism.

One remaining issue here is that we have to hack the `CompileEngine`(now 
`TECompiler`) to cache and reuse the previously scheduled PrimFuncs. Very glad 
to know if existing machanisms (like `relay_to_tir`?) can help us 
:slight_smile:   cc @areusch





---
[Visit 
Topic](https://discuss.tvm.apache.org/t/introducing-ty-nnp-backend-with-end2end-tensorir-integration/11807/4)
 to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.apache.org/email/unsubscribe/b8849695b0789156eaf9d2d0cd3cf1b0c03ed4539a781710dad9b544cb5ece17).

Reply via email to