Skip to content

Commit f1a1284

Browse files
committed
use tmpdir argument for location to save ckpt
1 parent d3254aa commit f1a1284

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

tests/tests_pytorch/core/test_saving.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,25 @@
66
from tests_pytorch.helpers.runif import RunIf
77

88

9-
def create_boring_checkpoint():
9+
def create_boring_checkpoint(tmpdir):
1010
model = BoringModel()
1111
trainer = pl.Trainer(accelerator="auto", max_epochs=1, enable_model_summary=False, enable_progress_bar=False)
1212
trainer.fit(model)
13-
trainer.save_checkpoint("./boring.ckpt")
13+
trainer.save_checkpoint(f"{tmpdir}/boring.ckpt")
1414

1515

1616
@pytest.mark.parametrize("map_location", ("cpu", torch.device("cpu"), lambda storage, loc: storage, {"cpu": "cpu"}))
17-
def test_load_from_checkpoint_map_location_cpu(map_location):
18-
create_boring_checkpoint()
19-
model = BoringModel.load_from_checkpoint("./boring.ckpt", map_location=map_location)
17+
def test_load_from_checkpoint_map_location_cpu(tmpdir, map_location):
18+
create_boring_checkpoint(tmpdir)
19+
model = BoringModel.load_from_checkpoint(f"{tmpdir}/boring.ckpt", map_location=map_location)
2020
assert model.device.type == "cpu"
2121

2222

2323
@RunIf(min_cuda_gpus=1)
2424
@pytest.mark.parametrize(
2525
"map_location", ("cuda", torch.device("cuda"), lambda storage, loc: storage.cuda(), {"cpu": "cuda"})
2626
)
27-
def test_load_from_checkpoint_map_location_gpu(map_location):
28-
create_boring_checkpoint()
29-
model = BoringModel.load_from_checkpoint("./boring.ckpt", map_location=map_location)
27+
def test_load_from_checkpoint_map_location_gpu(tmpdir, map_location):
28+
create_boring_checkpoint(tmpdir)
29+
model = BoringModel.load_from_checkpoint(f"{tmpdir}/boring.ckpt", map_location=map_location)
3030
assert model.device.type == "cuda"

0 commit comments

Comments
 (0)