Skip to content

Commit

Permalink
[OP] new mode to reshape (#2114)
Browse files Browse the repository at this point in the history

* [OP] new mode to reshape

* update desc
  • Loading branch information
antinucleon committed May 13, 2016
1 parent 6d99054 commit b5a3693
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 29 deletions.
157 changes: 136 additions & 21 deletions src/operator/reshape-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,84 @@ enum ReshapeOpInputs {kData};
enum ReshapeOpOutputs {kOut};
} // namespace reshape_enum


struct ShapeInfo {
std::vector<int> info;
};


inline std::istream &operator>>(std::istream &is, ShapeInfo &shape) {
while (true) {
char ch = is.get();
if (ch == '(') break;
if (!isspace(ch)) {
is.setstate(std::ios::failbit);
return is;
}
}

int idx;
std::vector<int> tmp;
while (is >> idx) {
tmp.push_back(idx);
char ch;
do {
ch = is.get();
} while (isspace(ch));
if (ch == ',') {
while (true) {
ch = is.peek();
if (isspace(ch)) {
is.get(); continue;
}
if (ch == ')') {
is.get(); break;
}
break;
}
if (ch == ')') break;
} else if (ch == ')') {
break;
} else {
is.setstate(std::ios::failbit);
return is;
}
}
shape.info = tmp;
return is;
}


inline std::ostream &operator<<(std::ostream &os, const ShapeInfo &shape) {
os << '(';
for (index_t i = 0; i < shape.info.size(); ++i) {
if (i != 0) os << ',';
os << shape.info[i];
}
// python style tuple
if (shape.info.size() == 1) os << ',';
os << ')';
return os;
}

