Skip to content

Commit

Permalink
Fixes #7557
Browse files Browse the repository at this point in the history
Add a function to create a JSON file that maps input and output paths.

Signed-off-by: staydelight <kevin295643815697236@gmail.com>
  • Loading branch information
staydelight committed May 14, 2024
1 parent daf2e71 commit 9bc4cd6
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 27 deletions.
98 changes: 72 additions & 26 deletions monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,19 @@

from __future__ import annotations

import json
import logging
import sys
import glob
import os
import re
import warnings
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Iterator, Sequence
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any

from monai.apps.utils import get_logger
import numpy as np
from torch.utils.data._utils.collate import np_str_obj_array_pattern

Expand Down Expand Up @@ -51,6 +54,16 @@
pydicom, has_pydicom = optional_import("pydicom")
nrrd, has_nrrd = optional_import("nrrd", allow_namespace_pkg=True)

DEFAULT_FMT = "%(asctime)s %(levelname)s %(filename)s:%(lineno)d - %(message)s"

logger = get_logger(module_name=__name__, fmt=DEFAULT_FMT)
logger = logging.getLogger(__name__)
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s'))
logger.addHandler(handler)
logger.setLevel(logging.DEBUG)


__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"]


Expand Down Expand Up @@ -98,8 +111,10 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs) -> Sequence[Any] |
kwargs: additional args for actual `read` API of 3rd party libs.
"""
#self.update_json(input_file=data)
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")


@abstractmethod
def get_data(self, img) -> tuple[np.ndarray, dict]:
"""
Expand Down Expand Up @@ -147,6 +162,24 @@ def _stack_images(image_list: list, meta_dict: dict):
meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0
return np.stack(image_list, axis=0)

def update_json(input_file=None, output_file=None):
record_path = "img-label.json"

if not os.path.exists(record_path) or os.stat(record_path).st_size == 0:
with open(record_path, 'w') as f:
json.dump([], f)

with open(record_path, 'r+') as f:
records = json.load(f)
if input_file:
new_record = {"image": input_file, "label": []}
records.append(new_record)
elif output_file and records:
records[-1]["label"].append(output_file)

f.seek(0)
json.dump(records, f, indent=4)


@require_pkg(pkg_name="itk")
class ITKReader(ImageReader):
Expand All @@ -168,8 +201,8 @@ class ITKReader(ImageReader):
series_name: the name of the DICOM series if there are multiple ones.
used when loading DICOM series.
reverse_indexing: whether to use a reversed spatial indexing convention for the returned data array.
If ``False``, the spatial indexing convention is reversed to be compatible with ITK;
otherwise, the spatial indexing follows the numpy convention. Default is ``False``.
If ``False``, the spatial indexing follows the numpy convention;
otherwise, the spatial indexing convention is reversed to be compatible with ITK. Default is ``False``.
This option does not affect the metadata.
series_meta: whether to load the metadata of the DICOM series (using the metadata from the first slice).
This flag is checked only when loading DICOM series. Default is ``False``.
Expand Down Expand Up @@ -225,6 +258,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
img_ = []

filenames: Sequence[PathLike] = ensure_tuple(data)
update_json(input_file=filenames)
kwargs_ = self.kwargs.copy()
kwargs_.update(kwargs)
for name in filenames:
Expand Down Expand Up @@ -332,6 +366,25 @@ def _get_affine(self, img, lps_to_ras: bool = True):
affine[:sr, -1] = origin[:sr]
if lps_to_ras:
affine = orientation_ras_lps(affine)
logger.debug("lps is changed to ras")

# 使用 Logger 輸出信息

logger.info("\nOrigin[:sr]:")
logger.info(", ".join(f"{x:.10f}" for x in origin[:sr]))

logger.info("\nDirection[:sr, :sr]:")
for row in direction[:sr, :sr]:
logger.info(", ".join(f"{x:.15f}" for x in row))

logger.info("\nSpacing[:sr]:")
logger.info(", ".join(f"{x:.15f}" for x in spacing[:sr]))


# affine = numpy.round(affine, decimals=5)

logger.debug(f"Affine matrix:\n{affine}")

return affine

def _get_spatial_shape(self, img):
Expand Down Expand Up @@ -404,12 +457,8 @@ class PydicomReader(ImageReader):
label_dict: label of the dicom data. If provided, it will be used when loading segmentation data.
Keys of the dict are the classes, and values are the corresponding class number. For example:
for TCIA collection "C4KC-KiTS", it can be: {"Kidney": 0, "Renal Tumor": 1}.
fname_regex: a regular expression to match the file names when the input is a folder.
If provided, only the matched files will be included. For example, to include the file name
"image_0001.dcm", the regular expression could be `".*image_(\\d+).dcm"`. Default to `""`.
Set it to `None` to use `pydicom.misc.is_dicom` to match valid files.
kwargs: additional args for `pydicom.dcmread` API. more details about available args:
https://pydicom.github.io/pydicom/stable/reference/generated/pydicom.filereader.dcmread.html
https://pydicom.github.io/pydicom/stable/reference/generated/pydicom.filereader.dcmread.html#pydicom.filereader.dcmread
If the `get_data` function will be called
(for example, when using this reader with `monai.transforms.LoadImage`), please ensure that the argument
`stop_before_pixels` is `True`, and `specific_tags` covers all necessary tags, such as `PixelSpacing`,
Expand All @@ -423,7 +472,6 @@ def __init__(
swap_ij: bool = True,
prune_metadata: bool = True,
label_dict: dict | None = None,
fname_regex: str = "",
**kwargs,
):
super().__init__()
Expand All @@ -433,7 +481,6 @@ def __init__(
self.swap_ij = swap_ij
self.prune_metadata = prune_metadata
self.label_dict = label_dict
self.fname_regex = fname_regex

def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool:
"""
Expand Down Expand Up @@ -465,6 +512,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
img_ = []

filenames: Sequence[PathLike] = ensure_tuple(data)
update_json(input_file=filenames)
kwargs_ = self.kwargs.copy()
kwargs_.update(kwargs)

Expand All @@ -474,16 +522,9 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
name = f"{name}"
if Path(name).is_dir():
# read DICOM series
if self.fname_regex is not None:
series_slcs = [slc for slc in glob.glob(os.path.join(name, "*")) if re.match(self.fname_regex, slc)]
else:
series_slcs = [slc for slc in glob.glob(os.path.join(name, "*")) if pydicom.misc.is_dicom(slc)]
slices = []
for slc in series_slcs:
try:
slices.append(pydicom.dcmread(fp=slc, **kwargs_))
except pydicom.errors.InvalidDicomError as e:
warnings.warn(f"Failed to read {slc} with exception: \n{e}.", stacklevel=2)
series_slcs = glob.glob(os.path.join(name, "*"))
series_slcs = [slc for slc in series_slcs if "LICENSE" not in slc]
slices = [pydicom.dcmread(fp=slc, **kwargs_) for slc in series_slcs]
img_.append(slices if len(slices) > 1 else slices[0])
if len(slices) > 1:
self.has_series = True
Expand Down Expand Up @@ -913,9 +954,11 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py
"""
logger.info(f"Reading NIfTI data from: {data}")
img_: list[Nifti1Image] = []

filenames: Sequence[PathLike] = ensure_tuple(data)
update_json(input_file=filenames)
kwargs_ = self.kwargs.copy()
kwargs_.update(kwargs)
for name in filenames:
Expand Down Expand Up @@ -1076,13 +1119,14 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
img_: list[Nifti1Image] = []

filenames: Sequence[PathLike] = ensure_tuple(data)
update_json(input_file=filenames)
kwargs_ = self.kwargs.copy()
kwargs_.update(kwargs)
for name in filenames:
img = np.load(name, allow_pickle=True, **kwargs_)
if Path(name).name.endswith(".npz"):
# load expected items from NPZ file
npz_keys = list(img.keys()) if self.npz_keys is None else self.npz_keys
npz_keys = [f"arr_{i}" for i in range(len(img))] if self.npz_keys is None else self.npz_keys
for k in npz_keys:
img_.append(img[k])
else:
Expand Down Expand Up @@ -1173,6 +1217,7 @@ def read(self, data: Sequence[PathLike] | PathLike | np.ndarray, **kwargs):
img_: list[PILImage.Image] = []

filenames: Sequence[PathLike] = ensure_tuple(data)
update_json(input_file=filenames)
kwargs_ = self.kwargs.copy()
kwargs_.update(kwargs)
for name in filenames:
Expand Down Expand Up @@ -1297,10 +1342,11 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs) -> Sequence[Any] |
"""
img_: list = []
filenames: Sequence[PathLike] = ensure_tuple(data)
update_json(input_file=filenames)
kwargs_ = self.kwargs.copy()
kwargs_.update(kwargs)
for name in filenames:
nrrd_image = NrrdImage(*nrrd.read(name, index_order=self.index_order, **kwargs_))
nrrd_image = NrrdImage(*nrrd.read(name, index_order=self.index_order, *kwargs_))
img_.append(nrrd_image)
return img_ if len(filenames) > 1 else img_[0]

Expand All @@ -1323,7 +1369,7 @@ def get_data(self, img: NrrdImage | list[NrrdImage]) -> tuple[np.ndarray, dict]:
header = dict(i.header)
if self.index_order == "C":
header = self._convert_f_to_c_order(header)
header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(header)
header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(i)

if self.affine_lps_to_ras:
header = self._switch_lps_ras(header)
Expand All @@ -1344,7 +1390,7 @@ def get_data(self, img: NrrdImage | list[NrrdImage]) -> tuple[np.ndarray, dict]:

return _stack_images(img_array, compatible_meta), compatible_meta

def _get_affine(self, header: dict) -> np.ndarray:
def _get_affine(self, img: NrrdImage) -> np.ndarray:
"""
Get the affine matrix of the image, it can be used to correct
spacing, orientation or execute spatial transforms.
Expand All @@ -1353,8 +1399,8 @@ def _get_affine(self, header: dict) -> np.ndarray:
img: A `NrrdImage` loaded from image file
"""
direction = header["space directions"]
origin = header["space origin"]
direction = img.header["space directions"]
origin = img.header["space origin"]

x, y = direction.shape
affine_diam = min(x, y) + 1
Expand Down
28 changes: 27 additions & 1 deletion monai/data/image_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from typing import TYPE_CHECKING, Any, cast

import numpy as np
import os
import json

from monai.apps.utils import get_logger
from monai.config import DtypeLike, NdarrayOrTensor, PathLike
Expand Down Expand Up @@ -196,6 +198,25 @@ def write(self, filename: PathLike, verbose: bool = True, **kwargs):
if verbose:
logger.info(f"writing: {filename}")

def update_json(self, input_file=None, output_file=None):
record_path = "img-label.json"

if not os.path.exists(record_path) or os.stat(record_path).st_size == 0:
with open(record_path, 'w') as f:
json.dump([], f)

with open(record_path, 'r+') as f:
records = json.load(f)
if input_file:
new_record = {"image": input_file, "label": []}
records.append(new_record)
elif output_file and records:
records[-1]["label"].append(output_file)

f.seek(0)
json.dump(records, f, indent=4)


@classmethod
def create_backend_obj(cls, data_array: NdarrayOrTensor, **kwargs) -> np.ndarray:
"""
Expand Down Expand Up @@ -276,7 +297,7 @@ def resample_if_needed(
# convert back at the end
if isinstance(output_array, MetaTensor):
output_array.applied_operations = []
data_array, *_ = convert_data_type(output_array, output_type=orig_type)
data_array, *_ = convert_data_type(output_array, output_type=orig_type) # type: ignore
affine, *_ = convert_data_type(output_array.affine, output_type=orig_type) # type: ignore
return data_array[0], affine

Expand Down Expand Up @@ -462,7 +483,9 @@ def write(self, filename: PathLike, verbose: bool = False, **kwargs):
- https://github.com/InsightSoftwareConsortium/ITK/blob/v5.2.1/Wrapping/Generators/Python/itk/support/extras.py#L809
"""
logger.info(f"ITKWriter is processing the file: {filename}")
super().write(filename, verbose=verbose)
super().update_json(output_file=filename)
self.data_obj = self.create_backend_obj(
cast(NdarrayOrTensor, self.data_obj),
channel_dim=self.channel_dim,
Expand Down Expand Up @@ -625,7 +648,9 @@ def write(self, filename: PathLike, verbose: bool = False, **obj_kwargs):
- https://nipy.org/nibabel/reference/nibabel.nifti1.html#nibabel.nifti1.save
"""
logger.info(f"NibabelWriter is processing the file: {filename}")
super().write(filename, verbose=verbose)
super().update_json(output_file=filename)
self.data_obj = self.create_backend_obj(
cast(NdarrayOrTensor, self.data_obj), affine=self.affine, dtype=self.output_dtype, **obj_kwargs
)
Expand Down Expand Up @@ -771,6 +796,7 @@ def write(self, filename: PathLike, verbose: bool = False, **kwargs):
- https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.save
"""
super().write(filename, verbose=verbose)
super().update_json(output_file=filename)
self.data_obj = self.create_backend_obj(
data_array=self.data_obj,
dtype=self.output_dtype,
Expand Down

0 comments on commit 9bc4cd6

Please sign in to comment.