Skip to content
This repository was archived by the owner on Nov 16, 2023. It is now read-only.

Commit 4484155

Browse files
authored
added ability to score test images to disk for dutch F3 (#232)
* added ability to score test images to disk for dutch F3 * gitpython package fix from staging
1 parent b75e647 commit 4484155

File tree

4 files changed

+124
-31
lines changed

4 files changed

+124
-31
lines changed

cv_lib/cv_lib/event_handlers/tensorboard_handlers.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,7 @@
66
import logging
77
import logging.config
88

9-
try:
10-
from tensorboardX import SummaryWriter
11-
except ImportError:
12-
raise RuntimeError("No tensorboardX package is found. Please install with the command: \npip install tensorboardX")
9+
from tensorboardX import SummaryWriter
1310

1411

1512
def create_summary_writer(log_dir):
@@ -52,16 +49,22 @@ def log_lr(summary_writer, optimizer, log_interval, engine):
5249
def log_metrics(summary_writer, train_engine, log_interval, engine, metrics_dict=_DEFAULT_METRICS):
5350
metrics = engine.state.metrics
5451
for m in metrics_dict:
55-
summary_writer.add_scalar(metrics_dict[m], metrics[m], getattr(train_engine.state, log_interval))
52+
summary_writer.add_scalar(
53+
metrics_dict[m], metrics[m], getattr(train_engine.state, log_interval)
54+
)
5655

5756

58-
def create_image_writer(summary_writer, label, output_variable, normalize=False, transform_func=lambda x: x):
57+
def create_image_writer(
58+
summary_writer, label, output_variable, normalize=False, transform_func=lambda x: x
59+
):
5960
logger = logging.getLogger(__name__)
6061

6162
def write_to(engine):
6263
try:
6364
data_tensor = transform_func(engine.state.output[output_variable])
64-
image_grid = torchvision.utils.make_grid(data_tensor, normalize=normalize, scale_each=True)
65+
image_grid = torchvision.utils.make_grid(
66+
data_tensor, normalize=normalize, scale_each=True
67+
)
6568
summary_writer.add_image(label, image_grid, engine.state.epoch)
6669
except KeyError:
6770
logger.warning("Predictions and or ground truth labels not available to report")

environment/anaconda/local/environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ dependencies:
2626
- toolz==0.10.0
2727
- tabulate==0.8.2
2828
- Jinja2==2.10.3
29-
- gitpython==3.0.5
29+
- gitpython==3.0.6
3030
- tensorboard==2.0.1
3131
- tensorboardx==1.9
3232
- invoke==1.3.0

experiments/interpretation/dutchf3_patch/local/test.py

Lines changed: 86 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
"""
1111
Modified version of the Alaudah testing script
1212
Runs only on single GPU
13-
14-
Estimated time to run on single V100: 5 hours
1513
"""
1614

1715
import itertools
@@ -25,9 +23,16 @@
2523
import numpy as np
2624
import torch
2725
import torch.nn.functional as F
26+
from PIL import Image
2827
from albumentations import Compose, Normalize, PadIfNeeded, Resize
2928
from cv_lib.utils import load_log_configuration
3029
from cv_lib.segmentation import models
30+
from cv_lib.segmentation.dutchf3.utils import (
31+
current_datetime,
32+
generate_path,
33+
git_branch,
34+
git_hash,
35+
)
3136
from deepseismic_interpretation.dutchf3.data import (
3237
add_patch_depth_channels,
3338
get_seismic_labels,
@@ -39,6 +44,8 @@
3944
from torch.utils import data
4045
from toolz import take
4146

47+
from matplotlib import cm
48+
4249

4350
_CLASS_NAMES = [
4451
"upper_ns",
@@ -57,9 +64,9 @@ def __init__(self, n_classes):
5764

5865
def _fast_hist(self, label_true, label_pred, n_class):
5966
mask = (label_true >= 0) & (label_true < n_class)
60-
hist = np.bincount(n_class * label_true[mask].astype(int) + label_pred[mask], minlength=n_class ** 2,).reshape(
61-
n_class, n_class
62-
)
67+
hist = np.bincount(
68+
n_class * label_true[mask].astype(int) + label_pred[mask], minlength=n_class ** 2,
69+
).reshape(n_class, n_class)
6370
return hist
6471

6572
def update(self, label_trues, label_preds):
@@ -99,6 +106,21 @@ def reset(self):
99106
self.confusion_matrix = np.zeros((self.n_classes, self.n_classes))
100107

101108

109+
def normalize(array):
110+
"""
111+
Normalizes a segmentation mask array to be in [0,1] range
112+
"""
113+
min = array.min()
114+
return (array - min) / (array.max() - min)
115+
116+
117+
def mask_to_disk(mask, fname):
118+
"""
119+
write segmentation mask to disk using a particular colormap
120+
"""
121+
Image.fromarray(cm.gist_earth(normalize(mask), bytes=True)).save(fname)
122+
123+
102124
def _transform_CHW_to_HWC(numpy_array):
103125
return np.moveaxis(numpy_array, 0, -1)
104126

@@ -180,7 +202,9 @@ def _compose_processing_pipeline(depth, aug=None):
180202

181203

182204
def _generate_batches(h, w, ps, patch_size, stride, batch_size=64):
183-
hdc_wdx_generator = itertools.product(range(0, h - patch_size + ps, stride), range(0, w - patch_size + ps, stride),)
205+
hdc_wdx_generator = itertools.product(
206+
range(0, h - patch_size + ps, stride), range(0, w - patch_size + ps, stride),
207+
)
184208
for batch_indexes in itertoolz.partition_all(batch_size, hdc_wdx_generator):
185209
yield batch_indexes
186210

@@ -191,7 +215,9 @@ def _output_processing_pipeline(config, output):
191215
_, _, h, w = output.shape
192216
if config.TEST.POST_PROCESSING.SIZE != h or config.TEST.POST_PROCESSING.SIZE != w:
193217
output = F.interpolate(
194-
output, size=(config.TEST.POST_PROCESSING.SIZE, config.TEST.POST_PROCESSING.SIZE,), mode="bilinear",
218+
output,
219+
size=(config.TEST.POST_PROCESSING.SIZE, config.TEST.POST_PROCESSING.SIZE,),
220+
mode="bilinear",
195221
)
196222

197223
if config.TEST.POST_PROCESSING.CROP_PIXELS > 0:
@@ -206,7 +232,15 @@ def _output_processing_pipeline(config, output):
206232

207233

208234
def _patch_label_2d(
209-
model, img, pre_processing, output_processing, patch_size, stride, batch_size, device, num_classes,
235+
model,
236+
img,
237+
pre_processing,
238+
output_processing,
239+
patch_size,
240+
stride,
241+
batch_size,
242+
device,
243+
num_classes,
210244
):
211245
"""Processes a whole section
212246
"""
@@ -221,14 +255,19 @@ def _patch_label_2d(
221255
# generate output:
222256
for batch_indexes in _generate_batches(h, w, ps, patch_size, stride, batch_size=batch_size):
223257
batch = torch.stack(
224-
[pipe(img_p, _extract_patch(hdx, wdx, ps, patch_size), pre_processing,) for hdx, wdx in batch_indexes],
258+
[
259+
pipe(img_p, _extract_patch(hdx, wdx, ps, patch_size), pre_processing,)
260+
for hdx, wdx in batch_indexes
261+
],
225262
dim=0,
226263
)
227264

228265
model_output = model(batch.to(device))
229266
for (hdx, wdx), output in zip(batch_indexes, model_output.detach().cpu()):
230267
output = output_processing(output)
231-
output_p[:, :, hdx + ps : hdx + ps + patch_size, wdx + ps : wdx + ps + patch_size,] += output
268+
output_p[
269+
:, :, hdx + ps : hdx + ps + patch_size, wdx + ps : wdx + ps + patch_size,
270+
] += output
232271

233272
# crop the output_p in the middle
234273
output = output_p[:, :, ps:-ps, ps:-ps]
@@ -253,12 +292,22 @@ def to_image(label_mask, n_classes=6):
253292

254293

255294
def _evaluate_split(
256-
split, section_aug, model, pre_processing, output_processing, device, running_metrics_overall, config, debug=False
295+
split,
296+
section_aug,
297+
model,
298+
pre_processing,
299+
output_processing,
300+
device,
301+
running_metrics_overall,
302+
config,
303+
debug=False,
257304
):
258305
logger = logging.getLogger(__name__)
259306

260307
TestSectionLoader = get_test_loader(config)
261-
test_set = TestSectionLoader(config.DATASET.ROOT, split=split, is_transform=True, augmentations=section_aug,)
308+
test_set = TestSectionLoader(
309+
config.DATASET.ROOT, split=split, is_transform=True, augmentations=section_aug,
310+
)
262311

263312
n_classes = test_set.n_classes
264313

@@ -268,6 +317,21 @@ def _evaluate_split(
268317
logger.info("Running in Debug/Test mode")
269318
test_loader = take(1, test_loader)
270319

320+
try:
321+
output_dir = generate_path(
322+
config.OUTPUT_DIR + "_test",
323+
git_branch(),
324+
git_hash(),
325+
config.MODEL.NAME,
326+
current_datetime(),
327+
)
328+
except TypeError:
329+
output_dir = generate_path(
330+
config.OUTPUT_DIR + "_test",
331+
config.MODEL.NAME,
332+
current_datetime(),
333+
)
334+
271335
running_metrics_split = runningScore(n_classes)
272336

273337
# testing mode:
@@ -295,6 +359,10 @@ def _evaluate_split(
295359
running_metrics_split.update(gt, pred)
296360
running_metrics_overall.update(gt, pred)
297361

362+
# dump images to disk for review
363+
mask_to_disk(pred.squeeze(), os.path.join(output_dir, f"{i}_pred.png"))
364+
mask_to_disk(gt.squeeze(), os.path.join(output_dir, f"{i}_gt.png"))
365+
298366
# get scores
299367
score, class_iou = running_metrics_split.get_scores()
300368

@@ -350,12 +418,16 @@ def test(*options, cfg=None, debug=False):
350418
running_metrics_overall = runningScore(n_classes)
351419

352420
# Augmentation
353-
section_aug = Compose([Normalize(mean=(config.TRAIN.MEAN,), std=(config.TRAIN.STD,), max_pixel_value=1,)])
421+
section_aug = Compose(
422+
[Normalize(mean=(config.TRAIN.MEAN,), std=(config.TRAIN.STD,), max_pixel_value=1,)]
423+
)
354424

355425
patch_aug = Compose(
356426
[
357427
Resize(
358-
config.TRAIN.AUGMENTATIONS.RESIZE.HEIGHT, config.TRAIN.AUGMENTATIONS.RESIZE.WIDTH, always_apply=True,
428+
config.TRAIN.AUGMENTATIONS.RESIZE.HEIGHT,
429+
config.TRAIN.AUGMENTATIONS.RESIZE.WIDTH,
430+
always_apply=True,
359431
),
360432
PadIfNeeded(
361433
min_height=config.TRAIN.AUGMENTATIONS.PAD.HEIGHT,

experiments/interpretation/dutchf3_patch/local/train.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,9 @@ def run(*options, cfg=None, debug=False):
111111
[
112112
Normalize(mean=(config.TRAIN.MEAN,), std=(config.TRAIN.STD,), max_pixel_value=1),
113113
Resize(
114-
config.TRAIN.AUGMENTATIONS.RESIZE.HEIGHT, config.TRAIN.AUGMENTATIONS.RESIZE.WIDTH, always_apply=True,
114+
config.TRAIN.AUGMENTATIONS.RESIZE.HEIGHT,
115+
config.TRAIN.AUGMENTATIONS.RESIZE.WIDTH,
116+
always_apply=True,
115117
),
116118
PadIfNeeded(
117119
min_height=config.TRAIN.AUGMENTATIONS.PAD.HEIGHT,
@@ -151,9 +153,14 @@ def run(*options, cfg=None, debug=False):
151153
n_classes = train_set.n_classes
152154

153155
train_loader = data.DataLoader(
154-
train_set, batch_size=config.TRAIN.BATCH_SIZE_PER_GPU, num_workers=config.WORKERS, shuffle=True,
156+
train_set,
157+
batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
158+
num_workers=config.WORKERS,
159+
shuffle=True,
160+
)
161+
val_loader = data.DataLoader(
162+
val_set, batch_size=config.VALIDATION.BATCH_SIZE_PER_GPU, num_workers=config.WORKERS,
155163
)
156-
val_loader = data.DataLoader(val_set, batch_size=config.VALIDATION.BATCH_SIZE_PER_GPU, num_workers=config.WORKERS,)
157164

158165
model = getattr(models, config.MODEL.NAME).get_seg_model(config)
159166

@@ -170,14 +177,18 @@ def run(*options, cfg=None, debug=False):
170177
)
171178

172179
try:
173-
output_dir = generate_path(config.OUTPUT_DIR, git_branch(), git_hash(), config.MODEL.NAME, current_datetime(),)
180+
output_dir = generate_path(
181+
config.OUTPUT_DIR, git_branch(), git_hash(), config.MODEL.NAME, current_datetime(),
182+
)
174183
except TypeError:
175184
output_dir = generate_path(config.OUTPUT_DIR, config.MODEL.NAME, current_datetime(),)
176185

177186
summary_writer = create_summary_writer(log_dir=path.join(output_dir, config.LOG_DIR))
178187

179188
snapshot_duration = scheduler_step * len(train_loader)
180-
scheduler = CosineAnnealingScheduler(optimizer, "lr", config.TRAIN.MAX_LR, config.TRAIN.MIN_LR, snapshot_duration)
189+
scheduler = CosineAnnealingScheduler(
190+
optimizer, "lr", config.TRAIN.MAX_LR, config.TRAIN.MIN_LR, snapshot_duration
191+
)
181192

182193
# weights are inversely proportional to the frequency of the classes in the
183194
# training set
@@ -190,7 +201,8 @@ def run(*options, cfg=None, debug=False):
190201
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
191202

192203
trainer.add_event_handler(
193-
Events.ITERATION_COMPLETED, logging_handlers.log_training_output(log_interval=config.PRINT_FREQ),
204+
Events.ITERATION_COMPLETED,
205+
logging_handlers.log_training_output(log_interval=config.PRINT_FREQ),
194206
)
195207
trainer.add_event_handler(Events.EPOCH_STARTED, logging_handlers.log_lr(optimizer))
196208
trainer.add_event_handler(
@@ -208,7 +220,9 @@ def _select_pred_and_mask(model_out_dict):
208220
prepare_batch,
209221
metrics={
210222
"nll": Loss(criterion, output_transform=_select_pred_and_mask),
211-
"pixacc": pixelwise_accuracy(n_classes, output_transform=_select_pred_and_mask, device=device),
223+
"pixacc": pixelwise_accuracy(
224+
n_classes, output_transform=_select_pred_and_mask, device=device
225+
),
212226
"cacc": class_accuracy(n_classes, output_transform=_select_pred_and_mask),
213227
"mca": mean_class_accuracy(n_classes, output_transform=_select_pred_and_mask),
214228
"ciou": class_iou(n_classes, output_transform=_select_pred_and_mask),
@@ -267,11 +281,15 @@ def _tensor_to_numpy(pred_tensor):
267281
)
268282
evaluator.add_event_handler(
269283
Events.EPOCH_COMPLETED,
270-
create_image_writer(summary_writer, "Validation/Mask", "mask", transform_func=transform_func),
284+
create_image_writer(
285+
summary_writer, "Validation/Mask", "mask", transform_func=transform_func
286+
),
271287
)
272288
evaluator.add_event_handler(
273289
Events.EPOCH_COMPLETED,
274-
create_image_writer(summary_writer, "Validation/Pred", "y_pred", transform_func=transform_pred),
290+
create_image_writer(
291+
summary_writer, "Validation/Pred", "y_pred", transform_func=transform_pred
292+
),
275293
)
276294

277295
def snapshot_function():

0 commit comments

Comments
 (0)