Skip to content
This repository has been archived by the owner on Feb 20, 2021. It is now read-only.

Commit

Permalink
✨ Add utility functions for LR scheduling
Browse files Browse the repository at this point in the history
  • Loading branch information
emaballarin committed May 24, 2020
1 parent 35c0249 commit 4d2ed35
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
4 changes: 4 additions & 0 deletions pyromaniac/optim/util/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .scheduling import sched_sss
from .scheduling import sched_ssse

del scheduling
30 changes: 30 additions & 0 deletions pyromaniac/optim/util/scheduling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) 2020- AI-CPS@UniTS
# Copyright (c) 2020- Emanuele Ballarin <emanuele@ballarin.cc>
# SPDX-License-Identifier: Apache-2.0

import numpy as np


def sched_sss(start_lr: float, stop_lr: float, nr_decays):
r"""Obtain the canonical scheduling parametrization from the start/stop/step one.
Arguments:
start_lr (float): initial learning rate
stop_lr (float): final learning rate
nr_decays: number of learning rate decay steps
"""

return (start_lr), (np.power((stop_lr / start_lr), (1.0 / nr_decays)).item())


def sched_ssse(start_lr: float, stop_lr: float, nr_decays, epochs: int):
r"""Obtain the canonical scheduling parametrization from the start/stop/step one, with explicit epoch specification.
Arguments:
start_lr (float): initial learning rate
stop_lr (float): final learning rate
nr_decays: number of learning rate decay steps
epochs (int): number of total learning epochs
"""

return (sched_sss(start_lr, stop_lr, nr_decays)), int(epochs // nr_decays)

0 comments on commit 4d2ed35

Please sign in to comment.