Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion monai/transforms/croppad/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1147,6 +1147,7 @@ class RandCropByLabelClasses(Randomizable, TraceableTransform, LazyTransform, Mu
allow_smaller: if `False`, an exception will be raised if the image is smaller than
the requested ROI in any dimension. If `True`, any smaller dimensions will remain
unchanged.
warn: if `True` prints a warning if a class is not present in the label.

"""

Expand All @@ -1163,6 +1164,7 @@ def __init__(
image_threshold: float = 0.0,
indices: list[NdarrayOrTensor] | None = None,
allow_smaller: bool = False,
warn: bool = True,
) -> None:
self.spatial_size = spatial_size
self.ratios = ratios
Expand All @@ -1174,6 +1176,7 @@ def __init__(
self.centers: list[list[int]] | None = None
self.indices = indices
self.allow_smaller = allow_smaller
self.warn = warn

def randomize(
self,
Expand All @@ -1198,7 +1201,7 @@ def randomize(
if _shape is None:
raise ValueError("label or image must be provided to infer the output spatial shape.")
self.centers = generate_label_classes_crop_centers(
self.spatial_size, self.num_samples, _shape, indices_, self.ratios, self.R, self.allow_smaller
self.spatial_size, self.num_samples, _shape, indices_, self.ratios, self.R, self.allow_smaller, self.warn
)

@LazyTransform.lazy_evaluation.setter # type: ignore
Expand Down
4 changes: 4 additions & 0 deletions monai/transforms/croppad/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,8 @@ class RandCropByLabelClassesd(Randomizable, MapTransform, LazyTransform, MultiSa
the requested ROI in any dimension. If `True`, any smaller dimensions will remain
unchanged.
allow_missing_keys: don't raise exception if key is missing.
warn: if `True` prints a warning if a class is not present in the label.


"""

Expand All @@ -979,6 +981,7 @@ def __init__(
indices_key: str | None = None,
allow_smaller: bool = False,
allow_missing_keys: bool = False,
warn: bool = True,
) -> None:
MapTransform.__init__(self, keys, allow_missing_keys)
self.label_key = label_key
Expand All @@ -991,6 +994,7 @@ def __init__(
num_samples=num_samples,
image_threshold=image_threshold,
allow_smaller=allow_smaller,
warn=warn,
)

def set_random_state(
Expand Down
4 changes: 3 additions & 1 deletion monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ def generate_label_classes_crop_centers(
ratios: list[float | int] | None = None,
rand_state: np.random.RandomState | None = None,
allow_smaller: bool = False,
warn: bool = True,
) -> list[list[int]]:
"""
Generate valid sample locations based on the specified ratios of label classes.
Expand All @@ -551,6 +552,7 @@ def generate_label_classes_crop_centers(
allow_smaller: if `False`, an exception will be raised if the image is smaller than
the requested ROI in any dimension. If `True`, any smaller dimensions will be set to
match the cropped size (i.e., no cropping in that dimension).
warn: if `True` prints a warning if a class is not present in the label.

"""
if rand_state is None:
Expand All @@ -567,7 +569,7 @@ def generate_label_classes_crop_centers(
raise ValueError(f"ratios should not contain negative number, got {ratios_}.")

for i, array in enumerate(indices):
if len(array) == 0:
if len(array) == 0 and warn:
warnings.warn(f"no available indices of class {i} to crop, set the crop ratio of this class to zero.")
ratios_[i] = 0

Expand Down