-
Notifications
You must be signed in to change notification settings - Fork 613
[moe] brings batch/sequence-wise load balance loss #2061
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…d seq-wise aux loss for load balance
| job_config, parallel_dims=parallel_dims, ft_manager=self.ft_manager | ||
| ) | ||
|
|
||
| self.loss_fn = functools.partial( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can add a condition here to wrap loss or not for MoE. for now all models in torchtitan only return a single output so its ok for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If subsume this moe loss wrapper into build_loss_fn we can avoid adding the logic here.
wwwjn
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! @shuhuayu is working on a more formal review, and I have some house-keeping comments
|
|
||
|
|
||
| @dataclass | ||
| class ExtraLosses: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This section is specifically for MoE load balancing loss for now, do you foresee any other loss related params will be used in this section? If not, let's make the name for descriptive and specific
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Followup here. Should we merge these configs to the Model dataclass?
| load_balance_loss_weight: float = 0 | ||
| """Weight of load balance loss""" | ||
|
|
||
| load_balance_coeff: float | None = 1e-3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably rename this to loss_free_load_balance_coeff? And IIUC because it's loss free, we need to set it to none if we use loss-based load balancing, otherwise it will register a optimizer hook here:
torchtitan/torchtitan/components/optimizer.py
Line 411 in 58fa181
| if _should_register_moe_balancing_hook(model_parts): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think both loss-free and loss-based load balancing are used simultaneously in deepseek v3.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes DPSKV3 (and GLM 4.5, as i know) uses both.
load_balance_coeff is the name used in the current repo, and yes maybe we should name them properly.
| ) | ||
|
|
||
| @staticmethod | ||
| @torch.compile(fullgraph=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n00b q: Do we always want to compile this loss? Is it for speed purpose? Should we provide options for users to control whether they want to compile or not, like if job_config.compile.enable and "loss" in job_config.compile.components in loss.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep for speed up. Idk when we have compile + full graph will it automatically compiled or not (i would expect so)
| def moe_loss( | ||
| pred: tuple[torch.Tensor, torch.Tensor] | torch.Tensor, | ||
| labels: torch.Tensor, | ||
| loss_fn: LossFunction, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we could have a consistent API with other loss function - Taking job_config as input , and plug-in the loss like other loss Function in TrainSpec:
| build_loss_fn=build_cross_entropy_loss, |
So that we could avoid the change in train.py. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. I think we can use a new build_loss_fn for models that possibly have moe. Or we can update build_cross_entropy_loss by checking whether moe is enabled from config here
torchtitan/torchtitan/components/loss.py
Line 29 in ad9f188
| if job_config.compile.enable and "loss" in job_config.compile.components: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you mean smth like build_multiple_loss? or we do build_ce_and_moe_loss and build_mse_and_moe_loss?
shuhuayu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the pr @rakkit ! Made some comments here.
| indices: torch.Tensor, # Shape: (B*S, K) - Selected Expert Indices | ||
| B: int, # Batch size | ||
| S: int, # Sequence length (T in the paper) | ||
| top_k: int, # K_r |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The K_r here is the same with K elsewhere in this function right? Maybe we can use a consistent notation top_k in all comments, and tell people this is K_r in the deepseek paper. Similarly we can use N to denote the number of routed experts and tell people this is N_r in the deepseek paper.
| # 1. Reshape inputs to handle each sequence separately: (B, S, N) | ||
| # This ensures we calculate P_i and f_i per sequence (Eq 20 & 18). | ||
| scores_per_seq = scores.view(B, S, N) | ||
| indices_per_seq = indices.view(B, S, top_k) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not used afterwards.
| indices_per_seq = indices.view(B, S, top_k) |
| # f_i = (N / (K * T)) * count_i | ||
|
|
||
| # Flatten the top-k dimension to count hits per sequence: (B, S*K) | ||
| flat_indices_per_seq = indices_per_seq.view(B, -1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| flat_indices_per_seq = indices_per_seq.view(B, -1) | |
| batch_indices_per_seq = indices.flatten(1) |
| selection_counts = torch.zeros((B, N), device=scores.device, dtype=scores.dtype) | ||
| src = torch.ones_like(flat_indices_per_seq, dtype=scores.dtype) | ||
| selection_counts.scatter_add_(1, flat_indices_per_seq, src) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems to me we do not need to create a new src here. We may consider using torch.bincount to save memory.
| selection_counts = torch.zeros((B, N), device=scores.device, dtype=scores.dtype) | |
| src = torch.ones_like(flat_indices_per_seq, dtype=scores.dtype) | |
| selection_counts.scatter_add_(1, flat_indices_per_seq, src) | |
| offset = (torch.arange(B, device=batch_indices_per_seq.device).unsqueeze(1) * N) | |
| flat_indices = (batch_indices_per_seq + offset).reshape(-1) | |
| selection_counts = torch.bincount(flat_indices, minlength=B * N).reshape(B, N) | |
| selection_counts = selection_counts.to(dtype=scores.dtype) |
| super().__init__() | ||
|
|
||
| num_experts = moe_args.num_experts | ||
| self.topk = moe_args.top_k |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for nit
| self.topk = moe_args.top_k | |
| self.top_k = moe_args.top_k |
| job_config, parallel_dims=parallel_dims, ft_manager=self.ft_manager | ||
| ) | ||
|
|
||
| self.loss_fn = functools.partial( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If subsume this moe loss wrapper into build_loss_fn we can avoid adding the logic here.
| load_balance_loss_weight: float = 0 | ||
| """Weight of load balance loss""" | ||
|
|
||
| load_balance_coeff: float | None = 1e-3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| load_balance_coeff: float | None = 1e-3 | |
| load_balance_bias_coeff: float | None = 1e-3 |
| losses_config = job_config.model.extra_losses | ||
| self.moe_args.load_balance_loss_type = losses_config.load_balance_loss_type | ||
| self.moe_args.load_balance_loss_weight = losses_config.load_balance_loss_weight | ||
| self.moe_args.load_balance_coeff = losses_config.load_balance_coeff |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| self.moe_args.load_balance_coeff = losses_config.load_balance_coeff | |
| self.moe_args.load_balance_bias_coeff = losses_config.load_balance_bias_coeff |
| if isinstance(pred, tuple): | ||
| pred, load_balance_loss = pred | ||
| loss = loss_fn(pred, labels) | ||
| # USE STE to make the magnitude of loss remain the same |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can be more explicit here.
| # USE STE to make the magnitude of loss remain the same | |
| # Add auxiliary loss to the computation graph for gradients in the backward pass, | |
| # but cancel out its numeric value so the forward pass only logs language model task loss. |
| ) | ||
| out = out.reshape(bs, slen, dim) | ||
| return out | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a draft PR for:
For now, it only applies to the DeepSeek model, but I can add it for all other moe models at the end.
(also, we dont log the aux loss, but i can add it in optimizer hook to do this if you want)
The main concern is that the aux loss does not work well with PP. From what I have tested, it works well only with 1F1B. And it is broken for ZBV or interleaved 1f1b.
To test it:

CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --model.extra_losses.load_balance_loss_weight=0.001