diff --git a/src/pass/lower_warp_memory.cc b/src/pass/lower_warp_memory.cc index dbd073c0b14c7..d0412adc08df5 100644 --- a/src/pass/lower_warp_memory.cc +++ b/src/pass/lower_warp_memory.cc @@ -94,9 +94,13 @@ class WarpStoreCoeffFinder : private IRVisitor { CHECK_EQ(m.size(), 2U) << "LowerWarpMemory failed due to store index=" << index; int coeff; - CHECK(arith::GetConstInt(ir::Simplify(m[0]), &coeff) && coeff > 0) + Expr mcoeff = ir::Simplify(m[0]); + + CHECK(arith::GetConstInt(mcoeff, &coeff) && coeff > 0) << "LowerWarpMemory failed due to store index=" << index - << ", require positive constant coefficient on warp index"; + << ", require positive constant coefficient on warp index " << warp_index_ + << " but get " << mcoeff; + if (warp_coeff_ != 0) { CHECK_EQ(warp_coeff_, coeff) << "LowerWarpMemory failed due to two different store coefficient to warp index"; @@ -129,11 +133,6 @@ class WarpIndexFinder : private IRVisitor { } private: - void Visit(const NodeRef &node) final { - if (warp_index_.defined()) return; - IRVisitor::Visit(node); - } - /// Visitor implementation void Visit_(const AttrStmt *op) final { if (op->attr_key == attr::thread_extent) { @@ -145,7 +144,15 @@ class WarpIndexFinder : private IRVisitor { << "Expect threadIdx.x 's size to be equal to warp size(" << warp_size_ << ")" << " to enable warp memory" << " but get " << op->value << " instead"; - warp_index_ = iv; + if (warp_index_.defined()) { + CHECK(warp_index_.same_as(iv)) + << "Find two instance of " << warp_index_->thread_tag + << " in the same kernel. " + << "Please create it using thread_axis once and reuse the axis " + << "across multiple binds in the same kernel"; + } else { + warp_index_ = iv; + } } } IRVisitor::Visit_(op);