1- # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
1+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
22#
33# Licensed under the Apache License, Version 2.0 (the "License");
44# you may not use this file except in compliance with the License.
1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ #Taken and modified for fairscale from:
15+ # https://github.com/facebookresearch/fairscale/blob/main/fairscale/optim/oss.py
16+ #Commit: 8acbec718f3c70a6b9785470bb9e05cd84fc3f8e
1417
15- ######
18+ import numpy as np
19+ from itertools import chain
1620from functools import reduce
21+ from collections import OrderedDict
1722
1823import paddle
1924from paddle import framework
25+ import paddle .distributed as dist
26+ from paddle .optimizer import Optimizer
27+
2028from ...utils .log_util import logger
29+ from ...utils .internal_storage import ParamStorage
30+ from ...meta_parallel .sharding .sharding_utils import Type
31+
32+ # CUDA alignment 256 bytes
33+ alignment = {"gpu" : 256 , }
34+ align = {
35+ Type .fp16 .value : 2 ,
36+ Type .fp32 .value : 4 ,
37+ }
38+
39+ __all__ = ["ShardingOptimizerStage2" ]
2140
2241
23- def _is_trainable (param : paddle . Tensor ) -> bool :
42+ def _is_trainable (param ) :
2443 return not param .stop_gradient
2544
2645
@@ -41,13 +60,8 @@ class DygraphShardingOptimizer(object):
4160 # 3. dynamic trainable params, which is the case bewteen pretraining and finetuning
4261 # 4. option to choose fuse comm (more GPU MEM need) or un-fuse comm
4362
44- def __init__ (
45- self ,
46- hcg ,
47- user_defined_strategy ,
48- params ,
49- inner_optimizer_class ,
50- ** inner_optimizer_kargs , ):
63+ def __init__ (self , hcg , user_defined_strategy , params ,
64+ inner_optimizer_class , ** inner_optimizer_kargs ):
5165
5266 if not isinstance (params , list ):
5367 raise TypeError (
@@ -196,3 +210,245 @@ def _grad_clip(self):
196210
197211 def __getattr__ (self , item ):
198212 return getattr (self ._inner_optimizer , item )
213+
214+
215+ class ShardingOptimizerStage2 (Optimizer ):
216+ """
217+ A wrapper for Sharding Stage2 Optimizer in Dygraph.
218+
219+ .. warning: ShardingOptimizer encapsulates the optimization strategy and integrates it into the optimizer.
220+
221+ .. ZeRO: 1.https://arxiv.org/pdf/1910.02054.pdf 2.https://arxiv.org/pdf/1910.02054.pdf.
222+
223+ """
224+
225+ # TODO (Baibaifan)
226+ # Feature Notes:
227+ # 1. Unified memory for parameters and parameters.grad to InternalStorage.
228+ # 2. Support the segmentation of optimizer parameters and partial updating of parameters.
229+ # 3. Dynamically adjust training parameters and models。
230+ # 4. Support offload function.
231+ # 5. Support the establishment of independent communication groups.
232+ # 6. Broadcast_fp16 is not supported now.
233+ def __init__ (self ,
234+ params ,
235+ optim ,
236+ group ,
237+ broadcast_fp16 = False ,
238+ offload = False ,
239+ device = "gpu" ,
240+ accumulation_steps = None ,
241+ ** kw ):
242+
243+ super ().__init__ (optim ._learning_rate , params , kw )
244+
245+ # Segmentation information
246+ self ._dtype_rank_params = OrderedDict (
247+ ) # {dtype:[param1,param2]} device, rank, params
248+ self ._param2rank = {}
249+ self ._segment_params = []
250+ self ._rank_buffer_size = {} # {dtype: {rank: numel+alignment}}
251+ self ._param2align = {} # {param.name: align}
252+
253+ # Default information
254+ self ._optim_defaults = kw
255+ self ._optim = optim
256+ self ._local_params = params
257+ self ._default_device = device
258+ self ._accumulation_steps = accumulation_steps
259+
260+ assert group is not None , "Distributed communication group is must be gived"
261+ self .group = group
262+ self .world_size = group .nranks
263+ self .rank = group .rank
264+
265+ self .broadcast_fp16 = broadcast_fp16
266+ self .param_storages = {} # {dtype: {rank: InternalStorage}}
267+ self .offload = offload # Using for offload
268+
269+ # Update optimizer parameters and adjust parameter storage and use according to rank.
270+ self .update_opt_status ()
271+
272+ def update_opt_status (self ):
273+ """Update optimizer status and parameter storage information, and special functions to be developed.
274+ """
275+ # func 1
276+ self ._integration_params ()
277+
278+ # fun 2 TODO
279+
280+ # Segement helpers
281+
282+ def segment_params (self ):
283+ """
284+ Divide all optimizer parameters equally into rank.
285+ """
286+ if len (self ._segment_params ) == 0 :
287+ self ._segment_params , param_lists = [
288+ [] for _ in range (self .world_size )
289+ ], [[] for _ in range (self .world_size )]
290+ sizes = [0 ] * self .world_size
291+ for param in self ._local_params :
292+ # Add this param to rank with smallest size.
293+ rank = sizes .index (min (sizes ))
294+ param_lists [rank ].append (param )
295+
296+ # Statistical real numels
297+ sizes [rank ] += np .prod (param .shape ) if param .trainable else 0
298+
299+ for rank , params in enumerate (param_lists ):
300+ # param_group_rank = copy.copy(params)
301+ self ._segment_params [rank ].extend (params )
302+ return self ._segment_params
303+
304+ @property
305+ def local_params (self ):
306+ return self ._local_params
307+
308+ @property
309+ def accumulation_steps (self ):
310+ return self ._accumulation_steps
311+
312+ @property
313+ def param2rank (self ):
314+ """Map the params to the rank which owns them"""
315+ if len (self ._param2rank ) == 0 :
316+ for rank , params in enumerate (self .segment_params ()):
317+ for param in params :
318+ self ._param2rank [param .name ] = rank
319+ return self ._param2rank
320+
321+ @property
322+ def dtype_rank_params (self ):
323+ """
324+ Divide the parameters into groups according to rank and dtype.
325+ """
326+ if len (self ._dtype_rank_params ) == 0 :
327+ # Assign the parameters of each rank according to the type
328+ for param in self ._local_params :
329+ if param .dtype not in self ._dtype_rank_params .keys ():
330+ self ._dtype_rank_params [
331+ param .dtype ] = [[] for _ in range (self .world_size )]
332+ self ._dtype_rank_params [param .dtype ][self .param2rank [
333+ param .name ]].append (param )
334+
335+ # Sort per rank params by size
336+ for dtype in self ._dtype_rank_params .keys ():
337+ for rank_params in self ._dtype_rank_params [dtype ]:
338+ rank_params .sort (key = lambda x : np .prod (x .shape ))
339+
340+ return self ._dtype_rank_params
341+
342+ @property
343+ def rank_buffer_size (self ):
344+ """
345+ Count the memory size of the parameters corresponding to rank under the corresponding dtype.
346+ """
347+ # CUDA alignment 256 bytes
348+ if len (self ._rank_buffer_size ) == 0 :
349+ for dtype in self .dtype_rank_params .keys ():
350+ if dtype not in self ._rank_buffer_size .keys ():
351+ self ._rank_buffer_size [dtype ] = {}
352+ for dst_rank , per_rank_params in enumerate (
353+ self .dtype_rank_params [dtype ]):
354+ if dst_rank not in self ._rank_buffer_size [dtype ].keys ():
355+ self ._rank_buffer_size [dtype ][dst_rank ] = 0
356+ for param in per_rank_params :
357+ if not param .trainable :
358+ continue
359+ size = np .prod (param .shape ) * align [dtype ]
360+ remaining = size % alignment [self ._default_device ]
361+ ali = 0 if remaining == 0 else alignment [
362+ self ._default_device ] - remaining
363+ align_ = ali // align [dtype ]
364+ self ._rank_buffer_size [dtype ][dst_rank ] += np .prod (
365+ param .shape ) + align_
366+ self ._param2align [param .name ] = align_
367+
368+ return self ._rank_buffer_size
369+
370+ def _integration_params (self ):
371+ """
372+ Integrate the parameters into a continuous memory according to rank, and support the update of training parameters.
373+ """
374+
375+ for dtype , per_rank_params in self .dtype_rank_params .items ():
376+ if dtype not in self .param_storages .keys ():
377+ self .param_storages [dtype ] = {}
378+
379+ for dst_rank , params in enumerate (per_rank_params ):
380+ if len (params ) > 0 :
381+
382+ # Merge all the trainable params in a single InternalStorage
383+ trainable_params = list (
384+ filter (lambda x : x .trainable , params ))
385+ if trainable_params :
386+ param_storage = ParamStorage (
387+ size = self .rank_buffer_size [dtype ][dst_rank ],
388+ dtype = dtype ,
389+ device = self ._default_device )
390+
391+ param_storage .add_rank_params (trainable_params ,
392+ self ._param2align )
393+ self .param_storages [dtype ][dst_rank ] = param_storage
394+
395+ # Clear the InternalStorage keys which are not in use anymore
396+ dtype_in_use = list (self .dtype_rank_params .keys ())
397+ dtype_to_pop = list (
398+ filter (lambda x : x not in dtype_in_use , self .param_storages .keys ()))
399+ for d in dtype_to_pop :
400+ self .param_storages .pop (d )
401+
402+ def step (self ):
403+ """
404+ A wrapper for Optimizer's step function to finish the update operation of the optimizer.
405+ """
406+
407+ # Synchronize optimizer parameters for the current rank
408+ if len (self .dtype_rank_params .keys (
409+ )) == 1 and Type .fp32 .value in self .dtype_rank_params .keys ():
410+ self ._optim ._parameter_list = self .dtype_rank_params [
411+ Type .fp32 .value ][self .rank ]
412+ elif len (self .dtype_rank_params .keys (
413+ )) == 1 and Type .fp16 .value in self .dtype_rank_params .keys ():
414+ self ._optim ._parameter_list = self .dtype_rank_params [
415+ Type .fp16 .value ][self .rank ]
416+ else :
417+ self ._optim ._parameter_list = self .dtype_rank_params [
418+ Type .fp16 .value ][self .rank ] + self .dtype_rank_params [
419+ Type .fp32 .value ][self .rank ]
420+
421+ # Run the optimizer of the current rank step
422+ self ._optim .step ()
423+
424+ # Synchronize all the updated shards in between the ranks
425+ self ._broadcast_params ()
426+
427+ # Return full parameters to optimizer parameters
428+ self ._optim ._parameter_list = self ._local_params
429+
430+ def clear_cache (self ):
431+ self ._segment_params .clear ()
432+ self ._dtype_rank_params .clear ()
433+ self ._param2rank .clear ()
434+
435+ @paddle .no_grad ()
436+ def _broadcast_params (self ):
437+ """Broadcast the parameters of the current rank to each rank"""
438+
439+ assert self ._default_device == "gpu" , "Only supported gpu"
440+
441+ # Exchange all the shards with the other ranks
442+ for dtype_per_rank in self .param_storages .values ():
443+ for dst_rank , internal_storage in dtype_per_rank .items ():
444+ dist .broadcast (
445+ tensor = internal_storage .buffer ,
446+ src = dst_rank ,
447+ group = self .group ,
448+ use_calc_stream = True )
449+
450+ # Multi stream operation will be supported later
451+ dist .wait (
452+ tensor = internal_storage .buffer ,
453+ group = self .group ,
454+ use_calc_stream = True )
0 commit comments