Skip to content

Commit

Permalink
Fix RNG reload in resume training from epoch checkpoint (huggingface#…
Browse files Browse the repository at this point in the history
…17055)

* Fix RNG reload in resume training from epoch checkpoint

* Fix test
  • Loading branch information
sgugger authored May 3, 2022
1 parent 6e17ba6 commit 1c9fcd0
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 23 deletions.
8 changes: 6 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,13 +789,16 @@ def estimate_tokens(self, input_dict: Dict[str, Union[torch.Tensor, Any]]) -> in
Returns:
`int`: The total number of tokens.
"""
if not hasattr(self, "warnings_issued"):
self.warnings_issued = {}
if self.main_input_name in input_dict:
return input_dict[self.main_input_name].numel()
else:
elif "estimate_tokens" not in self.warnings_issued:
logger.warning(
"Could not estimate the number of tokens of the input, floating-point operations will not be computed"
)
return 0
self.warnings_issued["estimate_tokens"] = True
return 0

def floating_point_ops(
self, input_dict: Dict[str, Union[torch.Tensor, Any]], exclude_embeddings: bool = True
Expand Down Expand Up @@ -895,6 +898,7 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
# Save config and origin of the pretrained weights if given in model
self.config = config
self.name_or_path = config.name_or_path
self.warnings_issued = {}

def post_init(self):
"""
Expand Down
6 changes: 5 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1151,7 +1151,8 @@ def train(
kwargs:
Additional keyword arguments used to hide deprecated arguments
"""
resume_from_checkpoint = None if not resume_from_checkpoint else resume_from_checkpoint
if resume_from_checkpoint is False:
resume_from_checkpoint = None

# memory metrics - must set up as early as possible
self._memory_tracker.start()
Expand Down Expand Up @@ -1395,6 +1396,9 @@ def train(
)
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)

if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint)

step = -1
for step, inputs in enumerate(epoch_iterator):

Expand Down
71 changes: 51 additions & 20 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
require_torch_bf16,
require_torch_gpu,
require_torch_multi_gpu,
require_torch_non_multi_gpu,
require_torch_tf32,
require_torch_up_to_2_gpus,
require_wandb,
Expand Down Expand Up @@ -162,11 +161,12 @@ def __call__(self, eval_pred):


class RegressionModelConfig(PretrainedConfig):
def __init__(self, a=0, b=0, double_output=False, **kwargs):
def __init__(self, a=0, b=0, double_output=False, random_torch=True, **kwargs):
super().__init__(**kwargs)
self.a = a
self.b = b
self.double_output = double_output
self.random_torch = random_torch
self.hidden_size = 1


Expand Down Expand Up @@ -264,14 +264,18 @@ def __init__(self, config):
super().__init__(config)
self.a = nn.Parameter(torch.tensor(config.a).float())
self.b = nn.Parameter(torch.tensor(config.b).float())
self.random_torch = config.random_torch

def forward(self, input_x, labels=None, **kwargs):
y = input_x * self.a + self.b
torch_rand = torch.randn(1).squeeze()
if self.random_torch:
torch_rand = torch.randn(1).squeeze()
np_rand = np.random.rand()
rand_rand = random.random()

y += 0.05 * torch_rand + 0.05 * torch.tensor(np_rand + rand_rand)
if self.random_torch:
y += 0.05 * torch_rand
y += 0.05 * torch.tensor(np_rand + rand_rand)

if labels is None:
return (y,)
Expand Down Expand Up @@ -1016,33 +1020,60 @@ def test_can_resume_training(self):
trainer.train(resume_from_checkpoint=True)
self.assertTrue("No valid checkpoint found in output directory" in str(context.exception))

@require_torch_non_multi_gpu
def test_resume_training_with_randomness(self):
# This test will fail flakily for more than 1 GPUs since the result will be slightly more different
# TODO: investigate why it fails for 2 GPUs?
# For more than 1 GPUs, since the randomness is introduced in the model and with DataParallel (which is used
# in this test for more than 2 GPUs), the calls to the torch RNG will happen in a random order (sometimes
# GPU 0 will call first and sometimes GPU 1).
random_torch = not torch.cuda.is_available() or torch.cuda.device_count() <= 1

if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
train_dataset = RegressionDataset(length=128)
eval_dataset = RegressionDataset()

config = RegressionModelConfig(a=0, b=2)
model = RegressionRandomPreTrainedModel(config)
with self.subTest("Test every step"):
config = RegressionModelConfig(a=0, b=2, random_torch=random_torch)
model = RegressionRandomPreTrainedModel(config)

tmp_dir = self.get_auto_remove_tmp_dir()
args = RegressionTrainingArguments(tmp_dir, save_steps=5, learning_rate=0.1)
trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
tmp_dir = self.get_auto_remove_tmp_dir()
args = RegressionTrainingArguments(tmp_dir, save_steps=5, learning_rate=0.1)
trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)

trainer.train()
(a, b) = trainer.model.a.item(), trainer.model.b.item()
trainer.train()
(a, b) = trainer.model.a.item(), trainer.model.b.item()

model = RegressionRandomPreTrainedModel(config)
trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
trainer.train(resume_from_checkpoint=os.path.join(tmp_dir, "checkpoint-15"))
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
model = RegressionRandomPreTrainedModel(config)
trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
trainer.train(resume_from_checkpoint=os.path.join(tmp_dir, "checkpoint-15"))
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()

self.assertAlmostEqual(a, a1, delta=1e-8)
self.assertAlmostEqual(b, b1, delta=1e-8)

with self.subTest("Test every epoch"):
config = RegressionModelConfig(a=0, b=2, random_torch=random_torch)
model = RegressionRandomPreTrainedModel(config)

tmp_dir = self.get_auto_remove_tmp_dir()
args = RegressionTrainingArguments(tmp_dir, save_strategy="epoch", learning_rate=0.1)
trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)

trainer.train()
(a, b) = trainer.model.a.item(), trainer.model.b.item()

model = RegressionRandomPreTrainedModel(config)
trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)

checkpoints = [d for d in os.listdir(tmp_dir) if d.startswith("checkpoint-")]
# There should be one checkpoint per epoch.
self.assertEqual(len(checkpoints), 3)
checkpoint_dir = sorted(checkpoints, key=lambda x: int(x.replace("checkpoint-", "")))[0]

trainer.train(resume_from_checkpoint=os.path.join(tmp_dir, checkpoint_dir))
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()

self.assertAlmostEqual(a, a1, delta=1e-8)
self.assertAlmostEqual(b, b1, delta=1e-8)
self.assertAlmostEqual(a, a1, delta=1e-8)
self.assertAlmostEqual(b, b1, delta=1e-8)

# regression for this issue: https://github.com/huggingface/transformers/issues/12970
def test_training_with_resume_from_checkpoint_false(self):
Expand Down

0 comments on commit 1c9fcd0

Please sign in to comment.