Skip to content

Commit

Permalink
ConstantPad2d and F.pad (pytorch#856)
Browse files Browse the repository at this point in the history
  • Loading branch information
szagoruyko authored and apaszke committed Mar 1, 2017
1 parent 37e0548 commit 12efd53
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 1 deletion.
13 changes: 12 additions & 1 deletion test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch.nn.init as init
import torch.nn.utils.rnn as rnn_utils
from torch.nn.utils import clip_grad_norm
from torch.autograd import Variable
from torch.autograd import Variable, gradcheck
from torch.nn import Parameter
from common_nn import NNTestCase, ModuleTest, CriterionTest, TestBase, \
module_tests, criterion_tests, TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, \
Expand Down Expand Up @@ -635,6 +635,17 @@ def test_Dropout3d(self):
input = torch.Tensor(num_features, b, d, w, h)
self._test_dropout(nn.Dropout3d, input)

def test_pad(self):
inputs = Variable(torch.randn(1, 3, 4, 4), requires_grad=True)
gradcheck(lambda x: F.pad(x, (1, 1, 1, 1)), (inputs,))
gradcheck(lambda x: F.pad(x, (-1, 1, -2, 1)), (inputs,))
gradcheck(lambda x: F.pad(x, (-1, 1, -2, 1), value=2), (inputs,))
gradcheck(lambda x: F.pad(x, (-1, 1, -2, 1), mode='replicate'), (inputs,))
gradcheck(lambda x: F.pad(x, (-1, 1, -2, 1), mode='reflect'), (inputs,))

inputs = Variable(torch.randn(1, 2, 3, 4, 4), requires_grad=True)
gradcheck(lambda x: F.pad(x, (1, 1, 1, 1, 1, 1), mode='replicate'), (inputs,))

def _test_maxpool_indices(self, num_dim, type=torch.FloatTensor):
def expected_indices(dim):
if dim == 1:
Expand Down
72 changes: 72 additions & 0 deletions torch/nn/_functions/padding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from torch.autograd import Function


class ConstantPad2d(Function):

def __init__(self, pad, value=0):
super(ConstantPad2d, self).__init__()
self.pad = pad
self.value = value

def forward(self, input):
assert input.dim() == 4, 'only 4D supported for padding'
pad_l, pad_r, pad_t, pad_b = self.pad
h = input.size(2) + pad_t + pad_b
w = input.size(3) + pad_l + pad_r
assert w > 0 and h > 0, 'input is too small'

self.input_size = input.size()

# crop input if necessary
output = input.new(input.size(0), input.size(1), h, w).fill_(self.value)
c_input = input
if pad_t < 0:
c_input = c_input.narrow(2, -pad_t, c_input.size(2) + pad_t)
if pad_b < 0:
c_input = c_input.narrow(2, 0, c_input.size(2) + pad_b)
if pad_l < 0:
c_input = c_input.narrow(3, -pad_l, c_input.size(3) + pad_l)
if pad_r < 0:
c_input = c_input.narrow(3, 0, c_input.size(3) + pad_r)

# crop output if necessary
c_output = output
if pad_t > 0:
c_output = c_output.narrow(2, pad_t, c_output.size(2) - pad_t)
if pad_b > 0:
c_output = c_output.narrow(2, 0, c_output.size(2) - pad_b)
if pad_l > 0:
c_output = c_output.narrow(3, pad_l, c_output.size(3) - pad_l)
if pad_r > 0:
c_output = c_output.narrow(3, 0, c_output.size(3) - pad_r)
c_output.copy_(c_input)
return output

def backward(self, grad_output):
pad_l, pad_r, pad_t, pad_b = self.pad

grad_input = grad_output.new(self.input_size).zero_()

# crop grad_input if necessary
cg_input = grad_input
if pad_t < 0:
cg_input = cg_input.narrow(2, -pad_t, cg_input.size(2) + pad_t)
if pad_b < 0:
cg_input = cg_input.narrow(2, 0, cg_input.size(2) + pad_b)
if pad_l < 0:
cg_input = cg_input.narrow(3, -pad_l, cg_input.size(3) + pad_l)
if pad_r < 0:
cg_input = cg_input.narrow(3, 0, cg_input.size(3) + pad_r)

# crop grad_output if necessary
cg_output = grad_output
if pad_t > 0:
cg_output = cg_output.narrow(2, pad_t, cg_output.size(2) - pad_t)
if pad_b > 0:
cg_output = cg_output.narrow(2, 0, cg_output.size(2) - pad_b)
if pad_l > 0:
cg_output = cg_output.narrow(3, pad_l, cg_output.size(3) - pad_l)
if pad_r > 0:
cg_output = cg_output.narrow(3, 0, cg_output.size(3) - pad_r)
cg_input.copy_(cg_output)
return grad_input
34 changes: 34 additions & 0 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from . import _functions
from .modules import utils
from torch.nn._functions.conv import ConvNd
from ._functions.padding import ConstantPad2d
from .modules.utils import _single, _pair, _triple
# Convolutions

Expand Down Expand Up @@ -531,3 +532,36 @@ def upsample_bilinear(input, size=None, scale_factor=None):
scale_factor (int): multiplier for spatial size. Has to be an integer.
"""
return _functions.thnn.UpsamplingBilinear2d(size, scale_factor)(input)


def pad(input, pad, mode='constant', value=0):
"""Pads tensor.
Currently only 2D and 3D padding supported.
In case of 4D input tensor pad should be in form (pad_l, pad_r, pad_t, pad_b )
In case of 5D pad should be (pleft, pright, ptop, pbottom, pfront, pback)
Args
input (Variable): 4D or 5D tensor
pad (tuple): 4-elem or 6-elem tuple
mode: 'constant', 'reflect' or 'replicate'
value: fill value for 'constant' padding
"""
if input.dim() == 4:
assert len(pad) == 4, '4D tensors expect 4 values for padding'
if mode == 'constant':
return ConstantPad2d(pad, value)(input)
elif mode == 'reflect':
return _functions.thnn.ReflectionPad2d(*pad)(input)
elif mode == 'replicate':
return _functions.thnn.ReplicationPad2d(*pad)(input)
elif input.dim() == 5:
assert len(pad) == 6, '5D tensors expect 6 values for padding'
if mode == 'constant':
raise NotImplementedError
elif mode == 'reflect':
raise NotImplementedError
elif mode == 'replicate':
return _functions.thnn.ReplicationPad3d(*pad)(input)
else:
raise NotImplementedError("Only 4D and 5D padding is supported for now")

0 comments on commit 12efd53

Please sign in to comment.