Skip to content

Commit b6d6d77

Browse files
staydelightstaydelightpre-commit-ci[bot]KumoLiu
authored
Add a mapping function in image_reader.py and image_writer.py (#7769)
Add a function to create a JSON file that maps input and output paths. Fixes #7557 . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: staydelight <kevin295643815697236@gmail.com> Co-authored-by: staydelight <kevin295643815697236@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
1 parent b62d1e1 commit b6d6d77

File tree

7 files changed

+351
-6
lines changed

7 files changed

+351
-6
lines changed

docs/source/transforms.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,12 @@ IO
554554
:members:
555555
:special-members: __call__
556556

557+
`WriteFileMapping`
558+
""""""""""""""""""
559+
.. autoclass:: WriteFileMapping
560+
:members:
561+
:special-members: __call__
562+
557563

558564
NVIDIA Tool Extension (NVTX)
559565
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -1642,6 +1648,12 @@ IO (Dict)
16421648
:members:
16431649
:special-members: __call__
16441650

1651+
`WriteFileMappingd`
1652+
"""""""""""""""""""
1653+
.. autoclass:: WriteFileMappingd
1654+
:members:
1655+
:special-members: __call__
1656+
16451657
Post-processing (Dict)
16461658
^^^^^^^^^^^^^^^^^^^^^^
16471659

monai/transforms/__init__.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,8 +238,18 @@
238238
)
239239
from .inverse import InvertibleTransform, TraceableTransform
240240
from .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict
241-
from .io.array import SUPPORTED_READERS, LoadImage, SaveImage
242-
from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict
241+
from .io.array import SUPPORTED_READERS, LoadImage, SaveImage, WriteFileMapping
242+
from .io.dictionary import (
243+
LoadImaged,
244+
LoadImageD,
245+
LoadImageDict,
246+
SaveImaged,
247+
SaveImageD,
248+
SaveImageDict,
249+
WriteFileMappingd,
250+
WriteFileMappingD,
251+
WriteFileMappingDict,
252+
)
243253
from .lazy.array import ApplyPending
244254
from .lazy.dictionary import ApplyPendingd, ApplyPendingD, ApplyPendingDict
245255
from .lazy.functional import apply_pending

monai/transforms/io/array.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
import inspect
18+
import json
1819
import logging
1920
import sys
2021
import traceback
@@ -45,11 +46,19 @@
4546
from monai.transforms.utility.array import EnsureChannelFirst
4647
from monai.utils import GridSamplePadMode
4748
from monai.utils import ImageMetaKey as Key
48-
from monai.utils import OptionalImportError, convert_to_dst_type, ensure_tuple, look_up_option, optional_import
49+
from monai.utils import (
50+
MetaKeys,
51+
OptionalImportError,
52+
convert_to_dst_type,
53+
ensure_tuple,
54+
look_up_option,
55+
optional_import,
56+
)
4957

5058
nib, _ = optional_import("nibabel")
5159
Image, _ = optional_import("PIL.Image")
5260
nrrd, _ = optional_import("nrrd")
61+
FileLock, has_filelock = optional_import("filelock", name="FileLock")
5362

5463
__all__ = ["LoadImage", "SaveImage", "SUPPORTED_READERS"]
5564

