-
Notifications
You must be signed in to change notification settings - Fork 9.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add affine_grid * missing setup * remove import * reformat * rename and reformat * reformat cpp
- Loading branch information
Showing
5 changed files
with
193 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .affine_grid import affine_grid | ||
|
||
__all__ = ['affine_grid'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
from torch.autograd import Function | ||
from torch.autograd.function import once_differentiable | ||
|
||
from . import affine_grid_cuda | ||
|
||
|
||
class _AffineGridGenerator(Function): | ||
|
||
@staticmethod | ||
def forward(ctx, theta, size, align_corners): | ||
|
||
ctx.save_for_backward(theta) | ||
ctx.size = size | ||
ctx.align_corners = align_corners | ||
|
||
func = affine_grid_cuda.affine_grid_generator_forward | ||
|
||
output = func(theta, size, align_corners) | ||
|
||
return output | ||
|
||
@staticmethod | ||
@once_differentiable | ||
def backward(ctx, grad_output): | ||
theta = ctx.saved_tensors | ||
size = ctx.size | ||
align_corners = ctx.align_corners | ||
|
||
func = affine_grid_cuda.affine_grid_generator_backward | ||
|
||
grad_input = func(grad_output, theta, size, align_corners) | ||
|
||
return grad_input, None, None | ||
|
||
|
||
def affine_grid(theta, size, align_corners=False): | ||
if torch.__version__ >= '1.3': | ||
return F.affine_grid(theta, size, align_corners) | ||
elif align_corners: | ||
return F.affine_grid(theta, size) | ||
else: | ||
# enforce floating point dtype on theta | ||
if not theta.is_floating_point(): | ||
raise ValueError( | ||
'Expected theta to have floating point type, but got {}'. | ||
format(theta.dtype)) | ||
# check that shapes and sizes match | ||
if len(size) == 4: | ||
if theta.dim() != 3 or theta.size(-2) != 2 or theta.size(-1) != 3: | ||
raise ValueError( | ||
'Expected a batch of 2D affine matrices of shape Nx2x3 ' | ||
'for size {}. Got {}.'.format(size, theta.shape)) | ||
elif len(size) == 5: | ||
if theta.dim() != 3 or theta.size(-2) != 3 or theta.size(-1) != 4: | ||
raise ValueError( | ||
'Expected a batch of 3D affine matrices of shape Nx3x4 ' | ||
'for size {}. Got {}.'.format(size, theta.shape)) | ||
else: | ||
raise NotImplementedError( | ||
'affine_grid only supports 4D and 5D sizes, ' | ||
'for 2D and 3D affine transforms, respectively. ' | ||
'Got size {}.'.format(size)) | ||
if min(size) <= 0: | ||
raise ValueError( | ||
'Expected non-zero, positive output size. Got {}'.format(size)) | ||
return _AffineGridGenerator.apply(theta, size, align_corners) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
// Modified from https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AffineGridGenerator.cpp | ||
#include <ATen/ATen.h> | ||
#include <ATen/NativeFunctions.h> | ||
#include <torch/extension.h> | ||
|
||
namespace mmdetection { | ||
|
||
using namespace at; | ||
|
||
at::Tensor linspace_from_neg_one(const Tensor& grid, int64_t num_steps, | ||
bool align_corners) { | ||
if (num_steps <= 1) { | ||
return at::tensor(0, grid.options()); | ||
} | ||
auto range = at::linspace(-1, 1, num_steps, grid.options()); | ||
if (!align_corners) { | ||
range = range * (num_steps - 1) / num_steps; | ||
} | ||
return range; | ||
} | ||
|
||
Tensor make_base_grid_4D(const Tensor& theta, int64_t N, int64_t C, int64_t H, | ||
int64_t W, bool align_corners) { | ||
auto base_grid = at::empty({N, H, W, 3}, theta.options()); | ||
|
||
base_grid.select(-1, 0).copy_(linspace_from_neg_one(theta, W, align_corners)); | ||
base_grid.select(-1, 1).copy_( | ||
linspace_from_neg_one(theta, H, align_corners).unsqueeze_(-1)); | ||
base_grid.select(-1, 2).fill_(1); | ||
|
||
return base_grid; | ||
} | ||
|
||
Tensor make_base_grid_5D(const Tensor& theta, int64_t N, int64_t C, int64_t D, | ||
int64_t H, int64_t W, bool align_corners) { | ||
auto base_grid = at::empty({N, D, H, W, 4}, theta.options()); | ||
|
||
base_grid.select(-1, 0).copy_(linspace_from_neg_one(theta, W, align_corners)); | ||
base_grid.select(-1, 1).copy_( | ||
linspace_from_neg_one(theta, H, align_corners).unsqueeze_(-1)); | ||
base_grid.select(-1, 2).copy_(linspace_from_neg_one(theta, D, align_corners) | ||
.unsqueeze_(-1) | ||
.unsqueeze_(-1)); | ||
base_grid.select(-1, 3).fill_(1); | ||
|
||
return base_grid; | ||
} | ||
|
||
Tensor affine_grid_generator_4D_forward(const Tensor& theta, int64_t N, | ||
int64_t C, int64_t H, int64_t W, | ||
bool align_corners) { | ||
Tensor base_grid = make_base_grid_4D(theta, N, C, H, W, align_corners); | ||
auto grid = base_grid.view({N, H * W, 3}).bmm(theta.transpose(1, 2)); | ||
return grid.view({N, H, W, 2}); | ||
} | ||
|
||
Tensor affine_grid_generator_5D_forward(const Tensor& theta, int64_t N, | ||
int64_t C, int64_t D, int64_t H, | ||
int64_t W, bool align_corners) { | ||
Tensor base_grid = make_base_grid_5D(theta, N, C, D, H, W, align_corners); | ||
auto grid = base_grid.view({N, D * H * W, 4}).bmm(theta.transpose(1, 2)); | ||
return grid.view({N, D, H, W, 3}); | ||
} | ||
|
||
Tensor affine_grid_generator_forward(const Tensor& theta, IntArrayRef size, | ||
bool align_corners) { | ||
if (size.size() == 4) { | ||
return affine_grid_generator_4D_forward(theta, size[0], size[1], size[2], | ||
size[3], align_corners); | ||
} else { | ||
return affine_grid_generator_5D_forward(theta, size[0], size[1], size[2], | ||
size[3], size[4], align_corners); | ||
} | ||
} | ||
|
||
Tensor affine_grid_generator_4D_backward(const Tensor& grad_grid, int64_t N, | ||
int64_t C, int64_t H, int64_t W, | ||
bool align_corners) { | ||
auto base_grid = make_base_grid_4D(grad_grid, N, C, H, W, align_corners); | ||
AT_ASSERT(grad_grid.sizes() == IntArrayRef({N, H, W, 2})); | ||
auto grad_theta = base_grid.view({N, H * W, 3}) | ||
.transpose(1, 2) | ||
.bmm(grad_grid.view({N, H * W, 2})); | ||
return grad_theta.transpose(1, 2); | ||
} | ||
|
||
Tensor affine_grid_generator_5D_backward(const Tensor& grad_grid, int64_t N, | ||
int64_t C, int64_t D, int64_t H, | ||
int64_t W, bool align_corners) { | ||
auto base_grid = make_base_grid_5D(grad_grid, N, C, D, H, W, align_corners); | ||
AT_ASSERT(grad_grid.sizes() == IntArrayRef({N, D, H, W, 3})); | ||
auto grad_theta = base_grid.view({N, D * H * W, 4}) | ||
.transpose(1, 2) | ||
.bmm(grad_grid.view({N, D * H * W, 3})); | ||
return grad_theta.transpose(1, 2); | ||
} | ||
|
||
Tensor affine_grid_generator_backward(const Tensor& grad, IntArrayRef size, | ||
bool align_corners) { | ||
if (size.size() == 4) { | ||
return affine_grid_generator_4D_backward(grad, size[0], size[1], size[2], | ||
size[3], align_corners); | ||
} else { | ||
return affine_grid_generator_5D_backward(grad, size[0], size[1], size[2], | ||
size[3], size[4], align_corners); | ||
} | ||
} | ||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||
m.def("affine_grid_generator_forward", &affine_grid_generator_forward, | ||
"affine_grid_generator_forward"); | ||
m.def("affine_grid_generator_backward", &affine_grid_generator_backward, | ||
"affine_grid_generator_backward"); | ||
} | ||
|
||
} // namespace mmdetection |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters