Skip to content

Commit 29f74be

Browse files
committed
Save more images for inspection
1 parent b02316e commit 29f74be

File tree

2 files changed

+48
-21
lines changed

2 files changed

+48
-21
lines changed

code/main_centroid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def runTraining(args):
159159
def main():
160160
parser = argparse.ArgumentParser()
161161

162-
parser.add_argument('--epochs', default=200, type=int)
162+
parser.add_argument('--epochs', default=30, type=int)
163163
parser.add_argument('--dataset', default='TOY2', choices=['TOY2'])
164164
parser.add_argument('--mode', default='quadratic', choices=['quadratic', 'logbarrier'])
165165

code/utils/utils.py

Lines changed: 47 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#!/usr/bin/env python3
22

3+
import warnings
34
from pathlib import Path
45
from functools import partial
56
from multiprocessing import Pool
@@ -11,6 +12,7 @@
1112
import torch.nn as nn
1213
import torch.nn.functional as F
1314
from tqdm import tqdm
15+
from skimage.io import imsave
1416
from torch import Tensor, einsum
1517

1618
tqdm_ = partial(tqdm, ncols=125,
@@ -105,6 +107,21 @@ def probs2one_hot(probs: Tensor) -> Tensor:
105107
return res
106108

107109

110+
# Save the raw predictions
111+
def save_images(segs: Tensor, names: Iterable[str], root: Path) -> None:
112+
for seg, name in zip(segs, names):
113+
save_path = (root / name).with_suffix(".png")
114+
save_path.parent.mkdir(parents=True, exist_ok=True)
115+
116+
if len(seg.shape) == 2:
117+
imsave(str(save_path), seg.detach().cpu().numpy().astype(np.uint8))
118+
elif len(seg.shape) == 3:
119+
np.save(str(save_path), seg.detach().cpu().numpy())
120+
else:
121+
raise ValueError("How did you get here")
122+
123+
124+
# Save a fancy looking figure
108125
def saveImages(net, img_batch, batch_size, epoch, dataset, mode, device):
109126
path = Path('results/images/') / dataset / mode
110127
path.mkdir(parents=True, exist_ok=True)
@@ -115,31 +132,41 @@ def saveImages(net, img_batch, batch_size, epoch, dataset, mode, device):
115132

116133
log_dice = torch.zeros((len(img_batch)), device=device)
117134

118-
tq_iter = tqdm_(enumerate(img_batch), total=len(img_batch), desc=desc)
119-
for j, data in tq_iter:
120-
img = data["img"].to(device)
121-
weak_mask = data["weak_mask"].to(device)
122-
full_mask = data["full_mask"].to(device)
135+
with warnings.catch_warnings():
136+
warnings.filterwarnings("ignore", category=UserWarning)
137+
138+
tq_iter = tqdm_(enumerate(img_batch), total=len(img_batch), desc=desc)
139+
for j, data in tq_iter:
140+
img = data["img"].to(device)
141+
weak_mask = data["weak_mask"].to(device)
142+
full_mask = data["full_mask"].to(device)
143+
144+
logits = net(img)
145+
probs = F.softmax(5 * logits, dim=1)
146+
147+
segmentation = probs2class(probs)[:, None, ...].float()
148+
log_dice[j] = dice_coef(probs2one_hot(probs), full_mask)[0, 1] # 1st item, 2nd class
123149

124-
logits = net(img)
125-
probs = F.softmax(5 * logits, dim=1)
150+
out = torch.cat((img, segmentation, weak_mask[:, [1], ...]))
126151

127-
segmentation = probs2class(probs)[:, None, ...].float()
128-
log_dice[j] = dice_coef(probs2one_hot(probs), full_mask)[0, 1] # 1st item, 2nd class
152+
torchvision.utils.save_image(out.data, path / f"{j}_Ep_{epoch:04d}.png",
153+
nrow=batch_size,
154+
padding=2,
155+
normalize=False,
156+
range=None,
157+
scale_each=False,
158+
pad_value=0)
129159

130-
out = torch.cat((img, segmentation, weak_mask[:, [1], ...]))
160+
predicted_class: Tensor = probs2class(probs)
161+
filenames: List[str] = [Path(p).stem for p in data["path"]]
131162

132-
torchvision.utils.save_image(out.data, path / f"{j}_Ep_{epoch:04d}.png",
133-
nrow=batch_size,
134-
padding=2,
135-
normalize=False,
136-
range=None,
137-
scale_each=False,
138-
pad_value=0)
163+
save_images(predicted_class,
164+
filenames,
165+
Path("results/raw_images") / mode / f"iter{epoch:03d}")
139166

140-
tq_iter.set_postfix({"DSC": f"{log_dice[:j+1].mean():05.3f}"})
141-
tq_iter.update(1)
142-
tq_iter.close()
167+
tq_iter.set_postfix({"DSC": f"{log_dice[:j+1].mean():05.3f}"})
168+
tq_iter.update(1)
169+
tq_iter.close()
143170

144171

145172
# Metrics

0 commit comments

Comments
 (0)