Skip to content

Commit

Permalink
reshape (#2342)
Browse files Browse the repository at this point in the history
  • Loading branch information
antinucleon committed Jun 5, 2016
1 parent f123a64 commit 26f9dfd
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 3 deletions.
28 changes: 25 additions & 3 deletions src/operator/reshape-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,18 @@ enum ReshapeOpOutputs {kOut};


struct ShapeInfo {
inline size_t ndim() const {
return info.size();
}

inline size_t Size() const {
size_t sz = 1;
for (size_t i = 0; i < info.size(); ++i) {
sz *= info[i];
}
return sz;
}

std::vector<int> info;
};

Expand All @@ -40,9 +52,18 @@ inline std::istream &operator>>(std::istream &is, ShapeInfo &shape) {
return is;
}
}

int idx;
std::vector<int> tmp;
// deal with empty case
// safe to remove after stop using target_shape
size_t pos = is.tellg();
char ch = is.get();
if (ch == ')') {
shape.info = tmp;
return is;
}
is.seekg(pos);
// finish deal
while (is >> idx) {
tmp.push_back(idx);
char ch;
Expand Down Expand Up @@ -90,8 +111,9 @@ struct ReshapeParam : public dmlc::Parameter<ReshapeParam> {
bool keep_highest;
ShapeInfo shape;
DMLC_DECLARE_PARAMETER(ReshapeParam) {
int tmp[] = {0, 0};
DMLC_DECLARE_FIELD(target_shape)
.set_default(TShape())
.set_default(TShape(tmp, tmp + 2))
.describe("(Deprecated! Use shape instead.) Target new shape. One and only one dim can be 0, "
"in which case it will be inferred from the rest of dims");
DMLC_DECLARE_FIELD(keep_highest).set_default(false)
Expand Down Expand Up @@ -180,7 +202,7 @@ class ReshapeProp : public OperatorProperty {
param_.shape.info.size() > 0, true) << "targe_shape or shape must be present.";
const TShape &dshape = in_shape->at(reshape_enum::kData);
if (dshape.ndim() == 0) return false;
if (param_.target_shape.ndim() == 0) {
if (param_.shape.ndim() != 0) {
std::vector<int> tmp;
int src_idx = 0;
int neg_idx = -1;
Expand Down
4 changes: 4 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,6 +889,8 @@ def test_reshape():
def test_reshape_new(src_shape, shape_args, dst_shape):
net = mx.sym.Variable("data")
net = mx.sym.Reshape(net, shape=shape_args)
js = net.tojson()
net = mx.sym.load_json(js)
_, output_shape, __ = net.infer_shape(data=src_shape)
assert output_shape[0] == dst_shape, \
'Src Shape = %s, Shape Arguments = %s, Dst Shape = %s, Output Shape = %s' \
Expand Down Expand Up @@ -919,6 +921,8 @@ def test_reshape_new(src_shape, shape_args, dst_shape):
# Test old api
net = mx.sym.Variable("data")
net = mx.sym.Reshape(net, target_shape=(2, 0))
js = net.tojson()
net = mx.sym.load_json(js)
_, output_shape, __ = net.infer_shape(data=(2, 3, 5, 5))
assert(output_shape[0] == (2, 75))

Expand Down

0 comments on commit 26f9dfd

Please sign in to comment.