Skip to content

Commit

Permalink
Merge branch 'testing_training'
Browse files Browse the repository at this point in the history
  • Loading branch information
Abe404 committed Jul 27, 2023
2 parents bedb9dc + 1ca29e5 commit 2f4094f
Show file tree
Hide file tree
Showing 25 changed files with 1,971 additions and 395 deletions.
2 changes: 1 addition & 1 deletion painter/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ setuptools==46.1.3
natsort==6.0.0
scikit-image==0.19.1
PyWavelets>=1.1.1
Pillow==9.0.1
Pillow==10.0.0
scipy>=1.4.1
qimage2ndarray==1.8.3
nibabel==5.0.0
Expand Down
15 changes: 1 addition & 14 deletions painter/src/compute_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,20 +123,7 @@ def get_seg_metrics(seg_dir, gt_dir, fname):

gt = im_utils.load_seg(os.path.join(gt_dir, fname))
gt = gt.astype(bool).astype(int)

# if they dont match in shape assume the gt needs the depth moving from last to first
if gt.shape != seg.shape:
# these transformmations are taken from the server code that
# loads and segments an image.
# here we assume the gt has the same orientation as the original image
# hence the same transofmrations need to be applied to get the gt
# to align to the segmented image.

# FIXME TODO: Consider removing this soon.
gt = np.rot90(gt, k=3)
gt = np.moveaxis(gt, -1, 0) # depth moved to beginning
# reverse lr and ud
gt = gt[::-1, :, ::-1]
assert gt.shape == seg.shape
m = metrics_from_binary_masks(seg, gt)
m.fname = fname
return m
Expand Down
38 changes: 13 additions & 25 deletions painter/src/im_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,6 @@ def load_image_with_header(image_path):
header = image.header
affine = image.affine
image = np.array(image.dataobj)

image = np.rot90(image, k=3)
image = np.moveaxis(image, -1, 0) # depth moved to beginning
# reverse lr and ud
image = image[::-1, :, ::-1]
return image.astype(int), affine, header

def load_image(image_path):
Expand All @@ -66,16 +61,8 @@ def load_image(image_path):
if image_path.endswith('.nii.gz'):
image = nib.load(image_path)
image = np.array(image.dataobj)
image = np.rot90(image, k=3)
image = np.moveaxis(image, -1, 0) # depth moved to beginning
# reverse lr and ud
image = image[::-1, :, ::-1]
elif image_path.endswith('.nrrd'):
image, _header = nrrd.read(image_path)
image = np.rot90(image, k=3)
image = np.moveaxis(image, -1, 0) # depth moved to beginning
# reverse lr and ud
image = image[::-1, :, ::-1]
else:
raise Exception(f"Unhandled file ending {image_path}")
image = image.astype(int)
Expand Down Expand Up @@ -125,6 +112,12 @@ def annot_slice_to_pixmap(slice_np):


def get_outline_pixmap(seg_slice, annot_slice):

assert seg_slice.shape == annot_slice[0].shape, (
'get_outline_pixmap: '
f'seg_slice shape {seg_slice.shape} should match '
f' annot_slice shape {annot_slice.shape}')

seg_map = (seg_slice > 0).astype(int)
annot_plus = (annot_slice[1] > 0).astype(int)
annot_minus = (annot_slice[0] > 0).astype(int)
Expand All @@ -149,7 +142,7 @@ def seg_slice_to_pixmap(slice_np):
return QtGui.QPixmap.fromImage(q_image)

def get_slice(volume, slice_idx, mode):
if mode == 'axial':
if mode == 'sagittal':
if len(volume.shape) > 3:
slice_idx = (volume.shape[1] - slice_idx) - 1
# if more than 3 presume first is channel dimension
Expand All @@ -164,14 +157,15 @@ def get_slice(volume, slice_idx, mode):
# slice_data = volume[:, :, :, slice_idx]
#else:
# slice_data = volume[:, slice_idx, :]
elif mode == 'sagittal':
elif mode == 'axial':
if len(volume.shape) > 3:
# if more than 3 presume first is channel dimension
slice_data = volume[:, :, :, slice_idx]
else:
slice_data = volume[:, :, slice_idx]
else:
raise Exception(f"Unhandled slice mode: {mode}")
# not sure why I had to rot90. Based on visual inspection
return slice_data


Expand All @@ -183,15 +177,17 @@ def store_annot_slice(annot_pixmap, annot_data, slice_idx, mode):
slice_rgb_np = np.array(qimage2ndarray.rgb_view(annot_pixmap.toImage()))
fg = slice_rgb_np[:, :, 0] > 0
bg = slice_rgb_np[:, :, 1] > 0
if mode == 'axial':


