-
Notifications
You must be signed in to change notification settings - Fork 826
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix upsample nearest bug #5347
Merged
Merged
fix upsample nearest bug #5347
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
34ee0a6
fix upsample nearest bug
BBuf 073dda1
Merge branch 'master' into fix_upsample_bug
BBuf c7831bc
Merge branch 'master' into fix_upsample_bug
oneflow-ci-bot 74a225d
Merge branch 'master' into fix_upsample_bug
oneflow-ci-bot 1ef6b8b
Merge branch 'master' into fix_upsample_bug
oneflow-ci-bot fb8e7f0
Merge branch 'master' into fix_upsample_bug
oneflow-ci-bot File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,7 +23,7 @@ namespace { | |
|
||
static int64_t GetNearestInputIndex(const int64_t out_dim_idx, const float scale, | ||
const int64_t in_dim_size) { | ||
int64_t index = static_cast<int64_t>(floorf((static_cast<float>(out_dim_idx) + 0.5f) * scale)); | ||
int64_t index = static_cast<int64_t>(std::floor((static_cast<float>(out_dim_idx) * scale))); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
index = index > in_dim_size - 1 ? in_dim_size - 1 : index; | ||
index = index < static_cast<int64_t>(0) ? static_cast<int64_t>(0) : index; | ||
return index; | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -65,8 +65,8 @@ REGISTER_USER_OP("upsample_grad") | |
LOG(FATAL) << "upsample_nearest only supports NCHW"; | ||
} | ||
*dx_shape = Shape({dy_shape.At(0), dy_shape.At(1), | ||
static_cast<int32_t>(dy_shape.At(2) / height_scale), | ||
static_cast<int32_t>(dy_shape.At(3) / width_scale)}); | ||
static_cast<int32_t>(std::round(dy_shape.At(2) / height_scale)), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. scale如果取1.5,输入图为3x3,那么输出图就是4x4,然后4/1.5=2.666xxx, 安装之前的做法直接取整这种情况会意外丢掉一个像素。 |
||
static_cast<int32_t>(std::round(dy_shape.At(3) / width_scale))}); | ||
return Maybe<void>::Ok(); | ||
}) | ||
.SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个地方的测试,一个scale_factor有点太单薄了,我建议scale_factor由out_size / in_size算出来,可以设置多组out_size和in_size来算.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我pr interpolate的时候加上吧,其实这里是因为昨天用interpolate包Upsample测试的时候没过,发现不对劲:cry:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,那麻烦ziqi加吧,我本地验证了几组数据没问题。