diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index ae1b2458524817..f7c3836d4afd42 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1488,6 +1488,7 @@ def train( raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.") # This might change the seed so needs to run first. self._hp_search_setup(trial) + self._train_batch_size = self.args.train_batch_size # Model re-init model_reloaded = False