Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit ba31f44

Browse files
jackdafrozenator
authored andcommitted
added optimizer registry (#1401)
* added optimizer registry * fixed adafactor -> Adafactor * fixed default naming * improved base optimizer registration implementation
1 parent af22c24 commit ba31f44

File tree

2 files changed

+124
-47
lines changed

2 files changed

+124
-47
lines changed

tensor2tensor/utils/optimize.py

Lines changed: 79 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from tensor2tensor.utils import mlperf_log
2525
from tensor2tensor.utils import multistep_optimizer
2626
from tensor2tensor.utils import yellowfin
27+
from tensor2tensor.utils import registry
2728

2829
import tensorflow as tf
2930

@@ -93,6 +94,83 @@ def optimize(loss, learning_rate, hparams, use_tpu=False, variables=None):
9394
return train_op
9495

9596

97+
@registry.register_optimizer
98+
def adam(learning_rate, hparams):
99+
# We change the default epsilon for Adam.
100+
# Using LazyAdam as it's much faster for large vocabulary embeddings.
101+
return tf.contrib.opt.LazyAdamOptimizer(
102+
learning_rate,
103+
beta1=hparams.optimizer_adam_beta1,
104+
beta2=hparams.optimizer_adam_beta2,
105+
epsilon=hparams.optimizer_adam_epsilon)
106+
107+
108+
@registry.register_optimizer
109+
def multistep_adam(learning_rate, hparams):
110+
return multistep_optimizer.MultistepAdamOptimizer(
111+
learning_rate,
112+
beta1=hparams.optimizer_adam_beta1,
113+
beta2=hparams.optimizer_adam_beta2,
114+
epsilon=hparams.optimizer_adam_epsilon,
115+
n=hparams.optimizer_multistep_accumulate_steps)
116+
117+
118+
@registry.register_optimizer
119+
def momentum(learning_rate, hparams):
120+
return tf.train.MomentumOptimizer(
121+
learning_rate,
122+
momentum=hparams.optimizer_momentum_momentum,
123+
use_nesterov=hparams.optimizer_momentum_nesterov)
124+
125+
126+
@registry.register_optimizer
127+
def yellow_fin(learning_rate, hparams):
128+
return yellowfin.YellowFinOptimizer(
129+
learning_rate=learning_rate,
130+
momentum=hparams.optimizer_momentum_momentum)
131+
132+
133+
@registry.register_optimizer
134+
def true_adam(learning_rate, hparams):
135+
return tf.train.AdamOptimizer(
136+
learning_rate,
137+
beta1=hparams.optimizer_adam_beta1,
138+
beta2=hparams.optimizer_adam_beta2,
139+
epsilon=hparams.optimizer_adam_epsilon)
140+
141+
142+
@registry.register_optimizer
143+
def adam_w(learning_rate, hparams):
144+
# Openai gpt used weight decay.
145+
# Given the internals of AdamW, weight decay dependent on the
146+
# learning rate is chosen to match the openai implementation.
147+
# The weight decay update to each parameter is applied before the adam
148+
# gradients computation, which is different from that described
149+
# in the paper and in the openai implementation:
150+
# https://arxiv.org/pdf/1711.05101.pdf
151+
return tf.contrib.opt.AdamWOptimizer(
152+
0.01*learning_rate,
153+
learning_rate,
154+
beta1=hparams.optimizer_adam_beta1,
155+
beta2=hparams.optimizer_adam_beta2,
156+
epsilon=hparams.optimizer_adam_epsilon)
157+
158+
159+
@registry.register_optimizer("Adafactor")
160+
def register_adafactor(learning_rate, hparams):
161+
return adafactor.adafactor_optimizer_from_hparams(hparams, learning_rate)
162+
163+
164+
def _register_base_optimizer(key, fn):
165+
registry.register_optimizer(key)(
166+
lambda learning_rate, hparams: fn(learning_rate))
167+
168+
169+
for k in tf.contrib.layers.OPTIMIZER_CLS_NAMES:
170+
if k not in registry._OPTIMIZERS:
171+
_register_base_optimizer(k, tf.contrib.layers.OPTIMIZER_CLS_NAMES[k])
172+
173+
96174
class ConditionalOptimizer(tf.train.Optimizer):
97175
"""Conditional optimizer."""
98176

@@ -113,53 +191,7 @@ def __init__(self, optimizer_name, lr, hparams, use_tpu=False): # pylint: disab
113191
value=hparams.optimizer_adam_epsilon,
114192
hparams=hparams)
115193

