Skip to content

Commit 9af946b

Browse files
committed
padding vocab_size when using pipeline parallellism
padding vocab_size when using pipeline parallellism fix
1 parent 2c2c3cd commit 9af946b

File tree

5 files changed

+21
-6
lines changed

5 files changed

+21
-6
lines changed

colossalai/booster/plugin/hybrid_parallel_plugin.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -937,6 +937,7 @@ def __init__(
937937
enable_jit_fused: bool = False,
938938
enable_sequence_parallelism: bool = False,
939939
enable_sequence_overlap: bool = False,
940+
parallel_output: bool = True,
940941
num_microbatches: Optional[int] = None,
941942
microbatch_size: Optional[int] = None,
942943
initial_scale: float = 2**16,
@@ -961,6 +962,7 @@ def __init__(
961962
pp_style: str = "1f1b",
962963
num_model_chunks: int = 1,
963964
enable_metadata_cache: bool = True,
965+
make_vocab_size_divisible_by: int = 128,
964966
) -> None:
965967
super().__init__()
966968
assert (
@@ -1033,6 +1035,8 @@ def __init__(
10331035
enable_jit_fused=self.enable_jit_fused,
10341036
enable_sequence_parallelism=enable_sequence_parallelism,
10351037
enable_sequence_overlap=enable_sequence_overlap,
1038+
parallel_output=parallel_output,
1039+
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
10361040
)
10371041
self.amp_config = dict(
10381042
initial_scale=initial_scale,

colossalai/shardformer/modeling/gpt2.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -783,11 +783,12 @@ def forward(
783783
scale = scale * (1 / float(self.layer_idx + 1))
784784

785785
# use coloattention
786-
attention = ColoAttention(
787-
embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.attn_dropout.p, scale=scale
788-
)
786+
if not hasattr(self, "attention"):
787+
self.attention = ColoAttention(
788+
embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.attn_dropout.p, scale=scale
789+
)
789790

790-
attn_output = attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type)
791+
attn_output = self.attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type)
791792

792793
attn_output = self.c_proj(attn_output)
793794
attn_output = self.resid_dropout(attn_output)

colossalai/shardformer/modeling/llama.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -481,8 +481,9 @@ def forward(
481481
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
482482
attn_mask_type = AttnMaskType.paddedcausal
483483

484-
attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
485-
attn_output = attention(
484+
if not hasattr(self, "attention"):
485+
self.attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
486+
attn_output = self.attention(
486487
query_states,
487488
key_states,
488489
value_states,

colossalai/shardformer/policies/gpt2.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@ def preprocess(self):
3333
if vocab_size % world_size != 0:
3434
new_vocab_size = vocab_size + world_size - vocab_size % world_size
3535
self.model.resize_token_embeddings(new_vocab_size)
36+
elif self.shard_config.pipeline_stage_manager is not None:
37+
# padding vocab_size when using pipeline parallellism
38+
new_vocab_size = vocab_size
39+
multiple = self.shard_config.make_vocab_size_divisible_by
40+
while (new_vocab_size % multiple) != 0:
41+
new_vocab_size += 1
42+
self.model.resize_token_embeddings(new_vocab_size)
3643
return self.model
3744

3845
def module_policy(self):

colossalai/shardformer/shard/shard_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ class ShardConfig:
3434
enable_all_optimization: bool = False
3535
enable_sequence_parallelism: bool = False
3636
enable_sequence_overlap: bool = False
37+
parallel_output: bool = True
38+
make_vocab_size_divisible_by: int = 128
3739
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
3840
# pipeline_parallel_size: int
3941
# data_parallel_size: int

0 commit comments

Comments
 (0)