Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,15 @@ class PipelineEngine(DeepSpeedEngine):
This engine is created by ``deepspeed.initialize()`` when a :class:`PipelineModule`
is provided.
"""
def __init__(self, *super_args, **super_kwargs):
def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs):
super().__init__(*super_args, **super_kwargs)
assert isinstance(self.module, PipelineModule), "model must base PipelineModule"

assert self.zero_optimization_stage() < 2, "ZeRO-2 and ZeRO-3 are incompatible with pipeline parallelism"

# We schedule the all-reduces, so disable it in super().backward()
self.enable_backward_allreduce = False
self.has_bool_tensors = has_bool_tensors

# used to disable the pipeline all-reduce when used with 1-bit Adam/1-bit LAMB
self.pipeline_enable_backward_allreduce = True
Expand Down Expand Up @@ -854,7 +855,7 @@ def _exec_send_activations(self, buffer_id):
# NCCL does not like to send torch.BoolTensor types, so cast the mask to half().
# We could do char, but with half() we can eventually flatten with other fp16
# messages (TODO)
if self.module.__class__.__name__ == 'GPT2ModelPipe':
if self.module.__class__.__name__ == 'GPT2ModelPipe' or self.has_bool_tensors:
outputs = list(outputs)
outputs[-1] = outputs[-1].half()
outputs = tuple(outputs)
Expand All @@ -873,7 +874,7 @@ def _exec_send_activations(self, buffer_id):
f'{type(outputs)}')

# Restore the boolean tensor
if self.module.__class__.__name__ == 'GPT2ModelPipe':
if self.module.__class__.__name__ == 'GPT2ModelPipe' or self.has_bool_tensors:
outputs = list(outputs)
outputs[-1] = outputs[-1].bool()
outputs = tuple(outputs)
Expand All @@ -899,7 +900,7 @@ def _exec_send_grads(self, buffer_id):
# a grad that needs to be communicated. We free the buffer immediately
# after, so no need to restore it. The receiver also has a hack that skips
# the recv. This is because NCCL does not let us send torch.BoolTensor :-(.
if self.module.__class__.__name__ == 'GPT2ModelPipe':
if self.module.__class__.__name__ == 'GPT2ModelPipe' or self.has_bool_tensors:
inputs = list(inputs)
inputs.pop()
inputs = tuple(inputs)
Expand Down Expand Up @@ -960,7 +961,7 @@ def _exec_recv_activations(self, buffer_id):

# NCCL does not like to send torch.BoolTensor types, so un-cast the
# attention mask
if self.module.__class__.__name__ == 'GPT2ModelPipe':
if self.module.__class__.__name__ == 'GPT2ModelPipe' or self.has_bool_tensors:
recvd[-1] = recvd[-1].bool()

recvd = tuple(recvd)
Expand Down
9 changes: 8 additions & 1 deletion deepspeed/runtime/pipe/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ def __init__(self,
base_seed=1234,
partition_method='parameters',
activation_checkpoint_interval=0,
activation_checkpoint_func=checkpointing.checkpoint):
activation_checkpoint_func=checkpointing.checkpoint,
checkpointable_layers=None):
"""Modules to be parallelized with pipeline parallelism.

The key constraint that enables pipeline parallelism is the
Expand Down Expand Up @@ -137,6 +138,10 @@ def forward(self, inputs):

self.loss_fn = loss_fn

self.checkpointable_layers = checkpointable_layers
if checkpointable_layers is not None:
assert isinstance(checkpointable_layers, list), "param `checkpointable_layers` must be type of list."

self.seed_layers = seed_layers
self.seed_fn = seed_fn
self.base_seed = base_seed
Expand Down Expand Up @@ -602,6 +607,8 @@ def _is_checkpointable(self, funcs):
if self.__class__.__name__ in ('GPTModelPipe', 'GPT2ModelPipe'):
return all('ParallelTransformerLayerPipe' in f.__class__.__name__
for f in funcs)
if self.checkpointable_layers is not None:
return all(f.__class__.__name__ in self.checkpointable_layers for f in funcs)

params = [f.parameters() for f in funcs if isinstance(f, torch.nn.Module)]
return any(len(list(p)) > 0 for p in params)