Skip to content

Commit 4c1a3dc

Browse files
Add main.py
1 parent a2eaa79 commit 4c1a3dc

File tree

1 file changed

+119
-0
lines changed

1 file changed

+119
-0
lines changed

main.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import os
2+
import sys
3+
import argparse
4+
import torch
5+
from dataset import *
6+
from helper import *
7+
from model import *
8+
from fastai.conv_learner import *
9+
from fastai.dataset import *
10+
from pathlib import Path
11+
import json
12+
13+
sys.path.append('./fastai')
14+
15+
if torch.cuda.is_available():
16+
torch.backends.cudnn.enabled = True
17+
18+
def main():
19+
20+
parser = argparse.ArgumentParser()
21+
parser.add_argument('--data_dir', type=str, default='mass_roads', \
22+
help='Path to the dataset')
23+
parser.add_argument('--learning_rate', type=float, default='0.1', \
24+
help='Learning Rate')
25+
parser.add_argument('--model_dir', type=str, default='models/', \
26+
help='Path to the complete trained model file(.h5)')
27+
parser.add_argument('--mode', type=str, default='test', \
28+
help='Learning Rate')
29+
parser.add_argument('--num_epochs', type=int, default='1', \
30+
help='Number of epochs')
31+
parser.add_argument('--cycle_len', type=int, default='2', \
32+
help='Cycle Length')
33+
parser.add_argument('--test_img', type=str, default='test.png', \
34+
help='Test Image Path')
35+
36+
args = parser.parse_args()
37+
cwd = os.getcwd()
38+
39+
f = resnet34
40+
cut, lr_cut = model_meta[f]
41+
42+
if args.mode == 'train':
43+
new_base_path = convert_and_resize(args.data_dir)
44+
train_x, train_y, valid_x, valid_y, test_x, test_y \
45+
= get_file_list(os.path.join(cwd, new_base_path))
46+
47+
PATH = Path(os.path.join(cwd, new_base_path))
48+
img_size = 1024
49+
batch_size = 2
50+
aug_tfms = []
51+
52+
tfms = tfms_from_model(resnet34, img_size, crop_type=CropType.NO, \
53+
tfm_y=TfmType.NO, aug_tfms=aug_tfms)
54+
datasets = ImageData.get_ds(MatchedFilesDataset, (train_x, train_y),
55+
(valid_x, valid_y), tfms, path=PATH)
56+
md = ImageData(PATH, datasets, batch_size, num_workers=4, classes=None)
57+
denorm = md.trn_ds.denorm
58+
x, y = next(iter(md.trn_dl))
59+
print (x.shape, y.shape)
60+
m_base = get_base()
61+
m = to_gpu(Unet34(m_base))
62+
models = UpsampleModel(m)
63+
64+
learn = ConvLearner(md, models)
65+
learn.opt_fn=optim.Adam
66+
learn.crit=mask_loss
67+
learn.metrics=[mask_acc, dice]
68+
learn.freeze_to(1)
69+
learn.load(os.path.join(cwd, 'models/1024Deepglobe-tmp'))
70+
print ('Started Training...')
71+
learn.fit(args.learning_rate, args.num_epochs, cycle_len=args.cycle_len, use_clr=(20,4))
72+
learn.save(os.path.join(cwd, 'models/Mnih-final-1024'))
73+
74+
elif args.mode == 'test':
75+
PATH = Path('./')
76+
img_size = 1024
77+
batch_size = 1
78+
aug_tfms = []
79+
m_base = get_base()
80+
m = to_gpu(Unet34(m_base))
81+
models = UpsampleModel(m)
82+
83+
t_img = [args.test_img]
84+
save_path = [args.test_img.split('/')[-1]]
85+
print (save_path, args.test_img)
86+
img = Image.open(args.test_img).resize((1024,1024)).save('1024_' + save_path[0])
87+
88+
tfms = tfms_from_model(resnet34, img_size, crop_type=CropType.NO, tfm_y=TfmType.NO, aug_tfms=aug_tfms)
89+
datasets = ImageData.get_ds(MatchedFilesDataset, (t_img, t_img), (t_img, t_img), tfms, path=PATH)
90+
md = ImageData(PATH, datasets, batch_size, num_workers=4, classes=None)
91+
denorm = md.trn_ds.denorm
92+
93+
learn = ConvLearner(md, models)
94+
learn.load(os.path.join(cwd, 'models/1024DeepGlobe-Mnih-tmp'))
95+
96+
x, _ = next(iter(md.trn_dl))
97+
start = time.time()
98+
py = to_np(learn.model(V(x)))
99+
end = time.time()
100+
print ('Prediction Time', (end-start), 'seconds')
101+
s = py[0][0]*255.0
102+
cv2.imwrite('./' + 'masked_' + save_path[0], s)
103+
104+
inp = './1024_'+ save_path[0]
105+
out = './masked_' + save_path[0]
106+
img = cv2.imread(inp)
107+
mask = cv2.imread(out, 0)
108+
copy = img.copy()
109+
new = np.zeros(img.shape, img.dtype)
110+
new[:,:] = (255, 10, 10)
111+
new_mask = cv2.bitwise_and(new, new, mask=mask)
112+
cv2.addWeighted(new_mask, 1, img, 0.6, 0, img)
113+
cv2.imwrite('./' + 'overlay_' + save_path[0], img)
114+
115+
else:
116+
pass
117+
118+
if __name__ == '__main__':
119+
main()

0 commit comments

Comments
 (0)