Skip to content

Commit e2ae388

Browse files
piiswrongtqchen
authored andcommitted
improve infer shape/type error message (#4)
* improve infer shape/type error message * fix dense infer shape
1 parent 986caf7 commit e2ae388

File tree

3 files changed

+86
-59
lines changed

3 files changed

+86
-59
lines changed

nnvm/src/top/nn.cc

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,23 @@ inline bool DenseInferShape(const nnvm::NodeAttrs& attrs,
3535
CHECK_EQ(in_shape->size(), 2U) << "Input:[data, weight]";
3636
}
3737
CHECK_EQ(out_shape->size(), 1U);
38-
TShape dshape = (*in_shape)[DenseParam::kData];
39-
TShape oshape = (*out_shape)[0];
40-
// require data to be known
41-
if (dshape.ndim() == 0) return false;
42-
dim_t num_input;
43-
num_input = dshape.ProdShape(1, dshape.ndim());
44-
SHAPE_ASSIGN_CHECK(*in_shape, DenseParam::kWeight, TShape({param.units, num_input}));
45-
if (param.use_bias) {
46-
SHAPE_ASSIGN_CHECK(*in_shape, DenseParam::kBias, TShape({param.units}));
38+
if ((*out_shape)[0].ndim() != 0) {
39+
// reverse infer
40+
TShape dshape = (*out_shape)[0];
41+
dshape[dshape.ndim() - 1] = 0;
42+
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, DenseParam::kData, dshape);
43+
}
44+
dim_t num_inputs = 0;
45+
if ((*in_shape)[DenseParam::kData].ndim() != 0) {
46+
TShape oshape = (*in_shape)[DenseParam::kData];
47+
num_inputs = oshape[oshape.ndim() - 1];
48+
oshape[oshape.ndim() - 1] = param.units;
49+
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
4750
}
48-
SHAPE_ASSIGN_CHECK(*out_shape, 0, TShape({dshape[0], param.units}));
49-
if (oshape.ndim() != 0) {
50-
dshape[0] = oshape[0];
51-
SHAPE_ASSIGN_CHECK(*in_shape, DenseParam::kData, dshape);
51+
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, DenseParam::kWeight,
52+
TShape({param.units, num_inputs}));
53+
if (param.use_bias) {
54+
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, DenseParam::kBias, TShape({param.units}));
5255
}
5356
return true;
5457
}

nnvm/src/top/op_common.h

Lines changed: 66 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -13,29 +13,6 @@
1313

