-
Couldn't load subscription status.
- Fork 5.9k
Enable is_test attr of batch norm and drop out op for test program #8642
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,8 +27,6 @@ namespace framework { | |
|
|
||
| const std::string kFeedOpType = "feed"; | ||
| const std::string kFetchOpType = "fetch"; | ||
| const std::string kDropOutOpType = "dropout"; | ||
| const std::string kBatchNormOpType = "batch_norm"; | ||
|
|
||
| bool HasDependentVar(const proto::OpDesc& op_desc, | ||
| const std::set<std::string>& dependent_vars) { | ||
|
|
@@ -186,26 +184,26 @@ void Prune(const proto::ProgramDesc& input, proto::ProgramDesc* output) { | |
| prune_impl(input, output, 0, -1, dependent_vars); | ||
| } | ||
|
|
||
| void inference_optimize_impl(const proto::ProgramDesc& input, | ||
| proto::ProgramDesc* output, int block_id) { | ||
| *output = input; | ||
| auto* op_field = output->mutable_blocks(block_id)->mutable_ops(); | ||
| void inference_optimize_impl(proto::ProgramDesc* input, int block_id) { | ||
| auto* op_field = input->mutable_blocks(block_id)->mutable_ops(); | ||
| for (auto& op_desc : *op_field) { | ||
| if (op_desc.type() == kDropOutOpType || | ||
| op_desc.type() == kBatchNormOpType) { | ||
| for (auto& attr : *op_desc.mutable_attrs()) { | ||
| if (attr.name() == "is_test") { | ||
| attr.set_b(true); | ||
| break; | ||
| } | ||
| for (auto& attr : *op_desc.mutable_attrs()) { | ||
| if (attr.name() == "is_test") { | ||
| attr.set_b(true); | ||
| break; | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这么看,原来的 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是的,生成inference_program的时候就经过了inference_optimize函数的处理了。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 可以删除 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
| void InferenceOptimize(const proto::ProgramDesc& input, | ||
| proto::ProgramDesc* output) { | ||
| inference_optimize_impl(input, output, 0); | ||
| *output = input; | ||
| int num_blocks = output->blocks_size(); | ||
| PADDLE_ENFORCE_GT(num_blocks, 0, "ProgramDesc must have at least one block"); | ||
| for (int i = 0; i < num_blocks; ++i) { | ||
| inference_optimize_impl(output, i); | ||
| } | ||
| } | ||
|
|
||
| } // namespace framework | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
output -> inout
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.