Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
1eb20fe
DOC: Fix HoVerNet plus documentation
shaneahmed Feb 18, 2022
465c09d
DEV: Add models/architecture/micronet
shaneahmed Feb 18, 2022
298212d
Merge branch 'develop' into enh-add-micronet
shaneahmed Feb 18, 2022
43c836c
DEV: Fix deepsource errors.
shaneahmed Feb 18, 2022
63e1ec9
DEV: Add docstring for weights initialization.
shaneahmed Feb 18, 2022
3c3a201
Merge remote-tracking branch 'origin/develop' into enh-add-micronet
shaneahmed Feb 18, 2022
375630d
DOC: Update docstrings for various Micronet functions.
shaneahmed Feb 21, 2022
da230ba
NEW: Upload micronet consep model.
shaneahmed Feb 21, 2022
0704067
DEV: Update models/architecture/micronet.py to work with tiatoolbox
shaneahmed Feb 21, 2022
17c15f0
TST: Add test for micronet.
shaneahmed Feb 21, 2022
a748981
MAINT: Fix spelling errors in tests/models/test_hovernet
shaneahmed Feb 21, 2022
65d2a0a
BUG: Fix deepsource bugs
shaneahmed Feb 21, 2022
80ca408
Merge remote-tracking branch 'origin/develop' into enh-add-micronet
shaneahmed Feb 21, 2022
2144d9a
BUG: Fix deepsource errors
shaneahmed Feb 21, 2022
47cd5c4
DEV: Update the forward function to accept args and kwargs
shaneahmed Feb 21, 2022
b2d1167
DEV: Update abc to include input_images in modelabc
shaneahmed Feb 22, 2022
70fca32
TST: Add map_location to tests
shaneahmed Feb 22, 2022
5423ea6
DEV: Update forward function input.
shaneahmed Feb 22, 2022
a5ebf0e
DOC: Update docs to include micronet
shaneahmed Feb 22, 2022
8beba3d
DOC: update docstring for micronet
shaneahmed Feb 22, 2022
e4c2733
BUG: Skip potential bug error
shaneahmed Feb 22, 2022
dcd66fd
DOC: update docstring
shaneahmed Feb 22, 2022
efa17a5
DOC: Update docstring
shaneahmed Feb 22, 2022
d077db3
DOC: Add performance table to micronet
shaneahmed Feb 22, 2022
7496f17
TST: Add tests to improve coverage.
shaneahmed Feb 22, 2022
6695279
DEV: Update __init__ to include fetch_pretrained_weights
shaneahmed Feb 22, 2022
9af1c00
DEV: Rearrange group functions to clean the code.
shaneahmed Feb 23, 2022
78eb186
DEV: Update MicroNet to add local test and post process nuclei
shaneahmed Feb 24, 2022
3236432
TST: Fix micronet test
shaneahmed Feb 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions docs/pretrained.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ They share the same input output configuration defined below:
patch_input_shape=(1024, 1024),
patch_output_shape=(512, 512),
stride_shape=(256, 256),
save_resolution={'units': 'mpp', 'resolution': 8.0}
save_resolution={'units': 'mpp', 'resolution': 8.0}
)


Expand Down Expand Up @@ -144,7 +144,7 @@ They share the same input output configuration defined below:
patch_input_shape=(1024, 1024),
patch_output_shape=(512, 512),
stride_shape=(256, 256),
save_resolution={'units': 'mpp', 'resolution': 0.25}
save_resolution={'units': 'mpp', 'resolution': 0.25}
)


Expand Down Expand Up @@ -183,7 +183,7 @@ input output configuration:
patch_input_shape=(256, 256),
patch_output_shape=(164, 164),
stride_shape=(164, 164),
save_resolution={'units': 'mpp', 'resolution': 0.25}
save_resolution={'units': 'mpp', 'resolution': 0.25}
)

.. collapse:: Model names
Expand Down Expand Up @@ -217,7 +217,7 @@ input output configuration:
patch_input_shape=(256, 256),
patch_output_shape=(164, 164),
stride_shape=(164, 164),
save_resolution={'units': 'mpp', 'resolution': 0.25}
save_resolution={'units': 'mpp', 'resolution': 0.25}
)

.. collapse:: Model names
Expand Down Expand Up @@ -251,12 +251,13 @@ input output configuration:
patch_input_shape=(270, 270),
patch_output_shape=(80, 80),
stride_shape=(80, 80),
save_resolution={'units': 'mpp', 'resolution': 0.25}
save_resolution={'units': 'mpp', 'resolution': 0.25}
)

