Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

align reshape input param with pytorch #5804

Merged
merged 36 commits into from
Aug 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
6f9fe41
align reshape input param with pytorch
BBuf Aug 9, 2021
d6321a0
Merge branch 'master' into fix_reshape_bug
BBuf Aug 9, 2021
16dd548
Merge branch 'master' into fix_reshape_bug
BBuf Aug 9, 2021
00ed776
Merge branch 'master' into fix_reshape_bug
oneflow-ci-bot Aug 9, 2021
68dc951
fix pixel_shuffle impl
BBuf Aug 10, 2021
a6afec8
Merge branch 'master' into fix_reshape_bug
BBuf Aug 10, 2021
19f3fa0
Merge branch 'master' into fix_reshape_bug
oneflow-ci-bot Aug 10, 2021
ad6b24a
fix ci error
BBuf Aug 10, 2021
356439c
Merge branch 'fix_reshape_bug' of https://github.com/Oneflow-Inc/onef…
BBuf Aug 10, 2021
dd7bb3a
Merge branch 'master' into fix_reshape_bug
oneflow-ci-bot Aug 10, 2021
d456970
Merge branch 'master' into fix_reshape_bug
oneflow-ci-bot Aug 10, 2021
c542267
Merge branch 'master' into fix_reshape_bug
oneflow-ci-bot Aug 10, 2021
b00701b
Merge branch 'master' into fix_reshape_bug
oneflow-ci-bot Aug 10, 2021
02a94b5
Merge branch 'master' into fix_reshape_bug
oneflow-ci-bot Aug 10, 2021
2fb31e1
Merge branch 'master' into fix_reshape_bug
oneflow-ci-bot Aug 10, 2021
4ab6d28
Merge branch 'master' into fix_reshape_bug
oneflow-ci-bot Aug 10, 2021
eb8befb
fix reshape result bug
BBuf Aug 11, 2021
7174818
Merge branch 'fix_reshape_bug' of https://github.com/Oneflow-Inc/onef…
BBuf Aug 11, 2021
577ad29
Merge branch 'master' into fix_reshape_bug
BBuf Aug 11, 2021
2c7bb13
Merge branch 'master' into fix_reshape_bug
oneflow-ci-bot Aug 11, 2021
ef9bde1
Merge branch 'master' into fix_reshape_bug
oneflow-ci-bot Aug 11, 2021
ac0e6f3
Merge branch 'master' into fix_reshape_bug
oneflow-ci-bot Aug 11, 2021
449eccf
Merge branch 'master' into fix_reshape_bug
BBuf Aug 12, 2021
bf95139
fix ci error
BBuf Aug 12, 2021
d0263d6
Merge branch 'master' into fix_reshape_bug
oneflow-ci-bot Aug 12, 2021
5728725
Merge branch 'master' into fix_reshape_bug
oneflow-ci-bot Aug 12, 2021
d00314b
Merge branch 'master' into fix_reshape_bug
oneflow-ci-bot Aug 12, 2021
f891d8e
Merge branch 'master' into fix_reshape_bug
oneflow-ci-bot Aug 12, 2021
b0d0338
Merge branch 'master' into fix_reshape_bug
oneflow-ci-bot Aug 12, 2021
54a69bc
Merge branch 'master' into fix_reshape_bug
oneflow-ci-bot Aug 12, 2021
6126074
Merge branch 'master' into fix_reshape_bug
oneflow-ci-bot Aug 12, 2021
2cae2c0
Merge branch 'master' into fix_reshape_bug
oneflow-ci-bot Aug 12, 2021
21b5312
Merge branch 'master' into fix_reshape_bug
oneflow-ci-bot Aug 12, 2021
a0ae12a
Merge branch 'master' into fix_reshape_bug
oneflow-ci-bot Aug 12, 2021
840f5f6
Merge branch 'master' into fix_reshape_bug
oneflow-ci-bot Aug 12, 2021
bd8950e
Merge branch 'master' into fix_reshape_bug
oneflow-ci-bot Aug 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions python/oneflow/nn/modules/batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,12 @@ def forward(self, x):
if len(mean.shape) == 1:
nd_params_shape = [1] * len(x.shape)
nd_params_shape[axis] = params_shape[0]
mean = mean.reshape(shape=nd_params_shape)
variance = variance.reshape(shape=nd_params_shape)
mean = flow.reshape(mean, shape=nd_params_shape)
variance = flow.reshape(variance, shape=nd_params_shape)
if self.weight and params_shape[0] == self.weight.nelement():
weight = self.weight.reshape(shape=nd_params_shape)
weight = flow.reshape(self.weight, shape=nd_params_shape)
if self.bias and params_shape[0] == self.bias.nelement():
bias = self.bias.reshape(shape=nd_params_shape)
bias = flow.reshape(self.bias, shape=nd_params_shape)
elif len(mean.shape) == len(x.shape):
pass
else:
Expand Down
8 changes: 4 additions & 4 deletions python/oneflow/nn/modules/instancenorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def _forward(self, x):
variance = x.var(2, keepdim=True)
normalized = (x - mean) / flow.sqrt(variance + self.eps)
if self.weight and params_shape[0] == self.weight.nelement():
weight = self.weight.reshape(shape=nd_params_shape)
weight = flow.reshape(self.weight, shape=nd_params_shape)
if self.bias and params_shape[0] == self.bias.nelement():
bias = self.bias.reshape(shape=nd_params_shape)
bias = flow.reshape(self.bias, shape=nd_params_shape)
if self.weight:
normalized = normalized * weight
if self.bias:
Expand All @@ -50,9 +50,9 @@ def _forward(self, x):

def forward(self, x):
self._check_input_dim(x)
reshape_to_1d = x.reshape([x.shape[0], x.shape[1], -1])
reshape_to_1d = flow.reshape(x, [x.shape[0], x.shape[1], -1])
normalized_1d_out = self._forward(reshape_to_1d)
reshape_back_to_nd = normalized_1d_out.reshape(list(x.shape))
reshape_back_to_nd = flow.reshape(normalized_1d_out, list(x.shape))
return reshape_back_to_nd


Expand Down
14 changes: 7 additions & 7 deletions python/oneflow/nn/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def forward(self, input, target):
if input_shape_len == 3:
(b, c, h) = (input.shape[0], input.shape[1], input.shape[2])
input = flow.F.transpose(input, perm=(0, 2, 1))
input = input.reshape(shape=[-1, input.shape[2]])
input = flow.reshape(input, shape=[-1, input.shape[2]])
target = target.flatten()
elif input_shape_len == 4:
(b, c, h, w) = (
Expand All @@ -191,7 +191,7 @@ def forward(self, input, target):
input.shape[3],
)
input = flow.F.transpose(input, perm=(0, 2, 3, 1))
input = input.reshape(shape=[-1, input.shape[3]])
input = flow.reshape(input, shape=[-1, input.shape[3]])
target = target.flatten()
elif input_shape_len >= 5:
raise NotImplemented
Expand Down Expand Up @@ -423,10 +423,10 @@ def forward(self, input, target):
elif len(input.shape) == 3:
(b, c, h) = (input.shape[0], input.shape[1], input.shape[2])
input = flow.F.transpose(input, perm=(0, 2, 1))
input = input.reshape(shape=[-1, input.shape[2]])
input = flow.reshape(input, shape=[-1, input.shape[2]])
target = target.flatten()
res = self.nllloss_1d(input, target)
res = res.reshape((b, h))
res = res.reshape(b, h)
elif len(input.shape) == 4:
(b, c, h, w) = (
input.shape[0],
Expand All @@ -435,10 +435,10 @@ def forward(self, input, target):
input.shape[3],
)
input = flow.F.transpose(input, perm=(0, 2, 3, 1))
input = input.reshape(shape=[-1, input.shape[3]])
input = input.reshape(-1, input.shape[3])
target = target.flatten()
res = self.nllloss_1d(input, target)
res = res.reshape((b, h, w))
res = res.reshape(b, h, w)
else:
raise NotImplemented
if self.ignore_index is not None:
Expand All @@ -447,7 +447,7 @@ def forward(self, input, target):
ones = flow.ones(
condition.shape, dtype=condition.dtype, device=condition.device
)
condition = ones.sub(condition).reshape(tuple(res.shape))
condition = flow.reshape(ones.sub(condition), tuple(res.shape))
res = flow.where(condition, res, zeros)
if self.reduction == "mean":
res = res.sum()
Expand Down
2 changes: 1 addition & 1 deletion python/oneflow/nn/modules/meshgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def forward(self, inputs):
for i in range(size):
view_shape = [1] * size
view_shape[i] = -1
outputs.append(inputs[i].reshape(view_shape).expand(*shape))
outputs.append(flow.reshape(inputs[i], view_shape).expand(*shape))
return outputs


Expand Down
8 changes: 4 additions & 4 deletions python/oneflow/nn/modules/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def norm_op(input, ord=None, dim=None, keepdim=False):
>>> a = flow.tensor(np.arange(9, dtype=np.float32) - 4)
>>> a
tensor([-4., -3., -2., -1., 0., 1., 2., 3., 4.], dtype=oneflow.float32)
>>> b = a.reshape((3, 3))
>>> b = a.reshape(3, 3)
>>> b
tensor([[-4., -3., -2.],
[-1., 0., 1.],
Expand Down Expand Up @@ -293,7 +293,7 @@ def norm_op(input, ord=None, dim=None, keepdim=False):

Using the :attr:`dim` argument to compute matrix norms::

>>> m = flow.tensor(np.arange(8, dtype=np.float32)).reshape((2, 2, 2))
>>> m = flow.tensor(np.arange(8, dtype=np.float32)).reshape(2, 2, 2)
>>> LA.norm(m, dim=(1,2))
tensor([ 3.7417, 11.225 ], dtype=oneflow.float32)
"""
Expand Down Expand Up @@ -360,7 +360,7 @@ def vector_norm_tensor_op(input, ord=2, dim=None, keepdim=False):
>>> a = flow.tensor(np.arange(9, dtype=np.float32) - 4)
>>> a
tensor([-4., -3., -2., -1., 0., 1., 2., 3., 4.], dtype=oneflow.float32)
>>> b = a.reshape((3, 3))
>>> b = a.reshape(3, 3)
>>> b
tensor([[-4., -3., -2.],
[-1., 0., 1.],
Expand Down Expand Up @@ -419,7 +419,7 @@ def matrix_norm_tensor_op(input, ord="fro", dim=(-2, -1), keepdim=False):
>>> import oneflow as flow
>>> from oneflow import linalg as LA
>>> import numpy as np
>>> a = flow.tensor(np.arange(9, dtype=np.float32)).reshape((3,3))
>>> a = flow.tensor(np.arange(9, dtype=np.float32)).reshape(3,3)
>>> a
tensor([[0., 1., 2.],
[3., 4., 5.],
Expand Down
8 changes: 4 additions & 4 deletions python/oneflow/nn/modules/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,12 +267,12 @@ def forward(self, x):
if len(mean.shape) == 1:
nd_params_shape = [1] * len(x.shape)
nd_params_shape[self.begin_norm_axis] = params_shape[0]
mean = mean.reshape(shape=nd_params_shape)
variance = variance.reshape(shape=nd_params_shape)
mean = flow.reshape(mean, shape=nd_params_shape)
variance = flow.reshape(variance, nd_params_shape)
if self.weight and params_shape[0] == self.weight.nelement():
weight = self.weight.reshape(shape=nd_params_shape)
weight = flow.reshape(self.weight, shape=nd_params_shape)
if self.bias and params_shape[0] == self.bias.nelement():
bias = self.bias.reshape(shape=nd_params_shape)
bias = flow.reshape(self.bias, shape=nd_params_shape)
elif len(mean.shape) == len(x.shape):
pass
else:
Expand Down
36 changes: 15 additions & 21 deletions python/oneflow/nn/modules/pixelshuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,32 +115,26 @@ def forward(self, input: Tensor) -> Tensor:
), "The channels of input tensor must be divisible by (upscale_factor * upscale_factor) or (h_upscale_factor * w_upscale_factor)"
_new_c = int(_channel / (self.h_upscale_factor * self.w_upscale_factor))
out = input.reshape(
[
_batch,
_new_c,
self.h_upscale_factor * self.w_upscale_factor,
_height,
_width,
]
_batch,
_new_c,
self.h_upscale_factor * self.w_upscale_factor,
_height,
_width,
)
out = out.reshape(
[
_batch,
_new_c,
self.h_upscale_factor,
self.w_upscale_factor,
_height,
_width,
]
_batch,
_new_c,
self.h_upscale_factor,
self.w_upscale_factor,
_height,
_width,
)
out = out.permute(0, 1, 4, 2, 5, 3)
out = out.reshape(
[
_batch,
_new_c,
_height * self.h_upscale_factor,
_width * self.w_upscale_factor,
]
_batch,
_new_c,
_height * self.h_upscale_factor,
_width * self.w_upscale_factor,
)
return out

Expand Down
36 changes: 35 additions & 1 deletion python/oneflow/nn/modules/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def _input_args_is_flow_size(args):
return all((isinstance(x, flow.Size) for x in args)) and len(args) == 1


@register_tensor_op("reshape")
def reshape_op(input, shape: Sequence[int] = None):
"""This operator reshapes a Tensor.

Expand Down Expand Up @@ -60,6 +59,41 @@ def reshape_op(input, shape: Sequence[int] = None):
return flow.F.reshape(input, shape)


@register_tensor_op("reshape")
def reshape_tensor_op(input, *shape):
"""This operator reshapes a Tensor.

We can set one dimension in `shape` as `-1`, the operator will infer the complete shape.

Args:
x: A Tensor.
*shape: tuple of python::ints or int...
Returns:
A Tensor has the same type as `x`.

For example:

.. code-block:: python

>>> import numpy as np
>>> import oneflow as flow
>>> x = np.array(
... [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]
... ).astype(np.float32)
>>> input = flow.Tensor(x)

>>> y = input.reshape(2, 2, 2, -1).shape
>>> y
flow.Size([2, 2, 2, 2])

"""
if _input_args_is_int(shape):
new_shape = _single(shape)
else:
raise ValueError("the input shape parameter of reshape is not illegal!")
return flow.F.reshape(input, new_shape)


@register_tensor_op("view")
def view_op(input, *shape):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(self):
super(FlattenLayer, self).__init__()

def forward(self, x): # x shape: (batch, *, *, ...)
res = x.reshape(shape=[x.shape[0], -1])
res = x.reshape(x.shape[0], -1)
return res


Expand Down
4 changes: 2 additions & 2 deletions python/oneflow/test/tensor/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,15 +578,15 @@ def _test_tensor_reshape(test_case):
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]
).astype(np.float32)
input = flow.Tensor(x)
of_shape = input.reshape(shape=[2, 2, 2, -1]).numpy().shape
of_shape = input.reshape(2, 2, 2, -1).numpy().shape
np_shape = (2, 2, 2, 2)
test_case.assertTrue(np.array_equal(of_shape, np_shape))

