-
Notifications
You must be signed in to change notification settings - Fork 825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FuseAddToOutput #3524
FuseAddToOutput #3524
Conversation
需不需要为这个功能添加相应的单测保证 fuse 之前和之后的输出是一样的呢 |
目前CI的集成测试部分覆盖了这个场景,我也考虑一下怎么加单元测试 |
第一句说到 add 的执行时间为 |
inplace失败的情况下,memcpy增加的开销是 |
@@ -437,6 +437,17 @@ def set_enable_cudnn_fused_normalization_add_relu(func_desc, value): | |||
func_desc.job_config_proto.enable_cudnn_fused_normalization_add_relu = value | |||
|
|||
|
|||
@oneflow_function_config("enable_fuse_add_to_output") | |||
def set_enable_fuse_add_to_output(func_desc, value): | |||
r"""Whether enable fuse_add_to_output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以把comment里对这个开关的描述用英文简单说明一下? 比如目的、功效。作为我们API的注释文档。#3524 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里才是 If enabled, try to fuse a binary element-wise add to one of the predecessors to improve performance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里才是 If enabled, try to fuse a binary element-wise add to one of the predecessors to improve performance.
尴尬
CHECK_EQ(add_to_output->shape(), dx->shape()); | ||
KernelUtil<DeviceType::kCPU, T>::Addition( | ||
ctx->device_ctx(), add_to_output->shape().elem_cnt(), dx->mut_dptr<T>(), dx->dptr<T>(), | ||
add_to_output->dptr<T>()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CPU情况下的fuse add to output会有提升吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CPU情况下的fuse add to output会有提升吗
取决于kernel的实现,CPU访存也是有开销的,所以理论上也有机会优化
} | ||
} | ||
if (output_consumer_cnt != 1) { return false; } | ||
if (user_op_conf.has_input("_add_to_output", 0)) { return false; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的众多if逻辑,要加注释
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的众多if逻辑,要加注释
添加了必要的注释
Maybe<void> FuseAddToOutputPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { | ||
const HashMap<std::string, user_op::OpArg> supported_op_type_name2output_arg( | ||
{{"conv_data_grad", user_op::OpArg("dx", 0)}, {"normalization", user_op::OpArg("y", 0)}}); | ||
HashMap<std::string, OperatorConf> op_name2op_conf; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fuse_add_to_output_op_name2op_conf ?
或者加注释,这个op_name2op_conf是把该op的输出的唯一消费者 add fuse到该op上。
const OpNode* add_to_node; | ||
const LogicalBlobId* add_to_lbi; | ||
const LogicalBlobId* sum_lbi; | ||
if ((!IsReachable(in_0.op_name(), in_1.op_name())) && IsAddToOutputSupported(in_0_node, in_0)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里寻找合适的add,然后fuse到前一个op的output上的逻辑非常绕,需要说明整个判断逻辑。我这里读完简单复述一下:
- 遍历所有的add n (n = 2),且没有ctrl in的op
- 判断该add_2 op 是否能跟其中的一个in fuse。 判断的标准是:in所在的op是ConvGrad或者BN类型,且没有被fuse过,且输出只有一个消费者,即被这个add op所消费,才能fuse。
保证fuse不会导致其他的图的变化;保证没有递归fuse
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里寻找合适的add,然后fuse到前一个op的output上的逻辑非常绕,需要说明整个判断逻辑。我这里读完简单复述一下:
- 遍历所有的add n (n = 2),且没有ctrl in的op
- 判断该add_2 op 是否能跟其中的一个in fuse。 判断的标准是:in所在的op是ConvGrad或者BN类型,且没有被fuse过,且输出只有一个消费者,即被这个add op所消费,才能fuse。
保证fuse不会导致其他的图的变化;保证没有递归fuse
这里的代码是自解释的
可以构造一个单测,网络里有一个bn或者conv grad,后面接一个add,fuse开关打开add被fuse了,且输出不变?可能这个单测也过于简单。如果resnet50里一定有这个fuse产生,那就用resnet50的网络的集成测试,开了fuse结果不变。 |
单元测试的问题是不容易判断是否被fuse了 |
@@ -428,7 +428,7 @@ def set_cudnn_conv_heuristic_search_algo(func_desc, value): | |||
|
|||
@oneflow_function_config("enable_cudnn_fused_normalization_add_relu") | |||
def set_enable_cudnn_fused_normalization_add_relu(func_desc, value): | |||
r"""Whether enable cudnn_fused_normalization_add_relu | |||
r"""Whether enable cudnn_fused_normalization_add_relu. If enabled, try to fuse a binary element-wise add to one of the predecessors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个API注释加错位置了吧。。。。另外可以补一句 to improve performance
@@ -152,6 +160,7 @@ REGISTER_USER_OP("normalization") | |||
.Input("moving_variance") | |||
.Input("gamma") | |||
.Input("beta") | |||
.OptionalInput("_add_to_output") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个以下划线开头的arg命名感觉很奇怪。。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个以下划线开头的arg命名感觉很奇怪。。
下划线是为了描述这是一个特殊的arg,类似与保留字。如果是后面的add_to_output奇怪,我暂时想不到更好的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add to output通俗易懂,挺好的。我只是奇怪这个下划线,去搜了一下保留字,确实有这种用法,是我见识太少了QAQ
@@ -428,7 +428,8 @@ def set_cudnn_conv_heuristic_search_algo(func_desc, value): | |||
|
|||
@oneflow_function_config("enable_cudnn_fused_normalization_add_relu") | |||
def set_enable_cudnn_fused_normalization_add_relu(func_desc, value): | |||
r"""Whether enable cudnn_fused_normalization_add_relu | |||
r"""Whether enable cudnn_fused_normalization_add_relu. | |||
If enabled, try to fuse a binary element-wise add to one of the predecessors to improve performance. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不是。。。这个开关的作用是fuse bn add relu 吧。下面 441 行的开关才是这句注释该加的地方吧。俊丞兄你加错行了
尝试将二元加法op和输入的生产者之一fuse,利用 cudnn 或者kernel支持的 inplace add 减少 add 的开销,提高性能。
add 通常是访存瓶颈操作,开销体现在两个输入的load和输出的store,假设输入或者输出的大小为N,add的执行时间约为
3 * N / B
,B
为显存带宽。以cuDNN 为例,如果将加法的一个输入做为生产者之一的op的额外输入并且和该op的output inplace,利用 cuDNN 的 scale parameter 机制,可以将该开销降低为N / B
(额外输入的load开销,加法操作的开销一般认为可以被内存操作掩盖) ;如果额外的输入和无法和output inplace,那么需要先进行memcpy,开销为2 * N / B
,总开销和原来使用 add op 相等。支持的 op 需要添加一个可选的输入
_add_to_output
,并在 kernel 中自行将原本的输出结果和_add_to_output
中的内容相加。FuseAddToOutputPass
会在op graph上寻找符合条件的 add op 进行 fusion。这个PR中添加了conv_data_grad
和normalization
的支持,可以覆盖 resnet50 中的全部 add。