Skip to content

Commit 7234141

Browse files
authored
Split canonical ops (PaddlePaddle#43)
* move popart_canonicalization files * split popart_canonicalization ops
1 parent ee3af7d commit 7234141

File tree

11 files changed

+244
-111
lines changed

11 files changed

+244
-111
lines changed

paddle/fluid/framework/ipu/CMakeLists.txt

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1-
set(POPART_CANONICALIZATION_HANDLERS_SRC
2-
"popart_canonicalization/other_ops.cpp"
1+
set(POPART_CANONICALIZATION_SRC
2+
"popart_canonicalization/canonicalization_utils.cc"
3+
"popart_canonicalization/activation_ops.cc"
4+
"popart_canonicalization/logic_ops.cc"
5+
"popart_canonicalization/math_ops.cc"
6+
"popart_canonicalization/nn_ops.cc"
7+
"popart_canonicalization/tensor_ops.cc"
8+
"popart_canonicalization/other_ops.cc"
39
)
4-
cc_library(popart_canonicalization_utils SRCS popart_canonicalization_utils.cc
5-
${POPART_CANONICALIZATION_HANDLERS_SRC} DEPS framework_proto enforce)
10+
cc_library(popart_canonicalization_utils SRCS ${POPART_CANONICALIZATION_SRC} DEPS framework_proto enforce)
611

712
cc_library(ipu_device SRCS device.cc DEPS enforce popart)
813
cc_library(ipu_utils SRCS ipu_utils.cc DEPS memory framework_proto popart)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ipu/popart_canonicalization/canonicalization_utils.h"
16+
#include "paddle/fluid/platform/enforce.h"
17+
18+
namespace paddle {
19+
namespace framework {
20+
namespace {
21+
22+
//
23+
24+
} // namespace
25+
} // namespace framework
26+
} // namespace paddle

paddle/fluid/framework/ipu/popart_canonicalization_utils.cc renamed to paddle/fluid/framework/ipu/popart_canonicalization/canonicalization_utils.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
#include "paddle/fluid/framework/ipu/popart_canonicalization_utils.h"
15+
#include "paddle/fluid/framework/ipu/popart_canonicalization/canonicalization_utils.h"
1616

1717
namespace paddle {
1818
namespace framework {
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ipu/popart_canonicalization/canonicalization_utils.h"
16+
#include "paddle/fluid/platform/enforce.h"
17+
18+
namespace paddle {
19+
namespace framework {
20+
namespace {
21+
22+
//
23+
24+
} // namespace
25+
} // namespace framework
26+
} // namespace paddle
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ipu/popart_canonicalization/canonicalization_utils.h"
16+
#include "paddle/fluid/platform/enforce.h"
17+
18+
namespace paddle {
19+
namespace framework {
20+
namespace {
21+
22+
ir::Node *elementwise_add_handler(ir::Graph *graph, ir::Node *node) {
23+
auto *op = node->Op();
24+
auto op_desc = std::make_unique<framework::OpDesc>();
25+
op_desc->SetType("Add");
26+
27+
std::vector<std::string> inputs;
28+
inputs.push_back(op->Input("X").front());
29+
inputs.push_back(op->Input("Y").front());
30+
op_desc->SetInput("__inputs__", inputs);
31+
std::vector<std::string> outputs;
32+
outputs.push_back(op->Output("Out").front());
33+
op_desc->SetOutput("__outputs__", outputs);
34+
35+
op_desc->Flush();
36+
return graph->CreateOpNode(op_desc.get());
37+
}
38+
39+
ir::Node *reduce_mean_handler(ir::Graph *graph, ir::Node *node) {
40+
auto *op = node->Op();
41+
auto op_desc = std::make_unique<framework::OpDesc>();
42+
op_desc->SetType("ReduceMean");
43+
44+
std::vector<std::string> inputs;
45+
inputs.push_back(op->Input("X").front());
46+
op_desc->SetInput("__inputs__", inputs);
47+
std::vector<std::string> outputs;
48+
outputs.push_back(op->Output("Out").front());
49+
op_desc->SetOutput("__outputs__", outputs);
50+
auto reduce_all = BOOST_GET_CONST(bool, op->GetAttr("reduce_all"));
51+
if (!reduce_all) {
52+
auto axes_ = BOOST_GET_CONST(std::vector<int>, op->GetAttr("dim"));
53+
auto axes = std::vector<int64_t>{axes_.begin(), axes_.end()};
54+
op_desc->SetAttr("axes", axes);
55+
}
56+
auto keepdims_ = BOOST_GET_CONST(bool, op->GetAttr("keep_dim"));
57+
auto keepdims = int64_t{keepdims_};
58+
op_desc->SetAttr("keepdims", keepdims);
59+
60+
op_desc->Flush();
61+
return graph->CreateOpNode(op_desc.get());
62+
}
63+
64+
REGISTER_HANDLER(elementwise_add, elementwise_add_handler);
65+
REGISTER_HANDLER(reduce_mean, reduce_mean_handler);
66+
67+
} // namespace
68+
} // namespace framework
69+
} // namespace paddle
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ipu/popart_canonicalization/canonicalization_utils.h"
16+
#include "paddle/fluid/platform/enforce.h"
17+
18+
namespace paddle {
19+
namespace framework {
20+
namespace {
21+
22+
ir::Node *conv2d_handler(ir::Graph *graph, ir::Node *node) {
23+
auto *op = node->Op();
24+
auto op_desc = std::make_unique<framework::OpDesc>();
25+
op_desc->SetType("Conv");
26+
27+
std::vector<std::string> inputs;
28+
inputs.push_back(op->Input("Input").front());
29+
inputs.push_back(op->Input("Filter").front());
30+
if (op->HasInput("Bias")) {
31+
if (!op->Input("Bias").empty()) {
32+
inputs.push_back(op->Input("Bias").front());
33+
}
34+
}
35+
op_desc->SetInput("__inputs__", inputs);
36+
std::vector<std::string> outputs;
37+
outputs.push_back(op->Output("Output").front());
38+
op_desc->SetOutput("__outputs__", outputs);
39+
40+
auto dilations_ = BOOST_GET_CONST(std::vector<int>, op->GetAttr("dilations"));
41+
auto dilations = std::vector<int64_t>{dilations_.begin(), dilations_.end()};
42+
auto group_ = BOOST_GET_CONST(int, op->GetAttr("groups"));
43+
auto group = int64_t{group_};
44+
auto pads_ = BOOST_GET_CONST(std::vector<int>, op->GetAttr("paddings"));
45+
if (pads_.size() == 2) {
46+
pads_.push_back(pads_[0]);
47+
pads_.push_back(pads_[1]);
48+
}
49+
auto pads = std::vector<int64_t>{pads_.begin(), pads_.end()};
50+
auto stride_ = BOOST_GET_CONST(std::vector<int>, op->GetAttr("strides"));
51+
auto stride = std::vector<int64_t>{stride_.begin(), stride_.end()};
52+
op_desc->SetAttr("dilations", dilations);
53+
op_desc->SetAttr("group", group);
54+
op_desc->SetAttr("pads", pads);
55+
op_desc->SetAttr("strides", stride);
56+
57+
op_desc->Flush();
58+
return graph->CreateOpNode(op_desc.get());
59+
}
60+
61+
REGISTER_HANDLER(conv2d, conv2d_handler);
62+
63+
} // namespace
64+
} // namespace framework
65+
} // namespace paddle
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ipu/popart_canonicalization/canonicalization_utils.h"
16+
#include "paddle/fluid/platform/enforce.h"
17+
18+
namespace paddle {
19+
namespace framework {
20+
namespace {
21+
22+
//
23+
24+
} // namespace
25+
} // namespace framework
26+
} // namespace paddle

0 commit comments

Comments
 (0)