Thanks for your comments:)
[quote="areusch, post:3, topic:11807"] could you say more here? is this a Relay-level thing or a TIR thing? presuming you’ve implemented this as a pass, how do you plan to ensure that the Relay-level pass makes the same scheduling decision as the TIR pass? [/quote] Perhaps I could take a fake example on Conv2d to describe it: fn (%arg0: Tensor[(1, 32, 224, 224), int8], %nn.conv2d_arg: Tensor[(32, 3, 7, 7), int8]) { %conv_fn = fn (%data: Tensor[(1, 3, 224, 224), int8], %weight: Tensor[(32, 3, 7, 7), int8], Primitive=1) { nn.conv2d(%data, %weight, padding=[1, 1, 1, 1], kernel_size=[7, 7], out_dtype="int32") }; %conv_fn(%arg0, %nn.conv2d_arg) } and the coresponding PrimFunc for primitive call `%conv_fn` would be like ```python @T.prim_func def main(x: T.Buffer[...], weight: T.Buffer[(32, 3, 7, 7), "int8"], y: T.Buffer[...]) -> None: # body ``` Assume to utilize the specific hardware, we want to arrange I/O channels into 4*4 tiles. There are extra two notes: - We get to know the "best" weight layout until a TIR schedule/tuning is done. - The required layout is out of scope of common representations like "OIHW", "OHWI", etc. The TIR schedule part would do following transformation on `weight`: ```python o, i, h, w = s.get_read_buffer_axes(conv_block) o_outer, o_inner = s.buffer_split(o, factor=4) # [32, 3, 7, 7] -> [8, 4, 3, 7, 7] i_outer, i_inner = s.buffer_split(i, factor=4) # [8, 4, 3, 7, 7] -> [8, 4, 1, 4, 7, 7] s.buffer_reorder(o_outer, o_inner, i_outer, i_inner, h, w) # [8, 4, 1, 4, 7, 7] -> [8, 1, 4, 4, 7, 7] ``` Above we use a set of extended TensorIR primitives, but they can just be seen as sugars of ongoing schedule primitive `transform_layout`: https://github.com/apache/tvm-rfcs/pull/39 The point is that they are not arbitary index remappings (compare to a general `transform_layout`). We ensure every such schedule step takes exact equivalent relay transformations. In TIR schedule phase, we trace every buffer layout change on function param buffer (we can do that since they are what we implement), generate the transform (&& reverse transform) in relay on each step, and finally compose them into single layout transform (&& reverse transform) functions in relay. For the used example, it would be: - `s.buffer_split(o, factor=4)` - x -> relay.reshape(x, [-1, 4, 3, 7, 7]) - (reverse) x -> relay.reshape(x, [32, 3, 7, 7]) - `s.buffer_split(i, factor=4)` - x -> relay.reshape(relay.nn.pad(x, [..., (0, 1), ...]), [8, 4, -1, 4, 7, 7]) - (reverse) x -> relay.strided_slice(relay.reshape(x, [8, 4, 4, 7, 7]), begin=..., end=...) - `s.buffer_reorder(...)` - x -> relay.transpose(x, [...]) - (reverse) x -> relay.transpose(x, [...]) Finally all transforms (&& reverse transforms) are composed into two `relay.Function` objects to rewrite relay-level layouts, which accepts original relay params, returns updated params tuple: fn (%p0: Tensor[..., int8], %p1: Tensor[(32, 3, 7, 7), int8]) { %0 = reshape(%p1, newshape=[...]); %1 = nn.pad(%0, pad_width=[...]); %2 = reshape(%1, newshape=[...]); %3 = transpose(%2, axes=[...]); (%p0, %3) } and the reverse direction is: fn (%p0: Tensor[..., int8], %p1: Tensor[(8, 4, 1, 4, 7, 7), int8]) { %0 = transpose(%p1, axes=[...]); %1 = reshape(%0, newshape=[...]); %2 = strided_slice(%1, begin=[...], end=[...], strides=[...]); %3 = reshape(%2, newshape=[32, 3, 7, 7]); (%p0, %3) } A relay pass now can perform "pre"-schedule for each primitive function, fetch the layout transform functions from schedule result, and perform relay-level layout updation. Finally, an extra `FoldConstants` could eliminate all extra transformations out of primitive calls typically. fn (%arg0: Tensor[(1, 32, 224, 224), int8], %nn.conv2d_arg: Tensor[(32, 3, 7, 7), int8]) { %0 = reshape(%nn.conv2d_arg, newshape=[...]); %1 = nn.pad(%0, pad_width=[...]); %2 = reshape(%1, newshape=[...]); %3 = transpose(%2, axes=[...]); %conv_fn = fn (%data: Tensor[(1, 3, 224, 224), int8], %weight: Tensor[(8, 4, 1, 4, 7, 7), int8], Primitive=1, DevicePrimFuncKey=873487) { %4 = transpose(%weight, axes=[...]); %5 = reshape(%4, newshape=[...]); %6 = strided_slice(%5, begin=[...], end=[...], strides=[...]); %7 = reshape(%6, newshape=[32, 3, 7, 7]); nn.conv2d(%data, %7, padding=[1, 1, 1, 1], kernel_size=[7, 7], out_dtype="int32"); }; %conv_fn(%arg0, %3) } The actual params are transformed before call into `%conv_fn` and the formal params are reversed within `%conv_fn`'s body. Why we need reverse transforms is that we currently can not represent a "lowered" function call in relay (correct me). It is a workaround for us to keep a valid primitive function body, that is, the relay module after pass can still be safely evaluated on a CPU. All things described are only targeted to weights (free tensors) now. We check that a tensor produced/consumed by other relay calls should not get transformed. For input and output layouts, we find relay `ConvertLayout` can cover the currently demands. However, I think there is no essential difference between "appliable functions to transform layout" and a simple tag like "NCHW" on a input/output, it is possible to rewrite the input/output with the same machanism. One remaining issue here is that we have to hack the `CompileEngine`(now `TECompiler`) to cache and reuse the previously scheduled PrimFuncs. Very glad to know if existing machanisms (like `relay_to_tir`?) can help us :slight_smile: cc @areusch --- [Visit Topic](https://discuss.tvm.apache.org/t/introducing-ty-nnp-backend-with-end2end-tensorir-integration/11807/4) to respond. You are receiving this because you enabled mailing list mode. To unsubscribe from these emails, [click here](https://discuss.tvm.apache.org/email/unsubscribe/b8849695b0789156eaf9d2d0cd3cf1b0c03ed4539a781710dad9b544cb5ece17).