1
+ from fastestimator .pipeline .dynamic .preprocess import AbstractPreprocessing as AbstractPreprocessingD
2
+ from fastestimator .architecture .retinanet import RetinaNet , get_fpn_anchor_box , get_target
3
+ from fastestimator .pipeline .dynamic .preprocess import ImageReader
4
+ from fastestimator .pipeline .static .preprocess import Minmax
5
+ from fastestimator .estimator .estimator import Estimator
6
+ from fastestimator .pipeline .pipeline import Pipeline
7
+ from fastestimator .estimator .trace import Accuracy
8
+ import tensorflow as tf
9
+ import numpy as np
10
+ import svhn_data
11
+ import cv2
12
+
13
+ class Network :
14
+ def __init__ (self ):
15
+ self .model = RetinaNet (input_shape = (64 , 64 , 3 ), num_classes = 10 )
16
+ self .optimizer = tf .optimizers .Adam ()
17
+ self .loss = MyLoss ()
18
+
19
+ def train_op (self , batch ):
20
+ with tf .GradientTape () as tape :
21
+ predictions = self .model (batch ["image" ])
22
+ loss = self .loss ((batch ["target_cls" ], batch ["target_loc" ]), predictions )
23
+ gradients = tape .gradient (loss , self .model .trainable_variables )
24
+ self .optimizer .apply_gradients (zip (gradients , self .model .trainable_variables ))
25
+ return predictions , loss
26
+
27
+ def eval_op (self , batch ):
28
+ predictions = self .model (batch ["image" ], training = False )
29
+ loss = self .loss ((batch ["target_cls" ], batch ["target_loc" ]), predictions )
30
+ return predictions , loss
31
+
32
+ class MyPipeline (Pipeline ):
33
+ def edit_feature (self , feature ):
34
+ height , width = feature ["image" ].shape [0 ], feature ["image" ].shape [1 ]
35
+ feature ["x1" ], feature ["y1" ], feature ["x2" ], feature ["y2" ] = feature ["x1" ]/ width , feature ["y1" ]/ height , feature ["x2" ]/ width , feature ["y2" ]/ height
36
+ feature ["image" ] = cv2 .resize (feature ["image" ], (64 , 64 ))
37
+ anchorbox = get_fpn_anchor_box (input_shape = feature ["image" ].shape )
38
+ target_cls , target_loc = get_target (anchorbox , feature ["label" ], feature ["x1" ], feature ["y1" ], feature ["x2" ], feature ["y2" ], num_classes = 10 )
39
+ feature ["target_cls" ], feature ["target_loc" ] = target_cls , target_loc
40
+ return feature
41
+
42
+ class String2List (AbstractPreprocessingD ):
43
+ #this thing converts '[1, 2, 3]' into np.array([1, 2, 3])
44
+ def transform (self , data ):
45
+ data = np .array ([int (x ) for x in data [1 :- 1 ].split (',' )])
46
+ return data
47
+
48
+ class MyLoss (tf .losses .Loss ):
49
+ def call (self , y_true , y_pred ):
50
+ cls_gt , loc_gt = tuple (y_true )
51
+ cls_pred , loc_pred = tuple (y_pred )
52
+ focal_loss , obj_idx = self .focal_loss (cls_gt , cls_pred , num_classes = 10 )
53
+ smooth_l1_loss = self .smooth_l1 (loc_gt , loc_pred , obj_idx )
54
+ return focal_loss + smooth_l1_loss
55
+
56
+ def focal_loss (self , cls_gt , cls_pred , num_classes , alpha = 0.25 , gamma = 2.0 ):
57
+ #cls_gt has shape [B, A], cls_pred is in [B, A, K]
58
+ obj_idx = tf .where (tf .greater_equal (cls_gt , 0 )) #index of object
59
+ obj_bg_idx = tf .where (tf .greater_equal (cls_gt , - 1 )) #index of object and background
60
+ cls_gt = tf .one_hot (cls_gt , num_classes )
61
+ cls_gt = tf .gather_nd (cls_gt , obj_bg_idx )
62
+ cls_pred = tf .gather_nd (cls_pred , obj_bg_idx )
63
+ #getting the object count for each image in batch
64
+ _ , idx , count = tf .unique_with_counts (obj_bg_idx [:,0 ])
65
+ object_count = tf .gather_nd (count , tf .reshape (idx , (- 1 , 1 )))
66
+ object_count = tf .tile (tf .reshape (object_count ,(- 1 , 1 )), [1 ,num_classes ])
67
+ object_count = tf .cast (object_count , tf .float32 )
68
+ #reshape to the correct shape
69
+ cls_gt = tf .reshape (cls_gt , (- 1 , 1 ))
70
+ cls_pred = tf .reshape (cls_pred , (- 1 , 1 ))
71
+ object_count = tf .reshape (object_count , (- 1 , 1 ))
72
+ # compute the focal weight on each selected anchor box
73
+ alpha_factor = tf .ones_like (cls_gt ) * alpha
74
+ alpha_factor = tf .where (tf .equal (cls_gt , 1 ), alpha_factor , 1 - alpha_factor )
75
+ focal_weight = tf .where (tf .equal (cls_gt , 1 ), 1 - cls_pred , cls_pred )
76
+ focal_weight = alpha_factor * focal_weight ** gamma / object_count
77
+ focal_loss = tf .losses .BinaryCrossentropy ()(cls_gt , cls_pred , sample_weight = focal_weight )
78
+ return focal_loss , obj_idx
79
+
80
+ def smooth_l1 (self , loc_gt , loc_pred , obj_idx ):
81
+ #loc_gt anf loc_pred has shape [B, A, 4]
82
+ loc_gt = tf .gather_nd (loc_gt , obj_idx )
83
+ loc_pred = tf .gather_nd (loc_pred , obj_idx )
84
+ loc_gt = tf .reshape (loc_gt , (- 1 , 1 ))
85
+ loc_pred = tf .reshape (loc_pred , (- 1 , 1 ))
86
+ loc_diff = tf .abs (loc_gt - loc_pred )
87
+ smooth_l1_loss = tf .where (tf .less (loc_diff ,1 ), 0.5 * loc_diff ** 2 , loc_diff - 0.5 )
88
+ smooth_l1_loss = tf .reduce_mean (smooth_l1_loss )
89
+ return smooth_l1_loss
90
+
91
+ def get_estimator ():
92
+ train_csv , test_csv , path = svhn_data .load_data ()
93
+
94
+ pipeline = MyPipeline (batch_size = 256 ,
95
+ feature_name = ["image" , "label" , "x1" , "y1" , "x2" , "y2" , "target_cls" , "target_loc" ],
96
+ train_data = train_csv ,
97
+ validation_data = test_csv ,
98
+ transform_dataset = [[ImageReader (parent_path = path )], [String2List ()], [String2List ()], [String2List ()], [String2List ()], [String2List ()], [],[]],
99
+ transform_train = [[Minmax ()], [], [], [],[],[],[],[]],
100
+ padded_batch = True )
101
+
102
+ estimator = Estimator (network = Network (),
103
+ pipeline = pipeline ,
104
+ epochs = 10 )
105
+ return estimator
0 commit comments