@@ -374,23 +374,18 @@ def set_determinism(
374374 for func in additional_settings :
375375 func (seed )
376376
377- if torch .backends .flags_frozen ():
378- warnings .warn ("PyTorch global flag support of backends is disabled, enable it to set global `cudnn` flags." )
379- torch .backends .__allow_nonbracketed_mutation_flag = True
380-
381- if seed is not None :
382- torch .backends .cudnn .deterministic = True
383- torch .backends .cudnn .benchmark = False
384- else : # restore the original flags
385- torch .backends .cudnn .deterministic = _flag_deterministic
386- torch .backends .cudnn .benchmark = _flag_cudnn_benchmark
377+ with torch .backends .__allow_nonbracketed_mutation (): # FIXME: better method without accessing private member
378+ if seed is not None :
379+ torch .backends .cudnn .deterministic = True
380+ torch .backends .cudnn .benchmark = False
381+ else : # restore the original flags
382+ torch .backends .cudnn .deterministic = _flag_deterministic
383+ torch .backends .cudnn .benchmark = _flag_cudnn_benchmark
384+
387385 if use_deterministic_algorithms is not None :
388- if hasattr (torch , "use_deterministic_algorithms" ): # `use_deterministic_algorithms` is new in torch 1.8.0
389- torch .use_deterministic_algorithms (use_deterministic_algorithms )
390- elif hasattr (torch , "set_deterministic" ): # `set_deterministic` is new in torch 1.7.0
391- torch .set_deterministic (use_deterministic_algorithms )
392- else :
393- warnings .warn ("use_deterministic_algorithms=True, but PyTorch version is too old to set the mode." )
386+ # environment variable must be set to enable determinism for algorithms, alternative value is ":16:8"
387+ os .environ ["CUBLAS_WORKSPACE_CONFIG" ] = os .environ .get ("CUBLAS_WORKSPACE_CONFIG" , ":4096:8" )
388+ torch .use_deterministic_algorithms (use_deterministic_algorithms )
394389
395390
396391def list_to_dict (items ):
0 commit comments