@@ -418,6 +418,10 @@ class PydicomReader(ImageReader):
418418 If provided, only the matched files will be included. For example, to include the file name
419419 "image_0001.dcm", the regular expression could be `".*image_(\\ d+).dcm"`. Default to `""`.
420420 Set it to `None` to use `pydicom.misc.is_dicom` to match valid files.
421+ to_gpu: If True, load the image into GPU memory using CuPy and Kvikio. This can accelerate data loading.
422+ Default is False. CuPy and Kvikio are required for this option.
423+ In practical use, it's recommended to add a warm up call before the actual loading.
424+ A related tutorial will be prepared in the future, and the document will be updated accordingly.
421425 kwargs: additional args for `pydicom.dcmread` API. more details about available args:
422426 https://pydicom.github.io/pydicom/stable/reference/generated/pydicom.filereader.dcmread.html
423427 If the `get_data` function will be called
@@ -434,6 +438,7 @@ def __init__(
434438 prune_metadata : bool = True ,
435439 label_dict : dict | None = None ,
436440 fname_regex : str = "" ,
441+ to_gpu : bool = False ,
437442 ** kwargs ,
438443 ):
439444 super ().__init__ ()
@@ -444,6 +449,33 @@ def __init__(
444449 self .prune_metadata = prune_metadata
445450 self .label_dict = label_dict
446451 self .fname_regex = fname_regex
452+ if to_gpu and (not has_cp or not has_kvikio ):
453+ warnings .warn (
454+ "PydicomReader: CuPy and/or Kvikio not installed for GPU loading, falling back to CPU loading."
455+ )
456+ to_gpu = False
457+
458+ if to_gpu :
459+ self .warmup_kvikio ()
460+
461+ self .to_gpu = to_gpu
462+
463+ def warmup_kvikio (self ):
464+ """
465+ Warm up the Kvikio library to initialize the internal buffers, cuFile, GDS, etc.
466+ This can accelerate the data loading process when `to_gpu` is set to True.
467+ """
468+ if has_cp and has_kvikio :
469+ a = cp .arange (100 )
470+ with tempfile .NamedTemporaryFile () as tmp_file :
471+ tmp_file_name = tmp_file .name
472+ f = kvikio .CuFile (tmp_file_name , "w" )
473+ f .write (a )
474+ f .close ()
475+
476+ b = cp .empty_like (a )
477+ f = kvikio .CuFile (tmp_file_name , "r" )
478+ f .read (b )
447479
448480 def verify_suffix (self , filename : Sequence [PathLike ] | PathLike ) -> bool :
449481 """
@@ -475,12 +507,15 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
475507 img_ = []
476508
477509 filenames : Sequence [PathLike ] = ensure_tuple (data )
510+ self .filenames = list (filenames )
478511 kwargs_ = self .kwargs .copy ()
512+ if self .to_gpu :
513+ kwargs ["defer_size" ] = "100 KB"
479514 kwargs_ .update (kwargs )
480515
481516 self .has_series = False
482517
483- for name in filenames :
518+ for i , name in enumerate ( filenames ) :
484519 name = f"{ name } "
485520 if Path (name ).is_dir ():
486521 # read DICOM series
@@ -489,20 +524,28 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
489524 else :
490525 series_slcs = [slc for slc in glob .glob (os .path .join (name , "*" )) if pydicom .misc .is_dicom (slc )]
491526 slices = []
527+ loaded_slc_names = []
492528 for slc in series_slcs :
493529 try :
494530 slices .append (pydicom .dcmread (fp = slc , ** kwargs_ ))
531+ loaded_slc_names .append (slc )
495532 except pydicom .errors .InvalidDicomError as e :
496533 warnings .warn (f"Failed to read { slc } with exception: \n { e } ." , stacklevel = 2 )
497- img_ .append (slices if len (slices ) > 1 else slices [0 ])
498534 if len (slices ) > 1 :
499535 self .has_series = True
536+ img_ .append (slices )
537+ self .filenames [i ] = loaded_slc_names # type: ignore
538+ else :
539+ img_ .append (slices [0 ]) # type: ignore
540+ self .filenames [i ] = loaded_slc_names [0 ] # type: ignore
500541 else :
501542 ds = pydicom .dcmread (fp = name , ** kwargs_ )
502- img_ .append (ds )
503- return img_ if len (filenames ) > 1 else img_ [0 ]
543+ img_ .append (ds ) # type: ignore
544+ if len (filenames ) == 1 :
545+ return img_ [0 ]
546+ return img_
504547
505- def _combine_dicom_series (self , data : Iterable ):
548+ def _combine_dicom_series (self , data : Iterable , filenames : Sequence [ PathLike ] ):
506549 """
507550 Combine dicom series (a list of pydicom dataset objects). Their data arrays will be stacked together at a new
508551 dimension as the last dimension.
@@ -522,28 +565,27 @@ def _combine_dicom_series(self, data: Iterable):
522565 """
523566 slices : list = []
524567 # for a dicom series
525- for slc_ds in data :
568+ for slc_ds , filename in zip ( data , filenames ) :
526569 if hasattr (slc_ds , "InstanceNumber" ):
527- slices .append (slc_ds )
570+ slices .append (( slc_ds , filename ) )
528571 else :
529- warnings .warn (f"slice: { slc_ds .filename } does not have InstanceNumber tag, skip it." )
530- slices = sorted (slices , key = lambda s : s .InstanceNumber )
531-
572+ warnings .warn (f"slice: { filename } does not have InstanceNumber tag, skip it." )
573+ slices = sorted (slices , key = lambda s : s [0 ].InstanceNumber )
532574 if len (slices ) == 0 :
533575 raise ValueError ("the input does not have valid slices." )
534576
535- first_slice = slices [0 ]
577+ first_slice , first_filename = slices [0 ]
536578 average_distance = 0.0
537- first_array = self ._get_array_data (first_slice )
579+ first_array = self ._get_array_data (first_slice , first_filename )
538580 shape = first_array .shape
539- spacing = getattr (first_slice , "PixelSpacing" , [1.0 , 1.0 , 1.0 ] )
581+ spacing = getattr (first_slice , "PixelSpacing" , [1.0 ] * len ( shape ) )
540582 prev_pos = getattr (first_slice , "ImagePositionPatient" , (0.0 , 0.0 , 0.0 ))[2 ]
541583 stack_array = [first_array ]
542584 for idx in range (1 , len (slices )):
543- slc_array = self ._get_array_data (slices [idx ])
585+ slc_array = self ._get_array_data (slices [idx ][ 0 ], slices [ idx ][ 1 ] )
544586 slc_shape = slc_array .shape
545- slc_spacing = getattr (slices [idx ], "PixelSpacing" , ( 1.0 , 1.0 , 1.0 ))
546- slc_pos = getattr (slices [idx ], "ImagePositionPatient" , (0.0 , 0.0 , float (idx )))[2 ]
587+ slc_spacing = getattr (slices [idx ][ 0 ] , "PixelSpacing" , [ 1.0 ] * len ( shape ))
588+ slc_pos = getattr (slices [idx ][ 0 ] , "ImagePositionPatient" , (0.0 , 0.0 , float (idx )))[2 ]
547589 if not np .allclose (slc_spacing , spacing ):
548590 warnings .warn (f"the list contains slices that have different spacings { spacing } and { slc_spacing } ." )
549591 if shape != slc_shape :
@@ -555,11 +597,14 @@ def _combine_dicom_series(self, data: Iterable):
555597 if len (slices ) > 1 :
556598 average_distance /= len (slices ) - 1
557599 spacing .append (average_distance )
558- stack_array = np .stack (stack_array , axis = - 1 )
600+ if self .to_gpu :
601+ stack_array = cp .stack (stack_array , axis = - 1 )
602+ else :
603+ stack_array = np .stack (stack_array , axis = - 1 )
559604 stack_metadata = self ._get_meta_dict (first_slice )
560605 stack_metadata ["spacing" ] = np .asarray (spacing )
561- if hasattr (slices [- 1 ], "ImagePositionPatient" ):
562- stack_metadata ["lastImagePositionPatient" ] = np .asarray (slices [- 1 ].ImagePositionPatient )
606+ if hasattr (slices [- 1 ][ 0 ] , "ImagePositionPatient" ):
607+ stack_metadata ["lastImagePositionPatient" ] = np .asarray (slices [- 1 ][ 0 ] .ImagePositionPatient )
563608 stack_metadata [MetaKeys .SPATIAL_SHAPE ] = shape + (len (slices ),)
564609 else :
565610 stack_array = stack_array [0 ]
@@ -597,29 +642,35 @@ def get_data(self, data) -> tuple[np.ndarray, dict]:
597642 if self .has_series is True :
598643 # a list, all objects within a list belong to one dicom series
599644 if not isinstance (data [0 ], list ):
600- dicom_data .append (self ._combine_dicom_series (data ))
645+ # input is a dir, self.filenames is a list of list of filenames
646+ dicom_data .append (self ._combine_dicom_series (data , self .filenames [0 ])) # type: ignore
601647 # a list of list, each inner list represents a dicom series
602648 else :
603- for series in data :
604- dicom_data .append (self ._combine_dicom_series (series ))
649+ for i , series in enumerate ( data ) :
650+ dicom_data .append (self ._combine_dicom_series (series , self . filenames [ i ])) # type: ignore
605651 else :
606652 # a single pydicom dataset object
607653 if not isinstance (data , list ):
608654 data = [data ]
609- for d in data :
655+ for i , d in enumerate ( data ) :
610656 if hasattr (d , "SegmentSequence" ):
611- data_array , metadata = self ._get_seg_data (d )
657+ data_array , metadata = self ._get_seg_data (d , self . filenames [ i ] )
612658 else :
613- data_array = self ._get_array_data (d )
659+ data_array = self ._get_array_data (d , self . filenames [ i ] )
614660 metadata = self ._get_meta_dict (d )
615661 metadata [MetaKeys .SPATIAL_SHAPE ] = data_array .shape
616662 dicom_data .append ((data_array , metadata ))
617663
664+ # TODO: the actual type is list[np.ndarray | cp.ndarray]
665+ # should figure out how to define correct types without having cupy not found error
666+ # https://github.com/Project-MONAI/MONAI/pull/8188#discussion_r1886645918
618667 img_array : list [np .ndarray ] = []
619668 compatible_meta : dict = {}
620669
621670 for data_array , metadata in ensure_tuple (dicom_data ):
622- img_array .append (np .ascontiguousarray (np .swapaxes (data_array , 0 , 1 ) if self .swap_ij else data_array ))
671+ if self .swap_ij :
672+ data_array = cp .swapaxes (data_array , 0 , 1 ) if self .to_gpu else np .swapaxes (data_array , 0 , 1 )
673+ img_array .append (cp .ascontiguousarray (data_array ) if self .to_gpu else np .ascontiguousarray (data_array ))
623674 affine = self ._get_affine (metadata , self .affine_lps_to_ras )
624675 metadata [MetaKeys .SPACE ] = SpaceKeys .RAS if self .affine_lps_to_ras else SpaceKeys .LPS
625676 if self .swap_ij :
@@ -641,7 +692,7 @@ def get_data(self, data) -> tuple[np.ndarray, dict]:
641692
642693 _copy_compatible_dict (metadata , compatible_meta )
643694
644- return _stack_images (img_array , compatible_meta ), compatible_meta
695+ return _stack_images (img_array , compatible_meta , to_cupy = self . to_gpu ), compatible_meta
645696
646697 def _get_meta_dict (self , img ) -> dict :
647698 """
@@ -713,7 +764,7 @@ def _get_affine(self, metadata: dict, lps_to_ras: bool = True):
713764 affine = orientation_ras_lps (affine )
714765 return affine
715766
716- def _get_frame_data (self , img ) -> Iterator :
767+ def _get_frame_data (self , img , filename , array_data ) -> Iterator :
717768 """
718769 yield frames and description from the segmentation image.
719770 This function is adapted from Highdicom:
@@ -751,48 +802,54 @@ def _get_frame_data(self, img) -> Iterator:
751802 """
752803
753804 if not hasattr (img , "PerFrameFunctionalGroupsSequence" ):
754- raise NotImplementedError (
755- f"To read dicom seg: { img .filename } , 'PerFrameFunctionalGroupsSequence' is required."
756- )
805+ raise NotImplementedError (f"To read dicom seg: { filename } , 'PerFrameFunctionalGroupsSequence' is required." )
757806
758807 frame_seg_nums = []
759808 for f in img .PerFrameFunctionalGroupsSequence :
760809 if not hasattr (f , "SegmentIdentificationSequence" ):
761810 raise NotImplementedError (
762- f"To read dicom seg: { img . filename } , 'SegmentIdentificationSequence' is required for each frame."
811+ f"To read dicom seg: { filename } , 'SegmentIdentificationSequence' is required for each frame."
763812 )
764813 frame_seg_nums .append (int (f .SegmentIdentificationSequence [0 ].ReferencedSegmentNumber ))
765814
766- frame_seg_nums_arr = np .array (frame_seg_nums )
815+ frame_seg_nums_arr = cp . array ( frame_seg_nums ) if self . to_gpu else np .array (frame_seg_nums )
767816
768817 seg_descriptions = {int (f .SegmentNumber ): f for f in img .SegmentSequence }
769818
770- for i in np .unique (frame_seg_nums_arr ):
771- indices = np .where (frame_seg_nums_arr == i )[0 ]
772- yield (img . pixel_array [indices , ...], seg_descriptions [i ])
819+ for i in np .unique (frame_seg_nums_arr ) if not self . to_gpu else cp . unique ( frame_seg_nums_arr ) :
820+ indices = np .where (frame_seg_nums_arr == i )[0 ] if not self . to_gpu else cp . where ( frame_seg_nums_arr == i )[ 0 ]
821+ yield (array_data [indices , ...], seg_descriptions [i ])
773822
774- def _get_seg_data (self , img ):
823+ def _get_seg_data (self , img , filename ):
775824 """
776825 Get the array data and metadata of the segmentation image.
777826
778827 Aegs:
779828 img: a Pydicom dataset object that has attribute "SegmentSequence".
829+ filename: the file path of the image.
780830
781831 """
782832
783833 metadata = self ._get_meta_dict (img )
784834 n_classes = len (img .SegmentSequence )
785- spatial_shape = list (img .pixel_array .shape )
835+ array_data = self ._get_array_data (img , filename )
836+ spatial_shape = list (array_data .shape )
786837 spatial_shape [0 ] = spatial_shape [0 ] // n_classes
787838
788839 if self .label_dict is not None :
789840 metadata ["labels" ] = self .label_dict
790- all_segs = np .zeros ([* spatial_shape , len (self .label_dict )])
841+ if self .to_gpu :
842+ all_segs = cp .zeros ([* spatial_shape , len (self .label_dict )], dtype = array_data .dtype )
843+ else :
844+ all_segs = np .zeros ([* spatial_shape , len (self .label_dict )], dtype = array_data .dtype )
791845 else :
792846 metadata ["labels" ] = {}
793- all_segs = np .zeros ([* spatial_shape , n_classes ])
847+ if self .to_gpu :
848+ all_segs = cp .zeros ([* spatial_shape , n_classes ], dtype = array_data .dtype )
849+ else :
850+ all_segs = np .zeros ([* spatial_shape , n_classes ], dtype = array_data .dtype )
794851
795- for i , (frames , description ) in enumerate (self ._get_frame_data (img )):
852+ for i , (frames , description ) in enumerate (self ._get_frame_data (img , filename , array_data )):
796853 segment_label = getattr (description , "SegmentLabel" , f"label_{ i } " )
797854 class_name = getattr (description , "SegmentDescription" , segment_label )
798855 if class_name not in metadata ["labels" ].keys ():
@@ -840,19 +897,79 @@ def _get_seg_data(self, img):
840897
841898 return all_segs , metadata
842899
843- def _get_array_data (self , img ):
900+ def _get_array_data_from_gpu (self , img , filename ):
901+ """
902+ Get the raw array data of the image. This function is used when `to_gpu` is set to True.
903+
904+ Args:
905+ img: a Pydicom dataset object.
906+ filename: the file path of the image.
907+
908+ """
909+ rows = getattr (img , "Rows" , None )
910+ columns = getattr (img , "Columns" , None )
911+ bits_allocated = getattr (img , "BitsAllocated" , None )
912+ samples_per_pixel = getattr (img , "SamplesPerPixel" , 1 )
913+ number_of_frames = getattr (img , "NumberOfFrames" , 1 )
914+ pixel_representation = getattr (img , "PixelRepresentation" , 1 )
915+
916+ if rows is None or columns is None or bits_allocated is None :
917+ warnings .warn (
918+ f"dicom data: { filename } does not have Rows, Columns or BitsAllocated, falling back to CPU loading."
919+ )
920+
921+ if not hasattr (img , "pixel_array" ):
922+ raise ValueError (f"dicom data: { filename } does not have pixel_array." )
923+ data = img .pixel_array
924+
925+ return data
926+
927+ if bits_allocated == 8 :
928+ dtype = cp .int8 if pixel_representation == 1 else cp .uint8
929+ elif bits_allocated == 16 :
930+ dtype = cp .int16 if pixel_representation == 1 else cp .uint16
931+ elif bits_allocated == 32 :
932+ dtype = cp .int32 if pixel_representation == 1 else cp .uint32
933+ else :
934+ raise ValueError ("Unsupported BitsAllocated value" )
935+
936+ bytes_per_pixel = bits_allocated // 8
937+ total_pixels = rows * columns * samples_per_pixel * number_of_frames
938+ expected_pixel_data_length = total_pixels * bytes_per_pixel
939+
940+ pixel_data_tag = pydicom .tag .Tag (0x7FE0 , 0x0010 )
941+ if pixel_data_tag not in img :
942+ raise ValueError (f"dicom data: { filename } does not have pixel data." )
943+
944+ offset = img .get_item (pixel_data_tag , keep_deferred = True ).value_tell
945+
946+ with kvikio .CuFile (filename , "r" ) as f :
947+ buffer = cp .empty (expected_pixel_data_length , dtype = cp .int8 )
948+ f .read (buffer , expected_pixel_data_length , offset )
949+
950+ new_shape = (number_of_frames , rows , columns ) if number_of_frames > 1 else (rows , columns )
951+ data = buffer .view (dtype ).reshape (new_shape )
952+
953+ return data
954+
955+ def _get_array_data (self , img , filename ):
844956 """
845957 Get the array data of the image. If `RescaleSlope` and `RescaleIntercept` are available, the raw array data
846- will be rescaled. The output data has the dtype np. float32 if the rescaling is applied.
958+ will be rescaled. The output data has the dtype float32 if the rescaling is applied.
847959
848960 Args:
849961 img: a Pydicom dataset object.
962+ filename: the file path of the image.
850963
851964 """
852965 # process Dicom series
853- if not hasattr (img , "pixel_array" ):
854- raise ValueError (f"dicom data: { img .filename } does not have pixel_array." )
855- data = img .pixel_array
966+
967+ if self .to_gpu :
968+ data = self ._get_array_data_from_gpu (img , filename )
969+ else :
970+ if not hasattr (img , "pixel_array" ):
971+ raise ValueError (f"dicom data: { filename } does not have pixel_array." )
972+ data = img .pixel_array
856973
857974 slope , offset = 1.0 , 0.0
858975 rescale_flag = False
@@ -862,8 +979,14 @@ def _get_array_data(self, img):
862979 if hasattr (img , "RescaleIntercept" ):
863980 offset = img .RescaleIntercept
864981 rescale_flag = True
982+
865983 if rescale_flag :
866- data = data .astype (np .float32 ) * slope + offset
984+ if self .to_gpu :
985+ slope = cp .asarray (slope , dtype = cp .float32 )
986+ offset = cp .asarray (offset , dtype = cp .float32 )
987+ data = data .astype (cp .float32 ) * slope + offset
988+ else :
989+ data = data .astype (np .float32 ) * slope + offset
867990
868991 return data
869992
@@ -884,8 +1007,6 @@ class NibabelReader(ImageReader):
8841007 Default is False. CuPy and Kvikio are required for this option.
8851008 Note: For compressed NIfTI files, some operations may still be performed on CPU memory,
8861009 and the acceleration may not be significant. In some cases, it may be slower than loading on CPU.
887- In practical use, it's recommended to add a warm up call before the actual loading.
888- A related tutorial will be prepared in the future, and the document will be updated accordingly.
8891010 kwargs: additional args for `nibabel.load` API. more details about available args:
8901011 https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py
8911012
0 commit comments