Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FuseAddToOutput #3524

Merged
merged 10 commits into from
Sep 4, 2020
Merged

FuseAddToOutput #3524

merged 10 commits into from
Sep 4, 2020

Conversation

liujuncheng
Copy link
Collaborator

尝试将二元加法op和输入的生产者之一fuse,利用 cudnn 或者kernel支持的 inplace add 减少 add 的开销,提高性能。

add 通常是访存瓶颈操作,开销体现在两个输入的load和输出的store,假设输入或者输出的大小为N,add的执行时间约为 3 * N / BB 为显存带宽。以cuDNN 为例,如果将加法的一个输入做为生产者之一的op的额外输入并且和该op的output inplace,利用 cuDNN 的 scale parameter 机制,可以将该开销降低为 N / B (额外输入的load开销,加法操作的开销一般认为可以被内存操作掩盖) ;如果额外的输入和无法和output inplace,那么需要先进行memcpy,开销为 2 * N / B,总开销和原来使用 add op 相等。

支持的 op 需要添加一个可选的输入 _add_to_output,并在 kernel 中自行将原本的输出结果和_add_to_output中的内容相加。FuseAddToOutputPass 会在op graph上寻找符合条件的 add op 进行 fusion。这个PR中添加了 conv_data_gradnormalization 的支持,可以覆盖 resnet50 中的全部 add。

@daquexian
Copy link
Contributor

需不需要为这个功能添加相应的单测保证 fuse 之前和之后的输出是一样的呢

@liujuncheng
Copy link
Collaborator Author

需不需要为这个功能添加相应的单测保证 fuse 之前和之后的输出是一样的呢

目前CI的集成测试部分覆盖了这个场景,我也考虑一下怎么加单元测试

@jackalcooper jackalcooper added this to the 0.1.11 milestone Sep 1, 2020
@chengtbf
Copy link
Contributor

chengtbf commented Sep 2, 2020

add 通常是访存瓶颈操作,开销体现在两个输入的load和输出的store,假设输入或者输出的大小为N,add的执行时间约为 3 * N / BB 为显存带宽。以cuDNN 为例,如果将加法的一个输入做为生产者之一的op的额外输入并且和该op的output inplace,利用 cuDNN 的 scale parameter 机制,可以将该开销降低为 N / B (额外输入的load开销,加法操作的开销一般认为可以被内存操作掩盖) ;如果额外的输入和无法和output inplace,那么需要先进行memcpy,开销为 2 * N / B,总开销和原来使用 add op 相等。

第一句说到 add 的执行时间为 3 * N / B,当inplace成功时,开销降低到 N / B, inplace失败是,开销为 2 * N / B,为什么跟原来的add op相等?

@liujuncheng
Copy link
Collaborator Author

add 通常是访存瓶颈操作,开销体现在两个输入的load和输出的store,假设输入或者输出的大小为N,add的执行时间约为 3 * N / BB 为显存带宽。以cuDNN 为例,如果将加法的一个输入做为生产者之一的op的额外输入并且和该op的output inplace,利用 cuDNN 的 scale parameter 机制,可以将该开销降低为 N / B (额外输入的load开销,加法操作的开销一般认为可以被内存操作掩盖) ;如果额外的输入和无法和output inplace,那么需要先进行memcpy,开销为 2 * N / B,总开销和原来使用 add op 相等。

第一句说到 add 的执行时间为 3 * N / B,当inplace成功时,开销降低到 N / B, inplace失败是,开销为 2 * N / B,为什么跟原来的add op相等?

inplace失败的情况下,memcpy增加的开销是 2 * N / B,这句话“当inplace成功时,开销降低到 N / B”中的N / B还是需要的,加起来是 3 * N / B

@@ -437,6 +437,17 @@ def set_enable_cudnn_fused_normalization_add_relu(func_desc, value):
func_desc.job_config_proto.enable_cudnn_fused_normalization_add_relu = value


