Skip to content

Commit 21f1f39

Browse files
committed
Determinism fixes
1 parent 1dc0e68 commit 21f1f39

File tree

2 files changed

+25
-18
lines changed

2 files changed

+25
-18
lines changed

monai/utils/misc.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -374,23 +374,18 @@ def set_determinism(
374374
for func in additional_settings:
375375
func(seed)
376376

377-
if torch.backends.flags_frozen():
378-
warnings.warn("PyTorch global flag support of backends is disabled, enable it to set global `cudnn` flags.")
379-
torch.backends.__allow_nonbracketed_mutation_flag = True
380-
381-
if seed is not None:
382-
torch.backends.cudnn.deterministic = True
383-
torch.backends.cudnn.benchmark = False
384-
else: # restore the original flags
385-
torch.backends.cudnn.deterministic = _flag_deterministic
386-
torch.backends.cudnn.benchmark = _flag_cudnn_benchmark
377+
with torch.backends.__allow_nonbracketed_mutation(): # FIXME: better method without accessing private member
378+
if seed is not None:
379+
torch.backends.cudnn.deterministic = True
380+
torch.backends.cudnn.benchmark = False
381+
else: # restore the original flags
382+
torch.backends.cudnn.deterministic = _flag_deterministic
383+
torch.backends.cudnn.benchmark = _flag_cudnn_benchmark
384+
387385
if use_deterministic_algorithms is not None:
388-
if hasattr(torch, "use_deterministic_algorithms"): # `use_deterministic_algorithms` is new in torch 1.8.0
389-
torch.use_deterministic_algorithms(use_deterministic_algorithms)
390-
elif hasattr(torch, "set_deterministic"): # `set_deterministic` is new in torch 1.7.0
391-
torch.set_deterministic(use_deterministic_algorithms)
392-
else:
393-
warnings.warn("use_deterministic_algorithms=True, but PyTorch version is too old to set the mode.")
386+
# environment variable must be set to enable determinism for algorithms, alternative value is ":16:8"
387+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = os.environ.get("CUBLAS_WORKSPACE_CONFIG", ":4096:8")
388+
torch.use_deterministic_algorithms(use_deterministic_algorithms)
394389

395390

396391
def list_to_dict(items):

tests/utils/test_set_determinism.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ def test_values(self):
4242
self.assertEqual(seed, get_seed())
4343
a = np.random.randint(seed)
4444
b = torch.randint(seed, (1,))
45-
# tset when global flag support is disabled
45+
46+
# test when global flag support is disabled
4647
torch.backends.disable_global_flags()
4748
set_determinism(seed=seed)
4849
c = np.random.randint(seed)
@@ -60,12 +61,23 @@ def setUp(self):
6061

6162
@SkipIfBeforePyTorchVersion((1, 8)) # beta feature
6263
@skip_if_no_cuda
63-
def test_algo(self):
64+
def test_algo_not_deterministic(self):
65+
"""
66+
Test `avg_pool3d_backward_cuda` correctly raises an exception since it lacks a deterministic implementation.
67+
"""
6468
with self.assertRaises(RuntimeError):
6569
x = torch.randn(20, 16, 50, 44, 31, requires_grad=True, device="cuda:0")
6670
y = torch.nn.AvgPool3d((3, 2, 2), stride=(2, 1, 2))(x)
6771
y.sum().backward()
6872

73+
@skip_if_no_cuda
74+
def test_algo_cublas_env(self):
75+
"""
76+
Test `torch.mm` does not raise an exception with the CUBLAS_WORKSPACE_CONFIG environment variable correctly set.
77+
"""
78+
x = torch.rand(5, 5, device="cuda:0")
79+
_ = torch.mm(x, x)
80+
6981
def tearDown(self):
7082
set_determinism(None, use_deterministic_algorithms=False)
7183

0 commit comments

Comments
 (0)