|
| 1 | +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#include <algorithm> |
| 16 | +#include <memory> |
| 17 | +#include <string> |
| 18 | +#include <unordered_map> |
| 19 | +#include <unordered_set> |
| 20 | +#include <vector> |
| 21 | + |
| 22 | +#include "gtest/gtest-message.h" |
| 23 | +#include "gtest/gtest-test-part.h" |
| 24 | +#include "gtest/gtest.h" |
| 25 | +#include "gtest/gtest_pred_impl.h" |
| 26 | + |
| 27 | +#include "paddle/fluid/framework/details/build_strategy.h" |
| 28 | +#include "paddle/fluid/framework/op_proto_maker.h" |
| 29 | +#include "paddle/fluid/framework/op_registry.h" |
| 30 | +#include "paddle/fluid/framework/var_type_inference.h" |
| 31 | +#include "paddle/fluid/platform/place.h" |
| 32 | + |
| 33 | +DECLARE_bool(convert_all_blocks); |
| 34 | + |
| 35 | +namespace paddle { |
| 36 | +namespace framework { |
| 37 | + |
| 38 | +class SumOpMaker : public OpProtoAndCheckerMaker { |
| 39 | + public: |
| 40 | + void Make() { |
| 41 | + AddInput("X", "").AsDuplicable(); |
| 42 | + AddOutput("Out", "").AsDuplicable(); |
| 43 | + AddComment(""); |
| 44 | + } |
| 45 | +}; |
| 46 | + |
| 47 | +class SumOpWithKernel : public OperatorWithKernel { |
| 48 | + public: |
| 49 | + using OperatorWithKernel::OperatorWithKernel; |
| 50 | + |
| 51 | + protected: |
| 52 | + void InferShape(framework::InferShapeContext *ctx) const override {} |
| 53 | + OpKernelType GetExpectedKernelType( |
| 54 | + const ExecutionContext &ctx) const override { |
| 55 | + return OpKernelType(proto::VarType::FP32, ctx.Input<Tensor>("X")->place()); |
| 56 | + } |
| 57 | +}; |
| 58 | + |
| 59 | +} // namespace framework |
| 60 | +} // namespace paddle |
| 61 | + |
| 62 | +REGISTER_OP_WITHOUT_GRADIENT(sum, paddle::framework::SumOpWithKernel, |
| 63 | + paddle::framework::SumOpMaker); |
| 64 | + |
| 65 | +namespace paddle { |
| 66 | +namespace framework { |
| 67 | +namespace details { |
| 68 | + |
| 69 | +static std::vector<platform::Place> CreatePlaces(size_t num, bool use_cuda) { |
| 70 | + std::vector<platform::Place> result; |
| 71 | + result.reserve(num); |
| 72 | + for (size_t i = 0; i < num; ++i) { |
| 73 | + if (use_cuda) { |
| 74 | + result.emplace_back(platform::CUDAPlace(i)); |
| 75 | + } else { |
| 76 | + result.emplace_back(platform::CPUPlace()); |
| 77 | + } |
| 78 | + } |
| 79 | + return result; |
| 80 | +} |
| 81 | + |
| 82 | +void BuildStrategyApply(BuildStrategy *build_strategy, ir::Graph *graph) { |
| 83 | + std::string loss_name = ""; |
| 84 | + Scope scope; |
| 85 | + std::vector<Scope *> scopes = {&scope}; |
| 86 | + |
| 87 | + auto places = CreatePlaces(1, false); |
| 88 | + auto device = platform::Place2DeviceType(places[0]); |
| 89 | + |
| 90 | +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) |
| 91 | + platform::NCCLCommunicator ctxs; |
| 92 | +#elif defined(PADDLE_WITH_XPU) && defined(PADDLE_WITH_XPU_BKCL) |
| 93 | + platform::BKCLCommunicator ctxs; |
| 94 | +#endif |
| 95 | + |
| 96 | + build_strategy->Apply(graph, places, loss_name, scopes, 1, |
| 97 | +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) |
| 98 | + device, &ctxs); |
| 99 | +#elif defined(PADDLE_WITH_XPU) && defined(PADDLE_WITH_XPU_BKCL) |
| 100 | + device, &ctxs); |
| 101 | +#else |
| 102 | + device); |
| 103 | +#endif |
| 104 | +} |
| 105 | + |
| 106 | +std::unique_ptr<ir::Graph> CreateGraph() { |
| 107 | + ProgramDesc prog; |
| 108 | + auto *op = prog.MutableBlock(0)->AppendOp(); |
| 109 | + op->SetType("sum"); |
| 110 | + op->SetInput("X", {"a1"}); |
| 111 | + op->SetOutput("Out", {"b1"}); |
| 112 | + op->SetAttr("op_role", 1); |
| 113 | + |
| 114 | + prog.MutableBlock(0)->Var("a1")->SetType(proto::VarType::LOD_TENSOR); |
| 115 | + prog.MutableBlock(0)->Var("b1")->SetType(proto::VarType::LOD_TENSOR); |
| 116 | + |
| 117 | + std::unique_ptr<ir::Graph> g(new ir::Graph(prog)); |
| 118 | + return g; |
| 119 | +} |
| 120 | + |
| 121 | +std::unique_ptr<ir::Graph> CreateMultiGraph() { |
| 122 | + ProgramDesc prog; |
| 123 | + prog.AppendBlock(prog.Block(0)); |
| 124 | + prog.AppendBlock(prog.Block(0)); |
| 125 | + |
| 126 | + // Set contents in block_0. |
| 127 | + auto *op = prog.MutableBlock(0)->AppendOp(); |
| 128 | + op->SetType("sum"); |
| 129 | + op->SetInput("X", {"test_a", "test_b", "test_c"}); |
| 130 | + op->SetOutput("Out", {"test_out"}); |
| 131 | + op->SetAttr("op_role", 1); |
| 132 | + |
| 133 | + prog.MutableBlock(0)->Var("test_a")->SetType(proto::VarType::SELECTED_ROWS); |
| 134 | + prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::SELECTED_ROWS); |
| 135 | + prog.MutableBlock(0)->Var("test_c")->SetType(proto::VarType::SELECTED_ROWS); |
| 136 | + prog.MutableBlock(0)->Var("test_out"); |
| 137 | + op->InferVarType(prog.MutableBlock(0)); |
| 138 | + |
| 139 | + prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::LOD_TENSOR); |
| 140 | + op->InferVarType(prog.MutableBlock(0)); |
| 141 | + |
| 142 | + // Set contents in block_1. |
| 143 | + op = prog.MutableBlock(1)->AppendOp(); |
| 144 | + op->SetType("sum"); |
| 145 | + op->SetInput("X", {"a1"}); |
| 146 | + op->SetOutput("Out", {"b1"}); |
| 147 | + op->SetAttr("op_role", 1); |
| 148 | + |
| 149 | + prog.MutableBlock(1)->Var("a1")->SetType(proto::VarType::LOD_TENSOR); |
| 150 | + prog.MutableBlock(1)->Var("b1")->SetType(proto::VarType::LOD_TENSOR); |
| 151 | + |
| 152 | + // Set contents in block_2. |
| 153 | + op = prog.MutableBlock(2)->AppendOp(); |
| 154 | + op->SetType("sum"); |
| 155 | + op->SetInput("X", {"a2"}); |
| 156 | + op->SetOutput("Out", {"b2"}); |
| 157 | + op->SetAttr("op_role", 1); |
| 158 | + |
| 159 | + prog.MutableBlock(2)->Var("a2")->SetType(proto::VarType::LOD_TENSOR); |
| 160 | + prog.MutableBlock(2)->Var("b2")->SetType(proto::VarType::LOD_TENSOR); |
| 161 | + |
| 162 | + std::unique_ptr<ir::Graph> g(new ir::Graph(prog)); |
| 163 | + return g; |
| 164 | +} |
| 165 | + |
| 166 | +inline bool CheckSubGraphSame(ir::Graph *g1, ir::Graph *g2) { |
| 167 | + const auto &g1_nodes_set = g1->Nodes(); |
| 168 | + const auto &g2_nodes_set = g2->Nodes(); |
| 169 | + |
| 170 | + if (g1_nodes_set.size() != g2_nodes_set.size()) return false; |
| 171 | + |
| 172 | + std::vector<ir::Node *> g1_nodes(g1_nodes_set.begin(), g1_nodes_set.end()); |
| 173 | + std::vector<ir::Node *> g2_nodes(g2_nodes_set.begin(), g2_nodes_set.end()); |
| 174 | + |
| 175 | + auto comp = [](ir::Node *n1, ir::Node *n2) { |
| 176 | + return n1->Name() > n2->Name(); |
| 177 | + }; |
| 178 | + std::stable_sort(g1_nodes.begin(), g1_nodes.end(), comp); |
| 179 | + std::stable_sort(g2_nodes.begin(), g2_nodes.end(), comp); |
| 180 | + |
| 181 | + for (size_t i = 0; i < g1_nodes.size(); ++i) { |
| 182 | + const auto &n1 = g1_nodes[i]; |
| 183 | + const auto &n2 = g2_nodes[i]; |
| 184 | + |
| 185 | + if (n1->NodeType() != n2->NodeType()) return false; |
| 186 | + if (n1->Name() != n2->Name()) return false; |
| 187 | + |
| 188 | + auto n1_inputs = n1->inputs; |
| 189 | + auto n2_inputs = n2->inputs; |
| 190 | + if (n1_inputs.size() != n2_inputs.size()) return false; |
| 191 | + |
| 192 | + std::stable_sort(n1_inputs.begin(), n1_inputs.end(), comp); |
| 193 | + std::stable_sort(n2_inputs.begin(), n2_inputs.end(), comp); |
| 194 | + for (size_t i = 0; i < n1_inputs.size(); ++i) { |
| 195 | + if (n1_inputs[i]->Name() != n2_inputs[i]->Name()) return false; |
| 196 | + } |
| 197 | + |
| 198 | + auto n1_outputs = n1->outputs; |
| 199 | + auto n2_outputs = n2->outputs; |
| 200 | + if (n1_outputs.size() != n2_outputs.size()) return false; |
| 201 | + |
| 202 | + std::stable_sort(n1_outputs.begin(), n1_outputs.end(), comp); |
| 203 | + std::stable_sort(n2_outputs.begin(), n2_outputs.end(), comp); |
| 204 | + for (size_t i = 0; i < n1_outputs.size(); ++i) { |
| 205 | + if (n1_outputs[i]->Name() != n2_outputs[i]->Name()) return false; |
| 206 | + } |
| 207 | + |
| 208 | + if (n1->IsVar()) { |
| 209 | + const auto &var1 = n1->Var(); |
| 210 | + const auto &var2 = n2->Var(); |
| 211 | + if ((var1 == nullptr) != (var2 == nullptr)) return false; |
| 212 | + } |
| 213 | + |
| 214 | + if (n1->IsOp()) { |
| 215 | + const auto &op1 = n1->Op(); |
| 216 | + const auto &op2 = n2->Op(); |
| 217 | + if ((op1 == nullptr) != (op2 == nullptr)) return false; |
| 218 | + |
| 219 | + const auto &op1_input = op1->InputNames(); |
| 220 | + const auto &op2_input = op2->InputNames(); |
| 221 | + if (op1_input.size() != op2_input.size()) return false; |
| 222 | + if (op1_input != op2_input) return false; |
| 223 | + |
| 224 | + for (size_t i = 0; i < op1_input.size(); ++i) { |
| 225 | + if (op1->Input(op1_input[i]) != op2->Input(op2_input[i])) return false; |
| 226 | + } |
| 227 | + |
| 228 | + const auto &op1_output = op1->OutputNames(); |
| 229 | + const auto &op2_output = op2->OutputNames(); |
| 230 | + if (op1_output.size() != op2_output.size()) return false; |
| 231 | + if (op1_output != op2_output) return false; |
| 232 | + |
| 233 | + for (size_t i = 0; i < op1_output.size(); ++i) { |
| 234 | + if (op1->Output(op1_output[i]) != op2->Output(op2_output[i])) |
| 235 | + return false; |
| 236 | + } |
| 237 | + } |
| 238 | + } |
| 239 | + return true; |
| 240 | +} |
| 241 | + |
| 242 | +inline bool CheckGraphSame(ir::Graph *g1, ir::Graph *g2) { |
| 243 | + if (g1 == nullptr || g2 == nullptr) return true; |
| 244 | + |
| 245 | + if (FLAGS_convert_all_blocks) { |
| 246 | + if (g1->SubGraphsSize() != g2->SubGraphsSize()) return false; |
| 247 | + |
| 248 | + for (size_t i = 0; i < g1->SubGraphsSize(); ++i) { |
| 249 | + if (!CheckSubGraphSame(g1->GetSubGraph(i), g2->GetSubGraph(i))) |
| 250 | + return false; |
| 251 | + } |
| 252 | + } else { |
| 253 | + if (!CheckSubGraphSame(g1, g2)) return false; |
| 254 | + } |
| 255 | + return true; |
| 256 | +} |
| 257 | + |
| 258 | +TEST(BuildStrategy, Basic) { |
| 259 | + BuildStrategy build_strategy; |
| 260 | + |
| 261 | + ProgramDesc prog; |
| 262 | + ir::Graph old_graph(prog), graph(prog); |
| 263 | + |
| 264 | + BuildStrategyApply(&build_strategy, &graph); |
| 265 | + |
| 266 | + ASSERT_TRUE(CheckGraphSame(&old_graph, &graph)); |
| 267 | +} |
| 268 | + |
| 269 | +TEST(BuildStrategy, TestSingleGraph) { |
| 270 | + BuildStrategy build_strategy; |
| 271 | + auto graph = CreateGraph(); |
| 272 | + ir::Graph old_graph(graph->OriginProgram()); |
| 273 | + |
| 274 | + BuildStrategyApply(&build_strategy, graph.get()); |
| 275 | + |
| 276 | + // graph should not change for no pass here |
| 277 | + ASSERT_TRUE(CheckGraphSame(&old_graph, graph.get())); |
| 278 | +} |
| 279 | + |
| 280 | +TEST(BuildStrategy, TestMultiGraph) { |
| 281 | + // Set FLAGS_convert_all_blocks to true to make sure this test works. |
| 282 | + bool flag_temp = FLAGS_convert_all_blocks; |
| 283 | + FLAGS_convert_all_blocks = true; |
| 284 | + |
| 285 | + BuildStrategy build_strategy; |
| 286 | + auto graph = CreateMultiGraph(); |
| 287 | + ir::Graph old_graph(graph->OriginProgram()); |
| 288 | + |
| 289 | + BuildStrategyApply(&build_strategy, graph.get()); |
| 290 | + |
| 291 | + // graph should not change for no pass here |
| 292 | + ASSERT_TRUE(CheckGraphSame(&old_graph, graph.get())); |
| 293 | + |
| 294 | + // Recover FLAGS_convert_all_blocks. |
| 295 | + FLAGS_convert_all_blocks = flag_temp; |
| 296 | +} |
| 297 | + |
| 298 | +} // namespace details |
| 299 | +} // namespace framework |
| 300 | +} // namespace paddle |
0 commit comments