## Background and Motivation

Currently, TVM uses the `tir.Simplify` pass to remove some redundant expression 
like nested equivalent if-condition. For example, given a simple softmax 
operation like

```
primfn(placeholder_1: handle, T_softmax_norm_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "fused_nn_softmax", 
"tir.noalias": True}
  buffers = {T_softmax_norm: Buffer(T_softmax_norm_2: Pointer(float32), 
float32, [2, 10, 257, 1025], []),
             placeholder: Buffer(placeholder_2: Pointer(float32), float32, [2, 
10, 257, 1025], [])}
  buffer_map = {placeholder_1: placeholder, T_softmax_norm_1: T_softmax_norm} {
  allocate(T_softmax_maxelem: Pointer(global float32), float32, [2, 10, 257]), 
storage_scope = global {
    attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] 
"thread_extent" = 6;
    attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] 
"thread_extent" = 1024 {
      if @tir.likely((floordiv(floordiv((threadIdx.x + (blockIdx.x*1024)), 
257), 10) < 2), dtype=bool) {
        if @tir.likely((floordiv((threadIdx.x + (blockIdx.x*1024)), 257) < 20), 
dtype=bool) {
          if @tir.likely(((threadIdx.x + (blockIdx.x*1024)) < 5140), 
dtype=bool) {
            T_softmax_maxelem[((blockIdx.x*1024) + threadIdx.x)] = 
-3.40282e+38f32
          }
        }
      }
      // ...
```

`tir.Simplify` will simplify this to

```
primfn(placeholder_1: handle, T_softmax_norm_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "fused_nn_softmax", 
"tir.noalias": True}
  buffers = {T_softmax_norm: Buffer(T_softmax_norm_2: Pointer(float32), 
float32, [2, 10, 257, 1025], []),
             placeholder: Buffer(placeholder_2: Pointer(float32), float32, [2, 
10, 257, 1025], [])}
  buffer_map = {placeholder_1: placeholder, T_softmax_norm_1: T_softmax_norm} {
  allocate(T_softmax_maxelem: Pointer(global float32), float32, [2, 10, 257]), 
storage_scope = global {
    attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] 
"thread_extent" = 6;
    attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] 
"thread_extent" = 1024 {
      if @tir.likely((((blockIdx.x*1024) + threadIdx.x) < 5140), dtype=bool) {
        T_softmax_maxelem[((blockIdx.x*1024) + threadIdx.x)] = -3.40282e+38f32
      }
      // ...
```

where three equivalent condition will be simplified to one.

However, things will be different when the given input has a dynamic shape. 
Current `tir.Simplify` will fail given an input with dynamic shape, this is 
because the analyzer (actually the `RewriteSimplifier` and the 
`CanonicalSimplifier`) used by this pass lacks corresponding rules for this 
"non-const" situation. In the next part of this post we will continuous to use 
this simple softmax example to discuss this problem. We try to fix this problem 
by adding more rules in both `RewriteSimplifier` and `CanonicalSimplifier`. 
Currently this is still an experimental idea, if you find something wrong or 
improper, feel free to correct us in this post directly.



## Proposal 

