Skip to content

Commit

Permalink
Tests that the seeds get set as the user desires
Browse files Browse the repository at this point in the history
  • Loading branch information
cgpotts committed Apr 9, 2019
1 parent 681d297 commit 05a7bc5
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@
import os
import pytest
import random
import tensorflow as tf
import torch
import utils

__author__ = "Christopher Potts"
__version__ = "CS224u, Stanford, Spring 2019"

tf.enable_eager_execution()


@pytest.mark.parametrize("arg, expected", [
[
Expand Down Expand Up @@ -79,3 +83,30 @@ def test_glove2dict():
def test_get_vocab(X, n_words, expected):
result = utils.get_vocab(X, n_words=n_words)
assert result == expected


@pytest.mark.parametrize("set_value", [True, False])
def test_fix_random_seeds_system(set_value):
utils.fix_random_seeds(seed=42, set_system=set_value)
x = np.random.random()
utils.fix_random_seeds(seed=42, set_system=set_value)
y = np.random.random()
assert (x == y) == set_value


@pytest.mark.parametrize("set_value", [True, False])
def test_fix_random_seeds_pytorch(set_value):
utils.fix_random_seeds(seed=42, set_torch=set_value)
x = torch.rand(1)
utils.fix_random_seeds(seed=42, set_torch=set_value)
y = torch.rand(1)
assert (x == y) == set_value


@pytest.mark.parametrize("set_value", [True, False])
def test_fix_random_seeds_tensorflow(set_value):
utils.fix_random_seeds(seed=42, set_tensorflow=set_value)
x = tf.random.uniform([1]).numpy()
utils.fix_random_seeds(seed=42, set_tensorflow=set_value)
y = tf.random.uniform([1]).numpy()
assert (x == y) == set_value

0 comments on commit 05a7bc5

Please sign in to comment.