Skip to content

Commit

Permalink
Merge pull request huggingface#1112 from ayasyrev/sched_noise_dup_code
Browse files Browse the repository at this point in the history
sched noise dup code remove
  • Loading branch information
rwightman authored Mar 21, 2022
2 parents 7c67d6a + cf57695 commit d757fec
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 30 deletions.
23 changes: 5 additions & 18 deletions timm/scheduler/plateau_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self,
min_lr=lr_min
)

self.noise_range = noise_range_t
self.noise_range_t = noise_range_t
self.noise_pct = noise_pct
self.noise_type = noise_type
self.noise_std = noise_std
Expand Down Expand Up @@ -82,25 +82,12 @@ def step(self, epoch, metric=None):

self.lr_scheduler.step(metric, epoch) # step the base scheduler

if self.noise_range is not None:
if isinstance(self.noise_range, (list, tuple)):
apply_noise = self.noise_range[0] <= epoch < self.noise_range[1]
else:
apply_noise = epoch >= self.noise_range
if apply_noise:
self._apply_noise(epoch)
if self._is_apply_noise(epoch):
self._apply_noise(epoch)


def _apply_noise(self, epoch):
g = torch.Generator()
g.manual_seed(self.noise_seed + epoch)
if self.noise_type == 'normal':
while True:
# resample if noise out of percent limit, brute force but shouldn't spin much
noise = torch.randn(1, generator=g).item()
if abs(noise) < self.noise_pct:
break
else:
noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
noise = self._calculate_noise(epoch)

# apply the noise on top of previous LR, cache the old value so we can restore for normal
# stepping of base scheduler
Expand Down
32 changes: 20 additions & 12 deletions timm/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,21 +85,29 @@ def update_groups(self, values):
param_group[self.param_group_field] = value

def _add_noise(self, lrs, t):
if self._is_apply_noise(t):
noise = self._calculate_noise(t)
lrs = [v + v * noise for v in lrs]
return lrs

def _is_apply_noise(self, t) -> bool:
"""Return True if scheduler in noise range."""
if self.noise_range_t is not None:
if isinstance(self.noise_range_t, (list, tuple)):
apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
else:
apply_noise = t >= self.noise_range_t
if apply_noise:
g = torch.Generator()
g.manual_seed(self.noise_seed + t)
if self.noise_type == 'normal':
while True:
return apply_noise

def _calculate_noise(self, t) -> float:
g = torch.Generator()
g.manual_seed(self.noise_seed + t)
if self.noise_type == 'normal':
while True:
# resample if noise out of percent limit, brute force but shouldn't spin much
noise = torch.randn(1, generator=g).item()
if abs(noise) < self.noise_pct:
break
else:
noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
lrs = [v + v * noise for v in lrs]
return lrs
noise = torch.randn(1, generator=g).item()
if abs(noise) < self.noise_pct:
return noise
else:
noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
return noise

0 comments on commit d757fec

Please sign in to comment.