22import numpy as np
33import torch
44
5- from tiatoolbox import utils
65from tiatoolbox .models import MapDe
76from tiatoolbox .models .architecture import fetch_pretrained_weights
7+ from tiatoolbox .utils import env_detection as toolbox_env
8+ from tiatoolbox .utils .misc import select_device
89from tiatoolbox .wsicore .wsireader import WSIReader
910
11+ ON_GPU = toolbox_env .has_gpu ()
12+
1013
1114def _load_mapde (tmp_path , name ):
1215 """Loads MapDe model with specified weights."""
1316 model = MapDe ()
1417 fetch_pretrained_weights (name , f"{ tmp_path } /weights.pth" )
15- map_location = utils . misc . select_device (utils . env_detection . has_gpu () )
18+ map_location = select_device (ON_GPU )
1619 pretrained = torch .load (f"{ tmp_path } /weights.pth" , map_location = map_location )
1720 model .load_state_dict (pretrained )
1821
@@ -34,14 +37,10 @@ def test_functionality(remote_sample, tmp_path):
3437 (0 , 0 , 252 , 252 ), resolution = 0.50 , units = "mpp" , coord_space = "resolution"
3538 )
3639
37- model = _load_mapde (tmp_path = tmp_path , name = "mapde-crchisto " )
40+ model = _load_mapde (tmp_path = tmp_path , name = "mapde-conic " )
3841 patch = model .preproc (patch )
3942 batch = torch .from_numpy (patch )[None ]
40- output = model .infer_batch (model , batch , on_gpu = False )
41- output = model .postproc (output [0 ])
42- assert np .all (output [0 :2 ] == [[99 , 178 ], [64 , 218 ]])
43-
44- model = _load_mapde (tmp_path = tmp_path , name = "mapde-conic" )
45- output = model .infer_batch (model , batch , on_gpu = False )
43+ model = model .to (select_device (ON_GPU ))
44+ output = model .infer_batch (model , batch , on_gpu = ON_GPU )
4645 output = model .postproc (output [0 ])
4746 assert np .all (output [0 :2 ] == [[19 , 171 ], [53 , 89 ]])
0 commit comments