Skip to content

Commit

Permalink
Fix inconsistent alpha parameter/docs for RandGibbsNoise/RandGibbsNoi…
Browse files Browse the repository at this point in the history
…sed (#7584)

Fixes inconsistent alpha parameter/docs for
RandGibbsNoise/RandGibbsNoised

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: John Zielke <j.l.zielke@gmail.com>
Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
  • Loading branch information
johnzielke and KumoLiu authored Mar 27, 2024
1 parent c9fed96 commit ba3c72c
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 4 deletions.
8 changes: 6 additions & 2 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1840,15 +1840,19 @@ class RandGibbsNoise(RandomizableTransform):
Args:
prob (float): probability of applying the transform.
alpha (Sequence(float)): Parametrizes the intensity of the Gibbs noise filter applied. Takes
alpha (float, Sequence(float)): Parametrizes the intensity of the Gibbs noise filter applied. Takes
values in the interval [0,1] with alpha = 0 acting as the identity mapping.
If a length-2 list is given as [a,b] then the value of alpha will be
sampled uniformly from the interval [a,b]. 0 <= a <= b <= 1.
If a float is given, then the value of alpha will be sampled uniformly from the interval [0, alpha].
"""

backend = GibbsNoise.backend

def __init__(self, prob: float = 0.1, alpha: Sequence[float] = (0.0, 1.0)) -> None:
def __init__(self, prob: float = 0.1, alpha: float | Sequence[float] = (0.0, 1.0)) -> None:
if isinstance(alpha, float):
alpha = (0, alpha)
alpha = ensure_tuple(alpha)
if len(alpha) != 2:
raise ValueError("alpha length must be 2.")
if alpha[1] > 1 or alpha[0] < 0:
Expand Down
5 changes: 3 additions & 2 deletions monai/transforms/intensity/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -1423,10 +1423,11 @@ class RandGibbsNoised(RandomizableTransform, MapTransform):
keys: 'image', 'label', or ['image', 'label'] depending on which data
you need to transform.
prob (float): probability of applying the transform.
alpha (float, List[float]): Parametrizes the intensity of the Gibbs noise filter applied. Takes
alpha (float, Sequence[float]): Parametrizes the intensity of the Gibbs noise filter applied. Takes
values in the interval [0,1] with alpha = 0 acting as the identity mapping.
If a length-2 list is given as [a,b] then the value of alpha will be sampled
uniformly from the interval [a,b].
If a float is given, then the value of alpha will be sampled uniformly from the interval [0, alpha].
allow_missing_keys: do not raise exception if key is missing.
"""

Expand All @@ -1436,7 +1437,7 @@ def __init__(
self,
keys: KeysCollection,
prob: float = 0.1,
alpha: Sequence[float] = (0.0, 1.0),
alpha: float | Sequence[float] = (0.0, 1.0),
allow_missing_keys: bool = False,
) -> None:
MapTransform.__init__(self, keys, allow_missing_keys)
Expand Down
9 changes: 9 additions & 0 deletions tests/test_rand_gibbs_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,15 @@ def test_alpha(self, im_shape, input_type):
self.assertGreaterEqual(t.sampled_alpha, 0.5)
self.assertLessEqual(t.sampled_alpha, 0.51)

@parameterized.expand(TEST_CASES)
def test_alpha_single_value(self, im_shape, input_type):
im = self.get_data(im_shape, input_type)
alpha = 0.01
t = RandGibbsNoise(1.0, alpha)
_ = t(deepcopy(im))
self.assertGreaterEqual(t.sampled_alpha, 0)
self.assertLessEqual(t.sampled_alpha, 0.01)


if __name__ == "__main__":
unittest.main()
8 changes: 8 additions & 0 deletions tests/test_rand_gibbs_noised.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,14 @@ def test_alpha(self, im_shape, input_type):
_ = t(deepcopy(data))
self.assertTrue(0.5 <= t.rand_gibbs_noise.sampled_alpha <= 0.51)

@parameterized.expand(TEST_CASES)
def test_alpha_single_value(self, im_shape, input_type):
data = self.get_data(im_shape, input_type)
alpha = 0.01
t = RandGibbsNoised(KEYS, 1.0, alpha)
_ = t(deepcopy(data))
self.assertTrue(0 <= t.rand_gibbs_noise.sampled_alpha <= 0.01)


if __name__ == "__main__":
unittest.main()

0 comments on commit ba3c72c

Please sign in to comment.