Skip to content

Commit

Permalink
documentation nits and using convienience builtins
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Nov 23, 2023
1 parent 4161dba commit 4d8a145
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
7 changes: 5 additions & 2 deletions src/gfn/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,11 @@ def backward_step(
return new_states

def reward(self, final_states: States) -> TT["batch_shape", torch.float]:
"""This (and potentially log_reward) needs to be implemented."""
raise NotImplementedError("reward function not implemented")
"""The environment's reward given a state.
This or log_reward must be implemented.
"""
raise NotImplementedError("Reward function is not implemented.")

def log_reward(self, final_states: States) -> TT["batch_shape", torch.float]:
"""Calculates the log reward (clipping small rewards)."""
Expand Down
2 changes: 1 addition & 1 deletion src/gfn/gym/hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def make_random_states_tensor(
def update_masks(self) -> None:
"Update the masks based on the current states."
self.set_default_typing()
self.forward_masks[..., :-1] = self.tensor != env.height - 1
self.set_nonexit_masks(self.tensor != env.height - 1)
self.backward_masks = self.tensor != 0

return HyperGridStates
Expand Down
6 changes: 3 additions & 3 deletions src/gfn/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,10 +389,10 @@ def _extend(masks, first_dim):
def set_nonexit_masks(self, cond, allow_exit: bool = False):
"""Sets the allowable actions according to cond, appending the exit mask.
A convienience function for common mask operations.
A convenience function for common mask operations.
Args:
cond: a boolean of shape (batch_shape,) + (state_shape,), which
cond: a boolean of shape (batch_shape,) + (n_actions - 1,), which
denotes which actions are not allowed. For example, if a state element
represents action count, and no action can be repeated more than 5
times, cond might be state.tensor >= 5.
Expand All @@ -408,7 +408,7 @@ def set_nonexit_masks(self, cond, allow_exit: bool = False):
def set_exit_masks(self, batch_idx):
"""Sets forward masks such that the only allowable next action is to exit.
A convienience function for common mask operations.
A convenience function for common mask operations.
Args:
batch_idx: A Boolean index along the batch dimension, along which to
Expand Down

0 comments on commit 4d8a145

Please sign in to comment.