Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✅ Reuse models and datasets in tests #641

Merged
merged 5 commits into from
Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions tests/models/test_arch_mapde.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,23 @@
ON_GPU = toolbox_env.has_gpu()


def _load_mapde(tmp_path, name):
def _load_mapde(name):
"""Loads MapDe model with specified weights."""
model = MapDe()
fetch_pretrained_weights(name, f"{tmp_path}/weights.pth")
weights_path = fetch_pretrained_weights(name)
map_location = select_device(ON_GPU)
pretrained = torch.load(f"{tmp_path}/weights.pth", map_location=map_location)
pretrained = torch.load(weights_path, map_location=map_location)
model.load_state_dict(pretrained)

return model


def test_functionality(remote_sample, tmp_path):
def test_functionality(remote_sample):
"""Functionality test for MapDe.

Tests the functionality of MapDe model for inference at the patch level.

"""
tmp_path = str(tmp_path)
sample_wsi = str(remote_sample("wsi1_2k_2k_svs"))
reader = WSIReader.open(sample_wsi)

Expand All @@ -37,7 +36,7 @@ def test_functionality(remote_sample, tmp_path):
(0, 0, 252, 252), resolution=0.50, units="mpp", coord_space="resolution"
)

model = _load_mapde(tmp_path=tmp_path, name="mapde-conic")
model = _load_mapde(name="mapde-conic")
patch = model.preproc(patch)
batch = torch.from_numpy(patch)[None]
model = model.to(select_device(ON_GPU))
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_arch_micronet.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def test_functionality(remote_sample, tmp_path):
model = MicroNet()
patch = model.preproc(patch)
batch = torch.from_numpy(patch)[None]
fetch_pretrained_weights("micronet-consep", f"{tmp_path}/weights.pth")
weights_path = fetch_pretrained_weights("micronet-consep")
map_location = select_device(ON_GPU)
pretrained = torch.load(f"{tmp_path}/weights.pth", map_location=map_location)
pretrained = torch.load(weights_path, map_location=map_location)
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, on_gpu=ON_GPU)
output, _ = model.postproc(output[0])
Expand Down
5 changes: 2 additions & 3 deletions tests/models/test_arch_nuclick.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ def test_functional_nuclick(remote_sample, tmp_path, caplog):
tile_path = pathlib.Path(remote_sample("patch-extraction-vf"))
img = imread(tile_path)

_pretrained_path = f"{tmp_path}/weights.pth"
fetch_pretrained_weights("nuclick_original-pannuke", _pretrained_path)
weights_path = fetch_pretrained_weights("nuclick_original-pannuke")

# test creation
_ = NuClick(num_input_channels=5, num_output_channels=1)
Expand All @@ -46,7 +45,7 @@ def test_functional_nuclick(remote_sample, tmp_path, caplog):
batch = torch.from_numpy(batch[np.newaxis, ...])

model = NuClick(num_input_channels=5, num_output_channels=1)
pretrained = torch.load(_pretrained_path, map_location="cpu")
pretrained = torch.load(weights_path, map_location="cpu")
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, on_gpu=ON_GPU)
postproc_masks = model.postproc(
Expand Down
13 changes: 6 additions & 7 deletions tests/models/test_arch_sccnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,23 @@
from tiatoolbox.wsicore.wsireader import WSIReader


def _load_sccnn(tmp_path, name):
def _load_sccnn(name):
"""Loads SCCNN model with specified weights."""
model = SCCNN()
fetch_pretrained_weights(name, f"{tmp_path}/weights.pth")
weights_path = fetch_pretrained_weights(name)
map_location = utils.misc.select_device(utils.env_detection.has_gpu())
pretrained = torch.load(f"{tmp_path}/weights.pth", map_location=map_location)
pretrained = torch.load(weights_path, map_location=map_location)
model.load_state_dict(pretrained)

return model


def test_functionality(remote_sample, tmp_path):
def test_functionality(remote_sample):
"""Functionality test for SCCNN.

Tests the functionality of SCCNN model for inference at the patch level.

"""
tmp_path = str(tmp_path)
sample_wsi = str(remote_sample("wsi1_2k_2k_svs"))
reader = WSIReader.open(sample_wsi)

Expand All @@ -34,12 +33,12 @@ def test_functionality(remote_sample, tmp_path):
(30, 30, 61, 61), resolution=0.25, units="mpp", coord_space="resolution"
)
batch = torch.from_numpy(patch)[None]
model = _load_sccnn(tmp_path=tmp_path, name="sccnn-crchisto")
model = _load_sccnn(name="sccnn-crchisto")
output = model.infer_batch(model, batch, on_gpu=False)
output = model.postproc(output[0])
assert np.all(output == [[8, 7]])

model = _load_sccnn(tmp_path=tmp_path, name="sccnn-conic")
model = _load_sccnn(name="sccnn-conic")
output = model.infer_batch(model, batch, on_gpu=False)
output = model.postproc(output[0])
assert np.all(output == [[7, 8]])
5 changes: 2 additions & 3 deletions tests/models/test_arch_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ def test_functional_unet(remote_sample, tmp_path):
# convert to pathlib Path to prevent wsireader complaint
mini_wsi_svs = pathlib.Path(remote_sample("wsi2_4k_4k_svs"))

_pretrained_path = f"{tmp_path}/weights.pth"
fetch_pretrained_weights("fcn-tissue_mask", _pretrained_path)
pretrained_weights = fetch_pretrained_weights("fcn-tissue_mask")

reader = WSIReader.open(mini_wsi_svs)
with pytest.raises(ValueError, match=r".*Unknown encoder*"):
Expand All @@ -47,7 +46,7 @@ def test_functional_unet(remote_sample, tmp_path):
batch = torch.from_numpy(batch)

model = UNetModel(3, 2, encoder="resnet50", decoder_block=[3])
pretrained = torch.load(_pretrained_path, map_location="cpu")
pretrained = torch.load(pretrained_weights, map_location="cpu")
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, on_gpu=ON_GPU)
output = output[0]
Expand Down
7 changes: 6 additions & 1 deletion tests/models/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,17 @@ def test_dataset_abc():
def test_kather_dataset_default(tmp_path):
"""Test for kather patch dataset with default parameters."""
# test kather with default init
dataset_path = os.path.join(
rcParam["TIATOOLBOX_HOME"], "dataset", "kather100k-validation"
)
shutil.rmtree(dataset_path, ignore_errors=True)

_ = KatherPatchDataset()
# kather with default data path skip download
_ = KatherPatchDataset()

# remove generated data
shutil.rmtree(rcParam["TIATOOLBOX_HOME"])
shutil.rmtree(dataset_path, ignore_errors=False)


def test_kather_nonexisting_dir():
Expand Down
19 changes: 9 additions & 10 deletions tests/models/test_hovernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
from tiatoolbox.wsicore.wsireader import WSIReader


def test_functionality(remote_sample, tmp_path):
def test_functionality(remote_sample):
"""Functionality test."""
tmp_path = str(tmp_path)
sample_wsi = str(remote_sample("wsi1_2k_2k_svs"))
reader = WSIReader.open(sample_wsi)

