Skip to content

Commit f85737e

Browse files
committed
training example for retinanet
1 parent 67a23d4 commit f85737e

File tree

2 files changed

+118
-6
lines changed

2 files changed

+118
-6
lines changed

image_detection/retinanet_svhn.py

+105
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
from fastestimator.pipeline.dynamic.preprocess import AbstractPreprocessing as AbstractPreprocessingD
2+
from fastestimator.architecture.retinanet import RetinaNet, get_fpn_anchor_box, get_target
3+
from fastestimator.pipeline.dynamic.preprocess import ImageReader
4+
from fastestimator.pipeline.static.preprocess import Minmax
5+
from fastestimator.estimator.estimator import Estimator
6+
from fastestimator.pipeline.pipeline import Pipeline
7+
from fastestimator.estimator.trace import Accuracy
8+
import tensorflow as tf
9+
import numpy as np
10+
import svhn_data
11+
import cv2
12+
13+
class Network:
14+
def __init__(self):
15+
self.model = RetinaNet(input_shape=(64, 64, 3), num_classes=10)
16+
self.optimizer = tf.optimizers.Adam()
17+
self.loss = MyLoss()
18+
19+
def train_op(self, batch):
20+
with tf.GradientTape() as tape:
21+
predictions = self.model(batch["image"])
22+
loss = self.loss((batch["target_cls"], batch["target_loc"]), predictions)
23+
gradients = tape.gradient(loss, self.model.trainable_variables)
24+
self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
25+
return predictions, loss
26+
27+
def eval_op(self, batch):
28+
predictions = self.model(batch["image"], training=False)
29+
loss = self.loss((batch["target_cls"], batch["target_loc"]), predictions)
30+
return predictions, loss
31+
32+
class MyPipeline(Pipeline):
33+
def edit_feature(self, feature):
34+
height, width = feature["image"].shape[0], feature["image"].shape[1]
35+
feature["x1"], feature["y1"], feature["x2"], feature["y2"] = feature["x1"]/width, feature["y1"]/height, feature["x2"]/width, feature["y2"]/height
36+
feature["image"] = cv2.resize(feature["image"], (64, 64))
37+
anchorbox = get_fpn_anchor_box(input_shape=feature["image"].shape)
38+
target_cls, target_loc = get_target(anchorbox, feature["label"], feature["x1"], feature["y1"], feature["x2"], feature["y2"], num_classes=10)
39+
feature["target_cls"], feature["target_loc"] = target_cls, target_loc
40+
return feature
41+
42+
class String2List(AbstractPreprocessingD):
43+
#this thing converts '[1, 2, 3]' into np.array([1, 2, 3])
44+
def transform(self, data):
45+
data = np.array([int(x) for x in data[1:-1].split(',')])
46+
return data
47+
48+
class MyLoss(tf.losses.Loss):
49+
def call(self, y_true, y_pred):
50+
cls_gt, loc_gt = tuple(y_true)
51+
cls_pred, loc_pred = tuple(y_pred)
52+
focal_loss, obj_idx = self.focal_loss(cls_gt, cls_pred, num_classes=10)
53+
smooth_l1_loss = self.smooth_l1(loc_gt, loc_pred, obj_idx)
54+
return focal_loss+smooth_l1_loss
55+
56+
def focal_loss(self, cls_gt, cls_pred, num_classes, alpha=0.25, gamma=2.0):
57+
#cls_gt has shape [B, A], cls_pred is in [B, A, K]
58+
obj_idx = tf.where(tf.greater_equal(cls_gt, 0)) #index of object
59+
obj_bg_idx = tf.where(tf.greater_equal(cls_gt, -1)) #index of object and background
60+
cls_gt = tf.one_hot(cls_gt, num_classes)
61+
cls_gt = tf.gather_nd(cls_gt, obj_bg_idx)
62+
cls_pred = tf.gather_nd(cls_pred, obj_bg_idx)
63+
#getting the object count for each image in batch
64+
_, idx, count = tf.unique_with_counts(obj_bg_idx[:,0])
65+
object_count = tf.gather_nd(count, tf.reshape(idx, (-1, 1)))
66+
object_count = tf.tile(tf.reshape(object_count,(-1, 1)), [1,num_classes])
67+
object_count = tf.cast(object_count, tf.float32)
68+
#reshape to the correct shape
69+
cls_gt = tf.reshape(cls_gt, (-1, 1))
70+
cls_pred = tf.reshape(cls_pred, (-1, 1))
71+
object_count = tf.reshape(object_count, (-1, 1))
72+
# compute the focal weight on each selected anchor box
73+
alpha_factor = tf.ones_like(cls_gt) * alpha
74+
alpha_factor = tf.where(tf.equal(cls_gt, 1), alpha_factor, 1 - alpha_factor)
75+
focal_weight = tf.where(tf.equal(cls_gt, 1), 1 - cls_pred, cls_pred)
76+
focal_weight = alpha_factor * focal_weight ** gamma / object_count
77+
focal_loss = tf.losses.BinaryCrossentropy()(cls_gt, cls_pred, sample_weight=focal_weight)
78+
return focal_loss, obj_idx
79+
80+
def smooth_l1(self, loc_gt, loc_pred, obj_idx):
81+
#loc_gt anf loc_pred has shape [B, A, 4]
82+
loc_gt = tf.gather_nd(loc_gt, obj_idx)
83+
loc_pred = tf.gather_nd(loc_pred, obj_idx)
84+
loc_gt = tf.reshape(loc_gt, (-1, 1))
85+
loc_pred = tf.reshape(loc_pred, (-1, 1))
86+
loc_diff = tf.abs(loc_gt - loc_pred)
87+
smooth_l1_loss = tf.where(tf.less(loc_diff,1), 0.5 * loc_diff**2, loc_diff-0.5)
88+
smooth_l1_loss = tf.reduce_mean(smooth_l1_loss)
89+
return smooth_l1_loss
90+
91+
def get_estimator():
92+
train_csv, test_csv, path = svhn_data.load_data()
93+
94+
pipeline = MyPipeline(batch_size=256,
95+
feature_name=["image", "label", "x1", "y1", "x2", "y2", "target_cls", "target_loc"],
96+
train_data=train_csv,
97+
validation_data=test_csv,
98+
transform_dataset=[[ImageReader(parent_path=path)], [String2List()], [String2List()], [String2List()], [String2List()], [String2List()], [],[]],
99+
transform_train= [[Minmax()], [], [], [],[],[],[],[]],
100+
padded_batch=True)
101+
102+
estimator = Estimator(network= Network(),
103+
pipeline=pipeline,
104+
epochs= 10)
105+
return estimator

