Skip to content

Commit

Permalink
Fix issue with ratio evaluation steps and auto find batch size (huggi…
Browse files Browse the repository at this point in the history
…ngface#25436)

* Fully rebased solution

* 500
  • Loading branch information
muellerzr authored and blbadger committed Nov 8, 2023
1 parent 286e6ca commit 0a99b2f
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/transformers/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ def setup(self, args, state, model, **kwargs):
# keep track of model topology and gradients, unsupported on TPU
_watch_model = os.getenv("WANDB_WATCH", "false")
if not is_torch_tpu_available() and _watch_model in ("all", "parameters", "gradients"):
self._wandb.watch(model, log=_watch_model, log_freq=max(100, args.logging_steps))
self._wandb.watch(model, log=_watch_model, log_freq=max(100, state.logging_steps))

def on_train_begin(self, args, state, control, model=None, **kwargs):
if self._wandb is None:
Expand Down
25 changes: 17 additions & 8 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1586,14 +1586,6 @@ def _inner_training_loop(
f" {args.max_steps}"
)

# Compute absolute values for logging, eval, and save if given as ratio
if args.logging_steps and args.logging_steps < 1:
args.logging_steps = math.ceil(max_steps * args.logging_steps)
if args.eval_steps and args.eval_steps < 1:
args.eval_steps = math.ceil(max_steps * args.eval_steps)
if args.save_steps and args.save_steps < 1:
args.save_steps = math.ceil(max_steps * args.save_steps)

if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
if self.args.n_gpu > 1:
# nn.DataParallel(model) replicates the model, creating new variables and module
Expand Down Expand Up @@ -1627,6 +1619,23 @@ def _inner_training_loop(
self.state = TrainerState()
self.state.is_hyper_param_search = trial is not None

# Compute absolute values for logging, eval, and save if given as ratio
if args.logging_steps is not None:
if args.logging_steps < 1:
self.state.logging_steps = math.ceil(max_steps * args.logging_steps)
else:
self.state.logging_steps = args.logging_steps
if args.eval_steps is not None:
if args.eval_steps < 1:
self.state.eval_steps = math.ceil(max_steps * args.eval_steps)
else:
self.state.eval_steps = args.eval_steps
if args.save_steps is not None:
if args.save_steps < 1:
self.state.save_steps = math.ceil(max_steps * args.save_steps)
else:
self.state.save_steps = args.save_steps

# Activate gradient checkpointing if needed
if args.gradient_checkpointing:
self.model.gradient_checkpointing_enable()
Expand Down
17 changes: 13 additions & 4 deletions src/transformers/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ class TrainerState:
During training, represents the number of update steps completed.
max_steps (`int`, *optional*, defaults to 0):
The number of update steps to do during the current training.
logging_steps (`int`, *optional*, defaults to 500):
Log every X updates steps
eval_steps (`int`, *optional*):
Run an evaluation every X steps.
save_steps (`int`, *optional*, defaults to 500):
Save checkpoint every X updates steps.
total_flos (`float`, *optional*, defaults to 0):
The total number of floating operations done by the model since the beginning of training (stored as floats
to avoid overflow).
Expand All @@ -77,6 +83,9 @@ class TrainerState:
epoch: Optional[float] = None
global_step: int = 0
max_steps: int = 0
logging_steps: int = 500
eval_steps: int = 500
save_steps: int = 500
num_train_epochs: int = 0
total_flos: float = 0
log_history: List[Dict[str, float]] = None
Expand Down Expand Up @@ -421,22 +430,22 @@ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: Tra
# Log
if state.global_step == 1 and args.logging_first_step:
control.should_log = True
if args.logging_strategy == IntervalStrategy.STEPS and state.global_step % args.logging_steps == 0:
if args.logging_strategy == IntervalStrategy.STEPS and state.global_step % state.logging_steps == 0:
control.should_log = True

# Evaluate
if (
args.evaluation_strategy == IntervalStrategy.STEPS
and state.global_step % args.eval_steps == 0
and state.global_step % state.eval_steps == 0
and args.eval_delay <= state.global_step
):
control.should_evaluate = True

# Save
if (
args.save_strategy == IntervalStrategy.STEPS
and args.save_steps > 0
and state.global_step % args.save_steps == 0
and state.save_steps > 0
and state.global_step % state.save_steps == 0
):
control.should_save = True

Expand Down

0 comments on commit 0a99b2f

Please sign in to comment.