Skip to content

Commit 8770619

Browse files
committed
Attempted fix for global pruning error on zero weights pruned.
1 parent c831200 commit 8770619

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

sparselearning/core.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,6 @@
1212
import shutil
1313
import time
1414
from matplotlib import pyplot as plt
15-
16-
#from sparselearning.funcs import no_redistribution, momentum_redistribution, magnitude_redistribution, nonzero_redistribution
17-
#from sparselearning.funcs import global_momentum_growth, momentum_growth, random_growth, momentum_neuron_growth
18-
#from sparselearning.funcs import threshold_prune, magnitude_prune, global_magnitude_prune, magnitude_and_negativity_prune
19-
2015
from sparselearning.funcs import redistribution_funcs, growth_funcs, prune_funcs
2116

2217
def add_sparse_args(parser):
@@ -491,8 +486,11 @@ def calc_growth_redistribution(self):
491486
name2regrowth[name] = math.floor((self.total_removed/float(expected_killed))*name2regrowth[name])
492487
elif self.prune_mode == 'global_magnitude':
493488
expected_removed = self.baseline_nonzero*self.name2prune_rate[name]
494-
expected_vs_actual = self.total_removed/expected_removed
495-
name2regrowth[name] = math.floor(expected_vs_actual*name2regrowth[name])
489+
if expected_removed == 0.0:
490+
name2regrowth[name] = 0.0
491+
else:
492+
expected_vs_actual = self.total_removed/expected_removed
493+
name2regrowth[name] = math.floor(expected_vs_actual*name2regrowth[name])
496494

497495
return name2regrowth
498496

sparselearning/funcs.py

+1
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ def momentum_neuron_growth(masking, name, new_mask, total_regrowth, weight):
271271
y, idx = torch.sort(M[i].flatten())
272272
if neuron_regrowth > available:
273273
neuron_regrowth = available
274+
# TODO: Work into more stable growth method
274275
threshold = y[-(neuron_regrowth)].item()
275276
if threshold == 0.0: continue
276277
if neuron_regrowth < 10: continue

0 commit comments

Comments
 (0)