Skip to content

Commit

Permalink
Merge branch 'easier_environment_definition' of github.com:saleml/tor…
Browse files Browse the repository at this point in the history
…chgfn into rethinking_sampling
  • Loading branch information
josephdviviano committed Nov 24, 2023
2 parents 1ceb53d + 84bb169 commit b67a6d2
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 26 deletions.
7 changes: 5 additions & 2 deletions src/gfn/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,11 @@ def backward_step(
return new_states

def reward(self, final_states: States) -> TT["batch_shape", torch.float]:
"""This (and potentially log_reward) needs to be implemented."""
raise NotImplementedError("reward function not implemented")
"""The environment's reward given a state.
This or log_reward must be implemented.
"""
raise NotImplementedError("Reward function is not implemented.")

def log_reward(self, final_states: States) -> TT["batch_shape", torch.float]:
"""Calculates the log reward."""
Expand Down
7 changes: 6 additions & 1 deletion src/gfn/gym/hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,12 @@ def make_random_states_tensor(
def update_masks(self) -> None:
"Update the masks based on the current states."
self.set_default_typing()
self.forward_masks[..., :-1] = self.tensor != env.height - 1
# Not allowed to take any action beyond the environment height, but
# allow early termination.
self.set_nonexit_action_masks(
self.tensor == env.height - 1,
allow_exit=True,
)
self.backward_masks = self.tensor != 0

return HyperGridStates
Expand Down
16 changes: 8 additions & 8 deletions src/gfn/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,17 +401,17 @@ def _extend(masks, first_dim):
self.forward_masks = _extend(self.forward_masks, required_first_dim)
self.backward_masks = _extend(self.backward_masks, required_first_dim)

# The helper methods are convienience functions for common mask operations.
def set_nonexit_masks(self, cond, allow_exit: bool = False):
"""Sets the allowable actions according to cond, appending the exit mask.
# The helper methods are convenience functions for common mask operations.
def set_nonexit_action_masks(self, cond, allow_exit: bool):
"""Masks denoting disallowed actions according to cond, appending the exit mask.
A convienience function for common mask operations.
A convenience function for common mask operations.
Args:
cond: a boolean of shape (batch_shape,) + (state_shape,), which
denotes which actions are not allowed. For example, if a state element
cond: a boolean of shape (batch_shape,) + (n_actions - 1,), which
denotes which actions are *not* allowed. For example, if a state element
represents action count, and no action can be repeated more than 5
times, cond might be state.tensor >= 5.
times, cond might be state.tensor > 5 (assuming count starts at 0).
allow_exit: sets whether exiting can happen at any point in the
trajectory - if so, it should be set to True.
"""
Expand All @@ -424,7 +424,7 @@ def set_nonexit_masks(self, cond, allow_exit: bool = False):
def set_exit_masks(self, batch_idx):
"""Sets forward masks such that the only allowable next action is to exit.
A convienience function for common mask operations.
A convenience function for common mask operations.
Args:
batch_idx: A Boolean index along the batch dimension, along which to
Expand Down
25 changes: 13 additions & 12 deletions tutorials/examples/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dataclasses import dataclass

import pytest
import numpy as np

from .train_box import main as train_box_main
from .train_discreteebm import main as train_discreteebm_main
Expand Down Expand Up @@ -68,13 +69,13 @@ def test_hypergrid(ndim: int, height: int):
args = HypergridArgs(ndim=ndim, height=height, n_trajectories=n_trajectories)
final_l1_dist = train_hypergrid_main(args)
if ndim == 2 and height == 8:
assert final_l1_dist < 7.3e-4
assert np.isclose(final_l1_dist, 9.14e-4, atol=1e-5)
elif ndim == 2 and height == 16:
assert final_l1_dist < 4.8e-4
assert np.isclose(final_l1_dist, 4.56e-4, atol=1e-5)
elif ndim == 4 and height == 8:
assert final_l1_dist < 1.6e-4
assert np.isclose(final_l1_dist, 1.6e-4, atol=1e-5)
elif ndim == 4 and height == 16:
assert final_l1_dist < 2.45e-5
assert np.isclose(final_l1_dist, 2.45e-5, atol=1e-6)


@pytest.mark.parametrize("ndim", [2, 4])
Expand All @@ -84,13 +85,13 @@ def test_discreteebm(ndim: int, alpha: float):
args = DiscreteEBMArgs(ndim=ndim, alpha=alpha, n_trajectories=n_trajectories)
final_l1_dist = train_discreteebm_main(args)
if ndim == 2 and alpha == 0.1:
assert final_l1_dist < 0.0026
assert np.isclose(final_l1_dist, 2.97e-3, atol=1e-3)
elif ndim == 2 and alpha == 1.0:
assert final_l1_dist < 0.017
assert np.isclose(final_l1_dist, 0.017, atol=1e-3)
elif ndim == 4 and alpha == 0.1:
assert final_l1_dist < 0.009
assert np.isclose(final_l1_dist, 0.009, atol=1e-3)
elif ndim == 4 and alpha == 1.0:
assert final_l1_dist < 0.062
assert np.isclose(final_l1_dist, 0.062, atol=1e-3)


@pytest.mark.parametrize("delta", [0.1, 0.25])
Expand All @@ -113,10 +114,10 @@ def test_box(delta: float, loss: str):
print(args)
final_jsd = train_box_main(args)
if loss == "TB" and delta == 0.1:
assert final_jsd < 0.046
assert np.isclose(final_jsd, 3.81e-2, atol=1e-3)
elif loss == "DB" and delta == 0.1:
assert final_jsd < 0.18
assert np.isclose(final_jsd, 0.134, atol=1e-2)
if loss == "TB" and delta == 0.25:
assert final_jsd < 0.015
assert np.isclose(final_jsd, 2.93e-3, atol=1e-3)
elif loss == "DB" and delta == 0.25:
assert final_jsd < 0.027
assert np.isclose(final_jsd, 0.0142, atol=1e-3)
4 changes: 3 additions & 1 deletion tutorials/examples/train_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
)
from gfn.modules import ScalarEstimator

DEFAULT_SEED = 4444


def sample_from_reward(env: Box, n_samples: int):
"""Samples states from the true reward distribution
Expand Down Expand Up @@ -83,7 +85,7 @@ def estimate_jsd(kde1, kde2):


def main(args): # noqa: C901
seed = args.seed if args.seed != 0 else torch.randint(int(10e10), (1,))[0].item()
seed = args.seed if args.seed != 0 else DEFAULT_SEED
torch.manual_seed(seed)

device_str = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
Expand Down
4 changes: 3 additions & 1 deletion tutorials/examples/train_discreteebm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
from gfn.utils.common import validate
from gfn.utils.modules import NeuralNet, Tabular

DEFAULT_SEED = 4444


def main(args): # noqa: C901
seed = args.seed if args.seed != 0 else torch.randint(int(10e10), (1,))[0].item()
seed = args.seed if args.seed != 0 else DEFAULT_SEED
torch.manual_seed(seed)

device_str = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
Expand Down
4 changes: 3 additions & 1 deletion tutorials/examples/train_hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@
from gfn.utils.common import validate
from gfn.utils.modules import DiscreteUniform, NeuralNet, Tabular

DEFAULT_SEED = 4444


def main(args): # noqa: C901
seed = args.seed if args.seed != 0 else torch.randint(int(10e10), (1,))[0].item()
seed = args.seed if args.seed != 0 else DEFAULT_SEED
torch.manual_seed(seed)

device_str = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
Expand Down

0 comments on commit b67a6d2

Please sign in to comment.