Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Deformable Convolution operation. #1586

Merged
merged 9 commits into from
Dec 4, 2019
Merged

Conversation

pedrofreire
Copy link
Contributor

For #1416, this adds the deformable convolution operation, as described in Deformable Convolutional Networks (https://arxiv.org/abs/1703.06211).

  • The code is based on https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp ; the whole code was modified and refactored to remove redundancies and increase clarity, and to adapt it to torchvision.

  • The CPU part is a direct copy of the CUDA code; it might make sense to do follow-up adjustments in the CPU code to simplify it / optimize it, or to reuse functionality between CPU and CUDA..

  • We also add tests; they likely can be made more robust by doing multiple calls with random parameters.

@pedrofreire
Copy link
Contributor Author

I did not found which linter is used for C++, so the C++/CUDA code is currently without linting.
I also have to check if the code is Python 2 compatible :)

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR @pedrofreire !

I did a first pass and have a few comments on the Python side.
Also, can you rebase the PR on top of master? I've fixed CI and it should be all green now.

About lint, we use clang-format, you can run https://github.com/pytorch/vision/blob/master/travis-scripts/run-clang-format/run-clang-format.py to check for lint errors, or run clang-format to reformat the code for you.

torchvision/ops/deform_conv.py Outdated Show resolved Hide resolved
torchvision/ops/deform_conv.py Outdated Show resolved Hide resolved
torchvision/ops/deform_conv.py Outdated Show resolved Hide resolved
torchvision/ops/deform_conv.py Outdated Show resolved Hide resolved
torchvision/ops/deform_conv.py Outdated Show resolved Hide resolved
torchvision/ops/deform_conv.py Outdated Show resolved Hide resolved
torchvision/ops/deform_conv.py Outdated Show resolved Hide resolved
torchvision/ops/deform_conv.py Outdated Show resolved Hide resolved
torchvision/ops/deform_conv.py Outdated Show resolved Hide resolved
test/test_ops.py Outdated Show resolved Hide resolved
Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice improvements!

I have a few more comments. I think that we might want to pass the offsets as an argument as well to the nn.Module, sorry about the confusion before!

Also, we have some compilation failures (probably due to an old gcc) in https://travis-ci.org/pytorch/vision/builds/614474253?utm_source=github_status&utm_medium=notification , can you have a look? Here is one instance of the error

/tmp/pip-req-build-440z_mul/torchvision/csrc/cpu/DeformConv_cpu.cpp:984:58: error: converting to ‘std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>’ from initializer list would use explicit constructor ‘constexpr std::tuple< <template-parameter-1-1> >::tuple(_UElements&& ...) [with _UElements = {at::Tensor&, at::Tensor&, at::Tensor&, at::Tensor&}; <template-parameter-2-2> = void; _Elements = {at::Tensor, at::Tensor, at::Tensor, at::Tensor}]’
     return {grad_input, grad_weight, grad_offset, grad_bias};

I'm ccing @ppwwyyxx and @rbgirshick for awareness, because detectron2 has an implementation of deform conv2d (inspired from mmdetection from my understanding), but without CPU nor torchscript support.

torchvision/ops/deform_conv.py Outdated Show resolved Hide resolved
torchvision/ops/deform_conv.py Outdated Show resolved Hide resolved

