diff --git a/returnn/torch/updater.py b/returnn/torch/updater.py index db5004d35..23ead3d40 100644 --- a/returnn/torch/updater.py +++ b/returnn/torch/updater.py @@ -92,7 +92,7 @@ def __init__(self, *, config, network, device, initial_learning_rate=1.0): :param float initial_learning_rate: """ self.config = config - self.learning_rate = initial_learning_rate + self.learning_rate = float(initial_learning_rate) self._effective_learning_rate = self.learning_rate self.network = network self._device = device @@ -150,7 +150,7 @@ def set_learning_rate(self, value): :param float value: New learning rate. """ - self.learning_rate = value + self.learning_rate = float(value) self._update_effective_learning_rate() def get_effective_learning_rate(self) -> float: @@ -162,9 +162,10 @@ def get_effective_learning_rate(self) -> float: def _update_effective_learning_rate(self): self._effective_learning_rate = self.learning_rate if self.learning_rate_function is not None: - self._effective_learning_rate = self.learning_rate_function( + lr = self.learning_rate_function( global_train_step=self._current_train_step, epoch=self._current_epoch, learning_rate=self.learning_rate ) + self._effective_learning_rate = float(lr) if self.optimizer: for param_group in self.optimizer.param_groups: param_group["lr"] = self._effective_learning_rate