-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Add feed and fetch op to ProgramDesc before saving for inference #7636
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,19 +25,37 @@ limitations under the License. */ | |
|
|
||
| namespace paddle { | ||
|
|
||
| void InferenceEngine::LoadInferenceModel(const std::string& dirname) { | ||
| std::string model_filename = dirname + "/__model__.dat"; | ||
| LOG(INFO) << "loading model from " << model_filename; | ||
| std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary); | ||
| std::string program_desc_str; | ||
| inputfs.seekg(0, std::ios::end); | ||
| program_desc_str.resize(inputfs.tellg()); | ||
| inputfs.seekg(0, std::ios::beg); | ||
| LOG(INFO) << "program_desc_str's size: " << program_desc_str.size(); | ||
| inputfs.read(&program_desc_str[0], program_desc_str.size()); | ||
| inputfs.close(); | ||
|
|
||
| program_ = new framework::ProgramDesc(program_desc_str); | ||
| GenerateLoadProgram(dirname); | ||
|
|
||
| framework::BlockDesc* global_block = program_->MutableBlock(0); | ||
| feed_var_names_.clear(); | ||
| fetch_var_names_.clear(); | ||
| for (auto* op : global_block->AllOps()) { | ||
| if (op->Type() == "feed") { | ||
| feed_var_names_.insert(feed_var_names_.begin(), op->Output("Out")[0]); | ||
| } else if (op->Type() == "fetch") { | ||
| fetch_var_names_.push_back(op->Input("X")[0]); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| void InferenceEngine::LoadInferenceModel( | ||
| const std::string& dirname, | ||
| const std::vector<std::string>& feed_var_names, | ||
| const std::vector<std::string>& fetch_var_names) { | ||
| #ifdef PADDLE_USE_PTOOLS | ||
| std::string model_filename = dirname + "/__model__"; | ||
| LOG(INFO) << "Using PicklingTools, loading model from " << model_filename; | ||
| Val v; | ||
| LoadValFromFile(model_filename.c_str(), v, SERIALIZE_P0); | ||
| std::string program_desc_str = v["program_desc_str"]; | ||
| LOG(INFO) << "program_desc_str's size: " << program_desc_str.size(); | ||
| // PicklingTools cannot parse the vector of strings correctly. | ||
| #else | ||
| std::string model_filename = dirname + "/__model__.dat"; | ||
| LOG(INFO) << "loading model from " << model_filename; | ||
| std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary); | ||
|
|
@@ -48,7 +66,7 @@ void InferenceEngine::LoadInferenceModel( | |
| LOG(INFO) << "program_desc_str's size: " << program_desc_str.size(); | ||
| inputfs.read(&program_desc_str[0], program_desc_str.size()); | ||
| inputfs.close(); | ||
| #endif | ||
|
|
||
| program_ = new framework::ProgramDesc(program_desc_str); | ||
| GenerateLoadProgram(dirname); | ||
|
|
||
|
|
@@ -62,7 +80,7 @@ void InferenceEngine::LoadInferenceModel( | |
| } | ||
|
|
||
| bool InferenceEngine::IsParameter(const framework::VarDesc* var) { | ||
| if (var->Persistable()) { | ||
| if (var->Persistable() && var->Name() != "feed" && var->Name() != "fetch") { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should not use the name of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree. And we don't need to check Will fix in the future PR. |
||
| // There are many unreachable variables in the program | ||
| for (size_t i = 0; i < program_->Size(); ++i) { | ||
| const framework::BlockDesc& block = program_->Block(i); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,7 @@ | |
| import cPickle as pickle | ||
|
|
||
| from paddle.v2.fluid.framework import Program, Parameter, default_main_program, Variable | ||
| from . import core | ||
|
|
||
| __all__ = [ | ||
| 'save_vars', | ||
|
|
@@ -191,6 +192,33 @@ def get_inference_program(target_vars, main_program=None): | |
| return inference_program | ||
|
|
||
|
|
||
| def prepend_feed_ops(inference_program, feeded_var_names): | ||
| global_block = inference_program.global_block() | ||
| feed_var = global_block.create_var( | ||
| name='feed', type=core.VarDesc.VarType.FEED_MINIBATCH, persistable=True) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There might be some problem if fixed the name to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will fix this in the next PR. |
||
|
|
||
| for i, name in enumerate(feeded_var_names): | ||
| out = global_block.var(name) | ||
| global_block.prepend_op( | ||
| type='feed', | ||
| inputs={'X': [feed_var]}, | ||
| outputs={'Out': [out]}, | ||
| attrs={'col': i}) | ||
|
|
||
|
|
||
| def append_fetch_ops(inference_program, fetch_var_names): | ||
| global_block = inference_program.global_block() | ||
| fetch_var = global_block.create_var( | ||
| name='fetch', type=core.VarDesc.VarType.FETCH_LIST, persistable=True) | ||
|
|
||
| for i, name in enumerate(fetch_var_names): | ||
| global_block.append_op( | ||
| type='fetch', | ||
| inputs={'X': [name]}, | ||
| outputs={'Out': [fetch_var]}, | ||
| attrs={'col': i}) | ||
|
|
||
|
|
||
| def save_inference_model(dirname, | ||
| feeded_var_names, | ||
| target_vars, | ||
|
|
@@ -241,6 +269,9 @@ def save_inference_model(dirname, | |
| "fetch_var_names": fetch_var_names | ||
| }, f, -1) | ||
|
|
||
| prepend_feed_ops(inference_program, feeded_var_names) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can remove There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will do this in the next PR. Thanks! |
||
| append_fetch_ops(inference_program, fetch_var_names) | ||
|
|
||
| # Save only programDesc of inference_program in binary format | ||
| # in another file: __model__.dat | ||
| with open(model_file_name + ".dat", "wb") as fp: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can remove this function. If there are not
feed_opandfetch_opin theProgramDesc, users can specify these when callingRun().Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, i don't understand this properly. Based on the updated design, the
Run()function does not take as input the vector of fetch_var_name and feed_var_names. Right?So can you please explain the idea that users can specify that information when calling
Run().Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can get the
feed_var_namesfrom the argumentstd::map<std::string, Tensor>& feeds, where thestd::stringrepresent a name and theTensoris input data.Why the argument is a
std::map, because the corresponding argument in Python implementation is a dict.Have a look at the example, where show the detailed usage of the
Executor.Run().There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay, will take a look. Thanks for the reply.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, will remove this function in the next PR.