From 2c5c89f86ee718b31ce2397c699cd67a8c78623f Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 20 Jan 2023 03:12:53 +0800 Subject: [PATCH] Add warning in `RandHistogramShift` (#5877) Signed-off-by: KumoLiu Fixes #5875 . ### Description Add warning in `RandHistogramShift` when the image's intensity is a single value. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: KumoLiu --- monai/transforms/intensity/array.py | 8 +++++++- tests/test_rand_histogram_shift.py | 16 ++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 67afd7f3c3..774acf1e31 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -1468,9 +1468,15 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen if self.reference_control_points is None or self.floating_control_points is None: raise RuntimeError("please call the `randomize()` function first.") img_t = convert_to_tensor(img, track_meta=False) + img_min, img_max = img_t.min(), img_t.max() + if img_min == img_max: + warn( + f"The image's intensity is a single value {img_min}. " + "The original image is simply returned, no histogram shift is done." + ) + return img xp, *_ = convert_to_dst_type(self.reference_control_points, dst=img_t) yp, *_ = convert_to_dst_type(self.floating_control_points, dst=img_t) - img_min, img_max = img_t.min(), img_t.max() reference_control_points_scaled = xp * (img_max - img_min) + img_min floating_control_points_scaled = yp * (img_max - img_min) + img_min img_t = self.interp(img_t, reference_control_points_scaled, floating_control_points_scaled) diff --git a/tests/test_rand_histogram_shift.py b/tests/test_rand_histogram_shift.py index 64dd6a926b..318dad9dfa 100644 --- a/tests/test_rand_histogram_shift.py +++ b/tests/test_rand_histogram_shift.py @@ -44,6 +44,16 @@ ] ) +WARN_TESTS = [] +for p in TEST_NDARRAYS: + WARN_TESTS.append( + [ + {"num_control_points": 5, "prob": 1.0}, + {"img": p(np.zeros(8).reshape((1, 2, 2, 2)))}, + np.zeros(8).reshape((1, 2, 2, 2)), + ] + ) + class TestRandHistogramShift(unittest.TestCase): @parameterized.expand(TESTS) @@ -71,6 +81,12 @@ def test_interp(self): self.assertEqual(yi.shape, (3, 2)) assert_allclose(yi, array_type([[1.0, 5.0], [0.5, -0.5], [4.0, 5.0]])) + @parameterized.expand(WARN_TESTS) + def test_warn(self, input_param, input_data, expected_val): + with self.assertWarns(Warning): + result = RandHistogramShift(**input_param)(**input_data) + assert_allclose(result, expected_val, type_test="tensor") + if __name__ == "__main__": unittest.main()