From d842f2d5b9bd4e361644c332bf9dc7f9b064f581 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Wed, 7 Sep 2022 20:01:30 +0800 Subject: [PATCH] update the train_batch_size in case HPO change batch_size_per_device (#18918) Signed-off-by: Wang, Yi A Signed-off-by: Wang, Yi A --- src/transformers/trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index ae1b2458524817..f7c3836d4afd42 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1488,6 +1488,7 @@ def train( raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.") # This might change the seed so needs to run first. self._hp_search_setup(trial) + self._train_batch_size = self.args.train_batch_size # Model re-init model_reloaded = False