Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions colossalai/nn/layer/parallel_2d/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def matmul_2d(


class _Classifier2D(torch.autograd.Function):

@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(
Expand All @@ -76,7 +77,7 @@ def forward(
pipeline_parallel_size: int,
tensor_parallel_size: int,
) -> Tensor:

A = A.clone().detach()
A_shape = A.shape
A = A.reshape((-1, A_shape[-1]))
B_shape = B.shape
Expand Down Expand Up @@ -181,6 +182,7 @@ class Matmul_AB_2D(torch.autograd.Function):
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
"""

@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(
Expand Down Expand Up @@ -308,6 +310,7 @@ class Matmul_ABT_2D(torch.autograd.Function):
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
"""

@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(
Expand Down Expand Up @@ -440,6 +443,7 @@ class Matmul_ATB_2D(torch.autograd.Function):
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
"""

@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(
Expand Down Expand Up @@ -552,6 +556,7 @@ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:


class _Add_Bias_2D(torch.autograd.Function):

@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(
Expand Down Expand Up @@ -633,6 +638,7 @@ def add_bias_2d(input_: Tensor, bias: Tensor, output_size_per_partition: int, ro


class _Layernorm_2D(torch.autograd.Function):

@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx: Any, input_: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, row_parallel_mode: ParallelMode,
Expand Down Expand Up @@ -689,6 +695,7 @@ def layernorm_2d(input_: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, r


class _AllGatherTensor2D(torch.autograd.Function):

@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, inputs: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
Expand Down Expand Up @@ -742,6 +749,7 @@ def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor:


class _ReduceTensor2D(torch.autograd.Function):

@staticmethod
def forward(ctx, input_, parallel_mode):
return all_reduce(input_, parallel_mode)
Expand All @@ -766,6 +774,7 @@ def reduce_tensor_2d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor:


class _ReduceScatterTensor2D(torch.autograd.Function):

@staticmethod
def forward(ctx, input_, dim, parallel_mode):
ctx.dim = dim
Expand Down Expand Up @@ -793,11 +802,12 @@ def reduce_scatter_tensor_2d(tensor: Tensor, dim: int, parallel_mode: ParallelMo
world_size = gpc.get_world_size(parallel_mode)
assert dim_size % world_size == 0, \
f'The batch size ({dim_size}) is not a multiple of 2D size ({world_size}).'

return _ReduceScatterTensor2D.apply(tensor, dim, parallel_mode)


class _ReduceByBatch2D(torch.autograd.Function):

@staticmethod
def symbolic(graph, input_, reduce_mean: bool = False):
output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL)
Expand Down Expand Up @@ -834,4 +844,4 @@ def reduce_by_batch_2d(input_, reduce_mean: bool = False) -> Tensor:
reduce_mean (bool, optional):
If set to ``True``, it will divide the output by column parallel size, default to False.
"""
return _ReduceByBatch2D.apply(input_, reduce_mean)
return _ReduceByBatch2D.apply(input_, reduce_mean)
36 changes: 21 additions & 15 deletions colossalai/nn/layer/parallel_2p5d/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,26 @@ def get_parallel_rank(parallel_mode: ParallelMode):


class _Classifier2p5D(torch.autograd.Function):

@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(
ctx: Any,
A: Tensor,
B: Tensor,
bias,
tesseract_dim: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int,
ctx: Any,
A: Tensor,
B: Tensor,
bias,
tesseract_dim: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int,
) -> Tensor:

A = A.clone().detach()
A_shape = A.shape
A = A.reshape((-1, A_shape[-1]))
B_shape = B.shape
Expand Down Expand Up @@ -509,6 +510,7 @@ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:


class _Add_Bias_2p5D(torch.autograd.Function):

@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, input: Tensor, bias: Tensor, output_size_per_partition: int, tesseract_dim: int,
Expand Down Expand Up @@ -689,6 +691,7 @@ def layernorm_2p5d(input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int,


class _AllGatherTensor2p5D(torch.autograd.Function):

@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, inputs: Tensor, dim: int, col_parallel_mode: ParallelMode) -> Tensor:
Expand Down Expand Up @@ -777,6 +780,7 @@ def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor:


class _ReduceTensor2p5D(torch.autograd.Function):

@staticmethod
def forward(ctx, input_, parallel_mode):
return all_reduce(input_, parallel_mode)
Expand All @@ -801,6 +805,7 @@ def reduce_tensor_2p5d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor:


class _ReduceScatterTensor2p5D(torch.autograd.Function):

@staticmethod
def forward(ctx, input_, dim, parallel_mode):
ctx.dim = dim
Expand Down Expand Up @@ -833,6 +838,7 @@ def reduce_scatter_tensor_2p5d(input_: Tensor, dim: int, parallel_mode: Parallel


class _RreduceByBatch2p5D(torch.autograd.Function):

@staticmethod
def symbolic(graph, input_, reduce_mean: bool = False):
output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL)
Expand Down