Skip to content

Commit 14a4342

Browse files
committed
resolve comments
resolve comments resolve comments resolve comments resolve comments
1 parent f08e084 commit 14a4342

File tree

7 files changed

+84
-28
lines changed

7 files changed

+84
-28
lines changed

colossalai/booster/plugin/gemini_plugin.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,8 @@ def get_param_info(model: nn.Module, optim: Optimizer):
4646
# 1. A mapping from integer param_id to param32 shape.
4747

4848
param_info = {"id2shape": {}, "name2shape": {}}
49-
for m_name, m_var in model.named_modules():
50-
for p_name, p_var in m_var.named_parameters(recurse=False):
51-
param_name = m_name + "." + p_name if m_name else p_name
52-
original_shape = p_var.shape if isinstance(p_var, torch.Tensor) else None
53-
param_info["name2shape"][param_name] = original_shape
49+
for p_name, param in model.named_parameters(remove_duplicate=False):
50+
param_info["name2shape"][p_name] = param.shape
5451

5552
if optim is None:
5653
return param_info

colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
save_param_groups,
3333
save_state_dict,
3434
save_state_dict_shards,
35+
search_padding_dim,
3536
search_tp_partition_dim,
3637
sharded_optimizer_loading_epilogue,
3738
)
@@ -937,14 +938,29 @@ def shard_from_complete_optimizer_state(
937938
if isinstance(v, torch.Tensor) and k != "step":
938939
# Shard state along tensor parallel group.
939940
partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size)
941+
global_shape = current_shape
940942
if partition_dim is not None:
941-
slice_size = current_shape[partition_dim]
942943
# pad embedding params
943-
if partition_dim == 0:
944-
padding_size = current_shape[0] * self.tp_size - original_shape[0]
945-
if padding_size > 0:
946-
padding_data = torch.zeros_like(v[:padding_size, ...])
947-
v = torch.cat((v, padding_data), dim=0).contiguous()
944+
global_shape = (
945+
*current_shape[:partition_dim],
946+
current_shape[partition_dim] * self.tp_size,
947+
*current_shape[partition_dim + 1 :],
948+
)
949+
950+
padding_dim = search_padding_dim(global_shape, original_shape)
951+
if padding_dim is not None:
952+
padding_size = global_shape[padding_dim] - original_shape[padding_dim]
953+
padding_data = torch.zeros(
954+
*v.shape[:padding_dim],
955+
padding_size,
956+
*v.shape[padding_dim + 1 :],
957+
device=v.device,
958+
dtype=v.dtype,
959+
)
960+
v = torch.cat((v, padding_data), dim=padding_dim).contiguous()
961+
962+
if partition_dim is not None:
963+
slice_size = current_shape[partition_dim]
948964
v = v.split(slice_size, dim=partition_dim)[self.tp_rank]
949965

950966
# Shard state along data parallel group when using Zero.

colossalai/checkpoint_io/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,15 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz
120120
return partition_dim
121121

122122

123+
def search_padding_dim(global_shape: torch.Size, original_shape: torch.Size) -> Optional[int]:
124+
padding_dim = None
125+
for dim, length in enumerate(global_shape):
126+
if length > original_shape[dim]:
127+
padding_dim = dim
128+
break
129+
return padding_dim
130+
131+
123132
# ======================================
124133
# Helper classes and functions for saving shard file
125134
# ======================================