.. collapse:: Model names

- hovernet_original-consep
- micronet_hovernet-consep


--------------------
Expand Down Expand Up @@ -285,7 +286,7 @@ input output configuration:
patch_input_shape=(270, 270),
patch_output_shape=(80, 80),
stride_shape=(80, 80),
save_resolution={'units': 'mpp', 'resolution': 0.25}
save_resolution={'units': 'mpp', 'resolution': 0.25}
)

.. collapse:: Model names
Expand Down Expand Up @@ -325,9 +326,9 @@ input output configuration:
patch_input_shape=(256, 256),
patch_output_shape=(164, 164),
stride_shape=(164, 164),
save_resolution={'units': 'mpp', 'resolution': 0.5}
save_resolution={'units': 'mpp', 'resolution': 0.5}
)

.. collapse:: Model names

- hovernetplus-oed
- hovernetplus-oed
1 change: 1 addition & 0 deletions docs/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ Neural Network Architectures
- :obj:`Simplified U-Nets <tiatoolbox.models.architecture.unet>`
- :obj:`HoVerNet <tiatoolbox.models.architecture.hovernet.HoVerNet>`
- :obj:`HoVerNet+ <tiatoolbox.models.architecture.hovernetplus.HoVerNetPlus>`
- :obj:`MicroNet <tiatoolbox.models.architecture.micronet.MicroNet>`

Pipelines:
- :obj:`IDARS <tiatoolbox.models.architecture.idars>`
Expand Down
98 changes: 98 additions & 0 deletions tests/models/test_arch_micronet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# ***** BEGIN GPL LICENSE BLOCK *****
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
#
# The Original Code is Copyright (C) 2021, TIA Centre, University of Warwick
# All rights reserved.
# ***** END GPL LICENSE BLOCK *****

"""Unit test package for HoVerNet."""

import pathlib

import numpy as np
import pytest
import torch

from tiatoolbox import utils
from tiatoolbox.models.architecture import fetch_pretrained_weights
from tiatoolbox.models.architecture.micronet import MicroNet
from tiatoolbox.models.engine.semantic_segmentor import SemanticSegmentor
from tiatoolbox.utils import env_detection as toolbox_env
from tiatoolbox.wsicore.wsireader import WSIReader


def test_functionality(remote_sample, tmp_path):
"""Functionality test."""
tmp_path = str(tmp_path)
sample_wsi = str(remote_sample("wsi1_2k_2k_svs"))
reader = WSIReader.open(sample_wsi)

# * test fast mode (architecture used in PanNuke paper)
patch = reader.read_bounds(
(0, 0, 252, 252), resolution=0.25, units="mpp", coord_space="resolution"
)
patch = MicroNet.preproc(patch)
batch = torch.from_numpy(patch)[None]
model = MicroNet()
fetch_pretrained_weights("micronet_hovernet-consep", f"{tmp_path}/weights.pth")
map_location = utils.misc.select_device(utils.env_detection.has_gpu())
pretrained = torch.load(f"{tmp_path}/weights.pth", map_location=map_location)
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, on_gpu=False)
output, _ = model.postproc(output[0])
assert np.max(np.unique(output)) == 33


def test_value_error():
"""Test to generate value error is num_classes < 2."""
with pytest.raises(ValueError, match="Number of classes should be >=2"):
_ = MicroNet(num_class=1)


@pytest.mark.skipif(
toolbox_env.running_on_travis() or not toolbox_env.has_gpu(),
reason="Local test on machine with GPU.",
)
def test_micronet_output(remote_sample, tmp_path):
"""Tests the output of MicroNet."""
svs_1_small = pathlib.Path(remote_sample("svs-1-small"))
micronet_output = pathlib.Path(remote_sample("micronet-output"))
pretrained_model = "micronet_hovernet-consep"
batch_size = 5
num_loader_workers = 0
num_postproc_workers = 0

predictor = SemanticSegmentor(
pretrained_model=pretrained_model,
batch_size=batch_size,
num_loader_workers=num_loader_workers,
num_postproc_workers=num_postproc_workers,
)

output = predictor.predict(
imgs=[
svs_1_small,
],
save_dir=tmp_path / "output",
)

output = np.load(output[0][1] + ".raw.0.npy")
output_on_server = np.load(str(micronet_output))
output_on_server = np.round(output_on_server, decimals=3)
new_output = np.round(output[500:1000, 1000:1500, :], decimals=3)
true_values = output_on_server == new_output
percent_true = np.count_nonzero(true_values) / np.size(output_on_server)
assert percent_true > 0.999
30 changes: 15 additions & 15 deletions tests/models/test_hovernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,23 @@
ResidualBlock,
TFSamepaddingLayer,
)
from tiatoolbox.wsicore.wsireader import get_wsireader
from tiatoolbox.wsicore.wsireader import WSIReader


def test_functionality(remote_sample, tmp_path):
"""Functionality test."""
tmp_path = str(tmp_path)
sample_wsi = str(remote_sample("wsi1_2k_2k_svs"))
reader = get_wsireader(sample_wsi)
reader = WSIReader.open(sample_wsi)

# * test fast mode (architecture used in Pannuke paper)
# * test fast mode (architecture used in PanNuke paper)
patch = reader.read_bounds(
[0, 0, 256, 256], resolution=0.25, units="mpp", coord_space="resolution"
(0, 0, 256, 256), resolution=0.25, units="mpp", coord_space="resolution"
)
batch = torch.from_numpy(patch)[None]
model = HoVerNet(num_types=6, mode="fast")
fetch_pretrained_weights("hovernet_fast-pannuke", f"{tmp_path}/weigths.pth")
pretrained = torch.load(f"{tmp_path}/weigths.pth")
fetch_pretrained_weights("hovernet_fast-pannuke", f"{tmp_path}/weights.pth")
pretrained = torch.load(f"{tmp_path}/weights.pth")
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, on_gpu=False)
output = [v[0] for v in output]
Expand All @@ -57,12 +57,12 @@ def test_functionality(remote_sample, tmp_path):

# * test fast mode (architecture used for MoNuSAC data)
patch = reader.read_bounds(
[0, 0, 256, 256], resolution=0.25, units="mpp", coord_space="resolution"
(0, 0, 256, 256), resolution=0.25, units="mpp", coord_space="resolution"
)
batch = torch.from_numpy(patch)[None]
model = HoVerNet(num_types=5, mode="fast")
fetch_pretrained_weights("hovernet_fast-monusac", f"{tmp_path}/weigths.pth")
pretrained = torch.load(f"{tmp_path}/weigths.pth")
fetch_pretrained_weights("hovernet_fast-monusac", f"{tmp_path}/weights.pth")
pretrained = torch.load(f"{tmp_path}/weights.pth")
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, on_gpu=False)
output = [v[0] for v in output]
Expand All @@ -71,12 +71,12 @@ def test_functionality(remote_sample, tmp_path):

# * test original mode on CoNSeP dataset (architecture used in HoVerNet paper)
patch = reader.read_bounds(
[0, 0, 270, 270], resolution=0.25, units="mpp", coord_space="resolution"
(0, 0, 270, 270), resolution=0.25, units="mpp", coord_space="resolution"
)
batch = torch.from_numpy(patch)[None]
model = HoVerNet(num_types=5, mode="original")
fetch_pretrained_weights("hovernet_original-consep", f"{tmp_path}/weigths.pth")
pretrained = torch.load(f"{tmp_path}/weigths.pth")
fetch_pretrained_weights("hovernet_original-consep", f"{tmp_path}/weights.pth")
pretrained = torch.load(f"{tmp_path}/weights.pth")
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, on_gpu=False)
output = [v[0] for v in output]
Expand All @@ -85,12 +85,12 @@ def test_functionality(remote_sample, tmp_path):

