Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix upsample nearest bug #5347

Merged
merged 6 commits into from
Jun 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions oneflow/python/test/modules/test_upsample2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,34 @@ def _test_upsample2d_bilinear_aligncorner_backward(test_case, device):
test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5))


def _test_interpolate_float_scale(test_case, device):
input = flow.Tensor(
np.arange(1, 10).reshape((1, 1, 3, 3)),
device=flow.device(device),
dtype=flow.float32,
requires_grad=True,
)
m = flow.nn.Upsample(scale_factor=1.5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方的测试,一个scale_factor有点太单薄了,我建议scale_factor由out_size / in_size算出来,可以设置多组out_size和in_size来算.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我pr interpolate的时候加上吧,其实这里是因为昨天用interpolate包Upsample测试的时候没过,发现不对劲:cry:

Copy link
Contributor Author

@BBuf BBuf Jun 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,那麻烦ziqi加吧,我本地验证了几组数据没问题。

of_out = m(input)
np_out = np.array(
[
[
[
[1.0, 1.0, 2.0, 3.0],
[1.0, 1.0, 2.0, 3.0],
[4.0, 4.0, 5.0, 6.0],
[7.0, 7.0, 8.0, 9.0],
]
]
]
)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
of_out = of_out.sum()
of_out.backward()
np_grad = np.array([[[[4.0, 2.0, 2.0], [2.0, 1.0, 1.0], [2.0, 1.0, 1.0]]]])
test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5))


@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
Expand All @@ -286,6 +314,7 @@ def test_upsample2d(test_case):
_test_upsample2d_bilinear_4dim,
_test_upsample2d_backward,
_test_upsample2d_bilinear_aligncorner_backward,
_test_interpolate_float_scale,
]
arg_dict["device"] = ["cpu", "cuda"]
for arg in GenArgList(arg_dict):
Expand Down
2 changes: 1 addition & 1 deletion oneflow/user/kernels/upsample_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace {

static int64_t GetNearestInputIndex(const int64_t out_dim_idx, const float scale,
const int64_t in_dim_size) {
int64_t index = static_cast<int64_t>(floorf((static_cast<float>(out_dim_idx) + 0.5f) * scale));
int64_t index = static_cast<int64_t>(std::floor((static_cast<float>(out_dim_idx) * scale)));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

index = index > in_dim_size - 1 ? in_dim_size - 1 : index;
index = index < static_cast<int64_t>(0) ? static_cast<int64_t>(0) : index;
return index;
Expand Down
2 changes: 1 addition & 1 deletion oneflow/user/kernels/upsample_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace {

__device__ int64_t GetNearestInputIndex(const int64_t out_dim_idx, const float scale,
const int64_t in_dim_size) {
return max(min(static_cast<int64_t>(floorf((static_cast<float>(out_dim_idx) + 0.5f) * scale)),
return max(min(static_cast<int64_t>(std::floor((static_cast<float>(out_dim_idx) * scale))),
in_dim_size - 1),
static_cast<int64_t>(0));
}
Expand Down
4 changes: 2 additions & 2 deletions oneflow/user/ops/upsample_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ REGISTER_USER_OP("upsample_grad")
LOG(FATAL) << "upsample_nearest only supports NCHW";
}
*dx_shape = Shape({dy_shape.At(0), dy_shape.At(1),
static_cast<int32_t>(dy_shape.At(2) / height_scale),
static_cast<int32_t>(dy_shape.At(3) / width_scale)});
static_cast<int32_t>(std::round(dy_shape.At(2) / height_scale)),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scale如果取1.5,输入图为3x3,那么输出图就是4x4,然后4/1.5=2.666xxx, 安装之前的做法直接取整这种情况会意外丢掉一个像素。

static_cast<int32_t>(std::round(dy_shape.At(3) / width_scale))});
return Maybe<void>::Ok();
})
.SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
Expand Down