Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 12 additions & 14 deletions paddle/fluid/framework/prune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ namespace framework {

const std::string kFeedOpType = "feed";
const std::string kFetchOpType = "fetch";
const std::string kDropOutOpType = "dropout";
const std::string kBatchNormOpType = "batch_norm";

bool HasDependentVar(const proto::OpDesc& op_desc,
const std::set<std::string>& dependent_vars) {
Expand Down Expand Up @@ -186,26 +184,26 @@ void Prune(const proto::ProgramDesc& input, proto::ProgramDesc* output) {
prune_impl(input, output, 0, -1, dependent_vars);
}

void inference_optimize_impl(const proto::ProgramDesc& input,
proto::ProgramDesc* output, int block_id) {
*output = input;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output -> inout

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

auto* op_field = output->mutable_blocks(block_id)->mutable_ops();
void inference_optimize_impl(proto::ProgramDesc* input, int block_id) {
auto* op_field = input->mutable_blocks(block_id)->mutable_ops();
for (auto& op_desc : *op_field) {
if (op_desc.type() == kDropOutOpType ||
op_desc.type() == kBatchNormOpType) {
for (auto& attr : *op_desc.mutable_attrs()) {
if (attr.name() == "is_test") {
attr.set_b(true);
break;
}
for (auto& attr : *op_desc.mutable_attrs()) {
if (attr.name() == "is_test") {
attr.set_b(true);
break;
}
}
}
}

Copy link
Contributor

@Xreki Xreki Mar 2, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这么看,原来的inference_program已经有这个修改了啊,那原来是我没注意到。我理解,无论是哪个op,只要有is_test这个属性,inference_program里面都应该设置成true才对,line 192 - 193op type的判断是否可以去掉?

Copy link
Contributor Author

@kexinzhao kexinzhao Mar 2, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,生成inference_program的时候就经过了inference_optimize函数的处理了。
你说的很有道理,已修改~

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以删除line 30 - 31

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

void InferenceOptimize(const proto::ProgramDesc& input,
proto::ProgramDesc* output) {
inference_optimize_impl(input, output, 0);
*output = input;
int num_blocks = output->blocks_size();
PADDLE_ENFORCE_GT(num_blocks, 0, "ProgramDesc must have at least one block");
for (int i = 0; i < num_blocks; ++i) {
inference_optimize_impl(output, i);
}
}

} // namespace framework
Expand Down
21 changes: 19 additions & 2 deletions python/paddle/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,9 +947,26 @@ def to_string(self, throw_on_error, with_details=False):
def get_desc(self):
return self.desc

def clone(self):
def clone(self, for_test=False):
"""Clone the Program object

Set for_test to False when we want to clone the program for training.
Set for_test to True when we want to clone the program for testing.

Args:
for_test(bool): Some operators, such as batch_norm and drop_out ops,
behave differently in training and testing. If for_test is True,
the is_test attributes in these operators will be set to True for
testing purposes, otherwise, they remain unchanged.

Returns(Program):
The cloned Program object.
"""
p = Program()
p.desc = core.ProgramDesc(self.desc)
if for_test:
p.desc = core.inference_optimize(self.desc)
else:
p.desc = core.ProgramDesc(self.desc)
p.blocks = [Block(p, i) for i in xrange(self.desc.num_blocks())]
p.sync_with_cpp()
p.copy_param_info_from(self)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def train(net_type, use_cuda, save_dirname, is_local):
acc = fluid.layers.accuracy(input=predict, label=label)

# Test program
test_program = fluid.default_main_program().clone()
test_program = fluid.default_main_program().clone(for_test=True)

optimizer = fluid.optimizer.Adam(learning_rate=0.001)
optimize_ops, params_grads = optimizer.minimize(avg_cost)
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/tests/book/test_recognize_digits.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def train(nn_type,
else:
prediction, avg_loss, acc = net_conf(img, label)

test_program = fluid.default_main_program().clone()
test_program = fluid.default_main_program().clone(for_test=True)

optimizer = fluid.optimizer.Adam(learning_rate=0.001)
optimize_ops, params_grads = optimizer.minimize(avg_loss)
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/tests/book/test_recommender_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def train(use_cuda, save_dirname, is_local=True):
scale_infer, avg_cost = model()

# test program
test_program = fluid.default_main_program().clone()
test_program = fluid.default_main_program().clone(for_test=True)

sgd_optimizer = SGDOptimizer(learning_rate=0.2)
optimize_ops, params_grads = sgd_optimizer.minimize(avg_cost)
Expand Down