Skip to content

Commit fa15eec

Browse files
KumoLiuericspod
andauthored
simplify list_data_collate and collate_meta_tensor (#7165)
Fixes #5917 ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu <yunl@nvidia.com> Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
1 parent 85243f5 commit fa15eec

File tree

1 file changed

+31
-11
lines changed

1 file changed

+31
-11
lines changed

monai/data/utils.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,12 @@
5050
issequenceiterable,
5151
look_up_option,
5252
optional_import,
53+
pytorch_after,
5354
)
5455

56+
if pytorch_after(1, 13):
57+
# import private code for reuse purposes, comment in case things break in the future
58+
from torch.utils.data._utils.collate import collate_tensor_fn, default_collate_fn_map
5559
pd, _ = optional_import("pandas")
5660
DataFrame, _ = optional_import("pandas", name="DataFrame")
5761
nib, _ = optional_import("nibabel")
@@ -444,22 +448,31 @@ def pickle_operations(data, key=PICKLE_KEY_SUFFIX, is_encode: bool = True):
444448
return data
445449

446450

451+
def collate_meta_tensor_fn(batch, *, collate_fn_map=None):
452+
"""
453+
Collate a sequence of meta tensor into a single batched metatensor. This is called by `collage_meta_tensor`
454+
and so should not be used as a collate function directly in dataloaders.
455+
"""
456+
collate_fn = collate_tensor_fn if pytorch_after(1, 13) else default_collate
457+
collated = collate_fn(batch) # type: ignore
458+
meta_dicts = [i.meta or TraceKeys.NONE for i in batch]
459+
common_ = set.intersection(*[set(d.keys()) for d in meta_dicts if isinstance(d, dict)])
460+
if common_:
461+
meta_dicts = [{k: d[k] for k in common_} if isinstance(d, dict) else TraceKeys.NONE for d in meta_dicts]
462+
collated.meta = default_collate(meta_dicts)
463+
collated.applied_operations = [i.applied_operations or TraceKeys.NONE for i in batch]
464+
collated.is_batch = True
465+
return collated
466+
467+
447468
def collate_meta_tensor(batch):
448469
"""collate a sequence of meta tensor sequences/dictionaries into
449470
a single batched metatensor or a dictionary of batched metatensor"""
450471
if not isinstance(batch, Sequence):
451472
raise NotImplementedError()
452473
elem_0 = first(batch)
453474
if isinstance(elem_0, MetaObj):
454-
collated = default_collate(batch)
455-
meta_dicts = [i.meta or TraceKeys.NONE for i in batch]
456-
common_ = set.intersection(*[set(d.keys()) for d in meta_dicts if isinstance(d, dict)])
457-
if common_:
458-
meta_dicts = [{k: d[k] for k in common_} if isinstance(d, dict) else TraceKeys.NONE for d in meta_dicts]
459-
collated.meta = default_collate(meta_dicts)
460-
collated.applied_operations = [i.applied_operations or TraceKeys.NONE for i in batch]
461-
collated.is_batch = True
462-
return collated
475+
return collate_meta_tensor_fn(batch)
463476
if isinstance(elem_0, Mapping):
464477
return {k: collate_meta_tensor([d[k] for d in batch]) for k in elem_0}
465478
if isinstance(elem_0, (tuple, list)):
@@ -479,9 +492,16 @@ def list_data_collate(batch: Sequence):
479492
Need to use this collate if apply some transforms that can generate batch data.
480493
481494
"""
495+
496+
if pytorch_after(1, 13):
497+
# needs to go here to avoid circular import
498+
from monai.data.meta_tensor import MetaTensor
499+
500+
default_collate_fn_map.update({MetaTensor: collate_meta_tensor_fn})
482501
elem = batch[0]
483502
data = [i for k in batch for i in k] if isinstance(elem, list) else batch
484503
key = None
504+
collate_fn = default_collate if pytorch_after(1, 13) else collate_meta_tensor
485505
try:
486506
if config.USE_META_DICT:
487507
data = pickle_operations(data) # bc 0.9.0
@@ -490,9 +510,9 @@ def list_data_collate(batch: Sequence):
490510
for k in elem:
491511
key = k
492512
data_for_batch = [d[key] for d in data]
493-
ret[key] = collate_meta_tensor(data_for_batch)
513+
ret[key] = collate_fn(data_for_batch)
494514
else:
495-
ret = collate_meta_tensor(data)
515+
ret = collate_fn(data)
496516
return ret
497517
except RuntimeError as re:
498518
re_str = str(re)

0 commit comments

Comments
 (0)