Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FuseAddToOutput #3524

Merged
merged 10 commits into from
Sep 4, 2020
1 change: 1 addition & 0 deletions oneflow/core/job/job_build_and_infer_ctx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -936,6 +936,7 @@ Maybe<void> LazyJobBuildAndInferCtx::Complete() {
JUST(DoPass("GenerateBackwardAndOptimizerOpConfs"));
JUST(DoPass("CudnnFusedNormalizationAddReluPass"));
JUST(DoPass("PruneCastToStaticShapeOpsPass"));
JUST(DoPass("FuseAddToOutputPass"));
JUST(DoPass("IndexedSlicesOptimizerRewritePass"));
JUST(DoPass("SplitSparseSoftmaxCrossEntropyOpPass"));
JUST(DoPass("DoParallelCastBeforeWideningTypeCast"));
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/job/job_conf.proto
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ message JobConfigProto {
optional bool cudnn_conv_heuristic_search_algo = 205 [default = true];
optional bool cudnn_conv_use_deterministic_algo_only = 206 [default = false];
optional bool enable_cudnn_fused_normalization_add_relu = 207;
optional bool enable_fuse_add_to_output = 208 [default = false];

optional bool enable_reuse_mem = 300 [default = true];
optional bool enable_inplace = 301 [default = true];
Expand Down
130 changes: 130 additions & 0 deletions oneflow/core/job_rewriter/fuse_add_to_output_pass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/job_rewriter/op_graph_pass.h"
#include "oneflow/core/register/runtime_blob_desc.h"
#include "oneflow/core/framework/framework.h"

namespace oneflow {

namespace {

class FuseAddToOutputPass final : public OpGraphPass {
public:
FuseAddToOutputPass() = default;
~FuseAddToOutputPass() override = default;

bool IsEnabled() const override { return GlobalJobDesc().job_conf().enable_fuse_add_to_output(); }
Maybe<void> Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override;
};

Maybe<void> FuseAddToOutputPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const {
const HashMap<std::string, user_op::OpArg> supported_op_type_name2output_arg(
{{"conv_data_grad", user_op::OpArg("dx", 0)}, {"normalization", user_op::OpArg("y", 0)}});
HashMap<std::string, OperatorConf> op_name2op_conf;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fuse_add_to_output_op_name2op_conf ?
或者加注释,这个op_name2op_conf是把该op的输出的唯一消费者 add fuse到该op上。

auto IsAddToOutputSupported = [&](const OpNode* node, const LogicalBlobId& lbi) -> bool {
const OperatorConf& op_conf = node->op().op_conf();
if (!op_conf.has_user_conf()) { return false; }
if (op_name2op_conf.find(op_conf.name()) != op_name2op_conf.end()) { return false; }
auto it = supported_op_type_name2output_arg.find(op_conf.user_conf().op_type_name());
if (it == supported_op_type_name2output_arg.end()) { return false; }
const user_op::UserOpConfWrapper user_op_conf(op_conf);
if (GenLogicalBlobId(user_op_conf.output(it->second.name(), it->second.index())) != lbi) {
return false;
}
// add op should be the only consumer
int64_t output_consumer_cnt = 0;
for (const OpEdge* out_edge : node->out_edges()) {
if (std::find(out_edge->lbis().cbegin(), out_edge->lbis().cend(), lbi)
!= out_edge->lbis().cend()) {
output_consumer_cnt += 1;
}
}
if (output_consumer_cnt != 1) { return false; }
// already fused
if (user_op_conf.has_input("_add_to_output", 0)) { return false; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的众多if逻辑,要加注释

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的众多if逻辑,要加注释

添加了必要的注释

return true;
};
HashSet<std::string> ctrl_in_op_names;
op_graph.ForEachNode([&](const OpNode* op_node) {
for (const std::string& ctrl_in_op_name : op_node->op().op_conf().ctrl_in_op_name()) {
ctrl_in_op_names.insert(ctrl_in_op_name);
}
});

auto IsReachable = op_graph.MakePredicatorIsOpNameDataOrCtrlReachable();
op_graph.ForEachNode([&](const OpNode* op_node) {
const OperatorConf& op_conf = op_node->op().op_conf();
if (!op_conf.has_user_conf()) { return; }
if (!op_conf.ctrl_in_op_name().empty()) { return; }
if (ctrl_in_op_names.find(op_conf.name()) != ctrl_in_op_names.end()) { return; }
if (op_conf.user_conf().op_type_name() != "add_n") { return; }
if (op_name2op_conf.find(op_conf.name()) != op_name2op_conf.end()) { return; }
const user_op::UserOpConfWrapper user_op_conf(op_conf);
if (user_op_conf.input_size("in") != 2) { return; }

const LogicalBlobId in_0 = GenLogicalBlobId(user_op_conf.input("in", 0));
const LogicalBlobId in_1 = GenLogicalBlobId(user_op_conf.input("in", 1));
const LogicalBlobId out = GenLogicalBlobId(user_op_conf.output("out", 0));
const OpNode* in_0_node = op_graph.OpNode4OpName(in_0.op_name());
const OpNode* in_1_node = op_graph.OpNode4OpName(in_1.op_name());

const OpNode* add_to_node;
const LogicalBlobId* add_to_lbi;
const LogicalBlobId* sum_lbi;
if ((!IsReachable(in_0.op_name(), in_1.op_name())) && IsAddToOutputSupported(in_0_node, in_0)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里寻找合适的add,然后fuse到前一个op的output上的逻辑非常绕,需要说明整个判断逻辑。我这里读完简单复述一下:

  1. 遍历所有的add n (n = 2),且没有ctrl in的op
  2. 判断该add_2 op 是否能跟其中的一个in fuse。 判断的标准是:in所在的op是ConvGrad或者BN类型,且没有被fuse过,且输出只有一个消费者,即被这个add op所消费,才能fuse。

保证fuse不会导致其他的图的变化;保证没有递归fuse

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里寻找合适的add,然后fuse到前一个op的output上的逻辑非常绕,需要说明整个判断逻辑。我这里读完简单复述一下:

  1. 遍历所有的add n (n = 2),且没有ctrl in的op
  2. 判断该add_2 op 是否能跟其中的一个in fuse。 判断的标准是:in所在的op是ConvGrad或者BN类型,且没有被fuse过,且输出只有一个消费者,即被这个add op所消费,才能fuse。

保证fuse不会导致其他的图的变化;保证没有递归fuse

这里的代码是自解释的

add_to_node = in_0_node;
add_to_lbi = &in_1;
sum_lbi = &in_0;
} else if ((!IsReachable(in_1.op_name(), in_0.op_name()))
&& IsAddToOutputSupported(in_1_node, in_1)) {
add_to_node = in_1_node;
add_to_lbi = &in_0;
sum_lbi = &in_1;
} else {
return;
}
OperatorConf new_add_to_op_conf = add_to_node->op().op_conf();
*(*(new_add_to_op_conf.mutable_user_conf()->mutable_input()))["_add_to_output"]
.mutable_s()
->Add() = GenLogicalBlobName(*add_to_lbi);
job_builder->MutOpsOnlyOnce({new_add_to_op_conf});
for (const OpEdge* out_edge : op_node->out_edges()) {
const OpNode* consumer = out_edge->dst_node();
const std::string& consumer_op_name = consumer->op().op_name();
if (op_name2op_conf.find(consumer_op_name) == op_name2op_conf.end()) {
op_name2op_conf[consumer_op_name] = consumer->op().op_conf();
}
for (const std::string& ibn : consumer->op().input_bns()) {
if (consumer->op().BnInOp2Lbi(ibn) == out) {
OperatorConf& consumer_op_conf = op_name2op_conf.at(consumer_op_name);
PbMessage* conf =
MutableMessageInPbMessage(&consumer_op_conf, consumer_op_conf.op_type_case());
ReplaceInputLbnInOpCustomizedConf(conf, ibn, GenLogicalBlobName(out),
GenLogicalBlobName(*sum_lbi));
}
}
}
job_builder->DelOps({op_conf});
});
for (const auto& pair : op_name2op_conf) { job_builder->MutOpsOnlyOnce({pair.second}); }
return Maybe<void>::Ok();
}

} // namespace

REGISTER_FUNCTION_PASS("FuseAddToOutputPass", FuseAddToOutputPass);

} // namespace oneflow
14 changes: 13 additions & 1 deletion oneflow/python/framework/function_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def set_cudnn_conv_heuristic_search_algo(func_desc, value):

@oneflow_function_config("enable_cudnn_fused_normalization_add_relu")
def set_enable_cudnn_fused_normalization_add_relu(func_desc, value):
r"""Whether enable cudnn_fused_normalization_add_relu
r"""Whether enable cudnn_fused_normalization_add_relu.

Args:
func_desc ([type]): [description]
Expand All @@ -437,6 +437,18 @@ def set_enable_cudnn_fused_normalization_add_relu(func_desc, value):
func_desc.job_config_proto.enable_cudnn_fused_normalization_add_relu = value


@oneflow_function_config("enable_fuse_add_to_output")
def set_enable_fuse_add_to_output(func_desc, value):
r"""Whether enable fuse_add_to_output.
If enabled, try to fuse a binary element-wise add to one of the predecessors to improve performance.

Args:
func_desc ([type]): [description]
value ([type]): [description]
"""
func_desc.job_config_proto.enable_fuse_add_to_output = value


@oneflow_function_config("cudnn_conv_use_deterministic_algo_only")
def set_cudnn_conv_use_deterministic_algo_only(func_desc, value):
r"""Set value to cudnn conv_use_deterministic_only algorithm
Expand Down
54 changes: 38 additions & 16 deletions oneflow/user/kernels/conv_cudnn_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include "oneflow/user/ops/nn_util.h"
#include "oneflow/core/device/cudnn_conv_util.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/kernel/new_kernel_util.h"

namespace oneflow {

Expand Down Expand Up @@ -239,27 +240,48 @@ class ConvDataGradGpuKernel final : public user_op::OpKernel {
const CudnnConvArgs& args = args_and_algo.args;
const cudnnConvolutionBwdDataAlgoPerf_t& algo_perf = args_and_algo.algo_perf;

const void* alpha = CudnnSPOnePtr<T>();
const void* beta;
if (ctx->user_op_conf().has_input("_add_to_output", 0)) {
const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0);
CHECK_EQ(add_to_output->data_type(), dx->data_type());
CHECK_EQ(add_to_output->shape(), dx->shape());
Memcpy<DeviceType::kGPU>(
ctx->device_ctx(), dx->mut_dptr<void>(), add_to_output->dptr<void>(),
add_to_output->shape().elem_cnt() * GetSizeOfDataType(add_to_output->data_type()));
beta = CudnnSPOnePtr<T>();
} else {
beta = CudnnSPZeroPtr<T>();
}

OF_CUDNN_CHECK(cudnnConvolutionBackwardData(
ctx->device_ctx()->cudnn_handle(), CudnnSPOnePtr<T>(), args.wdesc.Get(), filter->dptr(),
ctx->device_ctx()->cudnn_handle(), alpha, args.wdesc.Get(), filter->dptr(),
args.ydesc.Get(), dy->dptr(), args.cdesc.Get(), algo_perf.algo, buf->mut_dptr(),
args.params.max_ws_size, CudnnSPZeroPtr<T>(), args.xdesc.Get(), dx->mut_dptr()));
args.params.max_ws_size, beta, args.xdesc.Get(), dx->mut_dptr()));
}
};

#define REGISTER_CONV_DATA_GRAD_FLOATING_KERNEL(dtype) \
REGISTER_USER_KERNEL("conv_data_grad") \
.SetCreateFn<ConvDataGradGpuKernel<dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceTag() == "gpu") \
& (user_op::HobDataType("dy", 0) == GetDataType<dtype>::value)) \
.SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \
const JobDesc& job_desc = ctx->job_desc(); \
const auto* dy = ctx->TensorDesc4ArgNameAndIndex("dy", 0); \
const auto* filter = ctx->TensorDesc4ArgNameAndIndex("filter", 0); \
const auto* dx = ctx->TensorDesc4ArgNameAndIndex("dx", 0); \
return InferTmpSizeWithCudnn<cudnnConvolutionBwdDataAlgoPerf_t>( \
dx, filter, dy, job_desc, ctx->user_op_conf(), \
job_desc.job_conf().has_cudnn_conv_force_bwd_data_algo(), \
job_desc.job_conf().cudnn_conv_force_bwd_data_algo()); \
#define REGISTER_CONV_DATA_GRAD_FLOATING_KERNEL(dtype) \
REGISTER_USER_KERNEL("conv_data_grad") \
.SetCreateFn<ConvDataGradGpuKernel<dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceTag() == "gpu") \
& (user_op::HobDataType("dy", 0) == GetDataType<dtype>::value)) \
.SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \
const JobDesc& job_desc = ctx->job_desc(); \
const auto* dy = ctx->TensorDesc4ArgNameAndIndex("dy", 0); \
const auto* filter = ctx->TensorDesc4ArgNameAndIndex("filter", 0); \
const auto* dx = ctx->TensorDesc4ArgNameAndIndex("dx", 0); \
return InferTmpSizeWithCudnn<cudnnConvolutionBwdDataAlgoPerf_t>( \
dx, filter, dy, job_desc, ctx->user_op_conf(), \
job_desc.job_conf().has_cudnn_conv_force_bwd_data_algo(), \
job_desc.job_conf().cudnn_conv_force_bwd_data_algo()); \
}) \
.SetInplaceProposalFn([](const user_op::InferContext& ctx, \
user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe<void> { \
if (ctx.user_op_conf().has_input("_add_to_output", 0)) { \
OF_RETURN_IF_ERROR(AddInplaceArgPairFn("dx", 0, "_add_to_output", 0, true)); \
} \
return Maybe<void>::Ok(); \
})

REGISTER_CONV_DATA_GRAD_FLOATING_KERNEL(float);
Expand Down
9 changes: 9 additions & 0 deletions oneflow/user/kernels/conv_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#include "oneflow/core/framework/framework.h"
#include "oneflow/user/ops/nn_util.h"
#include "oneflow/core/kernel/new_kernel_util.h"
#include "oneflow/core/kernel/kernel_util.h"

namespace oneflow {

Expand Down Expand Up @@ -530,6 +531,14 @@ class ConvDataGradCpuKernel final : public user_op::OpKernel {
conv_state->dilation_rate_3d_.data(),
conv_state->padding_before_3d_.data(), GetImgMutDptr<T>(dx, i));
}
if (ctx->user_op_conf().has_input("_add_to_output", 0)) {
const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0);
CHECK_EQ(add_to_output->data_type(), dx->data_type());
CHECK_EQ(add_to_output->shape(), dx->shape());
KernelUtil<DeviceType::kCPU, T>::Addition(
ctx->device_ctx(), add_to_output->shape().elem_cnt(), dx->mut_dptr<T>(), dx->dptr<T>(),
add_to_output->dptr<T>());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CPU情况下的fuse add to output会有提升吗

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CPU情况下的fuse add to output会有提升吗

取决于kernel的实现,CPU访存也是有开销的,所以理论上也有机会优化

}
}
};

Expand Down
Loading