Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Browse files Browse the repository at this point in the history
… remove_recurrent_op_and_static_rnn_api
  • Loading branch information
phlrain committed Nov 23, 2022
2 parents c8f184c + 3daf518 commit 53b0eb6
Show file tree
Hide file tree
Showing 32 changed files with 868 additions and 343 deletions.
2 changes: 2 additions & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ pass_library(embedding_fc_lstm_fuse_pass inference)
pass_library(fc_gru_fuse_pass inference)
pass_library(seq_concat_fc_fuse_pass inference)
pass_library(multi_batch_merge_pass base)
pass_library(map_depthwise_conv_to_conv_pass inference)
pass_library(conv_bn_fuse_pass inference)
pass_library(seqconv_eltadd_relu_fuse_pass inference)
pass_library(seqpool_concat_fuse_pass inference)
Expand All @@ -101,6 +102,7 @@ pass_library(delete_c_identity_op_pass inference)
pass_library(preln_residual_bias_fuse_pass inference)
pass_library(delete_fill_constant_op_pass inference)
pass_library(constant_folding_pass inference)
pass_library(conv2d_fusion_layout_transfer_pass inference)
pass_library(simplify_with_basic_ops_pass base)
pass_library(fc_elementwise_layernorm_fuse_pass base)
pass_library(skip_layernorm_fuse_pass base)
Expand Down
278 changes: 278 additions & 0 deletions paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.h"

#include <string>
#include <unordered_map>
#include <unordered_set>

#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"

