@@ -215,11 +215,16 @@ def recurse(cl):
215215# Inserts _post_init_method at the end of init method
216216# for all sub classes of torch.nn.Module
217217class InsertPostInitMethodToModuleSubClasses (object ):
218- def __init__ (self , enabled = True , mem_efficient_linear = True , config = None , dtype = None ):
218+ def __init__ (self ,
219+ enabled = True ,
220+ mem_efficient_linear = True ,
221+ ds_config = None ,
222+ dtype = None ):
219223 self .mem_efficient_linear = mem_efficient_linear
220224 self .enabled = enabled
221- self ._set_dtype (config , dtype )
222- assert self .dtype in [torch .half , torch .float ], f"Invalid data type { self .dtype } , allowed values are [torch.half, torch.float]"
225+ self ._set_dtype (ds_config , dtype )
226+ assert self .dtype in [
227+ torch .half , torch .float ], f"Invalid data type { self .dtype } , allowed values are [torch.half, torch.float]"
223228
224229 def __enter__ (self ):
225230 if not self .enabled :
@@ -287,8 +292,8 @@ def _disable_class(cls):
287292 torch .Tensor .__new__ = torch .Tensor .__old_new__
288293 torch .empty = _orig_torch_empty
289294
290- #un doing it here will undo it during training
291- #if self.mem_efficient_linear:
295+ # un doing it here will undo it during training
296+ # if self.mem_efficient_linear:
292297 # torch.nn.functional.linear = self.linear_bk
293298 # if self.mem_efficient_linear:
294299 # torch.nn.functional.linear = self.linear_bk
@@ -303,8 +308,7 @@ def _post_init_method(self, module):
303308
304309 def _set_dtype (self , ds_config , dtype ):
305310 if ds_config is not None and dtype is None :
306- _ds_config = DeepSpeedConfig (ds_config )
307- self .dtype = torch .half if _ds_config .fp16_enabled else torch .float
311+ self .dtype = torch .half if ds_config .fp16_enabled else torch .float
308312 elif dtype is None :
309313 self .dtype = torch .half
310314 else :
@@ -321,9 +325,11 @@ def __init__(self,
321325 mem_efficient_linear = True ,
322326 remote_device = None ,
323327 pin_memory = False ,
328+ config_dict_or_path = None ,
324329 config = None ,
325330 enabled = True ,
326- dtype = None ):
331+ dtype = None ,
332+ mpu = None ):
327333 """A context to enable massive model construction for training with
328334 ZeRO-3. Models are automatically partitioned (or, sharded) across the
329335 system and converted to half precision.
@@ -343,12 +349,14 @@ def __init__(self,
343349 pin_memory (bool, optional): Potentially increase performance by
344350 using pinned memory for model weights. ``remote_device`` must be
345351 ``"cpu"``. Defaults to ``False``.
346- config ( ``json file`` or dict , optional): If provided, provides configuration
352+ config_dict_or_path (dict or ``json file``, optional): If provided, provides configuration
347353 for swapping fp16 params to NVMe.
354+ config (dict or ``json file``, optional): Deprecated, use config_dict_or_path instead.
348355 enabled (bool, optional): If ``False``, this context has no
349356 effect. Defaults to ``True``.
350357 dtype (``dtype``, optional): Can be used to change the data type of the parameters.
351358 Supported options are ``torch.half`` and ``torch.float``. Defaults to ``None``
359+ mpu (``object``, optional): A model parallelism unit object that implements get_{model,data}_parallel_{rank,group,wolrd_size}
352360
353361 This context accelerates model initialization and enables models that
354362 are too large to allocate in their entirety in CPU memory. It has the
@@ -420,9 +428,11 @@ def get_model():
420428 model = deepspeed.zero.Init(module=model)
421429 """
422430
431+ _ds_config = DeepSpeedConfig (config_dict_or_path ,
432+ mpu ) if config_dict_or_path is not None else None
423433 super ().__init__ (enabled = enabled ,
424434 mem_efficient_linear = mem_efficient_linear ,
425- config = config ,
435+ ds_config = _ds_config ,
426436 dtype = dtype )
427437 if not torch .distributed .is_initialized ():
428438 init_distributed ()
@@ -435,21 +445,20 @@ def get_model():
435445 self .rank = torch .distributed .get_rank (group = self .ds_process_group )
436446 self .world_size = torch .distributed .get_world_size (group = self .ds_process_group )
437447
438- #Local device is the device where the parameters are consumed
439- #It is the device where parameters are fully instantiated using allgather
448+ # Local device is the device where the parameters are consumed
449+ # It is the device where parameters are fully instantiated using allgather
440450 self .local_device = torch .device ('cuda:{}' .format (os .environ ["LOCAL_RANK" ]))
441451
442- self ._validate_remote_device (remote_device , config )
452+ self ._validate_remote_device (remote_device , _ds_config )
443453
444- #Remote device is the device where parameter partiitons are stored
445- #It can be same as local_device or it could be CPU or NVMe.
454+ # Remote device is the device where parameter partiitons are stored
455+ # It can be same as local_device or it could be CPU or NVMe.
446456 self .remote_device = self .local_device if remote_device is None else remote_device
447457 self .pin_memory = pin_memory if (
448458 self .remote_device == OFFLOAD_CPU_DEVICE ) else False
449459
450460 # Enable fp16 param swapping to NVMe
451461 if self .remote_device == OFFLOAD_NVME_DEVICE :
452- _ds_config = DeepSpeedConfig (config )
453462 self .param_swapper = AsyncPartitionedParameterSwapper (_ds_config )
454463 else :
455464 self .param_swapper = None
@@ -463,22 +472,21 @@ def get_model():
463472 self ._convert_to_deepspeed_param (param )
464473 param .partition ()
465474
466- def _validate_remote_device (self , remote_device , ds_config ):
467- if ds_config is not None :
468- _ds_config = DeepSpeedConfig (ds_config )
475+ def _validate_remote_device (self , remote_device , _ds_config ):
476+ if _ds_config is not None :
469477 if remote_device in [None , OFFLOAD_CPU_DEVICE ]:
470478 if _ds_config .zero_config .offload_param is not None :
471479 offload_param_device = _ds_config .zero_config .offload_param [
472480 OFFLOAD_PARAM_DEVICE ]
473481 assert offload_param_device != OFFLOAD_NVME_DEVICE , \
474- f"{ OFFLOAD_PARAM_DEVICE } in DeepSpeed Config cannot be { offload_param_device } if remote device is { remote_device } ."
482+ f"{ OFFLOAD_PARAM_DEVICE } in DeepSpeed Config cannot be { offload_param_device } if remote device is { remote_device } ."
475483
476484 if remote_device == OFFLOAD_NVME_DEVICE :
477485 assert _ds_config .zero_config .offload_param is not None , \
478- f'{ OFFLOAD_PARAM } must be defined in DeepSpeed Config if remote device is { OFFLOAD_NVME_DEVICE } .'
486+ f'{ OFFLOAD_PARAM } must be defined in DeepSpeed Config if remote device is { OFFLOAD_NVME_DEVICE } .'
479487
480488 assert _ds_config .zero_config .offload_param [OFFLOAD_PARAM_NVME_PATH ] is not None , \
481- f'{ OFFLOAD_PARAM_NVME_PATH } in DeepSpeed Config cannot be None if remote device is { OFFLOAD_NVME_DEVICE } '
489+ f'{ OFFLOAD_PARAM_NVME_PATH } in DeepSpeed Config cannot be None if remote device is { OFFLOAD_NVME_DEVICE } '
482490
483491 def _post_init_method (self , module ):
484492 #see_memory_usage(f"Before converting parmas in {module.__class__.__name__}", force=False)
@@ -624,7 +632,7 @@ def _ensure_availability_of_partitioned_params(self, params):
624632
625633 def _all_gather (self , param_list , async_op = False , hierarchy = None ):
626634
627- #fetches from nvme if the partition is not available and in nvme
635+ # fetches from nvme if the partition is not available and in nvme
628636 self ._ensure_availability_of_partitioned_params (param_list )
629637
630638 handles = []
@@ -651,10 +659,10 @@ def _all_gather(self, param_list, async_op=False, hierarchy=None):
651659 def _partition (self , param_list , force = False , has_been_updated = False ):
652660 for param in param_list :
653661 #print_rank_0(f"Before Partitioning Param {param.ds_id}")
654- #self._param_status(param)
662+ # self._param_status(param)
655663 self ._partition_param (param , has_been_updated = has_been_updated )
656664 param .ds_status = ZeroParamStatus .NOT_AVAILABLE
657- #if param.ds_tensor is not None:
665+ # if param.ds_tensor is not None:
658666 # assert id(param.data) == id(param.ds_tensor.data), \
659667 # "After the parameters are initially partitioned, make sure we are not recreating the partition."
660668 #print_rank_0(f"After Partitioning Param {param.ds_id}")
@@ -678,7 +686,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False):
678686 # if numel in empty_buffers:
679687 # empty_buffers[numel].append(buffer)
680688
681- #if torch.distributed.get_rank():
689+ # if torch.distributed.get_rank():
682690 # print(f"Releasing {param.data.numel()}")
683691 if param .ds_tensor is not None and not has_been_updated :
684692
@@ -687,7 +695,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False):
687695 see_memory_usage (
688696 f'Before partitioning param { param .ds_id } { param .shape } ' ,
689697 force = False )
690- #param.data does not store anything meaningful in partitioned state
698+ # param.data does not store anything meaningful in partitioned state
691699 param .data = torch .ones (1 , dtype = self .dtype ).to (param .device )
692700 see_memory_usage (f'After partitioning param { param .ds_id } { param .shape } ' ,
693701 force = False )
@@ -765,7 +773,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False):
765773
766774 #param.data = param.ds_tensor.data
767775
768- #param.data does not store anything meaningful in partitioned state
776+ # param.data does not store anything meaningful in partitioned state
769777
770778 see_memory_usage (f'Before partitioning param { param .ds_id } { param .shape } ' ,
771779 force = False )
@@ -1002,7 +1010,8 @@ def _partition_gradient(self, param, partition_buffer=None, accumulate=False):
10021010 dtype = param .dtype ,
10031011 device = param .device )
10041012 else :
1005- assert partition_buffer .numel () >= partition_size , f"The partition buffer size { partition_buffer .numel ()} should match the size of param.ds_tensor { partition_size } "
1013+ assert partition_buffer .numel (
1014+ ) >= partition_size , f"The partition buffer size { partition_buffer .numel ()} should match the size of param.ds_tensor { partition_size } "
10061015
10071016 rank = torch .distributed .get_rank (group = self .ds_process_group )
10081017 start = partition_size * rank
0 commit comments