Skip to content

Commit 8bb59a2

Browse files
linsj20KKZ20
authored andcommitted
support mode 2 sp in gpt2 (hpcaitech#5)
* [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2
1 parent f1d4e18 commit 8bb59a2

File tree

6 files changed

+199
-81
lines changed

6 files changed

+199
-81
lines changed

colossalai/shardformer/layer/_operation.py

Lines changed: 171 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -169,47 +169,57 @@ def backward(ctx, grad_output):
169169
return grad_input, grad_weight, grad_bias, None, None, None
170170

171171

172-
def _AllgatherLinear(input_, weight, process_group):
172+
def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=None, gather_dim=1, keep_item=False):
173+
# currently only support one single tensor as output
173174
group_size = dist.get_world_size(process_group)
174175
cur_rank = dist.get_rank(process_group)
175176

176-
input_shape = input_.shape
177-
weight_shape = weight.shape
178-
179-
output_tensors = [torch.empty((input_shape[0], input_shape[1], weight_shape[0])) for _ in range(group_size)]
177+
#output_tensors = [torch.empty((input_shape[0], input_shape[1], weight_shape[0])) for _ in range(group_size)]
180178

181179
# initialization of ring communication
182-
input_shape[1]
183180
recv_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0
184181
send_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1
185-
recv_tensor = input_.clone()
186-
send_tensor = input_.clone()
187-
188-
recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group)
189-
send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group)
190-
handles = dist.batch_isend_irecv([send_op, recv_op])
182+
recv_tensors = {}
183+
send_tensors = {}
184+
for k, v in input_to_gather.items():
185+
recv_tensors[k] = v.clone()
186+
send_tensors[k] = v.clone()
187+
188+
def communicate_step():
189+
comm_ops = []
190+
for k in recv_tensors:
191+
comm_ops.append(dist.P2POp(dist.irecv, recv_tensors[k], recv_rank, group=process_group))
192+
comm_ops.append(dist.P2POp(dist.isend, send_tensors[k], send_rank, group=process_group))
193+
return dist.batch_isend_irecv(comm_ops)
194+
195+
def switch_step():
196+
for k in recv_tensors:
197+
tmp_tensor = send_tensors[k]
198+
send_tensors[k] = recv_tensors[k]
199+
recv_tensors[k] = tmp_tensor
200+
201+
output_tensors = []
202+
203+
handles = communicate_step()
191204
# first round: special case, retrive from local tensor
192-
output_tensors[0] = F.linear(input_, weight)
205+
output_tensors.append(func(**input_to_gather, **input_local))
193206
for i in range(group_size - 2):
194207
for handle in handles:
195208
handle.wait()
196209

197-
tmp_tensor = send_tensor
198-
send_tensor = recv_tensor
199-
recv_tensor = tmp_tensor
210+
switch_step()
200211

201-
recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group)
202-
send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group)
203-
handles = dist.batch_isend_irecv([recv_op, send_op])
212+
handles = communicate_step()
204213

205214
# actual computation
206-
output_tensors[i + 1] = F.linear(send_tensor, weight)
215+
output_tensors.append(func(**send_tensors, **input_local))
207216

208217
# final round: special case, no need to send/recv again
209218
for handle in handles:
210219
handle.wait()
211-
output_tensors[group_size - 1] = F.linear(recv_tensor, weight)
212-
return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=1)
220+
output_tensors.append(func(**recv_tensors, **input_local))
221+
222+
return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=gather_dim)
213223

214224

215225
class _GatherForwardReduceScatterBackward(torch.autograd.Function):
@@ -249,6 +259,41 @@ def backward(ctx, grad_output):
249259
return output, None, None
250260

251261

262+
class _GatherForwardReduceScatterBackward(torch.autograd.Function):
263+
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward
264+
265+
Args:
266+
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
267+
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
268+
overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward.
269+
270+
"""
271+
272+
@staticmethod
273+
def forward(ctx, input_, process_group, dim):
274+
ctx.process_group = process_group
275+
ctx.dim = dim
276+
277+
return _gather(input_, dim, process_group)
278+
279+
@staticmethod
280+
def backward(ctx, grad_output):
281+
dim = ctx.dim
282+
process_group = ctx.process_group
283+
284+
# do reduce-scatter
285+
new_shape = list(grad_output.shape)
286+
assert (
287+
new_shape[dim] % dist.get_world_size(process_group) == 0
288+
), f"The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). "
289+
new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group)
290+
grad_list = [item.contiguous() for item in torch.chunk(grad_output, dist.get_world_size(process_group), dim=dim)]
291+
output = torch.empty(new_shape, dtype=grad_output.dtype, device=grad_output.device)
292+
dist.reduce_scatter(output, grad_list, group=process_group)
293+
294+
return output, None, None
295+
296+
252297
class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
253298
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward
254299
@@ -260,19 +305,35 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
260305
"""
261306

