diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 0e9759aaf1..9339897d7a 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -150,3 +150,5 @@ def reduce_meta_tensor(meta_tensor): return _rebuild_meta, (type(meta_tensor), storage, dtype, metadata) ForkingPickler.register(MetaTensor, reduce_meta_tensor) + +from .ultrasound_confidence_map import UltrasoundConfidenceMap diff --git a/monai/data/ultrasound_confidence_map.py b/monai/data/ultrasound_confidence_map.py new file mode 100644 index 0000000000..4970288b98 --- /dev/null +++ b/monai/data/ultrasound_confidence_map.py @@ -0,0 +1,376 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Literal, Optional, Tuple + +import numpy as np + +from monai.utils import min_version, optional_import + +__all__ = ["UltrasoundConfidenceMap"] + +cv2, _ = optional_import("cv2") +Oct2Py, _ = optional_import("oct2py", "5.6.0", min_version, "Oct2Py") +csc_matrix, _ = optional_import("scipy.sparse", "1.7.1", min_version, "csc_matrix") +spsolve, _ = optional_import("scipy.sparse.linalg", "1.7.1", min_version, "spsolve") +hilbert, _ = optional_import("scipy.signal", "1.7.1", min_version, "hilbert") + + +class UltrasoundConfidenceMap: + """Compute confidence map from an ultrasound image. + This transform uses the method introduced by Karamalis et al. in https://doi.org/10.1016/j.media.2012.07.005. + It generates a confidence map by setting source and sink points in the image and computing the probability + for random walks to reach the source for each pixel. + + Args: + alpha (float, optional): Alpha parameter. Defaults to 2.0. + beta (float, optional): Beta parameter. Defaults to 90.0. + gamma (float, optional): Gamma parameter. Defaults to 0.05. + mode (str, optional): 'RF' or 'B' mode data. Defaults to 'B'. + sink_mode (str, optional): Sink mode. Defaults to 'all'. If 'mask' is selected, a mask must be when calling the transform. + """ + + def __init__( + self, + alpha: float = 2.0, + beta: float = 90.0, + gamma: float = 0.05, + mode: Literal["RF", "B"] = "B", + sink_mode: Literal["all", "mid", "min", "mask"] = "all", + backend: Literal["scipy", "octave"] = "scipy", + ): + + # The hyperparameters for confidence map estimation + self.alpha = alpha + self.beta = beta + self.gamma = gamma + self.mode = mode + self.sink_mode = sink_mode + self.backend = backend + + # The precision to use for all computations + self.eps = np.finfo("float64").eps + + # Store sink indices for external use + self._sink_indices = np.array([], dtype="float64") + + if self.backend == "octave": + # Octave instance for computing the confidence map + self.oc = Oct2Py() + + def sub2ind(self, size: Tuple[int], rows: np.ndarray, cols: np.ndarray) -> np.ndarray: + """Converts row and column subscripts into linear indices, + basically the copy of the MATLAB function of the same name. + https://www.mathworks.com/help/matlab/ref/sub2ind.html + + This function is Pythonic so the indices start at 0. + + Args: + size Tuple[int]: Size of the matrix + rows (np.ndarray): Row indices + cols (np.ndarray): Column indices + + Returns: + indices (np.ndarray): 1-D array of linear indices + """ + indices = rows + cols * size[0] + return indices + + def get_seed_and_labels( + self, data: np.ndarray, sink_mode: str = "all", sink_mask: Optional[np.ndarray] = None + ) -> Tuple[np.ndarray, np.ndarray]: + """Get the seed and label arrays for the max-flow algorithm + + Args: + data: Input array + sink_mode (str, optional): Sink mode. Defaults to 'all'. + sink_mask (np.ndarray, optional): Sink mask. Defaults to None. + + Returns: + Tuple[np.ndarray, np.ndarray]: Seed and label arrays + """ + + # Seeds and labels (boundary conditions) + seeds = np.array([], dtype="float64") + labels = np.array([], dtype="float64") + + # Indices for all columns + sc = np.arange(data.shape[1], dtype="float64") + + # SOURCE ELEMENTS - 1st matrix row + # Indices for 1st row, it will be broadcasted with sc + sr_up = np.array([0]) + seed = self.sub2ind(data.shape, sr_up, sc).astype("float64") + seed = np.unique(seed) + seeds = np.concatenate((seeds, seed)) + + # Label 1 + label = np.ones_like(seed) + labels = np.concatenate((labels, label)) + + # Create seeds for sink elements + + if sink_mode == "all": + # All elements in the last row + sr_down = np.ones_like(sc) * (data.shape[0] - 1) + self._sink_indices = np.array([sr_down, sc], dtype="int32") + seed = self.sub2ind(data.shape, sr_down, sc).astype("float64") + + elif sink_mode == "mid": + # Middle element in the last row + sc_down = np.array([data.shape[1] // 2]) + sr_down = np.ones_like(sc_down) * (data.shape[0] - 1) + self._sink_indices = np.array([sr_down, sc_down], dtype="int32") + seed = self.sub2ind(data.shape, sr_down, sc_down).astype("float64") + + elif sink_mode == "min": + # Minimum element in the last row (excluding 10% from the edges) + ten_percent = int(data.shape[1] * 0.1) + min_val = np.min(data[-1, ten_percent:-ten_percent]) + min_idxs = np.where(data[-1, ten_percent:-ten_percent] == min_val)[0] + ten_percent + sc_down = min_idxs + sr_down = np.ones_like(sc_down) * (data.shape[0] - 1) + self._sink_indices = np.array([sr_down, sc_down], dtype="int32") + seed = self.sub2ind(data.shape, sr_down, sc_down).astype("float64") + + elif sink_mode == "mask": + # All elements in the mask + coords = np.where(sink_mask != 0) + sr_down = coords[0] + sc_down = coords[1] + self._sink_indices = np.array([sr_down, sc_down], dtype="int32") + seed = self.sub2ind(data.shape, sr_down, sc_down).astype("float64") + + seed = np.unique(seed) + seeds = np.concatenate((seeds, seed)) + + # Label 2 + label = np.ones_like(seed) * 2 + labels = np.concatenate((labels, label)) + + return seeds, labels + + def normalize(self, inp: np.ndarray) -> np.ndarray: + """Normalize an array to [0, 1]""" + return (inp - np.min(inp)) / (np.ptp(inp) + self.eps) + + def attenuation_weighting(self, A: np.ndarray, alpha: float) -> np.ndarray: + """Compute attenuation weighting + + Args: + A (np.ndarray): Image + alpha: Attenuation coefficient (see publication) + + Returns: + W (np.ndarray): Weighting expressing depth-dependent attenuation + """ + + # Create depth vector and repeat it for each column + Dw = np.linspace(0, 1, A.shape[0], dtype="float64") + Dw = np.tile(Dw.reshape(-1, 1), (1, A.shape[1])) + + W = 1.0 - np.exp(-alpha * Dw) # Compute exp inline + + return W + + def confidence_laplacian( + self, P: np.ndarray, A: np.ndarray, beta: float, gamma: float + ) -> csc_matrix: # type: ignore + """Compute 6-Connected Laplacian for confidence estimation problem + + Args: + P (np.ndarray): The index matrix of the image with boundary padding. + A (np.ndarray): The padded image. + beta (float): Random walks parameter that defines the sensitivity of the Gaussian weighting function. + gamma (float): Horizontal penalty factor that adjusts the weight of horizontal edges in the Laplacian. + + Returns: + L (csc_matrix): The 6-connected Laplacian matrix used for confidence map estimation. + """ + + m, _ = P.shape + + P = P.T.flatten() + A = A.T.flatten() + + p = np.where(P > 0)[0] + + i = P[p] - 1 # Index vector + j = P[p] - 1 # Index vector + # Entries vector, initially for diagonal + s = np.zeros_like(p, dtype="float64") + + vl = 0 # Vertical edges length + + edge_templates = [ + -1, # Vertical edges + 1, + m - 1, # Diagonal edges + m + 1, + -m - 1, + -m + 1, + m, # Horizontal edges + -m, + ] + + vertical_end = None + diagonal_end = None + + for iter_idx, k in enumerate(edge_templates): + + Q = P[p + k] + + q = np.where(Q > 0)[0] + + ii = P[p[q]] - 1 + i = np.concatenate((i, ii)) + jj = Q[q] - 1 + j = np.concatenate((j, jj)) + W = np.abs(A[p[ii]] - A[p[jj]]) # Intensity derived weight + s = np.concatenate((s, W)) + + if iter_idx == 1: + vertical_end = s.shape[0] # Vertical edges length + elif iter_idx == 5: + diagonal_end = s.shape[0] # Diagonal edges length + + # Normalize weights + s = self.normalize(s) + + # Horizontal penalty + s[:vertical_end] += gamma + # s[vertical_end:diagonal_end] += gamma * np.sqrt(2) # --> In the paper it is sqrt(2) since the diagonal edges are longer yet does not exist in the original code + + # Normalize differences + s = self.normalize(s) + + # Gaussian weighting function + s = -( + (np.exp(-beta * s, dtype="float64")) + 1.0e-6 + ) # --> This epsilon changes results drastically default: 1.e-6 + + # Create Laplacian, diagonal missing + L = csc_matrix((s, (i, j))) + + # Reset diagonal weights to zero for summing + # up the weighted edge degree in the next step + L.setdiag(0) + + # Weighted edge degree + D = np.abs(L.sum(axis=0).A)[0] + + # Finalize Laplacian by completing the diagonal + L.setdiag(D) + + return L + + def _solve_linear_system(self, D, rhs, tol=1.0e-8, mode="scipy"): + + if mode == "scipy": + X = spsolve(D, rhs) + + elif mode == "octave": + X = self.oc.mldivide(D, rhs)[:, 0] + + return X + + def confidence_estimation(self, A, seeds, labels, beta, gamma, backend): + """Compute confidence map + + Args: + A (np.ndarray): Processed image. + seeds (np.ndarray): Seeds for the random walks framework. These are indices of the source and sink nodes. + labels (np.ndarray): Labels for the random walks framework. These represent the classes or groups of the seeds. + beta: Random walks parameter that defines the sensitivity of the Gaussian weighting function. + gamma: Horizontal penalty factor that adjusts the weight of horizontal edges in the Laplacian. + + Returns: + map: Confidence map which shows the probability of each pixel belonging to the source or sink group. + """ + + # Index matrix with boundary padding + G = np.arange(1, A.shape[0] * A.shape[1] + 1).reshape(A.shape[1], A.shape[0]).T + pad = 1 + + G = np.pad(G, (pad, pad), "constant", constant_values=(0, 0)) + B = np.pad(A, (pad, pad), "constant", constant_values=(0, 0)) + + # Laplacian + D = self.confidence_laplacian(G, B, beta, gamma) + + # Select marked columns from Laplacian to create L_M and B^T + B = D[:, seeds] + + # Select marked nodes to create B^T + N = np.sum(G > 0).item() + i_U = np.setdiff1d(np.arange(N), seeds.astype(int)) # Index of unmarked nodes + B = B[i_U, :] + + # Remove marked nodes from Laplacian by deleting rows and cols + keep_indices = np.setdiff1d(np.arange(D.shape[0]), seeds) + D = csc_matrix(D[keep_indices, :][:, keep_indices]) + + # Define M matrix + M = np.zeros((seeds.shape[0], 1), dtype="float64") + M[:, 0] = labels == 1 + + # Right-handside (-B^T*M) + rhs = -B @ M # type: ignore + + # Solve linear system + x = self._solve_linear_system(D, rhs, tol=1.0e-3, mode=backend) + + # Prepare output + probabilities = np.zeros((N,), dtype="float64") + # Probabilities for unmarked nodes + probabilities[i_U] = x + # Max probability for marked node + probabilities[seeds[labels == 1].astype(int)] = 1.0 + + # Final reshape with same size as input image (no padding) + probabilities = probabilities.reshape((A.shape[1], A.shape[0])).T + + return probabilities + + def __call__(self, data: np.ndarray, sink_mask: Optional[np.ndarray] = None) -> np.ndarray: + """Compute the confidence map + + Args: + data (np.ndarray): RF ultrasound data (one scanline per column) + + Returns: + map (np.ndarray): Confidence map + """ + + # Normalize data + data = data.astype("float64") + data = self.normalize(data) + + if self.mode == "RF": + # MATLAB hilbert applies the Hilbert transform to columns + data = np.abs(hilbert(data, axis=0)).astype("float64") # type: ignore + + seeds, labels = self.get_seed_and_labels(data, self.sink_mode, sink_mask) + + # Attenuation with Beer-Lambert + W = self.attenuation_weighting(data, self.alpha) + + # Apply weighting directly to image + # Same as applying it individually during the formation of the + # Laplacian + data = data * W + + # Find condidence values + map_ = self.confidence_estimation(data, seeds, labels, self.beta, self.gamma, self.backend) + + return map_ diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index cdad6ec6c3..477ec7a8bd 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -129,6 +129,7 @@ ShiftIntensity, StdShiftIntensity, ThresholdIntensity, + UltrasoundConfidenceMapTransform, ) from .intensity.dictionary import ( AdjustContrastd, diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index cdcada6dda..292edfb356 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -17,7 +17,7 @@ from abc import abstractmethod from collections.abc import Callable, Iterable, Sequence from functools import partial -from typing import Any, Tuple, Literal +from typing import Any, Literal, Tuple from warnings import warn import numpy as np @@ -26,6 +26,7 @@ from monai.config import DtypeLike from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor from monai.data.meta_obj import get_track_meta +from monai.data.ultrasound_confidence_map import UltrasoundConfidenceMap from monai.data.utils import get_random_patch, get_valid_patch_size from monai.networks.layers import GaussianFilter, HilbertTransform, MedianFilter, SavitzkyGolayFilter from monai.transforms.transform import RandomizableTransform, Transform @@ -37,10 +38,7 @@ from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_tensor, get_equivalent_dtype skimage, _ = optional_import("skimage", "0.19.0", min_version) -cv2, _ = optional_import("cv2") -Oct2Py, _ = optional_import("oct2py", "5.6.0", min_version, "Oct2Py") -csc_matrix, _ = optional_import("scipy.sparse", "1.7.1", min_version, "csc_matrix") -hilbert, _ = optional_import("scipy.signal", "1.7.1", min_version, "hilbert") + __all__ = [ "RandGaussianNoise", @@ -81,7 +79,7 @@ "RandIntensityRemap", "ForegroundMask", "ComputeHoVerMaps", - "UltrasoundConfidenceMap", + "UltrasoundConfidenceMapTransform", ] @@ -2583,7 +2581,8 @@ def __call__(self, mask: NdarrayOrTensor): hv_maps = convert_to_tensor(np.concatenate([h_map, v_map]), track_meta=get_track_meta()) return hv_maps -class UltrasoundConfidenceMap(Transform): + +class UltrasoundConfidenceMapTransform(Transform): """Compute confidence map from an ultrasound image. This transform uses the method introduced by Karamalis et al. in https://doi.org/10.1016/j.media.2012.07.005. It generates a confidence map by setting source and sink points in the image and computing the probability @@ -2594,6 +2593,7 @@ class UltrasoundConfidenceMap(Transform): beta (float, optional): Beta parameter. Defaults to 90.0. gamma (float, optional): Gamma parameter. Defaults to 0.05. mode (str, optional): 'RF' or 'B' mode data. Defaults to 'B'. + sink_mode (str, optional): Sink mode. Defaults to 'all'. If 'mask' is selected, a mask must be when calling the transform. """ def __init__( @@ -2602,286 +2602,47 @@ def __init__( beta: float = 90.0, gamma: float = 0.05, mode: Literal["RF", "B"] = "B", - ): - + sink_mode: Literal["all", "mid", "min", "mask"] = "all", + backend: Literal["scipy", "octave"] = "scipy", + ) -> None: self.alpha = alpha self.beta = beta self.gamma = gamma self.mode = mode + self.sink_mode = sink_mode + self.backend = backend - # The precision to use for all computations - self.eps = np.finfo("float64").eps - - # Octave instance for computing the confidence map - self.oc = Oct2Py() - - def sub2ind(self, size: Tuple[int], rows: np.ndarray, cols: np.ndarray) -> np.ndarray: - """Converts row and column subscripts into linear indices, - basically the copy of the MATLAB function of the same name. - https://www.mathworks.com/help/matlab/ref/sub2ind.html - - This function is Pythonic so the indices start at 0. - - Args: - size Tuple[int]: Size of the matrix - rows (np.ndarray): Row indices - cols (np.ndarray): Column indices - - Returns: - indices (np.ndarray): 1-D array of linear indices - """ - indices = rows + cols * size[0] - return indices - - def get_seed_and_labels(self, data : np.ndarray) -> Tuple[np.ndarray, np.ndarray]: - """Get the seed and label arrays for the max-flow algorithm - - Args: - data: Input array - - Returns: - Tuple[np.ndarray, np.ndarray]: Seed and label arrays - """ - - # Seeds and labels (boundary conditions) - seeds = np.array([], dtype="float64") - labels = np.array([], dtype="float64") - - # Indices for all columns - sc = np.arange(data.shape[1], dtype="float64") - - # SOURCE ELEMENTS - 1st matrix row - # Indices for 1st row, it will be broadcasted with sc - sr_up = np.array([0]) - seed = self.sub2ind(data.shape, sr_up, sc).astype("float64") - seed = np.unique(seed) - seeds = np.concatenate((seeds, seed)) - - # Label 1 - label = np.ones_like(seed) - labels = np.concatenate((labels, label)) - - # SINK ELEMENTS - last image row - sr_down = np.ones_like(sc) * (data.shape[0] - 1) - seed = self.sub2ind(data.shape, sr_down, sc).astype("float64") - - seed = np.unique(seed) - seeds = np.concatenate((seeds, seed)) - - # Label 2 - label = np.ones_like(seed) * 2 - labels = np.concatenate((labels, label)) - - return seeds, labels - - def normalize(self, inp: np.ndarray) -> np.ndarray: - """Normalize an array to [0, 1]""" - return (inp - np.min(inp)) / (np.ptp(inp) + self.eps) - - def attenuation_weighting(self, A: np.ndarray, alpha: float) -> np.ndarray: - """Compute attenuation weighting - - Args: - A (np.ndarray): Image - alpha: Attenuation coefficient (see publication) - - Returns: - W (np.ndarray): Weighting expressing depth-dependent attenuation - """ - - # Create depth vector and repeat it for each column - Dw = np.linspace(0, 1, A.shape[0], dtype="float64") - Dw = np.tile(Dw.reshape(-1, 1), (1, A.shape[1])) - - W = 1.0 - np.exp(-alpha * Dw) # Compute exp inline - - return W - - def confidence_laplacian( - self, P: np.ndarray, A: np.ndarray, beta: float, gamma: float - ) -> csc_matrix: # type: ignore - """Compute 6-Connected Laplacian for confidence estimation problem - - Args: - P (np.ndarray): The index matrix of the image with boundary padding. - A (np.ndarray): The padded image. - beta (float): Random walks parameter that defines the sensitivity of the Gaussian weighting function. - gamma (float): Horizontal penalty factor that adjusts the weight of horizontal edges in the Laplacian. - - Returns: - L (csc_matrix): The 6-connected Laplacian matrix used for confidence map estimation. - """ - - m, _ = P.shape - - P = P.T.flatten() - A = A.T.flatten() - - p = np.where(P > 0)[0] - - i = P[p] - 1 # Index vector - j = P[p] - 1 # Index vector - # Entries vector, initially for diagonal - s = np.zeros_like(p, dtype="float64") - - vl = 0 # Vertical edges length + if self.mode not in ["B", "RF"]: + raise ValueError(f"Unknown mode: {self.mode}. Supported modes are 'B' and 'RF'.") - edge_templates = [ - -1, # Vertical edges - 1, - m - 1, # Diagonal edges - m + 1, - -m - 1, - -m + 1, - m, # Horizontal edges - -m, - ] + if self.backend not in ["scipy", "octave"]: + raise ValueError(f"Unknown backend: {self.backend}. Supported modes are 'scipy' and 'octave'.") - vertical_end = None - diagonal_end = None - - for iter_idx, k in enumerate(edge_templates): - - Q = P[p + k] - - q = np.where(Q > 0)[0] - - ii = P[p[q]] - 1 - i = np.concatenate((i, ii)) - jj = Q[q] - 1 - j = np.concatenate((j, jj)) - W = np.abs(A[p[ii]] - A[p[jj]]) # Intensity derived weight - s = np.concatenate((s, W)) - - if iter_idx == 1: - vertical_end = s.shape[0] # Vertical edges length - elif iter_idx == 5: - diagonal_end = s.shape[0] # Diagonal edges length - - # Normalize weights - s = self.normalize(s) - - # Horizontal penalty - s[:vertical_end] += gamma - #s[vertical_end:diagonal_end] += gamma * np.sqrt(2) # --> In the paper it is sqrt(2) since the diagonal edges are longer yet does not exist in the original code - - # Normalize differences - s = self.normalize(s) - - # Gaussian weighting function - s = -( - (np.exp(-beta * s, dtype="float64")) + 1.0e-6 - ) # --> This epsilon changes results drastically default: 1.e-6 - - # Create Laplacian, diagonal missing - L = csc_matrix((s, (i, j))) - - # Reset diagonal weights to zero for summing - # up the weighted edge degree in the next step - L.setdiag(0) - - # Weighted edge degree - D = np.abs(L.sum(axis=0).A)[0] - - # Finalize Laplacian by completing the diagonal - L.setdiag(D) - - return L - - def confidence_estimation(self, A, seeds, labels, beta, gamma): - """Compute confidence map - - Args: - A (np.ndarray): Processed image. - seeds (np.ndarray): Seeds for the random walks framework. These are indices of the source and sink nodes. - labels (np.ndarray): Labels for the random walks framework. These represent the classes or groups of the seeds. - beta: Random walks parameter that defines the sensitivity of the Gaussian weighting function. - gamma: Horizontal penalty factor that adjusts the weight of horizontal edges in the Laplacian. - - Returns: - map: Confidence map which shows the probability of each pixel belonging to the source or sink group. - """ - - # Index matrix with boundary padding - G = np.arange(1, A.shape[0] * A.shape[1] + 1).reshape(A.shape[1], A.shape[0]).T - pad = 1 - - G = np.pad(G, (pad, pad), "constant", constant_values=(0, 0)) - B = np.pad(A, (pad, pad), "constant", constant_values=(0, 0)) - - # Laplacian - D = self.confidence_laplacian(G, B, beta, gamma) - - # Select marked columns from Laplacian to create L_M and B^T - B = D[:, seeds] - - # Select marked nodes to create B^T - N = np.sum(G > 0).item() - i_U = np.setdiff1d(np.arange(N), seeds.astype(int)) # Index of unmarked nodes - B = B[i_U, :] - - # Remove marked nodes from Laplacian by deleting rows and cols - keep_indices = np.setdiff1d(np.arange(D.shape[0]), seeds) - D = csc_matrix(D[keep_indices, :][:, keep_indices]) - - # Define M matrix - M = np.zeros((seeds.shape[0], 1), dtype="float64") - M[:, 0] = labels == 1 - - # Right-handside (-B^T*M) - rhs = -B @ M # type: ignore - - # Solve system exactly - x = self.oc.mldivide(D, rhs)[:, 0] - - # Prepare output - probabilities = np.zeros((N,), dtype="float64") - # Probabilities for unmarked nodes - probabilities[i_U] = x - # Max probability for marked node - probabilities[seeds[labels == 1].astype(int)] = 1.0 - - # Final reshape with same size as input image (no padding) - probabilities = probabilities.reshape((A.shape[1], A.shape[0])).T - - return probabilities - - def __call__(self, data: np.ndarray, downsample=None) -> np.ndarray: - """Compute the confidence map - - Args: - data (np.ndarray): RF ultrasound data (one scanline per column) - - Returns: - map (np.ndarray): Confidence map - """ - - # Normalize data - data = data.astype("float64") - data = self.normalize(data) + if self.sink_mode not in ["all", "mid", "min", "mask"]: + raise ValueError( + f"Unknown sink mode: {self.sink_mode}. Supported modes are 'all', 'mid', 'min' and 'mask'." + ) - if self.mode == "RF": - # MATLAB hilbert applies the Hilbert transform to columns - data = np.abs(hilbert(data, axis=0)).astype("float64") # type: ignore + self._compute_conf_map = UltrasoundConfidenceMap(self.alpha, self.beta, self.gamma, self.mode, self.sink_mode) - org_H, org_W = data.shape - if downsample is not None: - data = cv2.resize(data, (org_W // downsample, org_H // downsample), interpolation=cv2.INTER_CUBIC) + def __call__(self, img: NdarrayOrTensor, mask: NdarrayOrTensor | None = None) -> NdarrayOrTensor: - seeds, labels = self.get_seed_and_labels(data) + if self.sink_mode == "mask" and mask is None: + raise ValueError("Mask must be provided when sink mode is 'mask'.") - # Attenuation with Beer-Lambert - W = self.attenuation_weighting(data, self.alpha) + img = convert_to_tensor(img, track_meta=get_track_meta()) + img_np, *_ = convert_data_type(img, np.ndarray) - # Apply weighting directly to image - # Same as applying it individually during the formation of the - # Laplacian - data = data * W + mask_np = None + if mask is not None: + mask = convert_to_tensor(mask, dtype=torch.bool, track_meta=get_track_meta()) + mask_np, *_ = convert_data_type(mask, np.ndarray) - # Find condidence values - map_ = self.confidence_estimation(data, seeds, labels, self.beta, self.gamma) + # Convert to grayscale + if img_np.ndim == 3: + img_np = skimage.color.rgb2gray(img_np) - if downsample is not None: - map_ = cv2.resize(map_, (org_W, org_H), interpolation=cv2.INTER_CUBIC) + # Compute confidence map + conf_map = self._compute_conf_map(img_np, mask_np) - return map_ + return convert_to_dst_type(src=conf_map, dst=img)[0] diff --git a/tests/test_ultrasound_confidence_map_transform.py b/tests/test_ultrasound_confidence_map_transform.py new file mode 100644 index 0000000000..f69407daa5 --- /dev/null +++ b/tests/test_ultrasound_confidence_map_transform.py @@ -0,0 +1,650 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch + +from monai.transforms import UltrasoundConfidenceMapTransform +from tests.utils import assert_allclose, is_tf32_env + +TEST_INPUT = np.array( + [ + [1, 2, 3, 23, 13, 22, 5, 1, 2, 3], + [1, 2, 3, 12, 4, 6, 9, 1, 2, 3], + [1, 2, 3, 8, 7, 10, 11, 1, 2, 3], + [1, 2, 3, 14, 15, 16, 17, 1, 2, 3], + [1, 2, 3, 18, 19, 20, 21, 1, 2, 3], + [1, 2, 3, 24, 25, 26, 27, 1, 2, 3], + [1, 2, 3, 28, 29, 30, 31, 1, 2, 3], + [1, 2, 3, 32, 33, 34, 35, 1, 2, 3], + [1, 2, 3, 36, 37, 38, 39, 1, 2, 3], + [1, 2, 3, 40, 41, 42, 43, 1, 2, 3], + ] +) + +TEST_MASK = np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 1, 1, 1, 0], + [1, 1, 1, 0, 0, 0, 1, 1, 1, 0], + [1, 1, 1, 0, 0, 0, 1, 1, 1, 0], + [1, 1, 1, 0, 0, 0, 1, 1, 1, 0], + [1, 1, 1, 0, 0, 0, 1, 1, 1, 0], + [1, 1, 1, 0, 0, 0, 1, 1, 1, 0], + [1, 1, 1, 0, 0, 0, 1, 1, 1, 0], + [1, 1, 1, 0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] +) + + +SINK_ALL_OUTPUT = np.array( + [ + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [ + 0.97514489, + 0.96762971, + 0.96164186, + 0.95463443, + 0.9941512, + 0.99023054, + 0.98559401, + 0.98230057, + 0.96601224, + 0.95119599, + ], + [ + 0.92960533, + 0.92638451, + 0.9056675, + 0.9487176, + 0.9546961, + 0.96165853, + 0.96172303, + 0.92686401, + 0.92122613, + 0.89957239, + ], + [ + 0.86490963, + 0.85723665, + 0.83798141, + 0.90816201, + 0.90816097, + 0.90815301, + 0.9081427, + 0.85933627, + 0.85146935, + 0.82948586, + ], + [ + 0.77430346, + 0.76731372, + 0.74372311, + 0.89128774, + 0.89126885, + 0.89125066, + 0.89123521, + 0.76858589, + 0.76106647, + 0.73807776, + ], + [ + 0.66098109, + 0.65327697, + 0.63090644, + 0.33086588, + 0.3308383, + 0.33081937, + 0.33080718, + 0.6557468, + 0.64825099, + 0.62593375, + ], + [ + 0.52526945, + 0.51832586, + 0.49709412, + 0.25985059, + 0.25981009, + 0.25977729, + 0.25975222, + 0.52118958, + 0.51426328, + 0.49323164, + ], + [ + 0.3697845, + 0.36318971, + 0.34424661, + 0.17386804, + 0.17382046, + 0.17377993, + 0.17374668, + 0.36689317, + 0.36036096, + 0.3415582, + ], + [ + 0.19546374, + 0.1909659, + 0.17319999, + 0.08423318, + 0.08417993, + 0.08413242, + 0.08409104, + 0.19393909, + 0.18947485, + 0.17185031, + ], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ] +) + +SINK_MID_OUTPUT = np.array( + [ + [ + 1.00000000e00, + 1.00000000e00, + 1.00000000e00, + 1.00000000e00, + 1.00000000e00, + 1.00000000e00, + 1.00000000e00, + 1.00000000e00, + 1.00000000e00, + 1.00000000e00, + ], + [ + 9.99996103e-01, + 9.99994823e-01, + 9.99993550e-01, + 9.99930863e-01, + 9.99990782e-01, + 9.99984683e-01, + 9.99979000e-01, + 9.99997804e-01, + 9.99995985e-01, + 9.99994325e-01, + ], + [ + 9.99989344e-01, + 9.99988600e-01, + 9.99984099e-01, + 9.99930123e-01, + 9.99926598e-01, + 9.99824297e-01, + 9.99815032e-01, + 9.99991228e-01, + 9.99990881e-01, + 9.99988462e-01, + ], + [ + 9.99980787e-01, + 9.99979264e-01, + 9.99975828e-01, + 9.59669286e-01, + 9.59664779e-01, + 9.59656566e-01, + 9.59648332e-01, + 9.99983882e-01, + 9.99983038e-01, + 9.99980732e-01, + ], + [ + 9.99970181e-01, + 9.99969032e-01, + 9.99965730e-01, + 9.45197806e-01, + 9.45179593e-01, + 9.45163629e-01, + 9.45151458e-01, + 9.99973352e-01, + 9.99973254e-01, + 9.99971098e-01, + ], + [ + 9.99958608e-01, + 9.99957307e-01, + 9.99953444e-01, + 4.24743523e-01, + 4.24713305e-01, + 4.24694646e-01, + 4.24685271e-01, + 9.99960948e-01, + 9.99961829e-01, + 9.99960347e-01, + ], + [ + 9.99946675e-01, + 9.99945139e-01, + 9.99940312e-01, + 3.51353224e-01, + 3.51304003e-01, + 3.51268260e-01, + 3.51245366e-01, + 9.99947688e-01, + 9.99950165e-01, + 9.99949512e-01, + ], + [ + 9.99935877e-01, + 9.99934088e-01, + 9.99928982e-01, + 2.51197134e-01, + 2.51130273e-01, + 2.51080014e-01, + 2.51045852e-01, + 9.99936187e-01, + 9.99939716e-01, + 9.99940022e-01, + ], + [ + 9.99927846e-01, + 9.99925911e-01, + 9.99920188e-01, + 1.31550973e-01, + 1.31462736e-01, + 1.31394558e-01, + 1.31346069e-01, + 9.99927275e-01, + 9.99932142e-01, + 9.99933313e-01, + ], + [ + 9.99924204e-01, + 9.99922004e-01, + 9.99915767e-01, + 3.04861147e-04, + 1.95998056e-04, + 0.00000000e00, + 2.05182682e-05, + 9.99923115e-01, + 9.99928835e-01, + 9.99930535e-01, + ], + ] +) + + +SINK_MIN_OUTPUT = np.array( + [ + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [ + 0.99997545, + 0.99996582, + 0.99995245, + 0.99856594, + 0.99898314, + 0.99777223, + 0.99394423, + 0.98588113, + 0.97283215, + 0.96096504, + ], + [ + 0.99993872, + 0.99993034, + 0.9998832, + 0.9986147, + 0.99848741, + 0.9972981, + 0.99723719, + 0.94157173, + 0.9369832, + 0.91964243, + ], + [ + 0.99990802, + 0.99989475, + 0.99986873, + 0.98610197, + 0.98610047, + 0.98609749, + 0.98609423, + 0.88741275, + 0.88112911, + 0.86349156, + ], + [ + 0.99988924, + 0.99988509, + 0.99988698, + 0.98234089, + 0.98233591, + 0.98233065, + 0.98232562, + 0.81475172, + 0.80865978, + 0.79033138, + ], + [ + 0.99988418, + 0.99988484, + 0.99988323, + 0.86796555, + 0.86795874, + 0.86795283, + 0.86794756, + 0.72418193, + 0.71847704, + 0.70022037, + ], + [ + 0.99988241, + 0.99988184, + 0.99988103, + 0.85528225, + 0.85527303, + 0.85526389, + 0.85525499, + 0.61716519, + 0.61026209, + 0.59503671, + ], + [ + 0.99988015, + 0.99987985, + 0.99987875, + 0.84258114, + 0.84257121, + 0.84256042, + 0.84254897, + 0.48997924, + 0.49083978, + 0.46891561, + ], + [ + 0.99987865, + 0.99987827, + 0.9998772, + 0.83279589, + 0.83278624, + 0.83277384, + 0.83275897, + 0.36345545, + 0.33690244, + 0.35696828, + ], + [ + 0.99987796, + 0.99987756, + 0.99987643, + 0.82873223, + 0.82872648, + 0.82871803, + 0.82870711, + 0.0, + 0.26106012, + 0.29978657, + ], + ] +) + +SINK_MASK_OUTPUT = np.array( + [ + [ + 1.00000000e00, + 1.00000000e00, + 1.00000000e00, + 1.00000000e00, + 1.00000000e00, + 1.00000000e00, + 1.00000000e00, + 1.00000000e00, + 1.00000000e00, + 1.00000000e00, + ], + [ + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 2.86416400e-01, + 7.93271181e-01, + 5.81341234e-01, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 1.98395623e-01, + ], + [ + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 2.66733297e-01, + 2.80741490e-01, + 4.14078784e-02, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 7.91676486e-04, + ], + [ + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 1.86244537e-04, + 1.53413401e-04, + 7.85806495e-05, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 5.09797387e-06, + ], + [ + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 9.62904581e-07, + 7.23946225e-07, + 3.68824440e-07, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 4.79525316e-08, + ], + [ + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 1.50939343e-10, + 1.17724874e-10, + 6.21760843e-11, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 6.08922784e-10, + ], + [ + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 2.57593754e-13, + 1.94066716e-13, + 9.83784370e-14, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 9.80828665e-12, + ], + [ + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 4.22323494e-16, + 3.17556633e-16, + 1.60789400e-16, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 1.90789819e-13, + ], + [ + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 7.72677888e-19, + 5.83029424e-19, + 2.95946659e-19, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 4.97038275e-15, + ], + [ + 2.71345908e-24, + 5.92006757e-24, + 2.25580089e-23, + 3.82601970e-18, + 3.82835349e-18, + 3.83302158e-18, + 3.84002606e-18, + 8.40760586e-16, + 1.83433696e-15, + 1.11629633e-15, + ], + ] +) + + +class TestUltrasoundConfidenceMapTransform(unittest.TestCase): + def setUp(self): + self.input_img_torch = torch.rand((256, 256)) # mock input image (torch tensor) + self.input_mask_torch = torch.ones((256, 256), dtype=torch.bool) # mock mask (torch tensor) + + # create numpy versions of image and mask + self.input_img_np = self.input_img_torch.numpy() + self.input_mask_np = self.input_mask_torch.numpy() + + def test_parameters(self): + + # Unknown mode + with self.assertRaises(ValueError): + UltrasoundConfidenceMapTransform(mode="unknown") + + # Unknown backend + with self.assertRaises(ValueError): + UltrasoundConfidenceMapTransform(backend="unknown") + + # Unknown sink_mode + with self.assertRaises(ValueError): + UltrasoundConfidenceMapTransform(sink_mode="unknown") + + def test_sink_all(self): + transform = UltrasoundConfidenceMapTransform(sink_mode="all") + + # This should not raise an exception for torch tensor + result_torch = transform(self.input_img_torch) + self.assertIsInstance(result_torch, np.ndarray) + + # This should not raise an exception for numpy array + result_np = transform(self.input_img_np) + self.assertIsInstance(result_np, np.ndarray) + + def test_sink_mid(self): + transform = UltrasoundConfidenceMapTransform(sink_mode="mid") + + # This should not raise an exception for torch tensor + result_torch = transform(self.input_img_torch) + self.assertIsInstance(result_torch, np.ndarray) + + # This should not raise an exception for numpy array + result_np = transform(self.input_img_np) + self.assertIsInstance(result_np, np.ndarray) + + def test_sink_min(self): + transform = UltrasoundConfidenceMapTransform(sink_mode="min") + + # This should not raise an exception for torch tensor + result_torch = transform(self.input_img_torch) + self.assertIsInstance(result_torch, np.ndarray) + + # This should not raise an exception for numpy array + result_np = transform(self.input_img_np) + self.assertIsInstance(result_np, np.ndarray) + + def test_sink_mask(self): + transform = UltrasoundConfidenceMapTransform(sink_mode="mask") + + # This should not raise an exception for torch tensor with mask + result_torch = transform(self.input_img_torch, self.input_mask_torch) + self.assertIsInstance(result_torch, np.ndarray) + + # This should not raise an exception for numpy array with mask + result_np = transform(self.input_img_np, self.input_mask_np) + self.assertIsInstance(result_np, np.ndarray) + + # This should raise an exception for torch tensor without mask + with self.assertRaises(ValueError): + transform(self.input_img_torch) + + # This should raise an exception for numpy array without mask + with self.assertRaises(ValueError): + transform(self.input_img_np) + + def test_func(self): + + transform = UltrasoundConfidenceMapTransform( + alpha=2.0, beta=90.0, gamma=0.05, mode="B", sink_mode="all", backend="scipy" + ) + output = transform(self.input_img_np) + assert_allclose(output, SINK_ALL_OUTPUT, rtol=1e-4, atol=1e-4) + + transform = UltrasoundConfidenceMapTransform( + alpha=2.0, beta=90.0, gamma=0.05, mode="B", sink_mode="mid", backend="scipy" + ) + output = transform(self.input_img_np) + assert_allclose(output, SINK_MID_OUTPUT, rtol=1e-4, atol=1e-4) + + transform = UltrasoundConfidenceMapTransform( + alpha=2.0, beta=90.0, gamma=0.05, mode="B", sink_mode="min", backend="scipy" + ) + output = transform(self.input_img_np) + assert_allclose(output, SINK_MIN_OUTPUT, rtol=1e-4, atol=1e-4) + + transform = UltrasoundConfidenceMapTransform( + alpha=2.0, beta=90.0, gamma=0.05, mode="B", sink_mode="mask", backend="scipy" + ) + output = transform(self.input_img_np, self.input_mask_np) + assert_allclose(output, SINK_MASK_OUTPUT, rtol=1e-4, atol=1e-4) + + transform = UltrasoundConfidenceMapTransform( + alpha=2.0, beta=90.0, gamma=0.05, mode="B", sink_mode="all", backend="scipy" + ) + input_img_torch = torch.from_numpy(self.input_img_np) + output = transform(input_img_torch) + assert_allclose(output, SINK_ALL_OUTPUT, rtol=1e-4, atol=1e-4) + + transform = UltrasoundConfidenceMapTransform( + alpha=2.0, beta=90.0, gamma=0.05, mode="B", sink_mode="mid", backend="scipy" + ) + input_img_torch = torch.from_numpy(self.input_img_np) + output = transform(input_img_torch) + assert_allclose(output, SINK_MID_OUTPUT, rtol=1e-4, atol=1e-4) + + transform = UltrasoundConfidenceMapTransform( + alpha=2.0, beta=90.0, gamma=0.05, mode="B", sink_mode="min", backend="scipy" + ) + input_img_torch = torch.from_numpy(self.input_img_np) + output = transform(input_img_torch) + assert_allclose(output, SINK_MIN_OUTPUT, rtol=1e-4, atol=1e-4) + + transform = UltrasoundConfidenceMapTransform( + alpha=2.0, beta=90.0, gamma=0.05, mode="B", sink_mode="mask", backend="scipy" + ) + input_img_torch = torch.from_numpy(self.input_img_np) + input_mask_torch = torch.from_numpy(self.input_mask_np) + output = transform(input_img_torch, input_mask_torch) + assert_allclose(output, SINK_MASK_OUTPUT, rtol=1e-4, atol=1e-4)