Skip to content

Commit

Permalink
Add warn_only kwarg to use_deterministic_algorithms (pytorch#66233)
Browse files Browse the repository at this point in the history
Summary:
Fixes pytorch#64883

Adds a `warn_only` kwarg to `use_deterministic_algorithms`. When enabled, calling an operation that does not have a deterministic implementation will raise a warning, rather than an error.

`torch.testing._internal.common_device_type.expectedAlertNondeterministic` is also refactored and documented in this PR to make it easier to use and understand.

cc mruberry kurtamohler

Pull Request resolved: pytorch#66233

Reviewed By: bdhirsh

Differential Revision: D31616481

Pulled By: mruberry

fbshipit-source-id: 059634a82d54407492b1d8df08f059c758d0a420
  • Loading branch information
kurtamohler authored and facebook-github-bot committed Oct 15, 2021
1 parent 687c226 commit a256489
Show file tree
Hide file tree
Showing 8 changed files with 172 additions and 93 deletions.
28 changes: 21 additions & 7 deletions aten/src/ATen/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,18 +62,32 @@ bool Context::deterministicAlgorithms() const {
return _deterministic_algorithms;
}

void Context::setDeterministicAlgorithms(bool b) {
bool Context::deterministicAlgorithmsWarnOnly() const {
return _deterministic_algorithms_warn_only;
}

void Context::setDeterministicAlgorithms(bool b, bool warn_only=false) {
_deterministic_algorithms = b;
_deterministic_algorithms_warn_only = warn_only;
}

void Context::alertNotDeterministic(c10::string_view const& caller) {
if (globalContext().deterministicAlgorithms()) {
TORCH_CHECK(false,
caller, " does not have a deterministic implementation, but you set "
"'torch.use_deterministic_algorithms(True)'. You can turn off determinism ",
"just for this operation if that's acceptable for your application. You "
"can also file an issue at https://github.com/pytorch/pytorch/issues "
"to help us prioritize adding deterministic support for this operation.");
if (globalContext().deterministicAlgorithmsWarnOnly()) {
TORCH_WARN(
caller, " does not have a deterministic implementation, but you set "
"'torch.use_deterministic_algorithms(True, warn_only=True)'. "
"You can file an issue at https://github.com/pytorch/pytorch/issues "
"to help us prioritize adding deterministic support for this operation.");
} else {
TORCH_CHECK(false,
caller, " does not have a deterministic implementation, but you set "
"'torch.use_deterministic_algorithms(True)'. You can turn off "
"determinism just for this operation, or you can use the "
"'warn_only=True' option, if that's acceptable for your application. "
"You can also file an issue at https://github.com/pytorch/pytorch/issues "
"to help us prioritize adding deterministic support for this operation.");
}
}
}

Expand Down
4 changes: 3 additions & 1 deletion aten/src/ATen/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ class TORCH_API Context {
// }

bool deterministicAlgorithms() const;
void setDeterministicAlgorithms(bool);
bool deterministicAlgorithmsWarnOnly() const;
void setDeterministicAlgorithms(bool, bool);

// Note [Writing Nondeterministic Operations]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -234,6 +235,7 @@ class TORCH_API Context {
bool enabled_cudnn = true;
bool deterministic_cudnn = false;
bool _deterministic_algorithms = false;
bool _deterministic_algorithms_warn_only = false;
bool benchmark_cudnn = false;
bool allow_tf32_cudnn = true;
bool allow_tf32_cublas = true;
Expand Down
75 changes: 42 additions & 33 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,21 @@ def test_wildcard_import(self):

@wrapDeterministicFlagAPITest
def test_deterministic_flag(self):
for deterministic in [True, False]:
torch.use_deterministic_algorithms(deterministic)
for deterministic, warn_only in product([True, False], [True, False]):
torch.use_deterministic_algorithms(deterministic, warn_only=warn_only)
self.assertEqual(deterministic, torch.are_deterministic_algorithms_enabled())
self.assertEqual(warn_only, torch.is_deterministic_algorithms_warn_only_enabled())

with self.assertRaisesRegex(RuntimeError, r"use_deterministic_algorithms expects a bool, but got int"):
with self.assertRaisesRegex(
TypeError,
r"_set_deterministic_algorithms\(\): argument 'mode' \(position 1\) must be bool, not int"):
torch.use_deterministic_algorithms(1)

with self.assertRaisesRegex(
TypeError,
r"_set_deterministic_algorithms\(\): argument 'warn_only' must be bool, not int"):
torch.use_deterministic_algorithms(False, warn_only=1)

def test_type_conversion_via_dtype_name(self):
x = torch.tensor([1])
self.assertEqual(x.byte().dtype, torch.uint8)
Expand Down Expand Up @@ -3880,7 +3888,7 @@ def test_nondeterministic_alert_AvgPool3d(self, device):
res = module(input)
grad = torch.ones_like(res)

@expectedAlertNondeterministic('avg_pool3d_backward_cuda', 'cuda')
@expectedAlertNondeterministic('avg_pool3d_backward_cuda', ['cuda'], test_warning=False)
def backward_func(slf, device):
res.backward(grad)

Expand All @@ -3892,7 +3900,7 @@ def test_nondeterministic_alert_AdaptiveAvgPool2d(self, device):
res = module(input)
grad = torch.ones_like(res)

@expectedAlertNondeterministic('adaptive_avg_pool2d_backward_cuda', 'cuda')
@expectedAlertNondeterministic('adaptive_avg_pool2d_backward_cuda', ['cuda'], test_warning=False)
def backward_func(slf, device):
res.backward(grad)

Expand All @@ -3904,7 +3912,7 @@ def test_nondeterministic_alert_AdaptiveAvgPool3d(self, device):
res = module(input)
grad = torch.ones_like(res)

@expectedAlertNondeterministic('adaptive_avg_pool3d_backward_cuda', 'cuda')
@expectedAlertNondeterministic('adaptive_avg_pool3d_backward_cuda', ['cuda'], test_warning=False)
def backward_func(slf, device):
res.backward(grad)

Expand All @@ -3916,7 +3924,7 @@ def test_nondeterministic_alert_MaxPool3d(self, device):
res = module(input)
grad = torch.ones_like(res)

@expectedAlertNondeterministic('max_pool3d_with_indices_backward_cuda', 'cuda')
@expectedAlertNondeterministic('max_pool3d_with_indices_backward_cuda', ['cuda'], test_warning=False)
def backward_func(slf, device):
res.backward(grad)

Expand All @@ -3928,7 +3936,7 @@ def test_nondeterministic_alert_AdaptiveMaxPool2d(self, device):
res = module(input)
grad = torch.ones_like(res)

@expectedAlertNondeterministic('adaptive_max_pool2d_backward_cuda', 'cuda')
@expectedAlertNondeterministic('adaptive_max_pool2d_backward_cuda', ['cuda'], test_warning=False)
def backward_func(slf, device):
res.backward(grad)

Expand All @@ -3940,7 +3948,7 @@ def test_nondeterministic_alert_FractionalMaxPool2d(self, device):
res = module(input)
grad = torch.ones_like(res)

@expectedAlertNondeterministic('fractional_max_pool2d_backward_cuda', 'cuda')
@expectedAlertNondeterministic('fractional_max_pool2d_backward_cuda', ['cuda'], test_warning=False)
def backward_func(slf, device):
res.backward(grad)

Expand All @@ -3952,7 +3960,7 @@ def test_nondeterministic_alert_FractionalMaxPool3d(self, device):
res = module(input)
grad = torch.ones_like(res)

@expectedAlertNondeterministic('fractional_max_pool3d_backward_cuda', 'cuda')
@expectedAlertNondeterministic('fractional_max_pool3d_backward_cuda', ['cuda'], test_warning=False)
def backward_func(slf, device):
res.backward(grad)

Expand All @@ -3967,7 +3975,7 @@ def test_nondeterministic_alert_interpolate_linear(self, device):
align_corners=False)
grad = torch.ones_like(res)

@expectedAlertNondeterministic('upsample_linear1d_backward_out_cuda', 'cuda')
@expectedAlertNondeterministic('upsample_linear1d_backward_out_cuda', ['cuda'], test_warning=False)
def backward_func(slf, device):
res.backward(grad)

Expand All @@ -3982,7 +3990,7 @@ def test_nondeterministic_alert_interpolate_bilinear(self, device):
align_corners=False)
grad = torch.ones_like(res)

@expectedAlertNondeterministic('upsample_bilinear2d_backward_out_cuda', 'cuda')
@expectedAlertNondeterministic('upsample_bilinear2d_backward_out_cuda', ['cuda'], test_warning=False)
def backward_func(slf, device):
res.backward(grad)

Expand All @@ -3997,7 +4005,7 @@ def test_nondeterministic_alert_interpolate_bicubic(self, device):
align_corners=False)
grad = torch.ones_like(res)

