-
Notifications
You must be signed in to change notification settings - Fork 45.4k
Closed
Labels
models:officialmodels that come under official repositorymodels that come under official repositorytype:feature
Description
Prerequisites
- I checked to make sure that this feature has not been requested already.
1. The entire URL of the file you are using
models/official/nlp/optimization.py
Lines 68 to 107 in 7ecbac3
def create_optimizer(init_lr, | |
num_train_steps, | |
num_warmup_steps, | |
end_lr=0.0, | |
optimizer_type='adamw', | |
beta_1=0.9): | |
"""Creates an optimizer with learning rate schedule.""" | |
# Implements linear decay of the learning rate. | |
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay( | |
initial_learning_rate=init_lr, | |
decay_steps=num_train_steps, | |
end_learning_rate=end_lr) | |
if num_warmup_steps: | |
lr_schedule = WarmUp( | |
initial_learning_rate=init_lr, | |
decay_schedule_fn=lr_schedule, | |
warmup_steps=num_warmup_steps) | |
if optimizer_type == 'adamw': | |
logging.info('using Adamw optimizer') | |
optimizer = AdamWeightDecay( | |
learning_rate=lr_schedule, | |
weight_decay_rate=0.01, | |
beta_1=beta_1, | |
beta_2=0.999, | |
epsilon=1e-6, | |
exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias']) | |
elif optimizer_type == 'lamb': | |
logging.info('using Lamb optimizer') | |
optimizer = tfa_optimizers.LAMB( | |
learning_rate=lr_schedule, | |
weight_decay_rate=0.01, | |
beta_1=beta_1, | |
beta_2=0.999, | |
epsilon=1e-6, | |
exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias']) | |
else: | |
raise ValueError('Unsupported optimizer type: ', optimizer_type) | |
return optimizer |
2. Describe the feature you request
Similar to beta_1
as a tunable parameter for creating optimizers, we can add beta_2
, epsilon
, weight_decay_rate
, and exclude_from_weight_decay
as tunable parameters by passing them as argument from create_optimizer
.
3. Additional context
I was recently trying to finetune a Huggingface Roberta model and while doing so, I wanted to add a scheduler as well as AdamW with custom parameters, and thus I came across these methods.
4. Are you willing to contribute it? (Yes or No)
Yes
Metadata
Metadata
Labels
models:officialmodels that come under official repositorymodels that come under official repositorytype:feature