262307
@staticmethod
263-
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True):
308+
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True, ring=False):
264309
ctx.save_for_backward(input_, weight, bias)
265310
ctx.use_bias = bias is not None
266311
ctx.process_group = process_group
267312
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
268313
ctx.dim = dim
269314
ctx.overlap = overlap
270315

271-
if bias is not None:
272-
input_parallel = _gather(input_, dim, process_group)
273-
output = F.linear(input_parallel, weight, bias)
316+
if ring is True:
317+
input_to_gather = {}
318+
input_local = {}
319+
input_to_gather['input'] = input_
320+
input_local['weight'] = weight
321+
322+
output = _ring_as_gather(
323+
F.linear,
324+
input_to_gather=input_to_gather,
325+
input_local=input_local,
326+
process_group=process_group,
327+
)
328+
329+
if bias is not None:
330+
output += bias
274331
else:
275-
output = _AllgatherLinear(input_, weight, process_group)
332+
input_parallel = _gather(input_, dim, process_group)
333+
if bias is not None:
334+
output = F.linear(input_parallel, weight, bias)
335+
else:
336+
output = F.linear(input_parallel, weight)
276337

277338
return output
278339

@@ -376,34 +437,43 @@ def backward(ctx, grad_output):
376437
# wait until reduce-scatter finished
377438
reducescatter_handle.wait()
378439

379-
return output, grad_weight, grad_bias, None, None, None, None
440+
return output, grad_weight, grad_bias, None, None, None, None, None
380441

381442

382-
def _ReduceScatterLinear(input_, weight, process_group):
443+
def _ring_as_reducescatter(func, input_to_reducescatter=None, input_local=None, process_group=None, reducescatter_dim=1):
444+
# currently only support one single tensor as output
383445
group_size = dist.get_world_size(process_group)
384446
cur_rank = dist.get_rank(process_group)
385447

386-
input_shape = input_.shape
387-
388448
# initialization of ring communication
389-
# communicate(e.g.): 0->1->2->3
390-
# compute(e.g.): 3->2->1->0
391-
input_tensors = list(torch.split(input_, int(input_shape[1] / group_size), dim=1))
392-
input_tensors = input_tensors[cur_rank:] + input_tensors[:cur_rank]
393-
input_tensors.reverse()
394449
recv_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1
395450
send_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0
451+
input_tensors = []
452+
for _ in range(group_size):
453+
input_tensors.append({})
454+
for k, v in input_to_reducescatter.items():
455+
input_shape = v.shape
456+
assert input_shape[reducescatter_dim] % group_size == 0
457+
_input_tensors = list(torch.split(v, input_shape[reducescatter_dim] // group_size, dim=reducescatter_dim))
458+
for i in range(group_size):
459+
input_tensors[i][k] = _input_tensors[i]
460+
input_tensors = input_tensors[cur_rank:] + input_tensors[:cur_rank]
461+
input_tensors.reverse()
396462

397-
# first round: special case, no reduce operation
398-
output_tensor = F.linear(input_tensors[0], weight)
463+
output_tensor = func(**input_tensors[0], **input_local)
399464
recv_tensor = output_tensor.clone()
400465
send_tensor = output_tensor.clone()
401-
recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group)
402-
send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group)
403-
handles = dist.batch_isend_irecv([recv_op, send_op])
466+
467+
def communicate_step():
468+
recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group)
469+
send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group)
470+
return dist.batch_isend_irecv([recv_op, send_op])
471+
472+
handles = communicate_step()
473+
# first round: special case, retrive from local tensor
404474
for i in range(group_size - 2):
405475
# actual computation
406-
output_tensor = F.linear(input_tensors[i + 1], weight)
476+
output_tensor = func(**input_tensors[i + 1], **input_local)
407477

408478
for handle in handles:
409479
handle.wait()
@@ -413,12 +483,10 @@ def _ReduceScatterLinear(input_, weight, process_group):
413483
send_tensor = output_tensor
414484
output_tensor = tmp_tensor
415485

416-
recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group)
417-
send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group)
418-
handles = dist.batch_isend_irecv([recv_op, send_op])
486+
handles = communicate_step()
419487