@autotest()
def test_reshape_tensor_with_random_data(test_case):
device = random_device()
x = random_pytorch_tensor(ndim=4).to(device)
y = x.reshape(shape=(-1,))
y = x.reshape(-1,)
return y

@autotest()
Expand Down
2 changes: 1 addition & 1 deletion python/oneflow/utils/vision/datasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> Tensor:
s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)]
parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1)))
assert parsed.shape[0] == np.prod(s) or not strict
return Tensor(parsed.astype(m[2]), dtype=m[0]).reshape(shape=s)
return flow.reshape(Tensor(parsed.astype(m[2]), dtype=m[0]), shape=s)


def read_label_file(path: str) -> Tensor:
Expand Down
6 changes: 3 additions & 3 deletions python/oneflow/utils/vision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def to_tensor(pic):
if pic.mode == "1":
img = 255 * img

img = img.reshape(shape=(pic.size[1], pic.size[0], len(pic.getbands())))
img = flow.reshape(img, shape=(pic.size[1], pic.size[0], len(pic.getbands())))
# put it from HWC to CHW format
res = img.permute(2, 0, 1)
if img.dtype == flow.int:
Expand Down Expand Up @@ -189,9 +189,9 @@ def normalize(
)
)
if mean.ndim == 1:
mean = mean.reshape(shape=(-1, 1, 1))
mean = mean.reshape(-1, 1, 1)
if std.ndim == 1:
std = std.reshape(shape=(-1, 1, 1))
std = std.reshape(-1, 1, 1)
tensor = tensor.sub(mean).div(std)
# tensor.sub_(mean).div_(std)
return tensor
Expand Down