Skip to content

Commit

Permalink
fix bug (#2312)
Browse files Browse the repository at this point in the history
  • Loading branch information
KsenijaS authored Jun 11, 2020
1 parent 883f1fb commit 747f406
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 20 deletions.
21 changes: 2 additions & 19 deletions test/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,34 +446,17 @@ def test_heatmaps_to_keypoints(self):
assert torch.all(out2[1].eq(out_trace2[1]))

def test_keypoint_rcnn(self):
class KeyPointRCNN(torch.nn.Module):
def __init__(self):
super(KeyPointRCNN, self).__init__()
self.model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(
pretrained=True, min_size=200, max_size=300)

def forward(self, images):
output = self.model(images)
# TODO: The keypoints_scores require the use of Argmax that is updated in ONNX.
# For now we are testing all the output of KeypointRCNN except keypoints_scores.
# Enable When Argmax is updated in ONNX Runtime.
return output[0]['boxes'], output[0]['labels'], output[0]['scores'], output[0]['keypoints']

images, test_images = self.get_test_images()
# TODO:
# Enable test for dummy_image (no detection) once issue is
# _onnx_heatmaps_to_keypoints_loop for empty heatmaps is fixed
dummy_images = [torch.ones(3, 100, 100) * 0.3]
model = KeyPointRCNN()
model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300)
model.eval()
model(images)
self.run_model(model, [(images,), (test_images,), (dummy_images,)],
input_names=["images_tensors"],
output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
dynamic_axes={"images_tensors": [0, 1, 2, 3]},
tolerate_small_mismatch=True)
# TODO: enable this test once dynamic model export is fixed
# Test exported model for an image with no detections on other images

self.run_model(model, [(dummy_images,), (test_images,)],
input_names=["images_tensors"],
output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
Expand Down
6 changes: 5 additions & 1 deletion torchvision/models/detection/roi_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,12 @@ def _onnx_heatmaps_to_keypoints(maps, maps_i, roi_map_width, roi_map_height,
xy_preds_i_2.to(dtype=torch.float32)], 0)

# TODO: simplify when indexing without rank will be supported by ONNX
base = num_keypoints * num_keypoints + num_keypoints + 1
ind = torch.arange(num_keypoints)
ind = ind.to(dtype=torch.int64) * base
end_scores_i = roi_map.index_select(1, y_int.to(dtype=torch.int64)) \
.index_select(2, x_int.to(dtype=torch.int64))[:num_keypoints, 0, 0]
.index_select(2, x_int.to(dtype=torch.int64)).view(-1).index_select(0, ind.to(dtype=torch.int64))

return xy_preds_i, end_scores_i


Expand Down

0 comments on commit 747f406

Please sign in to comment.