# * test original mode on Kumar dataset (architecture used in HoVerNet paper)
patch = reader.read_bounds(
[0, 0, 270, 270], resolution=0.25, units="mpp", coord_space="resolution"
(0, 0, 270, 270), resolution=0.25, units="mpp", coord_space="resolution"
)
batch = torch.from_numpy(patch)[None]
model = HoVerNet(num_types=None, mode="original")
fetch_pretrained_weights("hovernet_original-kumar", f"{tmp_path}/weigths.pth")
pretrained = torch.load(f"{tmp_path}/weigths.pth")
fetch_pretrained_weights("hovernet_original-kumar", f"{tmp_path}/weights.pth")
pretrained = torch.load(f"{tmp_path}/weights.pth")
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, on_gpu=False)
output = [v[0] for v in output]
Expand Down
33 changes: 27 additions & 6 deletions tiatoolbox/data/pretrained_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ hovernet_fast-pannuke:
ioconfig:
class: semantic_segmentor.IOSegmentorConfig
kwargs:
input_resolutions:
input_resolutions:
- {"units": "mpp", "resolution": 0.25}
output_resolutions:
- {"units": "mpp", "resolution": 0.25}
Expand All @@ -644,7 +644,7 @@ hovernet_fast-monusac:
ioconfig:
class: semantic_segmentor.IOSegmentorConfig
kwargs:
input_resolutions:
input_resolutions:
- {"units": "mpp", "resolution": 0.25}
output_resolutions:
- {"units": "mpp", "resolution": 0.25}
Expand All @@ -667,7 +667,7 @@ hovernet_original-consep:
ioconfig:
class: semantic_segmentor.IOSegmentorConfig
kwargs:
input_resolutions:
input_resolutions:
- {"units": "mpp", "resolution": 0.25}
output_resolutions:
- {"units": "mpp", "resolution": 0.25}
Expand All @@ -690,7 +690,7 @@ hovernet_original-kumar:
ioconfig:
class: semantic_segmentor.IOSegmentorConfig
kwargs:
input_resolutions:
input_resolutions:
- {"units": "mpp", "resolution": 0.25}
output_resolutions:
- {"units": "mpp", "resolution": 0.25}
Expand All @@ -712,16 +712,37 @@ hovernetplus-oed:
ioconfig:
class: semantic_segmentor.IOSegmentorConfig
kwargs:
input_resolutions:
input_resolutions:
- {"units": "mpp", "resolution": 0.50}
output_resolutions:
- {"units": "mpp", "resolution": 0.50}
- {"units": "mpp", "resolution": 0.50}
- {"units": "mpp", "resolution": 0.50}
- {"units": "mpp", "resolution": 0.50}
- {"units": "mpp", "resolution": 0.50}
margin: 128
tile_shape: [2048, 2048]
patch_input_shape: [256, 256]
patch_output_shape: [164, 164]
stride_shape: [164, 164]
save_resolution: {'units': 'mpp', 'resolution': 0.50}

micronet_hovernet-consep:
url: https://tiatoolbox.dcs.warwick.ac.uk/models/seg/micronet_hovernet-consep.pth
architecture:
class: micronet.MicroNet
kwargs:
num_input_channels: 3
num_class: 2
ioconfig:
class: semantic_segmentor.IOSegmentorConfig
kwargs:
input_resolutions:
- {"units": "mpp", "resolution": 0.25}
output_resolutions:
- {"units": "mpp", "resolution": 0.25}
margin: 128
tile_shape: [2048, 2048]
patch_input_shape: [252, 252]
patch_output_shape: [252, 252]
stride_shape: [150, 150]
save_resolution: {'units': 'mpp', 'resolution': 0.25}
2 changes: 2 additions & 0 deletions tiatoolbox/data/remote_samples.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,5 @@ files:
url: [*modelroot, "predictions/bcss/wsi4_4k_4k.mask.npy"]
small_svs_tissue_mask:
url: [*modelroot, "predictions/tissue_mask/small_svs_tissue_mask.npy"]
micronet-output:
url: [*modelroot, "predictions/nuclei_mask/micronet-output.npy"]
2 changes: 1 addition & 1 deletion tiatoolbox/models/architecture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from tiatoolbox.models.dataset.classification import predefined_preproc_func
from tiatoolbox.utils.misc import download_data

__all__ = ["get_pretrained_model"]
__all__ = ["get_pretrained_model", "fetch_pretrained_weights"]
PRETRAINED_INFO = rcParam["pretrained_model_info"]


Expand Down
4 changes: 2 additions & 2 deletions tiatoolbox/models/architecture/hovernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ def _proc_np_hv(np_map: np.ndarray, hv_map: np.ndarray, fx: float = 1):
return proced_pred

@staticmethod
def _get_instance_info(pred_inst, pred_type=None):
def get_instance_info(pred_inst, pred_type=None):
"""To collect instance information and store it within a dictionary.

Args:
Expand Down Expand Up @@ -703,7 +703,7 @@ def postproc(raw_maps: List[np.ndarray]):

pred_type = tp_map
pred_inst = HoVerNet._proc_np_hv(np_map, hv_map)
nuc_inst_info_dict = HoVerNet._get_instance_info(pred_inst, pred_type)
nuc_inst_info_dict = HoVerNet.get_instance_info(pred_inst, pred_type)

return pred_inst, nuc_inst_info_dict

Expand Down
Loading