Skip to content

Commit

Permalink
LearningRateControl, make sure lr is float
Browse files Browse the repository at this point in the history
It might happen to be np.float64 or so...
It's not really such a big problem,
but it's a bit inconsistent then in the LR file.
  • Loading branch information
albertz committed Oct 16, 2024
1 parent 0894e52 commit f09da06
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions returnn/torch/updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __init__(self, *, config, network, device, initial_learning_rate=1.0):
:param float initial_learning_rate:
"""
self.config = config
self.learning_rate = initial_learning_rate
self.learning_rate = float(initial_learning_rate)
self._effective_learning_rate = self.learning_rate
self.network = network
self._device = device
Expand Down Expand Up @@ -150,7 +150,7 @@ def set_learning_rate(self, value):
:param float value: New learning rate.
"""
self.learning_rate = value
self.learning_rate = float(value)
self._update_effective_learning_rate()

def get_effective_learning_rate(self) -> float:
Expand All @@ -162,9 +162,10 @@ def get_effective_learning_rate(self) -> float:
def _update_effective_learning_rate(self):
self._effective_learning_rate = self.learning_rate
if self.learning_rate_function is not None:
self._effective_learning_rate = self.learning_rate_function(
lr = self.learning_rate_function(
global_train_step=self._current_train_step, epoch=self._current_epoch, learning_rate=self.learning_rate
)
self._effective_learning_rate = float(lr)
if self.optimizer:
for param_group in self.optimizer.param_groups:
param_group["lr"] = self._effective_learning_rate
Expand Down

0 comments on commit f09da06

Please sign in to comment.