Skip to content

Commit ce37cb9

Browse files
committed
reslove merge conflict. refine the example code.
2 parents 90d59cb + 7c1ff38 commit ce37cb9

File tree

139 files changed

+9647
-1546
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

139 files changed

+9647
-1546
lines changed

paddle/fluid/framework/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ cc_test(op_compatible_info_test SRCS op_compatible_info_test.cc DEPS op_compatib
268268

269269
cc_library(save_load_util SRCS save_load_util DEPS tensor scope layer)
270270
cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tensor scope layer)
271+
cc_library(generator SRCS generator.cc)
271272

272273
# Get the current working branch
273274
execute_process(

paddle/fluid/framework/fleet/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,6 @@ else()
1919
cc_library(gloo_wrapper SRCS gloo_wrapper.cc DEPS framework_proto variable_helper scope)
2020
endif(WITH_GLOO)
2121

22-
cc_library(heter_wrapper SRCS heter_wrapper.cc DEPS framework_proto device_context)
22+
cc_library(heter_wrapper SRCS heter_wrapper.cc DEPS framework_proto device_context heter_service_proto)
2323

2424
cc_test(test_fleet SRCS test_fleet.cc DEPS fleet_wrapper gloo_wrapper fs shell)
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <deque>
16+
#include <memory>
17+
#include <unordered_map>
18+
#include <unordered_set>
19+
#include <utility>
20+
21+
#include "paddle/fluid/framework/generator.h"
22+
23+
namespace paddle {
24+
namespace framework {
25+
26+
std::shared_ptr<Generator> Generator::gen_instance_ = NULL;
27+
28+
GeneratorState* Generator::GetState() {
29+
std::lock_guard<std::mutex> lock(this->mutex);
30+
return this->state_.get();
31+
}
32+
33+
void Generator::SetState(GeneratorState* state_in) {
34+
std::lock_guard<std::mutex> lock(this->mutex);
35+
*this->state_ = *state_in;
36+
}
37+
38+
uint64_t Generator::GetCurrentSeed() {
39+
std::lock_guard<std::mutex> lock(this->mutex);
40+
return this->state_->current_seed;
41+
}
42+
43+
uint64_t Generator::Seed() {
44+
std::lock_guard<std::mutex> lock(this->mutex);
45+
uint64_t seed;
46+
std::random_device de;
47+
seed = ((((uint64_t)de()) << 32) + de()) & 0x1FFFFFFFFFFFFF;
48+
this->state_->current_seed = seed;
49+
std::seed_seq seq({seed});
50+
this->state_->cpu_engine.seed(seq);
51+
52+
return this->state_->current_seed;
53+
}
54+
55+
void Generator::SetCurrentSeed(uint64_t seed) {
56+
std::lock_guard<std::mutex> lock(this->mutex);
57+
this->state_->current_seed = uint64_t(seed);
58+
std::seed_seq seq({seed});
59+
this->state_->cpu_engine.seed(seq);
60+
}
61+
62+
std::mt19937_64& Generator::GetCPUEngine() {
63+
std::lock_guard<std::mutex> lock(this->mutex);
64+
return this->state_->cpu_engine;
65+
}
66+
67+
void Generator::SetCPUEngine(std::mt19937_64 engine) {
68+
std::lock_guard<std::mutex> lock(this->mutex);
69+
this->state_->cpu_engine = std::mt19937_64(engine);
70+
}
71+
72+
uint64_t Generator::Random64() {
73+
std::lock_guard<std::mutex> lock(this->mutex);
74+
return this->state_->cpu_engine();
75+
}
76+
77+
} // namespace framework
78+
} // namespace paddle

paddle/fluid/framework/generator.h

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <stdint.h>
18+
#include <atomic>
19+
#include <deque>
20+
#include <iostream> // temp for debug
21+
#include <memory>
22+
#include <mutex> // NOLINT
23+
#include <random>
24+
#include <typeinfo>
25+
#include <utility>
26+
27+
namespace paddle {
28+
namespace framework {
29+
30+
struct GeneratorState {
31+
int64_t device = -1;
32+
uint64_t current_seed = 34342423252;
33+
std::mt19937_64 cpu_engine;
34+
};
35+
36+
struct Generator {
37+
Generator() {
38+
GeneratorState default_gen_state_cpu;
39+
default_gen_state_cpu.device = -1;
40+
default_gen_state_cpu.current_seed = 34342423252;
41+
std::seed_seq seq({34342423252});
42+
default_gen_state_cpu.cpu_engine = std::mt19937_64(seq);
43+
this->state_ = std::make_shared<GeneratorState>(default_gen_state_cpu);
44+
}
45+
explicit Generator(GeneratorState state_in)
46+
: state_{std::make_shared<GeneratorState>(state_in)} {}
47+
Generator(const Generator& other)
48+
: Generator(other, std::lock_guard<std::mutex>(other.mutex)) {}
49+
50+
// get random state
51+
GeneratorState* GetState();
52+
// set random state
53+
void SetState(GeneratorState* state_in);
54+
// get current seed
55+
uint64_t GetCurrentSeed();
56+
// random a seed and get
57+
uint64_t Seed();
58+
59+
// set seed
60+
void SetCurrentSeed(uint64_t seed);
61+
// get cpu engine
62+
std::mt19937_64& GetCPUEngine();
63+
// set cpu engine
64+
void SetCPUEngine(std::mt19937_64 engine);
65+
66+
uint64_t Random64();
67+
68+
bool is_init_py = false;
69+
70+
// CPU Generator singleton
71+
static std::shared_ptr<Generator> GetInstance() {
72+
if (NULL == gen_instance_) {
73+
gen_instance_.reset(new paddle::framework::Generator());
74+
}
75+
return gen_instance_;
76+
}
77+
78+
static std::shared_ptr<Generator> GetInstanceX() {
79+
if (NULL == gen_instance_) {
80+
gen_instance_.reset(new paddle::framework::Generator());
81+
}
82+
gen_instance_->is_init_py = true;
83+
return gen_instance_;
84+
}
85+
86+
private:
87+
static std::shared_ptr<Generator> gen_instance_;
88+
std::shared_ptr<GeneratorState> state_;
89+
mutable std::mutex mutex;
90+
91+
Generator(const Generator& other, const std::lock_guard<std::mutex>&)
92+
: state_(std::make_shared<GeneratorState>(*(other.state_))) {}
93+
};
94+
95+
} // namespace framework
96+
} // namespace paddle

paddle/fluid/framework/ir/conv_bn_fuse_pass.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,3 +368,7 @@ REGISTER_PASS(conv_transpose_bn_fuse_pass,
368368
paddle::framework::ir::ConvTransposeBNFusePass);
369369
REGISTER_PASS(conv_transpose_eltwiseadd_bn_fuse_pass,
370370
paddle::framework::ir::ConvTransposeEltwiseAddBNFusePass);
371+
REGISTER_PASS(depthwise_conv_bn_fuse_pass,
372+
paddle::framework::ir::DepthwiseConvBNFusePass);
373+
REGISTER_PASS(depthwise_conv_eltwiseadd_bn_fuse_pass,
374+
paddle::framework::ir::DepthwiseConvEltwiseAddBNFusePass);

paddle/fluid/framework/ir/conv_bn_fuse_pass.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,16 @@ class ConvTransposeEltwiseAddBNFusePass : public ConvEltwiseAddBNFusePass {
5656
std::string conv_type() const { return "conv2d_transpose"; }
5757
};
5858

59+
class DepthwiseConvBNFusePass : public ConvBNFusePass {
60+
public:
61+
std::string conv_type() const { return "depthwise_conv2d"; }
62+
};
63+
64+
class DepthwiseConvEltwiseAddBNFusePass : public ConvEltwiseAddBNFusePass {
65+
public:
66+
std::string conv_type() const { return "depthwise_conv2d"; }
67+
};
68+
5969
} // namespace ir
6070
} // namespace framework
6171
} // namespace paddle

