Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ckpt): support universal ckpt MOE extended Edition #423

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions internlm/model/modeling_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
convert_attn_args_to_kwargs,
convert_attn_kwargs_to_args,
internlm1_mha_pre_load_convert,
internlm1_mha_save_convert,
)
from internlm.solver.activation_checkpoint import activation_checkpoint
from internlm.utils.logger import get_logger
Expand Down Expand Up @@ -122,7 +121,7 @@ def __init__(
)

# Compatible with the name of internlm1 Wqkv linear layer
self.mixer.register_checkpoint_compatibility_hooks(internlm1_mha_pre_load_convert, internlm1_mha_save_convert)
self.mixer.register_checkpoint_compatibility_hooks(internlm1_mha_pre_load_convert)

self.dropout1 = nn.Dropout(drop_rate)
self.dropout2 = nn.Dropout(drop_rate)
Expand Down
3 changes: 1 addition & 2 deletions internlm/model/modeling_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
convert_attn_args_to_kwargs,
convert_attn_kwargs_to_args,
internlm1_mha_pre_load_convert,
internlm1_mha_save_convert,
)
from internlm.solver.activation_checkpoint import activation_checkpoint
from internlm.utils.logger import get_logger
Expand Down Expand Up @@ -112,7 +111,7 @@ def __init__(
)

# Compatible with the name of internlm1 Wqkv linear layer
self.mixer.register_checkpoint_compatibility_hooks(internlm1_mha_pre_load_convert, internlm1_mha_save_convert)
self.mixer.register_checkpoint_compatibility_hooks(internlm1_mha_pre_load_convert)

