For the integration of a new intrinsic, I would like to do a transformation to 
a TIR schedule to inline the addition of a bias into a matrix multiplication. I 
have created a very simple example to reproduce my problem, let's assume to 
following PrimFunc:

```python
@T.prim_func
def func(
    A: T.Buffer((16, 16), "int8"),
    B: T.Buffer((16, 16), "int8"),
    C: T.Buffer((16, 16), "int32"),
    D: T.Buffer((16, 16), "int32"),
    ) -> None:
 
    temp  = T.alloc_buffer((16, 16), dtype="int32")

    for i, j, k in T.grid(16, 16, 16):
        with T.block("multiply"):
            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
            with T.init():
                temp[vi, vj] = T.int32(0)
            temp[vi, vj] = temp[vi, vj] + T.cast(A[vi, vk], "int32") * 
T.cast(B[vj, vk], "int32")
    
    for i, j in T.grid(16, 16):
        with T.block("add"):
            vi, vj = T.axis.remap("SS", [i, j])
            D[vi, vj] = temp[vi, vj] + C[vi, vj] 
```

I want to transform it to achieve the following:

```python
@T.prim_func
def expected_v1(
    A: T.Buffer((16, 16), "int8"),
    B: T.Buffer((16, 16), "int8"),
    C: T.Buffer((16, 16), "int32"),
    D: T.Buffer((16, 16), "int32"),
    ) -> None:
 
    temp  = T.alloc_buffer((16, 16), dtype="int32")

    for i, j, k in T.grid(16, 16, 16):
        with T.block("multiply"):
            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
            with T.init():
                temp[vi, vj] = T.int32(0)
            temp[vi, vj] = temp[vi, vj] + T.cast(A[vi, vk], "int32") * 
T.cast(B[vj, vk], "int32") + C[vi, vj]
    
    for i, j in T.grid(16, 16):
        with T.block("add"):
            vi, vj = T.axis.remap("SS", [i, j])
            D[vi, vj] = temp[vi, vj]
```

Or, ideally:

```python
@T.prim_func
def expected_v2(
    A: T.Buffer((16, 16), "int8"),
    B: T.Buffer((16, 16), "int8"),
    C: T.Buffer((16, 16), "int32"),
    D: T.Buffer((16, 16), "int32"),
    ) -> None:
 
    temp  = T.alloc_buffer((16, 16), dtype="int32")

    for i, j, k in T.grid(16, 16, 16):
        with T.block("multiply"):
            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
            with T.init():
                D[vi, vj] = C[vi, vj]
            D[vi, vj] = D[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, 
vk], "int32")
```

As you can see, mathematically all these computations are equivalent, so I 
would expect there is some way of getting there. But everything I tried failed. 
I tried to use compute_inline in the multiply block, reverse_comput_inline in 
the add block, decompose_reduction and then reverse_compute_inline...

Could someone confirm this is indeed not possible? And if that is the case, 
why? These seem like valid transformations, that should be possible in some 
way, but I am probably missing the reason why those aren't possible.

Here is some example code to show some of what I tried (returns error `Error 
message: The consumer block tir.Block#0 to be inlined is required to have only 
a single producer block, and the producer block should be a complete block who 
has only a single consumer`):

```python
if __name__ == "__main__":
    sch = tir.Schedule(func)

    mult_block = sch.get_block("multiply")
    init_block = sch.decompose_reduction(mult_block, 
sch.get_loops(mult_block)[-1])
    update_block = sch.get_block("multiply_update")
    add_block = sch.get_block("add")
    sch.cache_write(add_block, 0, "local")
    sch.reverse_compute_inline(add_block)
```





---
[Visit 
Topic](https://discuss.tvm.apache.org/t/tir-problem-inlining-addition-into-matmul-block/18066/1)
 to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.apache.org/email/unsubscribe/68883c0cab2d389017bff8285870da636622f8aec57184b8363cca4a8c4f3398).

Reply via email to