diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 67afd7f3c3..774acf1e31 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -1468,9 +1468,15 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen if self.reference_control_points is None or self.floating_control_points is None: raise RuntimeError("please call the `randomize()` function first.") img_t = convert_to_tensor(img, track_meta=False) + img_min, img_max = img_t.min(), img_t.max() + if img_min == img_max: + warn( + f"The image's intensity is a single value {img_min}. " + "The original image is simply returned, no histogram shift is done." + ) + return img xp, *_ = convert_to_dst_type(self.reference_control_points, dst=img_t) yp, *_ = convert_to_dst_type(self.floating_control_points, dst=img_t) - img_min, img_max = img_t.min(), img_t.max() reference_control_points_scaled = xp * (img_max - img_min) + img_min floating_control_points_scaled = yp * (img_max - img_min) + img_min img_t = self.interp(img_t, reference_control_points_scaled, floating_control_points_scaled) diff --git a/tests/test_rand_histogram_shift.py b/tests/test_rand_histogram_shift.py index 64dd6a926b..318dad9dfa 100644 --- a/tests/test_rand_histogram_shift.py +++ b/tests/test_rand_histogram_shift.py @@ -44,6 +44,16 @@ ] ) +WARN_TESTS = [] +for p in TEST_NDARRAYS: + WARN_TESTS.append( + [ + {"num_control_points": 5, "prob": 1.0}, + {"img": p(np.zeros(8).reshape((1, 2, 2, 2)))}, + np.zeros(8).reshape((1, 2, 2, 2)), + ] + ) + class TestRandHistogramShift(unittest.TestCase): @parameterized.expand(TESTS) @@ -71,6 +81,12 @@ def test_interp(self): self.assertEqual(yi.shape, (3, 2)) assert_allclose(yi, array_type([[1.0, 5.0], [0.5, -0.5], [4.0, 5.0]])) + @parameterized.expand(WARN_TESTS) + def test_warn(self, input_param, input_data, expected_val): + with self.assertWarns(Warning): + result = RandHistogramShift(**input_param)(**input_data) + assert_allclose(result, expected_val, type_test="tensor") + if __name__ == "__main__": unittest.main()