@@ -45,121 +45,92 @@ void ParseSafeEagerDeletionSkipVars(
4545 std::vector<std::string>* skip_eager_delete_vars);
4646
4747} // namespace details
48- class ExecutorInfoCache {
48+
49+ class ExecutorInfo {
4950 public:
50- struct CacheKey {
51- CacheKey (const ProgramDesc* program_desc, const platform::Place& place,
52- int64_t start_op_index, int64_t end_op_index, bool is_grad)
53- : program_desc_(program_desc),
54- place_ (place),
55- start_op_index_(start_op_index),
56- end_op_index_(end_op_index),
57- is_grad_(is_grad) {
58- device_type_ = platform::Place2DeviceType (place);
59- PADDLE_ENFORCE_NOT_NULL (program_desc_,
60- " program_desc should not be null." );
61- }
62-
63- std::string DebugString () const {
64- std::stringstream ss;
65-
66- ss << " \n CacheKey(program_desc: " << program_desc_;
67- ss << " , start_op_index: " << start_op_index_;
68- ss << " , end_op_index: " << end_op_index_;
69- ss << " , is_grad: " << is_grad_;
70- ss << " , device_type: " << device_type_ << " )" ;
71-
72- return ss.str ();
73- }
74-
75- const ProgramDesc* program_desc_;
76- platform::Place place_;
77- int64_t start_op_index_;
78- int64_t end_op_index_;
79- bool is_grad_;
80- platform::DeviceType device_type_;
81- };
51+ struct CacheValue {
52+ std::shared_ptr<ParallelExecutor> executor_{nullptr };
53+ std::shared_ptr<ir::Graph> graph_{nullptr };
8254
83- using KeyType = size_t ;
84- using ValueType =
85- std::pair<std::shared_ptr<ParallelExecutor>, std::shared_ptr<ir::Graph>>;
86-
87- struct KeyHasher {
88- size_t operator ()(const CacheKey& key) const noexcept {
89- size_t seed = 10 ;
90- auto * prog_desc = key.program_desc_ ;
91- /*
92- * Note(Aurelius84): DO NOT use only ProgramDesc* to calculate hash value
93- * because a new program will hold same pointer address after an older
94- * program is destructed with a small probability. Add op size while
95- * hashing because program may contains at least one block.
96- */
97- hash_combine (&seed, prog_desc);
98- for (size_t i = 0 ; i < prog_desc->Size (); ++i) {
99- hash_combine (&seed, &prog_desc->Block (i));
100- hash_combine (&seed, prog_desc->Block (i).OpSize ());
101- }
102- hash_combine (&seed, static_cast <int >(key.device_type_ ));
103- hash_combine (&seed, key.start_op_index_ );
104- hash_combine (&seed, key.end_op_index_ );
105- hash_combine (&seed, key.is_grad_ );
106- VLOG (3 ) << " hash value is : " << seed
107- << " of key: " << key.DebugString ();
108- return seed;
109- }
110-
111- template <typename T>
112- void hash_combine (size_t * seed, const T& val) const {
113- std::hash<T> hasher;
114- (*seed) ^= hasher (val) + 0x9e3779b9 + ((*seed) << 6 ) + ((*seed >> 2 ));
115- }
55+ std::vector<std::string> skip_eager_delete_vars_;
11656 };
11757
58+ bool IsAvailable (bool is_grad) {
59+ const auto & executor =
60+ is_grad ? backward_info_.executor_ : forward_info_.executor_ ;
61+ return executor != nullptr ;
62+ }
63+
64+ CacheValue& GetMutable (bool is_grad) {
65+ return is_grad ? backward_info_ : forward_info_;
66+ }
67+
68+ private:
69+ CacheValue forward_info_;
70+ CacheValue backward_info_;
71+ };
72+
73+ class ExecutorInfoCache {
74+ public:
11875 static ExecutorInfoCache& Instance ();
11976
120- ValueType GetMutable (const CacheKey& key) {
121- auto key_val = key_hash_func_ (key);
77+ const BuildStrategy& GetBuildStrategy (int64_t program_id) {
78+ // If not found, insert build_strategy with default value.
79+ return strategy_map_[program_id];
80+ }
81+
82+ void SetBuildStrategy (int64_t program_id,
83+ const BuildStrategy& build_strategy) {
12284 PADDLE_ENFORCE_EQ (
123- Has (key_val ), true ,
124- platform::errors::NotFound ( " %s doesn't exist in ExecutorInfoCache " ,
125- key. DebugString () ));
126- return info_map_[key_val] ;
85+ strategy_map_. count (program_id ), 0 ,
86+ platform::errors::PreconditionNotMet (
87+ " program_id: %s already exist in ExecutorInfoCache " , program_id ));
88+ strategy_map_[program_id] = build_strategy ;
12789 }
12890
129- bool Has (const CacheKey& key) const {
130- auto key_val = key_hash_func_ (key);
131- return Has (key_val );
91+ bool Has (int64_t program_id, bool is_grad) {
92+ return info_map_. find (program_id) != info_map_. end () &&
93+ info_map_[program_id]. IsAvailable (is_grad );
13294 }
13395
134- bool Has ( const KeyType& key) const {
135- return info_map_. find (key) != info_map_. end ( );
96+ ExecutorInfo::CacheValue& GetMutable ( int64_t program_id, bool is_grad) {
97+ return info_map_[program_id]. GetMutable (is_grad );
13698 }
13799
138- void Insert (const CacheKey& key, ValueType value) {
139- auto key_val = key_hash_func_ (key);
140- PADDLE_ENFORCE_EQ (
141- Has (key_val), false ,
142- platform::errors::NotFound (" %s has existed in ExecutorInfoCache" ,
143- key.DebugString ()));
144- info_map_.insert ({key_val, value});
100+ void UpdateSkipEagerDeleteVars (int64_t program_id, bool is_grad,
101+ const std::vector<std::string>& skip_vars) {
102+ auto & cached_value = GetMutable (program_id, is_grad);
103+ cached_value.skip_eager_delete_vars_ = std::move (skip_vars);
104+ }
105+
106+ std::vector<std::string>& SkipEagerDeleteVars (int64_t program_id,
107+ bool is_grad) {
108+ auto & cached_value = GetMutable (program_id, is_grad);
109+ return cached_value.skip_eager_delete_vars_ ;
145110 }
146111
147112 size_t Size () const { return info_map_.size (); }
148113
149- void Finalize ();
114+ void Finalize () {
115+ // NOTE(Aurelius84): DO NOT perform finalize in destructor
116+ // to avoid problems caused by destructor order of static
117+ // object.
118+ info_map_.clear ();
119+ strategy_map_.clear ();
120+ }
150121
151122 private:
152- ExecutorInfoCache () = default;
153- DISABLE_COPY_AND_ASSIGN (ExecutorInfoCache);
154-
155- KeyHasher key_hash_func_;
156- std::unordered_map<KeyType, ValueType> info_map_;
123+ std::unordered_map<int64_t , ExecutorInfo> info_map_;
124+ std::unordered_map<int64_t , BuildStrategy> strategy_map_;
157125};
158126
159127using CacheInfo =
160128 std::pair<std::shared_ptr<ParallelExecutor>, bool /* is_new_created*/ >;
161129
162- CacheInfo GetExecutorInfoFromCache (const ExecutorInfoCache::CacheKey& cache_key,
130+ CacheInfo GetExecutorInfoFromCache (const ProgramDesc& program_desc,
131+ const platform::Place& place,
132+ int64_t start_op_index, int64_t end_op_index,
133+ bool is_grad, int64_t program_id,
163134 framework::Scope* scope);
164135
165136} // namespace framework
0 commit comments