@oneflow_function_config("enable_fuse_add_to_output")
def set_enable_fuse_add_to_output(func_desc, value):
r"""Whether enable fuse_add_to_output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以把comment里对这个开关的描述用英文简单说明一下? 比如目的、功效。作为我们API的注释文档。#3524 (comment)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里才是 If enabled, try to fuse a binary element-wise add to one of the predecessors to improve performance.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里才是 If enabled, try to fuse a binary element-wise add to one of the predecessors to improve performance.

尴尬

CHECK_EQ(add_to_output->shape(), dx->shape());
KernelUtil<DeviceType::kCPU, T>::Addition(
ctx->device_ctx(), add_to_output->shape().elem_cnt(), dx->mut_dptr<T>(), dx->dptr<T>(),
add_to_output->dptr<T>());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CPU情况下的fuse add to output会有提升吗

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CPU情况下的fuse add to output会有提升吗

取决于kernel的实现,CPU访存也是有开销的,所以理论上也有机会优化

}
}
if (output_consumer_cnt != 1) { return false; }
if (user_op_conf.has_input("_add_to_output", 0)) { return false; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的众多if逻辑,要加注释

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的众多if逻辑,要加注释

添加了必要的注释

Maybe<void> FuseAddToOutputPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const {
const HashMap<std::string, user_op::OpArg> supported_op_type_name2output_arg(
{{"conv_data_grad", user_op::OpArg("dx", 0)}, {"normalization", user_op::OpArg("y", 0)}});
HashMap<std::string, OperatorConf> op_name2op_conf;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fuse_add_to_output_op_name2op_conf ?
或者加注释,这个op_name2op_conf是把该op的输出的唯一消费者 add fuse到该op上。

const OpNode* add_to_node;
const LogicalBlobId* add_to_lbi;
const LogicalBlobId* sum_lbi;
if ((!IsReachable(in_0.op_name(), in_1.op_name())) && IsAddToOutputSupported(in_0_node, in_0)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里寻找合适的add,然后fuse到前一个op的output上的逻辑非常绕,需要说明整个判断逻辑。我这里读完简单复述一下:

  1. 遍历所有的add n (n = 2),且没有ctrl in的op
  2. 判断该add_2 op 是否能跟其中的一个in fuse。 判断的标准是:in所在的op是ConvGrad或者BN类型,且没有被fuse过,且输出只有一个消费者,即被这个add op所消费,才能fuse。

保证fuse不会导致其他的图的变化;保证没有递归fuse

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里寻找合适的add,然后fuse到前一个op的output上的逻辑非常绕,需要说明整个判断逻辑。我这里读完简单复述一下:

  1. 遍历所有的add n (n = 2),且没有ctrl in的op
  2. 判断该add_2 op 是否能跟其中的一个in fuse。 判断的标准是:in所在的op是ConvGrad或者BN类型,且没有被fuse过,且输出只有一个消费者,即被这个add op所消费,才能fuse。

保证fuse不会导致其他的图的变化;保证没有递归fuse

这里的代码是自解释的

@chengtbf
Copy link
Contributor

chengtbf commented Sep 2, 2020

可以构造一个单测,网络里有一个bn或者conv grad,后面接一个add,fuse开关打开add被fuse了,且输出不变?可能这个单测也过于简单。如果resnet50里一定有这个fuse产生,那就用resnet50的网络的集成测试,开了fuse结果不变。

@liujuncheng
Copy link
Collaborator Author

可以构造一个单测,网络里有一个bn或者conv grad,后面接一个add,fuse开关打开add被fuse了,且输出不变?可能这个单测也过于简单。如果resnet50里一定有这个fuse产生,那就用resnet50的网络的集成测试,开了fuse结果不变。

单元测试的问题是不容易判断是否被fuse了
CNN基本上一定会触发,所以集成测试就可以作为他的测试用例,未来这个开关应该是默认开启的,目前一两个版本暂时默认关闭。

@@ -428,7 +428,7 @@ def set_cudnn_conv_heuristic_search_algo(func_desc, value):

@oneflow_function_config("enable_cudnn_fused_normalization_add_relu")
def set_enable_cudnn_fused_normalization_add_relu(func_desc, value):
r"""Whether enable cudnn_fused_normalization_add_relu
r"""Whether enable cudnn_fused_normalization_add_relu. If enabled, try to fuse a binary element-wise add to one of the predecessors
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个API注释加错位置了吧。。。。另外可以补一句 to improve performance

@@ -152,6 +160,7 @@ REGISTER_USER_OP("normalization")
.Input("moving_variance")
.Input("gamma")
.Input("beta")
.OptionalInput("_add_to_output")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个以下划线开头的arg命名感觉很奇怪。。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个以下划线开头的arg命名感觉很奇怪。。

下划线是为了描述这是一个特殊的arg,类似与保留字。如果是后面的add_to_output奇怪,我暂时想不到更好的

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add to output通俗易懂,挺好的。我只是奇怪这个下划线,去搜了一下保留字,确实有这种用法,是我见识太少了QAQ

@@ -428,7 +428,8 @@ def set_cudnn_conv_heuristic_search_algo(func_desc, value):

@oneflow_function_config("enable_cudnn_fused_normalization_add_relu")
def set_enable_cudnn_fused_normalization_add_relu(func_desc, value):
r"""Whether enable cudnn_fused_normalization_add_relu
r"""Whether enable cudnn_fused_normalization_add_relu.
If enabled, try to fuse a binary element-wise add to one of the predecessors to improve performance.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不是。。。这个开关的作用是fuse bn add relu 吧。下面 441 行的开关才是这句注释该加的地方吧。俊丞兄你加错行了

@liujuncheng liujuncheng merged commit fd2fe57 into master Sep 4, 2020
@liujuncheng liujuncheng deleted the dev_pr_fuse_add_to_output branch September 4, 2020 07:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants