# [RFC] CSE

## Motivation

**C**ommon **S**ubexpression **E**limination is a compiler optimization that
effectively avoids repeated computation of the same expression. It is needed in
the AutoDiff submodule to improve its performance. As an example, consider a
compute definition that is as simple as the tanh activation function:

```
Y = tanh(X)
```

whose gradient definitions are

```
dX = (1 - tanh(X) * tanh(X)) * dY
```

We can clearly see that the value of `dX` is evaluated for **3 times** in this
example: once in the forward pass, and twice in the backward pass. Although such
re-evaluation might not be a big deal here (because tanh is a relatively
lightweight operator), one can probably imagine the trouble that this will bring
as the operators become more complicated, as our examples in the later text 
will show.

## Key Ideas

To illustrate our idea better, we consider the compute definition for the
BatchNorm operator and its gradient on the placeholder γ:

![](https://latex.codecogs.com/gif.latex?y%20%3D%20%5Coperatorname%7BBatchNorm%7D%28x%29%20%3D%20%5Chat%7Bx%7D%5Cgamma%20+%20%5Cbeta%20%3D%20%5Cfrac%7Bx%20-%20%5Coperatorname%7BE%7D%5Bx%5D%7D%7B%5Csqrt%7B%5Coperatorname%7BVAR%7D%5Bx%5D%20+%20%5Cvarepsilon%7D%7D%5Cgamma%20+%20%5Cbeta),
![](upload://dbZ7V4HUw9l4IrupMiCIlRDjIbr.gif) 

To avoid re-evaluating the normalized x (denoted as x_hat), we must complete the
following steps:

1. **Analysis**

   We need to find out that x_hat is the LARGEST common subexpression between
   the forward and the backward pass.
1. **Transform (Forward)**

   We need to set x_hat as one of the outputs, so that it can be stashed in
   memory by the forward pass.
1. **Transform (Backward)**

   We need to replace the x_hat in the backward compute definition with a
   placeholder, whose value will be fed by the one that we stashed in the
   previous step.

As one might notice, those steps are in essence inferring the feature maps
(a.k.a. backward dependency). In legacy machine learning frameworks (e.g.,
MXNet, TensorFlow) and deep learning libraries (e.g., cuDNN), this is done in a
manual, hard-coded way. But here we are doing this automatically, that is, the
feature maps are determined in such a way that the amount of computation needed
to evaluate the backward gradients is minimized.

## Implementation Details

To implement CSE, we need to go through the following steps:

1. Tensor AutoInliner

   The tensor inliner is needed to simplify the tensor expressions generated by
   the AD pass and has already been implemented as part of `autodiff/ad_util`.
   As an example, consider the raw compute definition from the AD pass on the γ
   gradient of the BatchNorm operator:

   ```
   extracted_tensor[b, h, w, -c] = 
           (X[b, -c, h, w] - E_X[-c]) / sqrt(VAR_X[-c] + ε)
   dγ[c] += dY[b, c, h, w] × extracted_tensor[b, h, w, -c]
   ```

   As we can see, the introduction of the intermediate tensor `extracted_tensor`
   causes trouble in the optimization: not only does it have a different
   indexing order from X (and dY), but a reverse index as well (i.e., c).
   Futhermore, the access to the `extracted_tensor` adds an extra level of
   indirection in the tensor expression comparison (described later). Therefore,
   the first step of the CSE is to automatically inline all the injective
   computes in their respective consumers.
1. CSE Optimizer
   
   ```C++
   // Note that for simplicity, some low-level details are omitted here.
   class CSEOptimizer {
    private:
     Tensor* src_;
     TensorExpTree src_tensor_expr_tree_, tgt_tensor_expr_tree_;  //< described 
later
    private:
     /// @brief Optimize the expression (to a placeholder) if
     ///
     ///            src_tensor_expr_tree_.Find(tgt_tensor_expr_tree_.at(expr))
     ///
     ///        Dispatch to the following operation nodes:
     ///          - Call
     ///          - +, -, *, /
     ///          - Reduce
     ///          - Int/FloatImm
     Expr Optimize(const Expr& expr);
    public:
     void Optimize(Tensor* const src, Tensor* const tgt) {
       src_tensor_tree_.Visit(*src);
       tgt_tensor_tree_.Visit(*tgt);
       new_body_stmt = Optimize(tgt->op->body);
       *tgt = ComputeOpNode::make(new_body_stmt, ...);
     }
   };
   ```

   The CSE optimizer is constructed from a `src` tensor expression to optimize
   `tgt`. It stores internally a tensor expression tree for each of them
   (described later). The optimizer operates on the body statement of target. As
   it optimizes for each expression, it looks the expression up in the target
   expression tree for the expression subtree and will replace the expression
   with a placeholder if the source expression tree is able to locate the same 
subtree.
1. Tensor Expression Tree
   
   ![](upload://9f1wY2L4NWKi1MVuC5VE4I2QC9D.png)

   ```C++
   class TensorExprTree {
    private:
     /*! \brief Construct a tensor expression, whose operator type is \p node
      *         and shape is \p axis.
      *
      *         Dispatch to the following operation nodes:
      *           - Call
      *           - ProducerLoad
      *           - +, -, *, /
      *           - Reduce
      *           - Int/FloatImm
      */
     TensorExprPtr Construct(const ObjectRef& node, const Array<IterVar>& axis)
    public:
     /*! \brief Visit `tensor`'s body statement to construct the expression 
tree. 
      */
     void Visit(const Tensor& tensor);
     /*! \brief Find a tensor expression in the tree.
      */
     bool Find(const TensorExpr& expr) const;
   };
   ```

   The tensor expression tree, constructed from a tensor, is a tree-like
   structure that is used to represent the expressions that a tensor has
   evaluated. It is in fact very similar to the legacy NNVM graph. The tree is
   able to search its subtrees to determine if they match certain tensor
   expression, which, as we have shown earlier, is used by the CSE optimization.

@yzhliu





---
[Visit Topic](https://discuss.tvm.apache.org/t/rfc-cse-optimization/8130/1) to 
respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.apache.org/email/unsubscribe/8297944c8b260ce1f73f46e8d3802b4e4c673fe88d664adb2ba8dba7d2a2c69b).

Reply via email to