Skip to content

Commit

Permalink
More control over which seeds get set
Browse files Browse the repository at this point in the history
  • Loading branch information
cgpotts committed Apr 9, 2019
1 parent 924b4b8 commit 681d297
Showing 1 changed file with 36 additions and 7 deletions.
43 changes: 36 additions & 7 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,32 +228,61 @@ def info_filter(logrec):
return int('loss' in logrec.getMessage().lower())
logging.getLogger('tensorflow').addFilter(info_filter)

def fix_random_seed(seed=42, set_torch_cudnn=True):

def fix_random_seeds(
seed=42,
set_system=True,
set_torch=True,
set_tensorflow=True,
set_torch_cudnn=True):
"""Fix random seeds for reproducibility.
Parameters
----------
seed : int
Random seed to be set.
set_system : bool
Whether to set `np.random.seed(seed)` and `random.seed(seed)`
set_tensorflow : bool
Whether to set `tf.random.set_random_seed(seed)`
set_torch : bool
Whether to set `torch.manual_seed(seed)`
set_torch_cudnn: bool
Flag for whether to enable cudnn deterministic mode.
Note that deterministic mode can have a performance impact, depending on your model.
https://pytorch.org/docs/stable/notes/randomness.html
"""
Notes
-----
Even though the random seeds are explicitly set,
the behavior may still not be deterministic (especially when a
GPU is enabled), due to:
* CUDA: There are some PyTorch functions that use CUDA functions
that can be a source of non-determinism:
https://pytorch.org/docs/stable/notes/randomness.html
* PYTHONHASHSEED: On Python 3.3 and greater, hash randomization is
turned on by default. This seed could be fixed before calling the
python interpreter (PYTHONHASHSEED=0 python test.py). However, it
seems impossible to set it inside the python program:
https://stackoverflow.com/questions/30585108/disable-hash-randomization-from-within-python-program
"""
# set system seed
np.random.seed(seed)
random.seed(seed)
if set_system:
np.random.seed(seed)
random.seed(seed)

# set torch seed
torch.manual_seed(seed)
if set_torch:
torch.manual_seed(seed)

# set torch cudnn backend
if set_torch_cudnn:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# set tf seed
tf.random.set_random_seed(seed)
if set_tensorflow:
tf.random.set_random_seed(seed)

0 comments on commit 681d297

Please sign in to comment.