-
Notifications
You must be signed in to change notification settings - Fork 825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add ReflectionPad2d #5172
add ReflectionPad2d #5172
Conversation
|
||
def forward(self, x): | ||
H, W = x.shape[2], x.shape[3] | ||
if self.padding[2] < H and self.padding[3] < H and self.padding[0] < W and self.padding[1] < W: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
你这个对padding_size的限制有出处吗,看了一下和pytorch没对齐。具体可以看:https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.pad
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
参考的是旧的接口
oneflow/oneflow/python/ops/pad.py
Line 294 in 04046fc
padding[2] < H and padding[3] < H and padding[0] < W and padding[1] < W |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯,这个限制是没问题的,torch的op在c++层也做了同样的检查,不符合会报错的
.Input("x") | ||
.Output("y") | ||
.Attr("padding", boundary) | ||
.Attr("floating_value", float(1.)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
此op没有floating_value 、integral_value属性、不需要加
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前这个op不传这两个属性会报错。因为C++端op里面有这三个属性。
oneflow/oneflow/user/ops/pad2d_ops.cpp
Line 51 in 0eec807
REGISTER_USER_OP(pad_2d_type) \ |
add ReflectionPad2d
doctest

docstring

unittest
