Skip to content

Commit d757fec

Browse files
authored
Merge pull request huggingface#1112 from ayasyrev/sched_noise_dup_code
sched noise dup code remove
2 parents 7c67d6a + cf57695 commit d757fec

File tree

2 files changed

+25
-30
lines changed

2 files changed

+25
-30
lines changed

timm/scheduler/plateau_lr.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(self,
4343
min_lr=lr_min
4444
)
4545

46-
self.noise_range = noise_range_t
46+
self.noise_range_t = noise_range_t
4747
self.noise_pct = noise_pct
4848
self.noise_type = noise_type
4949
self.noise_std = noise_std
@@ -82,25 +82,12 @@ def step(self, epoch, metric=None):
8282

8383
self.lr_scheduler.step(metric, epoch) # step the base scheduler
8484

85-
if self.noise_range is not None:
86-
if isinstance(self.noise_range, (list, tuple)):
87-
apply_noise = self.noise_range[0] <= epoch < self.noise_range[1]
88-
else:
89-
apply_noise = epoch >= self.noise_range
90-
if apply_noise:
91-
self._apply_noise(epoch)
85+
if self._is_apply_noise(epoch):
86+
self._apply_noise(epoch)
87+
9288

9389
def _apply_noise(self, epoch):
94-
g = torch.Generator()
95-
g.manual_seed(self.noise_seed + epoch)
96-
if self.noise_type == 'normal':
97-
while True:
98-
# resample if noise out of percent limit, brute force but shouldn't spin much
99-
noise = torch.randn(1, generator=g).item()
100-
if abs(noise) < self.noise_pct:
101-
break
102-
else:
103-
noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
90+
noise = self._calculate_noise(epoch)
10491

10592
# apply the noise on top of previous LR, cache the old value so we can restore for normal
10693
# stepping of base scheduler

timm/scheduler/scheduler.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -85,21 +85,29 @@ def update_groups(self, values):
8585
param_group[self.param_group_field] = value
8686

8787
def _add_noise(self, lrs, t):
88+
if self._is_apply_noise(t):
89+
noise = self._calculate_noise(t)
90+
lrs = [v + v * noise for v in lrs]
91+
return lrs
92+
93+
def _is_apply_noise(self, t) -> bool:
94+
"""Return True if scheduler in noise range."""
8895
if self.noise_range_t is not None:
8996
if isinstance(self.noise_range_t, (list, tuple)):
9097
apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
9198
else:
9299
apply_noise = t >= self.noise_range_t
93-
if apply_noise:
94-
g = torch.Generator()
95-
g.manual_seed(self.noise_seed + t)
96-
if self.noise_type == 'normal':
97-
while True:
100+
return apply_noise
101+
102+
def _calculate_noise(self, t) -> float:
103+
g = torch.Generator()
104+
g.manual_seed(self.noise_seed + t)
105+
if self.noise_type == 'normal':
106+
while True:
98107
# resample if noise out of percent limit, brute force but shouldn't spin much
99-
noise = torch.randn(1, generator=g).item()
100-
if abs(noise) < self.noise_pct:
101-
break
102-
else:
103-
noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
104-
lrs = [v + v * noise for v in lrs]
105-
return lrs
108+
noise = torch.randn(1, generator=g).item()
109+
if abs(noise) < self.noise_pct:
110+
return noise
111+
else:
112+
noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
113+
return noise

0 commit comments

Comments
 (0)