image_detection/svhn_data.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import os
22
import tarfile
33
import tempfile
4+
from operator import add
5+
46
import h5py
57
import numpy as np
68
import pandas as pd
@@ -16,8 +18,7 @@ def get_bbox(index, hdf5_data):
1618
item = hdf5_data['digitStruct']['bbox'][index].item()
1719
for key in ['label', 'left', 'top', 'width', 'height']:
1820
attr = hdf5_data[item][key]
19-
values = [hdf5_data[attr.value[i].item()].value[0][0]
20-
for i in range(len(attr))] if len(attr) > 1 else [attr.value[0][0]]
21+
values = [int(hdf5_data[attr.value[i].item()].value[0][0]) for i in range(len(attr))] if len(attr) > 1 else [int(attr.value[0][0])]
2122
attrs[key] = values
2223
return attrs
2324

@@ -31,10 +32,15 @@ def img_boundingbox_data_constructor(data_folder, mode, csv_path):
3132
if j % logging_interval == 0:
3233
print("retrieving bounding box for %s: %f%%" % (mode, j/num_example*100))
3334
img_name = get_name(j, f)
34-
row_dict = get_bbox(j, f)
35-
row_dict['img_name'] = os.path.join(mode, img_name)
35+
bbox = get_bbox(j, f)
36+
row_dict = {'image': os.path.join(mode, img_name),
37+
'label': bbox["label"],
38+
'x1': bbox["left"],
39+
'y1': bbox["top"],
40+
'x2': list(map(add, bbox["left"], bbox["width"])),
41+
'y2': list(map(add, bbox["top"], bbox["height"]))}
3642
row_list.append(row_dict)
37-
bbox_df = pd.DataFrame(row_list, columns=['img_name','label','left','top','width','height'])
43+
bbox_df = pd.DataFrame(row_list, columns=['image','label','x1','y1','x2','y2'])
3844
bbox_df.to_csv(csv_path, index=False)
3945
return bbox_df
4046

@@ -44,7 +50,7 @@ def load_data(path=None):
4450
if not os.path.exists(path):
4551
os.mkdir(path)
4652
train_csv = os.path.join(path, "train_data.csv")
47-
test_csv = os.path.join(path, "eval_data.csv")
53+
test_csv = os.path.join(path, "test_data.csv")
4854
if not (os.path.exists(os.path.join(path, "train.tar.gz")) and os.path.exists(os.path.join(path, "test.tar.gz"))):
4955
print("downloading data to %s" % path)
5056
wget.download('http://ufldl.stanford.edu/housenumbers/train.tar.gz', path)
@@ -57,6 +63,7 @@ def load_data(path=None):
5763
test_file.extractall(path=path)
5864
train_file.extractall(path=path)
5965
if not (os.path.exists(train_csv) and os.path.exists(test_csv)):
66+
print("constructing bounding box data...")
6067
train_folder = os.path.join(path, "train")
6168
test_folder = os.path.join(path, "test")
6269
img_boundingbox_data_constructor(train_folder, "train", train_csv)

0 commit comments

Comments
 (0)