This repository was archived by the owner on Jan 24, 2024. It is now read-only.
File tree Expand file tree Collapse file tree 2 files changed +11
-4
lines changed
Expand file tree Collapse file tree 2 files changed +11
-4
lines changed Original file line number Diff line number Diff line change @@ -219,7 +219,7 @@ class Adam {
219219 void Update (VariableHandle input) {
220220 PADDLE_ENFORCE (get_global_tape ().HasBeenBackwarded (),
221221 " optimization must happen after the backward" );
222- auto hyperparams = input->MutableHyperParams ();
222+ auto * hyperparams = input->MutableHyperParams (" adam " );
223223 // initialize states if they haven't been created
224224 if (hyperparams->empty ()) {
225225 framework::AttributeMap attrs;
Original file line number Diff line number Diff line change @@ -80,7 +80,13 @@ class Variable {
8080 return var_.GetMutable <T>();
8181 }
8282
83- std::vector<VariableHandle>* MutableHyperParams () { return &hyperparams_; }
83+ std::vector<VariableHandle>* MutableHyperParams (
84+ const std::string& optimizer) {
85+ PADDLE_ENFORCE (hyperparams_.find (optimizer) != hyperparams_.end (),
86+ " %s optimizer is not supported" ,
87+ optimizer);
88+ return &hyperparams_[optimizer];
89+ }
8490
8591 private:
8692 int count () {
@@ -94,8 +100,9 @@ class Variable {
94100 // Not own
95101 std::weak_ptr<Variable> grad_;
96102
97- // Adam Optimizer hyperparameter
98- std::vector<VariableHandle> hyperparams_;
103+ // Optimizer hyperparameters
104+ std::unordered_map<std::string, std::vector<VariableHandle>> hyperparams_{
105+ {" adam" , {}}};
99106};
100107
101108} // namespace tape
You can’t perform that action at this time.
0 commit comments