Neat! This is a really good explanation. I think I get most everything you are 
explaining. (I'm still stuck on `reverse_compute_at` which seems like a long 
name, and is still a bit too magical for me to understand.)

In terms of splitting / reordering, my goofy thought is that my favorite 
construct for tensor abstractions is `vmap` in Jax. What makes programming 
tensors so hard is that keeping 6 tensor dimensions in your head is really 
hard. vmap lets you do something then "zoom" into the area, forget entirely 
about the outer dimension, and focus on that. 

When writing tvm code for matrix multiply with double buffering. I would really 
like to 1) first decide on my tiling of the output, split, assign to blocks and 
threads, and then 2) write a separately scoped bit of code that doesn't even 
know about the outer construction at all.  Ideally, I would create my outer 
scope, vmap in, then all my buffers are automatically "reverse_compute_at" / 
vmapped, and then create my inner setup.

I don't know if this totally works but this would be my ideal:

```python
l_o = split_out(C, l) # only exposes the outer
n_o = split_out(C, n)

with prefix(ll_o, nn_o, tensors=[A, B],  threads=[]) as s2:
     s2.cache_read(... ) # this cache read is now local computed at here
     l, n = s2.axes(C) # these are now the inner splitted axes
     m = s2.axes(A) # A's outer axes are now invisible
     s2.reorder(n, l) # only touches the visible axes. 
```

(Maybe instead of a `with` this is an inner function like in jax)





---
[Visit 
Topic](https://discuss.tvm.apache.org/t/rfc-tensorir-a-schedulable-ir-for-tvm/7872/45)
 to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.apache.org/email/unsubscribe/05e4fd027d9c9c1a9ca4be05a983d0df9f859ea1f43bfac0a64be8f03e0bf98f).

Reply via email to