Skip to content
Merged
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
db017fd
Minor tweaks to support Megatron 2.4 + DS 3D
Jun 6, 2021
407ff0f
pipe partitioning
Jun 6, 2021
a096d32
re-enable grad buffer partitioning
Jun 11, 2021
9b4093b
Avoid partitioning small activations
tjruwase Jun 11, 2021
182be7b
Merge pull request #4 from ShadenSmith/olruwase/partition_activation
tjruwase Jun 11, 2021
3e948df
send/recv
Jun 13, 2021
b6a2cb3
isend/irecv missing wait
Jun 13, 2021
6bb63b8
turn off async ops
Jun 14, 2021
8097690
Merge branch 'megatron2.4-3d-sendrecv' into megatron2.4-3d
Jun 14, 2021
bd9e953
less verbose load
Jun 26, 2021
081ddb5
Merge branch 'master' into megatron2.4-3d
jeffra Jun 30, 2021
d26c258
added shaden's set_train_batch_size patches, plus formatting
jeffra Jul 13, 2021
9dbfdbd
Adds engine.was_step_applied() (#1251)
Jul 26, 2021
d6945de
Cleaning up tensor/pipe parallel accounting. (#1252)
Jul 26, 2021
f93e22b
Correctness fix PP+ZeRO for gradient accumulation + updates from mast…
jeffra Jul 30, 2021
e9b5dff
dont clear grads in stage 1 code path
jeffra Jul 31, 2021
4b35409
prevent none grads from being reduced
jeffra Jul 31, 2021
bc17042
fix empty grad zero tests
jeffra Aug 2, 2021
6b42882
Use mpu in DeepSpeedConfig() call (#1271)
tjruwase Aug 9, 2021
cce85b8
API for obtaining global gradient norm (#1292)
tjruwase Aug 9, 2021
e65e511
turn excessive noise off (#1293)
stas00 Aug 11, 2021
db2f8a0
[zero] restore fp16 params if no zero ckpts available (#1322)
jeffra Aug 25, 2021
72ce55a
Fix PP checkpoint bloat (#1324)
tjruwase Aug 25, 2021
c7f3bc5
update for cuda-11.4 (#1329)
stas00 Aug 30, 2021
ddaa406
Try something out
thomasw21 Sep 20, 2021
b57a10b
Woops
thomasw21 Sep 21, 2021
a7cca98
Make deepspeed pass any types of dtypes between stages
thomasw21 Sep 23, 2021
2c5d1e4
Woops
thomasw21 Sep 23, 2021
d6f7b00
Woops 2
thomasw21 Sep 23, 2021
33e2471
Woops 3
thomasw21 Sep 23, 2021
4d1b009
Try debugging deadlock
thomasw21 Sep 23, 2021
ab64b54
Fix dtype
thomasw21 Sep 29, 2021
8c29337
Fix some more things
Oct 5, 2021
e34994e
Woops
Oct 5, 2021
2c55a32
Use list comprehension instead of for loops, and increase the number …
thomasw21 Oct 5, 2021
456c49a
Run pre-commit
thomasw21 Oct 6, 2021
c187645
Use ValueError + error msg instead of NotImplemetedError
thomasw21 Oct 6, 2021
48dc913
Merge remote-tracking branch 'origin/master' into big-science-fix-pas…
thomasw21 Oct 6, 2021
7158a21
Use tuples instead of lists
thomasw21 Oct 6, 2021
e2c875a
Merge branch 'master' into big-science-fix-passing-multiple-tensors
jeffra Oct 6, 2021
0d8daac
Merge branch 'master' into big-science-fix-passing-multiple-tensors
tjruwase Oct 6, 2021
673a326
Make sure to set as input a tensor when required, instead of a tuple …
thomasw21 Oct 6, 2021
1c3dee5
Merge branch 'big-science-fix-passing-multiple-tensors' of github.com…
thomasw21 Oct 6, 2021
8060021
Update inputs as well
thomasw21 Oct 7, 2021
19de4aa
Merge branch 'master' into big-science-fix-passing-multiple-tensors
tjruwase Oct 7, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 81 additions & 38 deletions deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,22 @@ class PipelineEngine(DeepSpeedEngine):
This engine is created by ``deepspeed.initialize()`` when a :class:`PipelineModule`
is provided.
"""
ID_TO_DTYPE = [
torch.float32,
torch.float64,
torch.complex64,
torch.complex128,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.bool
]
DTYPE_TO_ID = {dtype: id_ for id_, dtype in enumerate(ID_TO_DTYPE)}

def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs):
super().__init__(*super_args, **super_kwargs)
assert isinstance(self.module, PipelineModule), "model must base PipelineModule"
Expand Down Expand Up @@ -568,18 +584,13 @@ def _exec_forward_pass(self, buffer_id):
local_part=inputs[1],
group=self.grid.get_slice_parallel_group())

inputs = part_input.full()
inputs.requires_grad = True
part_input = None
self.pipe_buffers['inputs'][buffer_id] = inputs
'''
inputs = tuple([part_input.full(), inputs[2]])
inputs = (part_input.full(), *inputs[2:])
inputs[0].requires_grad = True
# skip mask
#inputs[1].requires_grad = True
part_input = None
inputs = inputs[0] if len(inputs) == 1 else inputs
self.pipe_buffers['inputs'][buffer_id] = inputs
'''

# Zero out the gradients each time we use the tensor because only the data in
# tensor changes across batches
Expand All @@ -589,14 +600,26 @@ def _exec_forward_pass(self, buffer_id):

# Partition the outputs if we are not the last stage
if self.is_pipe_partitioned and not self.is_last_stage():
assert torch.is_tensor(outputs)
part = PartitionedTensor(tensor=outputs,
if isinstance(outputs, tuple):
first_output = outputs[0]
# TODO: Improve pipe partitioning to pass multiple tensors that require grads
assert all([
torch.is_tensor(elt) and elt.requires_grad is False
for elt in outputs[1:]
])
outputs_tail = outputs[1:]
elif torch.is_tensor(outputs):
first_output = outputs
outputs_tail = []
else:
raise ValueError("expecting a tensor or a tuple of tensors")
part = PartitionedTensor(tensor=first_output,
group=self.grid.get_slice_parallel_group())
# Clear the large output data, but save the computation graph
outputs.data = torch.zeros(1)
self.pipe_buffers['output_tensors'][buffer_id] = outputs
first_output.data = torch.zeros(1)
self.pipe_buffers['output_tensors'][buffer_id] = first_output
# Inject the partitioned tensor into the output before sending
outputs = tuple([part.to_meta(), part.data()])
outputs = (part.to_meta(), part.data(), *outputs_tail)
part = None

self.pipe_buffers['outputs'][buffer_id] = outputs
Expand Down Expand Up @@ -654,11 +677,11 @@ def _exec_backward_pass(self, buffer_id):
local_part=outputs[1],
group=self.grid.get_slice_parallel_group())
self.pipe_buffers['output_tensors'][buffer_id].data = part_output.full()
outputs = self.pipe_buffers['output_tensors'][buffer_id]
outputs = (self.pipe_buffers['output_tensors'][buffer_id], *outputs[2:])
else:
# Already restored from partition
self.pipe_buffers['output_tensors'][buffer_id].data = outputs
outputs = self.pipe_buffers['output_tensors'][buffer_id]
self.pipe_buffers['output_tensors'][buffer_id].data = outputs[0]
outputs = (self.pipe_buffers['output_tensors'][buffer_id], *outputs[1:])

grad_tensors = self.grad_layer
if self.is_grad_partitioned:
Expand All @@ -667,7 +690,7 @@ def _exec_backward_pass(self, buffer_id):
meta=self.grad_layer[0],
local_part=self.grad_layer[1],
group=self.grid.get_slice_parallel_group())
grad_tensors = part_grad.full()
grad_tensors = (part_grad.full(), *grad_tensors[2:])
part_grad = None
#print(f'RANK={self.global_rank} BEFORE-BWD restored grad={self.grad_layer[0].size()} {self.grad_layer[1].size()}')

Expand Down Expand Up @@ -774,6 +797,9 @@ def _send_tensor_meta(self, buffer, recv_stage):
assert isinstance(tensor, torch.Tensor)
send_shape = torch.LongTensor(data=tensor.size()).to(self.device)
send_ndims = torch.LongTensor(data=[len(tensor.size())]).to(self.device)
send_dtype = torch.LongTensor(data=[self.DTYPE_TO_ID[tensor.dtype]]).to(
self.device)
p2p.send(send_dtype, recv_stage)
p2p.send(send_ndims, recv_stage)
p2p.send(send_shape, recv_stage)
# Useful for performance debugging.
Expand Down Expand Up @@ -828,16 +854,19 @@ def _recv_tensor_meta(self, send_stage):
count_tensor = torch.LongTensor(data=[0]).to(self.device)
p2p.recv(count_tensor, send_stage)
num_tensors = count_tensor.item()
recv_shapes = []
recv_shapes_and_dtypes = []
for idx in range(num_tensors):
recv_dtype = torch.LongTensor(data=[0]).to(self.device)
p2p.recv(recv_dtype, send_stage)
recv_dtype = self.ID_TO_DTYPE[recv_dtype.item()]
recv_ndims = torch.LongTensor(data=[0]).to(self.device)
p2p.recv(recv_ndims, send_stage)
recv_ndims = recv_ndims.item()
recv_shape = torch.LongTensor([1] * recv_ndims).to(self.device)
p2p.recv(recv_shape, send_stage)
recv_shapes.append(recv_shape.tolist())
recv_shapes_and_dtypes.append((recv_shape.tolist(), recv_dtype))

buffers = self._allocate_buffers(recv_shapes, num_buffers=1)[0]
buffers = self._allocate_buffers(recv_shapes_and_dtypes, num_buffers=1)[0]
# Convert to tuples if requested.
if recv_type == 2:
buffers = tuple(buffers)
Expand Down Expand Up @@ -890,10 +919,22 @@ def _exec_send_grads(self, buffer_id):

# Partition the gradient
if self.is_grad_partitioned:
assert torch.is_tensor(inputs)
part = PartitionedTensor(tensor=inputs.grad,
if isinstance(inputs, tuple):
first_input = inputs[0]
assert all([torch.is_tensor(elt) for elt in inputs[1:]])
inputs_grad_tail = [
elt.grad for elt in inputs[1:] if elt.grad is not None
]
elif torch.is_tensor(inputs):
first_input = inputs
inputs_grad_tail = []
else:
raise ValueError("expecting a tensor or a tuple of tensors")
assert torch.is_tensor(first_input)
part = PartitionedTensor(tensor=first_input.grad,
group=self.grid.get_slice_parallel_group())
inputs = tuple([part.to_meta(), part.data()])

inputs = (part.to_meta(), part.data(), *inputs_grad_tail)

# XXX Terrible hack
# Drop the attention mask from the input buffer here. It does not have
Expand Down Expand Up @@ -987,18 +1028,22 @@ def _exec_recv_grads(self, buffer_id):
local_part=outputs[1],
group=self.grid.get_slice_parallel_group())
outputs[0].data = part_output.full()
outputs = outputs[0]
outputs = ([outputs[0], *outputs[2:]])
# save for backward
self.pipe_buffers['outputs'][buffer_id] = outputs

# Allocate gradient if necessary
if self.grad_layer is None:
if isinstance(outputs, torch.Tensor):
s = list(outputs.size())
self.grad_layer = self._allocate_buffer(s, num_buffers=1)[0]
self.grad_layer = self._allocate_buffer(s,
dtype=outputs.dtype,
num_buffers=1)[0]
else:
sizes = [list(t.size()) for t in outputs] # if t.is_floating_point()]
self.grad_layer = self._allocate_buffers(sizes, num_buffers=1)[0]
sizes_and_dtypes = [(list(t.size()),
t.dtype) for t in outputs if t.is_floating_point()]
self.grad_layer = self._allocate_buffers(sizes_and_dtypes,
num_buffers=1)[0]

if isinstance(self.grad_layer, torch.Tensor):
p2p.recv(self.grad_layer, self.next_stage)
Expand Down Expand Up @@ -1070,25 +1115,20 @@ def _zero_grads(self, inputs):
if t.grad is not None:
t.grad.data.zero_()

def _allocate_zeros(self, shape, fp16=None, **kwargs):
def _allocate_zeros(self, shape, **kwargs):
""" Allocate a tensor of zeros on the engine's device.

Arguments:
shape: the shape of the tensor to allocate
fp16 (bool): whether to use FP16. default: defer to self.fp16_enabled()
kwargs: passed to torch.zeros()

Returns:
A tensor from torch.zeros() allocated on self.device.
"""
if "dtype" not in kwargs and self.fp16_enabled():
kwargs["dtype"] = torch.half

if fp16 is None:
fp16 = self.fp16_enabled()

if fp16:
return torch.zeros(shape, dtype=torch.half, device=self.device, **kwargs)
else:
return torch.zeros(shape, device=self.device, **kwargs)
return torch.zeros(shape, device=self.device, **kwargs)

def _allocate_buffer(self, shape, num_buffers=-1, **kwargs):
buffers = []
Expand All @@ -1098,14 +1138,17 @@ def _allocate_buffer(self, shape, num_buffers=-1, **kwargs):
buffers.append(self._allocate_zeros(shape, **kwargs))
return buffers

def _allocate_buffers(self, shapes, requires_grad=False, num_buffers=-1):
def _allocate_buffers(self, shapes_and_dtypes, requires_grad=False, num_buffers=-1):
buffers = []
if num_buffers == -1:
num_buffers = self.num_pipe_buffers
for count in range(num_buffers):
buffer = []
for shape in shapes:
buffer.append(self._allocate_zeros(shape, requires_grad=requires_grad))
for shape, dtype in shapes_and_dtypes:
buffer.append(
self._allocate_zeros(shape,
dtype=dtype,
requires_grad=requires_grad))
buffers.append(buffer)
return buffers

Expand Down