Skip to content

Commit 4b7b83a

Browse files
authored
sequence parallel: inside text split (hpcaitech#6)
1 parent ad9e332 commit 4b7b83a

File tree

9 files changed

+31
-57
lines changed

9 files changed

+31
-57
lines changed

colossalai/booster/plugin/hybrid_parallel_plugin.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -660,10 +660,6 @@ def __init__(
660660
self.dp_pg = dp_process_group
661661
self.tp_pg = tp_process_group
662662
self.pp_pg = pp_process_group
663-
self.use_all_to_all_sequence_parallel = (
664-
self.model.shard_config.enable_sequence_parallelism
665-
and self.model.shard_config.sequence_parallelism_mode == "all_to_all"
666-
)
667663
if use_pipeline:
668664
init_pipeline_optimizer(optimizer, model)
669665
super().__init__(
@@ -684,7 +680,6 @@ def __init__(
684680
cpu_offload=cpu_offload,
685681
dp_process_group=dp_process_group,
686682
forced_dtype=forced_dtype,
687-
enable_sequence_parallel=self.use_all_to_all_sequence_parallel,
688683
)
689684

690685
def sync_dp_grads(self):
@@ -1098,6 +1093,7 @@ def __init__(
10981093
enable_sequence_parallelism=enable_sequence_parallelism,
10991094
sequence_parallelism_mode=sequence_parallelism_mode,
11001095
enable_sequence_overlap=enable_sequence_overlap,
1096+
zero_stage=zero_stage,
11011097
)
11021098
self.amp_config = dict(
11031099
initial_scale=initial_scale,

colossalai/shardformer/layer/_operation.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -756,7 +756,7 @@ def backward(ctx, grad_output):
756756
grad_output = grad_output * dist.get_world_size(ctx.process_group)
757757
elif ctx.grad_scale == "down":
758758
grad_output = grad_output / dist.get_world_size(ctx.process_group)
759-
return _gather(grad_output, ctx.dim, ctx.process_group), None, None
759+
return _gather(grad_output, ctx.dim, ctx.process_group), None, None, None
760760

761761

762762
class _ReduceForward(torch.autograd.Function):
@@ -819,7 +819,7 @@ def backward(ctx, grad_output):
819819
grad_output = grad_output * dist.get_world_size(ctx.process_group)
820820
elif ctx.grad_scale == "down":
821821
grad_output = grad_output / dist.get_world_size(ctx.process_group)
822-
return _split(grad_output, ctx.dim, ctx.process_group), None, None
822+
return _split(grad_output, ctx.dim, ctx.process_group), None, None, None
823823

824824

825825
class _AllToAll(torch.autograd.Function):
@@ -1020,12 +1020,12 @@ def matmul_gather_forward_reducescatter_backward(
10201020
)
10211021

10221022

1023-
def gather_forward_split_backward(input_, dim, process_group):
1024-
return _GatherForwardSplitBackward.apply(input_, dim, process_group)
1023+
def gather_forward_split_backward(input_, dim, process_group, grad_scale=None):
1024+
return _GatherForwardSplitBackward.apply(input_, dim, process_group, grad_scale)
10251025

10261026

1027-
def split_forward_gather_backward(input_, dim, process_group):
1028-
return _SplitForwardGatherBackward.apply(input_, dim, process_group)
1027+
def split_forward_gather_backward(input_, dim, process_group, grad_scale=None):
1028+
return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale)
10291029

10301030

10311031
def reduce_forward(input_, process_group):

colossalai/shardformer/modeling/gpt2.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -853,8 +853,6 @@ def forward(
853853

854854
# use variable seq_len to replace input_shape[-1]
855855
seq_len = input_shape[-1]
856-
if sp_mode in ["ring", "all_to_all"]:
857-
seq_len *= sp_size
858856

859857
if token_type_ids is not None:
860858
token_type_ids = token_type_ids.view(-1, seq_len)
@@ -866,8 +864,6 @@ def forward(
866864
past_key_values = tuple([None] * len(self.h))
867865
else:
868866
past_length = past_key_values[0][0].size(-2)
869-
if sp_mode in ["ring", "all_to_all"]:
870-
past_length *= sp_size
871867
if position_ids is None:
872868
position_ids = torch.arange(past_length, seq_len + past_length, dtype=torch.long, device=device)
873869
position_ids = position_ids.unsqueeze(0).view(-1, seq_len)
@@ -876,9 +872,6 @@ def forward(
876872
if sp_mode in ["ring", "all_to_all"]:
877873
position_ids = torch.chunk(position_ids.clone(), sp_size, dim=1)[dist.get_rank(sp_group)]
878874

879-
if sp_mode in ["ring", "all_to_all"]:
880-
attention_mask = _gather(attention_mask, 1, sp_group)
881-
882875
# GPT2Attention mask.
883876
if attention_mask is not None:
884877
if batch_size <= 0:
@@ -917,12 +910,12 @@ def forward(
917910
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
918911

919912
if inputs_embeds is None:
920-
if sp_mode in ["ring"]:
921-
input_ids = _gather(input_ids, 1, sp_group)
922-
inputs_embeds = self.wte(input_ids)
923-
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
924-
else:
925-
inputs_embeds = self.wte(input_ids)
913+
inputs_embeds = self.wte(input_ids)
914+
if sp_mode == "ring":
915+
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
916+
elif sp_mode == "all_to_all":
917+
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 'down')
918+
926919
position_embeds = self.wpe(position_ids)
927920
hidden_states = inputs_embeds + position_embeds
928921

colossalai/shardformer/modeling/llama.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,7 @@ def forward(
694694
return forward
695695

696696

697-
def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
697+
def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group, zero_stage=0):
698698
logger = logging.get_logger(__name__)
699699

700700
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
@@ -804,10 +804,6 @@ def forward(
804804
else:
805805
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
806806

807-
# sp: modify seq_length when using sequence parallel
808-
if sp_mode in ["ring", "all_to_all"]:
809-
seq_length *= sp_size
810-
811807
seq_length_with_past = seq_length
812808
past_key_values_length = 0
813809

@@ -827,13 +823,12 @@ def forward(
827823
position_ids = position_ids.view(-1, seq_length).long()
828824

829825
if inputs_embeds is None:
830-
if sp_mode == "ring":
831-
input_ids = _gather(input_ids, 1, sp_group)
832-
inputs_embeds = self.embed_tokens(input_ids)
833-
input_ids = input_ids.chunk(sp_size, dim=1)[torch.distributed.get_rank(sp_group)]
834-
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
835-
else:
836-
inputs_embeds = self.embed_tokens(input_ids)
826+
inputs_embeds = self.embed_tokens(input_ids)
827+
828+
if sp_mode in ["ring", "split_gather"]:
829+
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
830+
elif sp_mode == "all_to_all":
831+
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 'down')
837832

838833
# TODO use_distributed_mask
839834
use_distributed_mask = True if sp_mode in ["ring", "all_to_all"] else False
@@ -864,8 +859,6 @@ def forward(
864859
attention_mask = _gather(attention_mask, 1, sp_group)
865860

866861
hidden_states = inputs_embeds
867-
if sp_mode == "split_gather":
868-
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group)
869862

870863
if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
871864
if use_cache:
@@ -922,7 +915,10 @@ def custom_forward(*inputs):
922915
hidden_states = self.norm(hidden_states)
923916

924917
# Todo: Maybe this line can be optimized
925-
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale="up")
918+
if sp_mode == "ring" or sp_mode == "split_gather" or (sp_mode == "all_to_all" and zero_stage == 0):
919+
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
920+
elif sp_mode == "all_to_all" and zero_stage in [1, 2]:
921+
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale="up")
926922

927923
# add hidden states from the last decoder layer
928924
if output_hidden_states:

colossalai/shardformer/policies/gpt2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,9 @@ def module_policy(self):
109109
SubModuleReplacementDescription(
110110
suffix="mlp.c_proj",
111111
target_module=col_nn.GPT2FusedLinearConv1D_Row,
112-
kwargs={"seq_parallel": use_sequence_parallel},
112+
kwargs={
113+
"seq_parallel_mode": sp_mode,
114+
},
113115
),
114116
SubModuleReplacementDescription(
115117
suffix="attn.attn_dropout",

colossalai/shardformer/policies/llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
121121
)
122122
self.append_or_create_method_replacement(
123123
description={
124-
"forward": get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group),
124+
"forward": get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group, self.shard_config.zero_stage),
125125
},
126126
policy=policy,
127127
target_key=LlamaModel,

colossalai/shardformer/shard/shard_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class ShardConfig:
3737
enable_jit_fused: bool = False
3838
enable_sequence_parallelism: bool = False
3939
sequence_parallelism_mode: str = None
40+
zero_stage: int = 0
4041
enable_sequence_overlap: bool = False
4142
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
4243
# pipeline_parallel_size: int

colossalai/zero/low_level/low_level_optim.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,8 @@ def __init__(
7777
forced_dtype: Optional[torch.dtype] = None,
7878
moe_extra_dp_process_group: Optional[ProcessGroup] = None,
7979
master_weights: bool = True, # master weights
80-
enable_sequence_parallel: bool = False,
8180
):
8281
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
83-
self._enable_sequence_parallel = enable_sequence_parallel
8482

8583
self._dtype = self.optim.param_groups[0]["params"][0].dtype
8684
self._logger = get_dist_logger()
@@ -300,8 +298,7 @@ def _run_reduction(self):
300298

301299
if self.moe_extra_dp_pg is None:
302300
flat_grads = self._bucket_store.get_flatten_grad()
303-
if not self._enable_sequence_parallel:
304-
flat_grads /= self._world_size
301+
flat_grads /= self._world_size
305302
else:
306303
# record moe and non moe param
307304
moe_list = []

tests/test_shardformer/test_model/_utils.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -159,18 +159,7 @@ def _criterion(outputs, inputs):
159159

160160
shard_test_data = {}
161161
for k, v in data.items():
162-
if k not in ["input_ids", "attention_mask"]:
163-
shard_test_data[k] = data[k].clone()
164-
else:
165-
# todo: check the correctness of using dim=-1: to be compatible with date_gen_for_double_heads()
166-
shard_test_data[k] = (
167-
torch.chunk(data[k].clone(), booster.plugin.shard_config.sequence_parallel_size, dim=-1)[
168-
dist.get_rank(booster.plugin.shard_config.sequence_parallel_process_group)
169-
]
170-
if booster.plugin.shard_config.enable_sequence_parallelism
171-
and booster.plugin.shard_config.sequence_parallelism_mode in ["ring", "all_to_all"]
172-
else data[k].clone()
173-
)
162+
shard_test_data[k] = data[k].clone()
174163
unshard_test_data = {}
175164
for k, v in data.items():
176165
unshard_test_data[k] = data[k].clone()

0 commit comments

Comments
 (0)