@expectedAlertNondeterministic('upsample_bicubic2d_backward_out_cuda', 'cuda')
@expectedAlertNondeterministic('upsample_bicubic2d_backward_out_cuda', ['cuda'], test_warning=False)
def backward_func(slf, device):
res.backward(grad)

Expand All @@ -4012,7 +4020,7 @@ def test_nondeterministic_alert_interpolate_trilinear(self, device):
align_corners=False)
grad = torch.ones_like(res)

@expectedAlertNondeterministic('upsample_trilinear3d_backward_out_cuda', 'cuda')
@expectedAlertNondeterministic('upsample_trilinear3d_backward_out_cuda', ['cuda'], test_warning=False)
def backward_func(slf, device):
res.backward(grad)

Expand All @@ -4024,7 +4032,7 @@ def test_nondeterministic_alert_ReflectionPad1d(self, device):
res = module(input)
grad = torch.ones_like(res)

@expectedAlertNondeterministic('reflection_pad1d_backward_out_cuda', 'cuda')
@expectedAlertNondeterministic('reflection_pad1d_backward_out_cuda', ['cuda'], test_warning=False)
def backward_func(slf, device):
res.backward(grad)

Expand All @@ -4036,7 +4044,7 @@ def test_nondeterministic_alert_ReflectionPad2d(self, device):
res = module(input)
grad = torch.ones_like(res)

