Skip to content

Commit

Permalink
[NllbMoe] Update code to properly support loss computation (#25429)
Browse files Browse the repository at this point in the history
* update nllb_moe

* fix

* doc nits

* nits

* add a small test

* ficup

* remove adapted from
  • Loading branch information
ArthurZucker authored Aug 17, 2023
1 parent 9264fc9 commit 181d778
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 10 deletions.
23 changes: 13 additions & 10 deletions src/transformers/models/nllb_moe/modeling_nllb_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_l
return incremental_indices.long() + padding_idx


# Copied from transformers.models.switch_transformers.modeling_switch_transformers.load_balancing_loss_func with SwitchTransformers->NllbMoeModel
def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.Tensor) -> float:
r"""
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
Expand All @@ -144,6 +143,9 @@ def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.T
Returns:
The auxiliary loss.
"""
if router_probs is None:
return 0

num_experts = router_probs.shape[-1]

# cast the expert indices to int64, otherwise one-hot encoding will fail
Expand Down Expand Up @@ -699,7 +701,9 @@ def forward(
if self.is_sparse:
hidden_states, router_states = self.ffn(hidden_states, attention_mask)
else:
hidden_states = self.ffn(hidden_states)
# router_states set to None to track which layers have None gradients.
hidden_states, router_states = self.ffn(hidden_states), None

hidden_states = self.ff_dropout(hidden_states)

hidden_states = residual + hidden_states
Expand Down Expand Up @@ -830,7 +834,8 @@ def forward(
if self.is_sparse:
hidden_states, router_states = self.ffn(hidden_states, attention_mask)
else:
hidden_states = self.ffn(hidden_states)
hidden_states, router_states = self.ffn(hidden_states), None

hidden_states = self.ff_dropout(hidden_states)

hidden_states = residual + hidden_states
Expand Down Expand Up @@ -1730,7 +1735,7 @@ def forward(

if output_router_logits:
encoder_router_logits = outputs[-1]
decoder_router_logits = outputs[5 if output_attentions else 3]
decoder_router_logits = outputs[3 if output_attentions else 4]

# Compute the router loss (z_loss + auxiliary loss) for each router in the encoder and decoder
encoder_router_logits, encoder_expert_indexes = self._unpack_router_logits(encoder_router_logits)
Expand Down Expand Up @@ -1775,7 +1780,6 @@ def forward(
decoder_router_logits=outputs.decoder_router_logits,
)

# Copied from transfomers.models.switch_transformers.SwitchTransformersForConditionalGeneration._unpack_router_logits
def _unpack_router_logits(self, router_outputs):
total_router_logits = []
total_expert_indexes = []
Expand All @@ -1784,11 +1788,10 @@ def _unpack_router_logits(self, router_outputs):
router_logits, expert_indexes = router_output
total_router_logits.append(router_logits)
total_expert_indexes.append(expert_indexes)
if len(total_expert_indexes) > 0:
total_router_logits = torch.cat(total_router_logits, dim=1)
if len(total_expert_indexes) > 0:
torch.cat(total_expert_indexes, dim=1)
return torch.cat(total_router_logits, dim=1), torch.cat(total_expert_indexes, dim=1)

total_router_logits = torch.cat(total_router_logits, dim=1) if len(total_router_logits) > 0 else None
total_expert_indexes = torch.stack(total_expert_indexes, dim=1) if len(total_expert_indexes) > 0 else None
return total_router_logits, total_expert_indexes

# Copied from transfomers.models.switch_transformers.SwitchTransformersForConditionalGeneration.prepare_inputs_for_generation
def prepare_inputs_for_generation(
Expand Down
10 changes: 10 additions & 0 deletions tests/models/nllb_moe/test_modeling_nllb_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,16 @@ def test_generate_fp16(self):
model.generate(input_ids, attention_mask=attention_mask)
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)

def test_get_loss(self):
config, input_dict = self.model_tester.prepare_config_and_inputs()
input_dict["output_router_logits"] = True
input_dict["labels"] = input_dict["input_ids"]
model = NllbMoeForConditionalGeneration(config).eval().to(torch_device)
out = model(**input_dict)
self.assertIsNotNone(out.loss)
self.assertIsNotNone(model(**input_dict)["encoder_router_logits"][1])
self.assertIsNotNone(model(**input_dict)["decoder_router_logits"][0])


@require_torch
@require_sentencepiece
Expand Down

0 comments on commit 181d778

Please sign in to comment.