Skip to content

Commit

Permalink
Add warning in RandHistogramShift (#5877)
Browse files Browse the repository at this point in the history
Signed-off-by: KumoLiu <yunl@nvidia.com>

Fixes #5875 .

### Description

Add warning in `RandHistogramShift` when the image's intensity is a
single value.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

Signed-off-by: KumoLiu <yunl@nvidia.com>
  • Loading branch information
KumoLiu authored Jan 19, 2023
1 parent bdf5e1e commit 2c5c89f
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
8 changes: 7 additions & 1 deletion monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1468,9 +1468,15 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen
if self.reference_control_points is None or self.floating_control_points is None:
raise RuntimeError("please call the `randomize()` function first.")
img_t = convert_to_tensor(img, track_meta=False)
img_min, img_max = img_t.min(), img_t.max()
if img_min == img_max:
warn(
f"The image's intensity is a single value {img_min}. "
"The original image is simply returned, no histogram shift is done."
)
return img
xp, *_ = convert_to_dst_type(self.reference_control_points, dst=img_t)
yp, *_ = convert_to_dst_type(self.floating_control_points, dst=img_t)
img_min, img_max = img_t.min(), img_t.max()
reference_control_points_scaled = xp * (img_max - img_min) + img_min
floating_control_points_scaled = yp * (img_max - img_min) + img_min
img_t = self.interp(img_t, reference_control_points_scaled, floating_control_points_scaled)
Expand Down
16 changes: 16 additions & 0 deletions tests/test_rand_histogram_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@
]
)

WARN_TESTS = []
for p in TEST_NDARRAYS:
WARN_TESTS.append(
[
{"num_control_points": 5, "prob": 1.0},
{"img": p(np.zeros(8).reshape((1, 2, 2, 2)))},
np.zeros(8).reshape((1, 2, 2, 2)),
]
)


class TestRandHistogramShift(unittest.TestCase):
@parameterized.expand(TESTS)
Expand Down Expand Up @@ -71,6 +81,12 @@ def test_interp(self):
self.assertEqual(yi.shape, (3, 2))
assert_allclose(yi, array_type([[1.0, 5.0], [0.5, -0.5], [4.0, 5.0]]))

@parameterized.expand(WARN_TESTS)
def test_warn(self, input_param, input_data, expected_val):
with self.assertWarns(Warning):
result = RandHistogramShift(**input_param)(**input_data)
assert_allclose(result, expected_val, type_test="tensor")


if __name__ == "__main__":
unittest.main()

0 comments on commit 2c5c89f

Please sign in to comment.