Skip to content

Commit dbd6614

Browse files
authored
[CINN]optimize global memory read insert pointer (#68667)
* optimize global memory read insert pointer * polish code
1 parent 3eddd2a commit dbd6614

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

paddle/cinn/optim/eliminate_common_global_memory_read.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,9 @@ struct CommonGlobalMemoryEliminator : public ir::IRMutator<Expr*> {
335335
::common::errors::InvalidArgument(
336336
"The input expr should be a ScheduleBlockRealize"));
337337
current_sbr_ = node;
338+
if (current_block_) {
339+
insert_block_ = current_block_;
340+
}
338341
IRMutator<>::Visit(op, expr);
339342
}
340343

@@ -387,7 +390,11 @@ struct CommonGlobalMemoryEliminator : public ir::IRMutator<Expr*> {
387390
"buffer_name %s should not be in global_buffer_to_local_buffer_",
388391
buffer_name));
389392
global_buffer_to_local_buffer_[buffer_name] = new_tensor;
390-
block_to_insert_stmts_[current_block_].push_back(new_sbr);
393+
394+
PADDLE_ENFORCE_NOT_NULL(
395+
insert_block_,
396+
::common::errors::InvalidArgument("insert block CAN NOT be nullptr"));
397+
block_to_insert_stmts_[insert_block_].push_back(new_sbr);
391398
}
392399

393400
void SubstituteGlobalTensor(ir::Load* load_node,
@@ -405,7 +412,8 @@ struct CommonGlobalMemoryEliminator : public ir::IRMutator<Expr*> {
405412
std::unordered_map<std::string, ir::Expr> global_buffer_to_local_buffer_;
406413
std::unordered_map<ir::Block*, std::vector<ir::Expr>> block_to_insert_stmts_;
407414

408-
ir::Block* current_block_;
415+
ir::Block* current_block_{nullptr};
416+
ir::Block* insert_block_{nullptr};
409417
ir::ScheduleBlockRealize* current_sbr_;
410418
};
411419

0 commit comments

Comments
 (0)