struct ReshapeParam : public dmlc::Parameter<ReshapeParam> {
TShape target_shape;
bool keep_highest;
ShapeInfo shape;
DMLC_DECLARE_PARAMETER(ReshapeParam) {
DMLC_DECLARE_FIELD(target_shape)
.describe("Target new shape. One and only one dim can be 0, "
.set_default(TShape())
.describe("(Deprecated! Use shape instead.) Target new shape. One and only one dim can be 0, "
"in which case it will be inferred from the rest of dims");
DMLC_DECLARE_FIELD(keep_highest).set_default(false)
.describe("Whether keep the highest dim unchanged."
.describe("(Deprecated! Use shape instead.) Whether keep the highest dim unchanged."
"If set to yes, than the first dim in target_shape is ignored,"
"and always fixed as input");
DMLC_DECLARE_FIELD(shape)
.set_default(ShapeInfo())
.describe("Target new shape. If the dim is same, set it to 0. If the dim is set "
"to be -1, it will be inferred from the rest of dims. One and only one dim "
"can be -1");
}
};

Expand Down Expand Up @@ -109,32 +176,80 @@ class ReshapeProp : public OperatorProperty {
std::vector<TShape> *out_shape,
std::vector<TShape> *aux_shape) const override {
CHECK_EQ(in_shape->size(), 1) << "Input: [data]";
CHECK_EQ(param_.target_shape.ndim() > 0 ||
param_.shape.info.size() > 0, true) << "targe_shape or shape must be present.";
const TShape &dshape = in_shape->at(reshape_enum::kData);
if (dshape.ndim() == 0) return false;
TShape oshape = param_.target_shape;
int neg_count = 0;
index_t neg_idx = 0;
index_t start_idx = param_.keep_highest ? 1 : 0;
if (param_.keep_highest) {
oshape[0] = dshape[0];
}
for (index_t i = start_idx; i < oshape.ndim(); ++i) {
if (oshape[i] == 0) {
neg_count++;
neg_idx = i;
if (param_.target_shape.ndim() == 0) {
std::vector<int> tmp;
int src_idx = 0;
int neg_idx = -1;
size_t new_size = dshape.Size();
bool keep = true;
for (index_t i = 0; i < param_.shape.info.size(); ++i) {
int proposed_dim = param_.shape.info[i];
if (proposed_dim == 0) {
// keep same
CHECK_EQ(keep, true) << "After set manual dim, can't keep original dim";
tmp.push_back(dshape[src_idx++]);
new_size /= tmp.back();
} else if (proposed_dim < 0) {
// infer
CHECK_LT(neg_idx, 0) << "One and only one dim can be inferenced";
neg_idx = i;
tmp.push_back(0);
src_idx++;
} else {
// great than 0, new shape
CHECK_EQ(new_size % proposed_dim, 0) << "Illegal dim setting, can't be divided.";
tmp.push_back(proposed_dim);
new_size /= proposed_dim;
// after set manual shape, can't keep same
if (param_.shape.info.size() != dshape.ndim()) {
keep = false;
} else {
src_idx++;
}
}
}
}
if (neg_count == 1) {
oshape[neg_idx] = 1;
oshape[neg_idx] = dshape.Size()/oshape.Size();
}

CHECK(oshape.Size() == dshape.Size())
if (neg_idx > 0) {
tmp[neg_idx] = new_size;
}
TShape oshape(tmp.begin(), tmp.end());
CHECK_EQ(oshape.Size(), dshape.Size())
<< "Target shape size is different to source. "
<< "Target: " << param_.target_shape.Size()
<< "\nSource: " << dshape.Size();
out_shape->clear();
out_shape->push_back(oshape);
out_shape->clear();
out_shape->push_back(oshape);
} else {
LOG(INFO) << "Using target_shape will be deprecated.";
TShape oshape = param_.target_shape;
int neg_count = 0;
index_t neg_idx = 0;
index_t start_idx = param_.keep_highest ? 1 : 0;
if (param_.keep_highest) {
oshape[0] = dshape[0];
}
for (index_t i = start_idx; i < oshape.ndim(); ++i) {
if (oshape[i] == 0) {
neg_count++;
neg_idx = i;
}
}
if (neg_count == 1) {
oshape[neg_idx] = 1;
oshape[neg_idx] = dshape.Size() / oshape.Size();
}

CHECK(oshape.Size() == dshape.Size())
<< "Target shape size is different to source. "
<< "Target: " << param_.target_shape.Size()
<< "\nSource: " << dshape.Size();
out_shape->clear();
out_shape->push_back(oshape);
}
return true;
}

Expand Down
4 changes: 2 additions & 2 deletions src/operator/reshape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ DMLC_REGISTER_PARAMETER(ReshapeParam);

MXNET_REGISTER_OP_PROPERTY(Reshape, ReshapeProp)
.describe("Reshape input to target shape")
.add_argument("data", "Symbol", "Input data to reshape.")
.add_argument("data", "Symbol", "Input data to reshape.")
.add_arguments(ReshapeParam::__FIELDS__());

MXNET_REGISTER_OP_PROPERTY(Flatten, FlattenProp)
.describe("Flatten input")
.add_argument("data", "Symbol", "Input data to flatten.");
.add_argument("data", "Symbol", "Input data to flatten.");
} // namespace op
} // namespace mxnet
41 changes: 35 additions & 6 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,16 +221,16 @@ def check_softmax_with_ignore_label(xpu):

grad0 = grad.asnumpy()

for i in range(shape[0]/2):
for i in range(int(shape[0]/2)):
l_np[i] = 0
l[:] = l_np

exec1.forward()
exec1.backward()
grad1 = grad.asnumpy()

assert abs(np.sum(grad1[:shape[0]/2])) < 1e-5
assert reldiff(grad0[shape[0]/2:], grad1[shape[0]/2:]) < 1e-5
assert(abs(np.sum(grad1[:int(shape[0]/2)])) < 1e-5)
assert(reldiff(grad0[int(shape[0]/2):], grad1[int(shape[0]/2):]) < 1e-5)

def check_softmax_with_shape(shape, xpu):
X = mx.symbol.Variable('X')
Expand Down Expand Up @@ -750,7 +750,7 @@ def test_run_convolution_dilated_impulse_response(dil=(1,1), kernel_shape=(3,3),
be.forward(True)
out_o = be.outputs[0].asnumpy()
ndo = be.outputs[0]

out_grads = np.zeros(shape=be.outputs[0].shape, dtype=np.float32)
out_grads[0,0, 16,16] = 1.0
out_grad = mx.nd.array(out_grads)
Expand Down Expand Up @@ -784,7 +784,7 @@ def test_run_convolution_dilated_impulse_response(dil=(1,1), kernel_shape=(3,3),
be.backward([impulse_error])
out_orig = be.outputs[0].asnumpy()
kernel_gradient = be.grad_arrays[1].asnumpy()

dkernel = mx.nd.array(rnd_kernel_s + kernel_gradient)

be = net.bind(mx.cpu(), args={ 'input' : white_in, 'test_convolution_weight' : dkernel})
Expand All @@ -798,7 +798,35 @@ def test_run_convolution_dilated_impulse_response(dil=(1,1), kernel_shape=(3,3),
def test_convolution_dilated_impulse_response():
for dil in [ (1,1), (2,2), (3,3) ]:
for ks in [ (3,3), (4,4), (2,3), (3,2), (1,1) ]:
test_run_convolution_dilated_impulse_response(dil=dil, kernel_shape=ks)
test_run_convolution_dilated_impulse_response(dil=dil, kernel_shape=ks)

def test_reshape():
# case 1:
net = mx.sym.Variable("data")
net = mx.sym.Reshape(net, shape=(0, -1))
_, output_shape, __ = net.infer_shape(data=(2, 3, 5, 5))
assert(output_shape[0] == (2, 75))
# case 2:
net = mx.sym.Variable("data")
net = mx.sym.Reshape(net, shape=(0, 0, -1))
_, output_shape, __ = net.infer_shape(data=(2, 3, 5, 5))
assert(output_shape[0] == (2, 3, 25))
# case 3:
net = mx.sym.Variable("data")
net = mx.sym.Reshape(net, shape=(5, 3, 0, -1))
_, output_shape, __ = net.infer_shape(data=(2, 3, 5, 5))
assert(output_shape[0] == (5, 3, 5, 2))
# case 4:
net = mx.sym.Variable("data")
net = mx.sym.Reshape(net, shape=(0, 0, 0, 0))
_, output_shape, __ = net.infer_shape(data=(2, 3, 5, 5))
assert(output_shape[0] == (2, 3, 5, 5))
# case 5: test old api
net = mx.sym.Variable("data")
net = mx.sym.Reshape(net, target_shape=(2, 0))
_, output_shape, __ = net.infer_shape(data=(2, 3, 5, 5))
assert(output_shape[0] == (2, 75))


if __name__ == '__main__':
test_convolution_grouping()
Expand All @@ -824,5 +852,6 @@ def test_convolution_dilated_impulse_response():
test_batchnorm_training()
check_softmax_with_ignore_label(mx.cpu())
test_convolution_dilated_impulse_response()
test_reshape()
#check_softmax_with_shape((3,4), mx.cpu())
#check_multi_softmax_with_shape((3,4,5), mx.cpu())

0 comments on commit b5a3693

Please sign in to comment.