Skip to content

Commit

Permalink
Squeeze bug fix. (#506)
Browse files Browse the repository at this point in the history
  • Loading branch information
srkreddy1238 authored and tqchen committed May 29, 2018
1 parent 918dff2 commit 6c0b8ec
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 3 deletions.
9 changes: 6 additions & 3 deletions nnvm/src/top/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -638,11 +638,14 @@ inline bool SqueezeShape(const nnvm::NodeAttrs& attrs,
} else {
std::unordered_set<dim_t> axis_checker;
for (size_t i = 0; i < param.axis.ndim(); ++i) {
int real_axis;
if (param.axis[i] < 0) {
int real_axis = param.axis[i] + static_cast<int>(shp.ndim());
CHECK(real_axis < static_cast<int>(shp.ndim()) && real_axis >= 0);
axis_checker.insert(real_axis);
real_axis = param.axis[i] + static_cast<int>(shp.ndim());
} else {
real_axis = param.axis[i];
}
CHECK(real_axis < static_cast<int>(shp.ndim()) && real_axis >= 0);
axis_checker.insert(real_axis);
}
for (size_t i = 0; i < shp.ndim(); ++i) {
if (axis_checker.find(i) == axis_checker.end()) {
Expand Down
19 changes: 19 additions & 0 deletions nnvm/tests/python/unittest/test_infer_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,24 @@ def test_flatten():
sdict = infer_shape(y)
assert(sdict["y"][0] == [10, 200])

def test_squeeze():
x = sym.Variable("x", shape=(1, 1, 1, 10))
y = sym.squeeze(x, axis=(1,2), name='squeeze')
sdict = infer_shape(y)
assert(sdict['squeeze'][0] == [1, 10])

x = sym.Variable("x", shape=(1, 3, 1))
y = sym.squeeze(x, name='squeeze')
sdict = infer_shape(y)
assert(sdict['squeeze'][0] == [3])

y = sym.squeeze(x, axis=(0), name='squeeze')
sdict = infer_shape(y)
assert(sdict['squeeze'][0] == [3, 1])

y = sym.squeeze(x, axis=(0,2), name='squeeze')
sdict = infer_shape(y)
assert(sdict['squeeze'][0] == [3])

# Level 2
def test_conv2d():
Expand Down Expand Up @@ -331,3 +349,4 @@ def check(in_shape, out_shape, **kwargs):
test_reduce()
test_transpose()
test_prelu()
test_squeeze()

0 comments on commit 6c0b8ec

Please sign in to comment.