For the integration of a new intrinsic, I would like to do a transformation to a TIR schedule to inline the addition of a bias into a matrix multiplication. I have created a very simple example to reproduce my problem, let's assume to following PrimFunc:
```python @T.prim_func def func( A: T.Buffer((16, 16), "int8"), B: T.Buffer((16, 16), "int8"), C: T.Buffer((16, 16), "int32"), D: T.Buffer((16, 16), "int32"), ) -> None: temp = T.alloc_buffer((16, 16), dtype="int32") for i, j, k in T.grid(16, 16, 16): with T.block("multiply"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): temp[vi, vj] = T.int32(0) temp[vi, vj] = temp[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32") for i, j in T.grid(16, 16): with T.block("add"): vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = temp[vi, vj] + C[vi, vj] ``` I want to transform it to achieve the following: ```python @T.prim_func def expected_v1( A: T.Buffer((16, 16), "int8"), B: T.Buffer((16, 16), "int8"), C: T.Buffer((16, 16), "int32"), D: T.Buffer((16, 16), "int32"), ) -> None: temp = T.alloc_buffer((16, 16), dtype="int32") for i, j, k in T.grid(16, 16, 16): with T.block("multiply"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): temp[vi, vj] = T.int32(0) temp[vi, vj] = temp[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32") + C[vi, vj] for i, j in T.grid(16, 16): with T.block("add"): vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = temp[vi, vj] ``` Or, ideally: ```python @T.prim_func def expected_v2( A: T.Buffer((16, 16), "int8"), B: T.Buffer((16, 16), "int8"), C: T.Buffer((16, 16), "int32"), D: T.Buffer((16, 16), "int32"), ) -> None: temp = T.alloc_buffer((16, 16), dtype="int32") for i, j, k in T.grid(16, 16, 16): with T.block("multiply"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): D[vi, vj] = C[vi, vj] D[vi, vj] = D[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32") ``` As you can see, mathematically all these computations are equivalent, so I would expect there is some way of getting there. But everything I tried failed. I tried to use compute_inline in the multiply block, reverse_comput_inline in the add block, decompose_reduction and then reverse_compute_inline... Could someone confirm this is indeed not possible? And if that is the case, why? These seem like valid transformations, that should be possible in some way, but I am probably missing the reason why those aren't possible. Here is some example code to show some of what I tried (returns error `Error message: The consumer block tir.Block#0 to be inlined is required to have only a single producer block, and the producer block should be a complete block who has only a single consumer`): ```python if __name__ == "__main__": sch = tir.Schedule(func) mult_block = sch.get_block("multiply") init_block = sch.decompose_reduction(mult_block, sch.get_loops(mult_block)[-1]) update_block = sch.get_block("multiply_update") add_block = sch.get_block("add") sch.cache_write(add_block, 0, "local") sch.reverse_compute_inline(add_block) ``` --- [Visit Topic](https://discuss.tvm.apache.org/t/tir-problem-inlining-addition-into-matmul-block/18066/1) to respond. You are receiving this because you enabled mailing list mode. To unsubscribe from these emails, [click here](https://discuss.tvm.apache.org/email/unsubscribe/68883c0cab2d389017bff8285870da636622f8aec57184b8363cca4a8c4f3398).