Skip to content

Commit

Permalink
resolve tensorflow#299 Fix WarmupCosineDecay.
Browse files Browse the repository at this point in the history
* Previous version scaled the cosine decay by a linear warmup value. So the max value was max_lr*0.5*(1+cos(warmup_steps/total_steps*pi))
* New version has a linear warmup and then begins the cosine decay from cos(0.0) so the max value is now max_lr.
* Previous version accepted a tensor of values, this is not needed. Simplified to accept a single scaler step value.
* Updated tests to be consistent with the keras LearningRateSchedule tests.
* Renamed class from WarmUpCosine to WarmupCosineDecay. This is more consistent with the Keras LearningRateSchedules.
  • Loading branch information
owenvallis committed Oct 24, 2022
1 parent 48eeaaf commit c6d4700
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 187 deletions.
152 changes: 84 additions & 68 deletions tensorflow_similarity/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,98 +11,114 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict
from __future__ import annotations

import tensorflow as tf
from typing import Any

from tensorflow_similarity.types import FloatTensor
import tensorflow as tf


@tf.keras.utils.register_keras_serializable(package="Similarity")
class WarmUpCosine(tf.keras.optimizers.schedules.LearningRateSchedule):
"""A LearningRateSchedule that uses a cosine decay schedule with a warmup period.
class WarmupCosineDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
"""A cosine decay LearningRateSchedule with a linear warmup period.
See [Loshchilov & Hutter, ICLR2016](https://arxiv.org/abs/1608.03983),
SGDR: Stochastic Gradient Descent with Warm Restarts.
When training a model, it is often useful to lower the learning rate as
the training progresses. This schedule applies a linear warmup from 0.0
to a max learning rate followed by a cosine decay function. It requires
a `step` value to compute the decayed learning rate. You can just pass
a TensorFlow variable that you increment at each training step.
The schedule is a 1-arg callable that produces a decayed learning rate
when passed the current optimizer step. This can be useful for changing
the learning rate value across different invocations of optimizer
functions.
This learning rate schedule is useful for training when using the Barlow Twin Loss.
It is computed as:
The warmup period applies a linear scaling to the CosineDecay schedule.
```python
def decayed_learning_rate(step):
step = min(step, total_steps)
if step < warmup_steps:
decayed = step / warmup_steps
else:
decay_steps = total_steps - warmup_steps
cosine_decay = 0.5 * (1. + cos(pi * (step - warmup_steps) / decay_steps))
decayed = (1 - alpha) * cosine_decay + alpha
return max_learning_rate * decayed
```
Example usage:
```python
total_steps = 1000
warmup_steps = 100
lr_decayed_fn = tf.keras.optimizers.schedules.WarmupCosineDecay(
max_learning_rate, total_steps, warmup_steps)
```
You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
as the learning rate.
Returns:
A 1-arg callable learning rate schedule that takes the current optimizer
step and outputs the decayed learning rate, a scalar `Tensor` of the same
type as `max_learning_rate`.
"""

def __init__(
self,
initial_learning_rate: float,
decay_steps: int,
max_learning_rate: float,
total_steps: int,
warmup_steps: int,
warmup_learning_rate: float = 0.0,
alpha: float = 0.0,
name: str = "WarmUpCosine",
name: str | None = None,
):
"""Applies cosine decay to the learning rate.
"""Applies cosine decay with warmp up to the learning rate.
Args:
initial_learning_rate: A scalar `float32` or `float64` Tensor or a
Python number. The initial learning rate.
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
Number of steps to decay over.
warmup_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
Number of steps to warmup over. Must be smaller than the number of
decay_steps.
warmup_learning_rate: A scalar `float32` or `float64` Tensor or a
Python number. The initial warmup learning rate. Must be smaller than
the initial_learning_rate. Defaults to 0.0.
alpha: A scalar `float32` or `float64` Tensor or a Python number.
Minimum learning rate value as a fraction of initial_learning_rate.
Defaults to 0.0.
name: String. Optional name of the operation. Defaults to 'WarmUpCosine'.
max_learning_rate: The max learning rate after warmup.
total_steps: Total number of steps in the schedule.
warmup_steps: Number of steps to warmup over. Must be smaller than the total number of steps.
alpha: Minimum learning rate value as a fraction of initial_learning_rate.
name: Optional name of the operation. Defaults to 'WarmupCosineDecay'.
"""

super().__init__()

if warmup_learning_rate > initial_learning_rate:
raise ValueError("warmup_learning_rate must be smaller than the initial_learning_rate")
if warmup_steps >= total_steps:
raise ValueError("warmup_steps must be less than the total steps")

if warmup_steps > decay_steps:
raise ValueError("warmup_steps must be smaller than the decay_steps")
self.initial_learning_rate = initial_learning_rate
self.decay_steps = decay_steps
self.alpha = alpha
self.warmup_learning_rate = warmup_learning_rate
self.max_learning_rate = max_learning_rate
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.alpha = alpha
self.name = name

self.cosine_decay = tf.keras.optimizers.schedules.CosineDecay(
initial_learning_rate=initial_learning_rate,
decay_steps=decay_steps,
alpha=alpha,
)
# Compute the warmup increment.
self.tf_initial_learning_rate = tf.convert_to_tensor(self.initial_learning_rate, name="initial_learning_rate")
self.dtype = self.tf_initial_learning_rate.dtype
self.learning_rate_delta = tf.convert_to_tensor(
self.warmup_learning_rate / self.initial_learning_rate, self.dtype
)
self.warmup_inc = tf.math.divide_no_nan(
(1.0 - self.learning_rate_delta),
tf.convert_to_tensor(self.warmup_steps, self.dtype),
)

# If the warmup increment is zero we have no warm up phase and we set
# the learning rate delta to 1.0 to ensure the warmup_scaler value is
# always fixed at 1.0.
if self.warmup_inc == 0:
self.learning_rate_delta = tf.constant([1.0], self.dtype)

def __call__(self, step: FloatTensor) -> FloatTensor:
global_step_recomp = tf.cast(step, self.dtype)
warmup_scaler = tf.minimum(1.0, self.warmup_inc * global_step_recomp + self.learning_rate_delta)
learning_rate: FloatTensor = self.cosine_decay(global_step_recomp) * warmup_scaler
return learning_rate

def get_config(self) -> Dict[str, Any]:
self.max_learning_rate_tf = tf.convert_to_tensor(self.max_learning_rate, name="max_learning_rate")
self.dtype = self.max_learning_rate_tf.dtype
self.warmup_steps_tf = tf.cast(self.warmup_steps, self.dtype)

self.cosine_decay = tf.keras.experimental.CosineDecay(max_learning_rate, total_steps - warmup_steps, alpha)

def __call__(self, step):
with tf.name_scope(self.name or "WarmupCosineDecay"):
step = tf.cast(step, self.dtype)

learning_rate = tf.cond(
tf.math.less(step, self.warmup_steps_tf),
lambda: tf.math.divide_no_nan(step, self.warmup_steps_tf) * self.max_learning_rate_tf,
lambda: self.cosine_decay(step - self.warmup_steps_tf),
)

return learning_rate

def get_config(self) -> dict[str, Any]:
return {
"initial_learning_rate": self.initial_learning_rate,
"decay_steps": self.decay_steps,
"alpha": self.alpha,
"warmup_learning_rate": self.warmup_learning_rate,
"max_learning_rate": self.max_learning_rate,
"total_steps": self.total_steps,
"warmup_steps": self.warmup_steps,
"alpha": self.alpha,
"name": self.name,
}
Loading

0 comments on commit c6d4700

Please sign in to comment.