Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 171 additions & 71 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,47 +169,57 @@ def backward(ctx, grad_output):
return grad_input, grad_weight, grad_bias, None, None, None


def _AllgatherLinear(input_, weight, process_group):
def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=None, gather_dim=1, keep_item=False):
# currently only support one single tensor as output
group_size = dist.get_world_size(process_group)
cur_rank = dist.get_rank(process_group)

input_shape = input_.shape
weight_shape = weight.shape

output_tensors = [torch.empty((input_shape[0], input_shape[1], weight_shape[0])) for _ in range(group_size)]
#output_tensors = [torch.empty((input_shape[0], input_shape[1], weight_shape[0])) for _ in range(group_size)]

# initialization of ring communication
input_shape[1]
recv_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0
send_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1
recv_tensor = input_.clone()
send_tensor = input_.clone()

recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group)
send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group)
handles = dist.batch_isend_irecv([send_op, recv_op])
recv_tensors = {}
send_tensors = {}
for k, v in input_to_gather.items():
recv_tensors[k] = v.clone()
send_tensors[k] = v.clone()

def communicate_step():
comm_ops = []
for k in recv_tensors:
comm_ops.append(dist.P2POp(dist.irecv, recv_tensors[k], recv_rank, group=process_group))
comm_ops.append(dist.P2POp(dist.isend, send_tensors[k], send_rank, group=process_group))
return dist.batch_isend_irecv(comm_ops)

def switch_step():
for k in recv_tensors:
tmp_tensor = send_tensors[k]
send_tensors[k] = recv_tensors[k]
recv_tensors[k] = tmp_tensor

output_tensors = []

handles = communicate_step()
# first round: special case, retrive from local tensor
output_tensors[0] = F.linear(input_, weight)
output_tensors.append(func(**input_to_gather, **input_local))
for i in range(group_size - 2):
for handle in handles:
handle.wait()

tmp_tensor = send_tensor
send_tensor = recv_tensor
recv_tensor = tmp_tensor
switch_step()

recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group)
send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group)
handles = dist.batch_isend_irecv([recv_op, send_op])
handles = communicate_step()

# actual computation
output_tensors[i + 1] = F.linear(send_tensor, weight)
output_tensors.append(func(**send_tensors, **input_local))

# final round: special case, no need to send/recv again
for handle in handles:
handle.wait()
output_tensors[group_size - 1] = F.linear(recv_tensor, weight)
return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=1)
output_tensors.append(func(**recv_tensors, **input_local))

return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=gather_dim)


class _GatherForwardReduceScatterBackward(torch.autograd.Function):
Expand Down Expand Up @@ -249,6 +259,41 @@ def backward(ctx, grad_output):
return output, None, None


class _GatherForwardReduceScatterBackward(torch.autograd.Function):
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward

Args:
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward.

"""

@staticmethod
def forward(ctx, input_, process_group, dim):
ctx.process_group = process_group
ctx.dim = dim

return _gather(input_, dim, process_group)

@staticmethod
def backward(ctx, grad_output):
dim = ctx.dim
process_group = ctx.process_group

# do reduce-scatter
new_shape = list(grad_output.shape)
assert (
new_shape[dim] % dist.get_world_size(process_group) == 0
), f"The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). "
new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group)
grad_list = [item.contiguous() for item in torch.chunk(grad_output, dist.get_world_size(process_group), dim=dim)]
output = torch.empty(new_shape, dtype=grad_output.dtype, device=grad_output.device)
dist.reduce_scatter(output, grad_list, group=process_group)

return output, None, None


class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward

Expand All @@ -260,19 +305,35 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""

@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True):
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True, ring=False):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
ctx.dim = dim
ctx.overlap = overlap

if bias is not None:
input_parallel = _gather(input_, dim, process_group)
output = F.linear(input_parallel, weight, bias)
if ring is True:
input_to_gather = {}
input_local = {}
input_to_gather['input'] = input_
input_local['weight'] = weight

output = _ring_as_gather(
F.linear,
input_to_gather=input_to_gather,
input_local=input_local,
process_group=process_group,
)

if bias is not None:
output += bias
else:
output = _AllgatherLinear(input_, weight, process_group)
input_parallel = _gather(input_, dim, process_group)
if bias is not None:
output = F.linear(input_parallel, weight, bias)
else:
output = F.linear(input_parallel, weight)

return output

Expand Down Expand Up @@ -376,34 +437,43 @@ def backward(ctx, grad_output):
# wait until reduce-scatter finished
reducescatter_handle.wait()

return output, grad_weight, grad_bias, None, None, None, None
return output, grad_weight, grad_bias, None, None, None, None, None


def _ReduceScatterLinear(input_, weight, process_group):
def _ring_as_reducescatter(func, input_to_reducescatter=None, input_local=None, process_group=None, reducescatter_dim=1):
# currently only support one single tensor as output
group_size = dist.get_world_size(process_group)
cur_rank = dist.get_rank(process_group)

input_shape = input_.shape

# initialization of ring communication
# communicate(e.g.): 0->1->2->3
# compute(e.g.): 3->2->1->0
input_tensors = list(torch.split(input_, int(input_shape[1] / group_size), dim=1))
input_tensors = input_tensors[cur_rank:] + input_tensors[:cur_rank]
input_tensors.reverse()
recv_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1
send_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0
input_tensors = []
for _ in range(group_size):
input_tensors.append({})
for k, v in input_to_reducescatter.items():
input_shape = v.shape
assert input_shape[reducescatter_dim] % group_size == 0
_input_tensors = list(torch.split(v, input_shape[reducescatter_dim] // group_size, dim=reducescatter_dim))
for i in range(group_size):
input_tensors[i][k] = _input_tensors[i]
input_tensors = input_tensors[cur_rank:] + input_tensors[:cur_rank]
input_tensors.reverse()

# first round: special case, no reduce operation
output_tensor = F.linear(input_tensors[0], weight)
output_tensor = func(**input_tensors[0], **input_local)
recv_tensor = output_tensor.clone()
send_tensor = output_tensor.clone()
recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group)
send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group)
handles = dist.batch_isend_irecv([recv_op, send_op])

def communicate_step():
recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group)
send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group)
return dist.batch_isend_irecv([recv_op, send_op])

handles = communicate_step()
# first round: special case, retrive from local tensor
for i in range(group_size - 2):
# actual computation
output_tensor = F.linear(input_tensors[i + 1], weight)
output_tensor = func(**input_tensors[i + 1], **input_local)

for handle in handles:
handle.wait()
Expand All @@ -413,12 +483,10 @@ def _ReduceScatterLinear(input_, weight, process_group):
send_tensor = output_tensor
output_tensor = tmp_tensor

recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group)
send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group)
handles = dist.batch_isend_irecv([recv_op, send_op])
handles = communicate_step()

# final round: special case, no need to send/recv again
output_tensor = F.linear(input_tensors[group_size - 1], weight)
output_tensor = func(**input_tensors[-1], **input_local)
for handle in handles:
handle.wait()
output_tensor += recv_tensor
Expand All @@ -436,27 +504,44 @@ class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
"""

@staticmethod
def forward(ctx, input_, weight, bias, process_group, dim):
def forward(ctx, input_, weight, bias, process_group, dim, ring):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.dim = dim
if bias is not None:
partial_output = F.linear(input_, weight, bias)

if ring is True:
input_to_reducescatter = {}
input_local = {}
input_to_reducescatter['input'] = input_
input_local['weight'] = weight

if bias is not None:
input_to_reducescatter['bias'] = bias

output = _ring_as_reducescatter(
F.linear,
input_to_reducescatter=input_to_reducescatter,
input_local=input_local,
process_group=process_group,
)
else:
return _ReduceScatterLinear(input_, weight, process_group)
if bias is not None:
partial_output = F.linear(input_, weight, bias)
else:
partial_output = F.linear(input_, weight)

output_shape = list(partial_output.shape)
assert (
output_shape[dim] % dist.get_world_size(process_group) == 0
), f"The dimension to split ({output_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). "
output_shape[dim] = output_shape[dim] // dist.get_world_size(process_group)
output_shape = list(partial_output.shape)
assert (
output_shape[dim] % dist.get_world_size(process_group) == 0
), f"The dimension to split ({output_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). "
output_shape[dim] = output_shape[dim] // dist.get_world_size(process_group)

output_list = [
item.contiguous() for item in torch.chunk(partial_output, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(output_shape, dtype=partial_output.dtype, device=partial_output.device).contiguous()
dist.reduce_scatter(output, output_list, group=process_group)
output_list = [
item.contiguous() for item in torch.chunk(partial_output, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(output_shape, dtype=partial_output.dtype, device=partial_output.device).contiguous()
dist.reduce_scatter(output, output_list, group=process_group)

return output

Expand Down Expand Up @@ -484,7 +569,7 @@ def backward(ctx, grad_output):
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None

return grad_input, grad_weight, grad_bias, None, None
return grad_input, grad_weight, grad_bias, None, None, None


class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
Expand Down Expand Up @@ -533,17 +618,32 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""

@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap):
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
ctx.dim = dim
ctx.overlap = overlap

input_parallel = _gather(input_, dim, process_group)
if ring is True:
input_to_gather = {}
input_local = {}
input_to_gather['input'] = input_
input_local['other'] = weight

output = _ring_as_gather(
torch.matmul,
input_to_gather=input_to_gather,
input_local=input_local,
process_group=process_group,
gather_dim=dim
)

else:
input_parallel = _gather(input_, dim, process_group)

output = torch.matmul(input_parallel, weight)
output = torch.matmul(input_parallel, weight)

if bias is not None:
output = output + bias
Expand Down Expand Up @@ -624,7 +724,7 @@ def backward(ctx, grad_output):
# wait until reduce-scatter finished
reducescatter_handle.wait()

return output, grad_weight, grad_bias, None, None, None, None
return output, grad_weight, grad_bias, None, None, None, None, None


class _SplitForwardGatherBackward(torch.autograd.Function):
Expand Down Expand Up @@ -877,10 +977,10 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre


def linear_gather_forward_reducescatter_backward(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
):
return _LinearWithGatherForwardReduceScatterBackward.apply(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring
)


Expand All @@ -892,15 +992,15 @@ def reducescatter_forward_gather_backward(input_, process_group, dim):
return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim)


def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, process_group=None, dim=1):
return _LinearWithReduceScatterForwardGatherBackward.apply(input_, weight, bias, process_group, dim)
def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, process_group=None, dim=1, ring=False):
return _LinearWithReduceScatterForwardGatherBackward.apply(input_, weight, bias, process_group, dim, ring)


def matmul_gather_forward_reducescatter_backward(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
):
return _MatmulWithGatherForwardReduceScatterBackward.apply(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring
)


Expand Down
Loading