i saw the comment in nn.softmax

    This operator can be optimized away for inference

for now, the bert performance bottleneck is related with softmax.
what's the meaning of this comment,how to optimize away this op.
the ir may like below:

    %1579 = fn (%p0218: Tensor[(128, 12, 128, 128), float32], Primitive=1, 
hash="2bf1f4aba825ef91") -> Tensor[(128, 12, 128, 128), float32] {
      nn.softmax(%p0218) /* ty=Tensor[(128, 12, 128, 128), float32] */
    };
    %1580 = %1579(%1578) /* ty=Tensor[(128, 12, 128, 128), float32] */;
    %1581 = fn (%p0217: Tensor[(128, 12, 128, 128), float32], Primitive=1, 
relay.reshape_only=1, hash="cb182a507ee11eec") -> Tensor[(1536, 128, 128), 
float32] {
      reshape(%p0217, newshape=[-1, 128, 128]) /* ty=Tensor[(1536, 128, 128), 
float32] */
    };
    %1594 = fn (%p0296: Tensor[(16384, 768), int8], %p1234: Tensor[(768, 768), 
int8], %p2153: Tensor[(16384, 1), int32], %p379: int32, %p479: Tensor[(768), 
int32], %p553: int32, %p653: float32, Primitive=1, hash="fc3402299aaf823e") -> 
Tensor[(16384, 768), float32] {
      %1586 = nn.dense(%p0296, %p1234, units=768, out_dtype="int32") /* 
ty=Tensor[(16384, 768), int32] */;
      %1587 = multiply(%p379, %p479) /* ty=Tensor[(768), int32] */;
      %1588 = subtract(%p553, %1587) /* ty=Tensor[(768), int32] */;
      %1589 = subtract(%1586, %p2153) /* ty=Tensor[(16384, 768), int32] */;
      %1590 = expand_dims(%1588, axis=0) /* ty=Tensor[(1, 768), int32] */;
      %1591 = add(%1589, %1590) /* ty=Tensor[(16384, 768), int32] */;
      %1592 = cast(%1591, dtype="float32") /* ty=Tensor[(16384, 768), float32] 
*/;
      %1593 = multiply(%p653, 0.00167063f /* ty=float32 */) /* ty=float32 */;
      multiply(%1592, %1593) /* ty=Tensor[(16384, 768), float32] */
    };
    %1595 = %1594(%1545, meta[relay.Constant][33] /* ty=Tensor[(768, 768), 
int8] */, %1553, %1551, meta[relay.Constant][34] /* ty=Tensor[(768), int32] */, 
%1554, %1540) /* ty=Tensor[(16384, 768), float32] */;
    %1596 = fn (%p0295: Tensor[(16384, 768), float32], %p1233: Tensor[(768), 
float32], Primitive=1, hash="9fb7058ef5653f52") -> Tensor[(1536, 64, 128), 
float32] {
      %1582 = reshape(%p0295, newshape=[128, 128, 768]) /* ty=Tensor[(128, 128, 
768), float32] */;
      %1583 = add(%1582, %p1233) /* ty=Tensor[(128, 128, 768), float32] */;
      %1584 = reshape(%1583, newshape=[128, 128, 12, 64]) /* ty=Tensor[(128, 
128, 12, 64), float32] */;
      %1585 = transpose(%1584, axes=[0, 2, 3, 1]) /* ty=Tensor[(128, 12, 64, 
128), float32] */;
      reshape(%1585, newshape=[-1, 64, 128]) /* ty=Tensor[(1536, 64, 128), 
float32] */
    };
    %1597 = %1581(%1580) /* ty=Tensor[(1536, 128, 128), float32] */;
    %1598 = %1596(%1595, meta[relay.Constant][35] /* ty=Tensor[(768), float32] 
*/) /* ty=Tensor[(1536, 64, 128), float32] */;
    %1599 = fn (%p0216: Tensor[(1536, 128, 128), float32], %p1185: 
Tensor[(1536, 64, 128), float32], Primitive=1, hash="ee1827ff1631f589") -> 
Tensor[(1536, 128, 64), float32] {
      nn.batch_matmul(%p0216, %p1185, transpose_b=True) /* ty=Tensor[(1536, 
128, 64), float32] */
    };
    %1600 = %1599(%1597, %1598) /* ty=Tensor[(1536, 128, 64), float32] */;

can this condition be optimized?





---
[Visit 
Topic](https://discuss.tvm.apache.org/t/performance-can-inference-optimized-away-softmax/12166/1)
 to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.apache.org/email/unsubscribe/c47812dbcc143baa3cc682c6b916a883f8855ca13fdf583a2bd7807b8163cfa1).

Reply via email to