Skip to content

Commit

Permalink
Fix flaky SwitchTransformersModelTest::test_training_gradient (hugg…
Browse files Browse the repository at this point in the history
…ingface#35587)

* fix

* Update tests/models/switch_transformers/test_modeling_switch_transformers.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
  • Loading branch information
3 people authored Jan 9, 2025
1 parent eb4579c commit 82dd6c1
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,8 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel
test_torchscript = False
# The small SWITCH_TRANSFORMERS model needs higher percentages for CPU/MP tests
model_split_percents = [0.5, 0.8, 0.9]
# `SwitchTransformers` is a MOE in which not all experts will get gradients because they are not all used in a single forward pass
test_all_params_have_gradient = False

def setUp(self):
self.model_tester = SwitchTransformersModelTester(self)
Expand Down
9 changes: 6 additions & 3 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ class ModelTesterMixin:
test_mismatched_shapes = True
test_missing_keys = True
test_model_parallel = False
# Used in `check_training_gradient_checkpointing` to NOT check all params having gradient (e.g. for some MOE models)
test_all_params_have_gradient = True
is_encoder_decoder = False
has_attentions = True
_is_composite = False
Expand Down Expand Up @@ -895,9 +897,10 @@ def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=No
loss.backward()
optimizer.step()

for k, v in model.named_parameters():
if v.requires_grad:
self.assertTrue(v.grad is not None, f"{k} in {model_class.__name__} has no gradient!")
if self.test_all_params_have_gradient:
for k, v in model.named_parameters():
if v.requires_grad:
self.assertTrue(v.grad is not None, f"{k} in {model_class.__name__} has no gradient!")

def test_training(self):
if not self.model_tester.is_training:
Expand Down

0 comments on commit 82dd6c1

Please sign in to comment.