Skip to content

Commit

Permalink
Revert "[Unity] Fix IndexDataTypeNormalizer so that it correctly hand…
Browse files Browse the repository at this point in the history
…les corner case" (#16241)

Revert "[Unity] Fix IndexDataTypeNormalizer so that it correctly handles corner case (#16235)"

This reverts commit f7b0193.
  • Loading branch information
tqchen authored Dec 14, 2023
1 parent e100a13 commit 6741678
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 21 deletions.
3 changes: 0 additions & 3 deletions include/tvm/tir/data_type_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,6 @@ class DataTypeLegalizer : public StmtExprMutator {
std::unordered_map<const IterVarNode*, IterVar> ivmap_;
// a map from original vars to ones with new dtype
std::unordered_map<const VarNode*, Var> var_remap_;
// number of iterations. The first iteration collects var_remap_,
// and the second iteration performs rewrite
int iter_ = 0;
};

/*!
Expand Down
21 changes: 3 additions & 18 deletions src/tir/ir/data_type_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,6 @@ Stmt DataTypeLegalizer::VisitStmt_(const AttrStmtNode* op) {
ICHECK(iv != nullptr) << "Expected type to be IterVarNode"
<< ", but get " << op->node->GetTypeKey();
PrimExpr e = VisitExpr(iv->var);
if (iter_ == 0) {
return GetRef<AttrStmt>(op);
}
Var var = Downcast<Var>(e);
if (ivmap_.find(iv) == ivmap_.end()) {
Range dom = iv->dom;
Expand Down Expand Up @@ -396,9 +393,6 @@ IterVar IndexDataTypeRewriter::VisitIterVar(const IterVar& iter_var) {
}

Buffer IndexDataTypeRewriter::VisitBuffer(const Buffer& buffer) {
if (iter_ == 0) {
return buffer;
}
bool is_enabled = is_enabled_;

is_enabled_ = true;
Expand Down Expand Up @@ -588,10 +582,6 @@ IndexDataTypeNormalizer::IndexDataTypeNormalizer(DataType target_data_type)
: target_data_type_(std::move(target_data_type)) {}

PrimFunc IndexDataTypeNormalizer::Rewrite(PrimFunc func) {
// collect var remap
VisitStmt(std::move(func->body));
iter_++;
// start rewrite
Map<Var, Buffer> new_buffer_map = func->buffer_map;
for (const auto& [var, buffer] : func->buffer_map) {
new_buffer_map.Set(var, VisitBuffer(buffer));
Expand Down Expand Up @@ -628,15 +618,10 @@ PrimExpr IndexDataTypeNormalizer::VisitExpr_(const IntImmNode* op) {
}

PrimExpr IndexDataTypeNormalizer::VisitExpr_(const VarNode* op) {
// In the first iteration, collect var_remap_
if (iter_ == 0) {
if (is_enabled_ && CanRewriteDType(op->dtype) && op->dtype != target_data_type_ &&
!var_remap_.count(op)) {
var_remap_[op] = GetRef<Var>(op).copy_with_dtype(target_data_type_);
}
return GetRef<Var>(op);
if (is_enabled_ && CanRewriteDType(op->dtype) && op->dtype != target_data_type_ &&
!var_remap_.count(op)) {
var_remap_[op] = GetRef<Var>(op).copy_with_dtype(target_data_type_);
}
// In the second iteration, rewrite the var
return DataTypeLegalizer::VisitExpr_(op);
}

Expand Down

0 comments on commit 6741678

Please sign in to comment.