1
+ import os
2
+ import sys
3
+ import argparse
4
+ import torch
5
+ from dataset import *
6
+ from helper import *
7
+ from model import *
8
+ from fastai .conv_learner import *
9
+ from fastai .dataset import *
10
+ from pathlib import Path
11
+ import json
12
+
13
+ sys .path .append ('./fastai' )
14
+
15
+ if torch .cuda .is_available ():
16
+ torch .backends .cudnn .enabled = True
17
+
18
+ def main ():
19
+
20
+ parser = argparse .ArgumentParser ()
21
+ parser .add_argument ('--data_dir' , type = str , default = 'mass_roads' , \
22
+ help = 'Path to the dataset' )
23
+ parser .add_argument ('--learning_rate' , type = float , default = '0.1' , \
24
+ help = 'Learning Rate' )
25
+ parser .add_argument ('--model_dir' , type = str , default = 'models/' , \
26
+ help = 'Path to the complete trained model file(.h5)' )
27
+ parser .add_argument ('--mode' , type = str , default = 'test' , \
28
+ help = 'Learning Rate' )
29
+ parser .add_argument ('--num_epochs' , type = int , default = '1' , \
30
+ help = 'Number of epochs' )
31
+ parser .add_argument ('--cycle_len' , type = int , default = '2' , \
32
+ help = 'Cycle Length' )
33
+ parser .add_argument ('--test_img' , type = str , default = 'test.png' , \
34
+ help = 'Test Image Path' )
35
+
36
+ args = parser .parse_args ()
37
+ cwd = os .getcwd ()
38
+
39
+ f = resnet34
40
+ cut , lr_cut = model_meta [f ]
41
+
42
+ if args .mode == 'train' :
43
+ new_base_path = convert_and_resize (args .data_dir )
44
+ train_x , train_y , valid_x , valid_y , test_x , test_y \
45
+ = get_file_list (os .path .join (cwd , new_base_path ))
46
+
47
+ PATH = Path (os .path .join (cwd , new_base_path ))
48
+ img_size = 1024
49
+ batch_size = 2
50
+ aug_tfms = []
51
+
52
+ tfms = tfms_from_model (resnet34 , img_size , crop_type = CropType .NO , \
53
+ tfm_y = TfmType .NO , aug_tfms = aug_tfms )
54
+ datasets = ImageData .get_ds (MatchedFilesDataset , (train_x , train_y ),
55
+ (valid_x , valid_y ), tfms , path = PATH )
56
+ md = ImageData (PATH , datasets , batch_size , num_workers = 4 , classes = None )
57
+ denorm = md .trn_ds .denorm
58
+ x , y = next (iter (md .trn_dl ))
59
+ print (x .shape , y .shape )
60
+ m_base = get_base ()
61
+ m = to_gpu (Unet34 (m_base ))
62
+ models = UpsampleModel (m )
63
+
64
+ learn = ConvLearner (md , models )
65
+ learn .opt_fn = optim .Adam
66
+ learn .crit = mask_loss
67
+ learn .metrics = [mask_acc , dice ]
68
+ learn .freeze_to (1 )
69
+ learn .load (os .path .join (cwd , 'models/1024Deepglobe-tmp' ))
70
+ print ('Started Training...' )
71
+ learn .fit (args .learning_rate , args .num_epochs , cycle_len = args .cycle_len , use_clr = (20 ,4 ))
72
+ learn .save (os .path .join (cwd , 'models/Mnih-final-1024' ))
73
+
74
+ elif args .mode == 'test' :
75
+ PATH = Path ('./' )
76
+ img_size = 1024
77
+ batch_size = 1
78
+ aug_tfms = []
79
+ m_base = get_base ()
80
+ m = to_gpu (Unet34 (m_base ))
81
+ models = UpsampleModel (m )
82
+
83
+ t_img = [args .test_img ]
84
+ save_path = [args .test_img .split ('/' )[- 1 ]]
85
+ print (save_path , args .test_img )
86
+ img = Image .open (args .test_img ).resize ((1024 ,1024 )).save ('1024_' + save_path [0 ])
87
+
88
+ tfms = tfms_from_model (resnet34 , img_size , crop_type = CropType .NO , tfm_y = TfmType .NO , aug_tfms = aug_tfms )
89
+ datasets = ImageData .get_ds (MatchedFilesDataset , (t_img , t_img ), (t_img , t_img ), tfms , path = PATH )
90
+ md = ImageData (PATH , datasets , batch_size , num_workers = 4 , classes = None )
91
+ denorm = md .trn_ds .denorm
92
+
93
+ learn = ConvLearner (md , models )
94
+ learn .load (os .path .join (cwd , 'models/1024DeepGlobe-Mnih-tmp' ))
95
+
96
+ x , _ = next (iter (md .trn_dl ))
97
+ start = time .time ()
98
+ py = to_np (learn .model (V (x )))
99
+ end = time .time ()
100
+ print ('Prediction Time' , (end - start ), 'seconds' )
101
+ s = py [0 ][0 ]* 255.0
102
+ cv2 .imwrite ('./' + 'masked_' + save_path [0 ], s )
103
+
104
+ inp = './1024_' + save_path [0 ]
105
+ out = './masked_' + save_path [0 ]
106
+ img = cv2 .imread (inp )
107
+ mask = cv2 .imread (out , 0 )
108
+ copy = img .copy ()
109
+ new = np .zeros (img .shape , img .dtype )
110
+ new [:,:] = (255 , 10 , 10 )
111
+ new_mask = cv2 .bitwise_and (new , new , mask = mask )
112
+ cv2 .addWeighted (new_mask , 1 , img , 0.6 , 0 , img )
113
+ cv2 .imwrite ('./' + 'overlay_' + save_path [0 ], img )
114
+
115
+ else :
116
+ pass
117
+
118
+ if __name__ == '__main__' :
119
+ main ()
0 commit comments