|
13 | 13 |
|
14 | 14 | namespace nnvm { |
15 | 15 | namespace top { |
16 | | - |
17 | | -/*! \brief exception throwed by InferShape error */ |
18 | | -struct InferShapeError : public dmlc::Error { |
19 | | - /*! \brief analyze message */ |
20 | | - std::string msg; |
21 | | - /*! \brief corresponding input index */ |
22 | | - int index; |
23 | | - // constructor |
24 | | - InferShapeError(const std::string& msg_, int index) |
25 | | - : dmlc::Error(msg_), msg(msg_), index(index) {} |
26 | | -}; |
27 | | - |
28 | | -/*! \brief exception throwed by InferShape error */ |
29 | | -struct InferTypeError : public dmlc::Error { |
30 | | - /*! \brief analyze message */ |
31 | | - std::string msg; |
32 | | - /*! \brief corresponding input index */ |
33 | | - int index; |
34 | | - // constructor |
35 | | - InferTypeError(const std::string& msg_, int index) |
36 | | - : dmlc::Error(msg_), msg(msg_), index(index) {} |
37 | | -}; |
38 | | - |
39 | 16 | /*! |
40 | 17 | * \brief Parse keyword arguments as PType arguments and save to parsed |
41 | 18 | * \tparam PType the arameter type. |
@@ -128,41 +105,88 @@ inline bool type_assign(int *y, const int& x) { |
128 | 105 | return true; |
129 | 106 | } |
130 | 107 |
|
| 108 | +template<typename AttrType> |
| 109 | +inline std::string attr_assign_error_msg(const NodeAttrs& attrs, |
| 110 | + int index, bool is_input, |
| 111 | + const AttrType& expected, |
| 112 | + const AttrType& actual, |
| 113 | + const char* attr_name) { |
| 114 | + static const auto& flist_inputs = Op::GetAttr<FListInputNames>("FListInputNames"); |
| 115 | + static const auto& flist_outputs = Op::GetAttr<FListOutputNames>("FListOutputNames"); |
| 116 | + const auto& flist = is_input ? flist_inputs : flist_outputs; |
| 117 | + std::string name; |
| 118 | + if (flist.count(attrs.op)) { |
| 119 | + name = flist[attrs.op](attrs)[index]; |
| 120 | + } else { |
| 121 | + name = (is_input ? "data" : "output") + std::to_string(index); |
| 122 | + } |
| 123 | + std::ostringstream msg; |
| 124 | + msg << "Operator " << attrs.op->name << "("; |
| 125 | + for (const auto& kv : attrs.dict) msg << kv.first << "=" << kv.second << ", "; |
| 126 | + msg << "name=" << attrs.name << ") expects " << name << "\'s " << attr_name |
| 127 | + << " to be " << expected << ", but got " << actual << "."; |
| 128 | + return msg.str(); |
| 129 | +} |
| 130 | + |
131 | 131 | /*! |
132 | 132 | * \brief macro assign shape to out if out is unknown otherwise check consistency |
133 | 133 | * Use macro so we can see the error file more clearly |
134 | | - * \param shape_array the shape array to store the result |
| 134 | + * \param inputs the shape array to store the result |
135 | 135 | * \param index the index of in the array |
136 | 136 | * \param shape the inferred shape |
137 | 137 | */ |
138 | | -#define SHAPE_ASSIGN_CHECK(shape_array, index, shape) \ |
139 | | - { \ |
140 | | - if (!shape_assign(&(shape_array)[index], TShape(shape))) { \ |
141 | | - std::ostringstream os; \ |
142 | | - os << "Shape inconsistent, Provided=" << (shape_array)[index] << ',' \ |
143 | | - << " inferred shape=" << shape; \ |
144 | | - throw InferShapeError(os.str(), index); \ |
145 | | - } \ |
| 138 | +#define NNVM_ASSIGN_INPUT_SHAPE(attrs, inputs, index, shape) \ |
| 139 | + { \ |
| 140 | + if (!shape_assign(&(inputs)[index], TShape(shape))) { \ |
| 141 | + LOG(FATAL) << attr_assign_error_msg(attrs, index, true, shape, \ |
| 142 | + (inputs)[index], "shape"); \ |
| 143 | + } \ |
| 144 | + } |
| 145 | + |
| 146 | +/*! |
| 147 | + * \brief macro assign shape to out if out is unknown otherwise check consistency |
| 148 | + * Use macro so we can see the error file more clearly |
| 149 | + * \param inputs the shape array to store the result |
| 150 | + * \param index the index of in the array |
| 151 | + * \param shape the inferred shape |
| 152 | + */ |
| 153 | +#define NNVM_ASSIGN_OUTPUT_SHAPE(attrs, outputs, index, shape) \ |
| 154 | + { \ |
| 155 | + if (!shape_assign(&(outputs)[index], TShape(shape))) { \ |
| 156 | + LOG(FATAL) << attr_assign_error_msg(attrs, index, false, shape, \ |
| 157 | + (outputs)[index], "shape"); \ |
| 158 | + } \ |
146 | 159 | } |
147 | 160 |
|
148 | 161 | /*! |
149 | 162 | * \brief macro assign type to out if out is unknown (-1) otherwise check consistency |
150 | 163 | * Use macro so we can see the error file more clearly |
151 | | - * \param type_array the type array to store the result |
| 164 | + * \param inputs the type array to store the result |
152 | 165 | * \param index the index of in the array |
153 | 166 | * \param type the inferred type |
154 | 167 | */ |
155 | | -#define TYPE_ASSIGN_CHECK(type_array, index, type) \ |
156 | | - { \ |
157 | | - if (!type_assign(&(type_array)[index], type)) { \ |
158 | | - std::ostringstream os; \ |
159 | | - os << "Type inconsistent, Provided=" \ |
160 | | - << type_string((type_array)[index]) << ',' \ |
161 | | - << " inferred type=" << type_string(type); \ |
162 | | - throw InferTypeError(os.str(), index); \ |
163 | | - } \ |
| 168 | +#define NNVM_ASSIGN_INPUT_TYPE(attrs, inputs, index, type) \ |
| 169 | + { \ |
| 170 | + if (!type_assign(&(inputs)[index], type)) { \ |
| 171 | + LOG(FATAL) << attr_assign_error_msg(attrs, index, true, type, \ |
| 172 | + (inputs)[index], "type"); \ |
| 173 | + } \ |
164 | 174 | } |
165 | 175 |
|
| 176 | +/*! |
| 177 | + * \brief macro assign type to out if out is unknown (-1) otherwise check consistency |
| 178 | + * Use macro so we can see the error file more clearly |
| 179 | + * \param inputs the type array to store the result |
| 180 | + * \param index the index of in the array |
| 181 | + * \param type the inferred type |
| 182 | + */ |
| 183 | +#define NNVM_ASSIGN_OUTPUT_TYPE(attrs, outputs, index, type) \ |
| 184 | + { \ |
| 185 | + if (!type_assign(&(outputs)[index], type)) { \ |
| 186 | + LOG(FATAL) << attr_assign_error_msg(attrs, index, false, type, \ |
| 187 | + (outputs)[index], "type"); \ |
| 188 | + } \ |
| 189 | + } |
166 | 190 |
|
167 | 191 | // simply return the shape as same |
168 | 192 | inline bool SameShape(const NodeAttrs& attrs, |
|
0 commit comments