-
Notifications
You must be signed in to change notification settings - Fork 224
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Utility functions to get and set seeds for the Python, Numpy, TensorFlow and PyTorch random number generators.
- Loading branch information
1 parent
1398431
commit a87153b
Showing
14 changed files
with
308 additions
and
117 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
""" | ||
This submodule contains utility functions to manage random number generator (RNG) seeds. It may change | ||
depending on how we decide to handle randomisation in tests (and elsewhere) going forwards. See | ||
https://github.com/SeldonIO/alibi-detect/issues/250. | ||
""" | ||
from contextlib import contextmanager | ||
import random | ||
import numpy as np | ||
import os | ||
from alibi_detect.utils.frameworks import has_pytorch, has_tensorflow | ||
|
||
if has_tensorflow: | ||
import tensorflow as tf | ||
if has_pytorch: | ||
import torch | ||
|
||
# Init global seed | ||
_ALIBI_SEED = None | ||
|
||
|
||
def set_seed(seed: int): | ||
""" | ||
Sets the Python, NumPy, TensorFlow and PyTorch random seeds, and the PYTHONHASHSEED env variable. | ||
Parameters | ||
---------- | ||
seed | ||
Value of the random seed to set. | ||
""" | ||
global _ALIBI_SEED | ||
seed = max(seed, 0) # TODO: This is a fix to allow --randomly-seed=0 in setup.cfg. To be removed in future | ||
_ALIBI_SEED = seed | ||
os.environ['PYTHONHASHSEED'] = str(seed) | ||
random.seed(seed) | ||
np.random.seed(seed) | ||
if has_tensorflow: | ||
tf.random.set_seed(seed) | ||
if has_pytorch: | ||
torch.manual_seed(seed) | ||
|
||
|
||
def get_seed() -> int: | ||
""" | ||
Gets the seed set by :func:`set_seed`. | ||
Example | ||
------- | ||
>>> from alibi_detect.utils._random import set_seed, get_seed | ||
>>> set_seed(42) | ||
>>> get_seed() | ||
42 | ||
""" | ||
if _ALIBI_SEED is not None: | ||
return _ALIBI_SEED | ||
else: | ||
raise RuntimeError('`set_seed` must be called before `get_seed` can be called.') | ||
|
||
|
||
@contextmanager | ||
def fixed_seed(seed: int): | ||
""" | ||
A context manager to run with a requested random seed (applied to all the RNG's set by :func:`set_seed`). | ||
Parameters | ||
---------- | ||
seed | ||
Value of the random seed to set in the isolated context. | ||
Example | ||
------- | ||
.. code-block :: python | ||
set_seed(0) | ||
with fixed_seed(42): | ||
dd = cd.LSDDDrift(X_ref) # seeds equal 42 here | ||
p_val = dd.predict(X_h0)['data']['p_val'] | ||
# seeds equal 0 here | ||
""" | ||
orig_seed = get_seed() | ||
set_seed(seed) | ||
try: | ||
yield | ||
finally: | ||
set_seed(orig_seed) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import pytest | ||
|
||
|
||
@pytest.fixture | ||
def seed(pytestconfig): | ||
""" | ||
Returns the random seed set by pytest-randomly. | ||
""" | ||
return pytestconfig.getoption("randomly_seed") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from alibi_detect.utils._random import set_seed, get_seed, fixed_seed | ||
import numpy as np | ||
import tensorflow as tf | ||
import torch | ||
|
||
|
||
def test_set_get_seed(seed): | ||
""" | ||
Tests the set_seed and get_seed fuctions. | ||
""" | ||
# Check initial seed within test is the one set by pytest-randomly | ||
current_seed = get_seed() | ||
assert current_seed == seed | ||
|
||
# Set another seed and check | ||
new_seed = seed + 42 | ||
set_seed(new_seed) | ||
current_seed = get_seed() | ||
assert current_seed == new_seed | ||
|
||
|
||
def test_fixed_seed(seed): | ||
""" | ||
Tests the fixed_seed context manager. | ||
""" | ||
n = 5 # Length of random number sequences | ||
|
||
nums0 = [] | ||
tmp_seed = seed + 42 | ||
with fixed_seed(tmp_seed): | ||
# Generate a sequence of random numbers | ||
for i in range(n): | ||
nums0.append(np.random.normal([1])) | ||
nums0.append(tf.random.normal([1])) | ||
nums0.append(torch.normal(torch.tensor([1.0]))) | ||
|
||
# Check seed unchanged after RNG calls | ||
assert get_seed() == tmp_seed | ||
|
||
# Generate another sequence of random numbers with same seed, and check equal | ||
nums1 = [] | ||
tmp_seed = seed + 42 | ||
with fixed_seed(tmp_seed): | ||
for i in range(n): | ||
nums1.append(np.random.normal([1])) | ||
nums1.append(tf.random.normal([1])) | ||
nums1.append(torch.normal(torch.tensor([1.0]))) | ||
assert nums0 == nums1 | ||
|
||
# Generate another sequence of random numbers with different seed, and check not equal | ||
nums2 = [] | ||
tmp_seed = seed + 99 | ||
with fixed_seed(tmp_seed): | ||
for i in range(n): | ||
nums2.append(np.random.normal([1])) | ||
nums2.append(tf.random.normal([1])) | ||
nums2.append(torch.normal(torch.tensor([1.0]))) | ||
assert nums1 != nums2 | ||
|
||
# Check seeds were reset upon exit of context managers | ||
assert get_seed() == seed |
Oops, something went wrong.