Skip to content

Commit

Permalink
reopen bn prim pir
Browse files Browse the repository at this point in the history
  • Loading branch information
cyber-pioneer committed Jan 10, 2024
1 parent b1daab4 commit 8f29fc6
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 13 deletions.
19 changes: 12 additions & 7 deletions paddle/fluid/primitive/base/decomp_trans.cc
Original file line number Diff line number Diff line change
Expand Up @@ -275,14 +275,16 @@ std::vector<std::vector<pir::OpResult>> call_decomp_rule(pir::Operation* op) {
}

void DecompProgram::decomp_program() {
std::ostringstream orig_prog_stream;
std::unordered_map<pir::OpResult, int> orig_vars_dict;
for (size_t i = 0; i < src_vars_.size(); i++) {
orig_vars_dict[src_vars_[i]] = static_cast<int>(i);
}
program_->Print(orig_prog_stream);
VLOG(4) << "[Prim] Origin program bofore decomp :\n"
<< orig_prog_stream.str();
if (VLOG_IS_ON(4)) {
std::ostringstream orig_prog_stream;
program_->Print(orig_prog_stream);
std::cout << "[Prim] Origin program before decomp :\n"
<< orig_prog_stream.str();
}
if (!paddle::prim::PrimCommonUtils::IsFwdPrimEnabled()) {
return;
}
Expand Down Expand Up @@ -334,9 +336,12 @@ void DecompProgram::decomp_program() {
}
auto& builder = *(paddle::dialect::ApiBuilder::Instance().GetBuilder());
builder.SetInsertionPointToBlockEnd(block);
std::ostringstream decomp_prog_stream;
program_->Print(decomp_prog_stream);
VLOG(4) << "[Prim] New program after decomp :\n" << decomp_prog_stream.str();
if (VLOG_IS_ON(4)) {
std::ostringstream decomp_prog_stream;
program_->Print(decomp_prog_stream);
std::cout << "[Prim] New program after decomp :\n"
<< decomp_prog_stream.str();
}
dst_vars_ = tar_vars;
return;
}
Expand Down
6 changes: 2 additions & 4 deletions test/legacy_test/test_batch_norm_op_prim_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ def setUp(self):
self.op_type = "batch_norm"
self.prim_op_type = "comp"
self.python_out_sig = ["Y"]
# (Todo: CZ) random error
self.check_prim_pir = False
self.check_prim_pir = True
self.initConfig()
self.initTestCase()

Expand Down Expand Up @@ -345,8 +344,7 @@ def initConfig(self):
self.epsilon = 1e-05
self.data_format = "NCHW"
self.use_global_stats = None
# Todo(CZ): open this
self.check_prim_pir = False
self.check_prim_pir = True


@unittest.skipIf(
Expand Down
3 changes: 1 addition & 2 deletions test/legacy_test/test_batch_norm_op_prim_nhwc.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,7 @@ def initConfig(self):
self.epsilon = 1e-05
self.data_format = "NHWC"
self.use_global_stats = None
# Todo(CZ): open this
self.check_prim_pir = False
self.check_prim_pir = True


class TestBatchNormOpNHWCShape2(TestBatchNormOp):
Expand Down

0 comments on commit 8f29fc6

Please sign in to comment.