Skip to content

Commit 3410794

Browse files
authored
Add ultrasound confidence map to transforms (#6709)
### Description This transform uses the method introduced by Karamalis et al. in https://doi.org/10.1016/j.media.2012.07.005 to compute confidence maps on ultrasound images. ### Possible Problems - I am not entirely sure if the "transforms" section is the right place for this method but I found it the most suitable since it is not "deep learning" and it is "pre-processing" in a way. - Current version of the implementation requires GNU Octave to be installed and defined in the path. This is an odd dependency, I am aware of that, yet using SciPy does not provide satisfactory results in terms of speed. If this kind of dependency is not suitable, I also have a pure SciPy implementation, yet it runs about x15 slower, and it is slow to work in real-time, I am open to any feedback. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Vahit Bugra YESILKAYNAK <bugrayesilkaynak@gmail.com>
1 parent 644c9e5 commit 3410794

File tree

11 files changed

+1203
-5
lines changed

11 files changed

+1203
-5
lines changed

docs/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ itk>=5.2
66
nibabel
77
parameterized
88
scikit-image>=0.19.0
9+
scipy>=1.7.1
910
tensorboard
1011
commonmark==0.9.1
1112
recommonmark==0.6.0

docs/source/installation.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
- [Uninstall the packages](#uninstall-the-packages)
1111
- [From conda-forge](#from-conda-forge)
1212
- [From GitHub](#from-github)
13-
- [Option 1 (as a part of your system-wide module)](#option-1-as-a-part-of-your-system-wide-module)
14-
- [Option 2 (editable installation)](#option-2-editable-installation)
13+
- [Option 1 (as a part of your system-wide module):](#option-1-as-a-part-of-your-system-wide-module)
14+
- [Option 2 (editable installation):](#option-2-editable-installation)
1515
- [Validating the install](#validating-the-install)
1616
- [MONAI version string](#monai-version-string)
1717
- [From DockerHub](#from-dockerhub)
@@ -254,10 +254,10 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is
254254
- The options are
255255

256256
```
257-
[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr]
257+
[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr]
258258
```
259259

260-
which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`,
260+
which correspond to `nibabel`, `scikit-image`, `scipy`, `pillow`, `tensorboard`,
261261
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, and `zarr` respectively.
262262

263263
- `pip install 'monai[all]'` installs all the optional dependencies.

monai/config/deviceconfig.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def get_optional_config_values():
7171
output["ITK"] = get_package_version("itk")
7272
output["Nibabel"] = get_package_version("nibabel")
7373
output["scikit-image"] = get_package_version("skimage")
74+
output["scipy"] = get_package_version("scipy")
7475
output["Pillow"] = get_package_version("PIL")
7576
output["Tensorboard"] = get_package_version("tensorboard")
7677
output["gdown"] = get_package_version("gdown")

monai/data/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,5 @@ def reduce_meta_tensor(meta_tensor):
150150
return _rebuild_meta, (type(meta_tensor), storage, dtype, metadata)
151151

152152
ForkingPickler.register(MetaTensor, reduce_meta_tensor)
153+
154+
from .ultrasound_confidence_map import UltrasoundConfidenceMap
Lines changed: 352 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,352 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import numpy as np
15+
from numpy.typing import NDArray
16+
17+
from monai.utils import min_version, optional_import
18+
19+
__all__ = ["UltrasoundConfidenceMap"]
20+
21+
cv2, _ = optional_import("cv2")
22+
csc_matrix, _ = optional_import("scipy.sparse", "1.7.1", min_version, "csc_matrix")
23+
spsolve, _ = optional_import("scipy.sparse.linalg", "1.7.1", min_version, "spsolve")
24+
hilbert, _ = optional_import("scipy.signal", "1.7.1", min_version, "hilbert")
25+
26+
27+
class UltrasoundConfidenceMap:
28+
"""Compute confidence map from an ultrasound image.
29+
This transform uses the method introduced by Karamalis et al. in https://doi.org/10.1016/j.media.2012.07.005.
30+
It generates a confidence map by setting source and sink points in the image and computing the probability
31+
for random walks to reach the source for each pixel.
32+
33+
Args:
34+
alpha (float, optional): Alpha parameter. Defaults to 2.0.
35+
beta (float, optional): Beta parameter. Defaults to 90.0.
36+
gamma (float, optional): Gamma parameter. Defaults to 0.05.
37+
mode (str, optional): 'RF' or 'B' mode data. Defaults to 'B'.
38+
sink_mode (str, optional): Sink mode. Defaults to 'all'. If 'mask' is selected, a mask must be when calling
39+
the transform. Can be 'all', 'mid', 'min', or 'mask'.
40+
"""
41+
42+
def __init__(self, alpha: float = 2.0, beta: float = 90.0, gamma: float = 0.05, mode="B", sink_mode="all"):
43+
# The hyperparameters for confidence map estimation
44+
self.alpha = alpha
45+
self.beta = beta
46+
self.gamma = gamma
47+
self.mode = mode
48+
self.sink_mode = sink_mode
49+
50+
# The precision to use for all computations
51+
self.eps = np.finfo("float64").eps
52+
53+
# Store sink indices for external use
54+
self._sink_indices = np.array([], dtype="float64")
55+
56+
def sub2ind(self, size: tuple[int, ...], rows: NDArray, cols: NDArray) -> NDArray:
57+
"""Converts row and column subscripts into linear indices,
58+
basically the copy of the MATLAB function of the same name.
59+
https://www.mathworks.com/help/matlab/ref/sub2ind.html
60+
61+
This function is Pythonic so the indices start at 0.
62+
63+
Args:
64+
size Tuple[int]: Size of the matrix
65+
rows (NDArray): Row indices
66+
cols (NDArray): Column indices
67+
68+
Returns:
69+
indices (NDArray): 1-D array of linear indices
70+
"""
71+
indices: NDArray = rows + cols * size[0]
72+
return indices
73+
74+
def get_seed_and_labels(
75+
self, data: NDArray, sink_mode: str = "all", sink_mask: NDArray | None = None
76+
) -> tuple[NDArray, NDArray]:
77+
"""Get the seed and label arrays for the max-flow algorithm
78+
79+
Args:
80+
data: Input array
81+
sink_mode (str, optional): Sink mode. Defaults to 'all'.
82+
sink_mask (NDArray, optional): Sink mask. Defaults to None.
83+
84+
Returns:
85+
Tuple[NDArray, NDArray]: Seed and label arrays
86+
"""
87+
88+
# Seeds and labels (boundary conditions)
89+
seeds = np.array([], dtype="float64")
90+
labels = np.array([], dtype="float64")
91+
92+
# Indices for all columns
93+
sc = np.arange(data.shape[1], dtype="float64")
94+
95+
# SOURCE ELEMENTS - 1st matrix row
96+
# Indices for 1st row, it will be broadcasted with sc
97+
sr_up = np.array([0])
98+
seed = self.sub2ind(data.shape, sr_up, sc).astype("float64")
99+
seed = np.unique(seed)
100+
seeds = np.concatenate((seeds, seed))
101+
102+
# Label 1
103+
label = np.ones_like(seed)
104+
labels = np.concatenate((labels, label))
105+
106+
# Create seeds for sink elements
107+
108+
if sink_mode == "all":
109+
# All elements in the last row
110+
sr_down = np.ones_like(sc) * (data.shape[0] - 1)
111+
self._sink_indices = np.array([sr_down, sc], dtype="int32")
112+
seed = self.sub2ind(data.shape, sr_down, sc).astype("float64")
113+
114+
elif sink_mode == "mid":
115+
# Middle element in the last row
116+
sc_down = np.array([data.shape[1] // 2])
117+
sr_down = np.ones_like(sc_down) * (data.shape[0] - 1)
118+
self._sink_indices = np.array([sr_down, sc_down], dtype="int32")
119+
seed = self.sub2ind(data.shape, sr_down, sc_down).astype("float64")
120+
121+
elif sink_mode == "min":
122+
# Minimum element in the last row (excluding 10% from the edges)
123+
ten_percent = int(data.shape[1] * 0.1)
124+
min_val = np.min(data[-1, ten_percent:-ten_percent])
125+
min_idxs = np.where(data[-1, ten_percent:-ten_percent] == min_val)[0] + ten_percent
126+
sc_down = min_idxs
127+
sr_down = np.ones_like(sc_down) * (data.shape[0] - 1)
128+
self._sink_indices = np.array([sr_down, sc_down], dtype="int32")
129+
seed = self.sub2ind(data.shape, sr_down, sc_down).astype("float64")
130+
131+
elif sink_mode == "mask":
132+
# All elements in the mask
133+
coords = np.where(sink_mask != 0)
134+
sr_down = coords[0]
135+
sc_down = coords[1]
136+
self._sink_indices = np.array([sr_down, sc_down], dtype="int32")
137+
seed = self.sub2ind(data.shape, sr_down, sc_down).astype("float64")
138+
139+
seed = np.unique(seed)
140+
seeds = np.concatenate((seeds, seed))
141+
142+
# Label 2
143+
label = np.ones_like(seed) * 2
144+
labels = np.concatenate((labels, label))
145+
146+
return seeds, labels
147+
148+
def normalize(self, inp: NDArray) -> NDArray:
149+
"""Normalize an array to [0, 1]"""
150+
normalized_array: NDArray = (inp - np.min(inp)) / (np.ptp(inp) + self.eps)
151+
return normalized_array
152+
153+
def attenuation_weighting(self, img: NDArray, alpha: float) -> NDArray:
154+
"""Compute attenuation weighting
155+
156+
Args:
157+
img (NDArray): Image
158+
alpha: Attenuation coefficient (see publication)
159+
160+
Returns:
161+
w (NDArray): Weighting expressing depth-dependent attenuation
162+
"""
163+
164+
# Create depth vector and repeat it for each column
165+
dw = np.linspace(0, 1, img.shape[0], dtype="float64")
166+
dw = np.tile(dw.reshape(-1, 1), (1, img.shape[1]))
167+
168+
w: NDArray = 1.0 - np.exp(-alpha * dw) # Compute exp inline
169+
170+
return w
171+
172+
def confidence_laplacian(self, padded_index: NDArray, padded_image: NDArray, beta: float, gamma: float):
173+
"""Compute 6-Connected Laplacian for confidence estimation problem
174+
175+
Args:
176+
padded_index (NDArray): The index matrix of the image with boundary padding.
177+
padded_image (NDArray): The padded image.
178+
beta (float): Random walks parameter that defines the sensitivity of the Gaussian weighting function.
179+
gamma (float): Horizontal penalty factor that adjusts the weight of horizontal edges in the Laplacian.
180+
181+
Returns:
182+
L (csc_matrix): The 6-connected Laplacian matrix used for confidence map estimation.
183+
"""
184+
185+
m, _ = padded_index.shape
186+
187+
padded_index = padded_index.T.flatten()
188+
padded_image = padded_image.T.flatten()
189+
190+
p = np.where(padded_index > 0)[0]
191+
192+
i = padded_index[p] - 1 # Index vector
193+
j = padded_index[p] - 1 # Index vector
194+
# Entries vector, initially for diagonal
195+
s = np.zeros_like(p, dtype="float64")
196+
197+
edge_templates = [
198+
-1, # Vertical edges
199+
1,
200+
m - 1, # Diagonal edges
201+
m + 1,
202+
-m - 1,
203+
-m + 1,
204+
m, # Horizontal edges
205+
-m,
206+
]
207+
208+
vertical_end = None
209+
210+
for iter_idx, k in enumerate(edge_templates):
211+
neigh_idxs = padded_index[p + k]
212+
213+
q = np.where(neigh_idxs > 0)[0]
214+
215+
ii = padded_index[p[q]] - 1
216+
i = np.concatenate((i, ii))
217+
jj = neigh_idxs[q] - 1
218+
j = np.concatenate((j, jj))
219+
w = np.abs(padded_image[p[ii]] - padded_image[p[jj]]) # Intensity derived weight
220+
s = np.concatenate((s, w))
221+
222+
if iter_idx == 1:
223+
vertical_end = s.shape[0] # Vertical edges length
224+
elif iter_idx == 5:
225+
s.shape[0] # Diagonal edges length
226+
227+
# Normalize weights
228+
s = self.normalize(s)
229+
230+
# Horizontal penalty
231+
s[:vertical_end] += gamma
232+
# s[vertical_end:diagonal_end] += gamma * np.sqrt(2) # --> In the paper it is sqrt(2)
233+
# since the diagonal edges are longer yet does not exist in the original code
234+
235+
# Normalize differences
236+
s = self.normalize(s)
237+
238+
# Gaussian weighting function
239+
s = -(
240+
(np.exp(-beta * s, dtype="float64")) + 1.0e-6
241+
) # --> This epsilon changes results drastically default: 1.e-6
242+
243+
# Create Laplacian, diagonal missing
244+
lap = csc_matrix((s, (i, j)))
245+
246+
# Reset diagonal weights to zero for summing
247+
# up the weighted edge degree in the next step
248+
lap.setdiag(0)
249+
250+
# Weighted edge degree
251+
diag = np.abs(lap.sum(axis=0).A)[0]
252+
253+
# Finalize Laplacian by completing the diagonal
254+
lap.setdiag(diag)
255+
256+
return lap
257+
258+
def _solve_linear_system(self, lap, rhs):
259+
x = spsolve(lap, rhs)
260+
261+
return x
262+
263+
def confidence_estimation(self, img, seeds, labels, beta, gamma):
264+
"""Compute confidence map
265+
266+
Args:
267+
img (NDArray): Processed image.
268+
seeds (NDArray): Seeds for the random walks framework. These are indices of the source and sink nodes.
269+
labels (NDArray): Labels for the random walks framework. These represent the classes or groups of the seeds.
270+
beta: Random walks parameter that defines the sensitivity of the Gaussian weighting function.
271+
gamma: Horizontal penalty factor that adjusts the weight of horizontal edges in the Laplacian.
272+
273+
Returns:
274+
map: Confidence map which shows the probability of each pixel belonging to the source or sink group.
275+
"""
276+
277+
# Index matrix with boundary padding
278+
idx = np.arange(1, img.shape[0] * img.shape[1] + 1).reshape(img.shape[1], img.shape[0]).T
279+
pad = 1
280+
281+
padded_idx = np.pad(idx, (pad, pad), "constant", constant_values=(0, 0))
282+
padded_img = np.pad(img, (pad, pad), "constant", constant_values=(0, 0))
283+
284+
# Laplacian
285+
lap = self.confidence_laplacian(padded_idx, padded_img, beta, gamma)
286+
287+
# Select marked columns from Laplacian to create L_M and B^T
288+
b = lap[:, seeds]
289+
290+
# Select marked nodes to create B^T
291+
n = np.sum(padded_idx > 0).item()
292+
i_u = np.setdiff1d(np.arange(n), seeds.astype(int)) # Index of unmarked nodes
293+
b = b[i_u, :]
294+
295+
# Remove marked nodes from Laplacian by deleting rows and cols
296+
keep_indices = np.setdiff1d(np.arange(lap.shape[0]), seeds)
297+
lap = csc_matrix(lap[keep_indices, :][:, keep_indices])
298+
299+
# Define M matrix
300+
m = np.zeros((seeds.shape[0], 1), dtype="float64")
301+
m[:, 0] = labels == 1
302+
303+
# Right-handside (-B^T*M)
304+
rhs = -b @ m # type: ignore
305+
306+
# Solve linear system
307+
x = self._solve_linear_system(lap, rhs)
308+
309+
# Prepare output
310+
probabilities = np.zeros((n,), dtype="float64")
311+
# Probabilities for unmarked nodes
312+
probabilities[i_u] = x
313+
# Max probability for marked node
314+
probabilities[seeds[labels == 1].astype(int)] = 1.0
315+
316+
# Final reshape with same size as input image (no padding)
317+
probabilities = probabilities.reshape((img.shape[1], img.shape[0])).T
318+
319+
return probabilities
320+
321+
def __call__(self, data: NDArray, sink_mask: NDArray | None = None) -> NDArray:
322+
"""Compute the confidence map
323+
324+
Args:
325+
data (NDArray): RF ultrasound data (one scanline per column) [H x W] 2D array
326+
327+
Returns:
328+
map (NDArray): Confidence map [H x W] 2D array
329+
"""
330+
331+
# Normalize data
332+
data = data.astype("float64")
333+
data = self.normalize(data)
334+
335+
if self.mode == "RF":
336+
# MATLAB hilbert applies the Hilbert transform to columns
337+
data = np.abs(hilbert(data, axis=0)).astype("float64") # type: ignore
338+
339+
seeds, labels = self.get_seed_and_labels(data, self.sink_mode, sink_mask)
340+
341+
# Attenuation with Beer-Lambert
342+
w = self.attenuation_weighting(data, self.alpha)
343+
344+
# Apply weighting directly to image
345+
# Same as applying it individually during the formation of the
346+
# Laplacian
347+
data = data * w
348+
349+
# Find condidence values
350+
map_: NDArray = self.confidence_estimation(data, seeds, labels, self.beta, self.gamma)
351+
352+
return map_

0 commit comments

Comments
 (0)