420488
# final round: special case, no need to send/recv again
421-
output_tensor = F.linear(input_tensors[group_size - 1], weight)
489+
output_tensor = func(**input_tensors[-1], **input_local)
422490
for handle in handles:
423491
handle.wait()
424492
output_tensor += recv_tensor
@@ -436,27 +504,44 @@ class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
436504
"""
437505

438506
@staticmethod
439-
def forward(ctx, input_, weight, bias, process_group, dim):
507+
def forward(ctx, input_, weight, bias, process_group, dim, ring):
440508
ctx.save_for_backward(input_, weight, bias)
441509
ctx.use_bias = bias is not None
442510
ctx.process_group = process_group
443511
ctx.dim = dim
444-
if bias is not None:
445-
partial_output = F.linear(input_, weight, bias)
512+
513+
if ring is True:
514+
input_to_reducescatter = {}
515+
input_local = {}
516+
input_to_reducescatter['input'] = input_
517+
input_local['weight'] = weight
518+
519+
if bias is not None:
520+
input_to_reducescatter['bias'] = bias
521+
522+
output = _ring_as_reducescatter(
523+
F.linear,
524+
input_to_reducescatter=input_to_reducescatter,
525+
input_local=input_local,
526+
process_group=process_group,
527+
)
446528
else:
447-
return _ReduceScatterLinear(input_, weight, process_group)
529+
if bias is not None:
530+
partial_output = F.linear(input_, weight, bias)
531+
else:
532+
partial_output = F.linear(input_, weight)
448533

449-
output_shape = list(partial_output.shape)
450-
assert (
451-
output_shape[dim] % dist.get_world_size(process_group) == 0
452-
), f"The dimension to split ({output_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). "
453-
output_shape[dim] = output_shape[dim] // dist.get_world_size(process_group)
534+
output_shape = list(partial_output.shape)
535+
assert (
536+
output_shape[dim] % dist.get_world_size(process_group) == 0
537+
), f"The dimension to split ({output_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). "
538+
output_shape[dim] = output_shape[dim] // dist.get_world_size(process_group)
454539

455-
output_list = [
456-
item.contiguous() for item in torch.chunk(partial_output, dist.get_world_size(process_group), dim=dim)
457-
]
458-
output = torch.empty(output_shape, dtype=partial_output.dtype, device=partial_output.device).contiguous()
459-
dist.reduce_scatter(output, output_list, group=process_group)
540+
output_list = [
541+
item.contiguous() for item in torch.chunk(partial_output, dist.get_world_size(process_group), dim=dim)
542+
]
543+
output = torch.empty(output_shape, dtype=partial_output.dtype, device=partial_output.device).contiguous()
544+
dist.reduce_scatter(output, output_list, group=process_group)
460545

461546
return output
462547

@@ -484,7 +569,7 @@ def backward(ctx, grad_output):
484569
grad_weight = grad_output.t().matmul(total_input)
485570
grad_bias = grad_output.sum(dim=0) if use_bias else None
486571

487-
return grad_input, grad_weight, grad_bias, None, None
572+
return grad_input, grad_weight, grad_bias, None, None, None
488573

489574

490575
class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
@@ -533,17 +618,32 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
533618
"""
534619

535620
@staticmethod
536-
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap):
621+
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring):
537622
ctx.save_for_backward(input_, weight, bias)
538623
ctx.use_bias = bias is not None
539624
ctx.process_group = process_group
540625
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
541626
ctx.dim = dim
542627
ctx.overlap = overlap
543628

544-
input_parallel = _gather(input_, dim, process_group)
629+
if ring is True:
630+
input_to_gather = {}
631+
input_local = {}
632+
input_to_gather['input'] = input_
633+
input_local['other'] = weight
634+
635+
output = _ring_as_gather(
636+
torch.matmul,
637+
input_to_gather=input_to_gather,
638+
input_local=input_local,
639+
process_group=process_group,
640+
gather_dim=dim
641+
)
642+
643+
else:
644+
input_parallel = _gather(input_, dim, process_group)
545645

546-
output = torch.matmul(input_parallel, weight)
646+
output = torch.matmul(input_parallel, weight)
547647

548648
if bias is not None:
549649
output = output + bias
@@ -624,7 +724,7 @@ def backward(ctx, grad_output):
624724
# wait until reduce-scatter finished
625725
reducescatter_handle.wait()
626726

627-
return output, grad_weight, grad_bias, None, None, None, None
727+
return output, grad_weight, grad_bias, None, None, None, None, None
628728

629729

630730
class _SplitForwardGatherBackward(torch.autograd.Function):
@@ -877,10 +977,10 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre
877977

878978

879979
def linear_gather_forward_reducescatter_backward(
880-
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap
980+
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
881981
):
882982
return _LinearWithGatherForwardReduceScatterBackward.apply(
883-
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap
983+
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring
884984
)
885985

886986

@@ -892,15 +992,15 @@ def reducescatter_forward_gather_backward(input_, process_group, dim):
892992
return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim)
893993

894994

895-
def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, process_group=None, dim=1):
896-
return _LinearWithReduceScatterForwardGatherBackward.apply(input_, weight, bias, process_group, dim)
995+
def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, process_group=None, dim=1, ring=False):
996+
return _LinearWithReduceScatterForwardGatherBackward.apply(input_, weight, bias, process_group, dim, ring)
897997

898998

899999
def matmul_gather_forward_reducescatter_backward(
900-
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap
1000+
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
9011001
):
9021002
return _MatmulWithGatherForwardReduceScatterBackward.apply(
903-
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap
1003+
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring
9041004
)
9051005

9061006

0 commit comments

Comments
 (0)