Skip to content

Commit a0ac5d3

Browse files
committed
SparseOptimizer object: handling sparse pattern optimization and other classical optimizers (Adam,...) state update (must clean parts of the state when sparse pattern changes).
1 parent f1bfa8f commit a0ac5d3

File tree

3 files changed

+255
-14
lines changed

3 files changed

+255
-14
lines changed

pytorch_block_sparse/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .block_sparse import BlockSparseMatrix
22
from .block_sparse_linear import BlockSparseLinear
33
from .util import BlockSparseModelPatcher
4-
from .sparse_optimizer import MagnitudeSparseOptimizerStrategy
4+
from .sparse_optimizer import SparseOptimizer

pytorch_block_sparse/sparse_optimizer.py

Lines changed: 182 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,29 @@
11
import torch
2-
2+
import torch.optim as optim
3+
from pytorch_block_sparse import BlockSparseMatrix
34

45
class SparseOptimizerStrategy:
56
def run(self, block_sparse_matrix):
67
raise NotImplementedError()
78

89

910
class MagnitudeSparseOptimizerStrategy(SparseOptimizerStrategy):
10-
def __init__(self, ratio, new_coefficients_method = "discrete", new_coefficients_scale=0.1):
11-
self.ratio = ratio
12-
self.new_coefficients_method = new_coefficients_method
11+
def __init__(self, cleanup_ratio, new_coefficients_distribution ="uniform", new_coefficients_scale=0.1):
12+
self.cleanup_ratio = cleanup_ratio
13+
self.new_coefficients_distribution = new_coefficients_distribution
1314
self.new_coefficients_scale = new_coefficients_scale
1415

1516
def initialize_new_blocks(self, old_data, new_data):
1617
mean, std = old_data.mean(), old_data.std()
1718

18-
if self.new_coefficients_method == "gaussian":
19+
if self.new_coefficients_distribution == "gaussian":
1920
new_data.normal_(mean=mean * self.new_coefficients_scale, std=std * self.new_coefficients_scale)
20-
elif self.new_coefficients_method == "discrete":
21+
elif self.new_coefficients_distribution == "uniform":
2122
new_data.random_(0, 1)
2223
new_data -= 0.5
2324
new_data *= 2 * std * self.new_coefficients_scale
2425
else:
25-
raise Exception("Unknown new coefficients method %s" % self.new_coefficients_method)
26+
raise Exception("Unknown new coefficients method %s" % self.new_coefficients_distribution)
2627

2728
def run(self, block_sparse_matrix):
2829
bsm = block_sparse_matrix
@@ -33,7 +34,7 @@ def run(self, block_sparse_matrix):
3334
_, indices = norms.sort()
3435

3536
# Extract the worst blocks
36-
bad_blocks = indices[:int(indices.shape[0] * self.ratio)]
37+
bad_blocks = indices[:int(indices.shape[0] * self.cleanup_ratio)]
3738

3839
# Find available positions
3940
block_mask = ~ bsm.block_mask_build(None)
@@ -52,19 +53,190 @@ def run(self, block_sparse_matrix):
5253

5354
new_block_mask[bad_blocks] = True
5455

55-
new_block_mask = new_block_mask.unsqueeze(-1).repeat(bsm.block_shape).float()
56+
new_block_mask = new_block_mask.unsqueeze(-1)
57+
new_block_mask = new_block_mask.repeat_interleave(bsm.block_shape[0], dim=0)
58+
new_block_mask = new_block_mask.repeat_interleave(bsm.block_shape[1], dim=1)
59+
new_block_mask = new_block_mask.float()
5660

5761
new_blocks = torch.zeros_like(bsm.data)
5862

5963
self.initialize_new_blocks(bsm.data, new_blocks)
6064

6165
new_blocks *= new_block_mask
6266

67+
state_keep_mask = 1.0 - new_block_mask
68+
6369
with torch.no_grad():
64-
bsm.data *= 1.0 - new_block_mask
70+
bsm.data *= state_keep_mask
6571
bsm.data += new_blocks
6672