self.dropout1 = nn.Dropout(drop_rate)
self.dropout2 = nn.Dropout(drop_rate)
Expand Down
15 changes: 15 additions & 0 deletions internlm/model/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,12 +649,19 @@ def __init__(
self.tp_dim = 1
else:
super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
self.tp_dim = -1

self.complete_size = [out_features, in_features]
setattr(self.weight, "offset", self.offset)
setattr(self.weight, "complete_size", [out_features, in_features])
setattr(self.weight, "tp_dim", self.tp_dim)

if bias:
if self.tp_dim == 0:
setattr(self.bias, "tp_dim", 0)
else:
setattr(self.bias, "tp_dim", -1)

def forward(self, input: torch.Tensor, batch_sizes: torch.Tensor = None) -> torch.Tensor: # pylint: disable=W0622
_class_name = self.__class__.__name__
assert self._communicator is not None, f"{_class_name} should register with a communicator first."
Expand Down Expand Up @@ -904,16 +911,24 @@ def __init__( # pylint: disable=W0231, W0233
self.weight = nn.Parameter(
torch.empty(num_groups, in_features, local_multiple * multiple_of, device=device, dtype=dtype)
)
self.tp_dim = 2
assert self.weight.shape[self.tp_dim] != out_features
elif split_mode == "row":
self.weight = nn.Parameter(
torch.empty(num_groups, local_multiple * multiple_of, out_features, device=device, dtype=dtype)
)
self.tp_dim = 1
assert self.weight.shape[self.tp_dim] != in_features
elif split_mode == "weight":
self.weight = nn.Parameter(
torch.empty(local_multiple * multiple_of, out_features, device=device, dtype=dtype)
)
self.tp_dim = 0
else: # none
self.weight = nn.Parameter(torch.empty(num_groups, in_features, out_features, device=device, dtype=dtype))
self.tp_dim = -1

setattr(self.weight, "tp_dim", self.tp_dim)

self.register_parameter("bias", None)
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
Expand Down
8 changes: 6 additions & 2 deletions internlm/model/modules/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ def register_checkpoint_compatibility_hooks(
# hoping that model developers will make good use of it when adapting.
# Is this interface already meeting all reasonable requirements?
self._register_load_state_dict_pre_hook(pre_load_hook, with_module=True)
self._register_state_dict_hook(pre_save_hook)
if pre_save_hook is not None:
logger.warning("pre_save_hook may destory universal_ckpt")
self._register_state_dict_hook(pre_save_hook)

def forward(self, x, inference_params=None, **kwargs):
if inference_params is None:
Expand Down Expand Up @@ -471,7 +473,9 @@ def register_checkpoint_compatibility_hooks(
# hoping that model developers will make good use of it when adapting.
# Is this interface already meeting all reasonable requirements?
self._register_load_state_dict_pre_hook(pre_load_hook, with_module=True)
self._register_state_dict_hook(pre_save_hook)
if pre_save_hook is not None:
logger.warning("pre_save_hook may destory universal_ckpt")
self._register_state_dict_hook(pre_save_hook)

def forward(self, x, inference_params=None, **kwargs):
if inference_params is None:
Expand Down
2 changes: 1 addition & 1 deletion internlm/model/moe/base_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(
) -> None:
super().__init__()
# for elastic expert paralle, experts may have multiple groups
expert_group_name = f"moe_ep_size_{ep_size}"
expert_group_name = "moe_ep_group"
if expert_group_name not in gpc.expert_parallel_group_names:
gpc.expert_parallel_group_names.append(expert_group_name)
self.gate = gate
Expand Down
62 changes: 45 additions & 17 deletions internlm/solver/optimizer/hybrid_zero_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,16 @@ def __init__(
assert self._param_bcast_sync_handler is not None

self._isp_communicator = isp_communicator
self.meta_for_zero = None
self.meta_for_zero = {"base_groups": {}}
self.meta_for_moe = {"base_groups": {}}
self.moe_group = []
# iterate over the param group in the optimizer
# partition these param groups for data parallel training
# and add buffers to parameter store for future access
for group_id, param_group in enumerate(self.optim.param_groups):
if "moe" in param_group and param_group["moe"]:
self.moe_group.append(group_id)

group_params = param_group["params"]

# set the dtype for each param group
Expand All @@ -166,8 +171,6 @@ def __init__(
self._zero_local_rank.append(gpc.get_local_rank(zero_mode))
self._zero_world_size.append(gpc.get_world_size(zero_mode))

if gpc.config.ckpt.need_metadata and self.meta_for_zero is None:
self.meta_for_zero = [{} for _ in range(gpc.get_world_size(zero_mode))]
# TODO _broadcast_parallel_mode is not only used in broadcast, maybe can change its name
self._broadcast_parallel_mode.append(zero_mode)

Expand Down Expand Up @@ -232,6 +235,12 @@ def __init__(
# managed by this data parallel rank
param_group["params"] = [fp32_flat_current_rank]

base_groups = self.optim.state_dict()["param_groups"][group_id]["params"]
if group_id in self.moe_group:
self.meta_for_moe["base_groups"][group_id] = base_groups
else:
self.meta_for_zero["base_groups"][group_id] = base_groups

# set reduction state
for param in self._fp16_param_groups[group_id]:
self._param_store.set_param_reduction_state(param, False)
Expand Down Expand Up @@ -285,20 +294,39 @@ def _partition_param_list(self, group_id, param_group):
numel_per_rank[rank_to_go] += param.numel()

if gpc.config.ckpt.need_metadata:
if group_id not in self.meta_for_zero[rank_to_go]:
self.meta_for_zero[rank_to_go][group_id] = {}

from internlm.train.pipeline import map_fqn_local_to_global

global_fqn = map_fqn_local_to_global[param.fqn] if param.fqn in map_fqn_local_to_global else param.fqn
self.meta_for_zero[rank_to_go][group_id][global_fqn] = {
"tp_dim": getattr(param, "tp_dim", -1),
"pp": gpc.get_local_rank(ParallelMode.PIPELINE),
"zero1": rank_to_go,
"fqn": param.fqn,
"shape": param.shape,
"group_id": group_id,
}
if rank_to_go == self.zero_local_rank[group_id]:

from internlm.train.pipeline import map_fqn_local_to_global

global_fqn = (
map_fqn_local_to_global[param.fqn] if param.fqn in map_fqn_local_to_global else param.fqn
)
if group_id in self.moe_group:
if group_id not in self.meta_for_moe:
self.meta_for_moe[group_id] = {}
tp_mode = ParallelMode.WEIGHT if is_using_isp() else ParallelMode.TENSOR
self.meta_for_moe[group_id][global_fqn] = {
"tp_dim": getattr(param, "tp_dim", -1),
"tp": gpc.get_local_rank(tp_mode),
"pp": gpc.get_local_rank(ParallelMode.PIPELINE),
"ep": gpc.get_local_rank(ParallelMode.EXPERT),
"edp": gpc.get_local_rank(ParallelMode.EXPERT_DATA),
"zero1": rank_to_go,
"fqn": param.fqn,
"shape": param.shape,
"group_id": group_id,
}
else:
if group_id not in self.meta_for_zero:
self.meta_for_zero[group_id] = {}
self.meta_for_zero[group_id][global_fqn] = {
"tp_dim": getattr(param, "tp_dim", -1),
"pp": gpc.get_local_rank(ParallelMode.PIPELINE),
"zero1": rank_to_go,
"fqn": param.fqn,
"shape": param.shape,
"group_id": group_id,
}

# check whether any rank is not assigned to parameters.
for rank, params in enumerate(params_per_rank):
Expand Down
Loading
Loading