Skip to content

Commit 6b42882

Browse files
authored
Use mpu in DeepSpeedConfig() call (#1271)
* Use mpu in DeepSpeedConfig() call * Improve argument naming
1 parent bc17042 commit 6b42882

File tree

1 file changed

+38
-29
lines changed

1 file changed

+38
-29
lines changed

deepspeed/runtime/zero/partition_parameters.py

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -215,11 +215,16 @@ def recurse(cl):
215215
# Inserts _post_init_method at the end of init method
216216
# for all sub classes of torch.nn.Module
217217
class InsertPostInitMethodToModuleSubClasses(object):
218-
def __init__(self, enabled=True, mem_efficient_linear=True, config=None, dtype=None):
218+
def __init__(self,
219+
enabled=True,
220+
mem_efficient_linear=True,
221+
ds_config=None,
222+
dtype=None):
219223
self.mem_efficient_linear = mem_efficient_linear
220224
self.enabled = enabled
221-
self._set_dtype(config, dtype)
222-
assert self.dtype in [torch.half, torch.float], f"Invalid data type {self.dtype}, allowed values are [torch.half, torch.float]"
225+
self._set_dtype(ds_config, dtype)
226+
assert self.dtype in [
227+
torch.half, torch.float], f"Invalid data type {self.dtype}, allowed values are [torch.half, torch.float]"
223228

224229
def __enter__(self):
225230
if not self.enabled:
@@ -287,8 +292,8 @@ def _disable_class(cls):
287292
torch.Tensor.__new__ = torch.Tensor.__old_new__
288293
torch.empty = _orig_torch_empty
289294

290-
#un doing it here will undo it during training
291-
#if self.mem_efficient_linear:
295+
# un doing it here will undo it during training
296+
# if self.mem_efficient_linear:
292297
# torch.nn.functional.linear = self.linear_bk
293298
# if self.mem_efficient_linear:
294299
# torch.nn.functional.linear = self.linear_bk
@@ -303,8 +308,7 @@ def _post_init_method(self, module):
303308

304309
def _set_dtype(self, ds_config, dtype):
305310
if ds_config is not None and dtype is None:
306-
_ds_config = DeepSpeedConfig(ds_config)
307-
self.dtype = torch.half if _ds_config.fp16_enabled else torch.float
311+
self.dtype = torch.half if ds_config.fp16_enabled else torch.float
308312
elif dtype is None:
309313
self.dtype = torch.half
310314
else:
@@ -321,9 +325,11 @@ def __init__(self,
321325
mem_efficient_linear=True,
322326
remote_device=None,
323327
pin_memory=False,
328+
config_dict_or_path=None,
324329
config=None,
325330
enabled=True,
326-
dtype=None):
331+
dtype=None,
332+
mpu=None):
327333
"""A context to enable massive model construction for training with
328334
ZeRO-3. Models are automatically partitioned (or, sharded) across the
329335
system and converted to half precision.
@@ -343,12 +349,14 @@ def __init__(self,
343349
pin_memory (bool, optional): Potentially increase performance by
344350
using pinned memory for model weights. ``remote_device`` must be
345351
``"cpu"``. Defaults to ``False``.
346-
config (``json file`` or dict, optional): If provided, provides configuration
352+
config_dict_or_path (dict or ``json file``, optional): If provided, provides configuration
347353
for swapping fp16 params to NVMe.
354+
config (dict or ``json file``, optional): Deprecated, use config_dict_or_path instead.
348355
enabled (bool, optional): If ``False``, this context has no
349356
effect. Defaults to ``True``.
350357
dtype (``dtype``, optional): Can be used to change the data type of the parameters.
351358
Supported options are ``torch.half`` and ``torch.float``. Defaults to ``None``
359+
mpu (``object``, optional): A model parallelism unit object that implements get_{model,data}_parallel_{rank,group,wolrd_size}
352360
353361
This context accelerates model initialization and enables models that
354362
are too large to allocate in their entirety in CPU memory. It has the
@@ -420,9 +428,11 @@ def get_model():
420428
model = deepspeed.zero.Init(module=model)
421429
"""
422430

431+
_ds_config = DeepSpeedConfig(config_dict_or_path,
432+
mpu) if config_dict_or_path is not None else None
423433
super().__init__(enabled=enabled,
424434
mem_efficient_linear=mem_efficient_linear,
425-
config=config,
435+
ds_config=_ds_config,
426436
dtype=dtype)
427437
if not torch.distributed.is_initialized():
428438
init_distributed()
@@ -435,21 +445,20 @@ def get_model():
435445
self.rank = torch.distributed.get_rank(group=self.ds_process_group)
436446
self.world_size = torch.distributed.get_world_size(group=self.ds_process_group)
437447

438-
#Local device is the device where the parameters are consumed
439-
#It is the device where parameters are fully instantiated using allgather
448+
# Local device is the device where the parameters are consumed
449+
# It is the device where parameters are fully instantiated using allgather
440450
self.local_device = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"]))
441451

442-
self._validate_remote_device(remote_device, config)
452+
self._validate_remote_device(remote_device, _ds_config)
443453

444-
#Remote device is the device where parameter partiitons are stored
445-
#It can be same as local_device or it could be CPU or NVMe.
454+
# Remote device is the device where parameter partiitons are stored
455+
# It can be same as local_device or it could be CPU or NVMe.
446456
self.remote_device = self.local_device if remote_device is None else remote_device
447457
self.pin_memory = pin_memory if (
448458
self.remote_device == OFFLOAD_CPU_DEVICE) else False
449459