73+
return state_keep_mask
74+
75+
class _RequiredParameter(object):
76+
"""Singleton class representing a required parameter for an Optimizer."""
77+
78+
def __repr__(self):
79+
return "<required parameter>"
80+
81+
required = _RequiredParameter()
82+
83+
84+
class OptimizerStateUpdater():
85+
def __init__(self, optimizer, sparse_object):
86+
self.optimizer = optimizer
87+
if not isinstance(sparse_object, BlockSparseMatrix):
88+
raise Exception(f"Unknown sparse_object type {sparse_object}")
89+
90+
self.sparse_object = sparse_object
91+
92+
def update_state_data(self, param, state_keep_mask):
93+
raise NotImplementedError()
94+
95+
def update_state(self, state_keep_mask):
96+
if isinstance(self.sparse_object, BlockSparseMatrix):
97+
search_param = self.sparse_object.data
98+
else:
99+
raise Exception(f"Unknown sparse_object type {self.sparse_object}")
100+
101+
found = False
102+
for param_group in self.optimizer.param_groups:
103+
for param in param_group["params"]:
104+
if param is search_param:
105+
found = True
106+
self.update_state_data(param, state_keep_mask)
107+
108+
return found
109+
110+
class AdamOptimizerStateUpdater(OptimizerStateUpdater):
111+
def update_state_data(self, param, state_keep_mask):
112+
opt = self.optimizer
113+
114+
param_state = opt.state[param]
115+
116+
for key in param_state:
117+
if key in ['exp_avg', 'exp_avg_sq', 'max_exp_avg_sq']:
118+
param_state[key] *= state_keep_mask
119+
elif key == 'step':
120+
# We cannot really alter the step info, it's global, so the bias_correction1 and bias_correction2 may
121+
# not be completely correct for the new coefficients, but it should not be a big issue
122+
pass
123+
else:
124+
raise Exception(f"Unknown key in Adam parameter state {key}")
125+
126+
class SparseOptimizer(torch.optim.Optimizer):
127+
METHODS = ["magnitude"]
128+
COEFFICIENTS_DISTRIBUTION = ["uniform", "gaussian"]
129+
allowed_keys = {"lr", "method", "new_coefficients_scale", "new_coefficients_distribution"}
130+
"""optimizer = sparse_cleaner.SparseOptimizer([BlockSparseMatrix,BlockSparseMatrix], method="magnitude", new_coefficients_distribution="uniform")
131+
optimizer.add_param_group(dict(sparse_objects=[BlockSparseMatrix], lr=0.5, method="magnitude", new_coefficients_distribution="gaussian", new_coefficients_scale = 1.0))"""
132+
def __init__(self, sparse_objects, lr=1e-1, method="magnitude", new_coefficients_scale = 0.1, new_coefficients_distribution="uniform"):
133+
if not 0.0 < lr:
134+
raise ValueError("Invalid learning rate: {}".format(lr))
135+
136+
defaults = dict(lr=lr,
137+
method=method,
138+
new_coefficients_scale=new_coefficients_scale,
139+
new_coefficients_distribution=new_coefficients_distribution)
140+
141+
super(SparseOptimizer, self).__init__([{"sparse_objects":sparse_objects}], defaults)
142+
self.attached_optimizers = []
143+
144+
@staticmethod
145+
def sparse_objects(model):
146+
ret = []
147+
for name, module in model.named_modules():
148+
if isinstance(module, BlockSparseMatrix):
149+
ret.append(module)
150+
151+
return ret
152+
153+
def attach_optimizer(self, optimizer):
154+
if optimizer in self.attached_optimizers:
155+
Warning("Optimizer already attached")
156+
return
157+
self.attached_optimizers.append(optimizer)
158+
159+
def add_param_group(self, sparse_objects_group):
160+
assert isinstance(sparse_objects_group, dict), "param group must be a dict"
161+
162+
for k in sparse_objects_group:
163+
if k == "sparse_objects":
164+
continue
165+
elif k not in self.allowed_keys:
166+
raise Exception("Unknown cleaning parameter %s" % k)
167+
168+
sparse_objects = sparse_objects_group['sparse_objects']
169+
170+
if isinstance(sparse_objects, BlockSparseMatrix):
171+
sparse_objects_group['sparse_objects'] = [sparse_objects]
172+
else:
173+
sparse_objects_group['sparse_objects'] = list(sparse_objects)
67174

175+
sparse_objects = sparse_objects_group['sparse_objects']
68176

177+
for p in sparse_objects:
178+
if isinstance(p, BlockSparseMatrix):
179+
continue
180+
else:
181+
raise Exception("I don't know how to clean this type of object: %s" % p)
69182

