File tree Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -102,6 +102,10 @@ def _strtobool(val: str) -> bool:
102102NP_MAX = np .iinfo (np .uint32 ).max
103103MAX_SEED = NP_MAX + 1 # 2**32, the actual seed should be in [0, MAX_SEED - 1] for uint32
104104
105+ # Environment variable must be set to enable determinism for algorithms (alternative value is ":16:8").
106+ # This needs to be here to ensure it's set before deterministic algorithms are used/initialised.
107+ os .environ ["CUBLAS_WORKSPACE_CONFIG" ] = os .environ .get ("CUBLAS_WORKSPACE_CONFIG" , ":4096:8" )
108+
105109
106110def zip_with (op , * vals , mapfunc = map ):
107111 """
@@ -383,8 +387,6 @@ def set_determinism(
383387 torch .backends .cudnn .benchmark = _flag_cudnn_benchmark
384388
385389 if use_deterministic_algorithms is not None :
386- # environment variable must be set to enable determinism for algorithms, alternative value is ":16:8"
387- os .environ ["CUBLAS_WORKSPACE_CONFIG" ] = os .environ .get ("CUBLAS_WORKSPACE_CONFIG" , ":4096:8" )
388390 torch .use_deterministic_algorithms (use_deterministic_algorithms )
389391
390392
You can’t perform that action at this time.
0 commit comments