Skip to content

Commit 530cc1f

Browse files
authored
Fix ImageFilter to allow Gaussian filter without filter_size (#8189)
Fixes #8127 Update `ImageFilter` to handle Gaussian filter without requiring `filter_size`. * Modify `monai/transforms/utility/array.py` to allow Gaussian filter without `filter_size`. - Adjust `_check_filter_format` method to skip `filter_size` check for Gaussian filter. Indeed Gauss filter is the only one in the list that doesn't require a filter_size. * Add unit test in `tests/test_image_filter.py` for Gaussian filter without `filter_size`. - Verify output shape matches input shape. Note that this method is compliant with the dictionnary version since this one load the fixed version. Signed-off-by: Eloi <eloi.navet@gmail.com> --------- Signed-off-by: Eloi Navet <eloi.navet@labri.fr> Signed-off-by: Eloi <eloi.navet@gmail.com> Signed-off-by: Eloi eloi.navet@gmail.com
1 parent c1ceea3 commit 530cc1f

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

monai/transforms/utility/array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1609,9 +1609,9 @@ def _check_all_values_uneven(self, x: tuple) -> None:
16091609

16101610
def _check_filter_format(self, filter: str | NdarrayOrTensor | nn.Module, filter_size: int | None = None) -> None:
16111611
if isinstance(filter, str):
1612-
if not filter_size:
1612+
if filter != "gauss" and not filter_size: # Gauss is the only filter that does not require `filter_size`
16131613
raise ValueError("`filter_size` must be specified when specifying filters by string.")
1614-
if filter_size % 2 == 0:
1614+
if filter_size and filter_size % 2 == 0:
16151615
raise ValueError("`filter_size` should be a single uneven integer.")
16161616
if filter not in self.supported_filters:
16171617
raise NotImplementedError(f"{filter}. Supported filters are {self.supported_filters}.")

tests/test_image_filter.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,12 @@ def test_pass_empty_metadata_dict(self):
134134
out_tensor = filter(image)
135135
self.assertTrue(isinstance(out_tensor, MetaTensor))
136136

137+
def test_gaussian_filter_without_filter_size(self):
138+
"Test Gaussian filter without specifying filter_size"
139+
filter = ImageFilter("gauss", sigma=2)
140+
out_tensor = filter(SAMPLE_IMAGE_2D)
141+
self.assertEqual(out_tensor.shape[1:], SAMPLE_IMAGE_2D.shape[1:])
142+
137143

138144
class TestImageFilterDict(unittest.TestCase):
139145

0 commit comments

Comments
 (0)