Skip to content

Commit 06cbd70

Browse files
MrGranddyge85evzpre-commit-ci[bot]ericspodKumoLiu
authored
fix implementation mistakes and add conjugate gradients solver (#7876)
Fixes #6767 ### Description Fixed two minor implementation differences between the official MATLAB code and the MONAI implementation, to confirm also added a new test case where two images and their confidence maps calculated by the official code is added to tests and the MONAI implementation results are checked against there results created by the official code. Also fixing the issue: added the conjugate gradients solver option. Now the users can utilize it to run the algorithm faster with a trade-off of accuracy of the end result, a range of speed-ups can be achieved with little to no quality loss by tweaking the parameters, the optimal parameters between quality and speed in my experience is set as the default parameters, namely 'cg_tol' and 'cg_maxiter'. For the CG solver installing PyAMG (https://github.com/pyamg/pyamg) is a requirement, this is because we use it to generate a preconditioner, without it CG does not provide any speed-ups, even slows down the algorithm. This part can be changed if the requirement is not ideal, yet this was the best solution as far as my knowledge goes. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: MrGranddy <bugrayesilkaynak@gmail.com> Signed-off-by: Vahit Buğra YEŞİLKAYNAK <bugrayesilkaynak@gmail.com> Co-authored-by: ge85evz <vanessa_share@nevarro.ifl.campar.in.tum.de> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
1 parent ac84a4e commit 06cbd70

File tree

11 files changed

+531
-455
lines changed

11 files changed

+531
-455
lines changed

docs/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,4 @@ onnx>=1.13.0
4040
onnxruntime; python_version <= '3.10'
4141
zarr
4242
huggingface_hub
43+
pyamg>=5.0.0

docs/source/installation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,6 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is
258258
```
259259

260260
which correspond to `nibabel`, `scikit-image`,`scipy`, `pillow`, `tensorboard`,
261-
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr`, `lpips`, `nvidia-ml-py`, and `huggingface_hub` respectively.
261+
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr`, `lpips`, `nvidia-ml-py`, `huggingface_hub` and `pyamg` respectively.
262262

263263
- `pip install 'monai[all]'` installs all the optional dependencies.

monai/data/ultrasound_confidence_map.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
cv2, _ = optional_import("cv2")
2222
csc_matrix, _ = optional_import("scipy.sparse", "1.7.1", min_version, "csc_matrix")
2323
spsolve, _ = optional_import("scipy.sparse.linalg", "1.7.1", min_version, "spsolve")
24+
cg, _ = optional_import("scipy.sparse.linalg", "1.7.1", min_version, "cg")
2425
hilbert, _ = optional_import("scipy.signal", "1.7.1", min_version, "hilbert")
26+
ruge_stuben_solver, _ = optional_import("pyamg", "5.0.0", min_version, "ruge_stuben_solver")
2527

2628

2729
class UltrasoundConfidenceMap:
@@ -30,22 +32,43 @@ class UltrasoundConfidenceMap:
3032
It generates a confidence map by setting source and sink points in the image and computing the probability
3133
for random walks to reach the source for each pixel.
3234
35+
The official code is available at:
36+
https://campar.in.tum.de/Main/AthanasiosKaramalisCode
37+
3338
Args:
3439
alpha (float, optional): Alpha parameter. Defaults to 2.0.
3540
beta (float, optional): Beta parameter. Defaults to 90.0.
3641
gamma (float, optional): Gamma parameter. Defaults to 0.05.
3742
mode (str, optional): 'RF' or 'B' mode data. Defaults to 'B'.
3843
sink_mode (str, optional): Sink mode. Defaults to 'all'. If 'mask' is selected, a mask must be when calling
3944
the transform. Can be 'all', 'mid', 'min', or 'mask'.
45+
use_cg (bool, optional): Use Conjugate Gradient method for solving the linear system. Defaults to False.
46+
cg_tol (float, optional): Tolerance for the Conjugate Gradient method. Defaults to 1e-6.
47+
Will be used only if `use_cg` is True.
48+
cg_maxiter (int, optional): Maximum number of iterations for the Conjugate Gradient method. Defaults to 200.
49+
Will be used only if `use_cg` is True.
4050
"""
4151

42-
def __init__(self, alpha: float = 2.0, beta: float = 90.0, gamma: float = 0.05, mode="B", sink_mode="all"):
52+
def __init__(
53+
self,
54+
alpha: float = 2.0,
55+
beta: float = 90.0,
56+
gamma: float = 0.05,
57+
mode="B",
58+
sink_mode="all",
59+
use_cg=False,
60+
cg_tol=1e-6,
61+
cg_maxiter=200,
62+
):
4363
# The hyperparameters for confidence map estimation
4464
self.alpha = alpha
4565
self.beta = beta
4666
self.gamma = gamma
4767
self.mode = mode
4868
self.sink_mode = sink_mode
69+
self.use_cg = use_cg
70+
self.cg_tol = cg_tol
71+
self.cg_maxiter = cg_maxiter
4972

5073
# The precision to use for all computations
5174
self.eps = np.finfo("float64").eps
@@ -228,17 +251,18 @@ def confidence_laplacian(self, padded_index: NDArray, padded_image: NDArray, bet
228251
s = self.normalize(s)
229252

230253
# Horizontal penalty
231-
s[:vertical_end] += gamma
232-
# s[vertical_end:diagonal_end] += gamma * np.sqrt(2) # --> In the paper it is sqrt(2)
233-
# since the diagonal edges are longer yet does not exist in the original code
254+
s[vertical_end:] += gamma
255+
# Here there is a difference between the official MATLAB code and the paper
256+
# on the edge penalty. We directly implement what the official code does.
234257

235258
# Normalize differences
236259
s = self.normalize(s)
237260

238261
# Gaussian weighting function
239262
s = -(
240-
(np.exp(-beta * s, dtype="float64")) + 1.0e-6
241-
) # --> This epsilon changes results drastically default: 1.e-6
263+
(np.exp(-beta * s, dtype="float64")) + 1e-5
264+
) # --> This epsilon changes results drastically default: 10e-6
265+
# Please notice that it is not 1e-6, it is 10e-6 which is actually different.
242266

243267
# Create Laplacian, diagonal missing
244268
lap = csc_matrix((s, (i, j)))
@@ -256,7 +280,14 @@ def confidence_laplacian(self, padded_index: NDArray, padded_image: NDArray, bet
256280
return lap
257281

258282
def _solve_linear_system(self, lap, rhs):
259-
x = spsolve(lap, rhs)
283+
284+
if self.use_cg:
285+
lap_sparse = lap.tocsr()
286+
ml = ruge_stuben_solver(lap_sparse, coarse_solver="pinv")
287+
m = ml.aspreconditioner(cycle="V")
288+
x, _ = cg(lap, rhs, tol=self.cg_tol, maxiter=self.cg_maxiter, M=m)
289+
else:
290+
x = spsolve(lap, rhs)
260291

261292
return x
262293

monai/transforms/intensity/array.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2789,21 +2789,42 @@ class UltrasoundConfidenceMapTransform(Transform):
27892789
It generates a confidence map by setting source and sink points in the image and computing the probability
27902790
for random walks to reach the source for each pixel.
27912791
2792+
The official code is available at:
2793+
https://campar.in.tum.de/Main/AthanasiosKaramalisCode
2794+
27922795
Args:
27932796
alpha (float, optional): Alpha parameter. Defaults to 2.0.
27942797
beta (float, optional): Beta parameter. Defaults to 90.0.
27952798
gamma (float, optional): Gamma parameter. Defaults to 0.05.
27962799
mode (str, optional): 'RF' or 'B' mode data. Defaults to 'B'.
27972800
sink_mode (str, optional): Sink mode. Defaults to 'all'. If 'mask' is selected, a mask must be when
27982801
calling the transform. Can be one of 'all', 'mid', 'min', 'mask'.
2802+
use_cg (bool, optional): Use Conjugate Gradient method for solving the linear system. Defaults to False.
2803+
cg_tol (float, optional): Tolerance for the Conjugate Gradient method. Defaults to 1e-6.
2804+
Will be used only if `use_cg` is True.
2805+
cg_maxiter (int, optional): Maximum number of iterations for the Conjugate Gradient method. Defaults to 200.
2806+
Will be used only if `use_cg` is True.
27992807
"""
28002808

2801-
def __init__(self, alpha: float = 2.0, beta: float = 90.0, gamma: float = 0.05, mode="B", sink_mode="all") -> None:
2809+
def __init__(
2810+
self,
2811+
alpha: float = 2.0,
2812+
beta: float = 90.0,
2813+
gamma: float = 0.05,
2814+
mode="B",
2815+
sink_mode="all",
2816+
use_cg=False,
2817+
cg_tol: float = 1.0e-6,
2818+
cg_maxiter: int = 200,
2819+
):
28022820
self.alpha = alpha
28032821
self.beta = beta
28042822
self.gamma = gamma
28052823
self.mode = mode
28062824
self.sink_mode = sink_mode
2825+
self.use_cg = use_cg
2826+
self.cg_tol = cg_tol
2827+
self.cg_maxiter = cg_maxiter
28072828

28082829
if self.mode not in ["B", "RF"]:
28092830
raise ValueError(f"Unknown mode: {self.mode}. Supported modes are 'B' and 'RF'.")
@@ -2813,7 +2834,9 @@ def __init__(self, alpha: float = 2.0, beta: float = 90.0, gamma: float = 0.05,
28132834
f"Unknown sink mode: {self.sink_mode}. Supported modes are 'all', 'mid', 'min' and 'mask'."
28142835
)
28152836

2816-
self._compute_conf_map = UltrasoundConfidenceMap(self.alpha, self.beta, self.gamma, self.mode, self.sink_mode)
2837+
self._compute_conf_map = UltrasoundConfidenceMap(
2838+
self.alpha, self.beta, self.gamma, self.mode, self.sink_mode, self.use_cg, self.cg_tol, self.cg_maxiter
2839+
)
28172840

28182841
def __call__(self, img: NdarrayOrTensor, mask: NdarrayOrTensor | None = None) -> NdarrayOrTensor:
28192842
"""Compute confidence map from an ultrasound image.

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,4 @@ zarr
5757
lpips==0.1.4
5858
nvidia-ml-py
5959
huggingface_hub
60+
pyamg>=5.0.0

setup.cfg

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ all =
8484
lpips==0.1.4
8585
nvidia-ml-py
8686
huggingface_hub
87+
pyamg>=5.0.0
8788
nibabel =
8889
nibabel
8990
ninja =
@@ -162,6 +163,8 @@ pynvml =
162163
# MetricsReloaded @ git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded
163164
huggingface_hub =
164165
huggingface_hub
166+
pyamg =
167+
pyamg>=5.0.0
165168

166169
[flake8]
167170
select = B,C,E,F,N,P,T4,W,B9

0 commit comments

Comments
 (0)