Skip to content

Commit

Permalink
Fix L1Loss when target.requires_grad is True. (pytorch#44471)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#44471

L1Loss had a completely different (and incorrect, see pytorch#43228) path when target.requires_grad was True.

This PR does the following:

1) adds derivative support for target via the normal derivatives.yaml route
2) kill the different (and incorrect) path for when target.requires_grad was True
3) modify the L1Loss CriterionTests to verify that the target derivative is checked.

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D23626008

Pulled By: gchanan

fbshipit-source-id: 2828be16b56b8dabe114962223d71b0e9a85f0f5
  • Loading branch information
gchanan authored and facebook-github-bot committed Sep 11, 2020
1 parent ea55820 commit 3de2c0b
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 11 deletions.
2 changes: 2 additions & 0 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1202,6 +1202,7 @@

- name: l1_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor
self: l1_loss_backward(grad, self, target, reduction)
target: l1_loss_backward(grad, target, self, reduction)

- name: mse_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor
self: mse_loss_backward(grad, self, target, reduction)
Expand Down Expand Up @@ -1520,6 +1521,7 @@
- name: l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor
grad_output: l1_loss_double_backward_grad_output(grad, self, target, reduction)
self: zeros_like(grad, at::MemoryFormat::Preserve)
target: zeros_like(grad, at::MemoryFormat::Preserve)

- name: log_sigmoid_backward(Tensor grad_output, Tensor self, Tensor buffer) -> Tensor
grad_output: log_sigmoid_backward(grad, self, buffer)
Expand Down
13 changes: 4 additions & 9 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2629,15 +2629,10 @@ def l1_loss(input, target, size_average=None, reduce=None, reduction='mean'):
stacklevel=2)
if size_average is not None or reduce is not None:
reduction = _Reduction.legacy_get_string(size_average, reduce)
if target.requires_grad:
_Reduction.get_enum(reduction) # throw an error if reduction is invalid
ret = torch.abs(input - target)
if reduction != 'none':
ret = torch.mean(ret) if reduction == 'mean' else torch.sum(ret)
else:
expanded_input, expanded_target = torch.broadcast_tensors(input, target)
ret = torch._C._nn.l1_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
return ret


expanded_input, expanded_target = torch.broadcast_tensors(input, target)
return torch._C._nn.l1_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))


def mse_loss(input, target, size_average=None, reduce=None, reduction='mean'):
Expand Down
4 changes: 2 additions & 2 deletions torch/testing/_internal/common_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3879,7 +3879,7 @@ def padding3d_circular(input, pad):
dict(
module_name='L1Loss',
input_size=(2, 3, 4),
target_size=(2, 3, 4),
target_fn=lambda: torch.randn((2, 3, 4), requires_grad=True),
reference_fn=lambda i, t, _: 1. / i.numel() *
sum((a - b).abs().sum() for a, b in zip(i, t)),
),
Expand Down Expand Up @@ -4277,7 +4277,7 @@ def padding3d_circular(input, pad):
dict(
module_name='L1Loss',
input_size=(),
target_size=(),
target_fn=lambda: torch.randn((), requires_grad=True),
reference_fn=lambda i, t, _: 1. / i.numel() * (i - t).abs().sum(),
desc='scalar',
),
Expand Down

0 comments on commit 3de2c0b

Please sign in to comment.