Skip to content

Commit

Permalink
[WIP] Add Keypoint R-CNN (facebookresearch#69)
Browse files Browse the repository at this point in the history
* [WIP] Keypoints inference on C2 models work

* Training seems to work

Still gives slightly worse results

* e2e training works but gives 3 and 5 mAP less

* Add modification proposed by @ChangErgou

Improves mAP by 1.5 points, to 0.514 and 0.609

* Keypoints reproduce expected results

* Clean coco.py

* Linter + remove unnecessary code

* Merge criteria for empty bboxes in has_valid_annotation

* Remove trailing print

* Add demo support for keypoints

Still need further cleanups and improvements, like adding fields support for the other ops in Keypoints

* More cleanups and misc improvements

* Fixes after rebase

* Add information to the readme

* Fix md formatting
  • Loading branch information
fmassa authored Feb 12, 2019
1 parent 1589ce0 commit e0a525a
Show file tree
Hide file tree
Showing 28 changed files with 1,013 additions and 32 deletions.
7 changes: 7 additions & 0 deletions MODEL_ZOO.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ R-50-FPN | Mask | 1x | 2 | 5.2 | 0.4536 | 11.3 | 0.12966 + 0.034 | 37.8 | 34.2 |
R-101-FPN | Mask | 1x | 2 | 7.9 | 0.5665 | 14.2 | 0.15384 + 0.034 | 40.1 | 36.1 | [6358805](https://download.pytorch.org/models/maskrcnn/e2e_mask_rcnn_R_101_FPN_1x.pth)
X-101-32x8d-FPN | Mask | 1x | 1 | 7.8 | 0.7562 | 37.8 | 0.21739 + 0.034 | 42.2 | 37.8 | [6358718](https://download.pytorch.org/models/maskrcnn/e2e_mask_rcnn_X_101_32x8d_FPN_1x.pth)

For person keypoint detection:

backbone | type | lr sched | im / gpu | train mem(GB) | train time (s/iter) | total train time(hr) | inference time(s/im) | box AP | keypoint AP | model id
-- | -- | -- | -- | -- | -- | -- | -- | -- | -- | --
R-50-FPN | Keypoint | 1x | 2 | 5.7 | 0.3771 | 9.4 | 0.10941 | 53.7 | 64.3 | 9981060



## Comparison with Detectron and mmdetection

Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ python webcam.py --min-image-size 300 MODEL.DEVICE cpu
python webcam.py --config-file ../configs/caffe2/e2e_mask_rcnn_R_101_FPN_1x_caffe2.yaml --min-image-size 300 MODEL.DEVICE cpu
# in order to see the probability heatmaps, pass --show-mask-heatmaps
python webcam.py --min-image-size 300 --show-mask-heatmaps MODEL.DEVICE cpu
# for the keypoint demo
python webcam.py --config-file ../configs/caffe2/e2e_keypoint_rcnn_R_50_FPN_1x_caffe2.yaml --min-image-size 300 MODEL.DEVICE cpu
```

A notebook with the demo can be found in [demo/Mask_R-CNN_demo.ipynb](demo/Mask_R-CNN_demo.ipynb).
Expand Down
43 changes: 43 additions & 0 deletions configs/caffe2/e2e_keypoint_rcnn_R_50_FPN_1x_caffe2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://Caffe2Detectron/COCO/37697547/e2e_keypoint_rcnn_R-50-FPN_1x"
BACKBONE:
CONV_BODY: "R-50-FPN"
OUT_CHANNELS: 256
RPN:
USE_FPN: True
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
ROI_BOX_HEAD:
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
PREDICTOR: "FPNPredictor"
NUM_CLASSES: 2
ROI_KEYPOINT_HEAD:
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
FEATURE_EXTRACTOR: "KeypointRCNNFeatureExtractor"
PREDICTOR: "KeypointRCNNPredictor"
POOLER_RESOLUTION: 14
POOLER_SAMPLING_RATIO: 2
RESOLUTION: 56
SHARE_BOX_FEATURE_EXTRACTOR: False
KEYPOINT_ON: True
DATASETS:
TRAIN: ("keypoints_coco_2014_train", "keypoints_coco_2014_valminusminival",)
TEST: ("keypoints_coco_2014_minival",)
INPUT:
MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
BASE_LR: 0.02
WEIGHT_DECAY: 0.0001
STEPS: (60000, 80000)
MAX_ITER: 90000
43 changes: 43 additions & 0 deletions configs/e2e_keypoint_rcnn_R_50_FPN_1x.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50"
BACKBONE:
CONV_BODY: "R-50-FPN"
OUT_CHANNELS: 256
RPN:
USE_FPN: True
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
ROI_BOX_HEAD:
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
PREDICTOR: "FPNPredictor"
NUM_CLASSES: 2
ROI_KEYPOINT_HEAD:
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
FEATURE_EXTRACTOR: "KeypointRCNNFeatureExtractor"
PREDICTOR: "KeypointRCNNPredictor"
POOLER_RESOLUTION: 14
POOLER_SAMPLING_RATIO: 2
RESOLUTION: 56
SHARE_BOX_FEATURE_EXTRACTOR: False
KEYPOINT_ON: True
DATASETS:
TRAIN: ("keypoints_coco_2014_train", "keypoints_coco_2014_valminusminival",)
TEST: ("keypoints_coco_2014_minival",)
INPUT:
MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
BASE_LR: 0.02
WEIGHT_DECAY: 0.0001
STEPS: (60000, 80000)
MAX_ITER: 90000
2 changes: 1 addition & 1 deletion configs/quick_schedules/e2e_faster_rcnn_R_50_C4_quick.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ DATASETS:
TRAIN: ("coco_2014_minival",)
TEST: ("coco_2014_minival",)
INPUT:
MIN_SIZE_TRAIN: 600
MIN_SIZE_TRAIN: (600,)
MAX_SIZE_TRAIN: 1000
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1000
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ DATASETS:
TRAIN: ("coco_2014_minival",)
TEST: ("coco_2014_minival",)
INPUT:
MIN_SIZE_TRAIN: 600
MIN_SIZE_TRAIN: (600,)
MAX_SIZE_TRAIN: 1000
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1000
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ DATASETS:
TRAIN: ("coco_2014_minival",)
TEST: ("coco_2014_minival",)
INPUT:
MIN_SIZE_TRAIN: 600
MIN_SIZE_TRAIN: (600,)
MAX_SIZE_TRAIN: 1000
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1000
Expand Down
50 changes: 50 additions & 0 deletions configs/quick_schedules/e2e_keypoint_rcnn_R_50_FPN_quick.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50"
BACKBONE:
CONV_BODY: "R-50-FPN"
OUT_CHANNELS: 256
RPN:
USE_FPN: True
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
BATCH_SIZE_PER_IMAGE: 256
ROI_BOX_HEAD:
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
PREDICTOR: "FPNPredictor"
NUM_CLASSES: 2
ROI_KEYPOINT_HEAD:
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
FEATURE_EXTRACTOR: "KeypointRCNNFeatureExtractor"
PREDICTOR: "KeypointRCNNPredictor"
POOLER_RESOLUTION: 14
POOLER_SAMPLING_RATIO: 2
RESOLUTION: 56
SHARE_BOX_FEATURE_EXTRACTOR: False
KEYPOINT_ON: True
DATASETS:
TRAIN: ("keypoints_coco_2014_minival",)
TEST: ("keypoints_coco_2014_minival",)
INPUT:
MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
MAX_SIZE_TRAIN: 1000
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1000
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
BASE_LR: 0.005
WEIGHT_DECAY: 0.0001
STEPS: (1500,)
MAX_ITER: 2000
IMS_PER_BATCH: 4
TEST:
IMS_PER_BATCH: 2
2 changes: 1 addition & 1 deletion configs/quick_schedules/e2e_mask_rcnn_R_50_C4_quick.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ DATASETS:
TRAIN: ("coco_2014_minival",)
TEST: ("coco_2014_minival",)
INPUT:
MIN_SIZE_TRAIN: 600
MIN_SIZE_TRAIN: (600,)
MAX_SIZE_TRAIN: 1000
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1000
Expand Down
2 changes: 1 addition & 1 deletion configs/quick_schedules/e2e_mask_rcnn_R_50_FPN_quick.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ DATASETS:
TRAIN: ("coco_2014_minival",)
TEST: ("coco_2014_minival",)
INPUT:
MIN_SIZE_TRAIN: 600
MIN_SIZE_TRAIN: (600,)
MAX_SIZE_TRAIN: 1000
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1000
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ DATASETS:
TRAIN: ("coco_2014_minival",)
TEST: ("coco_2014_minival",)
INPUT:
MIN_SIZE_TRAIN: 600
MIN_SIZE_TRAIN: (600,)
MAX_SIZE_TRAIN: 1000
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1000
Expand Down
2 changes: 1 addition & 1 deletion configs/quick_schedules/rpn_R_50_C4_quick.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ DATASETS:
TRAIN: ("coco_2014_minival",)
TEST: ("coco_2014_minival",)
INPUT:
MIN_SIZE_TRAIN: 600
MIN_SIZE_TRAIN: (600,)
MAX_SIZE_TRAIN: 1000
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1000
Expand Down
2 changes: 1 addition & 1 deletion configs/quick_schedules/rpn_R_50_FPN_quick.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ DATASETS:
TRAIN: ("coco_2014_minival",)
TEST: ("coco_2014_minival",)
INPUT:
MIN_SIZE_TRAIN: 600
MIN_SIZE_TRAIN: (600,)
MAX_SIZE_TRAIN: 1000
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1000
Expand Down
75 changes: 75 additions & 0 deletions demo/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ def run_on_opencv_image(self, image):
result = self.overlay_boxes(result, top_predictions)
if self.cfg.MODEL.MASK_ON:
result = self.overlay_mask(result, top_predictions)
if self.cfg.MODEL.KEYPOINT_ON:
result = self.overlay_keypoints(result, top_predictions)
result = self.overlay_class_names(result, top_predictions)

return result
Expand Down Expand Up @@ -297,6 +299,15 @@ def overlay_mask(self, image, predictions):

return composite

def overlay_keypoints(self, image, predictions):
keypoints = predictions.get_field("keypoints")
kps = keypoints.keypoints
scores = keypoints.get_field("logits")
kps = torch.cat((kps[:, :, 0:2], scores[:, :, None]), dim=2).numpy()
for region in kps:
image = vis_keypoints(image, region.transpose((1, 0)))
return image

def create_mask_montage(self, image, predictions):
"""
Create a montage showing the probability heatmaps for each one one of the
Expand Down Expand Up @@ -357,3 +368,67 @@ def overlay_class_names(self, image, predictions):
)

return image

import numpy as np
import matplotlib.pyplot as plt
from maskrcnn_benchmark.structures.keypoint import PersonKeypoints

def vis_keypoints(img, kps, kp_thresh=2, alpha=0.7):
"""Visualizes keypoints (adapted from vis_one_image).
kps has shape (4, #keypoints) where 4 rows are (x, y, logit, prob).
"""
dataset_keypoints = PersonKeypoints.NAMES
kp_lines = PersonKeypoints.CONNECTIONS

# Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv.
cmap = plt.get_cmap('rainbow')
colors = [cmap(i) for i in np.linspace(0, 1, len(kp_lines) + 2)]
colors = [(c[2] * 255, c[1] * 255, c[0] * 255) for c in colors]

# Perform the drawing on a copy of the image, to allow for blending.
kp_mask = np.copy(img)

# Draw mid shoulder / mid hip first for better visualization.
mid_shoulder = (
kps[:2, dataset_keypoints.index('right_shoulder')] +
kps[:2, dataset_keypoints.index('left_shoulder')]) / 2.0
sc_mid_shoulder = np.minimum(
kps[2, dataset_keypoints.index('right_shoulder')],
kps[2, dataset_keypoints.index('left_shoulder')])
mid_hip = (
kps[:2, dataset_keypoints.index('right_hip')] +
kps[:2, dataset_keypoints.index('left_hip')]) / 2.0
sc_mid_hip = np.minimum(
kps[2, dataset_keypoints.index('right_hip')],
kps[2, dataset_keypoints.index('left_hip')])
nose_idx = dataset_keypoints.index('nose')
if sc_mid_shoulder > kp_thresh and kps[2, nose_idx] > kp_thresh:
cv2.line(
kp_mask, tuple(mid_shoulder), tuple(kps[:2, nose_idx]),
color=colors[len(kp_lines)], thickness=2, lineType=cv2.LINE_AA)
if sc_mid_shoulder > kp_thresh and sc_mid_hip > kp_thresh:
cv2.line(
kp_mask, tuple(mid_shoulder), tuple(mid_hip),
color=colors[len(kp_lines) + 1], thickness=2, lineType=cv2.LINE_AA)

# Draw the keypoints.
for l in range(len(kp_lines)):
i1 = kp_lines[l][0]
i2 = kp_lines[l][1]
p1 = kps[0, i1], kps[1, i1]
p2 = kps[0, i2], kps[1, i2]
if kps[2, i1] > kp_thresh and kps[2, i2] > kp_thresh:
cv2.line(
kp_mask, p1, p2,
color=colors[l], thickness=2, lineType=cv2.LINE_AA)
if kps[2, i1] > kp_thresh:
cv2.circle(
kp_mask, p1,
radius=3, color=colors[l], thickness=-1, lineType=cv2.LINE_AA)
if kps[2, i2] > kp_thresh:
cv2.circle(
kp_mask, p2,
radius=3, color=colors[l], thickness=-1, lineType=cv2.LINE_AA)

# Blend the keypoints.
return cv2.addWeighted(img, 1.0 - alpha, kp_mask, alpha, 0)
15 changes: 14 additions & 1 deletion maskrcnn_benchmark/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_C.MODEL = CN()
_C.MODEL.RPN_ONLY = False
_C.MODEL.MASK_ON = False
_C.MODEL.KEYPOINT_ON = False
_C.MODEL.DEVICE = "cuda"
_C.MODEL.META_ARCHITECTURE = "GeneralizedRCNN"
_C.MODEL.CLS_AGNOSTIC_BBOX_REG = False
Expand All @@ -38,7 +39,7 @@
# -----------------------------------------------------------------------------
_C.INPUT = CN()
# Size of the smallest side of the image during training
_C.INPUT.MIN_SIZE_TRAIN = 800 # (800,)
_C.INPUT.MIN_SIZE_TRAIN = (800,) # (800,)
# Maximum size of the side of the image during training
_C.INPUT.MAX_SIZE_TRAIN = 1333
# Size of the smallest side of the image during testing
Expand Down Expand Up @@ -232,6 +233,18 @@
# GN
_C.MODEL.ROI_MASK_HEAD.USE_GN = False

_C.MODEL.ROI_KEYPOINT_HEAD = CN()
_C.MODEL.ROI_KEYPOINT_HEAD.FEATURE_EXTRACTOR = "KeypointRCNNFeatureExtractor"
_C.MODEL.ROI_KEYPOINT_HEAD.PREDICTOR = "KeypointRCNNPredictor"
_C.MODEL.ROI_KEYPOINT_HEAD.POOLER_RESOLUTION = 14
_C.MODEL.ROI_KEYPOINT_HEAD.POOLER_SAMPLING_RATIO = 0
_C.MODEL.ROI_KEYPOINT_HEAD.POOLER_SCALES = (1.0 / 16,)
_C.MODEL.ROI_KEYPOINT_HEAD.MLP_HEAD_DIM = 1024
_C.MODEL.ROI_KEYPOINT_HEAD.CONV_LAYERS = tuple(512 for _ in range(8))
_C.MODEL.ROI_KEYPOINT_HEAD.RESOLUTION = 14
_C.MODEL.ROI_KEYPOINT_HEAD.NUM_CLASSES = 17
_C.MODEL.ROI_KEYPOINT_HEAD.SHARE_BOX_FEATURE_EXTRACTOR = True

# ---------------------------------------------------------------------------- #
# ResNe[X]t options (ResNets = {ResNet, ResNeXt}
# Note that parts of a resnet may be used for both the backbone and the head
Expand Down
Loading

0 comments on commit e0a525a

Please sign in to comment.