Expand All @@ -27,8 +26,8 @@ def test_functionality(remote_sample, tmp_path):
)
batch = torch.from_numpy(patch)[None]
model = HoVerNet(num_types=6, mode="fast")
fetch_pretrained_weights("hovernet_fast-pannuke", f"{tmp_path}/weights.pth")
pretrained = torch.load(f"{tmp_path}/weights.pth")
weights_path = fetch_pretrained_weights("hovernet_fast-pannuke")
pretrained = torch.load(weights_path)
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, on_gpu=False)
output = [v[0] for v in output]
Expand All @@ -41,8 +40,8 @@ def test_functionality(remote_sample, tmp_path):
)
batch = torch.from_numpy(patch)[None]
model = HoVerNet(num_types=5, mode="fast")
fetch_pretrained_weights("hovernet_fast-monusac", f"{tmp_path}/weights.pth")
pretrained = torch.load(f"{tmp_path}/weights.pth")
weights_path = fetch_pretrained_weights("hovernet_fast-monusac")
pretrained = torch.load(weights_path)
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, on_gpu=False)
output = [v[0] for v in output]
Expand All @@ -55,8 +54,8 @@ def test_functionality(remote_sample, tmp_path):
)
batch = torch.from_numpy(patch)[None]
model = HoVerNet(num_types=5, mode="original")
fetch_pretrained_weights("hovernet_original-consep", f"{tmp_path}/weights.pth")
pretrained = torch.load(f"{tmp_path}/weights.pth")
weights_path = fetch_pretrained_weights("hovernet_original-consep")
pretrained = torch.load(weights_path)
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, on_gpu=False)
output = [v[0] for v in output]
Expand All @@ -69,8 +68,8 @@ def test_functionality(remote_sample, tmp_path):
)
batch = torch.from_numpy(patch)[None]
model = HoVerNet(num_types=None, mode="original")
fetch_pretrained_weights("hovernet_original-kumar", f"{tmp_path}/weights.pth")
pretrained = torch.load(f"{tmp_path}/weights.pth")
weights_path = fetch_pretrained_weights("hovernet_original-kumar")
pretrained = torch.load(weights_path)
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, on_gpu=False)
output = [v[0] for v in output]
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_hovernetplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def test_functionality(remote_sample, tmp_path):
assert len(model.decoder["hv"]) > 0, "Decoder must contain hv branch."
assert len(model.decoder["tp"]) > 0, "Decoder must contain tp branch."
assert len(model.decoder["ls"]) > 0, "Decoder must contain ls branch."
fetch_pretrained_weights("hovernetplus-oed", f"{tmp_path}/weigths.pth")
pretrained = torch.load(f"{tmp_path}/weigths.pth")
weights_path = fetch_pretrained_weights("hovernetplus-oed")
pretrained = torch.load(weights_path)
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, on_gpu=False)
assert len(output) == 4, "Must contain predictions for: np, hv, tp and ls branches."
Expand Down
6 changes: 2 additions & 4 deletions tests/models/test_nucleus_instance_segmentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,9 +519,7 @@ def test_cli_nucleus_instance_segment_ioconfig(remote_sample, tmp_path):
mini_wsi_jpg = f"{tmp_path}/mini_svs.jpg"
imwrite(mini_wsi_jpg, thumb)

fetch_pretrained_weights(
"hovernet_fast-pannuke", str(tmp_path.joinpath("hovernet_fast-pannuke.pth"))
)
pretrained_weights = fetch_pretrained_weights("hovernet_fast-pannuke")

# resolution for travis testing, not the correct ones
config = {
Expand Down Expand Up @@ -550,7 +548,7 @@ def test_cli_nucleus_instance_segment_ioconfig(remote_sample, tmp_path):
"--img-input",
str(mini_wsi_jpg),
"--pretrained-weights",
str(tmp_path.joinpath("hovernet_fast-pannuke.pth")),
str(pretrained_weights),
"--num-loader-workers",
str(0),
"--num-postproc-workers",
Expand Down
6 changes: 3 additions & 3 deletions tests/models/test_patch_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
from click.testing import CliRunner

from tiatoolbox import cli, rcParam
from tiatoolbox import cli
from tiatoolbox.models import IOPatchPredictorConfig, PatchPredictor
from tiatoolbox.models.architecture.vanilla import CNNModel
from tiatoolbox.models.dataset import (
Expand Down Expand Up @@ -205,7 +205,6 @@ def test_patch_dataset_crash(tmp_path):
match="Cannot load image data from",
):
_ = PatchDataset(imgs)
_rm_dir(rcParam["TIATOOLBOX_HOME"])

# preproc func for not defined dataset
with pytest.raises(
Expand Down Expand Up @@ -657,8 +656,9 @@ def test_patch_predictor_api(sample_patch1, sample_patch2, tmp_path):
# remove prev generated data
_rm_dir(save_dir_path)
os.makedirs(save_dir_path)

pretrained_weights = os.path.join(
rcParam["TIATOOLBOX_HOME"],
save_dir_path,
"tmp_pretrained_weigths",
"resnet18-kather100k.pth",
)
Expand Down
6 changes: 2 additions & 4 deletions tests/models/test_semantic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,9 +765,7 @@ def test_cli_semantic_segmentation_ioconfig(remote_sample, tmp_path):
sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8)
imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk)
sample_wsi_msk = f"{tmp_path}/small_svs_tissue_mask.jpg"
fetch_pretrained_weights(
"fcn-tissue_mask", str(tmp_path.joinpath("fcn-tissue_mask.pth"))
)
pretrained_weights = fetch_pretrained_weights("fcn-tissue_mask")

