Skip to content

Commit db017fd

Browse files
author
Shaden Smith
committed
Minor tweaks to support Megatron 2.4 + DS 3D
1 parent c697d7a commit db017fd

File tree

5 files changed

+67
-9
lines changed

5 files changed

+67
-9
lines changed

deepspeed/runtime/activation_checkpointing/checkpointing.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from deepspeed.runtime.config import DeepSpeedConfig
2626
from deepspeed.utils import logger
27-
from deepspeed.runtime.utils import move_to_device, see_memory_usage
27+
from deepspeed.runtime.utils import move_to_device, see_memory_usage, bwc_tensor_model_parallel_rank
2828
from deepspeed.utils.timer import SynchronizedWallClockTimer as Timers
2929

3030
# DeepSpeed Checkpointing Enabled or Disabled
@@ -213,9 +213,12 @@ def model_parallel_cuda_manual_seed(seed):
213213
model parallel regions.
214214
"""
215215
global mpu
216+
217+
tp_rank = bwc_tensor_model_parallel_rank(mpu)
218+
216219
# 2718 is just for fun and any POSITIVE value will work.
217220
offset = seed + 2718
218-
model_parallel_seed = offset + mpu.get_model_parallel_rank()
221+
model_parallel_seed = offset + tp_rank
219222
# Data parallel gets the original sedd.
220223
data_parallel_seed = seed
221224

@@ -225,7 +228,7 @@ def model_parallel_cuda_manual_seed(seed):
225228
'model parallel rank {}, and data parallel rank {} with '
226229
'model parallel seed: {} and data parallel seed: {}'.format(
227230
torch.distributed.get_rank(),
228-
mpu.get_model_parallel_rank(),
231+
tp_rank,
229232
mpu.get_data_parallel_rank(),
230233
model_parallel_seed,
231234
data_parallel_seed),
@@ -384,9 +387,14 @@ def save_args_for_backward(*all_args):
384387
global data_offsets, size_offsets
385388
if mp_rank is None:
386389
if mpu is not None:
387-
mp_rank = mpu.get_model_parallel_rank()
388-
mp_size = mpu.get_model_parallel_world_size()
389-
mp_group = mpu.get_model_parallel_group()
390+
if hasattr(mpu, 'get_tensor_model_parallel_rank'):
391+
mp_rank = mpu.get_tensor_model_parallel_rank()
392+
mp_size = mpu.get_tensor_model_parallel_world_size()
393+
mp_group = mpu.get_tensor_model_parallel_group()
394+
else:
395+
mp_rank = mpu.get_model_parallel_rank()
396+
mp_size = mpu.get_model_parallel_world_size()
397+
mp_group = mpu.get_model_parallel_group()
390398
else:
391399
mp_rank = 0
392400
mp_size = 1

deepspeed/runtime/engine.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1275,7 +1275,14 @@ def _take_model_step(self, lr_kwargs, block_eigenvalue={}):
12751275
self.skipped_steps += 1
12761276
else:
12771277
if self.lr_scheduler is not None:
1278-
self.lr_scheduler.step(**(lr_kwargs or {}))
1278+
try:
1279+
self.lr_scheduler.step(**(lr_kwargs or {}))
1280+
except TypeError:
1281+
# XXX Hack to work with Megatron 2.0 and DeepSpeed pipelines.
1282+
# We don't currently have a way to specify lr_kwargs from
1283+
# pipe_engine.train_batch()
1284+
self.lr_scheduler.step(increment=self.train_batch_size())
1285+
12791286

12801287
if report_progress and (self.global_steps + 1) % self.steps_per_print() == 0:
12811288
self._report_progress(self.global_steps + 1)

deepspeed/runtime/pipe/engine.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ def __init__(self, *super_args, **super_kwargs):
110110
self.is_model_parallel = self.grid.model_parallel_size > 1
111111

112112
# Partition input/output buffers
113-
self.is_pipe_partitioned = self.is_model_parallel
113+
# XXX temporarily disable while I revert some partition hacks.
114+
self.is_pipe_partitioned = False #self.is_model_parallel
114115
self.is_grad_partitioned = False
115116

116117
model_parameters = filter(lambda p: p.requires_grad, self.module.parameters())

deepspeed/runtime/pipe/module.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,10 @@ def load_state_dir(self, load_dir, strict=True):
585585
self._synchronize_tied_weights()
586586

587587
def _is_checkpointable(self, funcs):
588-
if self.__class__.__name__ == 'GPT2ModelPipe':
588+
# This is an unfortunate hack related to torch and deepspeed activation checkpoint implementations.
589+
# Some layers like torch.nn.Embedding will not receive grads if checkpointed, which breaks things.
590+
# I presume it's related to the discrete inputs that cannot require_grad? Need to revisit.
591+
if self.__class__.__name__ in ('GPTModelPipe', 'GPT2ModelPipe'):
589592
return all('ParallelTransformerLayerPipe' in f.__class__.__name__
590593
for f in funcs)
591594

deepspeed/runtime/utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,45 @@ def set_random_seed(seed):
4949
torch.manual_seed(seed)
5050

5151

52+
def bwc_tensor_model_parallel_rank(mpu=None):
53+
"""Backwards-compatible way of querying the tensor model parallel rank from
54+
an ``mpu`` object.
55+
56+
*Tensor* model parallelism means that tensors are physically split across
57+
processes. This contrasts with *pipeline* model parallelism, in which the
58+
layers are partitioned but tensors left intact.
59+
60+
The API for tensor model parallelism has changed across versions and this
61+
helper provides a best-effort implementation across versions of ``mpu``
62+
objects. The preferred mechanism is
63+
``mpu.get_tensor_model_parallel_rank()``.
64+
65+
This should "just work" with both Megatron-LM and DeepSpeed's pipeline
66+
parallelism.
67+
68+
Args:
69+
mpu (model parallel unit, optional): The tensor model parallel rank.
70+
If ``mpu=None``, returns 0. Defaults to ``None``.
71+
72+
Returns:
73+
int: the rank
74+
"""
75+
if mpu is None:
76+
# No model parallelism in easy :)
77+
return 0
78+
79+
if hasattr(mpu, 'get_tensor_model_parallel_rank'):
80+
# New Megatron and DeepSpeed convention (post pipeline-parallelism release)
81+
return mpu.get_tensor_model_parallel_rank()
82+
elif hasattr(mpu, 'get_slice_parallel_rank'):
83+
# Some DeepSpeed + pipeline parallelism versions
84+
return mpu.get_slice_parallel_rank()
85+
else:
86+
# Deprecated Megatron and DeepSpeed convention
87+
return mpu.get_model_parallel_rank()
88+
89+
90+
5291
def move_to_device(item, device):
5392
"""
5493
Move tensor onto device. Works on individual tensors, and tensors contained/nested in lists, tuples, and dicts.

0 commit comments

Comments
 (0)