Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/twinkle/metric/train_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def calculate(self):
self.lr = self.lr[0]
if isinstance(self.lr, list):
for idx, lr in enumerate(self.lr):
results[f'learning rate(param group {idx+1})'] = lr
results[f'learning rate(param group {idx + 1})'] = lr
else:
results['learning rate'] = self.lr
if self.step is not None:
Expand All @@ -54,7 +54,7 @@ def calculate(self):
if interval < 60:
results['total time elapse'] = f'{(time.time() - self.start_time):.0f} seconds'
else:
results['total time elapse'] = f'{(time.time() - self.start_time)/60:.1f} minutes'
results['total time elapse'] = f'{(time.time() - self.start_time) / 60:.1f} minutes'
results['speed'] = f'{speed:.2f} iters/s'
self.reset()
return results
46 changes: 39 additions & 7 deletions src/twinkle/model/transformers/moe/expert_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def apply_expert_parallel(model: nn.Module, device_mesh: DeviceMesh, config: dic
if ep_world_size <= 1:
return model

ep_fsdp_enabled = device_mesh.is_implicit_ep_fsdp_enabled()

if cfg.pad_to_max:
raise NotImplementedError('pad_to_max is not implemented.')
if cfg.all_to_all != 'torch':
Expand All @@ -44,7 +46,7 @@ def apply_expert_parallel(model: nn.Module, device_mesh: DeviceMesh, config: dic
raise RuntimeError('EP process group is not available in device_mesh.')

for block in find_moe_blocks(model):
shard_experts(block, device_mesh, cfg)
shard_experts(block, device_mesh, cfg, ep_fsdp_enabled=ep_fsdp_enabled)
patch_forward(block, device_mesh, cfg)

return model
Expand Down Expand Up @@ -75,7 +77,8 @@ def find_moe_blocks(model: nn.Module) -> Iterable[nn.Module]:
return blocks


def shard_experts(block: nn.Module, device_mesh: DeviceMesh, cfg: ExpertParallelConfig) -> None:
def shard_experts(block: nn.Module, device_mesh: DeviceMesh, cfg: ExpertParallelConfig, *,
ep_fsdp_enabled: bool) -> None:
num_experts = _get_num_experts(block)
ep_world_size = device_mesh.ep_world_size
ep_rank = device_mesh.ep_rank
Expand All @@ -88,6 +91,9 @@ def shard_experts(block: nn.Module, device_mesh: DeviceMesh, cfg: ExpertParallel
local_end = local_start + experts_per_rank

if isinstance(block.experts, nn.ModuleList):
if ep_fsdp_enabled:
raise NotImplementedError('EP+EP_FSDP currently does not support MoE experts stored as nn.ModuleList. '
'Only tensor experts (gate_up_proj/down_proj) are supported.')
local_experts = nn.ModuleList(block.experts[local_start:local_end])
block.experts = local_experts
block._ep_tensor_experts = False
Expand All @@ -102,6 +108,7 @@ def shard_experts(block: nn.Module, device_mesh: DeviceMesh, cfg: ExpertParallel
block._ep_rank = ep_rank
block._ep_world_size = ep_world_size
block._ep_ignore_shared_experts = cfg.ignore_shared_experts
block._ep_fsdp_enabled = ep_fsdp_enabled


def patch_forward(block: nn.Module, device_mesh: DeviceMesh, cfg: ExpertParallelConfig) -> None:
Expand Down Expand Up @@ -193,11 +200,14 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs):
group=ep_group,
)
recv_out = torch.empty_like(recv_tokens)
for expert_id in torch.unique(recv_expert_ids).tolist():
idx = (recv_expert_ids == expert_id).nonzero(as_tuple=False).view(-1)
expert_in = recv_tokens.index_select(0, idx)
expert_out = _run_expert(block, expert_id, expert_in)
recv_out.index_copy_(0, idx, expert_out)
if getattr(block, '_ep_fsdp_enabled', False) and getattr(block, '_ep_tensor_experts', False):
recv_out = _run_experts_ep_fsdp_batch(block, recv_tokens, recv_expert_ids)
else:
for expert_id in torch.unique(recv_expert_ids).tolist():
idx = (recv_expert_ids == expert_id).nonzero(as_tuple=False).view(-1)
expert_in = recv_tokens.index_select(0, idx)
expert_out = _run_expert(block, expert_id, expert_in)
recv_out.index_copy_(0, idx, expert_out)

send_out = torch.empty_like(send_tokens)
send_out = dist_nn.functional.all_to_all_single(
Expand Down Expand Up @@ -327,6 +337,7 @@ def _run_expert(block: nn.Module, expert_id: int, expert_in: torch.Tensor) -> to
expert = block.experts[expert_id]
return _run_module_with_casting(expert, expert_in)
experts = block.experts

gate_up = experts.gate_up_proj[expert_id]
down = experts.down_proj[expert_id]
compute_dtype = gate_up.dtype
Expand All @@ -340,6 +351,27 @@ def _run_expert(block: nn.Module, expert_id: int, expert_in: torch.Tensor) -> to
return out


def _run_experts_ep_fsdp_batch(
block: nn.Module,
expert_in: torch.Tensor,
local_expert_ids: torch.Tensor,
) -> torch.Tensor:
input_dtype = expert_in.dtype
if expert_in.numel() == 0:
return torch.empty_like(expert_in)
experts = block.experts
top_k_index = local_expert_ids.view(-1, 1).to(torch.long)
top_k_weights = torch.ones(
(expert_in.shape[0], 1),
dtype=expert_in.dtype,
device=expert_in.device,
)
out = experts(expert_in, top_k_index, top_k_weights)
if out.dtype != input_dtype:
out = out.to(input_dtype)
return out


def _module_compute_dtype(module: nn.Module, default: torch.dtype) -> torch.dtype:
for param in module.parameters():
if param.dtype.is_floating_point:
Expand Down
68 changes: 66 additions & 2 deletions src/twinkle/model/transformers/strategy/native_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
from torch import nn
from torch.distributed.device_mesh import DeviceMesh as TorchDeviceMesh
from torch.distributed.fsdp import fully_shard
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Set

from twinkle.utils import DeviceMesh, Platform
Expand All @@ -26,16 +27,32 @@ def __init__(self,
def wrap_model(self, model, optimizer=None):
if self.device_mesh is None:
return model, optimizer
from torch.distributed.fsdp import fully_shard
fsdp_mesh = _build_fsdp_mesh(self.device_mesh)
if fsdp_mesh is not None:
ep_fsdp_mode = _is_ep_fsdp_mode_enabled(
self.device_mesh,
self.enable_ep,
)
if self.enable_ep:
_ensure_moe_patched_if_needed(model, self.device_mesh)
_place_ep_experts_on_local_device(model, self.device_mesh)
mp_policy = _build_mp_policy(self.mixed_precision)
reshard_after_forward = self.fsdp_config.get('reshard_after_forward', True)
ignored_params = _collect_expert_params(model) if self.enable_ep else None

if ep_fsdp_mode:
_ensure_ep_fsdp_supported(model)
ep_fsdp_mesh = _build_ep_fsdp_mesh(self.device_mesh)
if ep_fsdp_mesh is None:
raise RuntimeError(
'Implicit EP_FSDP requires dp dim with size > 1, but could not build an ep_fsdp mesh from dp.')
_maybe_shard_ep_expert_blocks(
model,
mesh=ep_fsdp_mesh,
reshard_after_forward=reshard_after_forward,
mp_policy=mp_policy,
)

_maybe_shard_layers(
model,
mesh=fsdp_mesh,
Expand Down Expand Up @@ -85,6 +102,21 @@ def _build_fsdp_mesh(device_mesh: DeviceMesh) -> Optional[TorchDeviceMesh]:
return TorchDeviceMesh(device_mesh.device_type, flat_mesh, mesh_dim_names=('fsdp', ))


def _build_ep_fsdp_mesh(device_mesh: DeviceMesh) -> Optional[TorchDeviceMesh]:
if device_mesh is None or not device_mesh.has_dim('dp'):
return None
ranks = device_mesh.get_ranks_for_dims('dp')
if len(ranks) <= 1:
return None
return TorchDeviceMesh(device_mesh.device_type, ranks, mesh_dim_names=('ep_fsdp', ))


def _is_ep_fsdp_mode_enabled(device_mesh: Optional[DeviceMesh], enable_ep: bool) -> bool:
if not enable_ep or device_mesh is None:
return False
return device_mesh.is_implicit_ep_fsdp_enabled()


def _collect_expert_params(model: nn.Module) -> Optional[Set[nn.Parameter]]:
ignored: Set[nn.Parameter] = set()
ep_patched = False
Expand Down Expand Up @@ -137,9 +169,41 @@ def _ensure_moe_patched_if_needed(model: nn.Module, device_mesh: DeviceMesh) ->
'Call apply_expert_parallel(model, device_mesh, config) before wrapping with FSDP2.')


def _ensure_ep_fsdp_supported(model: nn.Module) -> None:
for module in model.modules():
if not getattr(module, '_ep_patched', False):
continue
experts = getattr(module, 'experts', None)
if isinstance(experts, nn.ModuleList):
raise NotImplementedError('EP+EP_FSDP currently does not support MoE experts stored as nn.ModuleList. '
'Only tensor experts (gate_up_proj/down_proj) are supported.')


def _maybe_shard_ep_expert_blocks(model: nn.Module, *, mesh: TorchDeviceMesh, reshard_after_forward: Optional[bool],
mp_policy: 'MixedPrecisionPolicy') -> int:
from torch.distributed.tensor import Shard
sharded_blocks = 0
for module in model.modules():
if not getattr(module, '_ep_patched', False):
continue
experts = getattr(module, 'experts', None)
if experts is None:
continue
# Correct EP+EP_FSDP behavior: only experts are sharded on ep_fsdp mesh.
# Non-expert params (router/gate etc.) are left to global FSDP wrapping.
fully_shard(
experts,
mesh=mesh,
reshard_after_forward=reshard_after_forward,
mp_policy=mp_policy,
shard_placement_fn=lambda param: Shard(1),
)
sharded_blocks += 1
return sharded_blocks


def _maybe_shard_layers(model: nn.Module, *, mesh: TorchDeviceMesh, reshard_after_forward: Optional[bool],
mp_policy: 'MixedPrecisionPolicy', ignored_params: Optional[Set[nn.Parameter]]) -> None:
from torch.distributed.fsdp import fully_shard
layers = getattr(model, 'layers', None)
if not isinstance(layers, nn.ModuleList):
return
Expand Down
27 changes: 24 additions & 3 deletions src/twinkle/utils/grad_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,23 @@ def normalize_and_clip_grad_norm(parameters: Iterable[torch.nn.Parameter],

has_dtensor_grad = any(hasattr(grad, 'to_local') for grad in grads)
has_local_tensor_grad = any(not hasattr(grad, 'to_local') for grad in grads)
if not (has_dtensor_grad and has_local_tensor_grad):
dtensor_mesh_keys = set()
for grad in grads:
if not hasattr(grad, 'to_local'):
continue
mesh = getattr(grad, 'device_mesh', None)
if mesh is None:
dtensor_mesh_keys.add('dtensor:unknown')
continue
try:
mesh_key = (tuple(mesh.mesh.flatten().tolist()), tuple(mesh.mesh_dim_names or ()))
except Exception:
mesh_key = repr(mesh)
dtensor_mesh_keys.add(mesh_key)

has_mixed_dtensor_mesh = len(dtensor_mesh_keys) > 1

if not (has_dtensor_grad and has_local_tensor_grad) and not has_mixed_dtensor_mesh:
grad_norm = torch.nn.utils.clip_grad_norm_(
parameters,
max_grad_norm,
Expand Down Expand Up @@ -64,6 +80,11 @@ def _local_grad(grad: torch.Tensor) -> torch.Tensor:
reduce_device = torch.device(Platform.get_local_device())
else:
reduce_device = torch.device('cpu')
reduce_group = group
if has_mixed_dtensor_mesh:
# Different DTensor meshes cannot be reduced by DTensor op propagation (e.g. aten.stack).
# Fall back to world reduction over local shards.
reduce_group = None

if norm_type == float('inf'):
local_norm = 0.0
Expand All @@ -74,7 +95,7 @@ def _local_grad(grad: torch.Tensor) -> torch.Tensor:
local_norm = max(local_norm, local_grad.detach().abs().max().item())
total_norm_tensor = torch.tensor(local_norm, device=reduce_device, dtype=torch.float32)
if dist.is_initialized():
dist.all_reduce(total_norm_tensor, op=dist.ReduceOp.MAX, group=group)
dist.all_reduce(total_norm_tensor, op=dist.ReduceOp.MAX, group=reduce_group)
total_norm = float(total_norm_tensor.item())
else:
local_sq = 0.0
Expand All @@ -85,7 +106,7 @@ def _local_grad(grad: torch.Tensor) -> torch.Tensor:
local_sq += local_grad.detach().float().pow(2).sum().item()
total_sq_tensor = torch.tensor(local_sq, device=reduce_device, dtype=torch.float32)
if dist.is_initialized():
dist.all_reduce(total_sq_tensor, op=dist.ReduceOp.SUM, group=group)
dist.all_reduce(total_sq_tensor, op=dist.ReduceOp.SUM, group=reduce_group)
total_norm = float(total_sq_tensor.sqrt().item())

clip_coef = float(max_grad_norm) / (total_norm + 1e-6)
Expand Down
38 changes: 38 additions & 0 deletions src/twinkle/utils/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,44 @@ def get_dim_group(self, dims):
key = tuple(c for i, c in enumerate(coord) if i != dim_idx)
return group_map[key]

def get_ranks_for_dims(self, dims):
if self.mesh_dim_names is None:
raise ValueError('mesh_dim_names is not set.')
if isinstance(dims, str):
dims = (dims, )
for dim_name in dims:
if dim_name not in self.mesh_dim_names:
raise ValueError(f"Dimension '{dim_name}' not found in mesh. Available: {self.mesh_dim_names}")

coord = self._get_coord()
if coord is None:
raise RuntimeError('Current rank is not found in mesh.')

slices = []
for i, dim_name in enumerate(self.mesh_dim_names):
if dim_name in dims:
slices.append(slice(None))
else:
slices.append(coord[i])
return sorted(self.mesh[tuple(slices)].flatten().tolist())

def is_implicit_ep_fsdp_enabled(self) -> bool:
ep_world_size = self.ep_world_size or 1
dp_world_size = self.dp_world_size or 1
if ep_world_size <= 1 or dp_world_size <= 1:
return False

world_size = self.world_size or 1
if world_size % ep_world_size != 0:
raise ValueError(f'world_size ({world_size}) must be divisible by ep_world_size ({ep_world_size}) '
'to infer implicit EP_FSDP from dp.')
expected_dp_size = world_size // ep_world_size
if dp_world_size != expected_dp_size:
raise ValueError(f'Implicit EP_FSDP requires dp_world_size == world_size // ep_world_size, '
f'but got dp_world_size={dp_world_size}, world_size={world_size}, '
f'ep_world_size={ep_world_size}.')
return True
Comment on lines +211 to +226
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic in is_implicit_ep_fsdp_enabled seems to assume that the device mesh consists only of dp and ep dimensions. The check dp_world_size != world_size // ep_world_size will likely fail for more complex meshes that also include other parallelism dimensions like tensor parallelism (tp) or pipeline parallelism (pp), as it would incorrectly require tp_size * pp_size == 1. This could be a potential limitation or bug. Please consider making the logic more robust to handle arbitrary mesh dimensions or clearly documenting this limitation.


@property
def order(self):
"""The order of the dimensions for megatron"""
Expand Down
Loading