Skip to content

Commit

Permalink
Add missing Modules to nn.functional (pytorch#1801)
Browse files Browse the repository at this point in the history
* add dropout2d and dropout3d to functional

added some loss functions to functional

added tests

using dropout from backend

added docs

fixes

* edited loss modules to call functional
  • Loading branch information
aron-bordin authored and soumith committed Jul 19, 2017
1 parent 31894ca commit 11f3ccf
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 40 deletions.
71 changes: 63 additions & 8 deletions docs/source/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,16 @@ Dropout functions

.. autofunction:: alpha_dropout

:hidden:`dropout2d`
~~~~~~~~~~~~~~~~~~~

.. autofunction:: dropout2d

:hidden:`dropout3d`
~~~~~~~~~~~~~~~~~~~

.. autofunction:: dropout3d

Distance functions
----------------------------------

Expand All @@ -930,30 +940,70 @@ Distance functions
Loss functions
--------------

:hidden:`nll_loss`
~~~~~~~~~~~~~~~~~~
:hidden:`binary_cross_entropy`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: nll_loss
.. autofunction:: binary_cross_entropy

:hidden:`poisson_nll_loss`
~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: poisson_nll_loss

:hidden:`kl_div`
~~~~~~~~~~~~~~~~
:hidden:`cosine_embedding_loss`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: kl_div
.. autofunction:: cosine_embedding_loss

:hidden:`cross_entropy`
~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: cross_entropy

:hidden:`binary_cross_entropy`
:hidden:`hinge_embedding_loss`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: binary_cross_entropy
.. autofunction:: hinge_embedding_loss

:hidden:`kl_div`
~~~~~~~~~~~~~~~~

.. autofunction:: kl_div

:hidden:`l1_loss`
~~~~~~~~~~~~~~~~~

.. autofunction:: l1_loss

:hidden:`mse_loss`
~~~~~~~~~~~~~~~~~~

.. autofunction:: mse_loss

:hidden:`margin_ranking_loss`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: margin_ranking_loss

:hidden:`multilabel_margin_loss`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: multilabel_margin_loss

:hidden:`multilabel_soft_margin_loss`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: multilabel_soft_margin_loss

:hidden:`multi_margin_loss`
~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: multi_margin_loss

:hidden:`nll_loss`
~~~~~~~~~~~~~~~~~~

.. autofunction:: nll_loss

:hidden:`binary_cross_entropy_with_logits`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -965,6 +1015,11 @@ Loss functions

.. autofunction:: smooth_l1_loss

:hidden:`soft_margin_loss`
~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: soft_margin_loss

:hidden:`triplet_margin_loss`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
1 change: 0 additions & 1 deletion test/common_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@
),
]


