2020from pycocotools .cocoeval import COCOeval
2121from ..modeling .keypoint_utils import oks_nms
2222from scipy .io import loadmat , savemat
23+ from ppdet .utils .logger import setup_logger
24+ logger = setup_logger (__name__ )
2325
2426__all__ = ['KeyPointTopDownCOCOEval' , 'KeyPointTopDownMPIIEval' ]
2527
@@ -38,7 +40,8 @@ def __init__(self,
3840 output_eval ,
3941 iou_type = 'keypoints' ,
4042 in_vis_thre = 0.2 ,
41- oks_thre = 0.9 ):
43+ oks_thre = 0.9 ,
44+ save_prediction_only = False ):
4245 super (KeyPointTopDownCOCOEval , self ).__init__ ()
4346 self .coco = COCO (anno_file )
4447 self .num_samples = num_samples
@@ -48,6 +51,7 @@ def __init__(self,
4851 self .oks_thre = oks_thre
4952 self .output_eval = output_eval
5053 self .res_file = os .path .join (output_eval , "keypoints_results.json" )
54+ self .save_prediction_only = save_prediction_only
5155 self .reset ()
5256
5357 def reset (self ):
@@ -90,6 +94,7 @@ def _write_coco_keypoint_results(self, keypoints):
9094 os .makedirs (self .output_eval )
9195 with open (self .res_file , 'w' ) as f :
9296 json .dump (results , f , sort_keys = True , indent = 4 )
97+ logger .info (f'The keypoint result is saved to { self .res_file } .' )
9398 try :
9499 json .load (open (self .res_file ))
95100 except Exception :
@@ -178,6 +183,10 @@ def accumulate(self):
178183 self .get_final_results (self .results ['all_preds' ],
179184 self .results ['all_boxes' ],
180185 self .results ['image_path' ])
186+ if self .save_prediction_only :
187+ logger .info (f'The keypoint result is saved to { self .res_file } '
188+ 'and do not evaluate the mAP.' )
189+ return
181190 coco_dt = self .coco .loadRes (self .res_file )
182191 coco_eval = COCOeval (self .coco , coco_dt , 'keypoints' )
183192 coco_eval .params .useSegm = None
@@ -191,6 +200,8 @@ def accumulate(self):
191200 self .eval_results ['keypoint' ] = keypoint_stats
192201
193202 def log (self ):
203+ if self .save_prediction_only :
204+ return
194205 stats_names = [
195206 'AP' , 'Ap .5' , 'AP .75' , 'AP (M)' , 'AP (L)' , 'AR' , 'AR .5' ,
196207 'AR .75' , 'AR (M)' , 'AR (L)'
@@ -213,9 +224,12 @@ def __init__(self,
213224 num_samples ,
214225 num_joints ,
215226 output_eval ,
216- oks_thre = 0.9 ):
227+ oks_thre = 0.9 ,
228+ save_prediction_only = False ):
217229 super (KeyPointTopDownMPIIEval , self ).__init__ ()
218230 self .ann_file = anno_file
231+ self .res_file = os .path .join (output_eval , "keypoints_results.json" )
232+ self .save_prediction_only = save_prediction_only
219233 self .reset ()
220234
221235 def reset (self ):
@@ -239,9 +253,32 @@ def update(self, inputs, outputs):
239253 self .results .append (results )
240254
241255 def accumulate (self ):
256+ self ._mpii_keypoint_results_save ()
257+ if self .save_prediction_only :
258+ logger .info (f'The keypoint result is saved to { self .res_file } '
259+ 'and do not evaluate the mAP.' )
260+ return
261+
242262 self .eval_results = self .evaluate (self .results )
243263
264+ def _mpii_keypoint_results_save (self ):
265+ results = []
266+ for res in self .results :
267+ if len (res ) == 0 :
268+ continue
269+ result = [{
270+ 'preds' : res ['preds' ][k ].tolist (),
271+ 'boxes' : res ['boxes' ][k ].tolist (),
272+ 'image_path' : res ['image_path' ][k ],
273+ } for k in range (len (res ))]
274+ results .extend (result )
275+ with open (self .res_file , 'w' ) as f :
276+ json .dump (results , f , sort_keys = True , indent = 4 )
277+ logger .info (f'The keypoint result is saved to { self .res_file } .' )
278+
244279 def log (self ):
280+ if self .save_prediction_only :
281+ return
245282 for item , value in self .eval_results .items ():
246283 print ("{} : {}" .format (item , value ))
247284
0 commit comments