Skip to content

Commit f0883a5

Browse files
committed
Merge branch 'main' of github.com:pytorch/vision into country_dataset
2 parents 9eb4f11 + cc0d1be commit f0883a5

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

torchvision/models/detection/keypoint_rcnn.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,16 +188,18 @@ def __init__(
188188
keypoint_roi_pool=None,
189189
keypoint_head=None,
190190
keypoint_predictor=None,
191-
num_keypoints=17,
191+
num_keypoints=None,
192192
):
193193

194194
assert isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None)))
195195
if min_size is None:
196196
min_size = (640, 672, 704, 736, 768, 800)
197197

198-
if num_classes is not None:
198+
if num_keypoints is not None:
199199
if keypoint_predictor is not None:
200-
raise ValueError("num_classes should be None when keypoint_predictor is specified")
200+
raise ValueError("num_keypoints should be None when keypoint_predictor is specified")
201+
else:
202+
num_keypoints = 17
201203

202204
out_channels = backbone.out_channels
203205

0 commit comments

Comments
 (0)