Skip to content

Commit cc266e5

Browse files
authored
add --save_prediction_only support for TopDown KeyPoint Metric (PaddlePaddle#3865)
* add --save_prediction_only support for TopDown KeyPoint Metric * add a use case for save_prediction_only
1 parent ed0cd8d commit cc266e5

File tree

3 files changed

+56
-8
lines changed

3 files changed

+56
-8
lines changed

configs/keypoint/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/keypoint/higherhrnet/hig
7575

7676
#MPII DataSet
7777
CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/keypoint/hrnet/hrnet_w32_256x256_mpii.yml
78+
79+
#当只需要保存评估预测的结果时,可以通过设置save_prediction_only参数实现,评估预测结果默认保存在output/keypoints_results.json文件中
80+
CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_512.yml --save_prediction_only
7881
```
7982

8083
**模型预测:**

ppdet/engine/trainer.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -227,19 +227,27 @@ def _init_metrics(self, validate=False):
227227
eval_dataset = self.cfg['EvalDataset']
228228
eval_dataset.check_or_download_dataset()
229229
anno_file = eval_dataset.get_anno()
230+
save_prediction_only = self.cfg.get('save_prediction_only', False)
230231
self._metrics = [
231-
KeyPointTopDownCOCOEval(anno_file,
232-
len(eval_dataset), self.cfg.num_joints,
233-
self.cfg.save_dir)
232+
KeyPointTopDownCOCOEval(
233+
anno_file,
234+
len(eval_dataset),
235+
self.cfg.num_joints,
236+
self.cfg.save_dir,
237+
save_prediction_only=save_prediction_only)
234238
]
235239
elif self.cfg.metric == 'KeyPointTopDownMPIIEval':
236240
eval_dataset = self.cfg['EvalDataset']
237241
eval_dataset.check_or_download_dataset()
238242
anno_file = eval_dataset.get_anno()
243+
save_prediction_only = self.cfg.get('save_prediction_only', False)
239244
self._metrics = [
240-
KeyPointTopDownMPIIEval(anno_file,
241-
len(eval_dataset), self.cfg.num_joints,
242-
self.cfg.save_dir)
245+
KeyPointTopDownMPIIEval(
246+
anno_file,
247+
len(eval_dataset),
248+
self.cfg.num_joints,
249+
self.cfg.save_dir,
250+
save_prediction_only=save_prediction_only)
243251
]
244252
elif self.cfg.metric == 'MOTDet':
245253
self._metrics = [JDEDetMetric(), ]

ppdet/metrics/keypoint_metrics.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from pycocotools.cocoeval import COCOeval
2121
from ..modeling.keypoint_utils import oks_nms
2222
from scipy.io import loadmat, savemat
23+
from ppdet.utils.logger import setup_logger
24+
logger = setup_logger(__name__)
2325

2426
__all__ = ['KeyPointTopDownCOCOEval', 'KeyPointTopDownMPIIEval']
2527

@@ -38,7 +40,8 @@ def __init__(self,
3840
output_eval,
3941
iou_type='keypoints',
4042
in_vis_thre=0.2,
41-
oks_thre=0.9):
43+
oks_thre=0.9,
44+
save_prediction_only=False):
4245
super(KeyPointTopDownCOCOEval, self).__init__()
4346
self.coco = COCO(anno_file)
4447
self.num_samples = num_samples
@@ -48,6 +51,7 @@ def __init__(self,
4851
self.oks_thre = oks_thre
4952
self.output_eval = output_eval
5053
self.res_file = os.path.join(output_eval, "keypoints_results.json")
54+
self.save_prediction_only = save_prediction_only
5155
self.reset()
5256

5357
def reset(self):
@@ -90,6 +94,7 @@ def _write_coco_keypoint_results(self, keypoints):
9094
os.makedirs(self.output_eval)
9195
with open(self.res_file, 'w') as f:
9296
json.dump(results, f, sort_keys=True, indent=4)
97+
logger.info(f'The keypoint result is saved to {self.res_file}.')
9398
try:
9499
json.load(open(self.res_file))
95100
except Exception:
@@ -178,6 +183,10 @@ def accumulate(self):
178183
self.get_final_results(self.results['all_preds'],
179184
self.results['all_boxes'],
180185
self.results['image_path'])
186+
if self.save_prediction_only:
187+
logger.info(f'The keypoint result is saved to {self.res_file} '
188+
'and do not evaluate the mAP.')
189+
return
181190
coco_dt = self.coco.loadRes(self.res_file)
182191
coco_eval = COCOeval(self.coco, coco_dt, 'keypoints')
183192
coco_eval.params.useSegm = None
@@ -191,6 +200,8 @@ def accumulate(self):
191200
self.eval_results['keypoint'] = keypoint_stats
192201

193202
def log(self):
203+
if self.save_prediction_only:
204+
return
194205
stats_names = [
195206
'AP', 'Ap .5', 'AP .75', 'AP (M)', 'AP (L)', 'AR', 'AR .5',
196207
'AR .75', 'AR (M)', 'AR (L)'
@@ -213,9 +224,12 @@ def __init__(self,
213224
num_samples,
214225
num_joints,
215226
output_eval,
216-
oks_thre=0.9):
227+
oks_thre=0.9,
228+
save_prediction_only=False):
217229
super(KeyPointTopDownMPIIEval, self).__init__()
218230
self.ann_file = anno_file
231+
self.res_file = os.path.join(output_eval, "keypoints_results.json")
232+
self.save_prediction_only = save_prediction_only
219233
self.reset()
220234

221235
def reset(self):
@@ -239,9 +253,32 @@ def update(self, inputs, outputs):
239253
self.results.append(results)
240254

241255
def accumulate(self):
256+
self._mpii_keypoint_results_save()
257+
if self.save_prediction_only:
258+
logger.info(f'The keypoint result is saved to {self.res_file} '
259+
'and do not evaluate the mAP.')
260+
return
261+
242262
self.eval_results = self.evaluate(self.results)
243263

264+
def _mpii_keypoint_results_save(self):
265+
results = []
266+
for res in self.results:
267+
if len(res) == 0:
268+
continue
269+
result = [{
270+
'preds': res['preds'][k].tolist(),
271+
'boxes': res['boxes'][k].tolist(),
272+
'image_path': res['image_path'][k],
273+
} for k in range(len(res))]
274+
results.extend(result)
275+
with open(self.res_file, 'w') as f:
276+
json.dump(results, f, sort_keys=True, indent=4)
277+
logger.info(f'The keypoint result is saved to {self.res_file}.')
278+
244279
def log(self):
280+
if self.save_prediction_only:
281+
return
245282
for item, value in self.eval_results.items():
246283
print("{} : {}".format(item, value))
247284

0 commit comments

Comments
 (0)