From 0735ceba02b4fccb61b6f0ab8fff81a50b12d284 Mon Sep 17 00:00:00 2001 From: Joseph Date: Thu, 23 Nov 2023 11:56:51 -0500 Subject: [PATCH 1/7] isort --- src/gfn/gym/helpers/box_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gfn/gym/helpers/box_utils.py b/src/gfn/gym/helpers/box_utils.py index dd9acd0..c6342c7 100644 --- a/src/gfn/gym/helpers/box_utils.py +++ b/src/gfn/gym/helpers/box_utils.py @@ -1,11 +1,11 @@ """This file contains utilitary functions for the Box environment.""" from typing import Tuple -from torch.distributions import Beta, Categorical, Distribution, MixtureSameFamily -from torchtyping import TensorType as TT import numpy as np import torch import torch.nn as nn +from torch.distributions import Beta, Categorical, Distribution, MixtureSameFamily +from torchtyping import TensorType as TT from gfn.gym import Box from gfn.modules import GFNModule From 4161dba32dd32e2434846c07efdc9b6f2f2014ad Mon Sep 17 00:00:00 2001 From: Joseph Date: Thu, 23 Nov 2023 11:58:04 -0500 Subject: [PATCH 2/7] typo fixed --- src/gfn/states.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gfn/states.py b/src/gfn/states.py index cf10288..766066a 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -385,7 +385,7 @@ def _extend(masks, first_dim): self.forward_masks = _extend(self.forward_masks, required_first_dim) self.backward_masks = _extend(self.backward_masks, required_first_dim) - # The helper methods are convienience functions for common mask operations. + # The helper methods are convenience functions for common mask operations. def set_nonexit_masks(self, cond, allow_exit: bool = False): """Sets the allowable actions according to cond, appending the exit mask. From 4d8a145008e9f02fa930b218c8fadade50fa6c05 Mon Sep 17 00:00:00 2001 From: Joseph Date: Thu, 23 Nov 2023 12:10:10 -0500 Subject: [PATCH 3/7] documentation nits and using convienience builtins --- src/gfn/env.py | 7 +++++-- src/gfn/gym/hypergrid.py | 2 +- src/gfn/states.py | 6 +++--- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/gfn/env.py b/src/gfn/env.py index 90ddd24..a21ae38 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -190,8 +190,11 @@ def backward_step( return new_states def reward(self, final_states: States) -> TT["batch_shape", torch.float]: - """This (and potentially log_reward) needs to be implemented.""" - raise NotImplementedError("reward function not implemented") + """The environment's reward given a state. + + This or log_reward must be implemented. + """ + raise NotImplementedError("Reward function is not implemented.") def log_reward(self, final_states: States) -> TT["batch_shape", torch.float]: """Calculates the log reward (clipping small rewards).""" diff --git a/src/gfn/gym/hypergrid.py b/src/gfn/gym/hypergrid.py index d54c8a7..4fe38d7 100644 --- a/src/gfn/gym/hypergrid.py +++ b/src/gfn/gym/hypergrid.py @@ -109,7 +109,7 @@ def make_random_states_tensor( def update_masks(self) -> None: "Update the masks based on the current states." self.set_default_typing() - self.forward_masks[..., :-1] = self.tensor != env.height - 1 + self.set_nonexit_masks(self.tensor != env.height - 1) self.backward_masks = self.tensor != 0 return HyperGridStates diff --git a/src/gfn/states.py b/src/gfn/states.py index 766066a..ceab537 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -389,10 +389,10 @@ def _extend(masks, first_dim): def set_nonexit_masks(self, cond, allow_exit: bool = False): """Sets the allowable actions according to cond, appending the exit mask. - A convienience function for common mask operations. + A convenience function for common mask operations. Args: - cond: a boolean of shape (batch_shape,) + (state_shape,), which + cond: a boolean of shape (batch_shape,) + (n_actions - 1,), which denotes which actions are not allowed. For example, if a state element represents action count, and no action can be repeated more than 5 times, cond might be state.tensor >= 5. @@ -408,7 +408,7 @@ def set_nonexit_masks(self, cond, allow_exit: bool = False): def set_exit_masks(self, batch_idx): """Sets forward masks such that the only allowable next action is to exit. - A convienience function for common mask operations. + A convenience function for common mask operations. Args: batch_idx: A Boolean index along the batch dimension, along which to From 459203cfa5019d95766b9551d7f35e40ea634f75 Mon Sep 17 00:00:00 2001 From: Joseph Date: Thu, 23 Nov 2023 14:38:27 -0500 Subject: [PATCH 4/7] using set_nonexit_action_masks properly, with a comment for clarity --- src/gfn/gym/hypergrid.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/gfn/gym/hypergrid.py b/src/gfn/gym/hypergrid.py index 4fe38d7..028f716 100644 --- a/src/gfn/gym/hypergrid.py +++ b/src/gfn/gym/hypergrid.py @@ -109,7 +109,12 @@ def make_random_states_tensor( def update_masks(self) -> None: "Update the masks based on the current states." self.set_default_typing() - self.set_nonexit_masks(self.tensor != env.height - 1) + # Not allowed to take any action beyond the environment height, but + # allow early termination. + self.set_nonexit_action_masks( + self.tensor == env.height - 1, + allow_exit=True, + ) self.backward_masks = self.tensor != 0 return HyperGridStates From fb7e819f858f94f65bc3e23c12b6ce3bd883d46c Mon Sep 17 00:00:00 2001 From: Joseph Date: Thu, 23 Nov 2023 14:39:01 -0500 Subject: [PATCH 5/7] Improved method name and documentaiton --- src/gfn/states.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/gfn/states.py b/src/gfn/states.py index ceab537..0b631f2 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -386,16 +386,16 @@ def _extend(masks, first_dim): self.backward_masks = _extend(self.backward_masks, required_first_dim) # The helper methods are convenience functions for common mask operations. - def set_nonexit_masks(self, cond, allow_exit: bool = False): - """Sets the allowable actions according to cond, appending the exit mask. + def set_nonexit_action_masks(self, cond, allow_exit: bool): + """Masks denoting disallowed actions according to cond, appending the exit mask. A convenience function for common mask operations. Args: cond: a boolean of shape (batch_shape,) + (n_actions - 1,), which - denotes which actions are not allowed. For example, if a state element + denotes which actions are *not* allowed. For example, if a state element represents action count, and no action can be repeated more than 5 - times, cond might be state.tensor >= 5. + times, cond might be state.tensor > 5 (assuming count starts at 0). allow_exit: sets whether exiting can happen at any point in the trajectory - if so, it should be set to True. """ From 226185fe207723ebb1d9b70212bf4f7c8ad48ca0 Mon Sep 17 00:00:00 2001 From: Joseph Date: Thu, 23 Nov 2023 18:41:26 -0500 Subject: [PATCH 6/7] changed test logic (will run on github instead of this terrible computer) --- tutorials/examples/test_scripts.py | 25 +++++++++++++------------ tutorials/examples/train_box.py | 4 +++- tutorials/examples/train_discreteebm.py | 4 +++- tutorials/examples/train_hypergrid.py | 4 +++- 4 files changed, 22 insertions(+), 15 deletions(-) diff --git a/tutorials/examples/test_scripts.py b/tutorials/examples/test_scripts.py index 0f63316..77af9ef 100644 --- a/tutorials/examples/test_scripts.py +++ b/tutorials/examples/test_scripts.py @@ -6,6 +6,7 @@ from dataclasses import dataclass import pytest +import numpy as np from .train_box import main as train_box_main from .train_discreteebm import main as train_discreteebm_main @@ -68,13 +69,13 @@ def test_hypergrid(ndim: int, height: int): args = HypergridArgs(ndim=ndim, height=height, n_trajectories=n_trajectories) final_l1_dist = train_hypergrid_main(args) if ndim == 2 and height == 8: - assert final_l1_dist < 7.3e-4 + assert np.isclose(final_l1_dist, 8.7e-4, atol=1e-5) elif ndim == 2 and height == 16: - assert final_l1_dist < 4.8e-4 + assert np.isclose(final_l1_dist, 4.8e-4, atol=1e-5) elif ndim == 4 and height == 8: - assert final_l1_dist < 1.6e-4 + assert np.isclose(final_l1_dist, 1.6e-4, atol=1e-5) elif ndim == 4 and height == 16: - assert final_l1_dist < 2.45e-5 + assert np.isclose(final_l1_dist, 2.45e-5, atol=1e-6) @pytest.mark.parametrize("ndim", [2, 4]) @@ -84,13 +85,13 @@ def test_discreteebm(ndim: int, alpha: float): args = DiscreteEBMArgs(ndim=ndim, alpha=alpha, n_trajectories=n_trajectories) final_l1_dist = train_discreteebm_main(args) if ndim == 2 and alpha == 0.1: - assert final_l1_dist < 0.0026 + assert np.isclose(final_l1_dist, 0.0026, atol=1e-4) elif ndim == 2 and alpha == 1.0: - assert final_l1_dist < 0.017 + assert np.isclose(final_l1_dist, 0.017, atol=1e-3) elif ndim == 4 and alpha == 0.1: - assert final_l1_dist < 0.009 + assert np.isclose(final_l1_dist, 0.009, atol=1e-3) elif ndim == 4 and alpha == 1.0: - assert final_l1_dist < 0.062 + assert np.isclose(final_l1_dist, 0.062, atol=1e-3) @pytest.mark.parametrize("delta", [0.1, 0.25]) @@ -113,10 +114,10 @@ def test_box(delta: float, loss: str): print(args) final_jsd = train_box_main(args) if loss == "TB" and delta == 0.1: - assert final_jsd < 0.046 + assert np.isclose(final_jsd, 0.046, atol=1e-3) elif loss == "DB" and delta == 0.1: - assert final_jsd < 0.18 + assert np.isclose(final_jsd, 0.18, atol=1e-2) if loss == "TB" and delta == 0.25: - assert final_jsd < 0.015 + assert np.isclose(final_jsd, 0.015, atol=1e-3) elif loss == "DB" and delta == 0.25: - assert final_jsd < 0.027 + assert np.isclose(final_jsd, 0.027, atol=1e-3) diff --git a/tutorials/examples/train_box.py b/tutorials/examples/train_box.py index 996f4c1..7483fec 100644 --- a/tutorials/examples/train_box.py +++ b/tutorials/examples/train_box.py @@ -33,6 +33,8 @@ ) from gfn.modules import ScalarEstimator +DEFAULT_SEED = 4444 + def sample_from_reward(env: Box, n_samples: int): """Samples states from the true reward distribution @@ -83,7 +85,7 @@ def estimate_jsd(kde1, kde2): def main(args): # noqa: C901 - seed = args.seed if args.seed != 0 else torch.randint(int(10e10), (1,))[0].item() + seed = args.seed if args.seed != 0 else DEFAULT_SEED torch.manual_seed(seed) device_str = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" diff --git a/tutorials/examples/train_discreteebm.py b/tutorials/examples/train_discreteebm.py index a7aab78..f5e35a9 100644 --- a/tutorials/examples/train_discreteebm.py +++ b/tutorials/examples/train_discreteebm.py @@ -23,9 +23,11 @@ from gfn.utils.common import validate from gfn.utils.modules import NeuralNet, Tabular +DEFAULT_SEED = 4444 + def main(args): # noqa: C901 - seed = args.seed if args.seed != 0 else torch.randint(int(10e10), (1,))[0].item() + seed = args.seed if args.seed != 0 else DEFAULT_SEED torch.manual_seed(seed) device_str = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index e9fd465..368d924 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -31,9 +31,11 @@ from gfn.utils.common import validate from gfn.utils.modules import DiscreteUniform, NeuralNet, Tabular +DEFAULT_SEED = 4444 + def main(args): # noqa: C901 - seed = args.seed if args.seed != 0 else torch.randint(int(10e10), (1,))[0].item() + seed = args.seed if args.seed != 0 else DEFAULT_SEED torch.manual_seed(seed) device_str = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" From 84bb16923d72eef75cb2068e76544a8fa9d86d4e Mon Sep 17 00:00:00 2001 From: Joseph Date: Thu, 23 Nov 2023 23:08:49 -0500 Subject: [PATCH 7/7] new targets --- tutorials/examples/test_scripts.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tutorials/examples/test_scripts.py b/tutorials/examples/test_scripts.py index 77af9ef..0e2021e 100644 --- a/tutorials/examples/test_scripts.py +++ b/tutorials/examples/test_scripts.py @@ -69,9 +69,9 @@ def test_hypergrid(ndim: int, height: int): args = HypergridArgs(ndim=ndim, height=height, n_trajectories=n_trajectories) final_l1_dist = train_hypergrid_main(args) if ndim == 2 and height == 8: - assert np.isclose(final_l1_dist, 8.7e-4, atol=1e-5) + assert np.isclose(final_l1_dist, 9.14e-4, atol=1e-5) elif ndim == 2 and height == 16: - assert np.isclose(final_l1_dist, 4.8e-4, atol=1e-5) + assert np.isclose(final_l1_dist, 4.56e-4, atol=1e-5) elif ndim == 4 and height == 8: assert np.isclose(final_l1_dist, 1.6e-4, atol=1e-5) elif ndim == 4 and height == 16: @@ -85,7 +85,7 @@ def test_discreteebm(ndim: int, alpha: float): args = DiscreteEBMArgs(ndim=ndim, alpha=alpha, n_trajectories=n_trajectories) final_l1_dist = train_discreteebm_main(args) if ndim == 2 and alpha == 0.1: - assert np.isclose(final_l1_dist, 0.0026, atol=1e-4) + assert np.isclose(final_l1_dist, 2.97e-3, atol=1e-3) elif ndim == 2 and alpha == 1.0: assert np.isclose(final_l1_dist, 0.017, atol=1e-3) elif ndim == 4 and alpha == 0.1: @@ -114,10 +114,10 @@ def test_box(delta: float, loss: str): print(args) final_jsd = train_box_main(args) if loss == "TB" and delta == 0.1: - assert np.isclose(final_jsd, 0.046, atol=1e-3) + assert np.isclose(final_jsd, 3.81e-2, atol=1e-3) elif loss == "DB" and delta == 0.1: - assert np.isclose(final_jsd, 0.18, atol=1e-2) + assert np.isclose(final_jsd, 0.134, atol=1e-2) if loss == "TB" and delta == 0.25: - assert np.isclose(final_jsd, 0.015, atol=1e-3) + assert np.isclose(final_jsd, 2.93e-3, atol=1e-3) elif loss == "DB" and delta == 0.25: - assert np.isclose(final_jsd, 0.027, atol=1e-3) + assert np.isclose(final_jsd, 0.0142, atol=1e-3)