Skip to content

Commit

Permalink
Always use fast gradcheck for LayerNorm 3d_no_affine_large_feature (p…
Browse files Browse the repository at this point in the history
…ytorch#61848)

Summary:
Due to the introduction of a test from https://github.com/pytorch/pytorch/pull/59987/files, slow gradcheck has been failing intermittently (timing out/getting killed).

Pull Request resolved: pytorch#61848

Reviewed By: mrshenli

Differential Revision: D29765773

Pulled By: soulitzer

fbshipit-source-id: d78bee758cab76f26ba9f54925c42d4825db9449
  • Loading branch information
soulitzer authored and facebook-github-bot committed Jul 19, 2021
1 parent 094abf5 commit 1b0a7f3
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions torch/testing/_internal/common_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1583,6 +1583,7 @@ def single_batch_reference_fn(input, parameters, module):
input_size=(4, 56, 56, 56),
cudnn=True,
check_eval=True,
gradcheck_fast_mode=True,
desc='3d_no_affine_large_feature',
),
dict(
Expand Down Expand Up @@ -5297,6 +5298,7 @@ def __init__(self, *args, **kwargs):
self.test_cpu = kwargs.get('test_cpu', True)
self.has_sparse_gradients = kwargs.get('has_sparse_gradients', False)
self.check_batched_grad = kwargs.get('check_batched_grad', True)
self.gradcheck_fast_mode = kwargs.get('gradcheck_fast_mode', None)

def _check_gradients(self, test_case, module, input_tuple):
params = tuple(x for x in module.parameters())
Expand All @@ -5316,11 +5318,13 @@ def fn_to_gradcheck(*inputs_and_params, **kwargs):
test_case.check_jacobian(module, input_tuple[0], test_input_jacobian)
else:
test_case.assertTrue(gradcheck(fn_to_gradcheck, input_tuple + params,
check_batched_grad=self.check_batched_grad))
check_batched_grad=self.check_batched_grad,
fast_mode=self.gradcheck_fast_mode))

if self.check_gradgrad:
test_case.assertTrue(gradgradcheck(fn_to_gradcheck, input_tuple + params,
check_batched_grad=self.check_batched_grad))
check_batched_grad=self.check_batched_grad,
fast_mode=self.gradcheck_fast_mode))

def _do_test(self, test_case, module, input):
num_threads = torch.get_num_threads()
Expand Down

0 comments on commit 1b0a7f3

Please sign in to comment.