paddle/fluid/framework/ir/subgraph_detector.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,8 @@ std::vector<std::vector<Node *>> SubgraphDetector::ExtractSubGraphs() {
309309
BriefNode *brief_node = itr.second;
310310

311311
if (!Agent(brief_node->node).marked()) {
312-
VLOG(4) << brief_node->node->id() << " node not a trt candidate.";
312+
VLOG(4) << brief_node->node->id() << " node named "
313+
<< brief_node->node->Name() << " is not a trt candidate.";
313314
continue;
314315
}
315316

paddle/fluid/framework/prune.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,23 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
210210
should_run.push_back(true);
211211
} else {
212212
should_run.push_back(false);
213+
// If the output of an op modifies feed vars, the op should not clip.
214+
// For example, in the transformer structure, the third parameter returned
215+
// by beam_search op is generally assigned to a feed var. Cutting the
216+
// assign op will cause an error.
217+
if (parent_block_id != -1) {
218+
bool flag = false;
219+
for (auto& var : op_desc.outputs()) {
220+
for (auto& argu : var.arguments()) {
221+
if (feed_var_names.count(argu)) {
222+
flag = true;
223+
}
224+
}
225+
}
226+
if (flag) {
227+
should_run.back() = true;
228+
}
229+
}
213230
}
214231
}
215232

