Skip to content

Commit

Permalink
fix gradcheck to generate valid input for forward AD complex (pytorch…
Browse files Browse the repository at this point in the history
…#68001)

Summary:
This fixed a few of the linalg checks that we disabled before!

This also seems to break sgn, abs and angle (sending on CI here to see if there are more). These two functions used to only ever get pure imaginary or real values.
This is very much likely that something is wrong with their formula.
But they are implemented as element-wise, so not sure where the error can come from. I tried to look at it but nothing obvious seems wrong there (especially because it is correct in backward mode).

Pull Request resolved: pytorch#68001

Reviewed By: soulitzer

Differential Revision: D32280475

Pulled By: albanD

fbshipit-source-id: e68b1ce0e2e97f8917c3d393141d649a7669aa9d
  • Loading branch information
albanD authored and facebook-github-bot committed Nov 10, 2021
1 parent 94b6fa6 commit a6c0edf
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 26 deletions.
17 changes: 11 additions & 6 deletions torch/autograd/gradcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,17 +965,22 @@ def wrapped_fn(*inputs):

return apply_to_c_outs(fn, torch.real), apply_to_c_outs(fn, torch.imag)

def _real_and_imag_input(fn, complex_inp_indices):
# returns new functions that take real inputs instead of complex inputs and compute fn(x + 0 * 1j)
# and f(x * 1j).
def _real_and_imag_input(fn, complex_inp_indices, tupled_inputs):
# returns new functions that take real inputs instead of complex inputs as
# (x, y) -> fn(x + y * 1j). And it computes: inp -> fn(inp + y * 1j) and inp -> fn(x + inp * 1j).
# In each case, the other part is considered constant.
# We do not use 0 for the constant here to make sure we always call the user function with a valid input.
def apply_to_c_inps(fn, fn_to_apply):
def wrapped_fn(*inputs):
new_inputs = list(inputs)
for should_be_complex in complex_inp_indices:
new_inputs[should_be_complex] = fn_to_apply(new_inputs[should_be_complex])
new_inputs[should_be_complex] = fn_to_apply(new_inputs[should_be_complex],
tupled_inputs[should_be_complex])
return _as_tuple(fn(*new_inputs))
return wrapped_fn
return apply_to_c_inps(fn, lambda x: x + 0 * 1j), apply_to_c_inps(fn, lambda x: x * 1j)
real_fn = apply_to_c_inps(fn, lambda inp, orig: inp + orig.imag * 1j)
imag_fn = apply_to_c_inps(fn, lambda inp, orig: orig.real + inp * 1j)
return real_fn, imag_fn


def _gradcheck_real_imag(gradcheck_fn, func, func_out, tupled_inputs, outputs, eps, rtol,
Expand Down Expand Up @@ -1003,7 +1008,7 @@ def _gradcheck_real_imag(gradcheck_fn, func, func_out, tupled_inputs, outputs, e
if check_forward_ad:
complex_inp_indices = [i for i, inp in enumerate(tupled_inputs) if is_tensor_like(inp) and inp.is_complex()]
if complex_inp_indices:
real_fn, imag_fn = _real_and_imag_input(func, complex_inp_indices)
real_fn, imag_fn = _real_and_imag_input(func, complex_inp_indices, tupled_inputs)

imag_inputs = [inp.imag if is_tensor_like(inp) and inp.is_complex() else inp for inp in tupled_inputs]
imag_func_out = imag_fn(*imag_inputs)
Expand Down
37 changes: 17 additions & 20 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7305,6 +7305,9 @@ def ref_pairwise_distance(input1, input2):
# https://github.com/pytorch/pytorch/blob/master/test/test_unary_ufuncs.py#L440-L449
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_out_arg_all_dtypes',
dtypes=[torch.cfloat, torch.cdouble]),
# The complex formula might be wrong
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD',
dtypes=complex_types()),
),
supports_inplace_autograd=False,
assert_autodiffed=True,
Expand Down Expand Up @@ -7373,7 +7376,7 @@ def ref_pairwise_distance(input1, input2):
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_method_grad',
device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS),
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD',
dtypes=[torch.cdouble]),
device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS),
)),
BinaryUfuncInfo('add',
# NumPy has no builtin reference for the alpha kwarg, but it is easy enough to emulate
Expand Down Expand Up @@ -7596,10 +7599,6 @@ def ref_pairwise_distance(input1, input2):
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard',
device_type='cuda', dtypes=[torch.cdouble],
active_if=IS_WINDOWS),
# Complex gradcheck tests asinh at points 0 + ix for x > 1 which are points
# where asinh is not differentiable
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD',
dtypes=complex_types()),
)),
UnaryUfuncInfo('atan',
aliases=('arctan', ),
Expand Down Expand Up @@ -7730,10 +7729,7 @@ def ref_pairwise_distance(input1, input2):
check_batched_gradgrad=False,
sample_inputs_func=sample_inputs_linalg_cholesky,
gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack],
skips=(
# Gradcheck for complex generates invalid inputs for this function
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD', dtypes=complex_types()),)),
decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack],),
OpInfo('cholesky_inverse',
dtypes=floating_and_complex_types(),
backward_dtypes=floating_types(),
Expand All @@ -7759,9 +7755,7 @@ def ref_pairwise_distance(input1, input2):
decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack],
skips=(
# cholesky_solve does not correctly warn when resizing out= inputs
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'),
# Gradcheck for complex generates invalid inputs for this function, i.e. NaNs.
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD', dtypes=complex_types()),)),
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'),),),
OpInfo('chunk',
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
sample_inputs_func=sample_inputs_chunk,
Expand Down Expand Up @@ -8522,9 +8516,6 @@ def ref_pairwise_distance(input1, input2):
sample_inputs_func=sample_inputs_linalg_cholesky,
gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack],
skips=(
# Gradcheck for complex generates invalid inputs for this function
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD', dtypes=complex_types()),),
),
OpInfo('linalg.cholesky_ex',
aten_name='linalg_cholesky_ex',
Expand All @@ -8534,9 +8525,6 @@ def ref_pairwise_distance(input1, input2):
sample_inputs_func=sample_inputs_linalg_cholesky,
gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack],
skips=(
# Gradcheck for complex generates invalid inputs for this function
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD', dtypes=complex_types()),),
),
OpInfo('linalg.cond',
aten_name='linalg_cond',
Expand Down Expand Up @@ -10019,7 +10007,12 @@ def ref_pairwise_distance(input1, input2):
device_type='cpu', dtypes=[torch.complex64, torch.complex128]),
# Reference: https://github.com/pytorch/pytorch/issues/48486
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard',
device_type='cpu', dtypes=[torch.complex64])
device_type='cpu', dtypes=[torch.complex64]),
# The complex formula might be wrong
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD',
dtypes=complex_types()),
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_inplace_forward_mode_AD',
dtypes=complex_types()),
)),
OpInfo('split',
dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool),
Expand Down Expand Up @@ -10482,7 +10475,11 @@ def ref_pairwise_distance(input1, input2):
torch.bfloat16: 1e-2}),),
safe_casts_outputs=True,
supports_forward_ad=True,
supports_complex_to_float=True),
supports_complex_to_float=True,
skips=(
# The complex formula might be wrong
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD',
dtypes=complex_types()),),),
UnaryUfuncInfo('isfinite',
ref=np.isfinite,
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
Expand Down

0 comments on commit a6c0edf

Please sign in to comment.