1111
1212from __future__ import annotations
1313
14- from collections .abc import Callable , Generator , Hashable , Iterable , Mapping , Sequence
14+ import sys
15+ import warnings
16+ from collections .abc import Callable , Generator , Hashable , Iterable , Iterator , Mapping , Sequence
1517from copy import deepcopy
18+ from multiprocessing .managers import ListProxy
19+ from multiprocessing .pool import ThreadPool
20+ from typing import TYPE_CHECKING
1621
1722import numpy as np
23+ import torch
1824
1925from monai .config import KeysCollection
2026from monai .config .type_definitions import NdarrayTensor
21- from monai .data .dataset import Dataset
2227from monai .data .iterable_dataset import IterableDataset
23- from monai .data .utils import iter_patch
24- from monai .transforms import apply_transform
25- from monai .utils import NumpyPadMode , ensure_tuple , first
28+ from monai .data .utils import iter_patch , pickle_hashing
29+ from monai .transforms import Compose , RandomizableTrait , Transform , apply_transform , convert_to_contiguous
30+ from monai .utils import NumpyPadMode , ensure_tuple , first , min_version , optional_import
31+
32+ if TYPE_CHECKING :
33+ from tqdm import tqdm
34+
35+ has_tqdm = True
36+ else :
37+ tqdm , has_tqdm = optional_import ("tqdm" , "4.47.0" , min_version , "tqdm" )
2638
2739__all__ = ["PatchDataset" , "GridPatchDataset" , "PatchIter" , "PatchIterd" ]
2840
@@ -184,6 +196,25 @@ class GridPatchDataset(IterableDataset):
184196 see also: :py:class:`monai.data.PatchIter` or :py:class:`monai.data.PatchIterd`.
185197 transform: a callable data transform operates on the patches.
186198 with_coordinates: whether to yield the coordinates of each patch, default to `True`.
199+ cache: whether to use cache mache mechanism, default to `False`.
200+ see also: :py:class:`monai.data.CacheDataset`.
201+ cache_num: number of items to be cached. Default is `sys.maxsize`.
202+ will take the minimum of (cache_num, data_length x cache_rate, data_length).
203+ cache_rate: percentage of cached data in total, default is 1.0 (cache all).
204+ will take the minimum of (cache_num, data_length x cache_rate, data_length).
205+ num_workers: the number of worker threads if computing cache in the initialization.
206+ If num_workers is None then the number returned by os.cpu_count() is used.
207+ If a value less than 1 is specified, 1 will be used instead.
208+ progress: whether to display a progress bar.
209+ copy_cache: whether to `deepcopy` the cache content before applying the random transforms,
210+ default to `True`. if the random transforms don't modify the cached content
211+ (for example, randomly crop from the cached image and deepcopy the crop region)
212+ or if every cache item is only used once in a `multi-processing` environment,
213+ may set `copy=False` for better performance.
214+ as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
215+ it may help improve the performance of following logic.
216+ hash_func: a callable to compute hash from data items to be cached.
217+ defaults to `monai.data.utils.pickle_hashing`.
187218
188219 """
189220
@@ -193,27 +224,148 @@ def __init__(
193224 patch_iter : Callable ,
194225 transform : Callable | None = None ,
195226 with_coordinates : bool = True ,
227+ cache : bool = False ,
228+ cache_num : int = sys .maxsize ,
229+ cache_rate : float = 1.0 ,
230+ num_workers : int | None = 1 ,
231+ progress : bool = True ,
232+ copy_cache : bool = True ,
233+ as_contiguous : bool = True ,
234+ hash_func : Callable [..., bytes ] = pickle_hashing ,
196235 ) -> None :
197236 super ().__init__ (data = data , transform = None )
237+ if transform is not None and not isinstance (transform , Compose ):
238+ transform = Compose (transform )
198239 self .patch_iter = patch_iter
199240 self .patch_transform = transform
200241 self .with_coordinates = with_coordinates
242+ self .set_num = cache_num
243+ self .set_rate = cache_rate
244+ self .progress = progress
245+ self .copy_cache = copy_cache
246+ self .as_contiguous = as_contiguous
247+ self .hash_func = hash_func
248+ self .num_workers = num_workers
249+ if self .num_workers is not None :
250+ self .num_workers = max (int (self .num_workers ), 1 )
251+ self ._cache : list | ListProxy = []
252+ self ._cache_other : list | ListProxy = []
253+ self .cache = cache
254+ self .first_random : int | None = None
255+ if self .patch_transform is not None :
256+ self .first_random = self .patch_transform .get_index_of_first (
257+ lambda t : isinstance (t , RandomizableTrait ) or not isinstance (t , Transform )
258+ )
201259
202- def __iter__ (self ):
203- for image in super ().__iter__ ():
204- for patch , * others in self .patch_iter (image ):
205- out_patch = patch
206- if self .patch_transform is not None :
207- out_patch = apply_transform (self .patch_transform , patch , map_items = False )
208- if self .with_coordinates and len (others ) > 0 : # patch_iter to yield at least 2 items: patch, coords
209- yield out_patch , others [0 ]
210- else :
211- yield out_patch
260+ if self .cache :
261+ if isinstance (data , Iterator ):
262+ raise TypeError ("Data can not be iterator when cache is True" )
263+ self .set_data (data ) # type: ignore
264+
265+ def set_data (self , data : Sequence ) -> None :
266+ """
267+ Set the input data and run deterministic transforms to generate cache content.
268+
269+ Note: should call this func after an entire epoch and must set `persistent_workers=False`
270+ in PyTorch DataLoader, because it needs to create new worker processes based on new
271+ generated cache content.
272+
273+ """
274+ self .data = data
275+
276+ # only compute cache for the unique items of dataset, and record the last index for duplicated items
277+ mapping = {self .hash_func (v ): i for i , v in enumerate (self .data )}
278+ self .cache_num = min (int (self .set_num ), int (len (mapping ) * self .set_rate ), len (mapping ))
279+ self ._hash_keys = list (mapping )[: self .cache_num ]
280+ indices = list (mapping .values ())[: self .cache_num ]
281+ self ._cache , self ._cache_other = zip (* self ._fill_cache (indices )) # type: ignore
282+
283+ def _fill_cache (self , indices = None ) -> list :
284+ """
285+ Compute and fill the cache content from data source.
286+
287+ Args:
288+ indices: target indices in the `self.data` source to compute cache.
289+ if None, use the first `cache_num` items.
290+
291+ """
292+ if self .cache_num <= 0 :
293+ return []
294+ if indices is None :
295+ indices = list (range (self .cache_num ))
296+ if self .progress and not has_tqdm :
297+ warnings .warn ("tqdm is not installed, will not show the caching progress bar." )
298+
299+ pfunc = tqdm if self .progress and has_tqdm else (lambda v , ** _ : v )
300+ with ThreadPool (self .num_workers ) as p :
301+ return list (pfunc (p .imap (self ._load_cache_item , indices ), total = len (indices ), desc = "Loading dataset" ))
302+
303+ def _load_cache_item (self , idx : int ):
304+ """
305+ Args:
306+ idx: the index of the input data sequence.
307+ """
308+ item = self .data [idx ] # type: ignore
309+ patch_cache , other_cache = [], []
310+ for patch , * others in self .patch_iter (item ):
311+ if self .first_random is not None :
312+ patch = self .patch_transform (patch , end = self .first_random , threading = True ) # type: ignore
313+
314+ if self .as_contiguous :
315+ patch = convert_to_contiguous (patch , memory_format = torch .contiguous_format )
316+ if self .with_coordinates and len (others ) > 0 : # patch_iter to yield at least 2 items: patch, coords
317+ other_cache .append (others [0 ])
318+ patch_cache .append (patch )
319+ return patch_cache , other_cache
320+
321+ def _generate_patches (self , src , ** apply_args ):
322+ """
323+ yield patches optionally post-processed by transform.
212324
325+ Args:
326+ src: a iterable of image patches.
327+ apply_args: other args for `self.patch_transform`.
328+
329+ """
330+ for patch , * others in src :
331+ out_patch = patch
332+ if self .patch_transform is not None :
333+ out_patch = self .patch_transform (patch , ** apply_args )
334+ if self .with_coordinates and len (others ) > 0 : # patch_iter to yield at least 2 items: patch, coords
335+ yield out_patch , others [0 ]
336+ else :
337+ yield out_patch
213338
214- class PatchDataset (Dataset ):
339+ def __iter__ (self ):
340+ if self .cache :
341+ cache_index = None
342+ for image in super ().__iter__ ():
343+ key = self .hash_func (image )
344+ if key in self ._hash_keys :
345+ # if existing in cache, try to get the index in cache
346+ cache_index = self ._hash_keys .index (key )
347+ if cache_index is None :
348+ # no cache for this index, execute all the transforms directly
349+ yield from self ._generate_patches (self .patch_iter (image ))
350+ else :
351+ if self ._cache is None :
352+ raise RuntimeError (
353+ "Cache buffer is not initialized, please call `set_data()` before epoch begins."
354+ )
355+ data = self ._cache [cache_index ] # type: ignore
356+ other = self ._cache_other [cache_index ] # type: ignore
357+
358+ # load data from cache and execute from the first random transform
359+ data = deepcopy (data ) if self .copy_cache else data
360+ yield from self ._generate_patches (zip (data , other ), start = self .first_random )
361+ else :
362+ for image in super ().__iter__ ():
363+ yield from self ._generate_patches (self .patch_iter (image ))
364+
365+
366+ class PatchDataset (IterableDataset ):
215367 """
216- returns a patch from an image dataset.
368+ Yields patches from data read from an image dataset.
217369 The patches are generated by a user-specified callable `patch_func`,
218370 and are optionally post-processed by `transform`.
219371 For example, to generate random patch samples from an image dataset:
@@ -263,26 +415,26 @@ def __init__(
263415 samples_per_image: `patch_func` should return a sequence of `samples_per_image` elements.
264416 transform: transform applied to each patch.
265417 """
266- super ().__init__ (data = data , transform = transform )
418+ super ().__init__ (data = data , transform = None )
267419
268420 self .patch_func = patch_func
269421 if samples_per_image <= 0 :
270422 raise ValueError ("sampler_per_image must be a positive integer." )
271423 self .samples_per_image = int (samples_per_image )
424+ self .patch_transform = transform
272425
273426 def __len__ (self ) -> int :
274- return len (self .data ) * self .samples_per_image
275-
276- def _transform (self , index : int ):
277- image_id = int (index / self .samples_per_image )
278- image = self .data [image_id ]
279- patches = self .patch_func (image )
280- if len (patches ) != self .samples_per_image :
281- raise RuntimeWarning (
282- f"`patch_func` must return a sequence of length: samples_per_image={ self .samples_per_image } ."
283- )
284- patch_id = (index - image_id * self .samples_per_image ) * (- 1 if index < 0 else 1 )
285- patch = patches [patch_id ]
286- if self .transform is not None :
287- patch = apply_transform (self .transform , patch , map_items = False )
288- return patch
427+ return len (self .data ) * self .samples_per_image # type: ignore
428+
429+ def __iter__ (self ):
430+ for image in super ().__iter__ ():
431+ patches = self .patch_func (image )
432+ if len (patches ) != self .samples_per_image :
433+ raise RuntimeWarning (
434+ f"`patch_func` must return a sequence of length: samples_per_image={ self .samples_per_image } ."
435+ )
436+ for patch in patches :
437+ out_patch = patch
438+ if self .patch_transform is not None :
439+ out_patch = apply_transform (self .patch_transform , patch , map_items = False )
440+ yield out_patch
0 commit comments