@expectedAlertNondeterministic('reflection_pad2d_backward_cuda', 'cuda')
@expectedAlertNondeterministic('reflection_pad2d_backward_cuda', ['cuda'], test_warning=False)
def backward_func(slf, device):
res.backward(grad)

Expand All @@ -4048,7 +4056,7 @@ def test_nondeterministic_alert_ReflectionPad3d(self, device):
res = module(input)
grad = torch.ones_like(res)

@expectedAlertNondeterministic('reflection_pad3d_backward_out_cuda', 'cuda')
@expectedAlertNondeterministic('reflection_pad3d_backward_out_cuda', ['cuda'], test_warning=False)
def backward_func(slf, device):
res.backward(grad)

Expand All @@ -4060,7 +4068,7 @@ def test_nondeterministic_alert_ReplicationPad1d(self, device):
res = module(input)
grad = torch.ones_like(res)

@expectedAlertNondeterministic('replication_pad1d_backward_cuda', 'cuda')
@expectedAlertNondeterministic('replication_pad1d_backward_cuda', ['cuda'], test_warning=False)
def backward_func(slf, device):
res.backward(grad)

Expand All @@ -4072,7 +4080,7 @@ def test_nondeterministic_alert_ReplicationPad2d(self, device):
res = module(input)
grad = torch.ones_like(res)

@expectedAlertNondeterministic('replication_pad2d_backward_cuda', 'cuda')
@expectedAlertNondeterministic('replication_pad2d_backward_cuda', ['cuda'], test_warning=False)
def backward_func(slf, device):
res.backward(grad)

Expand All @@ -4084,7 +4092,7 @@ def test_nondeterministic_alert_ReplicationPad3d(self, device):
res = module(input)
grad = torch.ones_like(res)

@expectedAlertNondeterministic('replication_pad3d_backward_cuda', 'cuda')
@expectedAlertNondeterministic('replication_pad3d_backward_cuda', ['cuda'], test_warning=False)
def backward_func(slf, device):
res.backward(grad)

Expand All @@ -4095,7 +4103,7 @@ def test_nondeterministic_alert_NLLLoss(self, device):
input = torch.randn(2, 3, 5, 5, device=device)
target = torch.rand(2, 5, 5, device=device).mul(3).floor().long()

@expectedAlertNondeterministic('nll_loss2d_forward_out_cuda_template', 'cuda')
@expectedAlertNondeterministic('nll_loss2d_forward_out_cuda_template', ['cuda'])
def forward_func(slf, device):
module(input, target)

Expand All @@ -4110,9 +4118,10 @@ def test_nondeterministic_alert_CTCLoss(self, device):
res = module(input, target, input_lengths, target_lengths)
grad = torch.ones_like(res)

@expectedAlertNondeterministic('ctc_loss_backward_gpu', 'cuda')
@expectedAlertNondeterministic('ctc_loss_backward_gpu', ['cuda'], test_warning=False)
def backward_func(slf, device):
res.backward(grad)
with warnings.catch_warnings(record=True) as w:
res.backward(grad)

backward_func(self, device)

Expand All @@ -4124,7 +4133,7 @@ def test_nondeterministic_alert_EmbeddingBag_max(self, device):
res = module(input)
grad = torch.ones_like(res)

@expectedAlertNondeterministic('embedding_bag_backward_cuda_max', 'cuda')
@expectedAlertNondeterministic('embedding_bag_backward_cuda_max', ['cuda'], test_warning=False)
def backward_func(slf, device):
res.backward(grad)

Expand All @@ -4137,7 +4146,7 @@ def test_func(op_call):
index = torch.tensor([[3]], device=device)
src = torch.tensor([[1.0]], device=device)

@expectedAlertNondeterministic('scatter_add_cuda_kernel', 'cuda')
@expectedAlertNondeterministic('scatter_add_cuda_kernel', ['cuda'])
def forward_func(slf, device):
op_call(input, dim, index, src)

