@@ -335,6 +335,9 @@ struct CommonGlobalMemoryEliminator : public ir::IRMutator<Expr*> {
335
335
::common::errors::InvalidArgument (
336
336
" The input expr should be a ScheduleBlockRealize" ));
337
337
current_sbr_ = node;
338
+ if (current_block_) {
339
+ insert_block_ = current_block_;
340
+ }
338
341
IRMutator<>::Visit (op, expr);
339
342
}
340
343
@@ -387,7 +390,11 @@ struct CommonGlobalMemoryEliminator : public ir::IRMutator<Expr*> {
387
390
" buffer_name %s should not be in global_buffer_to_local_buffer_" ,
388
391
buffer_name));
389
392
global_buffer_to_local_buffer_[buffer_name] = new_tensor;
390
- block_to_insert_stmts_[current_block_].push_back (new_sbr);
393
+
394
+ PADDLE_ENFORCE_NOT_NULL (
395
+ insert_block_,
396
+ ::common::errors::InvalidArgument (" insert block CAN NOT be nullptr" ));
397
+ block_to_insert_stmts_[insert_block_].push_back (new_sbr);
391
398
}
392
399
393
400
void SubstituteGlobalTensor (ir::Load* load_node,
@@ -405,7 +412,8 @@ struct CommonGlobalMemoryEliminator : public ir::IRMutator<Expr*> {
405
412
std::unordered_map<std::string, ir::Expr> global_buffer_to_local_buffer_;
406
413
std::unordered_map<ir::Block*, std::vector<ir::Expr>> block_to_insert_stmts_;
407
414
408
- ir::Block* current_block_;
415
+ ir::Block* current_block_{nullptr };
416
+ ir::Block* insert_block_{nullptr };
409
417
ir::ScheduleBlockRealize* current_sbr_;
410
418
};
411
419
0 commit comments