Skip to content

Commit 407ff0f

Browse files
author
Shaden Smith
committed
pipe partitioning
1 parent db017fd commit 407ff0f

File tree

1 file changed

+23
-25
lines changed

1 file changed

+23
-25
lines changed

deepspeed/runtime/pipe/engine.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ def __init__(self, *super_args, **super_kwargs):
111111

112112
# Partition input/output buffers
113113
# XXX temporarily disable while I revert some partition hacks.
114-
self.is_pipe_partitioned = False #self.is_model_parallel
115-
self.is_grad_partitioned = False
114+
self.is_pipe_partitioned = self.is_model_parallel
115+
self.is_grad_partitioned = False #self.is_model_parallel
116116

117117
model_parameters = filter(lambda p: p.requires_grad, self.module.parameters())
118118
num_params = sum([p.numel() for p in model_parameters])
@@ -554,12 +554,18 @@ def _exec_forward_pass(self, buffer_id):
554554
local_part=inputs[1],
555555
group=self.grid.get_slice_parallel_group())
556556

557+
inputs = part_input.full()
558+
inputs.requires_grad = True
559+
part_input = None
560+
self.pipe_buffers['inputs'][buffer_id] = inputs
561+
'''
557562
inputs = tuple([part_input.full(), inputs[2]])
558563
inputs[0].requires_grad = True
559564
# skip mask
560565
#inputs[1].requires_grad = True
561566
part_input = None
562567
self.pipe_buffers['inputs'][buffer_id] = inputs
568+
'''
563569

564570
# Zero out the gradients each time we use the tensor because only the data in
565571
# tensor changes across batches
@@ -569,13 +575,14 @@ def _exec_forward_pass(self, buffer_id):
569575

570576
# Partition the outputs if we are not the last stage
571577
if self.is_pipe_partitioned and not self.is_last_stage():
572-
part = PartitionedTensor(tensor=outputs[0],
578+
assert torch.is_tensor(outputs)
579+
part = PartitionedTensor(tensor=outputs,
573580
group=self.grid.get_slice_parallel_group())
574581
# Clear the large output data, but save the computation graph
575-
outputs[0].data = torch.zeros(1)
576-
self.pipe_buffers['output_tensors'][buffer_id] = outputs[0]
582+
outputs.data = torch.zeros(1)
583+
self.pipe_buffers['output_tensors'][buffer_id] = outputs
577584
# Inject the partitioned tensor into the output before sending
578-
outputs = tuple([part.to_meta(), part.data(), outputs[1]])
585+
outputs = tuple([part.to_meta(), part.data()])
579586
part = None
580587

581588
self.pipe_buffers['outputs'][buffer_id] = outputs
@@ -633,15 +640,11 @@ def _exec_backward_pass(self, buffer_id):
633640
local_part=outputs[1],
634641
group=self.grid.get_slice_parallel_group())
635642
self.pipe_buffers['output_tensors'][buffer_id].data = part_output.full()
636-
outputs = tuple(
637-
[self.pipe_buffers['output_tensors'][buffer_id],
638-
outputs[2]])
643+
outputs = self.pipe_buffers['output_tensors'][buffer_id]
639644
else:
640645
# Already restored from partition
641-
self.pipe_buffers['output_tensors'][buffer_id].data = outputs[0]
642-
outputs = tuple(
643-
[self.pipe_buffers['output_tensors'][buffer_id],
644-
outputs[1]])
646+
self.pipe_buffers['output_tensors'][buffer_id].data = outputs
647+
outputs = self.pipe_buffers['output_tensors'][buffer_id]
645648

646649
grad_tensors = self.grad_layer
647650
if self.is_grad_partitioned:
@@ -650,7 +653,7 @@ def _exec_backward_pass(self, buffer_id):
650653
meta=self.grad_layer[0],
651654
local_part=self.grad_layer[1],
652655
group=self.grid.get_slice_parallel_group())
653-
grad_tensors = tuple([part_grad.full(), self.grad_layer[2]])
656+
grad_tensors = part_grad.full()
654657
part_grad = None
655658
#print(f'RANK={self.global_rank} BEFORE-BWD restored grad={self.grad_layer[0].size()} {self.grad_layer[1].size()}')
656659

@@ -873,13 +876,10 @@ def _exec_send_grads(self, buffer_id):
873876

874877
# Partition the gradient
875878
if self.is_grad_partitioned:
876-
part = PartitionedTensor(tensor=inputs[0].grad,
879+
assert torch.is_tensor(inputs)
880+
part = PartitionedTensor(tensor=inputs.grad,
877881
group=self.grid.get_slice_parallel_group())
878-
# Clear the large output data, but save the computation graph
879-
# Inject the partitoned tensor into the output before sending
880-
881-
# XXX Hack
882-
inputs = tuple([part.to_meta(), part.data(), inputs[1]])
882+
inputs = tuple([part.to_meta(), part.data()])
883883

884884
# XXX Terrible hack
885885
# Drop the attention mask from the input buffer here. It does not have
@@ -900,8 +900,6 @@ def _exec_send_grads(self, buffer_id):
900900
# First two sends are partitioned gradient
901901
p2p.send(inputs[0], self.prev_stage)
902902
p2p.send(inputs[1], self.prev_stage)
903-
# XXX hack hack hack
904-
#p2p.send(inputs[2].grad, self.prev_stage)
905903
else:
906904
for idx, buffer in enumerate(inputs):
907905
# Skip tensors that will not produce a grad
@@ -975,7 +973,7 @@ def _exec_recv_grads(self, buffer_id):
975973
local_part=outputs[1],
976974
group=self.grid.get_slice_parallel_group())
977975
outputs[0].data = part_output.full()
978-
outputs = tuple([outputs[0], outputs[2]])
976+
outputs = outputs[0]
979977
# save for backward
980978
self.pipe_buffers['outputs'][buffer_id] = outputs
981979

@@ -985,7 +983,7 @@ def _exec_recv_grads(self, buffer_id):
985983
s = list(outputs.size())
986984
self.grad_layer = self._allocate_buffer(s, num_buffers=1)[0]
987985
else:
988-
sizes = [list(t.size()) for t in outputs if t.is_floating_point()]
986+
sizes = [list(t.size()) for t in outputs]# if t.is_floating_point()]
989987
self.grad_layer = self._allocate_buffers(sizes, num_buffers=1)[0]
990988

991989
if isinstance(self.grad_layer, torch.Tensor):
@@ -999,7 +997,7 @@ def _exec_recv_grads(self, buffer_id):
999997
dtype=torch.long,
1000998
device=self.device)
1001999
p2p.recv(buffer, self.next_stage)
1002-
1000+
10031001
if self.wall_clock_breakdown():
10041002
self.timers('pipe_recv_grad').stop()
10051003

0 commit comments

Comments
 (0)