We will show our proposal by solving this problem in our simple softmax example 
here. As shown in [this 
post](https://discuss.tvm.apache.org/t/pre-rfc-dynamic-shape-use-sizevar-instead-of-var-when-convert-any-in-the-getshape-function/10625),
 we can eliminate some redundant expressions by introducing sign information 
into tensor shapes. But there are still some other redundant expressions that 
are not covered by this solution. These redundant expressions can be eliminated 
by the `tir.Simplify` pass when the input's shape is static as shown before. 
For the dynamic situation, we list the reasons that prevent `tir.Simplify` from 
simplifying these redundancy as follows:

1. `RewriteSimplifier` only has rules for `IntImm` to simplify `floordiv(x, c1) 
< c2` to `x < c1 * c2`.
2. After the simplification from `floordiv(x, c1) < c2` to `x < c1 * c2`, we 
can directly get a new constant `c3 = c1 * c2` providing `c1` and `c2` are 
`IntImm`. But if we are given variables (or even worse, expressions), we cannot 
distinguish between `v1 * v2` and `v2 * v1`.



For clarity, we use `SizeVar` `d0`, `d1`, `d2`,  and `d3`  for the shape in our 
simple softmax example. The output of current `tir.Simplify` is

```c++
primfn(placeholder_1: handle, T_softmax_norm_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "fused_nn_softmax", 
"tir.noalias": True}
  buffers = {placeholder: Buffer(placeholder_2: Pointer(float32), float32, [d0: 
int32, d1: int32, d2: int32, d3: int32], [stride: int32, stride_1: int32, 
stride_2: int32, stride_3: int32], type="auto"),
             T_softmax_norm: Buffer(T_softmax_norm_2: Pointer(float32), 
float32, [d0, d1, d2, d3], [stride_4: int32, stride_5: int32, stride_6: int32, 
stride_7: int32], type="auto")}
  buffer_map = {placeholder_1: placeholder, T_softmax_norm_1: T_softmax_norm} {
  allocate(T_softmax_maxelem: Pointer(global float32), float32, [d0, d1, d2]), 
storage_scope = global {
    attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] 
"thread_extent" = floordiv((((d0*d1)*d2) + 511), 512);
    attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] 
"thread_extent" = 512 {
      if @tir.likely((floordiv(floordiv(((blockIdx.x*512) + threadIdx.x), d2), 
d1) < d0), dtype=bool) {
        if @tir.likely((floordiv(((blockIdx.x*512) + threadIdx.x), d2) < 
(d0*d1)), dtype=bool) {
          if @tir.likely((((blockIdx.x*512) + threadIdx.x) < ((d0*d1)*d2)), 
dtype=bool) {
            T_softmax_maxelem[((blockIdx.x*512) + threadIdx.x)] = 
-3.40282e+38f32
          }
        }
      }
```

To simplify `floordiv(((blockIdx.x*512) + threadIdx.x), d2) < (d0*d1)` and 
`floordiv(floordiv(((blockIdx.x*512) + threadIdx.x), d2), d1) < d0` to 
`(blockIdx.x*512) + threadIdx.x < ((d0*d1)*d2)`, we add a new rule in `PrimExpr 
RewriteSimplifier::Impl::VisitExpr_(const LTNode* op);`:

```c++
PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LTNode* op) {
    // ...
    PVar<PrimExpr> x, y, z, s1, s2;
    // ...
    TVM_TRY_REWRITE_IF(floordiv(x, s1) < s2, x < s1 * s2,
                       analyzer_->const_int_bound(s1.Eval())->min_value >= 0);
    // ...
}
```

Here comes our first worry. The corresponding `IntImm` version of this rule is

```c++
TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x, 0 < floormod(x, c1), 
c1.Eval()->value > 0);
```

where the if-condition is `>0` instead of `>=0`. For `PrimExpr` version we can 
only get the non-negative information from the `ConstIntBoundAnalyzer` (this 
bound information comes from simple facts like `SizeVar + SizeVar >= 0` or 
`SizeVar * SizeVar >= 0`). Although `=0` is an invalid case, this 
transformation is not equivalent and may hide some run-time errors. [This 
post](https://discuss.tvm.apache.org/t/discuss-embed-more-bound-information-into-var-or-expr/4079)
 shows some possible solutions for this issue but there may be some simpler 
solutions (if you have any idea, please share with us).



After adding this rule for rewrite simplify, we get:

```c++
primfn(placeholder_1: handle, T_softmax_norm_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "fused_nn_softmax", 
"tir.noalias": True}
  buffers = {placeholder: Buffer(placeholder_2: Pointer(float32), float32, [d0: 
int32, d1: int32, d2: int32, d3: int32], [stride: int32, stride_1: int32, 
stride_2: int32, stride_3: int32], type="auto"),
             T_softmax_norm: Buffer(T_softmax_norm_2: Pointer(float32), 
float32, [d0, d1, d2, d3], [stride_4: int32, stride_5: int32, stride_6: int32, 
stride_7: int32], type="auto")}
  buffer_map = {placeholder_1: placeholder, T_softmax_norm_1: T_softmax_norm} {
  allocate(T_softmax_maxelem: Pointer(global float32), float32, [d0, d1, d2]), 
storage_scope = global {
    attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] 
"thread_extent" = floordiv((((d0*d1)*d2) + 511), 512);
    attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] 
"thread_extent" = 512 {
      if @tir.likely((((blockIdx.x*512) + threadIdx.x) < (d2*(d1*d0))), 
dtype=bool) {
        if @tir.likely((((blockIdx.x*512) + threadIdx.x) < (d2*(d0*d1))), 
dtype=bool) {
          if @tir.likely((((blockIdx.x*512) + threadIdx.x) < ((d0*d1)*d2)), 
dtype=bool) {
            T_softmax_maxelem[((blockIdx.x*512) + threadIdx.x)] = 
-3.40282e+38f32
          }
        }
      }
```

Now the question becomes how to distinguish between `(d2*(d1*d0))`, 
`(d2*(d0*d1))`, and `((d0*d1)*d2)`. We add a rule in `PrimExpr 
CanonicalSimplifier::Impl::VisitExpr_(const MulNode* op);` to get a canonical 
form of multiplication:

```c++
PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const MulNode* op) {
  // ...
  // normalize
  PrimExpr a = this->CanonicalMutate(op->a);
  PrimExpr b = this->CanonicalMutate(op->b);

  // ...

  // var * expr => expr * var
  if (a.as<VarNode>() && !b.as<VarNode>()) {
    std::swap(a, b);
  }
  
  // if given var * var or expr * expr, use their
  // structural hash value to sort
  if (a.as<VarNode>() || !b.as<VarNode>()) {
    auto ah = StructuralHash()(a);
    auto bh = StructuralHash()(b);
    if (ah > bh) {
      std::swap(a, b);
    }
  }
  
  // ...
}
```

I think the method that uses the structural hash value for sorting is a bit 
ugly, but I have no other better idea currently. After add this rule for 
canonical simplify, we get:

```c++
primfn(placeholder_1: handle, T_softmax_norm_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "fused_nn_softmax", 
"tir.noalias": True}
  buffers = {placeholder: Buffer(placeholder_2: Pointer(float32), float32, [d0: 
int32, d1: int32, d2: int32, d3: int32], [stride: int32, stride_1: int32, 
stride_2: int32, stride_3: int32], type="auto"),
             T_softmax_norm: Buffer(T_softmax_norm_2: Pointer(float32), 
float32, [d0, d1, d2, d3], [stride_4: int32, stride_5: int32, stride_6: int32, 
stride_7: int32], type="auto")}
  buffer_map = {placeholder_1: placeholder, T_softmax_norm_1: T_softmax_norm} {
  allocate(T_softmax_maxelem: Pointer(global float32), float32, [d0, d1, d2]), 
storage_scope = global {
    attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] 
"thread_extent" = floordiv((((d0*d1)*d2) + 511), 512);
    attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] 
"thread_extent" = 512 {
      if @tir.likely((((blockIdx.x*512) + threadIdx.x) < ((d0*d1)*d2)), 
dtype=bool) {
        if @tir.likely((((blockIdx.x*512) + threadIdx.x) < ((d0*d1)*d2)), 
dtype=bool) {
          if @tir.likely((((blockIdx.x*512) + threadIdx.x) < ((d0*d1)*d2)), 
dtype=bool) {
            T_softmax_maxelem[((blockIdx.x*512) + threadIdx.x)] = 
-3.40282e+38f32
          }
        }
      }
```

Next we need to find a way to remove these literally equivalent expressions. 
Actually in `RewriteSimplify` there is such a mechanism:

```c++
PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) {
  // add condition context to if_then_else
  PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
  op = ret.as<CallNode>();

  // ...

  ExprDeepEqual expr_equal;
  if (op->op.same_as(tir::builtin::likely())) {
    for (const auto& constraint : literal_constraints_) {
      // Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } }
      if (expr_equal(constraint, op->args[0])) {
        return make_const(op->dtype, true);
      }
    }
  }
  return ret;
}
```

However, all `constraint`s in `literal_constraints_` have been processed by the 
`CanonicalSimplify` when enter this constraint:

```c++
Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) {
  PrimExpr condition = this->VisitExpr(op->condition); // HERE
  PrimExpr real_condition = condition;
  static auto op_likely = Op::Get("tir.likely");

  if (auto call = condition.as<CallNode>()) {
    if (call->op.same_as(op_likely)) {
      real_condition = call->args[0];
    }
  }

  Stmt then_case, else_case;
  {
    With<ConstraintContext> ctx(analyzer_, real_condition);
    then_case = this->VisitStmt(op->then_case);
  }
  if (op->else_case.defined()) {
    With<ConstraintContext> ctx(analyzer_, 
analyzer_->rewrite_simplify(Not(real_condition)));
    else_case = this->VisitStmt(op->else_case);
  }
  
  // ...
}
```

while `op->args[0]` is not since we are in the `RewriteSimplify` and the 
`CanonicalSimplify` is behind this process:

```c++
PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) {
  if (tir::is_const_int(expr)) return expr;
  PrimExpr res = expr;
  for (int i = 0; i < steps; ++i) {
    res = this->rewrite_simplify(res);                       // RewriteSimplify
    if (tir::is_const_int(res) || ++i == steps) return res;  // is ++i proper 
here?
    res = this->canonical_simplify(res);                     // 
CanonicalSimplify
    if (tir::is_const_int(res)) return res;
  }
  return res;
}
```

This will make `op->args[0]` looks something like `((threadIdx.x: int32 + 
(blockIdx.x: int32*512)) < (((d1: int32*d0: int32)*d2: int32)*d3: int32))` 
while the constraint looks like `(((blockIdx.x: int32*512) + threadIdx.x: 
int32) < (((d1: int32*d0: int32)*d2: int32)*d3: int32))`. To solve this 
problem, we can perform a canonical simplify before the comparison

```c++
PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) {
  // ...
  ExprDeepEqual expr_equal;
  if (op->op.same_as(tir::builtin::likely())) {
    auto condition = analyzer_->canonical_simplify(op->args[0]);
    for (const auto& constraint : literal_constraints_) {
      // Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } }
      if (expr_equal(constraint, condition)) {
        return make_const(op->dtype, true);
      }
    }
  }
  return ret;
}
```

After that we get:

```c++
primfn(placeholder_1: handle, T_softmax_norm_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "fused_nn_softmax", 
"tir.noalias": True}
  buffers = {placeholder: Buffer(placeholder_2: Pointer(float32), float32, [d0: 
int32, d1: int32, d2: int32, d3: int32], [stride: int32, stride_1: int32, 
stride_2: int32, stride_3: int32], type="auto"),
             T_softmax_norm: Buffer(T_softmax_norm_2: Pointer(float32), 
float32, [d0, d1, d2, d3], [stride_4: int32, stride_5: int32, stride_6: int32, 
stride_7: int32], type="auto")}
  buffer_map = {placeholder_1: placeholder, T_softmax_norm_1: T_softmax_norm} {
  allocate(T_softmax_maxelem: Pointer(global float32), float32, [d0, d1, d2]), 
storage_scope = global {
    attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] 
"thread_extent" = floordiv((((d0*d1)*d2) + 511), 512);
    attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] 
"thread_extent" = 512 {
      if @tir.likely((((blockIdx.x*512) + threadIdx.x) < ((d0*d1)*d2)), 
dtype=bool) {
        T_softmax_maxelem[((blockIdx.x*512) + threadIdx.x)] = -3.40282e+38f32
      }
```



Again, this is only an experimental idea and there are still some issues to be 
solved. If you have any better ideas, please feel free to suggest below.





---
[Visit 
Topic](https://discuss.tvm.apache.org/t/dynamic-shape-better-simplify-support-for-dynamic-boundary-check/10812/1)
 to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.apache.org/email/unsubscribe/fd77531879ca165913e64c1d35153ce98bb45a6b625c747b81044893f89535e6).

Reply via email to