config = {
"input_resolutions": [{"units": "mpp", "resolution": 2.0}],
Expand All @@ -789,7 +787,7 @@ def test_cli_semantic_segmentation_ioconfig(remote_sample, tmp_path):
"--img-input",
str(mini_wsi_svs),
"--pretrained-weights",
str(tmp_path.joinpath("fcn-tissue_mask.pth")),
str(pretrained_weights),
"--mode",
"wsi",
"--masks",
Expand Down
16 changes: 16 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from tests.test_annotation_stores import cell_polygon
from tiatoolbox import rcParam, utils
from tiatoolbox.models.architecture import fetch_pretrained_weights
from tiatoolbox.utils import misc
from tiatoolbox.utils.exceptions import FileNotSupported
from tiatoolbox.utils.transforms import locsize2bounds
Expand Down Expand Up @@ -1472,3 +1473,18 @@ def test_from_multi_head_dat_type_dict(tmp_path):
assert len(result) == 1
result = store.query(where=lambda x: x["type"][0:4] == "cell")
assert len(result) == 2


def test_fetch_pretrained_weights(tmp_path):
"""Test fetching pretrained weights for a model."""

file_path = os.path.join(tmp_path, "test_fetch_pretrained_weights.pth")
if os.path.exists(file_path):
os.remove(file_path)

fetch_pretrained_weights("mobilenet_v3_small-pcam", file_path)
assert os.path.exists(file_path)
assert os.path.getsize(file_path) > 0

with pytest.raises(ValueError, match="does not exist"):
fetch_pretrained_weights("abc", file_path)
23 changes: 18 additions & 5 deletions tiatoolbox/models/architecture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
PRETRAINED_INFO = rcParam["pretrained_model_info"]


def fetch_pretrained_weights(model_name: str, save_path: str, overwrite: bool = True):
def fetch_pretrained_weights(
model_name: str, save_path: str = None, overwrite: bool = False
) -> pathlib.Path:
"""Get the pretrained model information from yml file.

Args:
Expand All @@ -29,9 +31,22 @@ def fetch_pretrained_weights(model_name: str, save_path: str, overwrite: bool =
overwrite (bool):
Overwrite existing downloaded weights.

Returns:
pathlib.Path:
The local path to the cached pretrained weights after downloading.

"""
if model_name not in PRETRAINED_INFO:
raise ValueError(f"Pretrained model `{model_name}` does not exist")

info = PRETRAINED_INFO[model_name]

if save_path is None:
file_name = info["url"].split("/")[-1]
save_path = os.path.join(rcParam["TIATOOLBOX_HOME"], "models/", file_name)
shaneahmed marked this conversation as resolved.
Show resolved Hide resolved

download_data(info["url"], save_path, overwrite)
return pathlib.Path(save_path)


def get_pretrained_model(
Expand Down Expand Up @@ -111,11 +126,9 @@ def get_pretrained_model(
model.preproc_func = predefined_preproc_func(info["dataset"])

if pretrained_weights is None:
file_name = info["url"].split("/")[-1]
pretrained_weights = os.path.join(
rcParam["TIATOOLBOX_HOME"], "models/", file_name
pretrained_weights = fetch_pretrained_weights(
pretrained_model, overwrite=overwrite
)
fetch_pretrained_weights(pretrained_model, pretrained_weights, overwrite)

# ! assume to be saved in single GPU mode
# always load on to the CPU
Expand Down
2 changes: 1 addition & 1 deletion tiatoolbox/models/dataset/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(

if save_dir_path is None: # pragma: no cover
save_dir_path = Path(rcParam["TIATOOLBOX_HOME"], "dataset")
if not os.path.exists(save_dir_path):
if not os.path.exists(os.path.join(save_dir_path, "kather100k-validation")):
save_zip_path = os.path.join(save_dir_path, "Kather.zip")
url = (
"https://tiatoolbox.dcs.warwick.ac.uk/datasets"
Expand Down