11#!/usr/bin/env python3
22
3+ import warnings
34from pathlib import Path
45from functools import partial
56from multiprocessing import Pool
1112import torch .nn as nn
1213import torch .nn .functional as F
1314from tqdm import tqdm
15+ from skimage .io import imsave
1416from torch import Tensor , einsum
1517
1618tqdm_ = partial (tqdm , ncols = 125 ,
@@ -105,6 +107,21 @@ def probs2one_hot(probs: Tensor) -> Tensor:
105107 return res
106108
107109
110+ # Save the raw predictions
111+ def save_images (segs : Tensor , names : Iterable [str ], root : Path ) -> None :
112+ for seg , name in zip (segs , names ):
113+ save_path = (root / name ).with_suffix (".png" )
114+ save_path .parent .mkdir (parents = True , exist_ok = True )
115+
116+ if len (seg .shape ) == 2 :
117+ imsave (str (save_path ), seg .detach ().cpu ().numpy ().astype (np .uint8 ))
118+ elif len (seg .shape ) == 3 :
119+ np .save (str (save_path ), seg .detach ().cpu ().numpy ())
120+ else :
121+ raise ValueError ("How did you get here" )
122+
123+
124+ # Save a fancy looking figure
108125def saveImages (net , img_batch , batch_size , epoch , dataset , mode , device ):
109126 path = Path ('results/images/' ) / dataset / mode
110127 path .mkdir (parents = True , exist_ok = True )
@@ -115,31 +132,41 @@ def saveImages(net, img_batch, batch_size, epoch, dataset, mode, device):
115132
116133 log_dice = torch .zeros ((len (img_batch )), device = device )
117134
118- tq_iter = tqdm_ (enumerate (img_batch ), total = len (img_batch ), desc = desc )
119- for j , data in tq_iter :
120- img = data ["img" ].to (device )
121- weak_mask = data ["weak_mask" ].to (device )
122- full_mask = data ["full_mask" ].to (device )
135+ with warnings .catch_warnings ():
136+ warnings .filterwarnings ("ignore" , category = UserWarning )
137+
138+ tq_iter = tqdm_ (enumerate (img_batch ), total = len (img_batch ), desc = desc )
139+ for j , data in tq_iter :
140+ img = data ["img" ].to (device )
141+ weak_mask = data ["weak_mask" ].to (device )
142+ full_mask = data ["full_mask" ].to (device )
143+
144+ logits = net (img )
145+ probs = F .softmax (5 * logits , dim = 1 )
146+
147+ segmentation = probs2class (probs )[:, None , ...].float ()
148+ log_dice [j ] = dice_coef (probs2one_hot (probs ), full_mask )[0 , 1 ] # 1st item, 2nd class
123149
124- logits = net (img )
125- probs = F .softmax (5 * logits , dim = 1 )
150+ out = torch .cat ((img , segmentation , weak_mask [:, [1 ], ...]))
126151
127- segmentation = probs2class (probs )[:, None , ...].float ()
128- log_dice [j ] = dice_coef (probs2one_hot (probs ), full_mask )[0 , 1 ] # 1st item, 2nd class
152+ torchvision .utils .save_image (out .data , path / f"{ j } _Ep_{ epoch :04d} .png" ,
153+ nrow = batch_size ,
154+ padding = 2 ,
155+ normalize = False ,
156+ range = None ,
157+ scale_each = False ,
158+ pad_value = 0 )
129159
130- out = torch .cat ((img , segmentation , weak_mask [:, [1 ], ...]))
160+ predicted_class : Tensor = probs2class (probs )
161+ filenames : List [str ] = [Path (p ).stem for p in data ["path" ]]
131162
132- torchvision .utils .save_image (out .data , path / f"{ j } _Ep_{ epoch :04d} .png" ,
133- nrow = batch_size ,
134- padding = 2 ,
135- normalize = False ,
136- range = None ,
137- scale_each = False ,
138- pad_value = 0 )
163+ save_images (predicted_class ,
164+ filenames ,
165+ Path ("results/raw_images" ) / mode / f"iter{ epoch :03d} " )
139166
140- tq_iter .set_postfix ({"DSC" : f"{ log_dice [:j + 1 ].mean ():05.3f} " })
141- tq_iter .update (1 )
142- tq_iter .close ()
167+ tq_iter .set_postfix ({"DSC" : f"{ log_dice [:j + 1 ].mean ():05.3f} " })
168+ tq_iter .update (1 )
169+ tq_iter .close ()
143170
144171
145172# Metrics
0 commit comments