10
10
"""
11
11
Modified version of the Alaudah testing script
12
12
Runs only on single GPU
13
-
14
- Estimated time to run on single V100: 5 hours
15
13
"""
16
14
17
15
import itertools
25
23
import numpy as np
26
24
import torch
27
25
import torch .nn .functional as F
26
+ from PIL import Image
28
27
from albumentations import Compose , Normalize , PadIfNeeded , Resize
29
28
from cv_lib .utils import load_log_configuration
30
29
from cv_lib .segmentation import models
30
+ from cv_lib .segmentation .dutchf3 .utils import (
31
+ current_datetime ,
32
+ generate_path ,
33
+ git_branch ,
34
+ git_hash ,
35
+ )
31
36
from deepseismic_interpretation .dutchf3 .data import (
32
37
add_patch_depth_channels ,
33
38
get_seismic_labels ,
39
44
from torch .utils import data
40
45
from toolz import take
41
46
47
+ from matplotlib import cm
48
+
42
49
43
50
_CLASS_NAMES = [
44
51
"upper_ns" ,
@@ -57,9 +64,9 @@ def __init__(self, n_classes):
57
64
58
65
def _fast_hist (self , label_true , label_pred , n_class ):
59
66
mask = (label_true >= 0 ) & (label_true < n_class )
60
- hist = np .bincount (n_class * label_true [ mask ]. astype ( int ) + label_pred [ mask ], minlength = n_class ** 2 ,). reshape (
61
- n_class , n_class
62
- )
67
+ hist = np .bincount (
68
+ n_class * label_true [ mask ]. astype ( int ) + label_pred [ mask ], minlength = n_class ** 2 ,
69
+ ). reshape ( n_class , n_class )
63
70
return hist
64
71
65
72
def update (self , label_trues , label_preds ):
@@ -99,6 +106,21 @@ def reset(self):
99
106
self .confusion_matrix = np .zeros ((self .n_classes , self .n_classes ))
100
107
101
108
109
+ def normalize (array ):
110
+ """
111
+ Normalizes a segmentation mask array to be in [0,1] range
112
+ """
113
+ min = array .min ()
114
+ return (array - min ) / (array .max () - min )
115
+
116
+
117
+ def mask_to_disk (mask , fname ):
118
+ """
119
+ write segmentation mask to disk using a particular colormap
120
+ """
121
+ Image .fromarray (cm .gist_earth (normalize (mask ), bytes = True )).save (fname )
122
+
123
+
102
124
def _transform_CHW_to_HWC (numpy_array ):
103
125
return np .moveaxis (numpy_array , 0 , - 1 )
104
126
@@ -180,7 +202,9 @@ def _compose_processing_pipeline(depth, aug=None):
180
202
181
203
182
204
def _generate_batches (h , w , ps , patch_size , stride , batch_size = 64 ):
183
- hdc_wdx_generator = itertools .product (range (0 , h - patch_size + ps , stride ), range (0 , w - patch_size + ps , stride ),)
205
+ hdc_wdx_generator = itertools .product (
206
+ range (0 , h - patch_size + ps , stride ), range (0 , w - patch_size + ps , stride ),
207
+ )
184
208
for batch_indexes in itertoolz .partition_all (batch_size , hdc_wdx_generator ):
185
209
yield batch_indexes
186
210
@@ -191,7 +215,9 @@ def _output_processing_pipeline(config, output):
191
215
_ , _ , h , w = output .shape
192
216
if config .TEST .POST_PROCESSING .SIZE != h or config .TEST .POST_PROCESSING .SIZE != w :
193
217
output = F .interpolate (
194
- output , size = (config .TEST .POST_PROCESSING .SIZE , config .TEST .POST_PROCESSING .SIZE ,), mode = "bilinear" ,
218
+ output ,
219
+ size = (config .TEST .POST_PROCESSING .SIZE , config .TEST .POST_PROCESSING .SIZE ,),
220
+ mode = "bilinear" ,
195
221
)
196
222
197
223
if config .TEST .POST_PROCESSING .CROP_PIXELS > 0 :
@@ -206,7 +232,15 @@ def _output_processing_pipeline(config, output):
206
232
207
233
208
234
def _patch_label_2d (
209
- model , img , pre_processing , output_processing , patch_size , stride , batch_size , device , num_classes ,
235
+ model ,
236
+ img ,
237
+ pre_processing ,
238
+ output_processing ,
239
+ patch_size ,
240
+ stride ,
241
+ batch_size ,
242
+ device ,
243
+ num_classes ,
210
244
):
211
245
"""Processes a whole section
212
246
"""
@@ -221,14 +255,19 @@ def _patch_label_2d(
221
255
# generate output:
222
256
for batch_indexes in _generate_batches (h , w , ps , patch_size , stride , batch_size = batch_size ):
223
257
batch = torch .stack (
224
- [pipe (img_p , _extract_patch (hdx , wdx , ps , patch_size ), pre_processing ,) for hdx , wdx in batch_indexes ],
258
+ [
259
+ pipe (img_p , _extract_patch (hdx , wdx , ps , patch_size ), pre_processing ,)
260
+ for hdx , wdx in batch_indexes
261
+ ],
225
262
dim = 0 ,
226
263
)
227
264
228
265
model_output = model (batch .to (device ))
229
266
for (hdx , wdx ), output in zip (batch_indexes , model_output .detach ().cpu ()):
230
267
output = output_processing (output )
231
- output_p [:, :, hdx + ps : hdx + ps + patch_size , wdx + ps : wdx + ps + patch_size ,] += output
268
+ output_p [
269
+ :, :, hdx + ps : hdx + ps + patch_size , wdx + ps : wdx + ps + patch_size ,
270
+ ] += output
232
271
233
272
# crop the output_p in the middle
234
273
output = output_p [:, :, ps :- ps , ps :- ps ]
@@ -253,12 +292,22 @@ def to_image(label_mask, n_classes=6):
253
292
254
293
255
294
def _evaluate_split (
256
- split , section_aug , model , pre_processing , output_processing , device , running_metrics_overall , config , debug = False
295
+ split ,
296
+ section_aug ,
297
+ model ,
298
+ pre_processing ,
299
+ output_processing ,
300
+ device ,
301
+ running_metrics_overall ,
302
+ config ,
303
+ debug = False ,
257
304
):
258
305
logger = logging .getLogger (__name__ )
259
306
260
307
TestSectionLoader = get_test_loader (config )
261
- test_set = TestSectionLoader (config .DATASET .ROOT , split = split , is_transform = True , augmentations = section_aug ,)
308
+ test_set = TestSectionLoader (
309
+ config .DATASET .ROOT , split = split , is_transform = True , augmentations = section_aug ,
310
+ )
262
311
263
312
n_classes = test_set .n_classes
264
313
@@ -268,6 +317,21 @@ def _evaluate_split(
268
317
logger .info ("Running in Debug/Test mode" )
269
318
test_loader = take (1 , test_loader )
270
319
320
+ try :
321
+ output_dir = generate_path (
322
+ config .OUTPUT_DIR + "_test" ,
323
+ git_branch (),
324
+ git_hash (),
325
+ config .MODEL .NAME ,
326
+ current_datetime (),
327
+ )
328
+ except TypeError :
329
+ output_dir = generate_path (
330
+ config .OUTPUT_DIR + "_test" ,
331
+ config .MODEL .NAME ,
332
+ current_datetime (),
333
+ )
334
+
271
335
running_metrics_split = runningScore (n_classes )
272
336
273
337
# testing mode:
@@ -295,6 +359,10 @@ def _evaluate_split(
295
359
running_metrics_split .update (gt , pred )
296
360
running_metrics_overall .update (gt , pred )
297
361
362
+ # dump images to disk for review
363
+ mask_to_disk (pred .squeeze (), os .path .join (output_dir , f"{ i } _pred.png" ))
364
+ mask_to_disk (gt .squeeze (), os .path .join (output_dir , f"{ i } _gt.png" ))
365
+
298
366
# get scores
299
367
score , class_iou = running_metrics_split .get_scores ()
300
368
@@ -350,12 +418,16 @@ def test(*options, cfg=None, debug=False):
350
418
running_metrics_overall = runningScore (n_classes )
351
419
352
420
# Augmentation
353
- section_aug = Compose ([Normalize (mean = (config .TRAIN .MEAN ,), std = (config .TRAIN .STD ,), max_pixel_value = 1 ,)])
421
+ section_aug = Compose (
422
+ [Normalize (mean = (config .TRAIN .MEAN ,), std = (config .TRAIN .STD ,), max_pixel_value = 1 ,)]
423
+ )
354
424
355
425
patch_aug = Compose (
356
426
[
357
427
Resize (
358
- config .TRAIN .AUGMENTATIONS .RESIZE .HEIGHT , config .TRAIN .AUGMENTATIONS .RESIZE .WIDTH , always_apply = True ,
428
+ config .TRAIN .AUGMENTATIONS .RESIZE .HEIGHT ,
429
+ config .TRAIN .AUGMENTATIONS .RESIZE .WIDTH ,
430
+ always_apply = True ,
359
431
),
360
432
PadIfNeeded (
361
433
min_height = config .TRAIN .AUGMENTATIONS .PAD .HEIGHT ,
0 commit comments