Skip to content

Commit 0dd6a44

Browse files
authored
Revert "Revert "[Dy2Stat] Refactor ExecutorCache logic and pre-support BuildStrategy for pass (#34181)" (#34348)" (#34384)
This reverts commit 577fdde.
1 parent 937e21a commit 0dd6a44

File tree

11 files changed

+166
-148
lines changed

11 files changed

+166
-148
lines changed

paddle/fluid/framework/executor_cache.cc

Lines changed: 27 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/framework/executor_cache.h"
16+
#include "paddle/fluid/framework/op_info.h"
1617

1718
namespace paddle {
1819
namespace framework {
@@ -25,11 +26,11 @@ namespace framework {
2526

2627
namespace details {
2728

28-
static ExecutionStrategy GetExecutionStrategy(
29-
const ExecutorInfoCache::CacheKey &cache_key) {
29+
static ExecutionStrategy GetExecutionStrategy(const platform::Place &place) {
3030
framework::ExecutionStrategy execution_strategy;
3131

32-
switch (cache_key.device_type_) {
32+
auto device_type = platform::Place2DeviceType(place);
33+
switch (device_type) {
3334
case platform::DeviceType::CPU: {
3435
execution_strategy.num_threads_ = 2;
3536
break;
@@ -46,9 +47,9 @@ static ExecutionStrategy GetExecutionStrategy(
4647
}
4748
default:
4849
PADDLE_THROW(platform::errors::Unavailable("Unsupported Device type %d.",
49-
cache_key.device_type_));
50+
device_type));
5051
}
51-
execution_strategy.use_device_ = cache_key.device_type_;
52+
execution_strategy.use_device_ = device_type;
5253

5354
return execution_strategy;
5455
}
@@ -136,58 +137,51 @@ ExecutorInfoCache &ExecutorInfoCache::Instance() {
136137
return g_exe_cache_info_map;
137138
}
138139

139-
void ExecutorInfoCache::Finalize() {
140-
// NOTE(Aurelius84): DO NOT perform finalize in destructor
141-
// to avoid problems caused by destructor order of static
142-
// object.
143-
info_map_.clear();
144-
}
145-
146-
CacheInfo GetExecutorInfoFromCache(const ExecutorInfoCache::CacheKey &cache_key,
140+
CacheInfo GetExecutorInfoFromCache(const ProgramDesc &program_desc,
141+
const platform::Place &place,
142+
int64_t start_op_index, int64_t end_op_index,
143+
bool is_grad, int64_t program_id,
147144
framework::Scope *scope) {
148145
auto &cached_exe_info = framework::ExecutorInfoCache::Instance();
149146

150-
if (!cached_exe_info.Has(cache_key)) {
151-
VLOG(1) << "create exe_info for " << cache_key.DebugString();
152-
147+
if (!cached_exe_info.Has(program_id, is_grad)) {
153148
// TODO(Aurelius84): Consider to use LRU algorithm to replace this.
154149
if (cached_exe_info.Size() > 4u /* max_cached_size*/) {
155150
VLOG(2) << "The cached info size has exceeded max_cached_size: 4, clear "
156151
"all cache!";
157152
cached_exe_info.Finalize();
158153
}
159154

160-
framework::BuildStrategy build_strategy;
161-
auto execution_strategy = details::GetExecutionStrategy(cache_key);
155+
VLOG(1) << "create exe_info for " << program_id << " is_grad: " << is_grad;
156+
auto execution_strategy = details::GetExecutionStrategy(place);
157+
auto &build_strategy = cached_exe_info.GetBuildStrategy(program_id);
162158

159+
// 2. Construct Graph and ParallelExecutor.
163160
auto graph = std::make_shared<framework::ir::Graph>(
164-
*cache_key.program_desc_, cache_key.start_op_index_,
165-
cache_key.end_op_index_);
161+
program_desc, start_op_index, end_op_index);
166162
auto parallel_executor = std::make_shared<framework::ParallelExecutor>(
167-
cache_key.place_, scope, execution_strategy, build_strategy,
168-
graph.get());
163+
place, scope, execution_strategy, build_strategy, graph.get());
169164
parallel_executor->PrepareVariables(scope);
170165

171-
framework::ExecutorInfoCache::ValueType cache_val = {parallel_executor,
172-
graph};
173-
cached_exe_info.Insert(cache_key, cache_val);
174-
175-
bool is_new_created = true;
176-
return std::make_pair(parallel_executor, is_new_created);
166+
// 3. Insert value into cached map.
167+
auto &cached_value = cached_exe_info.GetMutable(program_id, is_grad);
168+
cached_value.executor_ = parallel_executor;
169+
cached_value.graph_ = std::move(graph);
170+
return std::make_pair(parallel_executor, /*is_new_created=*/true);
177171
} else {
178-
VLOG(1) << "get exe_info from cache by: " << cache_key.DebugString();
179-
bool is_new_created = false;
180-
auto cache_val = cached_exe_info.GetMutable(cache_key);
181-
auto parallel_executor = cache_val.first;
172+
VLOG(1) << "get exe_info from cache by: " << program_id
173+
<< " is_grad: " << is_grad;
174+
auto &cached_value = cached_exe_info.GetMutable(program_id, is_grad);
182175

176+
auto &parallel_executor = cached_value.executor_;
183177
// update op_handle scope_map in pe->executor_->Graph
184178
std::unordered_map<Scope *, Scope *> scope_map = {
185179
{parallel_executor->GetLocalScopes().front(), scope}};
186180
parallel_executor->ResetOpHandleScopeMapOfGraphs(scope_map);
187181
// need to recreate tmp variables in new scope
188182
parallel_executor->PrepareVariables(scope);
189183

190-
return std::make_pair(parallel_executor, is_new_created);
184+
return std::make_pair(parallel_executor, /*is_new_created=*/false);
191185
}
192186
}
193187

paddle/fluid/framework/executor_cache.h

Lines changed: 62 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -45,121 +45,92 @@ void ParseSafeEagerDeletionSkipVars(
4545
std::vector<std::string>* skip_eager_delete_vars);
4646

4747
} // namespace details
48-
class ExecutorInfoCache {
48+
49+
class ExecutorInfo {
4950
public:
50-
struct CacheKey {
51-
CacheKey(const ProgramDesc* program_desc, const platform::Place& place,
52-
int64_t start_op_index, int64_t end_op_index, bool is_grad)
53-
: program_desc_(program_desc),
54-
place_(place),
55-
start_op_index_(start_op_index),
56-
end_op_index_(end_op_index),
57-
is_grad_(is_grad) {
58-
device_type_ = platform::Place2DeviceType(place);
59-
PADDLE_ENFORCE_NOT_NULL(program_desc_,
60-
"program_desc should not be null.");
61-
}
62-
63-
std::string DebugString() const {
64-
std::stringstream ss;
65-
66-
ss << "\n CacheKey(program_desc: " << program_desc_;
67-
ss << ", start_op_index: " << start_op_index_;
68-
ss << ", end_op_index: " << end_op_index_;
69-
ss << ", is_grad: " << is_grad_;
70-
ss << ", device_type: " << device_type_ << ")";
71-
72-
return ss.str();
73-
}
74-
75-
const ProgramDesc* program_desc_;
76-
platform::Place place_;
77-
int64_t start_op_index_;
78-
int64_t end_op_index_;
79-
bool is_grad_;
80-
platform::DeviceType device_type_;
81-
};
51+
struct CacheValue {
52+
std::shared_ptr<ParallelExecutor> executor_{nullptr};
53+
std::shared_ptr<ir::Graph> graph_{nullptr};
8254

83-
using KeyType = size_t;
84-
using ValueType =
85-
std::pair<std::shared_ptr<ParallelExecutor>, std::shared_ptr<ir::Graph>>;
86-
87-
struct KeyHasher {
88-
size_t operator()(const CacheKey& key) const noexcept {
89-
size_t seed = 10;
90-
auto* prog_desc = key.program_desc_;
91-
/*
92-
* Note(Aurelius84): DO NOT use only ProgramDesc* to calculate hash value
93-
* because a new program will hold same pointer address after an older
94-
* program is destructed with a small probability. Add op size while
95-
* hashing because program may contains at least one block.
96-
*/
97-
hash_combine(&seed, prog_desc);
98-
for (size_t i = 0; i < prog_desc->Size(); ++i) {
99-
hash_combine(&seed, &prog_desc->Block(i));
100-
hash_combine(&seed, prog_desc->Block(i).OpSize());
101-
}
102-
hash_combine(&seed, static_cast<int>(key.device_type_));
103-
hash_combine(&seed, key.start_op_index_);
104-
hash_combine(&seed, key.end_op_index_);
105-
hash_combine(&seed, key.is_grad_);
106-
VLOG(3) << "hash value is : " << seed
107-
<< " of key: " << key.DebugString();
108-
return seed;
109-
}
110-
111-
template <typename T>
112-
void hash_combine(size_t* seed, const T& val) const {
113-
std::hash<T> hasher;
114-
(*seed) ^= hasher(val) + 0x9e3779b9 + ((*seed) << 6) + ((*seed >> 2));
115-
}
55+
std::vector<std::string> skip_eager_delete_vars_;
11656
};
11757

58+
bool IsAvailable(bool is_grad) {
59+
const auto& executor =
60+
is_grad ? backward_info_.executor_ : forward_info_.executor_;
61+
return executor != nullptr;
62+
}
63+
64+
CacheValue& GetMutable(bool is_grad) {
65+
return is_grad ? backward_info_ : forward_info_;
66+
}
67+
68+
private:
69+
CacheValue forward_info_;
70+
CacheValue backward_info_;
71+
};
72+
73+
class ExecutorInfoCache {
74+
public:
11875
static ExecutorInfoCache& Instance();
11976

120-
ValueType GetMutable(const CacheKey& key) {
121-
auto key_val = key_hash_func_(key);
77+
const BuildStrategy& GetBuildStrategy(int64_t program_id) {
78+
// If not found, insert build_strategy with default value.
79+
return strategy_map_[program_id];
80+
}
81+
82+
void SetBuildStrategy(int64_t program_id,
83+
const BuildStrategy& build_strategy) {
12284
PADDLE_ENFORCE_EQ(
123-
Has(key_val), true,
124-
platform::errors::NotFound("%s doesn't exist in ExecutorInfoCache",
125-
key.DebugString()));
126-
return info_map_[key_val];
85+
strategy_map_.count(program_id), 0,
86+
platform::errors::PreconditionNotMet(
87+
"program_id: %s already exist in ExecutorInfoCache", program_id));
88+
strategy_map_[program_id] = build_strategy;
12789
}
12890

129-
bool Has(const CacheKey& key) const {
130-
auto key_val = key_hash_func_(key);
131-
return Has(key_val);
91+
bool Has(int64_t program_id, bool is_grad) {
92+
return info_map_.find(program_id) != info_map_.end() &&
93+
info_map_[program_id].IsAvailable(is_grad);
13294
}
13395

134-
bool Has(const KeyType& key) const {
135-
return info_map_.find(key) != info_map_.end();
96+
ExecutorInfo::CacheValue& GetMutable(int64_t program_id, bool is_grad) {
97+
return info_map_[program_id].GetMutable(is_grad);
13698
}
13799

138-
void Insert(const CacheKey& key, ValueType value) {
139-
auto key_val = key_hash_func_(key);
140-
PADDLE_ENFORCE_EQ(
141-
Has(key_val), false,
142-
platform::errors::NotFound("%s has existed in ExecutorInfoCache",
143-
key.DebugString()));
144-
info_map_.insert({key_val, value});
100+
void UpdateSkipEagerDeleteVars(int64_t program_id, bool is_grad,
101+
const std::vector<std::string>& skip_vars) {
102+
auto& cached_value = GetMutable(program_id, is_grad);
103+
cached_value.skip_eager_delete_vars_ = std::move(skip_vars);
104+
}
105+
106+
std::vector<std::string>& SkipEagerDeleteVars(int64_t program_id,
107+
bool is_grad) {
108+
auto& cached_value = GetMutable(program_id, is_grad);
109+
return cached_value.skip_eager_delete_vars_;
145110
}
146111

147112
size_t Size() const { return info_map_.size(); }
148113

149-
void Finalize();
114+
void Finalize() {
115+
// NOTE(Aurelius84): DO NOT perform finalize in destructor
116+
// to avoid problems caused by destructor order of static
117+
// object.
118+
info_map_.clear();
119+
strategy_map_.clear();
120+
}
150121

151122
private:
152-
ExecutorInfoCache() = default;
153-
DISABLE_COPY_AND_ASSIGN(ExecutorInfoCache);
154-
155-
KeyHasher key_hash_func_;
156-
std::unordered_map<KeyType, ValueType> info_map_;
123+
std::unordered_map<int64_t, ExecutorInfo> info_map_;
124+
std::unordered_map<int64_t, BuildStrategy> strategy_map_;
157125
};
158126

159127
using CacheInfo =
160128
std::pair<std::shared_ptr<ParallelExecutor>, bool /*is_new_created*/>;
161129

162-
CacheInfo GetExecutorInfoFromCache(const ExecutorInfoCache::CacheKey& cache_key,
130+
CacheInfo GetExecutorInfoFromCache(const ProgramDesc& program_desc,
131+
const platform::Place& place,
132+
int64_t start_op_index, int64_t end_op_index,
133+
bool is_grad, int64_t program_id,
163134
framework::Scope* scope);
164135

165136
} // namespace framework

paddle/fluid/operators/run_program_op.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,10 @@ class RunProgramOpMaker : public framework::OpProtoAndCheckerMaker {
103103
"(bool, default false) Set to true for inference only, false "
104104
"for training.")
105105
.SetDefault(false);
106+
AddAttr<int64_t>(
107+
"program_id",
108+
"(int64_t)"
109+
"The unique hash id used as cache key for ExecutorInfoCache.");
106110
AddComment(R"DOC(
107111
RunProgram operator.
108112

0 commit comments

Comments
 (0)