-
Notifications
You must be signed in to change notification settings - Fork 825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add masked fill op #3515
add masked fill op #3515
Conversation
oneflow/user/ops/masked_fill_op.cpp
Outdated
REGISTER_USER_OP("masked_fill") | ||
.Input("x") | ||
.Input("mask") | ||
.Attr("value", UserOpAttrType::kAtFloat) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
masked_fill支持多种数据类型,没有办法用float精确表达,这里可以参考scalar_mul
…dev_add_op_masked_fill
…dev_add_op_masked_fill
|
||
namespace { | ||
|
||
__global__ void NaiveHalfFillGpu(const int64_t elem_cnt, const float16 x, float16* y) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NewKernelUtil<device_type>::Fill
是支持fp16的,是不是只需要处理好operand到float16的转换,而不是新写kernel
|
||
REGISTER_HALF_CONSTANT_LIKE_KERNEL | ||
|
||
} // namespace oneflow |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注意这里的空行
|
||
namespace { | ||
|
||
__global__ void HalfAddByScalarPtrGpu(const int64_t n, const half* x, const half* y, half* z) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里应该是给XxxByScalarPtr添加float16类型吧
namespace { | ||
|
||
template<typename CondT> | ||
__global__ void NaiveHalfWhere(const int64_t elem_cnt, const CondT* cond, const half* lhs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同理,这里直接给where添加float16就可以了吧
|
||
__global__ void HalfAddByScalarPtrGpu(const int64_t n, const half* x, const half* y, half* z) { | ||
const half y_value = y[0]; | ||
CUDA_1D_KERNEL_LOOP(i, n) { z[i] = x[i] + y_value; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
float16/half类型不能直接用+
,要用__hadd
change piar to macro Co-authored-by: Juncheng <liujuncheng1022@gmail.com>
添加masked fill op
reference https://pytorch.org/docs/stable/tensors.html?highlight=masked_fill#torch.Tensor.masked_fill