1
+ from __future__ import annotations
1
2
import multiprocessing
2
3
import os
3
4
from multiprocessing import Pool
4
5
from typing import List
6
+ from pathlib import Path
7
+ from warnings import warn
5
8
6
9
import numpy as np
7
10
from batchgenerators .utilities .file_and_folder_operations import isfile , subfiles
8
11
from nnunetv2 .configuration import default_num_processes
9
12
10
13
14
+ def find_broken_image_and_labels (
15
+ path_to_data_dir : str | Path ,
16
+ ) -> tuple [set [str ], set [str ]]:
17
+ """
18
+ Iterates through all numpys and tries to read them once to see if a ValueError is raised.
19
+ If so, the case id is added to the respective set and returned for potential fixing.
20
+
21
+ :path_to_data_dir: Path/str to the preprocessed directory containing the npys and npzs.
22
+ :returns: Tuple of a set containing the case ids of the broken npy images and a set of the case ids of broken npy segmentations.
23
+ """
24
+ content = os .listdir (path_to_data_dir )
25
+ unique_ids = [c [:- 4 ] for c in content if c .endswith (".npz" )]
26
+ failed_data_ids = set ()
27
+ failed_seg_ids = set ()
28
+ for unique_id in unique_ids :
29
+ # Try reading data
30
+ try :
31
+ np .load (path_to_data_dir / (unique_id + ".npy" ), "r" )
32
+ except ValueError :
33
+ failed_data_ids .add (unique_id )
34
+ # Try reading seg
35
+ try :
36
+ np .load (path_to_data_dir / (unique_id + "_seg.npy" ), "r" )
37
+ except ValueError :
38
+ failed_seg_ids .add (unique_id )
39
+
40
+ return failed_data_ids , failed_seg_ids
41
+
42
+
43
+ def try_fix_broken_npy (path_do_data_dir : Path , case_ids : set [str ], fix_image : bool ):
44
+ """
45
+ Receives broken case ids and tries to fix them by re-extracting the npz file (up to 5 times).
46
+
47
+ :param case_ids: Set of case ids that are broken.
48
+ :param path_do_data_dir: Path to the preprocessed directory containing the npys and npzs.
49
+ :raises ValueError: If the npy file could not be unpacked after 5 tries. --
50
+ """
51
+ for case_id in case_ids :
52
+ for i in range (5 ):
53
+ try :
54
+ key = "data" if fix_image else "seg"
55
+ suffix = ".npy" if fix_image else "_seg.npy"
56
+ read_npz = np .load (path_do_data_dir / (case_id + ".npz" ), "r" )[key ]
57
+ np .save (path_do_data_dir / (case_id + suffix ), read_npz )
58
+ # Try loading the just saved image.
59
+ np .load (path_do_data_dir / (case_id + suffix ), "r" )
60
+ break
61
+ except ValueError :
62
+ if i == 4 :
63
+ raise ValueError (
64
+ f"Could not unpack { case_id + suffix } after 5 tries!"
65
+ )
66
+ continue
67
+
68
+ def verify_or_stratify_npys (path_to_data_dir : str | Path ) -> None :
69
+ """
70
+ This re-reads the npy files after unpacking. Should there be a loading issue with any, it will try to unpack this file again and overwrites the existing.
71
+ If the new file does not get saved correctly 5 times, it will raise an error with the file name to the user. Does the same for images and segmentations.
72
+ :param path_to_data_dir: Path to the preprocessed directory containing the npys and npzs.
73
+ :raises ValueError: If the npy file could not be unpacked after 5 tries. --
74
+ Otherwise an obscured error will be raised later during training (depending when the broken file is sampled)
75
+ """
76
+ path_to_data_dir = Path (path_to_data_dir )
77
+ # Check for broken image and segmentation npys
78
+ failed_data_ids , failed_seg_ids = find_broken_image_and_labels (path_to_data_dir )
79
+
80
+ if len (failed_data_ids ) != 0 or len (failed_seg_ids ) != 0 :
81
+ warn (
82
+ f"Found { len (failed_data_ids )} faulty data npys and { len (failed_seg_ids )} !\n "
83
+ + f"Faulty images: { failed_data_ids } ; Faulty segmentations: { failed_seg_ids } )\n "
84
+ + "Trying to fix them now."
85
+ )
86
+ # Try to fix the broken npys by reextracting the npz. If that fails, raise error
87
+ try_fix_broken_npy (path_to_data_dir , failed_data_ids , fix_image = True )
88
+ try_fix_broken_npy (path_to_data_dir , failed_seg_ids , fix_image = False )
89
+
11
90
def _convert_to_npy (npz_file : str , unpack_segmentation : bool = True , overwrite_existing : bool = False ) -> None :
12
91
try :
13
92
a = np .load (npz_file ) # inexpensive, no compression is done here. This just reads metadata
@@ -34,6 +113,7 @@ def unpack_dataset(folder: str, unpack_segmentation: bool = True, overwrite_exis
34
113
[unpack_segmentation ] * len (npz_files ),
35
114
[overwrite_existing ] * len (npz_files ))
36
115
)
116
+ verify_or_stratify_npys (folder )
37
117
38
118
39
119
def get_case_identifiers (folder : str ) -> List [str ]:
0 commit comments