Skip to content

Commit 2cdc147

Browse files
committed
test fixes
1 parent 47c142d commit 2cdc147

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ def test_default_checkpoint_freq(save_mock, tmpdir, epochs: int, val_check_inter
5454
@pytest.mark.parametrize(
5555
("k", "epochs", "val_check_interval", "expected"), [(1, 1, 1.0, 1), (2, 2, 1.0, 2), (2, 1, 0.25, 4), (2, 2, 0.3, 6)]
5656
)
57-
@pytest.mark.parametrize("save_last", [False, True])
58-
def test_top_k(save_mock, tmpdir, k: int, epochs: int, val_check_interval: float, expected: int, save_last: bool):
57+
@pytest.mark.parametrize("save_last", [False, True, "link"])
58+
def test_top_k(save_mock, tmpdir, k, epochs, val_check_interval, expected, save_last):
5959
class TestModel(BoringModel):
6060
def __init__(self):
6161
super().__init__()
@@ -79,8 +79,8 @@ def training_step(self, batch, batch_idx):
7979
)
8080
trainer.fit(model)
8181

82-
if save_last:
83-
expected = expected
82+
# save_last=True: last epochs are saved every step (so double the save calls)
83+
expected = expected * 2 if save_last is True else expected
8484
assert save_mock.call_count == expected
8585

8686

0 commit comments

Comments
 (0)