-
Notifications
You must be signed in to change notification settings - Fork 10
/
create_letter_samples.py
executable file
·148 lines (118 loc) · 4.06 KB
/
create_letter_samples.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#!/usr/bin/env python
"""create_patches.py
Create labeled characters for training the character classification network.
Author: Amit Aides, Ahmad Kiswani
License: See attached license file
"""
from __future__ import division
import argparse
import AUVSItargets
import AUVSItargets.global_settings as gs
import cv2
import glob
from joblib import Parallel, delayed
import os
import random
import shutil
import traceback
RESIZED = False
if RESIZED:
K = gs.resized_K
SUB_PATH = 'resized_images'
else:
K = gs.K
SUB_PATH = 'renamed_images'
def main(jobs):
#
# Setup the paths.
#
imgs_paths = sorted(
glob.glob(os.path.join(gs.DATA_PATH, SUB_PATH, '*.jpg'))
)
img_names = [
os.path.splitext(os.path.split(path)[1])[0] for path in imgs_paths
]
data_paths = [os.path.join(gs.DATA_PATH,
'flight_data',
'resized_'+name+'.json'
) for name in img_names]
#
# Delete any old dst folder
#
dst_folder = os.path.join(gs.DATA_PATH, 'train_letter')
if os.path.exists(dst_folder):
shutil.rmtree(dst_folder)
os.makedirs(dst_folder)
img_index = 0
for img_path, data_path in zip(imgs_paths, data_paths):
print 'Extracting patches from image', img_path
img = AUVSItargets.Image(img_path, data_path, K=K)
patches = img.createPatches(patch_size=gs.PATCH_SIZE, patch_shift=1000)
results = Parallel(n_jobs=jobs)(
delayed(create_patch)(patch,
img,
img.latitude,
img.longitude,
img.yaw) for patch in patches
)
for mask, letter_label in results:
if mask is None:
continue
filename = '{:07}'.format(img_index)
img_index += 1
cv2.imwrite(os.path.join(dst_folder, filename+'.jpg'), mask)
with open(os.path.join(dst_folder, filename+'.txt'), 'w') as fp:
fp.write('{}.jpg\t{}'.format(filename, letter_label))
def create_patch(patch, img, latitude, longitude, yaw):
#
# Letters are pasted on a rectangle to improve the segmentation accuracy.
#
letter_label = random.choice(gs.LETTER_LABELS)
target_label = gs.SHAPE_LABELS.index('Rectangle')
if letter_label != 'no target':
if letter_label == 'rotated letter':
dangle = random.randint(45, 315)
else:
dangle = random.randint(-5, 5)
#
# Paste a rotatedrandom target on the patch
#
target, _, _ = AUVSItargets.randomTarget(
altitude=0,
longitude=longitude,
latitude=latitude,
target_label=target_label,
orientation=yaw+dangle
)
br = img.pastePatch(patch=patch, target=target)
br = AUVSItargets.squareCoords(br, noise=False)
patch = patch[br[1]:br[3], br[0]:br[2], ...]
#
# Mask out the letter and tight crop.
#
try:
kmean_mask, _ = AUVSItargets.KMEANS.getLetterMask(patch)
if kmean_mask is None:
#
# The segmentation failed.
#
return None, None
mask = AUVSItargets.tightCrop(kmean_mask)
except:
print(traceback.format_exc())
return None, None
print gs.LETTER_LABELS.index(letter_label)
return mask, gs.LETTER_LABELS.index(letter_label)
if __name__ == '__main__':
cmdline = argparse.ArgumentParser(usage="usage: ./{}"
.format(os.path.basename(__file__)),
description="Create letter patches")
cmdline.add_argument("--jobs",
"-j",
action="store",
help="Number of cores to use (default=1).",
type=int,
dest="jobs",
default=1)
args = cmdline.parse_args()
main(jobs=args.jobs)