namespace paddle {
namespace framework {
namespace ir {
namespace {

void TransDataLayout(DataLayout from_layout,
DataLayout to_layout,
const phi::DenseTensor &in,
phi::DenseTensor *out) {
PADDLE_ENFORCE_EQ(
arity(in.dims()),
4,
platform::errors::InvalidArgument(
"Input dimension arity only can be 4, the input dimension is %s.",
in.dims()));

auto &pool = platform::DeviceContextPool::Instance();

auto src_dim = in.dims();
std::vector<int64_t> dst_dim;

auto axis = GetAxis(from_layout, to_layout);
dst_dim.resize(axis.size());
for (size_t i = 0; i < axis.size(); i++) {
dst_dim[i] = src_dim[axis[i]];
}

out->Resize(phi::make_ddim(dst_dim));
out->mutable_data(phi::CPUPlace(), in.dtype());

framework::VisitDataType(
framework::TransToProtoVarType(in.dtype()),
CastDataLayout(pool.Get(phi::CPUPlace()), axis, in, out));

out->set_layout(to_layout);
}

void InsertLayoutTransOp(ir::Graph *graph,
ir::Node *prev_node,
ir::Node *next_node,
DataLayout from_layout,
DataLayout to_layout,
framework::BlockDesc *block_desc,
std::unordered_map<ir::Node *, ir::Node *> *cache) {
auto do_insert = [&](const std::string &in_var_name,
const std::string &out_var_name) {
auto update_op_desc = [&](framework::OpDesc &desc,
const std::string &x_name,
const std::string &out_name) {
desc.SetType("transfer_layout");
desc.SetInput("X", {x_name});
desc.SetOutput("Out", {out_name});
desc.SetAttr("src_layout", static_cast<int>(from_layout));
desc.SetAttr("dst_layout", static_cast<int>(to_layout));
desc.Flush();
};
CHECK_NOTNULL(block_desc);
if (cache->count(prev_node) == 0) {
framework::OpDesc op_desc(block_desc);
update_op_desc(op_desc, in_var_name, out_var_name);
auto *op_node = graph->CreateOpNode(&op_desc);
auto *op_out_var_desc = block_desc->Var(out_var_name);

op_out_var_desc->SetPersistable(false);
op_out_var_desc->SetDataType(prev_node->Var()->GetDataType());
auto to_shape = prev_node->Var()->GetShape();
if (from_layout == DataLayout::kNCHW) {
auto n = to_shape[0];
auto c = to_shape[1];
auto h = to_shape[2];
auto w = to_shape[3];
op_out_var_desc->SetShape({n, h, w, c});
} else {
auto n = to_shape[0];
auto h = to_shape[1];
auto w = to_shape[2];
auto c = to_shape[3];
op_out_var_desc->SetShape({n, c, h, w});
}

auto *op_out_var_node = graph->CreateVarNode(op_out_var_desc);
IR_NODE_LINK_TO(op_node, op_out_var_node);
cache->insert(std::make_pair(prev_node, op_out_var_node));
}
next_node->Op()->RenameInput(prev_node->Name(),
cache->at(prev_node)->Name());
IR_NODE_LINK_TO(prev_node, cache->at(prev_node)->inputs.front());
IR_NODE_LINK_TO(cache->at(prev_node), next_node);

IR_NODE_UNLINK(prev_node, next_node);
};

if (from_layout == DataLayout::kNCHW && to_layout == DataLayout::kNHWC) {
auto in_var_name = prev_node->Var()->Name();
auto out_var_name = in_var_name + "_nchw_to_nhwc";
do_insert(in_var_name, out_var_name);
} else if (from_layout == DataLayout::kNHWC &&
to_layout == DataLayout::kNCHW) {
auto in_var_name = prev_node->Var()->Name();
auto out_var_name = in_var_name + "_nhwc_to_nchw";
do_insert(in_var_name, out_var_name);
}
}

} // namespace

void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph,
platform::errors::PreconditionNotMet("graph should not be nullptr."));
FusePassBase::Init("data_layout_transfer", graph);
auto *scope = param_scope();

PADDLE_ENFORCE_EQ(graph->IsMainGraph(),
true,
platform::errors::InvalidArgument(
"the graph should be main graph when applying "
"conv2d_fusion_layout_transfer_pass"));

PADDLE_ENFORCE_NOT_NULL(
scope,
platform::errors::Fatal("scope must not be nullptr when applying "
"conv2d_fusion_layout_transfer_pass"));

// Not support multiple block now.
std::unordered_map<ir::Node *, ir::Node *> cache;
auto op_nodes = ir::TopologySortOperations(*graph);
auto iter = op_nodes.cbegin();
auto *block_desc = (*iter)->Op()->Block();

std::unordered_set<ir::Node *> vars_shape_nhwc;

// Only support conv2d_fusion now.
std::string target_op_type = "conv2d_fusion";
std::unordered_set<ir::Node *> valid_ops;

auto OpIsValid = [&](ir::Node *op_node) -> bool {
if (op_node->Op()->Type() != target_op_type) return false;

auto data_format =
op_node->Op()->GetAttrIfExists<std::string>("data_format");
if (data_format != "NCHW") return false;

auto filter_names = op_node->Op()->Input("Filter");

// If filter's channel is not multiple of 8, conv2d_fusion not run at nhwc.
for (const auto &filter_name : filter_names) {
auto *filter_var = scope->FindLocalVar(filter_name);
const auto &filter_tensor = filter_var->Get<phi::DenseTensor>();
if (filter_tensor.dims().size() == 4 &&
(filter_tensor.dims()[0] % 8 != 0 ||
filter_tensor.dims()[1] % 8 != 0)) {
return false;
}
}
return true;
};

for (auto *op_node : op_nodes) {
CHECK_EQ(op_node->IsOp(), true);
if (OpIsValid(op_node)) {
valid_ops.insert(op_node);
auto *op_desc = op_node->Op();
auto nhwc_attr = framework::Attribute(std::string("NHWC"));
op_desc->SetAttr("data_format", nhwc_attr);
op_desc->Flush();

// transfer weights
auto filter_names = op_desc->Input("Filter");
for (const auto &filter_name : filter_names) {
auto *filter_var = scope->FindLocalVar(filter_name);
auto *filter_tensor = filter_var->GetMutable<phi::DenseTensor>();
phi::DenseTensor temp_tensor = *filter_tensor;
filter_tensor->clear();

TransDataLayout(
DataLayout::kNCHW, DataLayout::kNHWC, temp_tensor, filter_tensor);
}
auto op_inputs = op_node->inputs;
for (auto *in_var_node : op_inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
if (in_var_node->Var()->Persistable()) {
if (std::find(filter_names.cbegin(),
filter_names.cend(),
in_var_node->Var()->Name()) != filter_names.cend()) {
auto from_shape = in_var_node->Var()->GetShape();
in_var_node->Var()->SetShape(
{from_shape[0], from_shape[2], from_shape[3], from_shape[1]});
}
}
}

// transfer outputs
auto op_outputs = op_node->outputs;
for (auto *out_var_node : op_outputs) {
CHECK_EQ(out_var_node->IsVar(), true);
if (out_var_node->Var()->Persistable()) continue;

auto from_shape = out_var_node->Var()->GetShape();
out_var_node->Var()->SetShape(
{from_shape[0], from_shape[2], from_shape[3], from_shape[1]});
vars_shape_nhwc.insert(out_var_node);
}
}
}

// Insert transfer_layout op
for (auto *op_node : op_nodes) {
CHECK_EQ(op_node->IsOp(), true);

if (valid_ops.count(op_node)) {
auto op_inputs = op_node->inputs;
for (auto *in_var_node : op_inputs) {
CHECK_EQ(in_var_node->IsVar(), true);

if (in_var_node->Var()->Persistable()) continue;
if (vars_shape_nhwc.count(in_var_node)) continue;

InsertLayoutTransOp(graph,
in_var_node,
op_node,
DataLayout::kNCHW,
DataLayout::kNHWC,
block_desc,
&cache);
}
} else {
auto op_inputs = op_node->inputs;
for (auto *in_var_node : op_inputs) {
CHECK_EQ(in_var_node->IsVar(), true);

if (vars_shape_nhwc.count(in_var_node)) {
InsertLayoutTransOp(graph,
in_var_node,
op_node,
DataLayout::kNHWC,
DataLayout::kNCHW,
block_desc,
&cache);
}
}
}
}
}

} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(conv2d_fusion_layout_transfer_pass,
paddle::framework::ir::Conv2dFusionLayoutTransferPass);
34 changes: 34 additions & 0 deletions paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/framework/ir/fuse_pass_base.h"

namespace paddle {
namespace framework {
namespace ir {

class Conv2dFusionLayoutTransferPass : public FusePassBase {
public:
Conv2dFusionLayoutTransferPass() = default;
virtual ~Conv2dFusionLayoutTransferPass() = default;

protected:
void ApplyImpl(ir::Graph* graph) const override;
};

} // namespace ir
} // namespace framework
} // namespace paddle
61 changes: 61 additions & 0 deletions paddle/fluid/framework/ir/map_depthwise_conv_to_conv_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/map_depthwise_conv_to_conv_pass.h"

#include <string>

#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace framework {
namespace ir {

void MapDepthwiseConv2ConvPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("map_depthwise_conv_to_conv_pass", graph);

int found_count = 0;
std::unordered_map<std::string, std::string> replaced_map{
{"depthwise_conv2d", "conv2d"},
};

auto nodes = graph->Nodes();

for (auto& node : nodes) {
if (!node->IsOp()) continue;
auto* op_desc = node->Op();
std::string op_type = op_desc->Type();
if (!replaced_map.count(op_type)) continue;
op_desc->SetType(replaced_map[op_type]);
op_desc->Flush();
++found_count;
}

AddStatis(found_count);
}

} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(map_depthwise_conv_to_conv_pass,
paddle::framework::ir::MapDepthwiseConv2ConvPass);
REGISTER_PASS_CAPABILITY(map_depthwise_conv_to_conv_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("depthwise_conv2d", 1)
.LE("conv2d", 1));
Loading

0 comments on commit 53b0eb6

Please sign in to comment.