Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,10 +660,6 @@ def __init__(
self.dp_pg = dp_process_group
self.tp_pg = tp_process_group
self.pp_pg = pp_process_group
self.use_all_to_all_sequence_parallel = (
self.model.shard_config.enable_sequence_parallelism
and self.model.shard_config.sequence_parallelism_mode == "all_to_all"
)
if use_pipeline:
init_pipeline_optimizer(optimizer, model)
super().__init__(
Expand All @@ -684,7 +680,6 @@ def __init__(
cpu_offload=cpu_offload,
dp_process_group=dp_process_group,
forced_dtype=forced_dtype,
enable_sequence_parallel=self.use_all_to_all_sequence_parallel,
)

def sync_dp_grads(self):
Expand Down Expand Up @@ -1098,6 +1093,7 @@ def __init__(
enable_sequence_parallelism=enable_sequence_parallelism,
sequence_parallelism_mode=sequence_parallelism_mode,
enable_sequence_overlap=enable_sequence_overlap,
zero_stage=zero_stage,
)
self.amp_config = dict(
initial_scale=initial_scale,
Expand Down
12 changes: 6 additions & 6 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ def backward(ctx, grad_output):
grad_output = grad_output * dist.get_world_size(ctx.process_group)
elif ctx.grad_scale == "down":
grad_output = grad_output / dist.get_world_size(ctx.process_group)
return _gather(grad_output, ctx.dim, ctx.process_group), None, None
return _gather(grad_output, ctx.dim, ctx.process_group), None, None, None


class _ReduceForward(torch.autograd.Function):
Expand Down Expand Up @@ -819,7 +819,7 @@ def backward(ctx, grad_output):
grad_output = grad_output * dist.get_world_size(ctx.process_group)
elif ctx.grad_scale == "down":
grad_output = grad_output / dist.get_world_size(ctx.process_group)
return _split(grad_output, ctx.dim, ctx.process_group), None, None
return _split(grad_output, ctx.dim, ctx.process_group), None, None, None


class _AllToAll(torch.autograd.Function):
Expand Down Expand Up @@ -1020,12 +1020,12 @@ def matmul_gather_forward_reducescatter_backward(
)


def gather_forward_split_backward(input_, dim, process_group):
return _GatherForwardSplitBackward.apply(input_, dim, process_group)
def gather_forward_split_backward(input_, dim, process_group, grad_scale=None):
return _GatherForwardSplitBackward.apply(input_, dim, process_group, grad_scale)


def split_forward_gather_backward(input_, dim, process_group):
return _SplitForwardGatherBackward.apply(input_, dim, process_group)
def split_forward_gather_backward(input_, dim, process_group, grad_scale=None):
return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale)


def reduce_forward(input_, process_group):
Expand Down
19 changes: 6 additions & 13 deletions colossalai/shardformer/modeling/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,8 +853,6 @@ def forward(

# use variable seq_len to replace input_shape[-1]
seq_len = input_shape[-1]
if sp_mode in ["ring", "all_to_all"]:
seq_len *= sp_size

if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, seq_len)
Expand All @@ -866,8 +864,6 @@ def forward(
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
if sp_mode in ["ring", "all_to_all"]:
past_length *= sp_size
if position_ids is None:
position_ids = torch.arange(past_length, seq_len + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, seq_len)
Expand All @@ -876,9 +872,6 @@ def forward(
if sp_mode in ["ring", "all_to_all"]:
position_ids = torch.chunk(position_ids.clone(), sp_size, dim=1)[dist.get_rank(sp_group)]

if sp_mode in ["ring", "all_to_all"]:
attention_mask = _gather(attention_mask, 1, sp_group)

# GPT2Attention mask.
if attention_mask is not None:
if batch_size <= 0:
Expand Down Expand Up @@ -917,12 +910,12 @@ def forward(
head_mask = self.get_head_mask(head_mask, self.config.n_layer)

if inputs_embeds is None:
if sp_mode in ["ring"]:
input_ids = _gather(input_ids, 1, sp_group)
inputs_embeds = self.wte(input_ids)
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
else:
inputs_embeds = self.wte(input_ids)
inputs_embeds = self.wte(input_ids)
if sp_mode == "ring":
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
elif sp_mode == "all_to_all":
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 'down')

position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds

Expand Down
26 changes: 11 additions & 15 deletions colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,7 @@ def forward(
return forward


def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group, zero_stage=0):
logger = logging.get_logger(__name__)

# Copied from transformers.models.bart.modeling_bart._make_causal_mask
Expand Down Expand Up @@ -804,10 +804,6 @@ def forward(
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

# sp: modify seq_length when using sequence parallel
if sp_mode in ["ring", "all_to_all"]:
seq_length *= sp_size

seq_length_with_past = seq_length
past_key_values_length = 0

Expand All @@ -827,13 +823,12 @@ def forward(
position_ids = position_ids.view(-1, seq_length).long()

if inputs_embeds is None:
if sp_mode == "ring":
input_ids = _gather(input_ids, 1, sp_group)
inputs_embeds = self.embed_tokens(input_ids)
input_ids = input_ids.chunk(sp_size, dim=1)[torch.distributed.get_rank(sp_group)]
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
else:
inputs_embeds = self.embed_tokens(input_ids)
inputs_embeds = self.embed_tokens(input_ids)

if sp_mode in ["ring", "split_gather"]:
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
elif sp_mode == "all_to_all":
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 'down')

# TODO use_distributed_mask
use_distributed_mask = True if sp_mode in ["ring", "all_to_all"] else False
Expand Down Expand Up @@ -864,8 +859,6 @@ def forward(
attention_mask = _gather(attention_mask, 1, sp_group)

hidden_states = inputs_embeds
if sp_mode == "split_gather":
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group)

if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
if use_cache:
Expand Down Expand Up @@ -922,7 +915,10 @@ def custom_forward(*inputs):
hidden_states = self.norm(hidden_states)

# Todo: Maybe this line can be optimized
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale="up")
if sp_mode == "ring" or sp_mode == "split_gather" or (sp_mode == "all_to_all" and zero_stage == 0):
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
elif sp_mode == "all_to_all" and zero_stage in [1, 2]:
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale="up")

# add hidden states from the last decoder layer
if output_hidden_states:
Expand Down
4 changes: 3 additions & 1 deletion colossalai/shardformer/policies/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def module_policy(self):
SubModuleReplacementDescription(
suffix="mlp.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
kwargs={"seq_parallel": use_sequence_parallel},
kwargs={
"seq_parallel_mode": sp_mode,
},
),
SubModuleReplacementDescription(
suffix="attn.attn_dropout",
Expand Down
2 changes: 1 addition & 1 deletion colossalai/shardformer/policies/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
)
self.append_or_create_method_replacement(
description={
"forward": get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group),
"forward": get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group, self.shard_config.zero_stage),
},
policy=policy,
target_key=LlamaModel,
Expand Down
1 change: 1 addition & 0 deletions colossalai/shardformer/shard/shard_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class ShardConfig:
enable_jit_fused: bool = False
enable_sequence_parallelism: bool = False
sequence_parallelism_mode: str = None
zero_stage: int = 0
enable_sequence_overlap: bool = False
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
# pipeline_parallel_size: int
Expand Down
5 changes: 1 addition & 4 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,8 @@ def __init__(
forced_dtype: Optional[torch.dtype] = None,
moe_extra_dp_process_group: Optional[ProcessGroup] = None,
master_weights: bool = True, # master weights
enable_sequence_parallel: bool = False,
):
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
self._enable_sequence_parallel = enable_sequence_parallel

self._dtype = self.optim.param_groups[0]["params"][0].dtype
self._logger = get_dist_logger()
Expand Down Expand Up @@ -300,8 +298,7 @@ def _run_reduction(self):

if self.moe_extra_dp_pg is None:
flat_grads = self._bucket_store.get_flatten_grad()
if not self._enable_sequence_parallel:
flat_grads /= self._world_size
flat_grads /= self._world_size
else:
# record moe and non moe param
moe_list = []
Expand Down
13 changes: 1 addition & 12 deletions tests/test_shardformer/test_model/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,18 +159,7 @@ def _criterion(outputs, inputs):

shard_test_data = {}
for k, v in data.items():
if k not in ["input_ids", "attention_mask"]:
shard_test_data[k] = data[k].clone()
else:
# todo: check the correctness of using dim=-1: to be compatible with date_gen_for_double_heads()
shard_test_data[k] = (
torch.chunk(data[k].clone(), booster.plugin.shard_config.sequence_parallel_size, dim=-1)[
dist.get_rank(booster.plugin.shard_config.sequence_parallel_process_group)
]
if booster.plugin.shard_config.enable_sequence_parallelism
and booster.plugin.shard_config.sequence_parallelism_mode in ["ring", "all_to_all"]
else data[k].clone()
)
shard_test_data[k] = data[k].clone()
unshard_test_data = {}
for k, v in data.items():
unshard_test_data[k] = data[k].clone()
Expand Down