Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@
def get_param_info(optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes:
# 1. A mapping from integer param_id to param32 shape.

if optim is None:
return {}
param_info = {"id2shape": {}}

start_index = 0
for group in optim.param_groups:
for param_id, param in enumerate(group["params"], start_index):
Expand Down Expand Up @@ -527,7 +527,7 @@ def configure(
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
optimizer_params_info = get_param_info(optimizer)
params_info = get_param_info(optimizer)
if not isinstance(model, ModelWrapper):
# convert model to sync bn
# FIXME(ver217): gemini does not support sync bn
Expand Down Expand Up @@ -558,7 +558,7 @@ def configure(
**self.zero_optim_config,
**self.optim_kwargs,
tp_group=self.tp_group,
optimizer_params_info=optimizer_params_info,
params_info=params_info,
verbose=self.verbose,
)

Expand Down
11 changes: 5 additions & 6 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,7 @@ def get_param_info(optim: Optimizer):

if optim is None:
return {}
param_info = {
"param_groups": [],
"param2id": {},
"id2param": {},
"param2shape": {},
}
param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}}
start_index = 0
for group in optim.param_groups:
packed_group = {k: v for k, v in group.items() if k != "params"}
Expand Down Expand Up @@ -947,6 +942,8 @@ class HybridParallelPlugin(PipelinePluginBase):
num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1.
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.

"""

def __init__(
Expand Down Expand Up @@ -989,6 +986,7 @@ def __init__(
num_model_chunks: int = 1,
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
enable_metadata_cache: bool = True,
make_vocab_size_divisible_by: int = 64,
) -> None:
super().__init__()
assert (
Expand Down Expand Up @@ -1095,6 +1093,7 @@ def __init__(
sequence_parallelism_mode=sequence_parallelism_mode,
enable_sequence_overlap=enable_sequence_overlap,
parallel_output=parallel_output,
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
gradient_checkpoint_config=gradient_checkpoint_config,
)
self.amp_config = dict(
Expand Down
29 changes: 28 additions & 1 deletion colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@

from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.tensor.padded_tensor import (
init_as_padded_tensor,
is_padded_tensor,
to_padded_tensor,
to_unpadded_tensor,
)
from colossalai.utils import get_current_device

from .general_checkpoint_io import GeneralCheckpointIO
Expand All @@ -32,6 +38,7 @@
save_param_groups,
save_state_dict,
save_state_dict_shards,
search_padding_dim,
search_tp_partition_dim,
sharded_optimizer_loading_epilogue,
)
Expand Down Expand Up @@ -89,6 +96,8 @@ def _model_sharder(
if param is None:
continue
# Gather tensor pieces when using tensor parallel.
if is_padded_tensor(param):
param = to_unpadded_tensor(param)
param_ = gather_distributed_param(param, keep_vars=False)
block, block_size = state_dict_sharder.append_param(prefix + name, param_)
if block is not None:
Expand Down Expand Up @@ -231,7 +240,6 @@ def save_sharded_model(
# When pipeline is used, each stage produces its own shard files and index files.
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
# After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.

final_index_file_path = copy.deepcopy(save_index_file)
tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
Expand All @@ -251,6 +259,7 @@ def save_sharded_model(
use_safetensors=use_safetensors,
use_pp_format=True,
)

if control_saving:
assert (
self.dp_rank == 0 and self.tp_rank == 0
Expand Down Expand Up @@ -867,6 +876,11 @@ def gather_from_sharded_optimizer_state(
dist.all_gather(gather_tensor, v, group=tp_group)
v = torch.cat(gather_tensor, dim=partition_dim)

padding_dim = search_padding_dim(v.shape, original_shape)
if padding_dim is not None:
v = init_as_padded_tensor(v, v.shape[padding_dim], original_shape[padding_dim], padding_dim)
v = to_unpadded_tensor(v)

state_[k] = v.detach().clone().to(device)

return state_
Expand Down Expand Up @@ -899,6 +913,19 @@ def shard_from_complete_optimizer_state(
if isinstance(v, torch.Tensor) and k != "step":
# Shard state along tensor parallel group.
partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size)
global_shape = current_shape
if partition_dim is not None:
# pad embedding params
global_shape = (
*current_shape[:partition_dim],
current_shape[partition_dim] * self.tp_size,
*current_shape[partition_dim + 1 :],
)

padding_dim = search_padding_dim(global_shape, original_shape)
if padding_dim is not None:
v = to_padded_tensor(v, global_shape[padding_dim], padding_dim)

if partition_dim is not None:
slice_size = current_shape[partition_dim]
v = v.split(slice_size, dim=partition_dim)[self.tp_rank]
Expand Down
9 changes: 9 additions & 0 deletions colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,15 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz
return partition_dim


def search_padding_dim(global_shape: torch.Size, original_shape: torch.Size) -> Optional[int]:
padding_dim = None
for dim, length in enumerate(global_shape):
if length > original_shape[dim]:
padding_dim = dim
break
return padding_dim


# ======================================
# Helper classes and functions for saving shard file
# ======================================
Expand Down
7 changes: 5 additions & 2 deletions colossalai/shardformer/layer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from ._operation import all_to_all_comm
from .attn import AttnMaskType, ColoAttention
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
from .loss import cross_entropy_1d
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule
Expand All @@ -25,6 +25,9 @@
"FusedRMSNorm",
"FusedLinear1D_Col",
"ParallelModule",
"PaddingEmbedding",
"PaddingLMHead",
"VocabParallelLMHead1D",
"AttnMaskType",
"ColoAttention",
"all_to_all_comm",
Expand Down
111 changes: 95 additions & 16 deletions colossalai/shardformer/layer/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
)

from ._operation import gather_forward_split_backward, reduce_forward
from .parallel_module import ParallelModule
from .parallel_module import PaddingParallelModule, ParallelModule
from .utils import create_randomizer_with_offset

__all__ = ["Embedding1D", "VocabParallelEmbedding1D"]
__all__ = ["Embedding1D", "VocabParallelEmbedding1D", "PaddingEmbedding"]


class Embedding1D(ParallelModule):
Expand Down Expand Up @@ -161,7 +161,80 @@ def forward(self, input_: Tensor) -> Tensor:
return output_parallel


class VocabParallelEmbedding1D(ParallelModule):
class PaddingEmbedding(PaddingParallelModule):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: torch.dtype = None,
device: torch.device = None,
weight: Optional[nn.Parameter] = None,
make_vocab_size_divisible_by: int = 64,
*args,
**kwargs,
):
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.embed_args = args
self.embed_kwargs = kwargs
self.padding_idx = padding_idx
if num_embeddings % make_vocab_size_divisible_by != 0:
self.num_embeddings = (
num_embeddings + make_vocab_size_divisible_by - (num_embeddings % make_vocab_size_divisible_by)
)
# create weight and bias
if weight is None:
factory_kwargs = {"device": device, "dtype": dtype}
weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)

super().__init__(self.num_embeddings, num_embeddings, weight)

if weight is None:
self.reset_parameters()

def reset_parameters(self) -> None:
init.normal_(self.weight)
self._fill_padding_idx_with_zero()

def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)

def forward(self, input: Tensor) -> Tensor:
return F.embedding(input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)

@staticmethod
def from_native_module(
module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> PaddingParallelModule:
r"""
Convert a native pytorch embedding module to a parallel module.
"""
LazyInitContext.materialize(module)
# get the origin attributes
num_embeddings = module.num_embeddings
embedding_dim = module.embedding_dim
padding_idx = module.padding_idx
device = module.weight.device
# create the parallel module
padding_embedding = PaddingEmbedding(
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
padding_idx=padding_idx,
device=device,
weight=module.weight,
*args,
**kwargs,
)

return padding_embedding


class VocabParallelEmbedding1D(PaddingParallelModule):
r"""Embedding parallelized in the vocabulary dimension.

