Skip to content

Commit 244148d

Browse files
ytl0623ericspodpre-commit-ci[bot]Yu0610KumoLiu
authored
Add function in monai.transforms.utils.py (#7712)
Fixes #6704 ### Description Combined `map_classes_to_indices` and `generate_label_classes_crop_centers` to a single function `map_and_generate_sampling_centers` in monai.transforms.utils.py. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: ytl0623 <david89062388@gmail.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Yu <146002968+Yu0610@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
1 parent 96bfda0 commit 244148d

File tree

3 files changed

+153
-0
lines changed

3 files changed

+153
-0
lines changed

monai/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,7 @@
671671
in_bounds,
672672
is_empty,
673673
is_positive,
674+
map_and_generate_sampling_centers,
674675
map_binary_to_indices,
675676
map_classes_to_indices,
676677
map_spatial_axes,

monai/transforms/utils.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@
108108
"in_bounds",
109109
"is_empty",
110110
"is_positive",
111+
"map_and_generate_sampling_centers",
111112
"map_binary_to_indices",
112113
"map_classes_to_indices",
113114
"map_spatial_axes",
@@ -368,6 +369,70 @@ def check_non_lazy_pending_ops(
368369
warnings.warn(msg)
369370

370371

372+
def map_and_generate_sampling_centers(
373+
label: NdarrayOrTensor,
374+
spatial_size: Sequence[int] | int,
375+
num_samples: int,
376+
label_spatial_shape: Sequence[int] | None = None,
377+
num_classes: int | None = None,
378+
image: NdarrayOrTensor | None = None,
379+
image_threshold: float = 0.0,
380+
max_samples_per_class: int | None = None,
381+
ratios: list[float | int] | None = None,
382+
rand_state: np.random.RandomState | None = None,
383+
allow_smaller: bool = False,
384+
warn: bool = True,
385+
) -> tuple[tuple]:
386+
"""
387+
Combine "map_classes_to_indices" and "generate_label_classes_crop_centers" functions, return crop center coordinates.
388+
This calls `map_classes_to_indices` to get indices from `label`, gets the shape from `label_spatial_shape`
389+
is given otherwise from the labels, calls `generate_label_classes_crop_centers`, and returns its results.
390+
391+
Args:
392+
label: use the label data to get the indices of every class.
393+
spatial_size: spatial size of the ROIs to be sampled.
394+
num_samples: total sample centers to be generated.
395+
label_spatial_shape: spatial shape of the original label data to unravel selected centers.
396+
indices: sequence of pre-computed foreground indices of every class in 1 dimension.
397+
num_classes: number of classes for argmax label, not necessary for One-Hot label.
398+
image: if image is not None, only return the indices of every class that are within the valid
399+
region of the image (``image > image_threshold``).
400+
image_threshold: if enabled `image`, use ``image > image_threshold`` to
401+
determine the valid image content area and select class indices only in this area.
402+
max_samples_per_class: maximum length of indices in each class to reduce memory consumption.
403+
Default is None, no subsampling.
404+
ratios: ratios of every class in the label to generate crop centers, including background class.
405+
if None, every class will have the same ratio to generate crop centers.
406+
rand_state: numpy randomState object to align with other modules.
407+
allow_smaller: if `False`, an exception will be raised if the image is smaller than
408+
the requested ROI in any dimension. If `True`, any smaller dimensions will be set to
409+
match the cropped size (i.e., no cropping in that dimension).
410+
warn: if `True` prints a warning if a class is not present in the label.
411+
Returns:
412+
Tuple of crop centres
413+
"""
414+
if label is None:
415+
raise ValueError("label must not be None.")
416+
indices = map_classes_to_indices(label, num_classes, image, image_threshold, max_samples_per_class)
417+
418+
if label_spatial_shape is not None:
419+
_shape = label_spatial_shape
420+
elif isinstance(label, monai.data.MetaTensor):
421+
_shape = label.peek_pending_shape()
422+
else:
423+
_shape = label.shape[1:]
424+
425+
if _shape is None:
426+
raise ValueError(
427+
"label_spatial_shape or label with a known shape must be provided to infer the output spatial shape."
428+
)
429+
centers = generate_label_classes_crop_centers(
430+
spatial_size, num_samples, _shape, indices, ratios, rand_state, allow_smaller, warn
431+
)
432+
433+
return ensure_tuple(centers)
434+
435+
371436
def map_binary_to_indices(
372437
label: NdarrayOrTensor, image: NdarrayOrTensor | None = None, image_threshold: float = 0.0
373438
) -> tuple[NdarrayOrTensor, NdarrayOrTensor]:
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
from copy import deepcopy
16+
17+
import numpy as np
18+
from parameterized import parameterized
19+
20+
from monai.transforms import map_and_generate_sampling_centers
21+
from monai.utils.misc import set_determinism
22+
from tests.utils import TEST_NDARRAYS, assert_allclose
23+
24+
TEST_CASE_1 = [
25+
# test Argmax data
26+
{
27+
"label": (np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])),
28+
"spatial_size": [2, 2, 2],
29+
"num_samples": 2,
30+
"label_spatial_shape": [3, 3, 3],
31+
"num_classes": 3,
32+
"image": None,
33+
"ratios": [0, 1, 2],
34+
"image_threshold": 0.0,
35+
},
36+
tuple,
37+
2,
38+
3,
39+
]
40+
41+
TEST_CASE_2 = [
42+
{
43+
"label": (
44+
np.array(
45+
[
46+
[[1, 0, 0], [0, 1, 0], [0, 0, 1]],
47+
[[0, 1, 0], [0, 0, 1], [1, 0, 0]],
48+
[[0, 0, 1], [1, 0, 0], [0, 1, 0]],
49+
]
50+
)
51+
),
52+
"spatial_size": [2, 2, 2],
53+
"num_samples": 1,
54+
"ratios": None,
55+
"label_spatial_shape": [3, 3, 3],
56+
"image": None,
57+
"image_threshold": 0.0,
58+
},
59+
tuple,
60+
1,
61+
3,
62+
]
63+
64+
65+
class TestMapAndGenerateSamplingCenters(unittest.TestCase):
66+
67+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
68+
def test_map_and_generate_sampling_centers(self, input_data, expected_type, expected_count, expected_shape):
69+
results = []
70+
for p in TEST_NDARRAYS + (None,):
71+
input_data = deepcopy(input_data)
72+
if p is not None:
73+
input_data["label"] = p(input_data["label"])
74+
set_determinism(0)
75+
result = map_and_generate_sampling_centers(**input_data)
76+
self.assertIsInstance(result, expected_type)
77+
self.assertEqual(len(result), expected_count)
78+
self.assertEqual(len(result[0]), expected_shape)
79+
# check for consistency between numpy, torch and torch.cuda
80+
results.append(result)
81+
if len(results) > 1:
82+
for x, y in zip(result[0], result[-1]):
83+
assert_allclose(x, y, type_test=False)
84+
85+
86+
if __name__ == "__main__":
87+
unittest.main()

0 commit comments

Comments
 (0)