colossalai/shardformer/layer/parallel_module.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,9 @@ def _load_from_state_dict(
298298

299299
if self.new_num_embeddings > self.old_num_embeddings:
300300
num_padding_tokens = self.new_num_embeddings - self.old_num_embeddings
301-
padding_embeddings = torch.zeros_like(input_param[:num_padding_tokens, ...])
301+
padding_embeddings = torch.zeros(
302+
num_padding_tokens, *input_param.shape[1:], device=input_param.device, dtype=input_param.dtype
303+
)
302304
input_param.data = torch.cat((input_param.data, padding_embeddings), dim=0).contiguous()
303305

304306
if is_distributed_tensor(param):
@@ -359,7 +361,9 @@ def _load_from_state_dict(
359361
def resize_embedding_weight(self):
360362
num_padding_tokens = self.new_num_embeddings - self.old_num_embeddings
361363
valid_weight = self.weight.data
362-
padding_weight = torch.zeros_like(self.weight[:num_padding_tokens, ...])
364+
padding_weight = torch.zeros(
365+
num_padding_tokens, *self.weight.shape[1:], device=self.weight.device, dtype=self.weight.dtype
366+
)
363367
# padding to embedding
364368
self.weight.data = torch.cat((valid_weight, padding_weight), dim=0).contiguous()
365369

colossalai/zero/gemini/gemini_ddp.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from torch.distributed.distributed_c10d import _get_default_group
1212

1313
from colossalai.accelerator import get_accelerator
14-
from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param
14+
from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param, search_padding_dim
1515
from colossalai.interface import ModelWrapper
1616
from colossalai.lazy import LazyTensor
1717
from colossalai.logging import get_dist_logger
@@ -524,7 +524,13 @@ def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
524524
else:
525525
if self.params_info is not None:
526526
origin_shape = self.params_info["name2shape"][name]
527-
destination[prefix + name] = p_mapping[param][: origin_shape[0], ...]
527+
padding_dim = search_padding_dim(p_mapping[param].shape, origin_shape)
528+
if padding_dim is not None:
529+
unpadding_slices = [slice(None)] * p_mapping[param].dim()
530+
unpadding_slices[padding_dim] = slice(None, origin_shape[0])
531+
destination[prefix + name] = p_mapping[param][tuple(unpadding_slices)]
532+
else:
533+
destination[prefix + name] = p_mapping[param]
528534
else:
529535
destination[prefix + name] = p_mapping[param]
530536
del p_mapping
@@ -653,12 +659,23 @@ def load(
653659
if state_key in state_dict:
654660
input_param = state_dict[state_key]
655661

662+
global_shape = dest_tensor.shape
656663
if source_device_mesh is not None and source_sharding_spec is not None:
657664
global_shape = get_global_shape(dest_tensor)
658-
padding_num = global_shape[0] - input_param.shape[0]
659-
if padding_num > 0:
660-
padding_data = torch.zeros_like(input_param[:padding_num, ...])
661-
input_param = torch.cat((input_param, padding_data), dim=0)
665+
666+
padding_dim = search_padding_dim(global_shape, input_param.shape)
667+
if padding_dim is not None:
668+
padding_num = global_shape[padding_dim] - input_param.shape[padding_dim]
669+
padding_data = torch.zeros(
670+
*input_param.shape[:padding_dim],
671+
padding_num,
672+
*input_param.shape[padding_dim + 1 :],
673+
device=input_param.device,
674+
dtype=input_param.dtype,
675+
)
676+
input_param = torch.cat((input_param, padding_data), dim=padding_dim)
677+
678+
if source_device_mesh is not None and source_sharding_spec is not None:
662679
input_param = distribute_tensor(input_param, source_device_mesh, source_sharding_spec)
663680
elif shard_fn is not None and gather_fn is not None:
664681
input_param = distribute_tensor_with_customization(
@@ -896,7 +913,11 @@ def state_dict_shard(
896913

897914
if self.params_info is not None:
898915
origin_shape = self.params_info["name2shape"][name]
899-
gathered_param = gathered_param[: origin_shape[0], ...]
916+
padding_dim = search_padding_dim(gathered_param.shape, origin_shape)
917+
if padding_dim is not None:
918+
unpadding_slices = [slice(None)] * gathered_param.dim()
919+
unpadding_slices[padding_dim] = slice(None, origin_shape[0])
920+
gathered_param = gathered_param[tuple(unpadding_slices)]
900921

901922
block, block_size = sharder.append_param(prefix + name, gathered_param)
902923
if block is not None:

colossalai/zero/gemini/gemini_optimizer.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from colossalai.accelerator import get_accelerator
1515
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
16-
from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param
16+
from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param, search_padding_dim
1717
from colossalai.interface import OptimizerWrapper
1818
from colossalai.logging import get_dist_logger
1919
from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
@@ -705,7 +705,7 @@ def load_single_param_states(self, param_id: int, saved_states: dict):
705705
Load saved optimizer states into parameter with given id.
706706
"""
707707

708-
def cast(param, state_range, value, global_shape, key=None):
708+
def cast(param, state_range, value, global_shape, origin_shape, key=None):
709709
"""
710710
Make a copy of the needed segment of value and cast it to device of param.
711711
"""
@@ -722,11 +722,21 @@ def cast(param, state_range, value, global_shape, key=None):
722722

723723
if is_dtensor:
724724
global_shape = get_global_shape(real_param)
725-
padding_num = global_shape[0] - origin_shape[0]
725+
726+
padding_dim = search_padding_dim(global_shape, origin_shape)
727+
if padding_dim is not None:
728+
padding_num = global_shape[padding_dim] - origin_shape[padding_dim]
726729
value = torch.reshape(value, origin_shape)
727-
if padding_num > 0:
728-
padding_data = torch.zeros_like(value[:padding_num, ...])
729-
value = torch.cat((value, padding_data), dim=0).contiguous()
730+
padding_data = torch.zeros(
731+
*value.shape[:padding_dim],
732+
padding_num,
733+
*value.shape[padding_dim + 1 :],
734+
device=value.device,
735+
dtype=value.dtype,
736+
)
737+
value = torch.cat((value, padding_data), dim=padding_dim).contiguous()
738+
739+
if is_dtensor:
730740
value = distribute_tensor(value, sharding_spec=shard_spec, device_mesh=device_mesh)
731741
elif is_customized_distributed:
732742
value = torch.reshape(value, global_shape)
@@ -753,7 +763,7 @@ def cast(param, state_range, value, global_shape, key=None):
753763
origin_shape = global_shape
754764

755765
for k, v in saved_states.items():
756-
updated_states[k] = cast(fake_param, state_range, v, global_shape, k)
766+
updated_states[k] = cast(fake_param, state_range, v, global_shape, origin_shape, k)
757767
del v # clean loaded states
758768
self.optim.state[fake_param].update(updated_states)
759769

tests/test_checkpoint_io/test_gemini_checkpoint_io.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
120120
for group in optimizer.param_groups:
121121
group["lr"] = 0.1
122122

123-
optimizer.zero_grad()
124123
with shared_tempdir() as tempdir:
125124
model_ckpt_path = f"{tempdir}/model"
126125
optimizer_ckpt_path = f"{tempdir}/optimizer"

0 commit comments

Comments
 (0)