116-
if optimizer_name == "Adam":
117-
# We change the default epsilon for Adam.
118-
# Using LazyAdam as it's much faster for large vocabulary embeddings.
119-
self._opt = tf.contrib.opt.LazyAdamOptimizer(
120-
lr,
121-
beta1=hparams.optimizer_adam_beta1,
122-
beta2=hparams.optimizer_adam_beta2,
123-
epsilon=hparams.optimizer_adam_epsilon)
124-
elif optimizer_name == "MultistepAdam":
125-
self._opt = multistep_optimizer.MultistepAdamOptimizer(
126-
lr,
127-
beta1=hparams.optimizer_adam_beta1,
128-
beta2=hparams.optimizer_adam_beta2,
129-
epsilon=hparams.optimizer_adam_epsilon,
130-
n=hparams.optimizer_multistep_accumulate_steps)
131-
elif optimizer_name == "Momentum":
132-
self._opt = tf.train.MomentumOptimizer(
133-
lr,
134-
momentum=hparams.optimizer_momentum_momentum,
135-
use_nesterov=hparams.optimizer_momentum_nesterov)
136-
elif optimizer_name == "YellowFin":
137-
self._opt = yellowfin.YellowFinOptimizer(
138-
learning_rate=lr, momentum=hparams.optimizer_momentum_momentum)
139-
elif optimizer_name == "TrueAdam":
140-
self._opt = tf.train.AdamOptimizer(
141-
lr,
142-
beta1=hparams.optimizer_adam_beta1,
143-
beta2=hparams.optimizer_adam_beta2,
144-
epsilon=hparams.optimizer_adam_epsilon)
145-
elif optimizer_name == "AdamW":
146-
# Openai gpt used weight decay.
147-
# Given the internals of AdamW, weight decay dependent on the
148-
# learning rate is chosen to match the openai implementation.
149-
# The weight decay update to each parameter is applied before the adam
150-
# gradients computation, which is different from that described
151-
# in the paper and in the openai implementation:
152-
# https://arxiv.org/pdf/1711.05101.pdf
153-
self._opt = tf.contrib.opt.AdamWOptimizer(
154-
0.01*lr,
155-
lr,
156-
beta1=hparams.optimizer_adam_beta1,
157-
beta2=hparams.optimizer_adam_beta2,
158-
epsilon=hparams.optimizer_adam_epsilon)
159-
elif optimizer_name == "Adafactor":
160-
self._opt = adafactor.adafactor_optimizer_from_hparams(hparams, lr)
161-
else:
162-
self._opt = tf.contrib.layers.OPTIMIZER_CLS_NAMES[optimizer_name](lr)
194+
self._opt = registry.optimizer(optimizer_name)(lr, hparams)
163195
if _mixed_precision_is_enabled(hparams):
164196
if not hparams.mixed_precision_optimizer_loss_scaler:
165197
tf.logging.warning("Using mixed precision without a loss scaler will "

tensor2tensor/utils/registry.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,51 @@ def list_models():
183183
return list(sorted(_MODELS))
184184

185185

186+
_OPTIMIZERS = {}
187+
188+
189+
def register_optimizer(name=None):
190+
"""Register an optimizer. name defaults to upper camel case of fn name."""
191+
192+
def default_opt_name(opt_fn):
193+
return misc_utils.snakecase_to_camelcase(default_name(opt_fn))
194+
195+
def decorator(opt_fn, registration_name):
196+
"""Registers and returns optimizer_fn with registration_name or default."""
197+
if registration_name is None:
198+
registration_name = default_opt_name(opt_fn)
199+
200+
if registration_name in _OPTIMIZERS and not tf.executing_eagerly():
201+
raise LookupError("Optimizer %s already registered." % registration_name)
202+
args, varargs, keywords, _ = inspect.getargspec(opt_fn)
203+
204+
if len(args) != 2 or varargs is not None or keywords is not None:
205+
raise ValueError("Optimizer registration function must take two "
206+
"arguments: learning_rate (float) and "
207+
"hparams (HParams).")
208+
_OPTIMIZERS[registration_name] = opt_fn
209+
return opt_fn
210+
211+
if callable(name):
212+
opt_fn = name
213+
registration_name = default_opt_name(opt_fn)
214+
return decorator(opt_fn, registration_name=registration_name)
215+
216+
return lambda opt_fn: decorator(opt_fn, name)
217+
218+
219+
def optimizer(name):
220+
if name not in _OPTIMIZERS:
221+
raise LookupError("Optimizer %s never registered. "
222+
"Available optimizers:\n %s"
223+
% (name, "\n".join(list_optimizers())))
224+
return _OPTIMIZERS[name]
225+
226+
227+
def list_optimizers():
228+
return list(sorted(_OPTIMIZERS))
229+
230+
186231
def register_hparams(name=None):
187232
"""Register an HParams set. name defaults to function name snake-cased."""
188233

0 commit comments

Comments
 (0)