|
11 | 11 | from torch.distributed.distributed_c10d import _get_default_group |
12 | 12 |
|
13 | 13 | from colossalai.accelerator import get_accelerator |
14 | | -from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param |
| 14 | +from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param, search_padding_dim |
15 | 15 | from colossalai.interface import ModelWrapper |
16 | 16 | from colossalai.lazy import LazyTensor |
17 | 17 | from colossalai.logging import get_dist_logger |
@@ -524,7 +524,13 @@ def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True): |
524 | 524 | else: |
525 | 525 | if self.params_info is not None: |
526 | 526 | origin_shape = self.params_info["name2shape"][name] |
527 | | - destination[prefix + name] = p_mapping[param][: origin_shape[0], ...] |
| 527 | + padding_dim = search_padding_dim(p_mapping[param].shape, origin_shape) |
| 528 | + if padding_dim is not None: |
| 529 | + unpadding_slices = [slice(None)] * p_mapping[param].dim() |
| 530 | + unpadding_slices[padding_dim] = slice(None, origin_shape[0]) |
| 531 | + destination[prefix + name] = p_mapping[param][tuple(unpadding_slices)] |
| 532 | + else: |
| 533 | + destination[prefix + name] = p_mapping[param] |
528 | 534 | else: |
529 | 535 | destination[prefix + name] = p_mapping[param] |
530 | 536 | del p_mapping |
@@ -653,12 +659,23 @@ def load( |
653 | 659 | if state_key in state_dict: |
654 | 660 | input_param = state_dict[state_key] |
655 | 661 |
|
| 662 | + global_shape = dest_tensor.shape |
656 | 663 | if source_device_mesh is not None and source_sharding_spec is not None: |
657 | 664 | global_shape = get_global_shape(dest_tensor) |
658 | | - padding_num = global_shape[0] - input_param.shape[0] |
659 | | - if padding_num > 0: |
660 | | - padding_data = torch.zeros_like(input_param[:padding_num, ...]) |
661 | | - input_param = torch.cat((input_param, padding_data), dim=0) |
| 665 | + |
| 666 | + padding_dim = search_padding_dim(global_shape, input_param.shape) |
| 667 | + if padding_dim is not None: |
| 668 | + padding_num = global_shape[padding_dim] - input_param.shape[padding_dim] |
| 669 | + padding_data = torch.zeros( |
| 670 | + *input_param.shape[:padding_dim], |
| 671 | + padding_num, |
| 672 | + *input_param.shape[padding_dim + 1 :], |
| 673 | + device=input_param.device, |
| 674 | + dtype=input_param.dtype, |
| 675 | + ) |
| 676 | + input_param = torch.cat((input_param, padding_data), dim=padding_dim) |
| 677 | + |
| 678 | + if source_device_mesh is not None and source_sharding_spec is not None: |
662 | 679 | input_param = distribute_tensor(input_param, source_device_mesh, source_sharding_spec) |
663 | 680 | elif shard_fn is not None and gather_fn is not None: |
664 | 681 | input_param = distribute_tensor_with_customization( |
@@ -896,7 +913,11 @@ def state_dict_shard( |
896 | 913 |
|
897 | 914 | if self.params_info is not None: |
898 | 915 | origin_shape = self.params_info["name2shape"][name] |
899 | | - gathered_param = gathered_param[: origin_shape[0], ...] |
| 916 | + padding_dim = search_padding_dim(gathered_param.shape, origin_shape) |
| 917 | + if padding_dim is not None: |
| 918 | + unpadding_slices = [slice(None)] * gathered_param.dim() |
| 919 | + unpadding_slices[padding_dim] = slice(None, origin_shape[0]) |
| 920 | + gathered_param = gathered_param[tuple(unpadding_slices)] |
900 | 921 |
|
901 | 922 | block, block_size = sharder.append_param(prefix + name, gathered_param) |
902 | 923 | if block is not None: |
|
0 commit comments