@@ -169,47 +169,57 @@ def backward(ctx, grad_output):
169169 return grad_input , grad_weight , grad_bias , None , None , None
170170
171171
172- def _AllgatherLinear (input_ , weight , process_group ):
172+ def _ring_as_gather (func , input_to_gather = None , input_local = None , process_group = None , gather_dim = 1 , keep_item = False ):
173+ # currently only support one single tensor as output
173174 group_size = dist .get_world_size (process_group )
174175 cur_rank = dist .get_rank (process_group )
175176
176- input_shape = input_ .shape
177- weight_shape = weight .shape
178-
179- output_tensors = [torch .empty ((input_shape [0 ], input_shape [1 ], weight_shape [0 ])) for _ in range (group_size )]
177+ #output_tensors = [torch.empty((input_shape[0], input_shape[1], weight_shape[0])) for _ in range(group_size)]
180178
181179 # initialization of ring communication
182- input_shape [1 ]
183180 recv_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0
184181 send_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1
185- recv_tensor = input_ .clone ()
186- send_tensor = input_ .clone ()
187-
188- recv_op = dist .P2POp (dist .irecv , recv_tensor , recv_rank , group = process_group )
189- send_op = dist .P2POp (dist .isend , send_tensor , send_rank , group = process_group )
190- handles = dist .batch_isend_irecv ([send_op , recv_op ])
182+ recv_tensors = {}
183+ send_tensors = {}
184+ for k , v in input_to_gather .items ():
185+ recv_tensors [k ] = v .clone ()
186+ send_tensors [k ] = v .clone ()
187+
188+ def communicate_step ():
189+ comm_ops = []
190+ for k in recv_tensors :
191+ comm_ops .append (dist .P2POp (dist .irecv , recv_tensors [k ], recv_rank , group = process_group ))
192+ comm_ops .append (dist .P2POp (dist .isend , send_tensors [k ], send_rank , group = process_group ))
193+ return dist .batch_isend_irecv (comm_ops )
194+
195+ def switch_step ():
196+ for k in recv_tensors :
197+ tmp_tensor = send_tensors [k ]
198+ send_tensors [k ] = recv_tensors [k ]
199+ recv_tensors [k ] = tmp_tensor
200+
201+ output_tensors = []
202+
203+ handles = communicate_step ()
191204 # first round: special case, retrive from local tensor
192- output_tensors [ 0 ] = F . linear ( input_ , weight )
205+ output_tensors . append ( func ( ** input_to_gather , ** input_local ) )
193206 for i in range (group_size - 2 ):
194207 for handle in handles :
195208 handle .wait ()
196209
197- tmp_tensor = send_tensor
198- send_tensor = recv_tensor
199- recv_tensor = tmp_tensor
210+ switch_step ()
200211
201- recv_op = dist .P2POp (dist .irecv , recv_tensor , recv_rank , group = process_group )
202- send_op = dist .P2POp (dist .isend , send_tensor , send_rank , group = process_group )
203- handles = dist .batch_isend_irecv ([recv_op , send_op ])
212+ handles = communicate_step ()
204213
205214 # actual computation
206- output_tensors [ i + 1 ] = F . linear ( send_tensor , weight )
215+ output_tensors . append ( func ( ** send_tensors , ** input_local ) )
207216
208217 # final round: special case, no need to send/recv again
209218 for handle in handles :
210219 handle .wait ()
211- output_tensors [group_size - 1 ] = F .linear (recv_tensor , weight )
212- return torch .cat (output_tensors [group_size - cur_rank :] + output_tensors [: group_size - cur_rank ], dim = 1 )
220+ output_tensors .append (func (** recv_tensors , ** input_local ))
221+
222+ return torch .cat (output_tensors [group_size - cur_rank :] + output_tensors [: group_size - cur_rank ], dim = gather_dim )
213223
214224
215225class _GatherForwardReduceScatterBackward (torch .autograd .Function ):
@@ -249,6 +259,41 @@ def backward(ctx, grad_output):
249259 return output , None , None
250260
251261
262+ class _GatherForwardReduceScatterBackward (torch .autograd .Function ):
263+ """Gather input from sequence parallel in forward and reduce-scatter gradient in backward
264+
265+ Args:
266+ input_ (`torch.Tensor`): The input tensor from sequence parallel region.
267+ process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
268+ overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward.
269+
270+ """
271+
272+ @staticmethod
273+ def forward (ctx , input_ , process_group , dim ):
274+ ctx .process_group = process_group
275+ ctx .dim = dim
276+
277+ return _gather (input_ , dim , process_group )
278+
279+ @staticmethod
280+ def backward (ctx , grad_output ):
281+ dim = ctx .dim
282+ process_group = ctx .process_group
283+
284+ # do reduce-scatter
285+ new_shape = list (grad_output .shape )
286+ assert (
287+ new_shape [dim ] % dist .get_world_size (process_group ) == 0
288+ ), f"The dimension to split ({ new_shape [dim ]} ) is not a multiple of tensor parallel size ({ dist .get_world_size (process_group )} ). "
289+ new_shape [dim ] = new_shape [dim ] // dist .get_world_size (process_group )
290+ grad_list = [item .contiguous () for item in torch .chunk (grad_output , dist .get_world_size (process_group ), dim = dim )]
291+ output = torch .empty (new_shape , dtype = grad_output .dtype , device = grad_output .device )
292+ dist .reduce_scatter (output , grad_list , group = process_group )
293+
294+ return output , None , None
295+
296+
252297class _LinearWithGatherForwardReduceScatterBackward (torch .autograd .Function ):
253298 """Gather input from sequence parallel in forward and reduce-scatter gradient in backward
254299
@@ -260,19 +305,35 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
260305 """
261306
262307 @staticmethod
263- def forward (ctx , input_ , weight , bias , process_group , async_grad_reduce_scatter , dim , overlap = True ):
308+ def forward (ctx , input_ , weight , bias , process_group , async_grad_reduce_scatter , dim , overlap = True , ring = False ):
264309 ctx .save_for_backward (input_ , weight , bias )
265310 ctx .use_bias = bias is not None
266311 ctx .process_group = process_group
267312 ctx .async_grad_reduce_scatter = async_grad_reduce_scatter
268313 ctx .dim = dim
269314 ctx .overlap = overlap
270315
271- if bias is not None :
272- input_parallel = _gather (input_ , dim , process_group )
273- output = F .linear (input_parallel , weight , bias )
316+ if ring is True :
317+ input_to_gather = {}
318+ input_local = {}
319+ input_to_gather ['input' ] = input_
320+ input_local ['weight' ] = weight
321+
322+ output = _ring_as_gather (
323+ F .linear ,
324+ input_to_gather = input_to_gather ,
325+ input_local = input_local ,
326+ process_group = process_group ,
327+ )
328+
329+ if bias is not None :
330+ output += bias
274331 else :
275- output = _AllgatherLinear (input_ , weight , process_group )
332+ input_parallel = _gather (input_ , dim , process_group )
333+ if bias is not None :
334+ output = F .linear (input_parallel , weight , bias )
335+ else :
336+ output = F .linear (input_parallel , weight )
276337
277338 return output
278339
@@ -376,34 +437,43 @@ def backward(ctx, grad_output):
376437 # wait until reduce-scatter finished
377438 reducescatter_handle .wait ()
378439
379- return output , grad_weight , grad_bias , None , None , None , None
440+ return output , grad_weight , grad_bias , None , None , None , None , None
380441
381442
382- def _ReduceScatterLinear (input_ , weight , process_group ):
443+ def _ring_as_reducescatter (func , input_to_reducescatter = None , input_local = None , process_group = None , reducescatter_dim = 1 ):
444+ # currently only support one single tensor as output
383445 group_size = dist .get_world_size (process_group )
384446 cur_rank = dist .get_rank (process_group )
385447
386- input_shape = input_ .shape
387-
388448 # initialization of ring communication
389- # communicate(e.g.): 0->1->2->3
390- # compute(e.g.): 3->2->1->0
391- input_tensors = list (torch .split (input_ , int (input_shape [1 ] / group_size ), dim = 1 ))
392- input_tensors = input_tensors [cur_rank :] + input_tensors [:cur_rank ]
393- input_tensors .reverse ()
394449 recv_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1
395450 send_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0
451+ input_tensors = []
452+ for _ in range (group_size ):
453+ input_tensors .append ({})
454+ for k , v in input_to_reducescatter .items ():
455+ input_shape = v .shape
456+ assert input_shape [reducescatter_dim ] % group_size == 0
457+ _input_tensors = list (torch .split (v , input_shape [reducescatter_dim ] // group_size , dim = reducescatter_dim ))
458+ for i in range (group_size ):
459+ input_tensors [i ][k ] = _input_tensors [i ]
460+ input_tensors = input_tensors [cur_rank :] + input_tensors [:cur_rank ]
461+ input_tensors .reverse ()
396462
397- # first round: special case, no reduce operation
398- output_tensor = F .linear (input_tensors [0 ], weight )
463+ output_tensor = func (** input_tensors [0 ], ** input_local )
399464 recv_tensor = output_tensor .clone ()
400465 send_tensor = output_tensor .clone ()
401- recv_op = dist .P2POp (dist .irecv , recv_tensor , recv_rank , group = process_group )
402- send_op = dist .P2POp (dist .isend , send_tensor , send_rank , group = process_group )
403- handles = dist .batch_isend_irecv ([recv_op , send_op ])
466+
467+ def communicate_step ():
468+ recv_op = dist .P2POp (dist .irecv , recv_tensor , recv_rank , group = process_group )
469+ send_op = dist .P2POp (dist .isend , send_tensor , send_rank , group = process_group )
470+ return dist .batch_isend_irecv ([recv_op , send_op ])
471+
472+ handles = communicate_step ()
473+ # first round: special case, retrive from local tensor
404474 for i in range (group_size - 2 ):
405475 # actual computation
406- output_tensor = F . linear ( input_tensors [i + 1 ], weight )
476+ output_tensor = func ( ** input_tensors [i + 1 ], ** input_local )
407477
408478 for handle in handles :
409479 handle .wait ()
@@ -413,12 +483,10 @@ def _ReduceScatterLinear(input_, weight, process_group):
413483 send_tensor = output_tensor
414484 output_tensor = tmp_tensor
415485
416- recv_op = dist .P2POp (dist .irecv , recv_tensor , recv_rank , group = process_group )
417- send_op = dist .P2POp (dist .isend , send_tensor , send_rank , group = process_group )
418- handles = dist .batch_isend_irecv ([recv_op , send_op ])
486+ handles = communicate_step ()
419487
420488 # final round: special case, no need to send/recv again
421- output_tensor = F . linear ( input_tensors [group_size - 1 ], weight )
489+ output_tensor = func ( ** input_tensors [- 1 ], ** input_local )
422490 for handle in handles :
423491 handle .wait ()
424492 output_tensor += recv_tensor
@@ -436,27 +504,44 @@ class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
436504 """
437505
438506 @staticmethod
439- def forward (ctx , input_ , weight , bias , process_group , dim ):
507+ def forward (ctx , input_ , weight , bias , process_group , dim , ring ):
440508 ctx .save_for_backward (input_ , weight , bias )
441509 ctx .use_bias = bias is not None
442510 ctx .process_group = process_group
443511 ctx .dim = dim
444- if bias is not None :
445- partial_output = F .linear (input_ , weight , bias )
512+
513+ if ring is True :
514+ input_to_reducescatter = {}
515+ input_local = {}
516+ input_to_reducescatter ['input' ] = input_
517+ input_local ['weight' ] = weight
518+
519+ if bias is not None :
520+ input_to_reducescatter ['bias' ] = bias
521+
522+ output = _ring_as_reducescatter (
523+ F .linear ,
524+ input_to_reducescatter = input_to_reducescatter ,
525+ input_local = input_local ,
526+ process_group = process_group ,
527+ )
446528 else :
447- return _ReduceScatterLinear (input_ , weight , process_group )
529+ if bias is not None :
530+ partial_output = F .linear (input_ , weight , bias )
531+ else :
532+ partial_output = F .linear (input_ , weight )
448533
449- output_shape = list (partial_output .shape )
450- assert (
451- output_shape [dim ] % dist .get_world_size (process_group ) == 0
452- ), f"The dimension to split ({ output_shape [dim ]} ) is not a multiple of tensor parallel size ({ dist .get_world_size (process_group )} ). "
453- output_shape [dim ] = output_shape [dim ] // dist .get_world_size (process_group )
534+ output_shape = list (partial_output .shape )
535+ assert (
536+ output_shape [dim ] % dist .get_world_size (process_group ) == 0
537+ ), f"The dimension to split ({ output_shape [dim ]} ) is not a multiple of tensor parallel size ({ dist .get_world_size (process_group )} ). "
538+ output_shape [dim ] = output_shape [dim ] // dist .get_world_size (process_group )
454539
455- output_list = [
456- item .contiguous () for item in torch .chunk (partial_output , dist .get_world_size (process_group ), dim = dim )
457- ]
458- output = torch .empty (output_shape , dtype = partial_output .dtype , device = partial_output .device ).contiguous ()
459- dist .reduce_scatter (output , output_list , group = process_group )
540+ output_list = [
541+ item .contiguous () for item in torch .chunk (partial_output , dist .get_world_size (process_group ), dim = dim )
542+ ]
543+ output = torch .empty (output_shape , dtype = partial_output .dtype , device = partial_output .device ).contiguous ()
544+ dist .reduce_scatter (output , output_list , group = process_group )
460545
461546 return output
462547
@@ -484,7 +569,7 @@ def backward(ctx, grad_output):
484569 grad_weight = grad_output .t ().matmul (total_input )
485570 grad_bias = grad_output .sum (dim = 0 ) if use_bias else None
486571
487- return grad_input , grad_weight , grad_bias , None , None
572+ return grad_input , grad_weight , grad_bias , None , None , None
488573
489574
490575class _ReduceScatterForwardGatherBackward (torch .autograd .Function ):
@@ -533,17 +618,32 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
533618 """
534619
535620 @staticmethod
536- def forward (ctx , input_ , weight , bias , process_group , async_grad_reduce_scatter , dim , overlap ):
621+ def forward (ctx , input_ , weight , bias , process_group , async_grad_reduce_scatter , dim , overlap , ring ):
537622 ctx .save_for_backward (input_ , weight , bias )
538623 ctx .use_bias = bias is not None
539624 ctx .process_group = process_group
540625 ctx .async_grad_reduce_scatter = async_grad_reduce_scatter
541626 ctx .dim = dim
542627 ctx .overlap = overlap
543628
544- input_parallel = _gather (input_ , dim , process_group )
629+ if ring is True :
630+ input_to_gather = {}
631+ input_local = {}
632+ input_to_gather ['input' ] = input_
633+ input_local ['other' ] = weight
634+
635+ output = _ring_as_gather (
636+ torch .matmul ,
637+ input_to_gather = input_to_gather ,
638+ input_local = input_local ,
639+ process_group = process_group ,
640+ gather_dim = dim
641+ )
642+
643+ else :
644+ input_parallel = _gather (input_ , dim , process_group )
545645
546- output = torch .matmul (input_parallel , weight )
646+ output = torch .matmul (input_parallel , weight )
547647
548648 if bias is not None :
549649 output = output + bias
@@ -624,7 +724,7 @@ def backward(ctx, grad_output):
624724 # wait until reduce-scatter finished
625725 reducescatter_handle .wait ()
626726
627- return output , grad_weight , grad_bias , None , None , None , None
727+ return output , grad_weight , grad_bias , None , None , None , None , None
628728
629729
630730class _SplitForwardGatherBackward (torch .autograd .Function ):
@@ -877,10 +977,10 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre
877977
878978
879979def linear_gather_forward_reducescatter_backward (
880- input_ , weight , bias , process_group , async_grad_reduce_scatter , dim , overlap
980+ input_ , weight , bias , process_group , async_grad_reduce_scatter , dim , overlap , ring = False
881981):
882982 return _LinearWithGatherForwardReduceScatterBackward .apply (
883- input_ , weight , bias , process_group , async_grad_reduce_scatter , dim , overlap
983+ input_ , weight , bias , process_group , async_grad_reduce_scatter , dim , overlap , ring
884984 )
885985
886986
@@ -892,15 +992,15 @@ def reducescatter_forward_gather_backward(input_, process_group, dim):
892992 return _ReduceScatterForwardGatherBackward .apply (input_ , process_group , dim )
893993
894994
895- def linear_reducescatter_forward_gather_backward (input_ , weight , bias = None , process_group = None , dim = 1 ):
896- return _LinearWithReduceScatterForwardGatherBackward .apply (input_ , weight , bias , process_group , dim )
995+ def linear_reducescatter_forward_gather_backward (input_ , weight , bias = None , process_group = None , dim = 1 , ring = False ):
996+ return _LinearWithReduceScatterForwardGatherBackward .apply (input_ , weight , bias , process_group , dim , ring )
897997
898998
899999def matmul_gather_forward_reducescatter_backward (
900- input_ , weight , bias , process_group , async_grad_reduce_scatter , dim , overlap
1000+ input_ , weight , bias , process_group , async_grad_reduce_scatter , dim , overlap , ring = False
9011001):
9021002 return _MatmulWithGatherForwardReduceScatterBackward .apply (
903- input_ , weight , bias , process_group , async_grad_reduce_scatter , dim , overlap
1003+ input_ , weight , bias , process_group , async_grad_reduce_scatter , dim , overlap , ring
9041004 )
9051005
9061006
0 commit comments