diff --git a/src/gfn/env.py b/src/gfn/env.py index b98c7e5..c21f958 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -187,8 +187,11 @@ def backward_step( return new_states def reward(self, final_states: States) -> TT["batch_shape", torch.float]: - """This (and potentially log_reward) needs to be implemented.""" - raise NotImplementedError("reward function not implemented") + """The environment's reward given a state. + + This or log_reward must be implemented. + """ + raise NotImplementedError("Reward function is not implemented.") def log_reward(self, final_states: States) -> TT["batch_shape", torch.float]: """Calculates the log reward.""" diff --git a/src/gfn/gym/hypergrid.py b/src/gfn/gym/hypergrid.py index 2c4c285..71d2862 100644 --- a/src/gfn/gym/hypergrid.py +++ b/src/gfn/gym/hypergrid.py @@ -106,7 +106,12 @@ def make_random_states_tensor( def update_masks(self) -> None: "Update the masks based on the current states." self.set_default_typing() - self.forward_masks[..., :-1] = self.tensor != env.height - 1 + # Not allowed to take any action beyond the environment height, but + # allow early termination. + self.set_nonexit_action_masks( + self.tensor == env.height - 1, + allow_exit=True, + ) self.backward_masks = self.tensor != 0 return HyperGridStates diff --git a/src/gfn/states.py b/src/gfn/states.py index 13d47b0..f5d63a4 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -401,17 +401,17 @@ def _extend(masks, first_dim): self.forward_masks = _extend(self.forward_masks, required_first_dim) self.backward_masks = _extend(self.backward_masks, required_first_dim) - # The helper methods are convienience functions for common mask operations. - def set_nonexit_masks(self, cond, allow_exit: bool = False): - """Sets the allowable actions according to cond, appending the exit mask. + # The helper methods are convenience functions for common mask operations. + def set_nonexit_action_masks(self, cond, allow_exit: bool): + """Masks denoting disallowed actions according to cond, appending the exit mask. - A convienience function for common mask operations. + A convenience function for common mask operations. Args: - cond: a boolean of shape (batch_shape,) + (state_shape,), which - denotes which actions are not allowed. For example, if a state element + cond: a boolean of shape (batch_shape,) + (n_actions - 1,), which + denotes which actions are *not* allowed. For example, if a state element represents action count, and no action can be repeated more than 5 - times, cond might be state.tensor >= 5. + times, cond might be state.tensor > 5 (assuming count starts at 0). allow_exit: sets whether exiting can happen at any point in the trajectory - if so, it should be set to True. """ @@ -424,7 +424,7 @@ def set_nonexit_masks(self, cond, allow_exit: bool = False): def set_exit_masks(self, batch_idx): """Sets forward masks such that the only allowable next action is to exit. - A convienience function for common mask operations. + A convenience function for common mask operations. Args: batch_idx: A Boolean index along the batch dimension, along which to diff --git a/tutorials/examples/test_scripts.py b/tutorials/examples/test_scripts.py index 0f63316..0e2021e 100644 --- a/tutorials/examples/test_scripts.py +++ b/tutorials/examples/test_scripts.py @@ -6,6 +6,7 @@ from dataclasses import dataclass import pytest +import numpy as np from .train_box import main as train_box_main from .train_discreteebm import main as train_discreteebm_main @@ -68,13 +69,13 @@ def test_hypergrid(ndim: int, height: int): args = HypergridArgs(ndim=ndim, height=height, n_trajectories=n_trajectories) final_l1_dist = train_hypergrid_main(args) if ndim == 2 and height == 8: - assert final_l1_dist < 7.3e-4 + assert np.isclose(final_l1_dist, 9.14e-4, atol=1e-5) elif ndim == 2 and height == 16: - assert final_l1_dist < 4.8e-4 + assert np.isclose(final_l1_dist, 4.56e-4, atol=1e-5) elif ndim == 4 and height == 8: - assert final_l1_dist < 1.6e-4 + assert np.isclose(final_l1_dist, 1.6e-4, atol=1e-5) elif ndim == 4 and height == 16: - assert final_l1_dist < 2.45e-5 + assert np.isclose(final_l1_dist, 2.45e-5, atol=1e-6) @pytest.mark.parametrize("ndim", [2, 4]) @@ -84,13 +85,13 @@ def test_discreteebm(ndim: int, alpha: float): args = DiscreteEBMArgs(ndim=ndim, alpha=alpha, n_trajectories=n_trajectories) final_l1_dist = train_discreteebm_main(args) if ndim == 2 and alpha == 0.1: - assert final_l1_dist < 0.0026 + assert np.isclose(final_l1_dist, 2.97e-3, atol=1e-3) elif ndim == 2 and alpha == 1.0: - assert final_l1_dist < 0.017 + assert np.isclose(final_l1_dist, 0.017, atol=1e-3) elif ndim == 4 and alpha == 0.1: - assert final_l1_dist < 0.009 + assert np.isclose(final_l1_dist, 0.009, atol=1e-3) elif ndim == 4 and alpha == 1.0: - assert final_l1_dist < 0.062 + assert np.isclose(final_l1_dist, 0.062, atol=1e-3) @pytest.mark.parametrize("delta", [0.1, 0.25]) @@ -113,10 +114,10 @@ def test_box(delta: float, loss: str): print(args) final_jsd = train_box_main(args) if loss == "TB" and delta == 0.1: - assert final_jsd < 0.046 + assert np.isclose(final_jsd, 3.81e-2, atol=1e-3) elif loss == "DB" and delta == 0.1: - assert final_jsd < 0.18 + assert np.isclose(final_jsd, 0.134, atol=1e-2) if loss == "TB" and delta == 0.25: - assert final_jsd < 0.015 + assert np.isclose(final_jsd, 2.93e-3, atol=1e-3) elif loss == "DB" and delta == 0.25: - assert final_jsd < 0.027 + assert np.isclose(final_jsd, 0.0142, atol=1e-3) diff --git a/tutorials/examples/train_box.py b/tutorials/examples/train_box.py index 996f4c1..7483fec 100644 --- a/tutorials/examples/train_box.py +++ b/tutorials/examples/train_box.py @@ -33,6 +33,8 @@ ) from gfn.modules import ScalarEstimator +DEFAULT_SEED = 4444 + def sample_from_reward(env: Box, n_samples: int): """Samples states from the true reward distribution @@ -83,7 +85,7 @@ def estimate_jsd(kde1, kde2): def main(args): # noqa: C901 - seed = args.seed if args.seed != 0 else torch.randint(int(10e10), (1,))[0].item() + seed = args.seed if args.seed != 0 else DEFAULT_SEED torch.manual_seed(seed) device_str = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" diff --git a/tutorials/examples/train_discreteebm.py b/tutorials/examples/train_discreteebm.py index a7aab78..f5e35a9 100644 --- a/tutorials/examples/train_discreteebm.py +++ b/tutorials/examples/train_discreteebm.py @@ -23,9 +23,11 @@ from gfn.utils.common import validate from gfn.utils.modules import NeuralNet, Tabular +DEFAULT_SEED = 4444 + def main(args): # noqa: C901 - seed = args.seed if args.seed != 0 else torch.randint(int(10e10), (1,))[0].item() + seed = args.seed if args.seed != 0 else DEFAULT_SEED torch.manual_seed(seed) device_str = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index e9fd465..368d924 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -31,9 +31,11 @@ from gfn.utils.common import validate from gfn.utils.modules import DiscreteUniform, NeuralNet, Tabular +DEFAULT_SEED = 4444 + def main(args): # noqa: C901 - seed = args.seed if args.seed != 0 else torch.randint(int(10e10), (1,))[0].item() + seed = args.seed if args.seed != 0 else DEFAULT_SEED torch.manual_seed(seed) device_str = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"