Skip to content

Commit

Permalink
Random seed utilities (#496)
Browse files Browse the repository at this point in the history
Utility functions to get and set seeds for the Python, Numpy, TensorFlow and PyTorch random number generators.
  • Loading branch information
ascillitoe authored May 16, 2022
1 parent 1398431 commit a87153b
Show file tree
Hide file tree
Showing 14 changed files with 308 additions and 117 deletions.
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ repos:
hooks:
- id: mypy
additional_dependencies: [
types-requests>=2.25.0,
types-requests~=2.25,
types-toml~=0.10
]
1 change: 0 additions & 1 deletion alibi_detect/cd/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1373,4 +1373,3 @@ def predict(self, # type: ignore[override]
cd['data']['coupling_yy'] = coupling[1]
cd['data']['coupling_xy'] = coupling[2]
return cd

1 change: 0 additions & 1 deletion alibi_detect/cd/sklearn/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,4 +298,3 @@ def get_config(self) -> dict:
The detector's configuration dictionary.
"""
raise NotImplementedError("get_config not yet implemented for `ClassifierDrift` with sklearn backend.")

2 changes: 1 addition & 1 deletion alibi_detect/saving/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def write_config(cfg: dict, filepath: Union[str, os.PathLike]):
cfg = _replace(cfg, None, "None") # Note: None replaced with "None" as None/null not valid TOML
logger.info('Writing config to {}'.format(filepath.joinpath('config.toml')))
with open(filepath.joinpath('config.toml'), 'w') as f:
toml.dump(cfg, f, encoder=toml.TomlNumpyEncoder()) # type: ignore[call-arg, attr-defined]
toml.dump(cfg, f, encoder=toml.TomlNumpyEncoder()) # type: ignore[misc]


def _save_preprocess_config(preprocess_fn: Callable,
Expand Down
12 changes: 6 additions & 6 deletions alibi_detect/saving/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ class KernelConfig(CustomBaseModel):
"A string referencing a filepath to a serialized kernel in `.dill` format, or an object registry reference."

# Below kwargs are only passed if kernel == @GaussianRBF
sigma: Optional[NDArray[float]] = None
sigma: Optional[NDArray[np.float32]] = None
"""
Bandwidth used for the kernel. Needn’t be specified if being inferred or trained. Can pass multiple values to eval
kernel with and then average.
Expand All @@ -367,7 +367,7 @@ class KernelConfigResolved(CustomBaseModel):
"The kernel."

# Below kwargs are only passed if kernel == @GaussianRBF
sigma: Optional[NDArray[float]] = None
sigma: Optional[NDArray[np.float32]] = None
"""
Bandwidth used for the kernel. Needn’t be specified if being inferred or trained. Can pass multiple values to eval
kernel with and then average.
Expand Down Expand Up @@ -660,7 +660,7 @@ class MMDDriftConfig(DriftDetectorConfig):
preprocess_at_init: bool = True
update_x_ref: Optional[Dict[str, int]] = None
kernel: Optional[Union[str, KernelConfig]] = None
sigma: Optional[NDArray[float]] = None
sigma: Optional[NDArray[np.float32]] = None
configure_kernel_from_x_ref: bool = True
n_permutations: int = 100
device: Optional[Literal['cpu', 'cuda']] = None
Expand All @@ -677,7 +677,7 @@ class MMDDriftConfigResolved(DriftDetectorConfigResolved):
preprocess_at_init: bool = True
update_x_ref: Optional[Dict[str, int]] = None
kernel: Optional[Union[Callable, KernelConfigResolved]] = None
sigma: Optional[NDArray[float]] = None
sigma: Optional[NDArray[np.float32]] = None
configure_kernel_from_x_ref: bool = True
n_permutations: int = 100
device: Optional[Literal['cpu', 'cuda']] = None
Expand All @@ -693,7 +693,7 @@ class LSDDDriftConfig(DriftDetectorConfig):
"""
preprocess_at_init: bool = True
update_x_ref: Optional[Dict[str, int]] = None
sigma: Optional[NDArray[float]] = None
sigma: Optional[NDArray[np.float32]] = None
n_permutations: int = 100
n_kernel_centers: Optional[int] = None
lambda_rd_max: float = 0.2
Expand All @@ -710,7 +710,7 @@ class LSDDDriftConfigResolved(DriftDetectorConfigResolved):
"""
preprocess_at_init: bool = True
update_x_ref: Optional[Dict[str, int]] = None
sigma: Optional[NDArray[float]] = None
sigma: Optional[NDArray[np.float32]] = None
n_permutations: int = 100
n_kernel_centers: Optional[int] = None
lambda_rd_max: float = 0.2
Expand Down
84 changes: 84 additions & 0 deletions alibi_detect/utils/_random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""
This submodule contains utility functions to manage random number generator (RNG) seeds. It may change
depending on how we decide to handle randomisation in tests (and elsewhere) going forwards. See
https://github.com/SeldonIO/alibi-detect/issues/250.
"""
from contextlib import contextmanager
import random
import numpy as np
import os
from alibi_detect.utils.frameworks import has_pytorch, has_tensorflow

if has_tensorflow:
import tensorflow as tf
if has_pytorch:
import torch

# Init global seed
_ALIBI_SEED = None


def set_seed(seed: int):
"""
Sets the Python, NumPy, TensorFlow and PyTorch random seeds, and the PYTHONHASHSEED env variable.
Parameters
----------
seed
Value of the random seed to set.
"""
global _ALIBI_SEED
seed = max(seed, 0) # TODO: This is a fix to allow --randomly-seed=0 in setup.cfg. To be removed in future
_ALIBI_SEED = seed
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
if has_tensorflow:
tf.random.set_seed(seed)
if has_pytorch:
torch.manual_seed(seed)


def get_seed() -> int:
"""
Gets the seed set by :func:`set_seed`.
Example
-------
>>> from alibi_detect.utils._random import set_seed, get_seed
>>> set_seed(42)
>>> get_seed()
42
"""
if _ALIBI_SEED is not None:
return _ALIBI_SEED
else:
raise RuntimeError('`set_seed` must be called before `get_seed` can be called.')


@contextmanager
def fixed_seed(seed: int):
"""
A context manager to run with a requested random seed (applied to all the RNG's set by :func:`set_seed`).
Parameters
----------
seed
Value of the random seed to set in the isolated context.
Example
-------
.. code-block :: python
set_seed(0)
with fixed_seed(42):
dd = cd.LSDDDrift(X_ref) # seeds equal 42 here
p_val = dd.predict(X_h0)['data']['p_val']
# seeds equal 0 here
"""
orig_seed = get_seed()
set_seed(seed)
try:
yield
finally:
set_seed(orig_seed)
2 changes: 1 addition & 1 deletion alibi_detect/utils/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def validate(cls, val: Any, field: ModelField) -> np.ndarray:
return _validate(cls, val, field)

else:
class NDArray(Generic[T], np.ndarray[Any, T]): # type: ignore[no-redef]
class NDArray(Generic[T], np.ndarray[Any, T]): # type: ignore[no-redef, type-var]
"""
A Generic pydantic model to validate (and coerce) np.ndarray's.
"""
Expand Down
9 changes: 9 additions & 0 deletions alibi_detect/utils/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import pytest


@pytest.fixture
def seed(pytestconfig):
"""
Returns the random seed set by pytest-randomly.
"""
return pytestconfig.getoption("randomly_seed")
61 changes: 61 additions & 0 deletions alibi_detect/utils/tests/test_random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from alibi_detect.utils._random import set_seed, get_seed, fixed_seed
import numpy as np
import tensorflow as tf
import torch


def test_set_get_seed(seed):
"""
Tests the set_seed and get_seed fuctions.
"""
# Check initial seed within test is the one set by pytest-randomly
current_seed = get_seed()
assert current_seed == seed

# Set another seed and check
new_seed = seed + 42
set_seed(new_seed)
current_seed = get_seed()
assert current_seed == new_seed


def test_fixed_seed(seed):
"""
Tests the fixed_seed context manager.
"""
n = 5 # Length of random number sequences

nums0 = []
tmp_seed = seed + 42
with fixed_seed(tmp_seed):
# Generate a sequence of random numbers
for i in range(n):
nums0.append(np.random.normal([1]))
nums0.append(tf.random.normal([1]))
nums0.append(torch.normal(torch.tensor([1.0])))

# Check seed unchanged after RNG calls
assert get_seed() == tmp_seed

# Generate another sequence of random numbers with same seed, and check equal
nums1 = []
tmp_seed = seed + 42
with fixed_seed(tmp_seed):
for i in range(n):
nums1.append(np.random.normal([1]))
nums1.append(tf.random.normal([1]))
nums1.append(torch.normal(torch.tensor([1.0])))
assert nums0 == nums1

# Generate another sequence of random numbers with different seed, and check not equal
nums2 = []
tmp_seed = seed + 99
with fixed_seed(tmp_seed):
for i in range(n):
nums2.append(np.random.normal([1]))
nums2.append(tf.random.normal([1]))
nums2.append(torch.normal(torch.tensor([1.0])))
assert nums1 != nums2

# Check seeds were reset upon exit of context managers
assert get_seed() == seed
Loading

0 comments on commit a87153b

Please sign in to comment.