Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 27 additions & 18 deletions tests/test_wsi_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from tiatoolbox.tools.registration.wsi_registration import (
AffineWSITransformer,
DFBRegister,
apply_affine_transformation,
apply_bspline_transform,
estimate_bspline_transform,
match_histograms,
Expand Down Expand Up @@ -187,17 +188,18 @@ def test_warning(
moving_image,
fixed_mask,
moving_mask,
caplog,
):
"""Test for displaying warning in prealignment function."""
fixed_img = imread(pathlib.Path(fixed_image))
moving_img = imread(pathlib.Path(moving_image))
fixed_mask = imread(pathlib.Path(fixed_mask))
moving_mask = imread(pathlib.Path(moving_mask))
fixed_img, moving_img = fixed_img[:, :, 0], moving_img[:, :, 0]
with pytest.warns(UserWarning):
_ = prealignment(
fixed_img, moving_img, fixed_mask, moving_mask, dice_overlap=0.9
)

_ = prealignment(fixed_img, moving_img, fixed_mask, moving_mask, dice_overlap=0.9)

assert "Not able to find the best transformation" in caplog.text


def test_match_histogram_inputs():
Expand All @@ -213,8 +215,14 @@ def test_match_histogram_inputs():
def test_match_histograms():
"""Test for preprocessing/normalization of an image pair."""
image_a = np.random.randint(256, size=(256, 256))
image_b = np.random.randint(256, size=(256, 256))
_, _ = match_histograms(image_a, image_b, 3)
image_b = np.zeros(shape=(256, 256), dtype=int)
out_a, out_b = match_histograms(image_a, image_b, 3)
assert np.all(out_a == image_a)
assert np.all(out_b == 255)

out_a, out_b = match_histograms(image_b, image_a, 3)
assert np.all(out_a == 255)
assert np.all(out_b == image_a)

image_a = np.random.randint(256, size=(256, 256, 1))
image_b = np.random.randint(256, size=(256, 256, 1))
Expand Down Expand Up @@ -438,31 +446,32 @@ def test_bspline_transform(fixed_image, moving_image, fixed_mask, moving_mask):
"""Test for estimate_bspline_transform function."""
fixed_img = imread(fixed_image)
moving_img = imread(moving_image)
fixed_msk = imread(fixed_mask)
moving_msk = imread(moving_mask)
fixed_mask_ = imread(fixed_mask)
moving_mask_ = imread(moving_mask)

rigid_transform = np.array(
[[-0.99683, -0.00333, 338.69983], [-0.03201, -0.98420, 770.22941], [0, 0, 1]]
)
moving_img = cv2.warpAffine(
moving_img, rigid_transform[0:-1][:], fixed_img.shape[:2][::-1]
)
moving_msk = cv2.warpAffine(
moving_msk, rigid_transform[0:-1][:], fixed_img.shape[:2][::-1]
)
moving_img = apply_affine_transformation(fixed_img, moving_img, rigid_transform)
moving_mask_ = apply_affine_transformation(fixed_img, moving_mask_, rigid_transform)

# Grayscale images as input
transform = estimate_bspline_transform(
fixed_img[:, :, 0], moving_img[:, :, 0], fixed_msk[:, :, 0], moving_msk[:, :, 0]
fixed_img[:, :, 0],
moving_img[:, :, 0],
fixed_mask_[:, :, 0],
moving_mask_[:, :, 0],
)
_ = apply_bspline_transform(fixed_img[:, :, 0], moving_img[:, :, 0], transform)

# RGB images as input
transform = estimate_bspline_transform(fixed_img, moving_img, fixed_msk, moving_msk)
transform = estimate_bspline_transform(
fixed_img, moving_img, fixed_mask_, moving_mask_
)

_ = apply_bspline_transform(fixed_img, moving_img, transform)
registered_msk = apply_bspline_transform(fixed_msk, moving_msk, transform)
mask_overlap = dice(fixed_msk, registered_msk)
registered_msk = apply_bspline_transform(fixed_mask_, moving_mask_, transform)
mask_overlap = dice(fixed_mask_, registered_msk)
assert mask_overlap > 0.75


Expand Down
66 changes: 45 additions & 21 deletions tiatoolbox/tools/registration/wsi_registration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import itertools
import warnings
from numbers import Number
from typing import Dict, Tuple, Union

Expand All @@ -13,6 +12,7 @@
from skimage.registration import phase_cross_correlation
from skimage.util import img_as_float

from tiatoolbox import logger
from tiatoolbox.tools.patchextraction import PatchExtractor
from tiatoolbox.utils.metrics import dice
from tiatoolbox.utils.transforms import imresize
Expand Down Expand Up @@ -86,6 +86,32 @@ def compute_center_of_mass(mask: np.ndarray) -> tuple:
return (x_coord_center, y_coord_center)


def apply_affine_transformation(fixed_img, moving_img, transform_initializer):
"""Apply affine transformation using OpenCV.

Args:
fixed_img (:class:`numpy.ndarray`):
A fixed image.
moving_img (:class:`numpy.ndarray`):
A moving image.
transform_initializer (:class:`numpy.ndarray`):
A rigid transformation matrix.

Returns:
:class:`numpy.ndarray`:
A transformed image.

Examples:
>>> moving_image = apply_affine_transformation(
... fixed_image, moving_image, transform_initializer
... )

"""
return cv2.warpAffine(
moving_img, transform_initializer[0:-1][:], fixed_img.shape[:2][::-1]
)


def prealignment(
fixed_img: np.ndarray,
moving_img: np.ndarray,
Expand Down Expand Up @@ -210,18 +236,16 @@ def prealignment(
pre_transform = all_transform[all_dice.index(dice_after)]

# Apply transformation to both image and mask
moving_img = cv2.warpAffine(
orig_moving_img, pre_transform[0:-1][:], orig_fixed_img.shape[:2][::-1]
)
moving_mask = cv2.warpAffine(
moving_mask, pre_transform[0:-1][:], fixed_img.shape[:2][::-1]
moving_img = apply_affine_transformation(
orig_fixed_img, orig_moving_img, pre_transform
)
moving_mask = apply_affine_transformation(fixed_img, moving_mask, pre_transform)

return pre_transform, moving_img, moving_mask, dice_after

warnings.warn(
logger.warning(
"Not able to find the best transformation for pre-alignment. "
"Try changing the values for 'dice_overlap' and 'rotation_step'.",
stacklevel=2,
)
return np.eye(3), moving_img, moving_mask, dice_before

Expand Down Expand Up @@ -850,12 +874,13 @@ def perform_dfbregister(
)

# Apply transformation
moving_img = cv2.warpAffine(
moving_img, tissue_transform[0:-1][:], fixed_img.shape[:2][::-1]
moving_img = apply_affine_transformation(
fixed_img, moving_img, tissue_transform
)
moving_mask = cv2.warpAffine(
moving_mask, tissue_transform[0:-1][:], fixed_img.shape[:2][::-1]
moving_mask = apply_affine_transformation(
fixed_img, moving_mask, tissue_transform
)

return tissue_transform, moving_img, moving_mask

def perform_dfbregister_block_wise(
Expand Down Expand Up @@ -963,11 +988,9 @@ def perform_dfbregister_block_wise(
)

# Apply transformation
moving_img = cv2.warpAffine(
moving_img, block_transform[0:-1][:], fixed_img.shape[:2][::-1]
)
moving_mask = cv2.warpAffine(
moving_mask, block_transform[0:-1][:], fixed_img.shape[:2][::-1]
moving_img = apply_affine_transformation(fixed_img, moving_img, block_transform)
moving_mask = apply_affine_transformation(
fixed_img, moving_mask, block_transform
)

return block_transform, moving_img, moving_mask
Expand Down Expand Up @@ -1025,12 +1048,13 @@ def register(
)
else:
# Apply transformation to both image and mask
moving_img = cv2.warpAffine(
moving_img, transform_initializer[0:-1][:], fixed_img.shape[:2][::-1]
moving_img = apply_affine_transformation(
fixed_img, moving_img, transform_initializer
)
moving_mask = cv2.warpAffine(
moving_mask, transform_initializer[0:-1][:], fixed_img.shape[:2][::-1]
moving_mask = apply_affine_transformation(
fixed_img, moving_mask, transform_initializer
)

before_dice = dice(fixed_mask, moving_mask)

# Estimate transform using tissue regions
Expand Down