criterion_tests = [
dict(module_name='L1Loss',
input_size=(2, 3, 4),
Expand Down
1 change: 1 addition & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3669,6 +3669,7 @@ def add_test(test):
test_params['constructor'] = getattr(nn, name)
test = NewModuleTest(**test_params)
add_test(test)

for test_params in criterion_tests + new_criterion_tests:
name = test_params.pop('module_name')
test_params['constructor'] = getattr(nn, name)
Expand Down
57 changes: 55 additions & 2 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,14 @@ def alpha_dropout(input, p=0.5, training=False):
return output.mul_(a).add_(b)


def dropout2d(input, p=0.5, training=False, inplace=False):
return _functions.dropout.FeatureDropout.apply(input, p, training, inplace)


def dropout3d(input, p=0.5, training=False, inplace=False):
return _functions.dropout.FeatureDropout.apply(input, p, training, inplace)


def threshold(input, threshold, value, inplace=False):
return _functions.thnn.Threshold.apply(input, threshold, value, inplace)

Expand Down Expand Up @@ -632,7 +640,7 @@ def poisson_nll_loss(input, target, log_input=True, full=False, size_average=Tru
return torch.sum(loss)


def kl_div(input, target, size_average=True):
def kl_div(input, target, size_average=True, weight=None):
r"""The `Kullback-Leibler divergence`_ Loss.
See :class:`~torch.nn.KLDivLoss` for details.
Expand All @@ -642,8 +650,10 @@ def kl_div(input, target, size_average=True):
target: Variable of the same shape as input
size_average: if True the output is divided by the number of elements
in input tensor
weight (Tensor, optional): a manual rescaling weight given to each
class. If given, has to be a Tensor of size "nclasses"
"""
return _functions.thnn.KLDivLoss(size_average)(input, target)
return _functions.thnn.KLDivLoss(size_average, weight=weight)(input, target)


def cross_entropy(input, target, weight=None, size_average=True, ignore_index=-100):
Expand Down Expand Up @@ -730,6 +740,49 @@ def smooth_l1_loss(input, target, size_average=True):
return _functions.thnn.SmoothL1Loss(size_average)(input, target)


def l1_loss(input, target, size_average=True):
return _functions.thnn.L1Loss(size_average)(input, target)


def mse_loss(input, target, size_average=True):
return _functions.thnn.MSELoss(size_average)(input, target)


def margin_ranking_loss(input1, input2, target, margin=0, size_average=True):
return _functions.loss.MarginRankingLoss(margin, size_average)(input1, input2, target)


def hinge_embedding_loss(input, target, margin=1.0, size_average=True):
return _functions.loss.HingeEmbeddingLoss(margin, size_average)(input, target)


def multilabel_margin_loss(input, target, size_average=True):
return _functions.thnn.MultiLabelMarginLoss(size_average)(input, target)


def soft_margin_loss(input, target, size_average=True):
return _functions.thnn.SoftMarginLoss(size_average)(input, target)


def multilabel_soft_margin_loss(input, target, weight=None, size_average=True):
input = torch.sigmoid(input)
return binary_cross_entropy(input, target, weight, size_average)


def cosine_embedding_loss(input1, input2, target, margin=0, size_average=True):
return _functions.loss.CosineEmbeddingLoss(margin, size_average)(input1, input2, target)


def multi_margin_loss(input, target, p=1, margin=1, weight=None, size_average=True):
if p != 1 and p != 2:
raise ValueError('only p == 1 and p == 2 supported')
if weight is not None and weight.dim() != 1:
raise ValueError('weight must be one-dimensional')

return _functions.thnn.MultiMarginLoss(size_average, p, margin,
weight=weight)(input, target)


def pixel_shuffle(input, upscale_factor):
r"""Rearranges elements in a tensor of shape ``[*, C*r^2, H, W]`` to a
tensor of shape ``[C, H*r, W*r]``.
Expand Down
4 changes: 2 additions & 2 deletions torch/nn/modules/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(self, p=0.5, inplace=False):
self.inplace = inplace

def forward(self, input):
return self._backend.Dropout2d.apply(input, self.p, self.training, self.inplace)
return F.dropout2d(input, self.p, self.training, self.inplace)

def __repr__(self):
inplace_str = ', inplace' if self.inplace else ''
Expand Down Expand Up @@ -149,7 +149,7 @@ def __init__(self, p=0.5, inplace=False):
self.inplace = inplace

def forward(self, input):
return self._backend.Dropout3d.apply(input, self.p, self.training, self.inplace)
return F.dropout3d(input, self.p, self.training, self.inplace)

def __repr__(self):
inplace_str = ', inplace' if self.inplace else ''
Expand Down
55 changes: 28 additions & 27 deletions torch/nn/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,12 @@ def __init__(self, size_average=True):
super(_Loss, self).__init__()
self.size_average = size_average

def forward(self, input, target):
_assert_no_grad(target)
backend_fn = getattr(self._backend, type(self).__name__)
return backend_fn(self.size_average)(input, target)


class _WeightedLoss(_Loss):
def __init__(self, weight=None, size_average=True):
super(_WeightedLoss, self).__init__(size_average)
self.register_buffer('weight', weight)

def forward(self, input, target):
_assert_no_grad(target)
backend_fn = getattr(self._backend, type(self).__name__)
return backend_fn(self.size_average, weight=self.weight)(input, target)


class L1Loss(_Loss):
r"""Creates a criterion that measures the mean absolute value of the
Expand Down Expand Up @@ -66,7 +56,9 @@ class L1Loss(_Loss):
>>> output = loss(input, target)
>>> output.backward()
"""
pass
def forward(self, input, target):
_assert_no_grad(target)
return F.l1_loss(input, target, size_average=self.size_average)


class NLLLoss(_WeightedLoss):
Expand Down Expand Up @@ -236,7 +228,9 @@ class KLDivLoss(_WeightedLoss):
.. _Kullback-Leibler divergence:
https://en.wikipedia.org/wiki/Kullback-Leibler_divergence
"""
pass
def forward(self, input, target):
_assert_no_grad(target)
return F.kl_div(input, target, size_average=self.size_average, weight=self.weight)


class MSELoss(_Loss):
Expand Down Expand Up @@ -271,7 +265,9 @@ class MSELoss(_Loss):
>>> output = loss(input, target)
>>> output.backward()
"""
pass
def forward(self, input, target):
_assert_no_grad(target)
return F.mse_loss(input, target, size_average=self.size_average)


class BCELoss(_WeightedLoss):
Expand All @@ -293,7 +289,10 @@ class BCELoss(_WeightedLoss):
to `False`, the losses are instead summed.
"""
pass
def forward(self, input, target):
_assert_no_grad(target)
return F.binary_cross_entropy(input, target, weight=self.weight,
size_average=self.size_average)


class BCEWithLogitsLoss(Module):
Expand Down Expand Up @@ -358,8 +357,7 @@ def __init__(self, margin=1.0, size_average=True):
self.size_average = size_average

def forward(self, input, target):
return self._backend.HingeEmbeddingLoss(self.margin,
self.size_average)(input, target)
return F.hinge_embedding_loss(input, target, self.margin, self.size_average)


class MultiLabelMarginLoss(_Loss):
Expand All @@ -379,7 +377,9 @@ class MultiLabelMarginLoss(_Loss):
This allows for different samples to have variable amounts of target classes
"""
pass
def forward(self, input, target):
_assert_no_grad(target)
return F.multilabel_margin_loss(input, target, size_average=self.size_average)


class SmoothL1Loss(_Loss):
Expand All @@ -399,7 +399,9 @@ class SmoothL1Loss(_Loss):
The division by `n` can be avoided if one sets the internal variable
`size_average` to `False`
"""
pass
def forward(self, input, target):
_assert_no_grad(target)
return F.smooth_l1_loss(input, target, size_average=self.size_average)


class SoftMarginLoss(_Loss):
Expand All @@ -414,7 +416,9 @@ class SoftMarginLoss(_Loss):
The normalization by the number of elements in the input can be disabled by
setting `self.size_average` to `False`.
"""
pass
def forward(self, input, target):
_assert_no_grad(target)
return F.soft_margin_loss(input, target, size_average=self.size_average)


class CrossEntropyLoss(_WeightedLoss):
Expand Down Expand Up @@ -481,8 +485,7 @@ class MultiLabelSoftMarginLoss(_WeightedLoss):
"""

def forward(self, input, target):
return F.binary_cross_entropy(torch.sigmoid(input), target,
self.weight, self.size_average)
return F.multilabel_soft_margin_loss(input, target, self.weight, self.size_average)


class CosineEmbeddingLoss(Module):
Expand Down Expand Up @@ -513,8 +516,7 @@ def __init__(self, margin=0, size_average=True):
self.size_average = size_average

def forward(self, input1, input2, target):
return self._backend.CosineEmbeddingLoss(self.margin,
self.size_average)(input1, input2, target)
return F.cosine_embedding_loss(input1, input2, target, self.margin, self.size_average)


class MarginRankingLoss(Module):
Expand Down Expand Up @@ -542,8 +544,7 @@ def __init__(self, margin=0, size_average=True):
self.size_average = size_average

def forward(self, input1, input2, target):
return self._backend.MarginRankingLoss(self.margin,
self.size_average)(input1, input2, target)
return F.margin_ranking_loss(input1, input2, target, self.margin, self.size_average)


class MultiMarginLoss(Module):
Expand Down Expand Up @@ -580,8 +581,8 @@ def __init__(self, p=1, margin=1, weight=None, size_average=True):
self.weight = weight

def forward(self, input, target):
return self._backend.MultiMarginLoss(self.size_average, self.p,
self.margin, weight=self.weight)(input, target)
return F.multi_margin_loss(input, target, self.p, self.margin,
self.weight, self.size_average)


class TripletMarginLoss(Module):
Expand Down

0 comments on commit 11f3ccf

Please sign in to comment.