diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index 456f8d9d06ae60..112c1a7a552593 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -746,7 +746,7 @@ def setup(self, args, state, model, **kwargs): # keep track of model topology and gradients, unsupported on TPU _watch_model = os.getenv("WANDB_WATCH", "false") if not is_torch_tpu_available() and _watch_model in ("all", "parameters", "gradients"): - self._wandb.watch(model, log=_watch_model, log_freq=max(100, args.logging_steps)) + self._wandb.watch(model, log=_watch_model, log_freq=max(100, state.logging_steps)) def on_train_begin(self, args, state, control, model=None, **kwargs): if self._wandb is None: diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 82fd41874d73bb..e3971dec474a67 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1586,14 +1586,6 @@ def _inner_training_loop( f" {args.max_steps}" ) - # Compute absolute values for logging, eval, and save if given as ratio - if args.logging_steps and args.logging_steps < 1: - args.logging_steps = math.ceil(max_steps * args.logging_steps) - if args.eval_steps and args.eval_steps < 1: - args.eval_steps = math.ceil(max_steps * args.eval_steps) - if args.save_steps and args.save_steps < 1: - args.save_steps = math.ceil(max_steps * args.save_steps) - if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: if self.args.n_gpu > 1: # nn.DataParallel(model) replicates the model, creating new variables and module @@ -1627,6 +1619,23 @@ def _inner_training_loop( self.state = TrainerState() self.state.is_hyper_param_search = trial is not None + # Compute absolute values for logging, eval, and save if given as ratio + if args.logging_steps is not None: + if args.logging_steps < 1: + self.state.logging_steps = math.ceil(max_steps * args.logging_steps) + else: + self.state.logging_steps = args.logging_steps + if args.eval_steps is not None: + if args.eval_steps < 1: + self.state.eval_steps = math.ceil(max_steps * args.eval_steps) + else: + self.state.eval_steps = args.eval_steps + if args.save_steps is not None: + if args.save_steps < 1: + self.state.save_steps = math.ceil(max_steps * args.save_steps) + else: + self.state.save_steps = args.save_steps + # Activate gradient checkpointing if needed if args.gradient_checkpointing: self.model.gradient_checkpointing_enable() diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index bade6841ed1be4..49b12ea558d401 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -53,6 +53,12 @@ class TrainerState: During training, represents the number of update steps completed. max_steps (`int`, *optional*, defaults to 0): The number of update steps to do during the current training. + logging_steps (`int`, *optional*, defaults to 500): + Log every X updates steps + eval_steps (`int`, *optional*): + Run an evaluation every X steps. + save_steps (`int`, *optional*, defaults to 500): + Save checkpoint every X updates steps. total_flos (`float`, *optional*, defaults to 0): The total number of floating operations done by the model since the beginning of training (stored as floats to avoid overflow). @@ -77,6 +83,9 @@ class TrainerState: epoch: Optional[float] = None global_step: int = 0 max_steps: int = 0 + logging_steps: int = 500 + eval_steps: int = 500 + save_steps: int = 500 num_train_epochs: int = 0 total_flos: float = 0 log_history: List[Dict[str, float]] = None @@ -421,13 +430,13 @@ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: Tra # Log if state.global_step == 1 and args.logging_first_step: control.should_log = True - if args.logging_strategy == IntervalStrategy.STEPS and state.global_step % args.logging_steps == 0: + if args.logging_strategy == IntervalStrategy.STEPS and state.global_step % state.logging_steps == 0: control.should_log = True # Evaluate if ( args.evaluation_strategy == IntervalStrategy.STEPS - and state.global_step % args.eval_steps == 0 + and state.global_step % state.eval_steps == 0 and args.eval_delay <= state.global_step ): control.should_evaluate = True @@ -435,8 +444,8 @@ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: Tra # Save if ( args.save_strategy == IntervalStrategy.STEPS - and args.save_steps > 0 - and state.global_step % args.save_steps == 0 + and state.save_steps > 0 + and state.global_step % state.save_steps == 0 ): control.should_save = True