Skip to content

Commit

Permalink
[MXNET-38]add reshape predicator function to c_predict_api (apache#9984)
Browse files Browse the repository at this point in the history
* add reshape predicator function to c_predict_api

* fix

* fix

* fix
  • Loading branch information
Ldpe2G authored and cjolivier01 committed Mar 10, 2018
1 parent 52b5196 commit 34d5e50
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 0 deletions.
21 changes: 21 additions & 0 deletions include/mxnet/c_predict_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,27 @@ MXNET_DLL int MXPredCreatePartialOut(const char* symbol_json_str,
mx_uint num_output_nodes,
const char** output_keys,
PredictorHandle* out);
/*!
* \brief Change the input shape of an existing predictor.
* \param num_input_nodes Number of input nodes to the net,
* For feedforward net, this is 1.
* \param input_keys The name of input argument.
* For feedforward net, this is {"data"}
* \param input_shape_indptr Index pointer of shapes of each input node.
* The length of this array = num_input_nodes + 1.
* For feedforward net that takes 4 dimensional input, this is {0, 4}.
* \param input_shape_data A flatted data of shapes of each input node.
* For feedforward net that takes 4 dimensional input, this is the shape data.
* \param handle The original predictor handle.
* \param out The reshaped predictor handle.
* \return 0 when success, -1 when failure.
*/
int MXPredReshape(mx_uint num_input_nodes,
const char** input_keys,
const mx_uint* input_shape_indptr,
const mx_uint* input_shape_data,
PredictorHandle handle,
PredictorHandle* out);
/*!
* \brief Get the shape of output node.
* The returned shape_data and shape_ndim is only valid before next call to MXPred function.
Expand Down
97 changes: 97 additions & 0 deletions src/c_api/c_predict_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ struct MXAPIPredictor {
std::vector<NDArray> out_arrays;
// argument arrays
std::vector<NDArray> arg_arrays;
// auxiliary arrays
std::vector<NDArray> aux_arrays;
// output shapes
std::vector<TShape> out_shapes;
// uint32_t buffer for output shapes
Expand All @@ -51,6 +53,10 @@ struct MXAPIPredictor {
std::unordered_map<std::string, size_t> key2arg;
// executor
std::unique_ptr<Executor> exec;
// symbol
nnvm::Symbol sym;
// Context
Context ctx;
};

struct MXAPINDList {
Expand Down Expand Up @@ -243,6 +249,97 @@ int MXPredCreatePartialOut(const char* symbol_json_str,
API_END_HANDLE_ERROR(delete ret);
}

int MXPredReshape(mx_uint num_input_nodes,
const char** input_keys,
const mx_uint* input_shape_indptr,
const mx_uint* input_shape_data,
PredictorHandle handle,
PredictorHandle* out) {
MXAPIPredictor* p = static_cast<MXAPIPredictor*>(handle);
std::unique_ptr<MXAPIPredictor> ret(new MXAPIPredictor());

API_BEGIN();
// shape inference
std::unordered_map<std::string, TShape> new_shape;
for (mx_uint i = 0; i < num_input_nodes; ++i) {
new_shape[std::string(input_keys[i])] =
TShape(input_shape_data + input_shape_indptr[i],
input_shape_data + input_shape_indptr[i + 1]);
}
ret->sym = p->sym;
std::vector<std::string> arg_names = ret->sym.ListInputNames(Symbol::kReadOnlyArgs);
std::vector<std::string> aux_names = ret->sym.ListInputNames(Symbol::kAuxiliaryStates);
std::vector<TShape> out_shapes(ret->sym.ListOutputNames().size());
std::vector<TShape> aux_shapes(aux_names.size());
std::vector<TShape> arg_shapes;
ret->key2arg = p->key2arg;

try {
std::vector<TShape> in_shapes;
in_shapes.reserve(arg_names.size());
for (std::string key : ret->sym.ListInputNames(Symbol::kAll)) {
if (new_shape.count(key) != 0) {
in_shapes.push_back(new_shape[key]);
} else {
in_shapes.push_back(TShape());
}
}
nnvm::Graph g; g.outputs = ret->sym.outputs;
g = mxnet::exec::InferShape(std::move(g), std::move(in_shapes), "__shape__");
bool infer_complete = (g.GetAttr<size_t>("shape_num_unknown_nodes") == 0);
CHECK(infer_complete)
<< "The shape information of is not enough to get the shapes";
CopyAttr(g.indexed_graph(),
g.GetAttr<nnvm::ShapeVector>("shape"),
&arg_shapes, &out_shapes, &aux_shapes);
} catch (const mxnet::op::InferShapeError &err) {
throw dmlc::Error(err.msg);
}

ret->arg_arrays = p->arg_arrays;
ret->ctx = p->ctx;
for (size_t i=0; i < arg_names.size(); ++i) {
TShape newShape = arg_shapes[i];
NDArray &arr = p->arg_arrays[i];
if (new_shape.count(arg_names[i]) != 0) {
ret->arg_arrays[i].ReshapeAndAlloc(newShape);
} else {
CHECK_EQ(newShape.Size(), arr.shape().Size())
<< "arg " << arg_names[i]
<< " shape has been changed, only allow to change the shape of input data.";
}
}
p->arg_arrays.clear();

for (size_t i=0; i < aux_names.size(); ++i) {
TShape newShape = aux_shapes[i];
NDArray &arr = p->aux_arrays[i];
CHECK_EQ(newShape.Size(), arr.shape().Size())
<< "aux " << aux_names[i]
<< " shape has been changed, only allow to change the shape of input data.";
}
ret->aux_arrays = p->aux_arrays;
p->aux_arrays.clear();

// bind
{
std::map<std::string, Context> ctx_map;
std::vector<NDArray> grad_store;
grad_store.reserve(ret->arg_arrays.size());
std::vector<OpReqType> grad_req(ret->arg_arrays.size(), kNullOp);

ret->exec.reset(Executor::Bind(ret->sym, ret->ctx, ctx_map,
ret->arg_arrays,
grad_store, grad_req,
ret->aux_arrays,
p->exec.get()));
ret->out_shapes = out_shapes;
ret->out_arrays = ret->exec->outputs();
}
*out = ret.release();
API_END();
}

int MXPredGetOutputShape(PredictorHandle handle,
mx_uint out_index,
mx_uint** shape_data,
Expand Down

0 comments on commit 34d5e50

Please sign in to comment.