diff --git a/tutorials/examples/test_scripts.py b/tutorials/examples/test_scripts.py index 0f63316a..77af9ef8 100644 --- a/tutorials/examples/test_scripts.py +++ b/tutorials/examples/test_scripts.py @@ -6,6 +6,7 @@ from dataclasses import dataclass import pytest +import numpy as np from .train_box import main as train_box_main from .train_discreteebm import main as train_discreteebm_main @@ -68,13 +69,13 @@ def test_hypergrid(ndim: int, height: int): args = HypergridArgs(ndim=ndim, height=height, n_trajectories=n_trajectories) final_l1_dist = train_hypergrid_main(args) if ndim == 2 and height == 8: - assert final_l1_dist < 7.3e-4 + assert np.isclose(final_l1_dist, 8.7e-4, atol=1e-5) elif ndim == 2 and height == 16: - assert final_l1_dist < 4.8e-4 + assert np.isclose(final_l1_dist, 4.8e-4, atol=1e-5) elif ndim == 4 and height == 8: - assert final_l1_dist < 1.6e-4 + assert np.isclose(final_l1_dist, 1.6e-4, atol=1e-5) elif ndim == 4 and height == 16: - assert final_l1_dist < 2.45e-5 + assert np.isclose(final_l1_dist, 2.45e-5, atol=1e-6) @pytest.mark.parametrize("ndim", [2, 4]) @@ -84,13 +85,13 @@ def test_discreteebm(ndim: int, alpha: float): args = DiscreteEBMArgs(ndim=ndim, alpha=alpha, n_trajectories=n_trajectories) final_l1_dist = train_discreteebm_main(args) if ndim == 2 and alpha == 0.1: - assert final_l1_dist < 0.0026 + assert np.isclose(final_l1_dist, 0.0026, atol=1e-4) elif ndim == 2 and alpha == 1.0: - assert final_l1_dist < 0.017 + assert np.isclose(final_l1_dist, 0.017, atol=1e-3) elif ndim == 4 and alpha == 0.1: - assert final_l1_dist < 0.009 + assert np.isclose(final_l1_dist, 0.009, atol=1e-3) elif ndim == 4 and alpha == 1.0: - assert final_l1_dist < 0.062 + assert np.isclose(final_l1_dist, 0.062, atol=1e-3) @pytest.mark.parametrize("delta", [0.1, 0.25]) @@ -113,10 +114,10 @@ def test_box(delta: float, loss: str): print(args) final_jsd = train_box_main(args) if loss == "TB" and delta == 0.1: - assert final_jsd < 0.046 + assert np.isclose(final_jsd, 0.046, atol=1e-3) elif loss == "DB" and delta == 0.1: - assert final_jsd < 0.18 + assert np.isclose(final_jsd, 0.18, atol=1e-2) if loss == "TB" and delta == 0.25: - assert final_jsd < 0.015 + assert np.isclose(final_jsd, 0.015, atol=1e-3) elif loss == "DB" and delta == 0.25: - assert final_jsd < 0.027 + assert np.isclose(final_jsd, 0.027, atol=1e-3) diff --git a/tutorials/examples/train_box.py b/tutorials/examples/train_box.py index 996f4c1f..7483fecf 100644 --- a/tutorials/examples/train_box.py +++ b/tutorials/examples/train_box.py @@ -33,6 +33,8 @@ ) from gfn.modules import ScalarEstimator +DEFAULT_SEED = 4444 + def sample_from_reward(env: Box, n_samples: int): """Samples states from the true reward distribution @@ -83,7 +85,7 @@ def estimate_jsd(kde1, kde2): def main(args): # noqa: C901 - seed = args.seed if args.seed != 0 else torch.randint(int(10e10), (1,))[0].item() + seed = args.seed if args.seed != 0 else DEFAULT_SEED torch.manual_seed(seed) device_str = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" diff --git a/tutorials/examples/train_discreteebm.py b/tutorials/examples/train_discreteebm.py index a7aab784..f5e35a98 100644 --- a/tutorials/examples/train_discreteebm.py +++ b/tutorials/examples/train_discreteebm.py @@ -23,9 +23,11 @@ from gfn.utils.common import validate from gfn.utils.modules import NeuralNet, Tabular +DEFAULT_SEED = 4444 + def main(args): # noqa: C901 - seed = args.seed if args.seed != 0 else torch.randint(int(10e10), (1,))[0].item() + seed = args.seed if args.seed != 0 else DEFAULT_SEED torch.manual_seed(seed) device_str = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index e9fd465c..368d9243 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -31,9 +31,11 @@ from gfn.utils.common import validate from gfn.utils.modules import DiscreteUniform, NeuralNet, Tabular +DEFAULT_SEED = 4444 + def main(args): # noqa: C901 - seed = args.seed if args.seed != 0 else torch.randint(int(10e10), (1,))[0].item() + seed = args.seed if args.seed != 0 else DEFAULT_SEED torch.manual_seed(seed) device_str = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"