if mode == 'sagittal':
slice_idx = (annot_data.shape[1] - slice_idx) - 1
annot_data[0, slice_idx] = bg
annot_data[1, slice_idx] = fg
elif mode == 'coronal':
raise Exception("not yet implemented")
# annot_data[0, :, slice_idx, :] = bg
# annot_data[1, :, slice_idx, :] = fg
elif mode == 'sagittal':
elif mode == 'axial':
#slice_idx = (annot_data.shape[3] - slice_idx) - 1
annot_data[0, :, :, slice_idx] = bg
annot_data[1, :, :, slice_idx] = fg
Expand Down Expand Up @@ -284,14 +280,6 @@ def save_corrected_segmentation_from_data(seg_data, annot_data, image_affine,
annot_plus = (annot_data[1] > 0).astype(int)
annot_minus = (annot_data[0] > 0).astype(int)
corrected = (((seg_map + annot_plus) - annot_minus) > 0)

# These operations are the inverse of what is done to an image when it is
# loaded. I am performing them to make the segmentation algin with the
# original image.
corrected = corrected[::-1, :, ::-1] # reverse lr and ud
corrected = np.moveaxis(corrected, 0, -1) # depth moved to end
corrected = np.rot90(corrected, k=1) # rotate 90.

corrected_nifty = nib.Nifti1Image(corrected.astype(np.int8),
image_affine, image_header)
output_dir = os.path.dirname(output_path)
Expand Down
4 changes: 2 additions & 2 deletions painter/src/slice_nav.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ def update_range(self, new_image, mode):
""" update range of slices based on shape of input image
and view mode """
if mode == 'axial':
slice_count = new_image.shape[0]
slice_count = new_image.shape[-1]
elif mode == 'coronal':
slice_count = new_image.shape[1]
elif mode == 'sagittal':
slice_count = new_image.shape[2]
slice_count = new_image.shape[0]
else:
raise Exception(f"Unhandled mode:{mode}")

Expand Down
30 changes: 30 additions & 0 deletions trainer/.pylintrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@

[MESSAGES CONTROL]

disable=trailing-newlines,
too-many-arguments,
too-many-statements,
trailing-whitespace,
too-many-locals,
invalid-name,
too-many-public-methods,
unnecessary-dunder-call,
consider-using-with,
import-error,
broad-except,
missing-function-docstring,
missing-class-docstring,
too-many-instance-attributes,
wrong-import-position,
too-many-branches




[TYPECHECK]

# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
# expressions are accepted.
generated-members=numpy.*,torch.*

40 changes: 35 additions & 5 deletions trainer/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""
Form a batch
Copyright (C) 2021 Abraham George Smith
Copyright (C) 2021-2023 Abraham George Smith
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
Expand All @@ -17,6 +17,19 @@
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import numpy as np
import torch
import im_utils

def pad_patches_to_dimension(patches, max_d, max_h, max_w):
patches_padded = []
for patch in patches:
if torch.is_tensor(patch):
patch = patch.numpy()
new_im_patch, _was_padded = im_utils.maybe_pad_image_to_pad_size(
patch, (max_d, max_h, max_w))
patches_padded.append(new_im_patch)
return np.array(patches_padded)


def collate_fn(batch):
num_items = len(batch)
Expand All @@ -26,16 +39,33 @@ def collate_fn(batch):
batch_segs = []
batch_classes = []
ignore_masks = []

for i in range(num_items):
item = batch[i]
class_data = {}
im_patches.append(item[0])
batch_fgs.append(item[1])
batch_bgs.append(item[2])
ignore_masks.append(item[3])
batch_segs.append(item[4])
batch_classes.append(item[5])

im_patches = np.array(im_patches)
return im_patches, batch_fgs, batch_bgs, ignore_masks, batch_segs, batch_classes
max_d = max(i.shape[0] for i in im_patches)
max_h = max(i.shape[1] for i in im_patches)
max_w = max(i.shape[2] for i in im_patches)

im_patches = pad_patches_to_dimension(im_patches, max_d, max_h, max_w)

batch_fgs_padded = []
# for the list of fg annotations for each item in the batch
for fgs in batch_fgs:
batch_fgs_padded.append(pad_patches_to_dimension(fgs, max_d, max_h, max_w))

batch_bgs_padded = []
# for the list of bg annotations for each item in the batch
for bgs in batch_bgs:
batch_bgs_padded.append(pad_patches_to_dimension(bgs, max_d, max_h, max_w))

ignore_masks_padded = pad_patches_to_dimension(ignore_masks,
max_d-34, max_h-34, max_w-34)

return (im_patches, batch_fgs_padded, batch_bgs_padded,
ignore_masks_padded, batch_segs, batch_classes)
Loading

0 comments on commit 2f4094f

Please sign in to comment.