-
Notifications
You must be signed in to change notification settings - Fork 825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implementation of constantpad-3d op #5529
Conversation
@@ -0,0 +1,66 @@ | |||
/* |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个系列的文件名,叫做 pad3d_xxxx
,改作全称的 constantpad3d_xxxx
会不会更好。
还有就是,依照命名的约定,我们一般不用 xxx_kernels_util.h(cpp,cu)
而是 xxx_kernel_util.h(cpp,cu)
也就是 util
里的 kernel
不用复数。
(目前合到 master 里的只有一个用了 xxx_kernels_util
,应该是review时的漏网之鱼
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,命名已修改
return static_cast<int8_t>(integral); | ||
} | ||
|
||
template<> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个空模板有什么意义吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
和L26-27的配合使用
const int64_t d_idx = 2; | ||
const int64_t h_idx = 3; | ||
const int64_t w_idx = 4; | ||
// padding vector: [left, right, top, bottom, font, back] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
font->front
const NdIndexOffsetHelper<int64_t, 5>& index_helper, const ShapeView& x_shape, | ||
const ShapeView& y_shape, const std::vector<int64_t>& padding, | ||
IN_T constant_value) { | ||
// for NCDHW format input tensor, index of n,c,d,h,w is 0,1,2,3,4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注释是不是可以直接去掉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我留一处吧(其余的删掉
const int64_t d_idx = 2; | ||
const int64_t h_idx = 3; | ||
const int64_t w_idx = 4; | ||
// padding vector: [left, right, top, bottom, font, back] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
.Input("dy", op.GetGradTensorWithOpOutput("y", 0)) | ||
.Output("dx") | ||
.Attr("padding", op.attr<std::vector<int64_t>>("padding")) | ||
.Attr("floating_value", op.attr<double>("floating_value")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
带这两个attr的意义是什么呢?在实现的时候可以做自动类型推断吧。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个主要和2d系列的保持了对齐(后面要改的话,提另一个统一改吧
index_helper.OffsetToNdIndex(num, n, c, d, h, w); | ||
|
||
const int64_t src_num = n_channel * x_depth * x_height * x_width; | ||
if (pad_font <= d && d < pad_font + x_depth && w >= pad_left && w < x_width + pad_left |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里变量的拼写错误也fix一下吧,pad_front
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
constantpad 3d 的op实现包括: