|  | 
| 28 | 28 | #include <tvm/tir/op.h> | 
| 29 | 29 | #include <tvm/tir/transform.h> | 
| 30 | 30 | 
 | 
| 31 |  | -#include <functional> | 
| 32 | 31 | #include <limits> | 
| 33 | 32 | #include <unordered_set> | 
| 34 | 33 | 
 | 
|  | 
| 37 | 36 | 
 | 
| 38 | 37 | namespace tvm { | 
| 39 | 38 | namespace tl { | 
| 40 |  | - | 
| 41 | 39 | using namespace tir; | 
| 42 | 40 | 
 | 
| 43 | 41 | class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { | 
| @@ -110,38 +108,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { | 
| 110 | 108 |     const DataType &dtype = op->dtype; | 
| 111 | 109 |     ICHECK(dtype.is_int() || dtype.is_uint()); | 
| 112 | 110 | 
 | 
| 113 |  | -    auto is_ceildiv_numerator = [&]() { | 
| 114 |  | -      std::vector<std::pair<PrimExpr, int>> terms; | 
| 115 |  | -      std::function<void(const PrimExpr &, int)> collect = | 
| 116 |  | -          [&](const PrimExpr &expr, int sign) { | 
| 117 |  | -            if (const auto *add = expr.as<AddNode>()) { | 
| 118 |  | -              collect(add->a, sign); | 
| 119 |  | -              collect(add->b, sign); | 
| 120 |  | -            } else if (const auto *sub = expr.as<SubNode>()) { | 
| 121 |  | -              collect(sub->a, sign); | 
| 122 |  | -              collect(sub->b, -sign); | 
| 123 |  | -            } else { | 
| 124 |  | -              terms.emplace_back(expr, sign); | 
| 125 |  | -            } | 
| 126 |  | -          }; | 
| 127 |  | -      collect(op->a, 1); | 
| 128 |  | -      int den_coeff = 0; | 
| 129 |  | -      int64_t const_term = 0; | 
| 130 |  | -      for (const auto &term : terms) { | 
| 131 |  | -        const PrimExpr &expr = term.first; | 
| 132 |  | -        int sign = term.second; | 
| 133 |  | -        if (const auto *imm = expr.as<IntImmNode>()) { | 
| 134 |  | -          const_term += static_cast<int64_t>(sign) * imm->value; | 
| 135 |  | -        } else if (analyzer_->CanProveEqual(expr, op->b)) { | 
| 136 |  | -          den_coeff += sign; | 
| 137 |  | -        } | 
| 138 |  | -      } | 
| 139 |  | -      return den_coeff == 1 && const_term == -1; | 
| 140 |  | -    }(); | 
| 141 |  | - | 
| 142 |  | -    if (support_bitwise_op_ && !is_ceildiv_numerator && | 
| 143 |  | -        is_const_power_of_two_integer(op->b, &shift) && | 
| 144 |  | -        analyzer_->CanProveGreaterEqual(op->a, 0)) { | 
|  | 111 | +    // lower (a + 31) // 512 to (a + 31) >> 5 | 
|  | 112 | +    if (support_bitwise_op_ && is_const_power_of_two_integer(op->b, &shift)) { | 
| 145 | 113 |       // lower to right shift if possible. | 
| 146 | 114 |       return op->a >> make_const(dtype, shift); | 
| 147 | 115 |     } | 
| @@ -192,8 +160,9 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { | 
| 192 | 160 |         //     == truncdiv(a + b*c, b) - c | 
| 193 | 161 |         IntImm min(op->a->dtype.element_of(), const_int_bound->min_value); | 
| 194 | 162 |         PrimExpr ceildiv = truncdiv((op->b - 1) - min, op->b); | 
| 195 |  | -        PrimExpr offset_numerator = | 
| 196 |  | -            analyzer_->Simplify(op->a + op->b * ceildiv); | 
|  | 163 | +        // Skip analyzer simplification so we preserve straightforward div | 
|  | 164 | +        // expressions. | 
|  | 165 | +        PrimExpr offset_numerator = op->a + op->b * ceildiv; | 
| 197 | 166 |         return truncdiv(offset_numerator, op->b) - ceildiv; | 
| 198 | 167 |       } | 
| 199 | 168 | 
 | 
| @@ -429,9 +398,9 @@ Stmt LowerIntrinStmt(Stmt stmt, const std::string &target) { | 
| 429 | 398 | } | 
| 430 | 399 | 
 | 
| 431 | 400 | namespace transform { | 
| 432 |  | -using namespace tir::transform; | 
| 433 | 401 | 
 | 
| 434 |  | -tvm::transform::Pass LowerIntrin() { | 
|  | 402 | +tir::transform::Pass LowerIntrin() { | 
|  | 403 | +  using namespace tir::transform; | 
| 435 | 404 |   auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { | 
| 436 | 405 |     auto *n = f.CopyOnWrite(); | 
| 437 | 406 |     auto target = f->GetAttr<Target>(tvm::attr::kTarget); | 
|  | 
0 commit comments