Skip to content

Commit 6239882

Browse files
committed
Merge remote-tracking branch 'tawald/master'
2 parents 75c0e46 + 9a83a2e commit 6239882

File tree

1 file changed

+80
-0
lines changed

1 file changed

+80
-0
lines changed

nnunetv2/training/dataloading/utils.py

+80
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,92 @@
1+
from __future__ import annotations
12
import multiprocessing
23
import os
34
from multiprocessing import Pool
45
from typing import List
6+
from pathlib import Path
7+
from warnings import warn
58

69
import numpy as np
710
from batchgenerators.utilities.file_and_folder_operations import isfile, subfiles
811
from nnunetv2.configuration import default_num_processes
912

1013

14+
def find_broken_image_and_labels(
15+
path_to_data_dir: str | Path,
16+
) -> tuple[set[str], set[str]]:
17+
"""
18+
Iterates through all numpys and tries to read them once to see if a ValueError is raised.
19+
If so, the case id is added to the respective set and returned for potential fixing.
20+
21+
:path_to_data_dir: Path/str to the preprocessed directory containing the npys and npzs.
22+
:returns: Tuple of a set containing the case ids of the broken npy images and a set of the case ids of broken npy segmentations.
23+
"""
24+
content = os.listdir(path_to_data_dir)
25+
unique_ids = [c[:-4] for c in content if c.endswith(".npz")]
26+
failed_data_ids = set()
27+
failed_seg_ids = set()
28+
for unique_id in unique_ids:
29+
# Try reading data
30+
try:
31+
np.load(path_to_data_dir / (unique_id + ".npy"), "r")
32+
except ValueError:
33+
failed_data_ids.add(unique_id)
34+
# Try reading seg
35+
try:
36+
np.load(path_to_data_dir / (unique_id + "_seg.npy"), "r")
37+
except ValueError:
38+
failed_seg_ids.add(unique_id)
39+
40+
return failed_data_ids, failed_seg_ids
41+
42+
43+
def try_fix_broken_npy(path_do_data_dir: Path, case_ids: set[str], fix_image: bool):
44+
"""
45+
Receives broken case ids and tries to fix them by re-extracting the npz file (up to 5 times).
46+
47+
:param case_ids: Set of case ids that are broken.
48+
:param path_do_data_dir: Path to the preprocessed directory containing the npys and npzs.
49+
:raises ValueError: If the npy file could not be unpacked after 5 tries. --
50+
"""
51+
for case_id in case_ids:
52+
for i in range(5):
53+
try:
54+
key = "data" if fix_image else "seg"
55+
suffix = ".npy" if fix_image else "_seg.npy"
56+
read_npz = np.load(path_do_data_dir / (case_id + ".npz"), "r")[key]
57+
np.save(path_do_data_dir / (case_id + suffix), read_npz)
58+
# Try loading the just saved image.
59+
np.load(path_do_data_dir / (case_id + suffix), "r")
60+
break
61+
except ValueError:
62+
if i == 4:
63+
raise ValueError(
64+
f"Could not unpack {case_id + suffix} after 5 tries!"
65+
)
66+
continue
67+
68+
def verify_or_stratify_npys(path_to_data_dir: str | Path) -> None:
69+
"""
70+
This re-reads the npy files after unpacking. Should there be a loading issue with any, it will try to unpack this file again and overwrites the existing.
71+
If the new file does not get saved correctly 5 times, it will raise an error with the file name to the user. Does the same for images and segmentations.
72+
:param path_to_data_dir: Path to the preprocessed directory containing the npys and npzs.
73+
:raises ValueError: If the npy file could not be unpacked after 5 tries. --
74+
Otherwise an obscured error will be raised later during training (depending when the broken file is sampled)
75+
"""
76+
path_to_data_dir = Path(path_to_data_dir)
77+
# Check for broken image and segmentation npys
78+
failed_data_ids, failed_seg_ids = find_broken_image_and_labels(path_to_data_dir)
79+
80+
if len(failed_data_ids) != 0 or len(failed_seg_ids) != 0:
81+
warn(
82+
f"Found {len(failed_data_ids)} faulty data npys and {len(failed_seg_ids)}!\n"
83+
+ f"Faulty images: {failed_data_ids}; Faulty segmentations: {failed_seg_ids})\n"
84+
+ "Trying to fix them now."
85+
)
86+
# Try to fix the broken npys by reextracting the npz. If that fails, raise error
87+
try_fix_broken_npy(path_to_data_dir, failed_data_ids, fix_image=True)
88+
try_fix_broken_npy(path_to_data_dir, failed_seg_ids, fix_image=False)
89+
1190
def _convert_to_npy(npz_file: str, unpack_segmentation: bool = True, overwrite_existing: bool = False) -> None:
1291
try:
1392
a = np.load(npz_file) # inexpensive, no compression is done here. This just reads metadata
@@ -34,6 +113,7 @@ def unpack_dataset(folder: str, unpack_segmentation: bool = True, overwrite_exis
34113
[unpack_segmentation] * len(npz_files),
35114
[overwrite_existing] * len(npz_files))
36115
)
116+
verify_or_stratify_npys(folder)
37117

38118

39119
def get_case_identifiers(folder: str) -> List[str]:

0 commit comments

Comments
 (0)