Skip to content

Commit

Permalink
For CriterionTests, have check_gradgrad actually only affect gradgrad…
Browse files Browse the repository at this point in the history
… checks. (pytorch#44060)

Summary:
Pull Request resolved: pytorch#44060

Right now it skips grad checks as well.

Test Plan: Imported from OSS

Reviewed By: zou3519

Differential Revision: D23484018

Pulled By: gchanan

fbshipit-source-id: 24a8f1af41f9918aaa62bc3cd78b139b2f8de1e1
  • Loading branch information
gchanan authored and facebook-github-bot committed Sep 3, 2020
1 parent 42f9897 commit 49215d7
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions torch/testing/_internal/common_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5070,9 +5070,6 @@ def __call__(self, test_case):
self._do_extra_tests(test_case, module, input, target)

def _do_extra_tests(self, test_case, module, input, target):
if not self.check_gradgrad:
return

test_case.assertFalse(target.requires_grad)

params = tuple(x for x in module.parameters())
Expand All @@ -5090,6 +5087,10 @@ def apply_fn(input1, input2, *params):
# TODO: we don't pass `target` as part of inputs because we don't
# currently compute the gradient w.r.t. target for loss functions.
gradcheck(apply_fn, inputs)

if not self.check_gradgrad:
return

gradgradcheck(apply_fn, inputs)

def test_cuda(self, test_case, dtype=None, extra_args=None):
Expand Down

0 comments on commit 49215d7

Please sign in to comment.