out_channels = weight.shape[0]
if bias is None:
bias = torch.zeros(out_channels, device=input.device, dtype=input.dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can leave this as is, but ideally if the bias is None we would want to avoid adding it in the code.

One easy way of doing it would be something like

out = # output of deform_conv2d
if bias is not None:
    out = out + bias[:, None, :, :]

this is slightly less efficient than if we had it in the kernel, but saves some compute if the bias doesn't exist.
Also note that we can handle this on the C++ level with c10::optional.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can do this as a follow-up.
Btw, what is the best way to handle optionals and autograd::Variable?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can do this as a follow-up.

Sure, this is not a blocker for this PR

Btw, what is the best way to handle optionals and autograd::Variable?

I would have expected that c10::optional would be enough, but if that's not the case I'd need to look a bit more. Let's let this in a follow-up

torchvision/csrc/cpu/DeformConv_cpu.cpp Show resolved Hide resolved
Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One minor nit and then this is good to go, thanks a lot @pedrofreire !

Also, there are still some gradient errors in some of the tests, there might be a difference in compilers that induces a difference in behavior somewhere :-/ https://travis-ci.org/pytorch/vision/jobs/615614936?utm_medium=notification&utm_source=github_status

Once that's fixed, tests are all green and the code is rebased this is ready to merge, thanks a lot!

torchvision/ops/deform_conv.py Outdated Show resolved Hide resolved
Pedro Freire and others added 9 commits December 1, 2019 03:54
This adds the deformable convolution operation, as described in Deformable Convolutional Networks (https://arxiv.org/abs/1703.06211).

- The code is based on https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp ; the whole code was modified and refactored to remove redundancies and increase clarity, and to adapt it to torchvision.

- The CPU part is a direct copy of the CUDA code; it might make sense to do follow-up adjustments in the CPU code to simplify it / optimize it, or to reuse functionality between CPU and CUDA..

- We also add tests (with a non-trivial set of parameters); they can be made more robust by randomizing the parameters and executing multiple times.
* rename some variables and arguments to match Conv2d;
* add optional bias;
* add weight, offset and bias as module parameters;
* remove the n_parallel_imgs parameter;
* Fix __repr__;
* etc..

Initialization of weight and bias is the same as in Conv2d, and
initialization of offsets to zero is the same as in the paper.

This also includes some other small unrelated fixes/improvements.
- We pass the offset in the forward of DeformConv2d, instead of having
an internal parameter. This adds some complexity to creating the module
(e.g. now you have to worry about the output size, to create the
offset), but it gives more flexibility.
- We also use make_tuple for tuple creation, in an attempt to fix error
w/ older compilers.
Old gcc versions were giving wrong results here, because they would
resolve abs as int -> int, thus causing undesired truncation. Replacing
abs by std::abs should allow for correct overloading of abs as float -> float.
We place offset arg before the weight arg, to be more
consistent with DeformConv2d.forward(input, offset)
Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thanks a lot @pedrofreire !

@breznak
Copy link

breznak commented Jan 24, 2020

Thanks for adding DCN! 💯

I have a question, does this support DCNv2?
https://github.com/CharlesShang/DCNv2
https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch

Also what is the difference between the two repos?

@pedrofreire
Copy link
Contributor Author

This PR does not include DCNv2 unfortunately, though it should be straightforward, using this PR as a reference, to port them from mmdetection or one of the repos you mentioned (I have not looked for the differences between them, if there are any).

@breznak
Copy link

breznak commented Jan 24, 2020

Thank you for the swift answer! I'm not sure if I'm upto the task, but I might try. Or someone can join me.

@lucasjinreal
Copy link

Any updates on this issue?

@fmassa
Copy link
Member

fmassa commented Apr 6, 2020

@jinfagang which updates are you talking about?

@pedrofreire
Copy link
Contributor Author

I guess he is talking about having support for DCNv2 (this PR only included the original DCN) - since that is what the most recent comments discuss about.

@fmassa
Copy link
Member

fmassa commented Apr 6, 2020

Thanks for the clarification @pedrofreire
If that's indeed the question, then no updates for DCNv2 for now.

@idow09
Copy link

idow09 commented Jun 22, 2020

Hi! Thanks for the useful op.
Why can't I find this anywhere on the official documentation?

@fmassa
Copy link
Member

fmassa commented Jun 22, 2020

@idow09 we forgot to add it (plus PS RoIAlign / PS RoI Pool) to https://github.com/pytorch/vision/blob/master/docs/source/ops.rst

A PR adding those is welcome.

@idow09
Copy link

idow09 commented Jun 22, 2020

Okay I'll see if I can manage. Thanks again

fmassa pushed a commit to fmassa/vision-1 that referenced this pull request Jul 8, 2020
* Add Deformable Convolution operation.

This adds the deformable convolution operation, as described in Deformable Convolutional Networks (https://arxiv.org/abs/1703.06211).

- The code is based on https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp ; the whole code was modified and refactored to remove redundancies and increase clarity, and to adapt it to torchvision.

- The CPU part is a direct copy of the CUDA code; it might make sense to do follow-up adjustments in the CPU code to simplify it / optimize it, or to reuse functionality between CPU and CUDA..

- We also add tests (with a non-trivial set of parameters); they can be made more robust by randomizing the parameters and executing multiple times.

* Update DeformConv to be more consistent w/ Conv2d

* rename some variables and arguments to match Conv2d;
* add optional bias;
* add weight, offset and bias as module parameters;
* remove the n_parallel_imgs parameter;
* Fix __repr__;
* etc..

Initialization of weight and bias is the same as in Conv2d, and
initialization of offsets to zero is the same as in the paper.

This also includes some other small unrelated fixes/improvements.

* Apply clang-format in DeformConv files.

* Import Optional type annotation

* Remove offset param from DeformConv2d module

- We pass the offset in the forward of DeformConv2d, instead of having
an internal parameter. This adds some complexity to creating the module
(e.g. now you have to worry about the output size, to create the
offset), but it gives more flexibility.
- We also use make_tuple for tuple creation, in an attempt to fix error
w/ older compilers.

* Replace abs by std::abs

Old gcc versions were giving wrong results here, because they would
resolve abs as int -> int, thus causing undesired truncation. Replacing
abs by std::abs should allow for correct overloading of abs as float -> float.

* Reorder declarations for clarity

* Reorder weight and offset args in deform_conv2d

We place offset arg before the weight arg, to be more
consistent with DeformConv2d.forward(input, offset)

* Replace abs by std::abs in DeformConv_cuda
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants