Skip to content

Use local RandomState instead of seeding the global RNG #12259

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions keras/datasets/boston_housing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def load_data(path='boston_housing.npz', test_split=0.2, seed=113):
x = f['x']
y = f['y']

np.random.seed(seed)
rng = np.random.RandomState(seed)
indices = np.arange(len(x))
np.random.shuffle(indices)
rng.shuffle(indices)
x = x[indices]
y = y[indices]

Expand Down
6 changes: 3 additions & 3 deletions keras/datasets/imdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,14 @@ def load_data(path='imdb.npz', num_words=None, skip_top=0,
x_train, labels_train = f['x_train'], f['y_train']
x_test, labels_test = f['x_test'], f['y_test']

np.random.seed(seed)
rng = np.random.RandomState(seed)
indices = np.arange(len(x_train))
np.random.shuffle(indices)
rng.shuffle(indices)
x_train = x_train[indices]
labels_train = labels_train[indices]

indices = np.arange(len(x_test))
np.random.shuffle(indices)
rng.shuffle(indices)
x_test = x_test[indices]
labels_test = labels_test[indices]

Expand Down
4 changes: 2 additions & 2 deletions keras/datasets/reuters.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def load_data(path='reuters.npz', num_words=None, skip_top=0,
with np.load(path) as f:
xs, labels = f['x'], f['y']

np.random.seed(seed)
rng = np.random.RandomState(seed)
indices = np.arange(len(xs))
np.random.shuffle(indices)
rng.shuffle(indices)
xs = xs[indices]
labels = labels[indices]

Expand Down
5 changes: 3 additions & 2 deletions keras/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,10 @@ def __call__(self, shape, dtype=None):
num_rows *= dim
num_cols = shape[-1]
flat_shape = (num_rows, num_cols)
rng = np.random
if self.seed is not None:
np.random.seed(self.seed)
a = np.random.normal(0.0, 1.0, flat_shape)
rng = np.random.RandomState(self.seed)
a = rng.normal(0.0, 1.0, flat_shape)
u, _, v = np.linalg.svd(a, full_matrices=False)
# Pick the one with the correct shape.
q = u if u.shape == flat_shape else v
Expand Down
90 changes: 90 additions & 0 deletions tests/keras/datasets/datasets_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import tempfile

import numpy as np
import pytest

from keras.datasets import boston_housing
from keras.datasets import imdb
from keras.datasets import reuters


@pytest.fixture
def fake_downloaded_boston_path(monkeypatch):
num_rows = 100
num_cols = 10
rng = np.random.RandomState(123)

x = rng.randint(1, 100, size=(num_rows, num_cols))
y = rng.normal(loc=100, scale=15, size=num_rows)

with tempfile.NamedTemporaryFile('wb', delete=True) as f:
np.savez(f, x=x, y=y)
monkeypatch.setattr(boston_housing, 'get_file',
lambda *args, **kwargs: f.name)
yield f.name


@pytest.fixture
def fake_downloaded_imdb_path(monkeypatch):
train_rows = 100
test_rows = 20
seq_length = 10
rng = np.random.RandomState(123)

x_train = rng.randint(1, 100, size=(train_rows, seq_length))
y_train = rng.binomial(n=1, p=0.5, size=train_rows)
x_test = rng.randint(1, 100, size=(test_rows, seq_length))
y_test = rng.binomial(n=1, p=0.5, size=test_rows)

with tempfile.NamedTemporaryFile('wb', delete=True) as f:
np.savez(f, x_train=x_train, y_train=y_train, x_test=x_test, y_test=y_test)
monkeypatch.setattr(imdb, 'get_file', lambda *args, **kwargs: f.name)
yield f.name


@pytest.fixture
def fake_downloaded_reuters_path(monkeypatch):
num_rows = 100
seq_length = 10
rng = np.random.RandomState(123)

x = rng.randint(1, 100, size=(num_rows, seq_length))
y = rng.binomial(n=1, p=0.5, size=num_rows)

with tempfile.NamedTemporaryFile('wb', delete=True) as f:
np.savez(f, x=x, y=y)
monkeypatch.setattr(reuters, 'get_file', lambda *args, **kwargs: f.name)
yield f.name


def test_boston_load_does_not_affect_global_rng(fake_downloaded_boston_path):
np.random.seed(1337)
before = np.random.randint(0, 100, size=10)

np.random.seed(1337)
boston_housing.load_data(path=fake_downloaded_boston_path, seed=9876)
after = np.random.randint(0, 100, size=10)

assert np.array_equal(before, after)


def test_imdb_load_does_not_affect_global_rng(fake_downloaded_imdb_path):
np.random.seed(1337)
before = np.random.randint(0, 100, size=10)

np.random.seed(1337)
imdb.load_data(path=fake_downloaded_imdb_path, seed=9876)
after = np.random.randint(0, 100, size=10)

assert np.array_equal(before, after)


def test_reuters_load_does_not_affect_global_rng(fake_downloaded_reuters_path):
np.random.seed(1337)
before = np.random.randint(0, 100, size=10)

np.random.seed(1337)
reuters.load_data(path=fake_downloaded_reuters_path, seed=9876)
after = np.random.randint(0, 100, size=10)

assert np.array_equal(before, after)
12 changes: 12 additions & 0 deletions tests/keras/initializers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,18 @@ def test_orthogonal(tensor_shape):
target_mean=0.)


def test_orthogonal_init_does_not_affect_global_rng():
np.random.seed(1337)
before = np.random.randint(0, 100, size=10)

np.random.seed(1337)
init = initializers.orthogonal(seed=9876)
init(shape=(10, 5))
after = np.random.randint(0, 100, size=10)

assert np.array_equal(before, after)


@pytest.mark.parametrize('tensor_shape',
[(100, 100), (10, 20), (30, 80), (1, 2, 3, 4)],
ids=['FC', 'RNN', 'RNN_INVALID', 'CONV'])
Expand Down