Skip to content

Commit e01a61a

Browse files
winglianArthurZucker
authored andcommitted
FSDP grad accum fix (#34645)
* add gradient accumulation steps tests for fsdp * invert no_sync context to fix training for fsdp
1 parent ccbd57a commit e01a61a

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

src/transformers/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2474,7 +2474,7 @@ def _inner_training_loop(
24742474
# We explicitly want to avoid relying on `accelerator.accumulate` for generation training
24752475
context = (
24762476
functools.partial(self.accelerator.no_sync, model=model)
2477-
if i == len(batch_samples) - 1
2477+
if i != len(batch_samples) - 1
24782478
else contextlib.nullcontext
24792479
)
24802480
with context():

tests/fsdp/test_fsdp.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,18 @@ def test_basic_run(self, sharding_strategy, dtype):
224224
cmd = launcher + script + args + fsdp_args
225225
execute_subprocess_async(cmd, env=self.get_env())
226226

227+
@parameterized.expand(params, name_func=_parameterized_custom_name_func)
228+
@require_torch_multi_accelerator
229+
@slow
230+
def test_basic_run_with_gradient_accumulation(self, sharding_strategy, dtype):
231+
launcher = get_launcher(distributed=True, use_accelerate=False)
232+
output_dir = self.get_auto_remove_tmp_dir()
233+
args = self.get_base_args(output_dir, 1, 50).split() + [f"--{dtype}", "--gradient_accumulation_steps", "2"]
234+
fsdp_args = ["--fsdp", f"{sharding_strategy} auto_wrap", "--fsdp_transformer_layer_cls_to_wrap", "BertLayer"]
235+
script = [f"{self.examples_dir_str}/pytorch/text-classification/run_glue.py"]
236+
cmd = launcher + script + args + fsdp_args
237+
execute_subprocess_async(cmd, env=self.get_env())
238+
227239
@parameterized.expand(dtypes)
228240
@require_torch_multi_accelerator
229241
@slow

0 commit comments

Comments
 (0)