Skip to content

Commit c2b2fd6

Browse files
KumoLiuericspodpre-commit-ci[bot]
authored andcommitted
Add cache option in GridPatchDataset (Project-MONAI#7180)
Part of Project-MONAI#6904 ### Description - Fix inefficient patching in `PatchDataset` - Add cache option in `GridPatchDataset` ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu <yunl@nvidia.com> Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Yu0610 <612410030@alum.ccu.edu.tw>
1 parent d9f8c3e commit c2b2fd6

File tree

3 files changed

+242
-46
lines changed

3 files changed

+242
-46
lines changed

monai/data/grid_dataset.py

Lines changed: 185 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,30 @@
1111

1212
from __future__ import annotations
1313

14-
from collections.abc import Callable, Generator, Hashable, Iterable, Mapping, Sequence
14+
import sys
15+
import warnings
16+
from collections.abc import Callable, Generator, Hashable, Iterable, Iterator, Mapping, Sequence
1517
from copy import deepcopy
18+
from multiprocessing.managers import ListProxy
19+
from multiprocessing.pool import ThreadPool
20+
from typing import TYPE_CHECKING
1621

1722
import numpy as np
23+
import torch
1824

1925
from monai.config import KeysCollection
2026
from monai.config.type_definitions import NdarrayTensor
21-
from monai.data.dataset import Dataset
2227
from monai.data.iterable_dataset import IterableDataset
23-
from monai.data.utils import iter_patch
24-
from monai.transforms import apply_transform
25-
from monai.utils import NumpyPadMode, ensure_tuple, first
28+
from monai.data.utils import iter_patch, pickle_hashing
29+
from monai.transforms import Compose, RandomizableTrait, Transform, apply_transform, convert_to_contiguous
30+
from monai.utils import NumpyPadMode, ensure_tuple, first, min_version, optional_import
31+
32+
if TYPE_CHECKING:
33+
from tqdm import tqdm
34+
35+
has_tqdm = True
36+
else:
37+
tqdm, has_tqdm = optional_import("tqdm", "4.47.0", min_version, "tqdm")
2638

2739
__all__ = ["PatchDataset", "GridPatchDataset", "PatchIter", "PatchIterd"]
2840

@@ -184,6 +196,25 @@ class GridPatchDataset(IterableDataset):
184196
see also: :py:class:`monai.data.PatchIter` or :py:class:`monai.data.PatchIterd`.
185197
transform: a callable data transform operates on the patches.
186198
with_coordinates: whether to yield the coordinates of each patch, default to `True`.
199+
cache: whether to use cache mache mechanism, default to `False`.
200+
see also: :py:class:`monai.data.CacheDataset`.
201+
cache_num: number of items to be cached. Default is `sys.maxsize`.
202+
will take the minimum of (cache_num, data_length x cache_rate, data_length).
203+
cache_rate: percentage of cached data in total, default is 1.0 (cache all).
204+
will take the minimum of (cache_num, data_length x cache_rate, data_length).
205+
num_workers: the number of worker threads if computing cache in the initialization.
206+
If num_workers is None then the number returned by os.cpu_count() is used.
207+
If a value less than 1 is specified, 1 will be used instead.
208+
progress: whether to display a progress bar.
209+
copy_cache: whether to `deepcopy` the cache content before applying the random transforms,
210+
default to `True`. if the random transforms don't modify the cached content
211+
(for example, randomly crop from the cached image and deepcopy the crop region)
212+
or if every cache item is only used once in a `multi-processing` environment,
213+
may set `copy=False` for better performance.
214+
as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
215+
it may help improve the performance of following logic.
216+
hash_func: a callable to compute hash from data items to be cached.
217+
defaults to `monai.data.utils.pickle_hashing`.
187218
188219
"""
189220

@@ -193,27 +224,148 @@ def __init__(
193224
patch_iter: Callable,
194225
transform: Callable | None = None,
195226
with_coordinates: bool = True,
227+
cache: bool = False,
228+
cache_num: int = sys.maxsize,
229+
cache_rate: float = 1.0,
230+
num_workers: int | None = 1,
231+
progress: bool = True,
232+
copy_cache: bool = True,
233+
as_contiguous: bool = True,
234+
hash_func: Callable[..., bytes] = pickle_hashing,
196235
) -> None:
197236
super().__init__(data=data, transform=None)
237+
if transform is not None and not isinstance(transform, Compose):
238+
transform = Compose(transform)
198239
self.patch_iter = patch_iter
199240
self.patch_transform = transform
200241
self.with_coordinates = with_coordinates
242+
self.set_num = cache_num
243+
self.set_rate = cache_rate
244+
self.progress = progress
245+
self.copy_cache = copy_cache
246+
self.as_contiguous = as_contiguous
247+
self.hash_func = hash_func
248+
self.num_workers = num_workers
249+
if self.num_workers is not None:
250+
self.num_workers = max(int(self.num_workers), 1)
251+
self._cache: list | ListProxy = []
252+
self._cache_other: list | ListProxy = []
253+
self.cache = cache
254+
self.first_random: int | None = None
255+
if self.patch_transform is not None:
256+
self.first_random = self.patch_transform.get_index_of_first(
257+
lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform)
258+
)
201259

202-
def __iter__(self):
203-
for image in super().__iter__():
204-
for patch, *others in self.patch_iter(image):
205-
out_patch = patch
206-
if self.patch_transform is not None:
207-
out_patch = apply_transform(self.patch_transform, patch, map_items=False)
208-
if self.with_coordinates and len(others) > 0: # patch_iter to yield at least 2 items: patch, coords
209-
yield out_patch, others[0]
210-
else:
211-
yield out_patch
260+
if self.cache:
261+
if isinstance(data, Iterator):
262+
raise TypeError("Data can not be iterator when cache is True")
263+
self.set_data(data) # type: ignore
264+
265+
def set_data(self, data: Sequence) -> None:
266+
"""
267+
Set the input data and run deterministic transforms to generate cache content.
268+
269+
Note: should call this func after an entire epoch and must set `persistent_workers=False`
270+
in PyTorch DataLoader, because it needs to create new worker processes based on new
271+
generated cache content.
272+
273+
"""
274+
self.data = data
275+
276+
# only compute cache for the unique items of dataset, and record the last index for duplicated items
277+
mapping = {self.hash_func(v): i for i, v in enumerate(self.data)}
278+
self.cache_num = min(int(self.set_num), int(len(mapping) * self.set_rate), len(mapping))
279+
self._hash_keys = list(mapping)[: self.cache_num]
280+
indices = list(mapping.values())[: self.cache_num]
281+
self._cache, self._cache_other = zip(*self._fill_cache(indices)) # type: ignore
282+
283+
def _fill_cache(self, indices=None) -> list:
284+
"""
285+
Compute and fill the cache content from data source.
286+
287+
Args:
288+
indices: target indices in the `self.data` source to compute cache.
289+
if None, use the first `cache_num` items.
290+
291+
"""
292+
if self.cache_num <= 0:
293+
return []
294+
if indices is None:
295+
indices = list(range(self.cache_num))
296+
if self.progress and not has_tqdm:
297+
warnings.warn("tqdm is not installed, will not show the caching progress bar.")
298+
299+
pfunc = tqdm if self.progress and has_tqdm else (lambda v, **_: v)
300+
with ThreadPool(self.num_workers) as p:
301+
return list(pfunc(p.imap(self._load_cache_item, indices), total=len(indices), desc="Loading dataset"))
302+
303+
def _load_cache_item(self, idx: int):
304+
"""
305+
Args:
306+
idx: the index of the input data sequence.
307+
"""
308+
item = self.data[idx] # type: ignore
309+
patch_cache, other_cache = [], []
310+
for patch, *others in self.patch_iter(item):
311+
if self.first_random is not None:
312+
patch = self.patch_transform(patch, end=self.first_random, threading=True) # type: ignore
313+
314+
if self.as_contiguous:
315+
patch = convert_to_contiguous(patch, memory_format=torch.contiguous_format)
316+
if self.with_coordinates and len(others) > 0: # patch_iter to yield at least 2 items: patch, coords
317+
other_cache.append(others[0])
318+
patch_cache.append(patch)
319+
return patch_cache, other_cache
320+
321+
def _generate_patches(self, src, **apply_args):
322+
"""
323+
yield patches optionally post-processed by transform.
212324
325+
Args:
326+
src: a iterable of image patches.
327+
apply_args: other args for `self.patch_transform`.
328+
329+
"""
330+
for patch, *others in src:
331+
out_patch = patch
332+
if self.patch_transform is not None:
333+
out_patch = self.patch_transform(patch, **apply_args)
334+
if self.with_coordinates and len(others) > 0: # patch_iter to yield at least 2 items: patch, coords
335+
yield out_patch, others[0]
336+
else:
337+
yield out_patch
213338

214-
class PatchDataset(Dataset):
339+
def __iter__(self):
340+
if self.cache:
341+
cache_index = None
342+
for image in super().__iter__():
343+
key = self.hash_func(image)
344+
if key in self._hash_keys:
345+
# if existing in cache, try to get the index in cache
346+
cache_index = self._hash_keys.index(key)
347+
if cache_index is None:
348+
# no cache for this index, execute all the transforms directly
349+
yield from self._generate_patches(self.patch_iter(image))
350+
else:
351+
if self._cache is None:
352+
raise RuntimeError(
353+
"Cache buffer is not initialized, please call `set_data()` before epoch begins."
354+
)
355+
data = self._cache[cache_index] # type: ignore
356+
other = self._cache_other[cache_index] # type: ignore
357+
358+
# load data from cache and execute from the first random transform
359+
data = deepcopy(data) if self.copy_cache else data
360+
yield from self._generate_patches(zip(data, other), start=self.first_random)
361+
else:
362+
for image in super().__iter__():
363+
yield from self._generate_patches(self.patch_iter(image))
364+
365+
366+
class PatchDataset(IterableDataset):
215367
"""
216-
returns a patch from an image dataset.
368+
Yields patches from data read from an image dataset.
217369
The patches are generated by a user-specified callable `patch_func`,
218370
and are optionally post-processed by `transform`.
219371
For example, to generate random patch samples from an image dataset:
@@ -263,26 +415,26 @@ def __init__(
263415
samples_per_image: `patch_func` should return a sequence of `samples_per_image` elements.
264416
transform: transform applied to each patch.
265417
"""
266-
super().__init__(data=data, transform=transform)
418+
super().__init__(data=data, transform=None)
267419

268420
self.patch_func = patch_func
269421
if samples_per_image <= 0:
270422
raise ValueError("sampler_per_image must be a positive integer.")
271423
self.samples_per_image = int(samples_per_image)
424+
self.patch_transform = transform
272425

273426
def __len__(self) -> int:
274-
return len(self.data) * self.samples_per_image
275-
276-
def _transform(self, index: int):
277-
image_id = int(index / self.samples_per_image)
278-
image = self.data[image_id]
279-
patches = self.patch_func(image)
280-
if len(patches) != self.samples_per_image:
281-
raise RuntimeWarning(
282-
f"`patch_func` must return a sequence of length: samples_per_image={self.samples_per_image}."
283-
)
284-
patch_id = (index - image_id * self.samples_per_image) * (-1 if index < 0 else 1)
285-
patch = patches[patch_id]
286-
if self.transform is not None:
287-
patch = apply_transform(self.transform, patch, map_items=False)
288-
return patch
427+
return len(self.data) * self.samples_per_image # type: ignore
428+
429+
def __iter__(self):
430+
for image in super().__iter__():
431+
patches = self.patch_func(image)
432+
if len(patches) != self.samples_per_image:
433+
raise RuntimeWarning(
434+
f"`patch_func` must return a sequence of length: samples_per_image={self.samples_per_image}."
435+
)
436+
for patch in patches:
437+
out_patch = patch
438+
if self.patch_transform is not None:
439+
out_patch = apply_transform(self.patch_transform, patch, map_items=False)
440+
yield out_patch

tests/test_grid_dataset.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -108,19 +108,18 @@ def test_shape(self):
108108
self.assertEqual(sorted(output), sorted(expected))
109109

110110
def test_loading_array(self):
111-
set_determinism(seed=1234)
112111
# test sequence input data with images
113112
images = [np.arange(16, dtype=float).reshape(1, 4, 4), np.arange(16, dtype=float).reshape(1, 4, 4)]
114113
# image level
115-
patch_intensity = RandShiftIntensity(offsets=1.0, prob=1.0)
114+
patch_intensity = RandShiftIntensity(offsets=1.0, prob=1.0).set_random_state(seed=1234)
116115
patch_iter = PatchIter(patch_size=(2, 2), start_pos=(0, 0))
117116
ds = GridPatchDataset(data=images, patch_iter=patch_iter, transform=patch_intensity)
118117
# use the grid patch dataset
119118
for item in DataLoader(ds, batch_size=2, shuffle=False, num_workers=0):
120119
np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2))
121120
np.testing.assert_allclose(
122121
item[0],
123-
np.array([[[[8.240326, 9.240326], [12.240326, 13.240326]]], [[[10.1624, 11.1624], [14.1624, 15.1624]]]]),
122+
np.array([[[[8.708934, 9.708934], [12.708934, 13.708934]]], [[[10.8683, 11.8683], [14.8683, 15.8683]]]]),
124123
rtol=1e-4,
125124
)
126125
np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5)
@@ -129,9 +128,7 @@ def test_loading_array(self):
129128
np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2))
130129
np.testing.assert_allclose(
131130
item[0],
132-
np.array(
133-
[[[[7.723618, 8.723618], [11.723618, 12.723618]]], [[[10.7175, 11.7175], [14.7175, 15.7175]]]]
134-
),
131+
np.array([[[[7.27427, 8.27427], [11.27427, 12.27427]]], [[[9.4353, 10.4353], [13.4353, 14.4353]]]]),
135132
rtol=1e-3,
136133
)
137134
np.testing.assert_allclose(
@@ -164,7 +161,7 @@ def test_loading_dict(self):
164161
self.assertListEqual(item[0]["metadata"], ["test string", "test string"])
165162
np.testing.assert_allclose(
166163
item[0]["image"],
167-
np.array([[[[8.240326, 9.240326], [12.240326, 13.240326]]], [[[10.1624, 11.1624], [14.1624, 15.1624]]]]),
164+
np.array([[[[8.708934, 9.708934], [12.708934, 13.708934]]], [[[10.8683, 11.8683], [14.8683, 15.8683]]]]),
168165
rtol=1e-4,
169166
)
170167
np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5)
@@ -173,15 +170,53 @@ def test_loading_dict(self):
173170
np.testing.assert_equal(item[0]["image"].shape, (2, 1, 2, 2))
174171
np.testing.assert_allclose(
175172
item[0]["image"],
176-
np.array(
177-
[[[[7.723618, 8.723618], [11.723618, 12.723618]]], [[[10.7175, 11.7175], [14.7175, 15.7175]]]]
178-
),
173+
np.array([[[[7.27427, 8.27427], [11.27427, 12.27427]]], [[[9.4353, 10.4353], [13.4353, 14.4353]]]]),
179174
rtol=1e-3,
180175
)
181176
np.testing.assert_allclose(
182177
item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5
183178
)
184179

180+
def test_set_data(self):
181+
from monai.transforms import Compose, Lambda, RandLambda
182+
183+
images = [np.arange(2, 18, dtype=float).reshape(1, 4, 4), np.arange(16, dtype=float).reshape(1, 4, 4)]
184+
185+
transform = Compose(
186+
[Lambda(func=lambda x: np.array(x * 10)), RandLambda(func=lambda x: x + 1)], map_items=False
187+
)
188+
patch_iter = PatchIter(patch_size=(2, 2), start_pos=(0, 0))
189+
dataset = GridPatchDataset(
190+
data=images,
191+
patch_iter=patch_iter,
192+
transform=transform,
193+
cache=True,
194+
cache_rate=1.0,
195+
copy_cache=not sys.platform == "linux",
196+
)
197+
198+
num_workers = 2 if sys.platform == "linux" else 0
199+
for item in DataLoader(dataset, batch_size=2, shuffle=False, num_workers=num_workers):
200+
np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2))
201+
np.testing.assert_allclose(item[0], np.array([[[[81, 91], [121, 131]]], [[[101, 111], [141, 151]]]]), rtol=1e-4)
202+
np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5)
203+
# simulate another epoch, the cache content should not be modified
204+
for item in DataLoader(dataset, batch_size=2, shuffle=False, num_workers=num_workers):
205+
np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2))
206+
np.testing.assert_allclose(item[0], np.array([[[[81, 91], [121, 131]]], [[[101, 111], [141, 151]]]]), rtol=1e-4)
207+
np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5)
208+
209+
# update the datalist and fill the cache content
210+
data_list2 = [np.arange(1, 17, dtype=float).reshape(1, 4, 4)]
211+
dataset.set_data(data=data_list2)
212+
# rerun with updated cache content
213+
for item in DataLoader(dataset, batch_size=2, shuffle=False, num_workers=num_workers):
214+
np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2))
215+
np.testing.assert_allclose(
216+
item[0], np.array([[[[91, 101], [131, 141]]], [[[111, 121], [151, 161]]]]), rtol=1e-4
217+
)
218+
np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5)
219+
185220

186221
if __name__ == "__main__":
187222
unittest.main()

0 commit comments

Comments
 (0)