183+
for name, default in self.defaults.items():
184+
if default is required and name not in sparse_objects_group:
185+
raise ValueError("parameter group didn't specify a value of required optimization parameter " +
186+
name)
187+
else:
188+
sparse_objects_group.setdefault(name, default)
189+
190+
if sparse_objects_group["method"] not in self.METHODS:
191+
raise Exception(f"Invalid Method {sparse_objects_group['method']}")
192+
193+
if sparse_objects_group["new_coefficients_distribution"] not in self.COEFFICIENTS_DISTRIBUTION:
194+
raise Exception(f"Invalid new coefficients distribution {sparse_objects_group['new_coefficients_distribution']}")
195+
196+
param_set = set()
197+
for group in self.param_groups:
198+
param_set.update(set(group['sparse_objects']))
199+
200+
if not param_set.isdisjoint(set(sparse_objects_group['sparse_objects'])):
201+
raise ValueError("some parameters appear in more than one parameter group")
202+
203+
self.param_groups.append(sparse_objects_group)
204+
205+
def clean(self, p, method, clean_ratio, new_coefficients_scale, new_coefficients_distribution):
206+
if not isinstance(p, BlockSparseMatrix):
207+
raise Exception("I don't know how to clean this : %s" % p)
208+
209+
if method == "magnitude":
210+
cleaner = MagnitudeSparseOptimizerStrategy(clean_ratio,
211+
new_coefficients_distribution =new_coefficients_distribution,
212+
new_coefficients_scale=new_coefficients_scale)
213+
else:
214+
raise Exception(f"Unknowncleaning method {method}")
215+
216+
state_keep_mask = cleaner.run(p)
217+
218+
if len(self.attached_optimizers) != 0:
219+
found = False
220+
for optimizer in self.attached_optimizers:
221+
if isinstance(optimizer, optim.Adam):
222+
updater = AdamOptimizerStateUpdater(optimizer, p)
223+
found = found or updater.update_state(state_keep_mask)
224+
225+
if not found:
226+
raise Exception(f"Could not find sparse object {p} in optimizers {self.attached_optimizers}")
227+
else:
228+
Warning("No attached optimizer.")
229+
230+
def step(self):
231+
for group in self.param_groups:
232+
clean_ratio = group['lr']
233+
if clean_ratio == 0.0:
234+
continue
235+
for p in group['sparse_objects']:
236+
self.clean(p,
237+
clean_ratio = clean_ratio,
238+
method=group['method'],
239+
new_coefficients_scale=group['new_coefficients_scale'],
240+
new_coefficients_distribution=group['new_coefficients_distribution'],
241+
)
70242

pytorch_block_sparse/tests/test_sparse_optimizer.py

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,84 @@
11
from unittest import TestCase
22
import unittest
3-
from pytorch_block_sparse import BlockSparseMatrix, MagnitudeSparseOptimizerStrategy
3+
from pytorch_block_sparse import BlockSparseMatrix, SparseOptimizer, BlockSparseLinear
4+
from pytorch_block_sparse.sparse_optimizer import MagnitudeSparseOptimizerStrategy
5+
import torch
6+
import torch.optim as optim
47

58
class TestFun(TestCase):
9+
def check_differences(self, bsm, reference_dense, expected_block_changes):
10+
dense = bsm.to_dense()
11+
12+
differences = (reference_dense != dense)
13+
block_shape = bsm.block_shape
14+
differences = float(differences.float().sum() / (block_shape[0] * block_shape[1]))
15+
16+
self.assertEqual(differences, expected_block_changes)
17+
618
def test0(self):
7-
bsm = BlockSparseMatrix.randn((256, 256), 32, block_shape=(32,32), device="cuda")
19+
size = (256, 256)
20+
block_count = 32
21+
cleanup_ratio = 0.1
22+
block_shape = (32,32)
23+
bsm = BlockSparseMatrix.randn(size, block_count, block_shape=block_shape, device="cuda")
824

9-
strategy = MagnitudeSparseOptimizerStrategy(0.1)
25+
dense0 = bsm.to_dense()
1026

27+
strategy = MagnitudeSparseOptimizerStrategy(cleanup_ratio)
1128
strategy.run(bsm)
1229

30+
expected_block_changes = int(cleanup_ratio * block_count) * 2
31+
self.check_differences(bsm, dense0, expected_block_changes)
32+
33+
def test_sparse_optimizer(self):
34+
size = (256, 256)
35+
block_count = 32
36+
cleanup_ratio = 0.1
37+
block_shape = (32, 32)
38+
bsm = BlockSparseMatrix.randn(size, block_count, block_shape=block_shape, device="cuda")
39+
dense0 = bsm.to_dense()
40+
41+
so = SparseOptimizer([bsm], lr=cleanup_ratio)
42+
43+
so.step()
44+
45+
expected_block_changes = int(cleanup_ratio * block_count) * 2
46+
self.check_differences(bsm, dense0, expected_block_changes)
47+
48+
def test_sparse_optimizer_attached_optimizer(self):
49+
size = (256, 256)
50+
density = 0.5
51+
cleanup_ratio = 0.1
52+
53+
linear = BlockSparseLinear(size[0], size[1], True, density).cuda()
54+
55+
sparse_objects = SparseOptimizer.sparse_objects(linear)
56+
57+
assert(len(sparse_objects) == 1)
58+
59+
so = SparseOptimizer(sparse_objects, lr=cleanup_ratio)
60+
61+
adam = optim.Adam(linear.parameters())
62+
63+
so.attach_optimizer(adam)
64+
65+
# Run forward and backward
66+
a = torch.randn([1, size[0]]).abs().cuda()
67+
out = linear(a)
68+
69+
loss = out.sum()
70+
71+
loss.backward()
72+
73+
adam.step()
74+
75+
dense0 = linear.weight.to_dense()
76+
77+
so.step()
78+
79+
block_count = linear.block_count
80+
expected_block_changes = int(cleanup_ratio * block_count) * 2
81+
self.check_differences(linear.weight, dense0, expected_block_changes)
1382

1483

1584

0 commit comments

Comments
 (0)