Skip to content
3 changes: 2 additions & 1 deletion tests/models/test_feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@
DeepFeatureExtractor,
IOSegmentorConfig,
)
from tiatoolbox.utils import env_detection as toolbox_env
from tiatoolbox.wsicore.wsireader import get_wsireader

ON_GPU = False
ON_GPU = not toolbox_env.running_on_travis() and toolbox_env.has_gpu()

# ----------------------------------------------------

Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_patch_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,7 +926,7 @@ def _test_predictor_output(
for idx, probabilities_ in enumerate(probabilities):
probabilities_max = max(probabilities_)
assert (
np.abs(probabilities_max - probabilities_check[idx]) <= 1e-6
np.abs(probabilities_max - probabilities_check[idx]) <= 5e-6
and predictions[idx] == predictions_check[idx]
), (
pretrained_model,
Expand Down
10 changes: 1 addition & 9 deletions tests/models/test_semantic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,6 @@ def test_functional_segmentor_merging(tmp_path):
[[0, 0, 2, 2], [2, 2, 4, 4]],
save_path=f"{save_dir}/raw.py",
cache_count_path=f"{save_dir}/count.py",
free_prediction=False,
)
assert np.sum(canvas - _output) < 1.0e-8
# a second rerun to test overlapping count,
Expand All @@ -386,7 +385,6 @@ def test_functional_segmentor_merging(tmp_path):
[[0, 0, 2, 2], [2, 2, 4, 4]],
save_path=f"{save_dir}/raw.py",
cache_count_path=f"{save_dir}/count.py",
free_prediction=False,
)
assert np.sum(canvas - _output) < 1.0e-8
# else will leave hanging file pointer
Expand All @@ -402,7 +400,6 @@ def test_functional_segmentor_merging(tmp_path):
[[0, 0, 2, 2], [2, 2, 4, 4]],
save_path=f"{save_dir}/raw.py",
cache_count_path=f"{save_dir}/count.py",
free_prediction=False,
)
del canvas # skipcq

Expand All @@ -414,7 +411,6 @@ def test_functional_segmentor_merging(tmp_path):
[[0, 0, 2, 2], [2, 2, 4, 4]],
save_path=f"{save_dir}/raw.1.py",
cache_count_path=f"{save_dir}/count.1.py",
free_prediction=False,
)
with pytest.raises(ValueError, match=r".*`save_path` does not match.*"):
semantic_segmentor.merge_prediction(
Expand All @@ -423,7 +419,6 @@ def test_functional_segmentor_merging(tmp_path):
[[0, 0, 2, 2], [2, 2, 4, 4]],
save_path=f"{save_dir}/raw.1.py",
cache_count_path=f"{save_dir}/count.py",
free_prediction=False,
)

with pytest.raises(ValueError, match=r".*`cache_count_path` does not match.*"):
Expand All @@ -433,7 +428,6 @@ def test_functional_segmentor_merging(tmp_path):
[[0, 0, 2, 2], [2, 2, 4, 4]],
save_path=f"{save_dir}/raw.py",
cache_count_path=f"{save_dir}/count.1.py",
free_prediction=False,
)
# * test non HW predictions
with pytest.raises(ValueError, match=r".*Prediction is no HW or HWC.*"):
Expand All @@ -443,7 +437,6 @@ def test_functional_segmentor_merging(tmp_path):
[[0, 0, 2, 2], [2, 2, 4, 4]],
save_path=f"{save_dir}/raw.py",
cache_count_path=f"{save_dir}/count.1.py",
free_prediction=False,
)

_rm_dir(save_dir)
Expand All @@ -460,7 +453,6 @@ def test_functional_segmentor_merging(tmp_path):
],
[[0, 0, 2, 2], [2, 2, 4, 4], [0, 4, 2, 6], [4, 0, 6, 2]],
save_path=None,
free_prediction=False,
)
assert np.sum(canvas - _output) < 1.0e-8
del canvas # skipcq
Expand Down Expand Up @@ -702,7 +694,7 @@ def test_behavior_tissue_mask_local(remote_sample, tmp_path):
_test_pred = np.load(f"{save_dir}/raw/0.raw.0.npy")
_test_pred = (_test_pred[..., 1] > 0.75) * 255
# divide 255 to binarize
assert np.mean(np.abs(_cache_pred[..., 0] - _test_pred) / 255) < 1.0e-3
assert np.mean(_cache_pred[..., 0] == _test_pred) > 0.99

_rm_dir(save_dir)
# mainly to test prediction on tile
Expand Down
1 change: 0 additions & 1 deletion tiatoolbox/models/engine/nucleus_instance_segmentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def _process_tile_predictions(
head_tile_shape[::-1],
head_predictions,
head_locations,
free_prediction=True,
)
head_raws.append(head_raw)
_, inst_dict = postproc(head_raws)
Expand Down
64 changes: 37 additions & 27 deletions tiatoolbox/models/engine/semantic_segmentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,26 +366,31 @@ class SemanticSegmentor:

Args:
model (nn.Module): Use externally defined PyTorch model for prediction with.
weights already loaded. Default is `None`. If provided,
`pretrained_model` argument is ignored.
weights already loaded. Default is `None`. If provided,
`pretrained_model` argument is ignored.
pretrained_model (str): Name of the existing models support by tiatoolbox
for processing the data. For a full list of pretrained models, refer to the
`docs <https://tia-toolbox.readthedocs.io/en/latest/pretrained.html>`_.
By default, the corresponding pretrained weights will also be
downloaded. However, you can override with your own set of weights
via the `pretrained_weights` argument. Argument is case insensitive.
for processing the data. For a full list of pretrained models, refer to the
`docs <https://tia-toolbox.readthedocs.io/en/latest/pretrained.html>`_.
By default, the corresponding pretrained weights will also be
downloaded. However, you can override with your own set of weights
via the `pretrained_weights` argument. Argument is case insensitive.
pretrained_weights (str): Path to the weight of the corresponding
`pretrained_model`.
`pretrained_model`.
batch_size (int) : Number of images fed into the model each time.
num_loader_workers (int) : Number of workers to load the data.
Take note that they will also perform preprocessing.
num_postproc_workers (int) : This value is there to maintain input
compatibility with `tiatoolbox.models.classification` and is
not used.
compatibility with `tiatoolbox.models.classification` and is
not used.
verbose (bool): Whether to output logging information.
dataset_class (obj): Dataset class to be used instead of default.
auto_generate_mask (bool): To automatically generate tile/WSI tissue mask
if is not provided.
if is not provided.

Attributes:
process_prediction_per_batch (bool): A flag to denote whether post
processing for inference output is applied after each batch or
after finishing an entire tile or WSI.

Examples:
>>> # Sample output of a network
Expand Down Expand Up @@ -425,6 +430,10 @@ def __init__(
self.ioconfig = ioconfig
self.model = model

# local variables for flagging mode within class,
# subclass should overwritten to alter some specific behavior
self.process_prediction_per_batch = True

# for runtime, such as after wrapping with nn.DataParallel
self._cache_dir = None
self._loader = None
Expand Down Expand Up @@ -667,14 +676,19 @@ def _predict_one_wsi(
sample_infos = np.split(sample_infos, batch_size, axis=0)

sample_outputs = list(zip(sample_infos, sample_outputs))
cum_output.extend(sample_outputs)
# TODO: detach or hook this into a parallel process
self._process_predictions(
cum_output, wsi_reader, ioconfig, save_path, cache_dir
)
if self.process_prediction_per_batch:
self._process_predictions(
sample_outputs, wsi_reader, ioconfig, save_path, cache_dir
)
else:
cum_output.extend(sample_outputs)
pbar.update()
pbar.close()

self._process_predictions(
cum_output, wsi_reader, ioconfig, save_path, cache_dir
)

# clean up the cache directories
shutil.rmtree(cache_dir)

Expand All @@ -689,6 +703,8 @@ def _process_predictions(
"""Define how the aggregated predictions are processed.

This includes merging the prediction if necessary and also saving afterwards.
Note that items within `cum_batch_predictions` will be consumed during
the operation.

Args:
cum_batch_predictions (list): List of batch predictions. Each item
Expand All @@ -701,6 +717,9 @@ def _process_predictions(
cache_dir (str): Root path to cache current WSI data.

"""
if len(cum_batch_predictions) == 0:
return

# assume predictions is N, each item has L output element
locations, predictions = list(zip(*cum_batch_predictions))
# Nx4 (N x [tl_x, tl_y, br_x, br_y), denotes the location of output patch
Expand Down Expand Up @@ -729,7 +748,6 @@ def _process_predictions(
merged_locations,
save_path=sub_save_path,
cache_count_path=sub_count_path,
free_prediction=True,
)

@staticmethod
Expand All @@ -739,7 +757,6 @@ def merge_prediction(
locations: Union[List, np.ndarray],
save_path: Union[str, pathlib.Path] = None,
cache_count_path: Union[str, pathlib.Path] = None,
free_prediction: bool = True,
):
"""Merge patch-level predictions to form a 2-dimensional prediction map.

Expand All @@ -760,9 +777,6 @@ def merge_prediction(
save_path (str): Location to save the assembled image.
cache_count_path (str): Location to store the canvas for counting
how many times each pixel get overlapped when assembling.
free_prediction (bool): If this is `True`, `predictions` will
be modified in place and each patch will be replace with `None`
once processed. This is to save memory when assembling.

Returns:
:class:`numpy.ndarray`: An image contains merged data.
Expand All @@ -778,7 +792,6 @@ def merge_prediction(
... [0, 0, 2, 2],
... [2, 2, 4, 4]],
... save_path=None,
... free_prediction=False,
... )
array([[1, 1, 0, 0],
[1, 1, 0, 0],
Expand Down Expand Up @@ -848,7 +861,7 @@ def index(arr, tl, br):
return arr[tl[0] : br[0], tl[1] : br[1]]

patch_infos = list(zip(locations, predictions))
for patch_idx, patch_info in enumerate(patch_infos):
for _, patch_info in enumerate(patch_infos):
# position is assumed to be in XY coordinate
(bound_in_wsi, prediction) = patch_info
# convert to XY to YX, and in tl, br
Expand Down Expand Up @@ -898,10 +911,6 @@ def index(arr, tl, br):
new_avg_pred = (old_raw_pred + patch_pred) / new_count
index(cum_canvas, tl_in_wsi, br_in_wsi)[:] = new_avg_pred
index(count_canvas, tl_in_wsi, br_in_wsi)[:] = new_count

# remove prediction without altering list ordering or length
if free_prediction:
patch_infos[patch_idx] = None
if not is_on_drive:
cum_canvas /= count_canvas + 1.0e-6
return cum_canvas
Expand Down Expand Up @@ -1192,6 +1201,7 @@ def __init__(
auto_generate_mask=auto_generate_mask,
dataset_class=dataset_class,
)
self.process_prediction_per_batch = False

def _process_predictions(
self,
Expand Down