1414
namespace nnvm {
1515
namespace top {
16-
17-
/*! \brief exception throwed by InferShape error */
18-
struct InferShapeError : public dmlc::Error {
19-
/*! \brief analyze message */
20-
std::string msg;
21-
/*! \brief corresponding input index */
22-
int index;
23-
// constructor
24-
InferShapeError(const std::string& msg_, int index)
25-
: dmlc::Error(msg_), msg(msg_), index(index) {}
26-
};
27-
28-
/*! \brief exception throwed by InferShape error */
29-
struct InferTypeError : public dmlc::Error {
30-
/*! \brief analyze message */
31-
std::string msg;
32-
/*! \brief corresponding input index */
33-
int index;
34-
// constructor
35-
InferTypeError(const std::string& msg_, int index)
36-
: dmlc::Error(msg_), msg(msg_), index(index) {}
37-
};
38-
3916
/*!
4017
* \brief Parse keyword arguments as PType arguments and save to parsed
4118
* \tparam PType the arameter type.
@@ -128,41 +105,88 @@ inline bool type_assign(int *y, const int& x) {
128105
return true;
129106
}
130107

108+
template<typename AttrType>
109+
inline std::string attr_assign_error_msg(const NodeAttrs& attrs,
110+
int index, bool is_input,
111+
const AttrType& expected,
112+
const AttrType& actual,
113+
const char* attr_name) {
114+
static const auto& flist_inputs = Op::GetAttr<FListInputNames>("FListInputNames");
115+
static const auto& flist_outputs = Op::GetAttr<FListOutputNames>("FListOutputNames");
116+
const auto& flist = is_input ? flist_inputs : flist_outputs;
117+
std::string name;
118+
if (flist.count(attrs.op)) {
119+
name = flist[attrs.op](attrs)[index];
120+
} else {
121+
name = (is_input ? "data" : "output") + std::to_string(index);
122+
}
123+
std::ostringstream msg;
124+
msg << "Operator " << attrs.op->name << "(";
125+
for (const auto& kv : attrs.dict) msg << kv.first << "=" << kv.second << ", ";
126+
msg << "name=" << attrs.name << ") expects " << name << "\'s " << attr_name
127+
<< " to be " << expected << ", but got " << actual << ".";
128+
return msg.str();
129+
}
130+
131131
/*!
132132
* \brief macro assign shape to out if out is unknown otherwise check consistency
133133
* Use macro so we can see the error file more clearly
134-
* \param shape_array the shape array to store the result
134+
* \param inputs the shape array to store the result
135135
* \param index the index of in the array
136136
* \param shape the inferred shape
137137
*/
138-
#define SHAPE_ASSIGN_CHECK(shape_array, index, shape) \
139-
{ \
140-
if (!shape_assign(&(shape_array)[index], TShape(shape))) { \
141-
std::ostringstream os; \
142-
os << "Shape inconsistent, Provided=" << (shape_array)[index] << ',' \
143-
<< " inferred shape=" << shape; \
144-
throw InferShapeError(os.str(), index); \
145-
} \
138+
#define NNVM_ASSIGN_INPUT_SHAPE(attrs, inputs, index, shape) \
139+
{ \
140+
if (!shape_assign(&(inputs)[index], TShape(shape))) { \
141+
LOG(FATAL) << attr_assign_error_msg(attrs, index, true, shape, \
142+
(inputs)[index], "shape"); \
143+
} \
144+
}
145+
146+
/*!
147+
* \brief macro assign shape to out if out is unknown otherwise check consistency
148+
* Use macro so we can see the error file more clearly
149+
* \param inputs the shape array to store the result
150+
* \param index the index of in the array
151+
* \param shape the inferred shape
152+
*/
153+
#define NNVM_ASSIGN_OUTPUT_SHAPE(attrs, outputs, index, shape) \
154+
{ \
155+
if (!shape_assign(&(outputs)[index], TShape(shape))) { \
156+
LOG(FATAL) << attr_assign_error_msg(attrs, index, false, shape, \
157+
(outputs)[index], "shape"); \
158+
} \
146159
}
147160

148161
/*!
149162
* \brief macro assign type to out if out is unknown (-1) otherwise check consistency
150163
* Use macro so we can see the error file more clearly
151-
* \param type_array the type array to store the result
164+
* \param inputs the type array to store the result
152165
* \param index the index of in the array
153166
* \param type the inferred type
154167
*/
155-
#define TYPE_ASSIGN_CHECK(type_array, index, type) \
156-
{ \
157-
if (!type_assign(&(type_array)[index], type)) { \
158-
std::ostringstream os; \
159-
os << "Type inconsistent, Provided=" \
160-
<< type_string((type_array)[index]) << ',' \
161-
<< " inferred type=" << type_string(type); \
162-
throw InferTypeError(os.str(), index); \
163-
} \
168+
#define NNVM_ASSIGN_INPUT_TYPE(attrs, inputs, index, type) \
169+
{ \
170+
if (!type_assign(&(inputs)[index], type)) { \
171+
LOG(FATAL) << attr_assign_error_msg(attrs, index, true, type, \
172+
(inputs)[index], "type"); \
173+
} \
164174
}
165175

176+
/*!
177+
* \brief macro assign type to out if out is unknown (-1) otherwise check consistency
178+
* Use macro so we can see the error file more clearly
179+
* \param inputs the type array to store the result
180+
* \param index the index of in the array
181+
* \param type the inferred type
182+
*/
183+
#define NNVM_ASSIGN_OUTPUT_TYPE(attrs, outputs, index, type) \
184+
{ \
185+
if (!type_assign(&(outputs)[index], type)) { \
186+
LOG(FATAL) << attr_assign_error_msg(attrs, index, false, type, \
187+
(outputs)[index], "type"); \
188+
} \
189+
}
166190

167191
// simply return the shape as same
168192
inline bool SameShape(const NodeAttrs& attrs,

nnvm/src/top/tensor.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ inline bool FlattenInferShape(const nnvm::NodeAttrs& attrs,
6464
for (uint32_t i = 1; i < dshape.ndim(); ++i) {
6565
target_dim *= dshape[i];
6666
}
67-
SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({dshape[0], target_dim}));
67+
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, TShape({dshape[0], target_dim}));
6868
return true;
6969
}
7070

@@ -130,11 +130,11 @@ inline bool ConcatenateInferShape(const nnvm::NodeAttrs& attrs,
130130
if (dshape.ndim() == 0) return false;
131131

132132
for (size_t i = 0; i < in_shape->size(); ++i) {
133-
SHAPE_ASSIGN_CHECK(*in_shape, i, dshape);
133+
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, i, dshape);
134134
}
135135

136136
if (!has_zero) dshape[param.axis] = size;
137-
SHAPE_ASSIGN_CHECK(*out_shape, 0, dshape);
137+
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, dshape);
138138
return dshape.Size() != 0;
139139
}
140140

@@ -210,7 +210,7 @@ inline bool CastInferType(const nnvm::NodeAttrs& attrs,
210210
std::vector<int> *out_attrs) {
211211
const CastParam& param = nnvm::get<CastParam>(attrs.parsed);
212212
CHECK_EQ(out_attrs->size(), 1U);
213-
TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype);
213+
NNVM_ASSIGN_OUTPUT_TYPE(attrs, *out_attrs, 0, param.dtype);
214214
return true;
215215
}
216216

0 commit comments

Comments
 (0)