1111
1212
1313def all_gather (tensor : Tensor , dim : int ,
14- parallel_mode : ParallelMode ) -> Tensor :
14+ parallel_mode : ParallelMode , async_op = False ) -> Tensor :
1515 """Gathers all tensors from the parallel group and concatenates them in a
1616 specific dimension.
1717
@@ -26,18 +26,28 @@ def all_gather(tensor: Tensor, dim: int,
2626 """
2727 depth = gpc .get_world_size (parallel_mode )
2828 temp = tensor .clone ()
29- shape = list (temp .shape )
30- shape [dim ] *= depth
31- out = torch .empty (shape , dtype = temp .dtype , device = get_current_device ())
32- out = list (torch .chunk (out , depth , dim = dim ))
33- out = [val .contiguous () for val in out ]
34- dist .all_gather (out , temp , group = gpc .get_group (parallel_mode ))
35- out = torch .cat (out , dim = dim )
36- return out
29+ # shape = list(temp.shape)
30+ # shape[dim] *= depth
31+ # out = torch.zeros(shape, dtype=temp.dtype, device=get_current_device())
32+ # out = list(torch.chunk(out, depth, dim=dim))
33+ # out = [val.contiguous() for val in out]
34+ shape = [1 ] * len (tensor .shape )
35+ shape [dim ] = depth
36+ out = tensor .repeat (shape )
37+ out = list (map (lambda x : x .contiguous (), torch .chunk (out , depth , dim = dim )))
38+ op = dist .all_gather (tensor_list = out ,
39+ tensor = temp ,
40+ group = gpc .get_group (parallel_mode ),
41+ async_op = async_op )
42+ # out = torch.cat(out, dim=dim)
43+ if async_op :
44+ return out , op
45+ else :
46+ return out
3747
3848
3949def reduce_scatter (tensor : Tensor , dim : int ,
40- parallel_mode : ParallelMode ) -> Tensor :
50+ parallel_mode : ParallelMode , async_op = False ) -> Tensor :
4151 """Reduces all tensors then scatters it in a specific dimension to all
4252 members in the parallel group.
4353
@@ -51,34 +61,52 @@ def reduce_scatter(tensor: Tensor, dim: int,
5161 :rtype: Tensor
5262 """
5363 depth = gpc .get_world_size (parallel_mode )
54- temp = list (torch .chunk (tensor , depth , dim = dim ))
55- temp = [val .contiguous () for val in temp ]
56- out = torch .empty (temp [0 ].shape ,
57- dtype = temp [0 ].dtype ,
58- device = get_current_device ())
59- dist .reduce_scatter (output = out ,
60- input_list = temp ,
61- group = gpc .get_group (parallel_mode ))
62- return out
64+ # temp = list(torch.chunk(tensor, depth, dim=dim))
65+ # temp = [val.contiguous() for val in temp]
66+ # out = torch.zeros(temp[0].shape,
67+ # dtype=temp[0].dtype,
68+ # device=get_current_device())
69+ temp = list (map (lambda x : x .contiguous (), torch .chunk (tensor , depth , dim = dim )))
70+ out = temp [0 ].clone ()
71+ op = dist .reduce_scatter (output = out ,
72+ input_list = temp ,
73+ group = gpc .get_group (parallel_mode ),
74+ async_op = async_op )
75+ if async_op :
76+ return out , op
77+ else :
78+ return out
6379
6480
65- def scatter (tensor : Tensor , src : int , dim : int ,
66- parallel_mode : ParallelMode ) -> Tensor :
67- """Scatters in a specific dimension from source rank to all ranks in
68- the parallel group.
81+ def all_reduce (tensor : Tensor ,
82+ parallel_mode : ParallelMode ,
83+ async_op = False ) -> Tensor :
84+ op = dist .all_reduce (tensor ,
85+ group = gpc .get_group (parallel_mode ),
86+ async_op = async_op )
87+ if async_op :
88+ return tensor , op
89+ else :
90+ return tensor
91+
92+
93+ # def scatter(tensor: Tensor, src: int, dim: int,
94+ # parallel_mode: ParallelMode) -> Tensor:
95+ # """Scatters in a specific dimension from source rank to all ranks in
96+ # the parallel group.
6997
70- :param tensor: Tensor to be scattered
71- :param dim: The dimension scattering in
72- :param parallel_mode: Parallel group mode used in this communication
73- :type tensor: Tensor
74- :type dim: int
75- :type parallel_mode: ParallelMode
76- :return: The tensor generated by scatter
77- :rtype: Tensor
78- """
79- depth = gpc .get_world_size (parallel_mode )
80- temp = tensor .clone ()
81- dist .broadcast (temp , src = src , group = gpc .get_group (parallel_mode ))
82- rank = gpc .get_local_rank (parallel_mode )
83- out = torch .chunk (temp , depth , dim = dim )[rank ].contiguous ()
84- return out
98+ # :param tensor: Tensor to be scattered
99+ # :param dim: The dimension scattering in
100+ # :param parallel_mode: Parallel group mode used in this communication
101+ # :type tensor: Tensor
102+ # :type dim: int
103+ # :type parallel_mode: ParallelMode
104+ # :return: The tensor generated by scatter
105+ # :rtype: Tensor
106+ # """
107+ # depth = gpc.get_world_size(parallel_mode)
108+ # temp = tensor.clone()
109+ # dist.broadcast(temp, src=src, group=gpc.get_group(parallel_mode))
110+ # rank = gpc.get_local_rank(parallel_mode)
111+ # out = torch.chunk(temp, depth, dim=dim)[rank].contiguous()
112+ # return out
0 commit comments