Skip to content

Commit

Permalink
torch.allclose opinfo (pytorch#68023)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#68023

Test Plan: Imported from OSS

Reviewed By: mruberry

Differential Revision: D32295811

Pulled By: george-qi

fbshipit-source-id: 3253104a5a9655d8ba7bbba6620038ed6d6669f1
  • Loading branch information
george-qi authored and facebook-github-bot committed Nov 10, 2021
1 parent 9a2db6f commit ae58644
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6774,6 +6774,26 @@ def sample_inputs_pixel_unshuffle(op_info, device, dtype, requires_grad, **kwarg
for downscale_factor in (1, 3)
]

def sample_inputs_allclose(op_info, device, dtype, requires_grad, **kwargs):
samples = []
sample_shapes = [(), (S), (S, S, S)]
atols = [1e-2, 1e-16]
rtols = [1e-1, 0.5]
eps = 1e-8
for s, rtol, atol in product(sample_shapes, rtols, atols):
# close sample
t = make_tensor(s, device=device, dtype=dtype, requires_grad=requires_grad)
close = (t + atol).detach().requires_grad_(requires_grad)
close_sample = SampleInput(t, args=(close,), kwargs=dict(rtol=rtol, atol=atol))
samples.append(close_sample)

# random sample
a = make_tensor(s, device=device, dtype=dtype, requires_grad=requires_grad)
b = make_tensor(s, device=device, dtype=dtype, requires_grad=requires_grad)
r_sample = SampleInput(a, args=(b,), kwargs=dict(rtol=rtol, atol=atol))
samples.append(r_sample)

return samples

foreach_unary_op_db: List[OpInfo] = [
ForeachFuncInfo('exp'),
Expand Down Expand Up @@ -7657,6 +7677,16 @@ def ref_pairwise_distance(input1, input2):
device_type='cuda', dtypes=[torch.cfloat],
active_if=IS_WINDOWS),
)),
OpInfo('allclose',
dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
ref=np.allclose,
supports_autograd=False,
supports_forward_ad=False,
sample_inputs_func=sample_inputs_allclose,
skips=(
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
),
supports_out=False),
OpInfo('broadcast_to',
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
supports_out=False,
Expand Down

0 comments on commit ae58644

Please sign in to comment.