Skip to content

Commit 034afc6

Browse files
committed
update (apache#26)
* updates (#1) * add scalars * change format * change inferattr interface * remove scalar * remove warning
1 parent 2db0d3a commit 034afc6

File tree

7 files changed

+91
-34
lines changed

7 files changed

+91
-34
lines changed

nnvm/example/src/operator.cc

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@ using nnvm::array_view;
2121

2222
// simply return the shape as same
2323
inline bool SameShape(const NodeAttrs& attrs,
24-
array_view<TShape*> ishape,
25-
array_view<TShape*> oshape) {
26-
if (ishape.size() == 0 || ishape[0]->ndim() == 0) return false;
27-
for (TShape* pshape : oshape) {
28-
*pshape = *ishape[0];
24+
std::vector<TShape> *ishape,
25+
std::vector<TShape> *oshape) {
26+
if (ishape->size() == 0 || (*ishape)[0].ndim() == 0) return false;
27+
for (TShape& pshape : *oshape) {
28+
pshape = (*ishape)[0];
2929
}
30-
for (TShape* pshape : ishape) {
31-
*pshape = *ishape[0];
30+
for (TShape& pshape : *ishape) {
31+
pshape = (*ishape)[0];
3232
}
3333
return true;
3434
}
@@ -51,13 +51,13 @@ NNVM_REGISTER_OP(reshape)
5151
})
5252
.attr<FInferShape>(
5353
"FInferShape", [] (const NodeAttrs& attrs,
54-
array_view<TShape*> ishape,
55-
array_view<TShape*> oshape) {
54+
std::vector<TShape> *ishape,
55+
std::vector<TShape> *oshape) {
5656
// get parsed attribute
5757
const TShape& target = nnvm::get<TShape>(attrs.parsed);
58-
*oshape[0] = target;
59-
if (ishape[0]->ndim() == 0) return false;
60-
CHECK_EQ(ishape[0]->Size(), target.Size())
58+
(*oshape)[0] = target;
59+
if ((*ishape)[0].ndim() == 0) return false;
60+
CHECK_EQ((*ishape)[0].Size(), target.Size())
6161
<< "Reshape op: source target shape mismatch";
6262
return true;
6363
})
@@ -78,9 +78,9 @@ NNVM_REGISTER_OP(cast)
7878
.attr<FInferShape>("FInferShape", SameShape)
7979
.attr<FInferType>(
8080
"FInferType", [](const NodeAttrs& attrs,
81-
array_view<int*> itype,
82-
array_view<int*> otype) {
83-
*otype[0] = nnvm::get<int>(attrs.parsed);
81+
std::vector<int> *itype,
82+
std::vector<int> *otype) {
83+
(*otype)[0] = nnvm::get<int>(attrs.parsed);
8484
return true;
8585
});
8686

nnvm/include/nnvm/op.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#ifndef NNVM_OP_H_
77
#define NNVM_OP_H_
88

9+
#include <dmlc/parameter.h>
910
#include <string>
1011
#include <vector>
1112
#include <utility>
@@ -22,6 +23,7 @@ struct NodeAttrs;
2223
template<typename ValueType>
2324
class OpMap;
2425
class OpRegistryEntry;
26+
using dmlc::ParamFieldInfo;
2527

2628
/*! \brief constant to indicate it take any length of positional inputs */
2729
static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
@@ -80,6 +82,8 @@ class Op {
8082
* This can be used to generate docstring automatically for the operator.
8183
*/
8284
std::string description;
85+
/* \brief description of inputs and keyword arguments*/
86+
std::vector<ParamFieldInfo> arguments;
8387
/*!
8488
* \brief number of inputs to the operator,
8589
* -1 means it is variable length
@@ -149,6 +153,22 @@ class Op {
149153
* \return reference to self.
150154
*/
151155
inline Op& describe(const std::string& descr); // NOLINT(*)
156+
/*!
157+
* \brief Add argument information to the function.
158+
* \param name Name of the argument.
159+
* \param type Type of the argument.
160+
* \param description Description of the argument.
161+
* \return reference to self.
162+
*/
163+
inline Op& add_argument(const std::string &name,
164+
const std::string &type,
165+
const std::string &description);
166+
/*!
167+
* \brief Append list if arguments to the end.
168+
* \param args Additional list of arguments.
169+
* \return reference to self.
170+
*/
171+
inline Op& add_arguments(const std::vector<ParamFieldInfo> &args);
152172
/*!
153173
* \brief Set the num_inputs
154174
* \param n The number of inputs to be set.
@@ -340,6 +360,18 @@ inline Op& Op::describe(const std::string& descr) { // NOLINT(*)
340360
return *this;
341361
}
342362

363+
inline Op& Op::add_argument(const std::string &name,
364+
const std::string &type,
365+
const std::string &description) {
366+
arguments.push_back({name, type, type, description});
367+
return *this;
368+
}
369+
370+
inline Op& Op::add_arguments(const std::vector<ParamFieldInfo> &args) {
371+
this->arguments.insert(arguments.end(), args.begin(), args.end());
372+
return *this;
373+
}
374+
343375
inline Op& Op::set_num_inputs(uint32_t n) { // NOLINT(*)
344376
this->num_inputs = n;
345377
return *this;

nnvm/include/nnvm/op_attr_types.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ using FMutateInputs = std::function<std::vector<uint32_t> (const NodeAttrs& attr
5757
*/
5858
template<typename AttrType>
5959
using FInferNodeEntryAttr = std::function<bool (const NodeAttrs& attrs,
60-
array_view<AttrType*> in_attrs,
61-
array_view<AttrType*> out_attrs)>;
60+
std::vector<AttrType> *in_attrs,
61+
std::vector<AttrType> *out_attrs)>;
6262
/*!
6363
* \brief Shape inference function.
6464
* Update the shapes given the input shape information.

nnvm/src/c_api/c_api_symbolic.cc

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,26 @@ int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator,
2828
const char ***arg_descriptions,
2929
const char **return_type) {
3030
const Op *op = static_cast<const Op *>(creator);
31+
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
3132

3233
API_BEGIN();
3334
*name = op->name.c_str();
3435
*description = op->description.c_str();
35-
*num_doc_args = 0;
36+
*num_doc_args = static_cast<nn_uint>(op->arguments.size());
37+
if (return_type) *return_type = nullptr;
38+
ret->ret_vec_charp.clear();
39+
for (size_t i = 0; i < op->arguments.size(); ++i) {
40+
ret->ret_vec_charp.push_back(op->arguments[i].name.c_str());
41+
}
42+
for (size_t i = 0; i < op->arguments.size(); ++i) {
43+
ret->ret_vec_charp.push_back(op->arguments[i].type_info_str.c_str());
44+
}
45+
for (size_t i = 0; i < op->arguments.size(); ++i) {
46+
ret->ret_vec_charp.push_back(op->arguments[i].description.c_str());
47+
}
48+
*arg_names = dmlc::BeginPtr(ret->ret_vec_charp);
49+
*arg_type_infos = dmlc::BeginPtr(ret->ret_vec_charp) + op->arguments.size();
50+
*arg_descriptions = dmlc::BeginPtr(ret->ret_vec_charp) + (op->arguments.size() * 2);
3651
API_END();
3752
}
3853

nnvm/src/core/symbolic.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,10 @@ void Symbol::Print(std::ostream &os) const {
151151
}
152152
if (!node->attrs.dict.empty()) {
153153
os << "Attrs:\n";
154-
for (auto &kv : node->attrs.dict) {
154+
// make an ordered copy because unordered_map doesn't guarantee order.
155+
std::map<std::string, std::string> sorted_dict(
156+
node->attrs.dict.begin(), node->attrs.dict.end());
157+
for (auto &kv : sorted_dict) {
155158
os << '\t' << kv.first << '=' << kv.second << '\n';
156159
}
157160
}

nnvm/src/pass/infer_shape_type.cc

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,44 +47,52 @@ Graph InferAttr(Graph &&ret,
4747
}
4848

4949
// temp space for shape inference.
50-
std::vector<AttrType*> ishape, oshape;
50+
std::vector<AttrType> ishape, oshape;
5151
// number of completed nodes
5252
size_t num_unknown = 0;
5353
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
5454
const auto& inode = idx[nid];
55+
uint32_t num_inputs = inode.inputs.size();
56+
uint32_t num_outputs = inode.source->num_outputs();
5557
if (inode.source->is_variable()) {
5658
if (shape_attr_key.length() != 0 && fis_none(rshape[idx.entry_id(nid, 0)])) {
5759
auto it = inode.source->attrs.dict.find(shape_attr_key);
5860
if (it != inode.source->attrs.dict.end()) {
59-
CHECK_EQ(inode.source->num_outputs(), 1);
61+
CHECK_EQ(num_outputs, 1);
6062
std::istringstream is(it->second);
6163
CHECK(is >> rshape[idx.entry_id(nid, 0)]) << "Invalid attribute";
6264
}
6365
}
6466
continue;
6567
}
66-
ishape.resize(inode.inputs.size());
67-
for (uint32_t i = 0; i < ishape.size(); ++i) {
68-
ishape[i] = &rshape[idx.entry_id(inode.inputs[i])];
69-
}
70-
oshape.resize(inode.source->num_outputs());
71-
for (uint32_t i = 0; i < oshape.size(); ++i) {
72-
oshape[i] = &rshape[idx.entry_id(nid, i)];
73-
}
7468
if (finfer_shape.count(inode.source->op)) {
69+
ishape.resize(num_inputs, def_value);
70+
for (uint32_t i = 0; i < ishape.size(); ++i) {
71+
ishape[i] = rshape[idx.entry_id(inode.inputs[i])];
72+
}
73+
oshape.resize(num_outputs, def_value);
74+
for (uint32_t i = 0; i < oshape.size(); ++i) {
75+
oshape[i] = rshape[idx.entry_id(nid, i)];
76+
}
7577
num_unknown +=
76-
!(finfer_shape[inode.source->op](inode.source->attrs, ishape, oshape));
78+
!(finfer_shape[inode.source->op](inode.source->attrs, &ishape, &oshape));
79+
for (uint32_t i = 0; i < num_inputs; ++i) {
80+
rshape[idx.entry_id(inode.inputs[i])] = ishape[i];
81+
}
82+
for (uint32_t i = 0; i < num_outputs; ++i) {
83+
rshape[idx.entry_id(nid, i)] = oshape[i];
84+
}
7785
} else if (is_backward.get(inode.source->op, false)) {
7886
// backward operator inference.
7987
CHECK_GE(inode.control_deps.size(), 1)
8088
<< "BackwardOp need to have control_deps to its forward op";
8189
const auto& fnode = idx[inode.control_deps[0]];
82-
CHECK_EQ(fnode.inputs.size(), inode.source->num_outputs())
90+
CHECK_EQ(fnode.inputs.size(), num_outputs)
8391
<< "BackwardOp need to correspond to the forward node";
8492
bool known = true;
8593
for (size_t i = 0; i < fnode.inputs.size(); ++i) {
86-
*oshape[i] = rshape[idx.entry_id(fnode.inputs[i])];
87-
if (fis_none(*oshape[i])) known = false;
94+
rshape[idx.entry_id(nid, i)] = rshape[idx.entry_id(fnode.inputs[i])];
95+
if (fis_none(rshape[idx.entry_id(nid, i)])) known = false;
8896
}
8997
num_unknown += !known;
9098
}

nnvm/tests/python/test_symbol.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ def test_copy():
4141
z = sym.Variable('z')
4242
y = sym.exp(sym.add(x, x, name='add', gpu=2),
4343
name='exp', gpu=1, attr={"kk": "1"})
44-
4544
assert y.__copy__().debug_str() == y.debug_str()
4645

4746
if __name__ == "__main__":

0 commit comments

Comments
 (0)