@@ -505,7 +514,7 @@ def __call__(
505514
else:
506515
self._data_index += 1
507516
if self.savepath_in_metadict and meta_data is not None:
508-
meta_data["saved_to"] = filename
517+
meta_data[MetaKeys.SAVED_TO] = filename
509518
return img
510519
msg = "\n".join([f"{e}" for e in err])
511520
raise RuntimeError(
@@ -514,3 +523,50 @@ def __call__(
514523
" https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n"
515524
f" The current registered writers for {self.output_ext}: {self.writers}.\n{msg}"
516525
)
526+
527+
528+
class WriteFileMapping(Transform):
529+
"""
530+
Writes a JSON file that logs the mapping between input image paths and their corresponding output paths.
531+
This class uses FileLock to ensure safe writing to the JSON file in a multiprocess environment.
532+
533+
Args:
534+
mapping_file_path (Path or str): Path to the JSON file where the mappings will be saved.
535+
"""
536+
537+
def __init__(self, mapping_file_path: Path | str = "mapping.json"):
538+
self.mapping_file_path = Path(mapping_file_path)
539+
540+
def __call__(self, img: NdarrayOrTensor):
541+
"""
542+
Args:
543+
img: The input image with metadata.
544+
"""
545+
if isinstance(img, MetaTensor):
546+
meta_data = img.meta
547+
548+
if MetaKeys.SAVED_TO not in meta_data:
549+
raise KeyError(
550+
"Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True."
551+
)
552+
553+
input_path = meta_data[Key.FILENAME_OR_OBJ]
554+
output_path = meta_data[MetaKeys.SAVED_TO]
555+
log_data = {"input": input_path, "output": output_path}
556+
557+
if has_filelock:
558+
with FileLock(str(self.mapping_file_path) + ".lock"):
559+
self._write_to_file(log_data)
560+
else:
561+
self._write_to_file(log_data)
562+
return img
563+
564+
def _write_to_file(self, log_data):
565+
try:
566+
with self.mapping_file_path.open("r") as f:
567+
existing_log_data = json.load(f)
568+
except (FileNotFoundError, json.JSONDecodeError):
569+
existing_log_data = []
570+
existing_log_data.append(log_data)
571+
with self.mapping_file_path.open("w") as f:
572+
json.dump(existing_log_data, f, indent=4)

monai/transforms/io/dictionary.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,17 @@
1717

1818
from __future__ import annotations
1919

20+
from collections.abc import Hashable, Mapping
2021
from pathlib import Path
2122
from typing import Callable
2223

2324
import numpy as np
2425

2526
import monai
26-
from monai.config import DtypeLike, KeysCollection
27+
from monai.config import DtypeLike, KeysCollection, NdarrayOrTensor
2728
from monai.data import image_writer
2829
from monai.data.image_reader import ImageReader
29-
from monai.transforms.io.array import LoadImage, SaveImage
30+
from monai.transforms.io.array import LoadImage, SaveImage, WriteFileMapping
3031
from monai.transforms.transform import MapTransform, Transform
3132
from monai.utils import GridSamplePadMode, ensure_tuple, ensure_tuple_rep
3233
from monai.utils.enums import PostFix
@@ -320,5 +321,31 @@ def __call__(self, data):
320321
return d
321322

322323

324+
class WriteFileMappingd(MapTransform):
325+
"""
326+
Dictionary-based wrapper of :py:class:`monai.transforms.WriteFileMapping`.
327+
328+
Args:
329+
keys: keys of the corresponding items to be transformed.
330+
See also: :py:class:`monai.transforms.compose.MapTransform`
331+
mapping_file_path: Path to the JSON file where the mappings will be saved.
332+
Defaults to "mapping.json".
333+
allow_missing_keys: don't raise exception if key is missing.
334+
"""
335+
336+
def __init__(
337+
self, keys: KeysCollection, mapping_file_path: Path | str = "mapping.json", allow_missing_keys: bool = False
338+
) -> None:
339+
super().__init__(keys, allow_missing_keys)
340+
self.mapping = WriteFileMapping(mapping_file_path)
341+
342+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
343+
d = dict(data)
344+
for key in self.key_iterator(d):
345+
d[key] = self.mapping(d[key])
346+
return d
347+
348+
323349
LoadImageD = LoadImageDict = LoadImaged
324350
SaveImageD = SaveImageDict = SaveImaged
351+
WriteFileMappingD = WriteFileMappingDict = WriteFileMappingd

monai/utils/enums.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,7 @@ class MetaKeys(StrEnum):
543543
SPATIAL_SHAPE = "spatial_shape" # optional key for the length in each spatial dimension
544544
SPACE = "space" # possible values of space type are defined in `SpaceKeys`
545545
ORIGINAL_CHANNEL_DIM = "original_channel_dim" # an integer or float("nan")
546+
SAVED_TO = "saved_to"
546547

547548

548549
class ColorOrder(StrEnum):

tests/test_mapping_file.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import json
15+
import os
16+
import shutil
17+
import tempfile
18+
import unittest
19+
20+
import numpy as np
21+
from parameterized import parameterized
22+
23+
from monai.data import DataLoader, Dataset
24+
from monai.transforms import Compose, LoadImage, SaveImage, WriteFileMapping
25+
from monai.utils import optional_import
26+
27+
nib, has_nib = optional_import("nibabel")
28+
29+
30+
def create_input_file(temp_dir, name):
31+
test_image = np.random.rand(128, 128, 128)
32+
output_ext = ".nii.gz"
33+
input_file = os.path.join(temp_dir, name + output_ext)
34+
nib.save(nib.Nifti1Image(test_image, np.eye(4)), input_file)
35+
return input_file
36+
37+
38+
def create_transform(temp_dir, mapping_file_path, savepath_in_metadict=True):
39+
return Compose(
40+
[
41+
LoadImage(image_only=True),
42+
SaveImage(output_dir=temp_dir, output_ext=".nii.gz", savepath_in_metadict=savepath_in_metadict),
43+
WriteFileMapping(mapping_file_path=mapping_file_path),
44+
]
45+
)
46+
47+
48+
@unittest.skipUnless(has_nib, "nibabel required")
49+
class TestWriteFileMapping(unittest.TestCase):
50+
def setUp(self):
51+
self.temp_dir = tempfile.mkdtemp()
52+
53+
def tearDown(self):
54+
shutil.rmtree(self.temp_dir)
55+
56+
@parameterized.expand([(True,), (False,)])
57+
def test_mapping_file(self, savepath_in_metadict):
58+
mapping_file_path = os.path.join(self.temp_dir, "mapping.json")
59+
name = "test_image"
60+
input_file = create_input_file(self.temp_dir, name)
61+
output_file = os.path.join(self.temp_dir, name, name + "_trans.nii.gz")
62+
63+
transform = create_transform(self.temp_dir, mapping_file_path, savepath_in_metadict)
64+
65+
if savepath_in_metadict:
66+
transform(input_file)
67+
self.assertTrue(os.path.exists(mapping_file_path))
68+
with open(mapping_file_path) as f:
69+
mapping_data = json.load(f)
70+
self.assertEqual(len(mapping_data), 1)
71+
self.assertEqual(mapping_data[0]["input"], input_file)
72+
self.assertEqual(mapping_data[0]["output"], output_file)
73+
else:
74+
with self.assertRaises(RuntimeError) as cm:
75+
transform(input_file)
76+
cause_exception = cm.exception.__cause__
77+
self.assertIsInstance(cause_exception, KeyError)
78+
self.assertIn(
79+
"Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True.",
80+
str(cause_exception),
81+
)
82+
83+
def test_multiprocess_mapping_file(self):
84+
num_images = 50
85+
86+
single_mapping_file = os.path.join(self.temp_dir, "single_mapping.json")
87+
multi_mapping_file = os.path.join(self.temp_dir, "multi_mapping.json")
88+
89+
data = [create_input_file(self.temp_dir, f"test_image_{i}") for i in range(num_images)]
90+
91+
# single process
92+
single_transform = create_transform(self.temp_dir, single_mapping_file)
93+
single_dataset = Dataset(data=data, transform=single_transform)
94+
single_loader = DataLoader(single_dataset, batch_size=1, num_workers=0, shuffle=True)
95+
for _ in single_loader:
96+
pass
97+
98+
# multiple processes
99+
multi_transform = create_transform(self.temp_dir, multi_mapping_file)
100+
multi_dataset = Dataset(data=data, transform=multi_transform)
101+
multi_loader = DataLoader(multi_dataset, batch_size=4, num_workers=3, shuffle=True)
102+
for _ in multi_loader:
103+
pass
104+
105+
with open(single_mapping_file) as f:
106+
single_mapping_data = json.load(f)
107+
with open(multi_mapping_file) as f:
108+
multi_mapping_data = json.load(f)
109+
110+
single_set = {(entry["input"], entry["output"]) for entry in single_mapping_data}
111+
multi_set = {(entry["input"], entry["output"]) for entry in multi_mapping_data}
112+
113+
self.assertEqual(single_set, multi_set)
114+
115+
116+
if __name__ == "__main__":
117+
unittest.main()

0 commit comments

Comments
 (0)