@@ -111,8 +111,8 @@ def __init__(self, *super_args, **super_kwargs):
111111
112112 # Partition input/output buffers
113113 # XXX temporarily disable while I revert some partition hacks.
114- self .is_pipe_partitioned = False # self.is_model_parallel
115- self .is_grad_partitioned = False
114+ self .is_pipe_partitioned = self .is_model_parallel
115+ self .is_grad_partitioned = False #self.is_model_parallel
116116
117117 model_parameters = filter (lambda p : p .requires_grad , self .module .parameters ())
118118 num_params = sum ([p .numel () for p in model_parameters ])
@@ -554,12 +554,18 @@ def _exec_forward_pass(self, buffer_id):
554554 local_part = inputs [1 ],
555555 group = self .grid .get_slice_parallel_group ())
556556
557+ inputs = part_input .full ()
558+ inputs .requires_grad = True
559+ part_input = None
560+ self .pipe_buffers ['inputs' ][buffer_id ] = inputs
561+ '''
557562 inputs = tuple([part_input.full(), inputs[2]])
558563 inputs[0].requires_grad = True
559564 # skip mask
560565 #inputs[1].requires_grad = True
561566 part_input = None
562567 self.pipe_buffers['inputs'][buffer_id] = inputs
568+ '''
563569
564570 # Zero out the gradients each time we use the tensor because only the data in
565571 # tensor changes across batches
@@ -569,13 +575,14 @@ def _exec_forward_pass(self, buffer_id):
569575
570576 # Partition the outputs if we are not the last stage
571577 if self .is_pipe_partitioned and not self .is_last_stage ():
572- part = PartitionedTensor (tensor = outputs [0 ],
578+ assert torch .is_tensor (outputs )
579+ part = PartitionedTensor (tensor = outputs ,
573580 group = self .grid .get_slice_parallel_group ())
574581 # Clear the large output data, but save the computation graph
575- outputs [ 0 ] .data = torch .zeros (1 )
576- self .pipe_buffers ['output_tensors' ][buffer_id ] = outputs [ 0 ]
582+ outputs .data = torch .zeros (1 )
583+ self .pipe_buffers ['output_tensors' ][buffer_id ] = outputs
577584 # Inject the partitioned tensor into the output before sending
578- outputs = tuple ([part .to_meta (), part .data (), outputs [ 1 ] ])
585+ outputs = tuple ([part .to_meta (), part .data ()])
579586 part = None
580587
581588 self .pipe_buffers ['outputs' ][buffer_id ] = outputs
@@ -633,15 +640,11 @@ def _exec_backward_pass(self, buffer_id):
633640 local_part = outputs [1 ],
634641 group = self .grid .get_slice_parallel_group ())
635642 self .pipe_buffers ['output_tensors' ][buffer_id ].data = part_output .full ()
636- outputs = tuple (
637- [self .pipe_buffers ['output_tensors' ][buffer_id ],
638- outputs [2 ]])
643+ outputs = self .pipe_buffers ['output_tensors' ][buffer_id ]
639644 else :
640645 # Already restored from partition
641- self .pipe_buffers ['output_tensors' ][buffer_id ].data = outputs [0 ]
642- outputs = tuple (
643- [self .pipe_buffers ['output_tensors' ][buffer_id ],
644- outputs [1 ]])
646+ self .pipe_buffers ['output_tensors' ][buffer_id ].data = outputs
647+ outputs = self .pipe_buffers ['output_tensors' ][buffer_id ]
645648
646649 grad_tensors = self .grad_layer
647650 if self .is_grad_partitioned :
@@ -650,7 +653,7 @@ def _exec_backward_pass(self, buffer_id):
650653 meta = self .grad_layer [0 ],
651654 local_part = self .grad_layer [1 ],
652655 group = self .grid .get_slice_parallel_group ())
653- grad_tensors = tuple ([ part_grad .full (), self . grad_layer [ 2 ]] )
656+ grad_tensors = part_grad .full ()
654657 part_grad = None
655658 #print(f'RANK={self.global_rank} BEFORE-BWD restored grad={self.grad_layer[0].size()} {self.grad_layer[1].size()}')
656659
@@ -873,13 +876,10 @@ def _exec_send_grads(self, buffer_id):
873876
874877 # Partition the gradient
875878 if self .is_grad_partitioned :
876- part = PartitionedTensor (tensor = inputs [0 ].grad ,
879+ assert torch .is_tensor (inputs )
880+ part = PartitionedTensor (tensor = inputs .grad ,
877881 group = self .grid .get_slice_parallel_group ())
878- # Clear the large output data, but save the computation graph
879- # Inject the partitoned tensor into the output before sending
880-
881- # XXX Hack
882- inputs = tuple ([part .to_meta (), part .data (), inputs [1 ]])
882+ inputs = tuple ([part .to_meta (), part .data ()])
883883
884884 # XXX Terrible hack
885885 # Drop the attention mask from the input buffer here. It does not have
@@ -900,8 +900,6 @@ def _exec_send_grads(self, buffer_id):
900900 # First two sends are partitioned gradient
901901 p2p .send (inputs [0 ], self .prev_stage )
902902 p2p .send (inputs [1 ], self .prev_stage )
903- # XXX hack hack hack
904- #p2p.send(inputs[2].grad, self.prev_stage)
905903 else :
906904 for idx , buffer in enumerate (inputs ):
907905 # Skip tensors that will not produce a grad
@@ -975,7 +973,7 @@ def _exec_recv_grads(self, buffer_id):
975973 local_part = outputs [1 ],
976974 group = self .grid .get_slice_parallel_group ())
977975 outputs [0 ].data = part_output .full ()
978- outputs = tuple ([ outputs [0 ], outputs [ 2 ]])
976+ outputs = outputs [0 ]
979977 # save for backward
980978 self .pipe_buffers ['outputs' ][buffer_id ] = outputs
981979
@@ -985,7 +983,7 @@ def _exec_recv_grads(self, buffer_id):
985983 s = list (outputs .size ())
986984 self .grad_layer = self ._allocate_buffer (s , num_buffers = 1 )[0 ]
987985 else :
988- sizes = [list (t .size ()) for t in outputs if t .is_floating_point ()]
986+ sizes = [list (t .size ()) for t in outputs ] # if t.is_floating_point()]
989987 self .grad_layer = self ._allocate_buffers (sizes , num_buffers = 1 )[0 ]
990988
991989 if isinstance (self .grad_layer , torch .Tensor ):
@@ -999,7 +997,7 @@ def _exec_recv_grads(self, buffer_id):
999997 dtype = torch .long ,
1000998 device = self .device )
1001999 p2p .recv (buffer , self .next_stage )
1002-
1000+
10031001 if self .wall_clock_breakdown ():
10041002 self .timers ('pipe_recv_grad' ).stop ()
10051003
0 commit comments