24
24
from tensor2tensor .utils import mlperf_log
25
25
from tensor2tensor .utils import multistep_optimizer
26
26
from tensor2tensor .utils import yellowfin
27
+ from tensor2tensor .utils import registry
27
28
28
29
import tensorflow as tf
29
30
@@ -93,6 +94,83 @@ def optimize(loss, learning_rate, hparams, use_tpu=False, variables=None):
93
94
return train_op
94
95
95
96
97
+ @registry .register_optimizer
98
+ def adam (learning_rate , hparams ):
99
+ # We change the default epsilon for Adam.
100
+ # Using LazyAdam as it's much faster for large vocabulary embeddings.
101
+ return tf .contrib .opt .LazyAdamOptimizer (
102
+ learning_rate ,
103
+ beta1 = hparams .optimizer_adam_beta1 ,
104
+ beta2 = hparams .optimizer_adam_beta2 ,
105
+ epsilon = hparams .optimizer_adam_epsilon )
106
+
107
+
108
+ @registry .register_optimizer
109
+ def multistep_adam (learning_rate , hparams ):
110
+ return multistep_optimizer .MultistepAdamOptimizer (
111
+ learning_rate ,
112
+ beta1 = hparams .optimizer_adam_beta1 ,
113
+ beta2 = hparams .optimizer_adam_beta2 ,
114
+ epsilon = hparams .optimizer_adam_epsilon ,
115
+ n = hparams .optimizer_multistep_accumulate_steps )
116
+
117
+
118
+ @registry .register_optimizer
119
+ def momentum (learning_rate , hparams ):
120
+ return tf .train .MomentumOptimizer (
121
+ learning_rate ,
122
+ momentum = hparams .optimizer_momentum_momentum ,
123
+ use_nesterov = hparams .optimizer_momentum_nesterov )
124
+
125
+
126
+ @registry .register_optimizer
127
+ def yellow_fin (learning_rate , hparams ):
128
+ return yellowfin .YellowFinOptimizer (
129
+ learning_rate = learning_rate ,
130
+ momentum = hparams .optimizer_momentum_momentum )
131
+
132
+
133
+ @registry .register_optimizer
134
+ def true_adam (learning_rate , hparams ):
135
+ return tf .train .AdamOptimizer (
136
+ learning_rate ,
137
+ beta1 = hparams .optimizer_adam_beta1 ,
138
+ beta2 = hparams .optimizer_adam_beta2 ,
139
+ epsilon = hparams .optimizer_adam_epsilon )
140
+
141
+
142
+ @registry .register_optimizer
143
+ def adam_w (learning_rate , hparams ):
144
+ # Openai gpt used weight decay.
145
+ # Given the internals of AdamW, weight decay dependent on the
146
+ # learning rate is chosen to match the openai implementation.
147
+ # The weight decay update to each parameter is applied before the adam
148
+ # gradients computation, which is different from that described
149
+ # in the paper and in the openai implementation:
150
+ # https://arxiv.org/pdf/1711.05101.pdf
151
+ return tf .contrib .opt .AdamWOptimizer (
152
+ 0.01 * learning_rate ,
153
+ learning_rate ,
154
+ beta1 = hparams .optimizer_adam_beta1 ,
155
+ beta2 = hparams .optimizer_adam_beta2 ,
156
+ epsilon = hparams .optimizer_adam_epsilon )
157
+
158
+
159
+ @registry .register_optimizer ("Adafactor" )
160
+ def register_adafactor (learning_rate , hparams ):
161
+ return adafactor .adafactor_optimizer_from_hparams (hparams , learning_rate )
162
+
163
+
164
+ def _register_base_optimizer (key , fn ):
165
+ registry .register_optimizer (key )(
166
+ lambda learning_rate , hparams : fn (learning_rate ))
167
+
168
+
169
+ for k in tf .contrib .layers .OPTIMIZER_CLS_NAMES :
170
+ if k not in registry ._OPTIMIZERS :
171
+ _register_base_optimizer (k , tf .contrib .layers .OPTIMIZER_CLS_NAMES [k ])
172
+
173
+
96
174
class ConditionalOptimizer (tf .train .Optimizer ):
97
175
"""Conditional optimizer."""
98
176
@@ -113,53 +191,7 @@ def __init__(self, optimizer_name, lr, hparams, use_tpu=False): # pylint: disab
113
191
value = hparams .optimizer_adam_epsilon ,
114
192
hparams = hparams )
115
193
116
- if optimizer_name == "Adam" :
117
- # We change the default epsilon for Adam.
118
- # Using LazyAdam as it's much faster for large vocabulary embeddings.
119
- self ._opt = tf .contrib .opt .LazyAdamOptimizer (
120
- lr ,
121
- beta1 = hparams .optimizer_adam_beta1 ,
122
- beta2 = hparams .optimizer_adam_beta2 ,
123
- epsilon = hparams .optimizer_adam_epsilon )
124
- elif optimizer_name == "MultistepAdam" :
125
- self ._opt = multistep_optimizer .MultistepAdamOptimizer (
126
- lr ,
127
- beta1 = hparams .optimizer_adam_beta1 ,
128
- beta2 = hparams .optimizer_adam_beta2 ,
129
- epsilon = hparams .optimizer_adam_epsilon ,
130
- n = hparams .optimizer_multistep_accumulate_steps )
131
- elif optimizer_name == "Momentum" :
132
- self ._opt = tf .train .MomentumOptimizer (
133
- lr ,
134
- momentum = hparams .optimizer_momentum_momentum ,
135
- use_nesterov = hparams .optimizer_momentum_nesterov )
136
- elif optimizer_name == "YellowFin" :
137
- self ._opt = yellowfin .YellowFinOptimizer (
138
- learning_rate = lr , momentum = hparams .optimizer_momentum_momentum )
139
- elif optimizer_name == "TrueAdam" :
140
- self ._opt = tf .train .AdamOptimizer (
141
- lr ,
142
- beta1 = hparams .optimizer_adam_beta1 ,
143
- beta2 = hparams .optimizer_adam_beta2 ,
144
- epsilon = hparams .optimizer_adam_epsilon )
145
- elif optimizer_name == "AdamW" :
146
- # Openai gpt used weight decay.
147
- # Given the internals of AdamW, weight decay dependent on the
148
- # learning rate is chosen to match the openai implementation.
149
- # The weight decay update to each parameter is applied before the adam
150
- # gradients computation, which is different from that described
151
- # in the paper and in the openai implementation:
152
- # https://arxiv.org/pdf/1711.05101.pdf
153
- self ._opt = tf .contrib .opt .AdamWOptimizer (
154
- 0.01 * lr ,
155
- lr ,
156
- beta1 = hparams .optimizer_adam_beta1 ,
157
- beta2 = hparams .optimizer_adam_beta2 ,
158
- epsilon = hparams .optimizer_adam_epsilon )
159
- elif optimizer_name == "Adafactor" :
160
- self ._opt = adafactor .adafactor_optimizer_from_hparams (hparams , lr )
161
- else :
162
- self ._opt = tf .contrib .layers .OPTIMIZER_CLS_NAMES [optimizer_name ](lr )
194
+ self ._opt = registry .optimizer (optimizer_name )(lr , hparams )
163
195
if _mixed_precision_is_enabled (hparams ):
164
196
if not hparams .mixed_precision_optimizer_loss_scaler :
165
197
tf .logging .warning ("Using mixed precision without a loss scaler will "
0 commit comments