## Background and Motivation
Currently, TVM uses the `tir.Simplify` pass to remove some redundant expression like nested equivalent if-condition. For example, given a simple softmax operation like ``` primfn(placeholder_1: handle, T_softmax_norm_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "fused_nn_softmax", "tir.noalias": True} buffers = {T_softmax_norm: Buffer(T_softmax_norm_2: Pointer(float32), float32, [2, 10, 257, 1025], []), placeholder: Buffer(placeholder_2: Pointer(float32), float32, [2, 10, 257, 1025], [])} buffer_map = {placeholder_1: placeholder, T_softmax_norm_1: T_softmax_norm} { allocate(T_softmax_maxelem: Pointer(global float32), float32, [2, 10, 257]), storage_scope = global { attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 6; attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 1024 { if @tir.likely((floordiv(floordiv((threadIdx.x + (blockIdx.x*1024)), 257), 10) < 2), dtype=bool) { if @tir.likely((floordiv((threadIdx.x + (blockIdx.x*1024)), 257) < 20), dtype=bool) { if @tir.likely(((threadIdx.x + (blockIdx.x*1024)) < 5140), dtype=bool) { T_softmax_maxelem[((blockIdx.x*1024) + threadIdx.x)] = -3.40282e+38f32 } } } // ... ``` `tir.Simplify` will simplify this to ``` primfn(placeholder_1: handle, T_softmax_norm_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "fused_nn_softmax", "tir.noalias": True} buffers = {T_softmax_norm: Buffer(T_softmax_norm_2: Pointer(float32), float32, [2, 10, 257, 1025], []), placeholder: Buffer(placeholder_2: Pointer(float32), float32, [2, 10, 257, 1025], [])} buffer_map = {placeholder_1: placeholder, T_softmax_norm_1: T_softmax_norm} { allocate(T_softmax_maxelem: Pointer(global float32), float32, [2, 10, 257]), storage_scope = global { attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 6; attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 1024 { if @tir.likely((((blockIdx.x*1024) + threadIdx.x) < 5140), dtype=bool) { T_softmax_maxelem[((blockIdx.x*1024) + threadIdx.x)] = -3.40282e+38f32 } // ... ``` where three equivalent condition will be simplified to one. However, things will be different when the given input has a dynamic shape. Current `tir.Simplify` will fail given an input with dynamic shape, this is because the analyzer (actually the `RewriteSimplifier` and the `CanonicalSimplifier`) used by this pass lacks corresponding rules for this "non-const" situation. In the next part of this post we will continuous to use this simple softmax example to discuss this problem. We try to fix this problem by adding more rules in both `RewriteSimplifier` and `CanonicalSimplifier`. Currently this is still an experimental idea, if you find something wrong or improper, feel free to correct us in this post directly. ## Proposal We will show our proposal by solving this problem in our simple softmax example here. As shown in [this post](https://discuss.tvm.apache.org/t/pre-rfc-dynamic-shape-use-sizevar-instead-of-var-when-convert-any-in-the-getshape-function/10625), we can eliminate some redundant expressions by introducing sign information into tensor shapes. But there are still some other redundant expressions that are not covered by this solution. These redundant expressions can be eliminated by the `tir.Simplify` pass when the input's shape is static as shown before. For the dynamic situation, we list the reasons that prevent `tir.Simplify` from simplifying these redundancy as follows: 1. `RewriteSimplifier` only has rules for `IntImm` to simplify `floordiv(x, c1) < c2` to `x < c1 * c2`. 2. After the simplification from `floordiv(x, c1) < c2` to `x < c1 * c2`, we can directly get a new constant `c3 = c1 * c2` providing `c1` and `c2` are `IntImm`. But if we are given variables (or even worse, expressions), we cannot distinguish between `v1 * v2` and `v2 * v1`. For clarity, we use `SizeVar` `d0`, `d1`, `d2`, and `d3` for the shape in our simple softmax example. The output of current `tir.Simplify` is ```c++ primfn(placeholder_1: handle, T_softmax_norm_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "fused_nn_softmax", "tir.noalias": True} buffers = {placeholder: Buffer(placeholder_2: Pointer(float32), float32, [d0: int32, d1: int32, d2: int32, d3: int32], [stride: int32, stride_1: int32, stride_2: int32, stride_3: int32], type="auto"), T_softmax_norm: Buffer(T_softmax_norm_2: Pointer(float32), float32, [d0, d1, d2, d3], [stride_4: int32, stride_5: int32, stride_6: int32, stride_7: int32], type="auto")} buffer_map = {placeholder_1: placeholder, T_softmax_norm_1: T_softmax_norm} { allocate(T_softmax_maxelem: Pointer(global float32), float32, [d0, d1, d2]), storage_scope = global { attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = floordiv((((d0*d1)*d2) + 511), 512); attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 512 { if @tir.likely((floordiv(floordiv(((blockIdx.x*512) + threadIdx.x), d2), d1) < d0), dtype=bool) { if @tir.likely((floordiv(((blockIdx.x*512) + threadIdx.x), d2) < (d0*d1)), dtype=bool) { if @tir.likely((((blockIdx.x*512) + threadIdx.x) < ((d0*d1)*d2)), dtype=bool) { T_softmax_maxelem[((blockIdx.x*512) + threadIdx.x)] = -3.40282e+38f32 } } } ``` To simplify `floordiv(((blockIdx.x*512) + threadIdx.x), d2) < (d0*d1)` and `floordiv(floordiv(((blockIdx.x*512) + threadIdx.x), d2), d1) < d0` to `(blockIdx.x*512) + threadIdx.x < ((d0*d1)*d2)`, we add a new rule in `PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LTNode* op);`: ```c++ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LTNode* op) { // ... PVar<PrimExpr> x, y, z, s1, s2; // ... TVM_TRY_REWRITE_IF(floordiv(x, s1) < s2, x < s1 * s2, analyzer_->const_int_bound(s1.Eval())->min_value >= 0); // ... } ``` Here comes our first worry. The corresponding `IntImm` version of this rule is ```c++ TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x, 0 < floormod(x, c1), c1.Eval()->value > 0); ``` where the if-condition is `>0` instead of `>=0`. For `PrimExpr` version we can only get the non-negative information from the `ConstIntBoundAnalyzer` (this bound information comes from simple facts like `SizeVar + SizeVar >= 0` or `SizeVar * SizeVar >= 0`). Although `=0` is an invalid case, this transformation is not equivalent and may hide some run-time errors. [This post](https://discuss.tvm.apache.org/t/discuss-embed-more-bound-information-into-var-or-expr/4079) shows some possible solutions for this issue but there may be some simpler solutions (if you have any idea, please share with us). After adding this rule for rewrite simplify, we get: ```c++ primfn(placeholder_1: handle, T_softmax_norm_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "fused_nn_softmax", "tir.noalias": True} buffers = {placeholder: Buffer(placeholder_2: Pointer(float32), float32, [d0: int32, d1: int32, d2: int32, d3: int32], [stride: int32, stride_1: int32, stride_2: int32, stride_3: int32], type="auto"), T_softmax_norm: Buffer(T_softmax_norm_2: Pointer(float32), float32, [d0, d1, d2, d3], [stride_4: int32, stride_5: int32, stride_6: int32, stride_7: int32], type="auto")} buffer_map = {placeholder_1: placeholder, T_softmax_norm_1: T_softmax_norm} { allocate(T_softmax_maxelem: Pointer(global float32), float32, [d0, d1, d2]), storage_scope = global { attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = floordiv((((d0*d1)*d2) + 511), 512); attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 512 { if @tir.likely((((blockIdx.x*512) + threadIdx.x) < (d2*(d1*d0))), dtype=bool) { if @tir.likely((((blockIdx.x*512) + threadIdx.x) < (d2*(d0*d1))), dtype=bool) { if @tir.likely((((blockIdx.x*512) + threadIdx.x) < ((d0*d1)*d2)), dtype=bool) { T_softmax_maxelem[((blockIdx.x*512) + threadIdx.x)] = -3.40282e+38f32 } } } ``` Now the question becomes how to distinguish between `(d2*(d1*d0))`, `(d2*(d0*d1))`, and `((d0*d1)*d2)`. We add a rule in `PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const MulNode* op);` to get a canonical form of multiplication: ```c++ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const MulNode* op) { // ... // normalize PrimExpr a = this->CanonicalMutate(op->a); PrimExpr b = this->CanonicalMutate(op->b); // ... // var * expr => expr * var if (a.as<VarNode>() && !b.as<VarNode>()) { std::swap(a, b); } // if given var * var or expr * expr, use their // structural hash value to sort if (a.as<VarNode>() || !b.as<VarNode>()) { auto ah = StructuralHash()(a); auto bh = StructuralHash()(b); if (ah > bh) { std::swap(a, b); } } // ... } ``` I think the method that uses the structural hash value for sorting is a bit ugly, but I have no other better idea currently. After add this rule for canonical simplify, we get: ```c++ primfn(placeholder_1: handle, T_softmax_norm_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "fused_nn_softmax", "tir.noalias": True} buffers = {placeholder: Buffer(placeholder_2: Pointer(float32), float32, [d0: int32, d1: int32, d2: int32, d3: int32], [stride: int32, stride_1: int32, stride_2: int32, stride_3: int32], type="auto"), T_softmax_norm: Buffer(T_softmax_norm_2: Pointer(float32), float32, [d0, d1, d2, d3], [stride_4: int32, stride_5: int32, stride_6: int32, stride_7: int32], type="auto")} buffer_map = {placeholder_1: placeholder, T_softmax_norm_1: T_softmax_norm} { allocate(T_softmax_maxelem: Pointer(global float32), float32, [d0, d1, d2]), storage_scope = global { attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = floordiv((((d0*d1)*d2) + 511), 512); attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 512 { if @tir.likely((((blockIdx.x*512) + threadIdx.x) < ((d0*d1)*d2)), dtype=bool) { if @tir.likely((((blockIdx.x*512) + threadIdx.x) < ((d0*d1)*d2)), dtype=bool) { if @tir.likely((((blockIdx.x*512) + threadIdx.x) < ((d0*d1)*d2)), dtype=bool) { T_softmax_maxelem[((blockIdx.x*512) + threadIdx.x)] = -3.40282e+38f32 } } } ``` Next we need to find a way to remove these literally equivalent expressions. Actually in `RewriteSimplify` there is such a mechanism: ```c++ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { // add condition context to if_then_else PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as<CallNode>(); // ... ExprDeepEqual expr_equal; if (op->op.same_as(tir::builtin::likely())) { for (const auto& constraint : literal_constraints_) { // Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } } if (expr_equal(constraint, op->args[0])) { return make_const(op->dtype, true); } } } return ret; } ``` However, all `constraint`s in `literal_constraints_` have been processed by the `CanonicalSimplify` when enter this constraint: ```c++ Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) { PrimExpr condition = this->VisitExpr(op->condition); // HERE PrimExpr real_condition = condition; static auto op_likely = Op::Get("tir.likely"); if (auto call = condition.as<CallNode>()) { if (call->op.same_as(op_likely)) { real_condition = call->args[0]; } } Stmt then_case, else_case; { With<ConstraintContext> ctx(analyzer_, real_condition); then_case = this->VisitStmt(op->then_case); } if (op->else_case.defined()) { With<ConstraintContext> ctx(analyzer_, analyzer_->rewrite_simplify(Not(real_condition))); else_case = this->VisitStmt(op->else_case); } // ... } ``` while `op->args[0]` is not since we are in the `RewriteSimplify` and the `CanonicalSimplify` is behind this process: ```c++ PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) { if (tir::is_const_int(expr)) return expr; PrimExpr res = expr; for (int i = 0; i < steps; ++i) { res = this->rewrite_simplify(res); // RewriteSimplify if (tir::is_const_int(res) || ++i == steps) return res; // is ++i proper here? res = this->canonical_simplify(res); // CanonicalSimplify if (tir::is_const_int(res)) return res; } return res; } ``` This will make `op->args[0]` looks something like `((threadIdx.x: int32 + (blockIdx.x: int32*512)) < (((d1: int32*d0: int32)*d2: int32)*d3: int32))` while the constraint looks like `(((blockIdx.x: int32*512) + threadIdx.x: int32) < (((d1: int32*d0: int32)*d2: int32)*d3: int32))`. To solve this problem, we can perform a canonical simplify before the comparison ```c++ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { // ... ExprDeepEqual expr_equal; if (op->op.same_as(tir::builtin::likely())) { auto condition = analyzer_->canonical_simplify(op->args[0]); for (const auto& constraint : literal_constraints_) { // Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } } if (expr_equal(constraint, condition)) { return make_const(op->dtype, true); } } } return ret; } ``` After that we get: ```c++ primfn(placeholder_1: handle, T_softmax_norm_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "fused_nn_softmax", "tir.noalias": True} buffers = {placeholder: Buffer(placeholder_2: Pointer(float32), float32, [d0: int32, d1: int32, d2: int32, d3: int32], [stride: int32, stride_1: int32, stride_2: int32, stride_3: int32], type="auto"), T_softmax_norm: Buffer(T_softmax_norm_2: Pointer(float32), float32, [d0, d1, d2, d3], [stride_4: int32, stride_5: int32, stride_6: int32, stride_7: int32], type="auto")} buffer_map = {placeholder_1: placeholder, T_softmax_norm_1: T_softmax_norm} { allocate(T_softmax_maxelem: Pointer(global float32), float32, [d0, d1, d2]), storage_scope = global { attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = floordiv((((d0*d1)*d2) + 511), 512); attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 512 { if @tir.likely((((blockIdx.x*512) + threadIdx.x) < ((d0*d1)*d2)), dtype=bool) { T_softmax_maxelem[((blockIdx.x*512) + threadIdx.x)] = -3.40282e+38f32 } ``` Again, this is only an experimental idea and there are still some issues to be solved. If you have any better ideas, please feel free to suggest below. --- [Visit Topic](https://discuss.tvm.apache.org/t/dynamic-shape-better-simplify-support-for-dynamic-boundary-check/10812/1) to respond. You are receiving this because you enabled mailing list mode. To unsubscribe from these emails, [click here](https://discuss.tvm.apache.org/email/unsubscribe/fd77531879ca165913e64c1d35153ce98bb45a6b625c747b81044893f89535e6).