Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Data necessary to re-train keypointrcnn_resnet50_fpn missing #7063

Open
oliverdain opened this issue Jan 6, 2023 · 3 comments
Open

Data necessary to re-train keypointrcnn_resnet50_fpn missing #7063

oliverdain opened this issue Jan 6, 2023 · 3 comments

Comments

@oliverdain
Copy link

oliverdain commented Jan 6, 2023

🐛 Describe the bug

I'm trying to retrain a keypointrcnn_resnet50_fpn model. Specifically, I've removed the roi_heads.keypoint_predictor and replaced it with a new one that matches the number of keypoints in my case. I frozen most of the rest of the model and I'm trying to just train that last layer and the roi_heads layers. This all works. My basic starting point is:

def get_trainable_mask_rcnn(
    num_keypoints: int,
    unfreeze_keypoint_head: bool,
) -> Tuple[MaskRCNN, Callable[[torch.Tensor], torch.Tensor]]:
    weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
    # Only 2 classes per bounding box: dog (class 1) or background (class 0)
    model = keypointrcnn_resnet50_fpn(weights=weights, num_classes=2)

    # First freeze the model.
    model.requires_grad_(False)

    # Figure out the number of input channels to the final layer (so this works even if we change the pre-trained
    # architecture a bit).
    first_module_in_keypoint_pred = next(model.roi_heads.keypoint_predictor.children())
    pred_layer_map_channels = first_module_in_keypoint_pred.in_channels
    # Now replace that final layer.
    model.roi_heads.keypoint_predictor = KeypointRCNNPredictor(pred_layer_map_channels, num_keypoints)

    if unfreeze_keypoint_head:
        model.roi_heads.keypoint_head.requires_grad_(True)

    return (model, weights.transforms())

From here I can train the model. But, it's very hard to do it well because what the model returns is very unusual. Specifically, when the model is model.train() it returns only losses but not the keypoint locations and when the model is model.eval it returns only the boxes, keypoints, etc. but not the losses. Thus, if you want to do something standard like track train and validation losses you have to reverse engineer the loss computations that are done in training mode and compute them manually. Similarly, if you want to track something like percent of keypoints correct for both train and validation that's very hard to do because you can't get the keypoints to compute such a thing for the training set.

From https://github.com/pytorch/vision/blob/main/torchvision/models/detection/keypoint_rcnn.py:

    During training, the model expects both the input tensors, as well as a targets (list of dictionary),
    containing:
        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
            ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
        - labels (Int64Tensor[N]): the class label for each ground-truth box
        - keypoints (FloatTensor[N, K, 3]): the K keypoints location for each of the N instances, in the
          format [x, y, visibility], where visibility=0 means that the keypoint is not visible.
    The model returns a Dict[Tensor] during training, containing the classification and regression
    losses for both the RPN and the R-CNN, and the keypoint loss.
    During inference, the model requires only the input tensors, and returns the post-processed
    predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
    follows:
        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
            ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
        - labels (Int64Tensor[N]): the predicted labels for each image
        - scores (Tensor[N]): the scores or each prediction
        - keypoints (FloatTensor[N, K, 3]): the locations of the predicted keypoints, in [x, y, v] format.

I think that most users would like to be able to get losses and boxes/keypoints in both train and validation mode. I suggest either extending the returned dicts to contain that information (in eval() mode you'd return the losses if and only if the targets were provided). Alternatively, you could just return the various raw tensors (heatmap, bounding boxes, etc.) and provide some simple functions that take those and produce heatmaps, keypoints, etc. so the user can use things more flexibly.

I'm sure I can find a way to reverse engineer things to make this work but that's a little bit not-so-user-friendly. But, more important, it seems brittle as an upgrade to torchvision is likely to break things in unforeseen ways.

Versions

PyTorch version: 1.13.0+cu116
Is debug build: False
CUDA used to build PyTorch: 11.6
ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux 11 (bullseye) (x86_64)
GCC version: (Debian 10.2.1-6) 10.2.1 20210110
Clang version: Could not collect
CMake version: version 3.18.4
Libc version: glibc-2.31

Python version: 3.10.6 (main, Oct 11 2022, 01:13:46) [GCC 10.2.1 20210110] (64-bit runtime)
Python platform: Linux-5.10.0-20-cloud-amd64-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.8.89
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: Tesla T4
Nvidia driver version: 525.60.13
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.24.1
[pip3] torch==1.13.0+cu116
[pip3] torchvision==0.14.0+cu116
[conda] Could not collect

cc @datumbox

@oliverdain
Copy link
Author

I think the returned keypoints_scores aren't meaningful. As I understand it, MaskRCNN for keypoints does a softmax over the keypoints heatmap. However, the heatmaps_to_keypoints method seems to take the heatmap raw values, before they've been through the softmax, reshape them to fix the ROI box via bicubic interpolation, and then simply return the maximum value. That value is meaningless without the other values in the heatmap. Indeed, if you run the values in keypoints_scores through a logistic they almost all come back as > 0.99 because the other values in the heatmap are also high so the actual probability are quite low. I think you probably want to compute the softmax before returning / computing the keypoints_scores or, probably better, just return the re-scaled heatmap directly.

@NicolasHug
Copy link
Member

Thank you for the detailed report @oliverdain

I think that most users would like to be able to get losses and boxes/keypoints in both train and validation mode.

Indeed, this has been a popular feature request for a while. We're keeping track of it in #1574, I will add an entry there to remember to address the Keypoint models as well.

This is something we want to do, but to be perfectly honest I can't provide you with an ETA on when this will be done. The main difficulty here is that addressing this use-case would certainly mean breaking backward compatibility, and this is something we try to avoid because it can cause a lot of disruption. We're still brainstorming the best way to do that.

@oliverdain
Copy link
Author

Thanks @NicolasHug

The main difficulty here is that addressing this use-case would certainly mean breaking backward compatibility

Currently the forward method returns a Dict[str, Tensor] in both train and eval mode. I think adding a few more keys to that dict to contain the losses, etc. probably isn't a breaking change for most people.

Any thoughts on the model returning the wrong score for the keypoints?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants