11import torch
2-
2+ import torch .optim as optim
3+ from pytorch_block_sparse import BlockSparseMatrix
34
45class SparseOptimizerStrategy :
56 def run (self , block_sparse_matrix ):
67 raise NotImplementedError ()
78
89
910class MagnitudeSparseOptimizerStrategy (SparseOptimizerStrategy ):
10- def __init__ (self , ratio , new_coefficients_method = "discrete " , new_coefficients_scale = 0.1 ):
11- self .ratio = ratio
12- self .new_coefficients_method = new_coefficients_method
11+ def __init__ (self , cleanup_ratio , new_coefficients_distribution = "uniform " , new_coefficients_scale = 0.1 ):
12+ self .cleanup_ratio = cleanup_ratio
13+ self .new_coefficients_distribution = new_coefficients_distribution
1314 self .new_coefficients_scale = new_coefficients_scale
1415
1516 def initialize_new_blocks (self , old_data , new_data ):
1617 mean , std = old_data .mean (), old_data .std ()
1718
18- if self .new_coefficients_method == "gaussian" :
19+ if self .new_coefficients_distribution == "gaussian" :
1920 new_data .normal_ (mean = mean * self .new_coefficients_scale , std = std * self .new_coefficients_scale )
20- elif self .new_coefficients_method == "discrete " :
21+ elif self .new_coefficients_distribution == "uniform " :
2122 new_data .random_ (0 , 1 )
2223 new_data -= 0.5
2324 new_data *= 2 * std * self .new_coefficients_scale
2425 else :
25- raise Exception ("Unknown new coefficients method %s" % self .new_coefficients_method )
26+ raise Exception ("Unknown new coefficients method %s" % self .new_coefficients_distribution )
2627
2728 def run (self , block_sparse_matrix ):
2829 bsm = block_sparse_matrix
@@ -33,7 +34,7 @@ def run(self, block_sparse_matrix):
3334 _ , indices = norms .sort ()
3435
3536 # Extract the worst blocks
36- bad_blocks = indices [:int (indices .shape [0 ] * self .ratio )]
37+ bad_blocks = indices [:int (indices .shape [0 ] * self .cleanup_ratio )]
3738
3839 # Find available positions
3940 block_mask = ~ bsm .block_mask_build (None )
@@ -52,19 +53,190 @@ def run(self, block_sparse_matrix):
5253
5354 new_block_mask [bad_blocks ] = True
5455
55- new_block_mask = new_block_mask .unsqueeze (- 1 ).repeat (bsm .block_shape ).float ()
56+ new_block_mask = new_block_mask .unsqueeze (- 1 )
57+ new_block_mask = new_block_mask .repeat_interleave (bsm .block_shape [0 ], dim = 0 )
58+ new_block_mask = new_block_mask .repeat_interleave (bsm .block_shape [1 ], dim = 1 )
59+ new_block_mask = new_block_mask .float ()
5660
5761 new_blocks = torch .zeros_like (bsm .data )
5862
5963 self .initialize_new_blocks (bsm .data , new_blocks )
6064
6165 new_blocks *= new_block_mask
6266
67+ state_keep_mask = 1.0 - new_block_mask
68+
6369 with torch .no_grad ():
64- bsm .data *= 1.0 - new_block_mask
70+ bsm .data *= state_keep_mask
6571 bsm .data += new_blocks
6672
73+ return state_keep_mask
74+
75+ class _RequiredParameter (object ):
76+ """Singleton class representing a required parameter for an Optimizer."""
77+
78+ def __repr__ (self ):
79+ return "<required parameter>"
80+
81+ required = _RequiredParameter ()
82+
83+
84+ class OptimizerStateUpdater ():
85+ def __init__ (self , optimizer , sparse_object ):
86+ self .optimizer = optimizer
87+ if not isinstance (sparse_object , BlockSparseMatrix ):
88+ raise Exception (f"Unknown sparse_object type { sparse_object } " )
89+
90+ self .sparse_object = sparse_object
91+
92+ def update_state_data (self , param , state_keep_mask ):
93+ raise NotImplementedError ()
94+
95+ def update_state (self , state_keep_mask ):
96+ if isinstance (self .sparse_object , BlockSparseMatrix ):
97+ search_param = self .sparse_object .data
98+ else :
99+ raise Exception (f"Unknown sparse_object type { self .sparse_object } " )
100+
101+ found = False
102+ for param_group in self .optimizer .param_groups :
103+ for param in param_group ["params" ]:
104+ if param is search_param :
105+ found = True
106+ self .update_state_data (param , state_keep_mask )
107+
108+ return found
109+
110+ class AdamOptimizerStateUpdater (OptimizerStateUpdater ):
111+ def update_state_data (self , param , state_keep_mask ):
112+ opt = self .optimizer
113+
114+ param_state = opt .state [param ]
115+
116+ for key in param_state :
117+ if key in ['exp_avg' , 'exp_avg_sq' , 'max_exp_avg_sq' ]:
118+ param_state [key ] *= state_keep_mask
119+ elif key == 'step' :
120+ # We cannot really alter the step info, it's global, so the bias_correction1 and bias_correction2 may
121+ # not be completely correct for the new coefficients, but it should not be a big issue
122+ pass
123+ else :
124+ raise Exception (f"Unknown key in Adam parameter state { key } " )
125+
126+ class SparseOptimizer (torch .optim .Optimizer ):
127+ METHODS = ["magnitude" ]
128+ COEFFICIENTS_DISTRIBUTION = ["uniform" , "gaussian" ]
129+ allowed_keys = {"lr" , "method" , "new_coefficients_scale" , "new_coefficients_distribution" }
130+ """optimizer = sparse_cleaner.SparseOptimizer([BlockSparseMatrix,BlockSparseMatrix], method="magnitude", new_coefficients_distribution="uniform")
131+ optimizer.add_param_group(dict(sparse_objects=[BlockSparseMatrix], lr=0.5, method="magnitude", new_coefficients_distribution="gaussian", new_coefficients_scale = 1.0))"""
132+ def __init__ (self , sparse_objects , lr = 1e-1 , method = "magnitude" , new_coefficients_scale = 0.1 , new_coefficients_distribution = "uniform" ):
133+ if not 0.0 < lr :
134+ raise ValueError ("Invalid learning rate: {}" .format (lr ))
135+
136+ defaults = dict (lr = lr ,
137+ method = method ,
138+ new_coefficients_scale = new_coefficients_scale ,
139+ new_coefficients_distribution = new_coefficients_distribution )
140+
141+ super (SparseOptimizer , self ).__init__ ([{"sparse_objects" :sparse_objects }], defaults )
142+ self .attached_optimizers = []
143+
144+ @staticmethod
145+ def sparse_objects (model ):
146+ ret = []
147+ for name , module in model .named_modules ():
148+ if isinstance (module , BlockSparseMatrix ):
149+ ret .append (module )
150+
151+ return ret
152+
153+ def attach_optimizer (self , optimizer ):
154+ if optimizer in self .attached_optimizers :
155+ Warning ("Optimizer already attached" )
156+ return
157+ self .attached_optimizers .append (optimizer )
158+
159+ def add_param_group (self , sparse_objects_group ):
160+ assert isinstance (sparse_objects_group , dict ), "param group must be a dict"
161+
162+ for k in sparse_objects_group :
163+ if k == "sparse_objects" :
164+ continue
165+ elif k not in self .allowed_keys :
166+ raise Exception ("Unknown cleaning parameter %s" % k )
167+
168+ sparse_objects = sparse_objects_group ['sparse_objects' ]
169+
170+ if isinstance (sparse_objects , BlockSparseMatrix ):
171+ sparse_objects_group ['sparse_objects' ] = [sparse_objects ]
172+ else :
173+ sparse_objects_group ['sparse_objects' ] = list (sparse_objects )
67174
175+ sparse_objects = sparse_objects_group ['sparse_objects' ]
68176
177+ for p in sparse_objects :
178+ if isinstance (p , BlockSparseMatrix ):
179+ continue
180+ else :
181+ raise Exception ("I don't know how to clean this type of object: %s" % p )
69182
183+ for name , default in self .defaults .items ():
184+ if default is required and name not in sparse_objects_group :
185+ raise ValueError ("parameter group didn't specify a value of required optimization parameter " +
186+ name )
187+ else :
188+ sparse_objects_group .setdefault (name , default )
189+
190+ if sparse_objects_group ["method" ] not in self .METHODS :
191+ raise Exception (f"Invalid Method { sparse_objects_group ['method' ]} " )
192+
193+ if sparse_objects_group ["new_coefficients_distribution" ] not in self .COEFFICIENTS_DISTRIBUTION :
194+ raise Exception (f"Invalid new coefficients distribution { sparse_objects_group ['new_coefficients_distribution' ]} " )
195+
196+ param_set = set ()
197+ for group in self .param_groups :
198+ param_set .update (set (group ['sparse_objects' ]))
199+
200+ if not param_set .isdisjoint (set (sparse_objects_group ['sparse_objects' ])):
201+ raise ValueError ("some parameters appear in more than one parameter group" )
202+
203+ self .param_groups .append (sparse_objects_group )
204+
205+ def clean (self , p , method , clean_ratio , new_coefficients_scale , new_coefficients_distribution ):
206+ if not isinstance (p , BlockSparseMatrix ):
207+ raise Exception ("I don't know how to clean this : %s" % p )
208+
209+ if method == "magnitude" :
210+ cleaner = MagnitudeSparseOptimizerStrategy (clean_ratio ,
211+ new_coefficients_distribution = new_coefficients_distribution ,
212+ new_coefficients_scale = new_coefficients_scale )
213+ else :
214+ raise Exception (f"Unknowncleaning method { method } " )
215+
216+ state_keep_mask = cleaner .run (p )
217+
218+ if len (self .attached_optimizers ) != 0 :
219+ found = False
220+ for optimizer in self .attached_optimizers :
221+ if isinstance (optimizer , optim .Adam ):
222+ updater = AdamOptimizerStateUpdater (optimizer , p )
223+ found = found or updater .update_state (state_keep_mask )
224+
225+ if not found :
226+ raise Exception (f"Could not find sparse object { p } in optimizers { self .attached_optimizers } " )
227+ else :
228+ Warning ("No attached optimizer." )
229+
230+ def step (self ):
231+ for group in self .param_groups :
232+ clean_ratio = group ['lr' ]
233+ if clean_ratio == 0.0 :
234+ continue
235+ for p in group ['sparse_objects' ]:
236+ self .clean (p ,
237+ clean_ratio = clean_ratio ,
238+ method = group ['method' ],
239+ new_coefficients_scale = group ['new_coefficients_scale' ],
240+ new_coefficients_distribution = group ['new_coefficients_distribution' ],
241+ )
70242
0 commit comments