Neat! This is a really good explanation. I think I get most everything you are explaining. (I'm still stuck on `reverse_compute_at` which seems like a long name, and is still a bit too magical for me to understand.)
In terms of splitting / reordering, my goofy thought is that my favorite construct for tensor abstractions is `vmap` in Jax. What makes programming tensors so hard is that keeping 6 tensor dimensions in your head is really hard. vmap lets you do something then "zoom" into the area, forget entirely about the outer dimension, and focus on that. When writing tvm code for matrix multiply with double buffering. I would really like to 1) first decide on my tiling of the output, split, assign to blocks and threads, and then 2) write a separately scoped bit of code that doesn't even know about the outer construction at all. Ideally, I would create my outer scope, vmap in, then all my buffers are automatically "reverse_compute_at" / vmapped, and then create my inner setup. I don't know if this totally works but this would be my ideal: ```python l_o = split_out(C, l) # only exposes the outer n_o = split_out(C, n) with prefix(ll_o, nn_o, tensors=[A, B], threads=[]) as s2: s2.cache_read(... ) # this cache read is now local computed at here l, n = s2.axes(C) # these are now the inner splitted axes m = s2.axes(A) # A's outer axes are now invisible s2.reorder(n, l) # only touches the visible axes. ``` (Maybe instead of a `with` this is an inner function like in jax) --- [Visit Topic](https://discuss.tvm.apache.org/t/rfc-tensorir-a-schedulable-ir-for-tvm/7872/45) to respond. You are receiving this because you enabled mailing list mode. To unsubscribe from these emails, [click here](https://discuss.tvm.apache.org/email/unsubscribe/05e4fd027d9c9c1a9ca4be05a983d0df9f859ea1f43bfac0a64be8f03e0bf98f).