Skip to content

Commit ba3c72c

Browse files
johnzielkeKumoLiu
andauthored
Fix inconsistent alpha parameter/docs for RandGibbsNoise/RandGibbsNoised (#7584)
Fixes inconsistent alpha parameter/docs for RandGibbsNoise/RandGibbsNoised ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: John Zielke <j.l.zielke@gmail.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
1 parent c9fed96 commit ba3c72c

File tree

4 files changed

+26
-4
lines changed

4 files changed

+26
-4
lines changed

monai/transforms/intensity/array.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1840,15 +1840,19 @@ class RandGibbsNoise(RandomizableTransform):
18401840
18411841
Args:
18421842
prob (float): probability of applying the transform.
1843-
alpha (Sequence(float)): Parametrizes the intensity of the Gibbs noise filter applied. Takes
1843+
alpha (float, Sequence(float)): Parametrizes the intensity of the Gibbs noise filter applied. Takes
18441844
values in the interval [0,1] with alpha = 0 acting as the identity mapping.
18451845
If a length-2 list is given as [a,b] then the value of alpha will be
18461846
sampled uniformly from the interval [a,b]. 0 <= a <= b <= 1.
1847+
If a float is given, then the value of alpha will be sampled uniformly from the interval [0, alpha].
18471848
"""
18481849

18491850
backend = GibbsNoise.backend
18501851

1851-
def __init__(self, prob: float = 0.1, alpha: Sequence[float] = (0.0, 1.0)) -> None:
1852+
def __init__(self, prob: float = 0.1, alpha: float | Sequence[float] = (0.0, 1.0)) -> None:
1853+
if isinstance(alpha, float):
1854+
alpha = (0, alpha)
1855+
alpha = ensure_tuple(alpha)
18521856
if len(alpha) != 2:
18531857
raise ValueError("alpha length must be 2.")
18541858
if alpha[1] > 1 or alpha[0] < 0:

monai/transforms/intensity/dictionary.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1423,10 +1423,11 @@ class RandGibbsNoised(RandomizableTransform, MapTransform):
14231423
keys: 'image', 'label', or ['image', 'label'] depending on which data
14241424
you need to transform.
14251425
prob (float): probability of applying the transform.
1426-
alpha (float, List[float]): Parametrizes the intensity of the Gibbs noise filter applied. Takes
1426+
alpha (float, Sequence[float]): Parametrizes the intensity of the Gibbs noise filter applied. Takes
14271427
values in the interval [0,1] with alpha = 0 acting as the identity mapping.
14281428
If a length-2 list is given as [a,b] then the value of alpha will be sampled
14291429
uniformly from the interval [a,b].
1430+
If a float is given, then the value of alpha will be sampled uniformly from the interval [0, alpha].
14301431
allow_missing_keys: do not raise exception if key is missing.
14311432
"""
14321433

@@ -1436,7 +1437,7 @@ def __init__(
14361437
self,
14371438
keys: KeysCollection,
14381439
prob: float = 0.1,
1439-
alpha: Sequence[float] = (0.0, 1.0),
1440+
alpha: float | Sequence[float] = (0.0, 1.0),
14401441
allow_missing_keys: bool = False,
14411442
) -> None:
14421443
MapTransform.__init__(self, keys, allow_missing_keys)

tests/test_rand_gibbs_noise.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,15 @@ def test_alpha(self, im_shape, input_type):
9090
self.assertGreaterEqual(t.sampled_alpha, 0.5)
9191
self.assertLessEqual(t.sampled_alpha, 0.51)
9292

93+
@parameterized.expand(TEST_CASES)
94+
def test_alpha_single_value(self, im_shape, input_type):
95+
im = self.get_data(im_shape, input_type)
96+
alpha = 0.01
97+
t = RandGibbsNoise(1.0, alpha)
98+
_ = t(deepcopy(im))
99+
self.assertGreaterEqual(t.sampled_alpha, 0)
100+
self.assertLessEqual(t.sampled_alpha, 0.01)
101+
93102

94103
if __name__ == "__main__":
95104
unittest.main()

tests/test_rand_gibbs_noised.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,14 @@ def test_alpha(self, im_shape, input_type):
105105
_ = t(deepcopy(data))
106106
self.assertTrue(0.5 <= t.rand_gibbs_noise.sampled_alpha <= 0.51)
107107

108+
@parameterized.expand(TEST_CASES)
109+
def test_alpha_single_value(self, im_shape, input_type):
110+
data = self.get_data(im_shape, input_type)
111+
alpha = 0.01
112+
t = RandGibbsNoised(KEYS, 1.0, alpha)
113+
_ = t(deepcopy(data))
114+
self.assertTrue(0 <= t.rand_gibbs_noise.sampled_alpha <= 0.01)
115+
108116

109117
if __name__ == "__main__":
110118
unittest.main()

0 commit comments

Comments
 (0)