paddle/fluid/framework/prune_test.cc

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,34 @@ TEST(Prune, recurrrent_op) {
185185
EXPECT_EQ(pruned.blocks(0).ops_size(), 2);
186186
EXPECT_EQ(pruned.blocks(1).ops_size(), 1);
187187
}
188+
189+
// If the output of an op modifies feed vars, the op should not clip.
190+
TEST(Prune, recurrrent_op_2) {
191+
f::ProgramDesc program;
192+
f::BlockDesc *block = program.MutableBlock(0);
193+
f::BlockDesc *sub_block = program.AppendBlock(*block);
194+
AddOp("one_two", {{"input", {"a"}}}, {{"output", {"b", "c"}}},
195+
f::AttributeMap{}, block);
196+
197+
std::vector<std::string> state_var_name(1, "y");
198+
AddOp("recurrent", {{"input", {"b", "c"}}}, {{"output", {"b1, c1"}}},
199+
{{"ex_states", state_var_name},
200+
{"states", state_var_name},
201+
{"sub_block", sub_block}},
202+
block);
203+
204+
EXPECT_TRUE(sub_block != nullptr);
205+
AddOp("rnn_memory_helper", {{"input", {"x"}}}, {{"output", {"a"}}},
206+
f::AttributeMap{}, sub_block);
207+
208+
f::proto::ProgramDesc *pdesc = program.Proto();
209+
pdesc->mutable_blocks(0)->mutable_ops(1)->set_is_target(true);
210+
211+
f::proto::ProgramDesc pruned;
212+
std::set<std::string> feed_var_names = {"x", "a"};
213+
214+
f::Prune(*pdesc, feed_var_names, &pruned);
215+
EXPECT_EQ(pruned.blocks_size(), 2);
216+
EXPECT_EQ(pruned.blocks(0).ops_size(), 2);
217+
EXPECT_EQ(pruned.blocks(1).ops_size(), 1);
218+
}

paddle/fluid/inference/tensorrt/engine.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,12 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector<T>& shape, std::string input,
8383
} else if (shape.size() == 3UL) {
8484
return nvinfer1::Dims3(shape[0], shape[1], shape[2]);
8585
}
86-
return nvinfer1::Dims4(shape[0], shape[1], 1, 1);
86+
nvinfer1::Dims dims;
87+
dims.nbDims = shape.size();
88+
for (size_t i = 0; i < shape.size(); i++) {
89+
dims.d[i] = shape[i];
90+
}
91+
return dims;
8792
}
8893
}
8994
} // NOLINT

0 commit comments

Comments
 (0)