Args:
Expand Down Expand Up @@ -201,10 +274,10 @@ def __init__(
process_group: ProcessGroup = None,
weight: Optional[nn.Parameter] = None,
weight_initializer: Callable = init.normal_(),
make_vocab_size_divisible_by: int = 64,
*args,
**kwargs,
):
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.embed_args = args
Expand All @@ -214,8 +287,23 @@ def __init__(
tensor_parallel_size = dist.get_world_size(group=process_group)
tensor_parallel_rank = dist.get_rank(group=process_group)

self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
self.num_embeddings = self.num_embeddings_per_partition
# generate weight and bias
if weight is None:
factory_kwargs = {"device": device, "dtype": dtype}
weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)

# calculate new padding size
multiple = make_vocab_size_divisible_by * tensor_parallel_size
if num_embeddings % multiple != 0:
self.num_embeddings = num_embeddings + multiple - (num_embeddings % multiple)

# resize vocabulary size
super().__init__(self.num_embeddings, num_embeddings, weight)

# deal with tensor parallelism
self.num_embeddings_per_partition = divide(self.num_embeddings, tensor_parallel_size)
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition

Expand All @@ -226,13 +314,6 @@ def __init__(
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)

# parameter
if weight is None:
factory_kwargs = {"device": device, "dtype": dtype}
self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight
if not is_distributed_tensor(self.weight):
sharded_weight = shard_rowwise(self.weight.data, process_group)
sharded_tensor_to_existing_param(sharded_weight, self.weight)
Expand All @@ -243,7 +324,7 @@ def __init__(
@staticmethod
def from_native_module(
module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
) -> PaddingParallelModule:
r"""
Convert a native pytorch embedding module to a parallel module.
"""
Expand Down Expand Up @@ -303,11 +384,9 @@ def forward(self, input_: Tensor) -> Tensor:
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0

output_parallel = F.embedding(
masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs
)

# Mask the output embedding.
embedding_output = output_parallel.clone()
embedding_output[input_mask, :] = 0.0
Expand Down
Loading