Expand Down Expand Up @@ -4169,7 +4178,7 @@ def test_func(op_call):
indices = torch.tensor([0, 0], device=device)
values = torch.tensor([0., 1.], device=device)

@expectedAlertNondeterministic('put_', 'cuda')
@expectedAlertNondeterministic('put_', ['cuda'])
def forward_func(slf, device):
op_call(a, indices, values, accumulate=True)

Expand All @@ -4182,7 +4191,7 @@ def test_nondeterministic_alert_histc(self, device):
def test_func(op_call):
a = torch.tensor([], device=device)

@expectedAlertNondeterministic('_histc_cuda', 'cuda')
@expectedAlertNondeterministic('_histc_cuda', ['cuda'])
def forward_func(slf, device):
res = op_call(a, min=0, max=3)

Expand All @@ -4195,7 +4204,7 @@ def test_nondeterministic_alert_bincount(self, device):
def test_func(op_call):
a = torch.tensor([], device=device, dtype=torch.long)

@expectedAlertNondeterministic('_bincount_cuda', 'cuda')
@expectedAlertNondeterministic('_bincount_cuda', ['cuda'])
def forward_func(slf, device):
res = op_call(a)

Expand All @@ -4207,7 +4216,7 @@ def forward_func(slf, device):
# Ensures that kthvalue throws nondeterministic alerts in the correct cases
@dtypes(torch.double)
def test_nondeterministic_alert_kthvalue(self, device, dtype):
@expectedAlertNondeterministic('kthvalue CUDA', 'cuda')
@expectedAlertNondeterministic('kthvalue CUDA', ['cuda'])
def test_func(slf, device, call_type):
S = 10
k = 5
Expand Down Expand Up @@ -4236,7 +4245,7 @@ def test_func(op_call):
res = op_call(a, dim, index)
grad = torch.ones_like(res)

@expectedAlertNondeterministic('scatter_add_cuda_kernel', 'cuda')
@expectedAlertNondeterministic('scatter_add_cuda_kernel', ['cuda'], test_warning=False)
def backward_func(slf, device):
res.backward(grad)

Expand All @@ -4251,7 +4260,7 @@ def test_nondeterministic_alert_grid_sample_2d(self, device):
res = torch.nn.functional.grid_sample(input, grid, align_corners=False)
grad = torch.ones_like(res)

@expectedAlertNondeterministic('grid_sampler_2d_backward_cuda', 'cuda')
@expectedAlertNondeterministic('grid_sampler_2d_backward_cuda', ['cuda'], test_warning=False)
def backward_func(slf, device):
res.backward(grad)

Expand All @@ -4263,7 +4272,7 @@ def test_nondeterministic_alert_grid_sample_3d(self, device):
res = torch.nn.functional.grid_sample(input, grid, align_corners=False)
grad = torch.ones_like(res)

@expectedAlertNondeterministic('grid_sampler_3d_backward_cuda', 'cuda')
@expectedAlertNondeterministic('grid_sampler_3d_backward_cuda', ['cuda'], test_warning=False)
def backward_func(slf, device):
res.backward(grad)

Expand Down Expand Up @@ -4314,7 +4323,7 @@ def test_func(slf, device, call_type):
else:
self.fail(f"'{call_type}' is not a valid call type")

@expectedAlertNondeterministic('median CUDA with indices output', 'cuda')
@expectedAlertNondeterministic('median CUDA with indices output', ['cuda'])
def test_func_expect_error(slf, device, call_type):
test_func(slf, device, call_type)

Expand Down
3 changes: 2 additions & 1 deletion torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,8 @@ def _set_cudnn_benchmark(arg: _bool) -> None: ... # THPModule_setBenchmarkCuDNN
def _get_cudnn_deterministic() -> _bool: ... # THPModule_deterministicCuDNN
def _set_cudnn_deterministic(arg: _bool) -> None: ... # THPModule_setDeterministicCuDNN
def _get_deterministic_algorithms() -> _bool: ... # THPModule_deterministicAlgorithms
def _set_deterministic_algorithms(arg: _bool) -> None: ... # THPModule_setDeterministicAlgorithms
def _get_deterministic_algorithms_warn_only() -> _bool: ... # THPModule_deterministicAlgorithmsWarnOnly
def _set_deterministic_algorithms(mode: _bool, *, warn_only: _bool) -> None: ... # THPModule_setDeterministicAlgorithms
def _get_warnAlways() -> _bool: ... # THPModule_warnAlways
def _set_warnAlways(arg: _bool) -> None: ... # THPModule_setWarnAlways
def _get_cudnn_allow_tf32() -> _bool: ... # THPModule_allowTF32CuDNN
Expand Down
Loading

0 comments on commit a256489

Please sign in to comment.