450460
# Enable fp16 param swapping to NVMe
451461
if self.remote_device == OFFLOAD_NVME_DEVICE:
452-
_ds_config = DeepSpeedConfig(config)
453462
self.param_swapper = AsyncPartitionedParameterSwapper(_ds_config)
454463
else:
455464
self.param_swapper = None
@@ -463,22 +472,21 @@ def get_model():
463472
self._convert_to_deepspeed_param(param)
464473
param.partition()
465474

466-
def _validate_remote_device(self, remote_device, ds_config):
467-
if ds_config is not None:
468-
_ds_config = DeepSpeedConfig(ds_config)
475+
def _validate_remote_device(self, remote_device, _ds_config):
476+
if _ds_config is not None:
469477
if remote_device in [None, OFFLOAD_CPU_DEVICE]:
470478
if _ds_config.zero_config.offload_param is not None:
471479
offload_param_device = _ds_config.zero_config.offload_param[
472480
OFFLOAD_PARAM_DEVICE]
473481
assert offload_param_device != OFFLOAD_NVME_DEVICE, \
474-
f"{OFFLOAD_PARAM_DEVICE} in DeepSpeed Config cannot be {offload_param_device} if remote device is {remote_device}."
482+
f"{OFFLOAD_PARAM_DEVICE} in DeepSpeed Config cannot be {offload_param_device} if remote device is {remote_device}."
475483

476484
if remote_device == OFFLOAD_NVME_DEVICE:
477485
assert _ds_config.zero_config.offload_param is not None, \
478-
f'{OFFLOAD_PARAM} must be defined in DeepSpeed Config if remote device is {OFFLOAD_NVME_DEVICE}.'
486+
f'{OFFLOAD_PARAM} must be defined in DeepSpeed Config if remote device is {OFFLOAD_NVME_DEVICE}.'
479487

480488
assert _ds_config.zero_config.offload_param[OFFLOAD_PARAM_NVME_PATH] is not None, \
481-
f'{OFFLOAD_PARAM_NVME_PATH} in DeepSpeed Config cannot be None if remote device is {OFFLOAD_NVME_DEVICE}'
489+
f'{OFFLOAD_PARAM_NVME_PATH} in DeepSpeed Config cannot be None if remote device is {OFFLOAD_NVME_DEVICE}'
482490

483491
def _post_init_method(self, module):
484492
#see_memory_usage(f"Before converting parmas in {module.__class__.__name__}", force=False)
@@ -624,7 +632,7 @@ def _ensure_availability_of_partitioned_params(self, params):
624632

625633
def _all_gather(self, param_list, async_op=False, hierarchy=None):
626634

627-
#fetches from nvme if the partition is not available and in nvme
635+
# fetches from nvme if the partition is not available and in nvme
628636
self._ensure_availability_of_partitioned_params(param_list)
629637

630638
handles = []
@@ -651,10 +659,10 @@ def _all_gather(self, param_list, async_op=False, hierarchy=None):
651659
def _partition(self, param_list, force=False, has_been_updated=False):
652660
for param in param_list:
653661
#print_rank_0(f"Before Partitioning Param {param.ds_id}")
654-
#self._param_status(param)
662+
# self._param_status(param)
655663
self._partition_param(param, has_been_updated=has_been_updated)
656664
param.ds_status = ZeroParamStatus.NOT_AVAILABLE
657-
#if param.ds_tensor is not None:
665+
# if param.ds_tensor is not None:
658666
# assert id(param.data) == id(param.ds_tensor.data), \
659667
# "After the parameters are initially partitioned, make sure we are not recreating the partition."
660668
#print_rank_0(f"After Partitioning Param {param.ds_id}")
@@ -678,7 +686,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False):
678686
# if numel in empty_buffers:
679687
# empty_buffers[numel].append(buffer)
680688

681-
#if torch.distributed.get_rank():
689+
# if torch.distributed.get_rank():
682690
# print(f"Releasing {param.data.numel()}")
683691
if param.ds_tensor is not None and not has_been_updated:
684692

@@ -687,7 +695,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False):
687695
see_memory_usage(
688696
f'Before partitioning param {param.ds_id} {param.shape}',
689697
force=False)
690-
#param.data does not store anything meaningful in partitioned state
698+
# param.data does not store anything meaningful in partitioned state
691699
param.data = torch.ones(1, dtype=self.dtype).to(param.device)
692700
see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}',
693701
force=False)
@@ -765,7 +773,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False):
765773

766774
#param.data = param.ds_tensor.data
767775

768-
#param.data does not store anything meaningful in partitioned state
776+
# param.data does not store anything meaningful in partitioned state
769777

770778
see_memory_usage(f'Before partitioning param {param.ds_id} {param.shape}',
771779
force=False)
@@ -1002,7 +1010,8 @@ def _partition_gradient(self, param, partition_buffer=None, accumulate=False):
10021010
dtype=param.dtype,
10031011
device=param.device)
10041012
else:
1005-
assert partition_buffer.numel() >= partition_size, f"The partition buffer size {partition_buffer.numel()} should match the size of param.ds_tensor {partition_size}"
1013+
assert partition_buffer.numel(
1014+
) >= partition_size, f"The partition buffer size {partition_buffer.numel()} should match the size of param.ds_tensor {partition_size}"
10061015

10071016
rank = torch.distributed.get_rank(group=self.ds_process_group)
10081017
start = partition_size * rank

0 commit comments

Comments
 (0)