Skip to content

Commit d0d4ed6

Browse files
authored
✅ Reduce run time for Mapde test (#627)
- Reduce run time for Mapde test.
1 parent e4deac4 commit d0d4ed6

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

tests/models/test_arch_mapde.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,20 @@
22
import numpy as np
33
import torch
44

5-
from tiatoolbox import utils
65
from tiatoolbox.models import MapDe
76
from tiatoolbox.models.architecture import fetch_pretrained_weights
7+
from tiatoolbox.utils import env_detection as toolbox_env
8+
from tiatoolbox.utils.misc import select_device
89
from tiatoolbox.wsicore.wsireader import WSIReader
910

11+
ON_GPU = toolbox_env.has_gpu()
12+
1013

1114
def _load_mapde(tmp_path, name):
1215
"""Loads MapDe model with specified weights."""
1316
model = MapDe()
1417
fetch_pretrained_weights(name, f"{tmp_path}/weights.pth")
15-
map_location = utils.misc.select_device(utils.env_detection.has_gpu())
18+
map_location = select_device(ON_GPU)
1619
pretrained = torch.load(f"{tmp_path}/weights.pth", map_location=map_location)
1720
model.load_state_dict(pretrained)
1821

@@ -34,14 +37,10 @@ def test_functionality(remote_sample, tmp_path):
3437
(0, 0, 252, 252), resolution=0.50, units="mpp", coord_space="resolution"
3538
)
3639

37-
model = _load_mapde(tmp_path=tmp_path, name="mapde-crchisto")
40+
model = _load_mapde(tmp_path=tmp_path, name="mapde-conic")
3841
patch = model.preproc(patch)
3942
batch = torch.from_numpy(patch)[None]
40-
output = model.infer_batch(model, batch, on_gpu=False)
41-
output = model.postproc(output[0])
42-
assert np.all(output[0:2] == [[99, 178], [64, 218]])
43-
44-
model = _load_mapde(tmp_path=tmp_path, name="mapde-conic")
45-
output = model.infer_batch(model, batch, on_gpu=False)
43+
model = model.to(select_device(ON_GPU))
44+
output = model.infer_batch(model, batch, on_gpu=ON_GPU)
4645
output = model.postproc(output[0])
4746
assert np.all(output[0:2] == [[19, 171], [53, 89]])

0 commit comments

Comments
 (0)