Skip to content

Commit

Permalink
training example for retinanet
Browse files Browse the repository at this point in the history
  • Loading branch information
vbvg2008 committed Jun 26, 2019
1 parent 67a23d4 commit f85737e
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 6 deletions.
105 changes: 105 additions & 0 deletions image_detection/retinanet_svhn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from fastestimator.pipeline.dynamic.preprocess import AbstractPreprocessing as AbstractPreprocessingD
from fastestimator.architecture.retinanet import RetinaNet, get_fpn_anchor_box, get_target
from fastestimator.pipeline.dynamic.preprocess import ImageReader
from fastestimator.pipeline.static.preprocess import Minmax
from fastestimator.estimator.estimator import Estimator
from fastestimator.pipeline.pipeline import Pipeline
from fastestimator.estimator.trace import Accuracy
import tensorflow as tf
import numpy as np
import svhn_data
import cv2

class Network:
def __init__(self):
self.model = RetinaNet(input_shape=(64, 64, 3), num_classes=10)
self.optimizer = tf.optimizers.Adam()
self.loss = MyLoss()

def train_op(self, batch):
with tf.GradientTape() as tape:
predictions = self.model(batch["image"])
loss = self.loss((batch["target_cls"], batch["target_loc"]), predictions)
gradients = tape.gradient(loss, self.model.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
return predictions, loss

def eval_op(self, batch):
predictions = self.model(batch["image"], training=False)
loss = self.loss((batch["target_cls"], batch["target_loc"]), predictions)
return predictions, loss

class MyPipeline(Pipeline):
def edit_feature(self, feature):
height, width = feature["image"].shape[0], feature["image"].shape[1]
feature["x1"], feature["y1"], feature["x2"], feature["y2"] = feature["x1"]/width, feature["y1"]/height, feature["x2"]/width, feature["y2"]/height
feature["image"] = cv2.resize(feature["image"], (64, 64))
anchorbox = get_fpn_anchor_box(input_shape=feature["image"].shape)
target_cls, target_loc = get_target(anchorbox, feature["label"], feature["x1"], feature["y1"], feature["x2"], feature["y2"], num_classes=10)
feature["target_cls"], feature["target_loc"] = target_cls, target_loc
return feature

class String2List(AbstractPreprocessingD):
#this thing converts '[1, 2, 3]' into np.array([1, 2, 3])
def transform(self, data):
data = np.array([int(x) for x in data[1:-1].split(',')])
return data

class MyLoss(tf.losses.Loss):
def call(self, y_true, y_pred):
cls_gt, loc_gt = tuple(y_true)
cls_pred, loc_pred = tuple(y_pred)
focal_loss, obj_idx = self.focal_loss(cls_gt, cls_pred, num_classes=10)
smooth_l1_loss = self.smooth_l1(loc_gt, loc_pred, obj_idx)
return focal_loss+smooth_l1_loss

def focal_loss(self, cls_gt, cls_pred, num_classes, alpha=0.25, gamma=2.0):
#cls_gt has shape [B, A], cls_pred is in [B, A, K]
obj_idx = tf.where(tf.greater_equal(cls_gt, 0)) #index of object
obj_bg_idx = tf.where(tf.greater_equal(cls_gt, -1)) #index of object and background
cls_gt = tf.one_hot(cls_gt, num_classes)
cls_gt = tf.gather_nd(cls_gt, obj_bg_idx)
cls_pred = tf.gather_nd(cls_pred, obj_bg_idx)
#getting the object count for each image in batch
_, idx, count = tf.unique_with_counts(obj_bg_idx[:,0])
object_count = tf.gather_nd(count, tf.reshape(idx, (-1, 1)))
object_count = tf.tile(tf.reshape(object_count,(-1, 1)), [1,num_classes])
object_count = tf.cast(object_count, tf.float32)
#reshape to the correct shape
cls_gt = tf.reshape(cls_gt, (-1, 1))
cls_pred = tf.reshape(cls_pred, (-1, 1))
object_count = tf.reshape(object_count, (-1, 1))
# compute the focal weight on each selected anchor box
alpha_factor = tf.ones_like(cls_gt) * alpha
alpha_factor = tf.where(tf.equal(cls_gt, 1), alpha_factor, 1 - alpha_factor)
focal_weight = tf.where(tf.equal(cls_gt, 1), 1 - cls_pred, cls_pred)
focal_weight = alpha_factor * focal_weight ** gamma / object_count
focal_loss = tf.losses.BinaryCrossentropy()(cls_gt, cls_pred, sample_weight=focal_weight)
return focal_loss, obj_idx

def smooth_l1(self, loc_gt, loc_pred, obj_idx):
#loc_gt anf loc_pred has shape [B, A, 4]
loc_gt = tf.gather_nd(loc_gt, obj_idx)
loc_pred = tf.gather_nd(loc_pred, obj_idx)
loc_gt = tf.reshape(loc_gt, (-1, 1))
loc_pred = tf.reshape(loc_pred, (-1, 1))
loc_diff = tf.abs(loc_gt - loc_pred)
smooth_l1_loss = tf.where(tf.less(loc_diff,1), 0.5 * loc_diff**2, loc_diff-0.5)
smooth_l1_loss = tf.reduce_mean(smooth_l1_loss)
return smooth_l1_loss

def get_estimator():
train_csv, test_csv, path = svhn_data.load_data()

pipeline = MyPipeline(batch_size=256,
feature_name=["image", "label", "x1", "y1", "x2", "y2", "target_cls", "target_loc"],
train_data=train_csv,
validation_data=test_csv,
transform_dataset=[[ImageReader(parent_path=path)], [String2List()], [String2List()], [String2List()], [String2List()], [String2List()], [],[]],
transform_train= [[Minmax()], [], [], [],[],[],[],[]],
padded_batch=True)

estimator = Estimator(network= Network(),
pipeline=pipeline,
epochs= 10)
return estimator
19 changes: 13 additions & 6 deletions image_detection/svhn_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import tarfile
import tempfile
from operator import add

import h5py
import numpy as np
import pandas as pd
Expand All @@ -16,8 +18,7 @@ def get_bbox(index, hdf5_data):
item = hdf5_data['digitStruct']['bbox'][index].item()
for key in ['label', 'left', 'top', 'width', 'height']:
attr = hdf5_data[item][key]
values = [hdf5_data[attr.value[i].item()].value[0][0]
for i in range(len(attr))] if len(attr) > 1 else [attr.value[0][0]]
values = [int(hdf5_data[attr.value[i].item()].value[0][0]) for i in range(len(attr))] if len(attr) > 1 else [int(attr.value[0][0])]
attrs[key] = values
return attrs

Expand All @@ -31,10 +32,15 @@ def img_boundingbox_data_constructor(data_folder, mode, csv_path):
if j % logging_interval == 0:
print("retrieving bounding box for %s: %f%%" % (mode, j/num_example*100))
img_name = get_name(j, f)
row_dict = get_bbox(j, f)
row_dict['img_name'] = os.path.join(mode, img_name)
bbox = get_bbox(j, f)
row_dict = {'image': os.path.join(mode, img_name),
'label': bbox["label"],
'x1': bbox["left"],
'y1': bbox["top"],
'x2': list(map(add, bbox["left"], bbox["width"])),
'y2': list(map(add, bbox["top"], bbox["height"]))}
row_list.append(row_dict)
bbox_df = pd.DataFrame(row_list, columns=['img_name','label','left','top','width','height'])
bbox_df = pd.DataFrame(row_list, columns=['image','label','x1','y1','x2','y2'])
bbox_df.to_csv(csv_path, index=False)
return bbox_df

Expand All @@ -44,7 +50,7 @@ def load_data(path=None):
if not os.path.exists(path):
os.mkdir(path)
train_csv = os.path.join(path, "train_data.csv")
test_csv = os.path.join(path, "eval_data.csv")
test_csv = os.path.join(path, "test_data.csv")
if not (os.path.exists(os.path.join(path, "train.tar.gz")) and os.path.exists(os.path.join(path, "test.tar.gz"))):
print("downloading data to %s" % path)
wget.download('http://ufldl.stanford.edu/housenumbers/train.tar.gz', path)
Expand All @@ -57,6 +63,7 @@ def load_data(path=None):
test_file.extractall(path=path)
train_file.extractall(path=path)
if not (os.path.exists(train_csv) and os.path.exists(test_csv)):
print("constructing bounding box data...")
train_folder = os.path.join(path, "train")
test_folder = os.path.join(path, "test")
img_boundingbox_data_constructor(train_folder, "train", train_csv)
Expand Down

0 comments on commit f85737e

Please sign in to comment.