Skip to content

Commit 6464d90

Browse files
committed
Change function def to Node ref for more flexiblity (#27)
* Remove warning in g++5 * Change function def to Node ref for more flexiblity
1 parent c362a28 commit 6464d90

File tree

15 files changed

+111
-54
lines changed

15 files changed

+111
-54
lines changed

nnvm/example/src/operator.cc

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@ using nnvm::FMutateInputs;
1515
using nnvm::FInferShape;
1616
using nnvm::FInferType;
1717
using nnvm::FInplaceOption;
18+
using nnvm::Node;
1819
using nnvm::NodeAttrs;
1920
using nnvm::TShape;
2021
using nnvm::array_view;
2122

2223
// simply return the shape as same
23-
inline bool SameShape(const NodeAttrs& attrs,
24+
inline bool SameShape(const Node& n,
2425
std::vector<TShape> *ishape,
2526
std::vector<TShape> *oshape) {
2627
if (ishape->size() == 0 || (*ishape)[0].ndim() == 0) return false;
@@ -33,7 +34,7 @@ inline bool SameShape(const NodeAttrs& attrs,
3334
return true;
3435
}
3536

36-
inline std::vector<std::pair<int, int> > InplaceIn0Out0(const NodeAttrs& attrs) {
37+
inline std::vector<std::pair<int, int> > InplaceIn0Out0(const Node& n) {
3738
return {{0, 0}};
3839
}
3940

@@ -50,11 +51,11 @@ NNVM_REGISTER_OP(reshape)
5051
attrs->parsed = std::move(target);
5152
})
5253
.attr<FInferShape>(
53-
"FInferShape", [] (const NodeAttrs& attrs,
54+
"FInferShape", [] (const Node& n,
5455
std::vector<TShape> *ishape,
5556
std::vector<TShape> *oshape) {
5657
// get parsed attribute
57-
const TShape& target = nnvm::get<TShape>(attrs.parsed);
58+
const TShape& target = nnvm::get<TShape>(n.attrs.parsed);
5859
(*oshape)[0] = target;
5960
if ((*ishape)[0].ndim() == 0) return false;
6061
CHECK_EQ((*ishape)[0].Size(), target.Size())
@@ -77,10 +78,10 @@ NNVM_REGISTER_OP(cast)
7778
})
7879
.attr<FInferShape>("FInferShape", SameShape)
7980
.attr<FInferType>(
80-
"FInferType", [](const NodeAttrs& attrs,
81+
"FInferType", [](const Node& n,
8182
std::vector<int> *itype,
8283
std::vector<int> *otype) {
83-
(*otype)[0] = nnvm::get<int>(attrs.parsed);
84+
(*otype)[0] = nnvm::get<int>(n.attrs.parsed);
8485
return true;
8586
});
8687

@@ -109,7 +110,7 @@ NNVM_REGISTER_OP(cross_device_copy)
109110
NNVM_REGISTER_OP(conv2d)
110111
.describe("take conv of input")
111112
.set_num_inputs(2)
112-
.attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
113+
.attr<FListInputNames>("FListInputNames", [](const Node& n) {
113114
return std::vector<std::string>{"data", "weight"};
114115
});
115116

@@ -119,7 +120,7 @@ NNVM_REGISTER_OP(add)
119120
NNVM_REGISTER_OP(assign)
120121
.set_num_inputs(2)
121122
.set_num_outputs(1)
122-
.attr<FMutateInputs>("FMutateInputs", [](const NodeAttrs& attrs) {
123+
.attr<FMutateInputs>("FMutateInputs", [](const Node& n) {
123124
return std::vector<uint32_t>{0};
124125
});
125126

nnvm/include/dmlc/base.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@
5858
__cplusplus >= 201103L || defined(_MSC_VER))
5959
#endif
6060

61+
/*! \brief strict CXX11 support */
62+
#ifndef DMLC_STRICT_CXX11
63+
#define DMLC_STRICT_CXX11 (__cplusplus >= 201103L || defined(_MSC_VER))
64+
#endif
65+
6166
/// check if g++ is before 4.6
6267
#if DMLC_USE_CXX11 && defined(__GNUC__) && !defined(__clang_version__)
6368
#if __GNUC__ == 4 && __GNUC_MINOR__ < 6
@@ -69,6 +74,7 @@
6974
#endif
7075
#endif
7176

77+
7278
/*!
7379
* \brief Enable std::thread related modules,
7480
* Used to disable some module in mingw compile.
@@ -82,6 +88,13 @@
8288
#define DMLC_USE_REGEX (__cplusplus >= 201103L || defined(_MSC_VER))
8389
#endif
8490

91+
/*! \brief helper macro to supress unused warning */
92+
#if defined(__GNUC__)
93+
#define DMLC_ATTRIBUTE_UNUSED __attribute__((unused))
94+
#else
95+
#define DMLC_ATTRIBUTE_UNUSED
96+
#endif
97+
8598
/*! \brief helper macro to generate string concat */
8699
#define DMLC_STR_CONCAT_(__x, __y) __x##__y
87100
#define DMLC_STR_CONCAT(__x, __y) DMLC_STR_CONCAT_(__x, __y)

nnvm/include/dmlc/json.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
#include <typeindex>
2626
#include <typeinfo>
2727
#include <unordered_map>
28+
#if DMLC_STRICT_CXX11
2829
#include "./any.h"
30+
#endif // DMLC_STRICT_CXX11
2931
#endif // DMLC_USE_CXX11
3032

3133
namespace dmlc {
@@ -320,7 +322,8 @@ class JSONObjectReadHelper {
320322
};
321323

322324
#define DMLC_JSON_ENABLE_ANY_VAR_DEF(KeyName) \
323-
static ::dmlc::json::AnyJSONManager& __make_AnyJSONType ## _ ## KeyName ## __
325+
static DMLC_ATTRIBUTE_UNUSED ::dmlc::json::AnyJSONManager& \
326+
__make_AnyJSONType ## _ ## KeyName ## __
324327

325328
/*!
326329
* \def DMLC_JSON_ENABLE_ANY
@@ -475,7 +478,7 @@ struct Handler {
475478
}
476479
};
477480

478-
#if DMLC_USE_CXX11
481+
#if DMLC_STRICT_CXX11
479482
// Manager to store json serialization strategy.
480483
class AnyJSONManager {
481484
public:
@@ -561,7 +564,7 @@ struct Handler<any> {
561564
CHECK(!reader->NextArrayItem()) << "invalid any json format";
562565
}
563566
};
564-
#endif // DMLC_USE_CXX11
567+
#endif // DMLC_STRICT_CXX11
565568

566569
} // namespace json
567570

nnvm/include/dmlc/parameter.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,8 @@ struct Parameter {
251251
static ::dmlc::parameter::ParamManagerSingleton<PType> inst(#PType); \
252252
return &inst.manager; \
253253
} \
254-
static ::dmlc::parameter::ParamManager &__make__ ## PType ## ParamManager__ = \
254+
static DMLC_ATTRIBUTE_UNUSED ::dmlc::parameter::ParamManager& \
255+
__make__ ## PType ## ParamManager__ = \
255256
(*PType::__MANAGER__()) \
256257

257258
//! \endcond

nnvm/include/dmlc/registry.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ class FunctionRegEntryBase {
216216
* \sa FactoryRegistryEntryBase
217217
*/
218218
#define DMLC_REGISTRY_REGISTER(EntryType, EntryTypeName, Name) \
219-
static EntryType & __make_ ## EntryTypeName ## _ ## Name ## __ = \
219+
static DMLC_ATTRIBUTE_UNUSED EntryType & __make_ ## EntryTypeName ## _ ## Name ## __ = \
220220
::dmlc::Registry<EntryType>::Get()->__REGISTER__(#Name) \
221221

222222
/*!
@@ -272,6 +272,7 @@ class FunctionRegEntryBase {
272272
*/
273273
#define DMLC_REGISTRY_LINK_TAG(UniqueTag) \
274274
int __dmlc_registry_file_tag_ ## UniqueTag ## __(); \
275-
static int __reg_file_tag_ ## UniqueTag ## __ = __dmlc_registry_file_tag_ ## UniqueTag ## __();
275+
static int DMLC_ATTRIBUTE_UNUSED __reg_file_tag_ ## UniqueTag ## __ = \
276+
__dmlc_registry_file_tag_ ## UniqueTag ## __();
276277
} // namespace dmlc
277278
#endif // DMLC_REGISTRY_H_

nnvm/include/nnvm/node.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ namespace nnvm {
1717

1818
// Forward declare node.
1919
class Node;
20-
2120
/*!
2221
* \brief we always used NodePtr for a reference pointer
2322
* to the node, so this alias can be changed in case.
@@ -48,8 +47,6 @@ struct NodeEntry {
4847
struct NodeAttrs {
4948
/*! \brief name of the node */
5049
std::string name;
51-
/*! \brief Vector representation of positional attributes */
52-
std::vector<double> scalars;
5350
/*! \brief The dictionary representation of attributes */
5451
std::unordered_map<std::string, std::string> dict;
5552
/*!
@@ -108,7 +105,7 @@ inline uint32_t Node::num_outputs() const {
108105
if (this->op->get_num_outputs == nullptr) {
109106
return this->op->num_outputs;
110107
} else {
111-
return this->op->get_num_outputs(this->attrs);
108+
return this->op->get_num_outputs(*this);
112109
}
113110
}
114111

@@ -117,7 +114,7 @@ inline uint32_t Node::num_inputs() const {
117114
if (this->op->get_num_inputs == nullptr) {
118115
return this->op->num_inputs;
119116
} else {
120-
return this->op->get_num_inputs(this->attrs);
117+
return this->op->get_num_inputs(*this);
121118
}
122119
}
123120

nnvm/include/nnvm/op.h

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -102,16 +102,16 @@ class Op {
102102
uint32_t num_outputs = 1;
103103
/*!
104104
* \brief get number of outputs given information about the node.
105-
* \param attrs The attribute of the node
105+
* \param n The node
106106
* \return number of outputs.
107107
*/
108-
std::function<uint32_t(const NodeAttrs& attrs)> get_num_outputs = nullptr;
108+
std::function<uint32_t(const Node& n)> get_num_outputs = nullptr;
109109
/*!
110110
* \brief get number of inputs given information about the node.
111-
* \param attrs The attribute of the node
111+
* \param n The node
112112
* \return number of inputs
113113
*/
114-
std::function<uint32_t(const NodeAttrs& attrs)> get_num_inputs = nullptr;
114+
std::function<uint32_t(const Node& n)> get_num_inputs = nullptr;
115115
/*!
116116
* \brief Attribute parser to parse the NodeAttrs information.
117117
*
@@ -136,11 +136,11 @@ class Op {
136136
* attrs->parsed = std::move(param);
137137
* }
138138
* // The other function that can utilize the parsed result.
139-
* TShape SumInferShape(const NodeAttrs& attrs,
139+
* TShape SumInferShape(const NodePtr& ptr,
140140
* const std::vector<TShape>& ishapes) {
141141
* // we can use the parsed version of param
142142
* // without repeatively parsing the parameter
143-
* const SumParam& param = nnvm::get<SumParam>(attrs.parsed);
143+
* const SumParam& param = nnvm::get<SumParam>(ptr->attrs.parsed);
144144
* }
145145
* \endcode
146146
*/
@@ -180,7 +180,7 @@ class Op {
180180
* \param fn The function to be set.
181181
* \return reference to self.
182182
*/
183-
inline Op& set_num_inputs(std::function<uint32_t (const NodeAttrs& attr)> fn); // NOLINT(*)
183+
inline Op& set_num_inputs(std::function<uint32_t (const Node& n)> fn); // NOLINT(*)
184184
/*!
185185
* \brief Set the num_outputs
186186
* \param n The number of outputs to be set.
@@ -192,7 +192,7 @@ class Op {
192192
* \param fn The function to be set.
193193
* \return reference to self.
194194
*/
195-
inline Op& set_num_outputs(std::function<uint32_t (const NodeAttrs& attr)> fn); // NOLINT(*)
195+
inline Op& set_num_outputs(std::function<uint32_t (const Node& n)> fn); // NOLINT(*)
196196
/*!
197197
* \brief Set the attr_parser function.
198198
* \param fn The number of outputs to be set.
@@ -279,10 +279,8 @@ class OpMap {
279279
};
280280

281281
// internal macros to make
282-
#define NNVM_STR_CONCAT_(__x, __y) __x##__y
283-
#define NNVM_STR_CONCAT(__x, __y) NNVM_STR_CONCAT_(__x, __y)
284282
#define NNVM_REGISTER_VAR_DEF(OpName) \
285-
static ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName
283+
static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName
286284

287285
/*!
288286
* \def NNVM_REGISTER_OP
@@ -300,7 +298,7 @@ class OpMap {
300298
* \endcode
301299
*/
302300
#define NNVM_REGISTER_OP(OpName) \
303-
NNVM_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \
301+
DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \
304302
::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName)
305303

306304
// implementations of template functions after this.
@@ -377,7 +375,7 @@ inline Op& Op::set_num_inputs(uint32_t n) { // NOLINT(*)
377375
return *this;
378376
}
379377

380-
inline Op& Op::set_num_inputs(std::function<uint32_t (const NodeAttrs& attr)> fn) { // NOLINT(*)
378+
inline Op& Op::set_num_inputs(std::function<uint32_t (const Node& n)> fn) { // NOLINT(*)
381379
this->get_num_inputs = fn;
382380
return *this;
383381
}
@@ -387,7 +385,7 @@ inline Op& Op::set_num_outputs(uint32_t n) { // NOLINT(*)
387385
return *this;
388386
}
389387

390-
inline Op& Op::set_num_outputs(std::function<uint32_t (const NodeAttrs& attr)> fn) { // NOLINT(*)
388+
inline Op& Op::set_num_outputs(std::function<uint32_t (const Node& n)> fn) { // NOLINT(*)
391389
this->get_num_outputs = fn;
392390
return *this;
393391
}

nnvm/include/nnvm/op_attr_types.h

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <functional>
1313
#include "./base.h"
1414
#include "./tuple.h"
15+
#include "./node.h"
1516

1617
namespace nnvm {
1718

@@ -21,44 +22,44 @@ namespace nnvm {
2122
/*!
2223
* \brief Return list of input arguments names of each operator.
2324
*
24-
* \param attrs The attributes of the node.
25+
* \param n The node.
2526
* \return list of inputs
2627
* \note Register under "FListInputNames", default return {"data"}.
2728
*
2829
* FListInputNames enables automatic variable creation for missing arguments.
2930
*/
30-
using FListInputNames = std::function<std::vector<std::string> (const NodeAttrs& attrs)>;
31+
using FListInputNames = std::function<std::vector<std::string> (const Node& n)>;
3132

3233
/*!
3334
* \brief Return list of output arguments names of each operator.
3435
*
35-
* \param attrs The attributes of the node.
36+
* \param n The node.
3637
* \return list of inputs
3738
* \note Register under "FListOutputNames", default return {"outputs"}.
3839
*
3940
* FListOutputNames customized naming for operator outputs.
4041
*/
41-
using FListOutputNames = std::function<std::vector<std::string> (const NodeAttrs& attrs)>;
42+
using FListOutputNames = std::function<std::vector<std::string> (const Node& n)>;
4243

4344
/*!
4445
* \brief Check whether operator will mutate k-th input.
45-
* \param attrs The attributes of the node.
46+
* \param n The node.
4647
* \return list of input indices it mutates.
4748
*
4849
* \note Register under "FMutateInputs", default return false
4950
* FMutateInputs enables mutation order handling correctly.
5051
*/
51-
using FMutateInputs = std::function<std::vector<uint32_t> (const NodeAttrs& attrs)>;
52+
using FMutateInputs = std::function<std::vector<uint32_t> (const Node& n)>;
5253

5354
/*!
5455
* \brief Inference function of certain type.
5556
* \tparam AttrType The type of the attribute to be infered.
5657
* \return whether all attributes are inferred.
5758
*/
5859
template<typename AttrType>
59-
using FInferNodeEntryAttr = std::function<bool (const NodeAttrs& attrs,
60-
std::vector<AttrType> *in_attrs,
61-
std::vector<AttrType> *out_attrs)>;
60+
using FInferNodeEntryAttr = std::function<bool (const Node& n,
61+
std::vector<AttrType> *in_ptr,
62+
std::vector<AttrType> *out_ptr)>;
6263
/*!
6364
* \brief Shape inference function.
6465
* Update the shapes given the input shape information.
@@ -96,7 +97,7 @@ using TIsBackwardOp = bool;
9697
/*!
9798
* \brief Get possible inplace options.
9899
* This function enables optimization to reuse memory of inputs in output.
99-
* \param attrs The attributes of the node
100+
* \param n The node
100101
* \param in_data The input data.
101102
* \param out_data The output data.
102103
* \return list of pair of that maps input->output,
@@ -105,7 +106,20 @@ using TIsBackwardOp = bool;
105106
* \note Register under "FInplaceOption", by default no inplace can happen.
106107
*/
107108
using FInplaceOption = std::function<
108-
std::vector<std::pair<int, int> > (const NodeAttrs& attrs)>;
109+
std::vector<std::pair<int, int> > (const Node& n)>;
110+
111+
/*!
112+
* \brief Get the gradient node of the op node
113+
* This function generates the backward graph of the node
114+
* \param nodeptr The node to take gradient
115+
* \param out_grads Gradient of current node's outputs
116+
* \return gradients of the inputs
117+
*
118+
* \note Register under "FGradient"
119+
*/
120+
using FGradient = std::function<std::vector<NodeEntry>(
121+
const NodePtr& nodeptr,
122+
const std::vector<NodeEntry>& out_grads)>;
109123

110124
} // namespace nnvm
111125

nnvm/include/nnvm/pass.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ namespace nnvm {
2323
* \param src The graph to be transformed.
2424
* \return The generated graph.
2525
*/
26-
typedef std::function<Graph (Graph src)> PassFunction;
26+
using PassFunction = std::function<Graph (Graph src)>;
2727

2828
/*!
2929
* \brief Apply a series of pass transformations on g.

0 commit comments

Comments
 (0)