Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 0 additions & 21 deletions paddle/inference/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,6 @@ cc_library(paddle_fluid_api
# Merge all modules into a simgle static library
cc_library(paddle_fluid DEPS paddle_fluid_api ${FLUID_CORE_MODULES})

# ptools
# just for testing, we may need to change the storing format for inference_model
# and move the dependent of pickle.
# download from http://www.picklingtools.com/
# build in the C++ sub-directory, using command
# make -f Makefile.Linux libptools.so
set(PTOOLS_LIB)
set(PTOOLS_ROOT $ENV{PTOOLS_ROOT} CACHE PATH "Folder contains PicklingTools")
find_path(PTOOLS_INC_DIR chooseser.h PATHS ${PTOOLS_ROOT}/C++)
find_library(PTOOLS_SHARED_LIB NAMES ptools PATHS ${PTOOLS_ROOT}/C++)
if(PTOOLS_INC_DIR AND PTOOLS_SHARED_LIB)
add_definitions(-DPADDLE_USE_PTOOLS)
set(PTOOLS_LIB ptools)
message(STATUS "Found PicklingTools: ${PTOOLS_SHARED_LIB}")
add_library(${PTOOLS_LIB} SHARED IMPORTED GLOBAL)
set_property(TARGET ${PTOOLS_LIB} PROPERTY IMPORTED_LOCATION ${PTOOLS_SHARED_LIB})
include_directories(${PTOOLS_ROOT}/C++)
include_directories(${PTOOLS_ROOT}/C++/opencontainers_1_8_5/include)
add_definitions(-DOC_NEW_STYLE_INCLUDES) # used in ptools
endif()

add_executable(example example.cc)
if(APPLE)
set(OPTIONAL_LINK_FLAGS)
Expand Down
18 changes: 3 additions & 15 deletions paddle/inference/example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,33 +18,21 @@ limitations under the License. */
#include "paddle/inference/inference.h"

DEFINE_string(dirname, "", "Directory of the inference model.");
DEFINE_string(feed_var_names, "", "Names of feeding variables");
DEFINE_string(fetch_var_names, "", "Names of fetching variables");

int main(int argc, char** argv) {
google::ParseCommandLineFlags(&argc, &argv, true);
if (FLAGS_dirname.empty() || FLAGS_feed_var_names.empty() ||
FLAGS_fetch_var_names.empty()) {
if (FLAGS_dirname.empty()) {
// Example:
// ./example --dirname=recognize_digits_mlp.inference.model
// --feed_var_names="x"
// --fetch_var_names="fc_2.tmp_2"
std::cout << "Usage: ./example --dirname=path/to/your/model "
"--feed_var_names=x --fetch_var_names=y"
<< std::endl;
std::cout << "Usage: ./example --dirname=path/to/your/model" << std::endl;
exit(1);
}

std::cout << "FLAGS_dirname: " << FLAGS_dirname << std::endl;
std::cout << "FLAGS_feed_var_names: " << FLAGS_feed_var_names << std::endl;
std::cout << "FLAGS_fetch_var_names: " << FLAGS_fetch_var_names << std::endl;

std::string dirname = FLAGS_dirname;
std::vector<std::string> feed_var_names = {FLAGS_feed_var_names};
std::vector<std::string> fetch_var_names = {FLAGS_fetch_var_names};

paddle::InferenceEngine* engine = new paddle::InferenceEngine();
engine->LoadInferenceModel(dirname, feed_var_names, fetch_var_names);
engine->LoadInferenceModel(dirname);

paddle::framework::LoDTensor input;
srand(time(0));
Expand Down
40 changes: 29 additions & 11 deletions paddle/inference/inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,37 @@ limitations under the License. */

namespace paddle {

void InferenceEngine::LoadInferenceModel(const std::string& dirname) {
std::string model_filename = dirname + "/__model__.dat";
LOG(INFO) << "loading model from " << model_filename;
std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary);
std::string program_desc_str;
inputfs.seekg(0, std::ios::end);
program_desc_str.resize(inputfs.tellg());
inputfs.seekg(0, std::ios::beg);
LOG(INFO) << "program_desc_str's size: " << program_desc_str.size();
inputfs.read(&program_desc_str[0], program_desc_str.size());
inputfs.close();

program_ = new framework::ProgramDesc(program_desc_str);
GenerateLoadProgram(dirname);

framework::BlockDesc* global_block = program_->MutableBlock(0);
feed_var_names_.clear();
fetch_var_names_.clear();
for (auto* op : global_block->AllOps()) {
if (op->Type() == "feed") {
feed_var_names_.insert(feed_var_names_.begin(), op->Output("Out")[0]);
} else if (op->Type() == "fetch") {
fetch_var_names_.push_back(op->Input("X")[0]);
}
}
}

void InferenceEngine::LoadInferenceModel(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove this function. If there are not feed_op and fetch_op in the ProgramDesc, users can specify these when calling Run().

Copy link
Contributor

@sidgoyal78 sidgoyal78 Jan 19, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, i don't understand this properly. Based on the updated design, the Run() function does not take as input the vector of fetch_var_name and feed_var_names. Right?

void Run(const ProgramDesc* program,
           Scope* scope,
           std::map<std::string, Tensor>& feeds,
           std::map<std::string, Tensor>& fetchs,
           std::string& feed_var_name = "feed",
           std::string& fetch_var_name = "fetch") {

So can you please explain the idea that users can specify that information when calling Run().

Copy link
Contributor

@Xreki Xreki Jan 19, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can get the feed_var_names from the argument std::map<std::string, Tensor>& feeds, where the std::string represent a name and the Tensor is input data.

Why the argument is a std::map, because the corresponding argument in Python implementation is a dict.

Have a look at the example, where show the detailed usage of the Executor.Run().

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, will take a look. Thanks for the reply.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, will remove this function in the next PR.

const std::string& dirname,
const std::vector<std::string>& feed_var_names,
const std::vector<std::string>& fetch_var_names) {
#ifdef PADDLE_USE_PTOOLS
std::string model_filename = dirname + "/__model__";
LOG(INFO) << "Using PicklingTools, loading model from " << model_filename;
Val v;
LoadValFromFile(model_filename.c_str(), v, SERIALIZE_P0);
std::string program_desc_str = v["program_desc_str"];
LOG(INFO) << "program_desc_str's size: " << program_desc_str.size();
// PicklingTools cannot parse the vector of strings correctly.
#else
std::string model_filename = dirname + "/__model__.dat";
LOG(INFO) << "loading model from " << model_filename;
std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary);
Expand All @@ -48,7 +66,7 @@ void InferenceEngine::LoadInferenceModel(
LOG(INFO) << "program_desc_str's size: " << program_desc_str.size();
inputfs.read(&program_desc_str[0], program_desc_str.size());
inputfs.close();
#endif

program_ = new framework::ProgramDesc(program_desc_str);
GenerateLoadProgram(dirname);

Expand All @@ -62,7 +80,7 @@ void InferenceEngine::LoadInferenceModel(
}

bool InferenceEngine::IsParameter(const framework::VarDesc* var) {
if (var->Persistable()) {
if (var->Persistable() && var->Name() != "feed" && var->Name() != "fetch") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not use the name of Variable to decide whether the var is input or output of feed_op and fetch_op, because the name is not fixed, and it is possible to specify other names.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. And we don't need to check fetch because fetch will not be an input to an op.
We can get the feed var name from the feed op's input info.

Will fix in the future PR.

// There are many unreachable variables in the program
for (size_t i = 0; i < program_->Size(); ++i) {
const framework::BlockDesc& block = program_->Block(i);
Expand Down
1 change: 1 addition & 0 deletions paddle/inference/inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class InferenceEngine {
delete load_program_;
}

void LoadInferenceModel(const std::string& dirname);
void LoadInferenceModel(const std::string& dirname,
const std::vector<std::string>& feed_var_names,
const std::vector<std::string>& fetch_var_names);
Expand Down
31 changes: 31 additions & 0 deletions python/paddle/v2/fluid/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import cPickle as pickle

from paddle.v2.fluid.framework import Program, Parameter, default_main_program, Variable
from . import core

__all__ = [
'save_vars',
Expand Down Expand Up @@ -191,6 +192,33 @@ def get_inference_program(target_vars, main_program=None):
return inference_program


def prepend_feed_ops(inference_program, feeded_var_names):
global_block = inference_program.global_block()
feed_var = global_block.create_var(
name='feed', type=core.VarDesc.VarType.FEED_MINIBATCH, persistable=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There might be some problem if fixed the name to feed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will fix this in the next PR.


for i, name in enumerate(feeded_var_names):
out = global_block.var(name)
global_block.prepend_op(
type='feed',
inputs={'X': [feed_var]},
outputs={'Out': [out]},
attrs={'col': i})


def append_fetch_ops(inference_program, fetch_var_names):
global_block = inference_program.global_block()
fetch_var = global_block.create_var(
name='fetch', type=core.VarDesc.VarType.FETCH_LIST, persistable=True)

for i, name in enumerate(fetch_var_names):
global_block.append_op(
type='fetch',
inputs={'X': [name]},
outputs={'Out': [fetch_var]},
attrs={'col': i})


def save_inference_model(dirname,
feeded_var_names,
target_vars,
Expand Down Expand Up @@ -241,6 +269,9 @@ def save_inference_model(dirname,
"fetch_var_names": fetch_var_names
}, f, -1)

prepend_feed_ops(inference_program, feeded_var_names)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove Line 265 - 270 now, and change the implementation of load_inference_model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do this in the next PR. Thanks!

append_fetch_ops(inference_program, fetch_var_names)

# Save only programDesc of inference_program in binary format
# in another file: __model__.dat
with open(model_file_name + ".dat", "wb") as fp:
Expand Down