-
Notifications
You must be signed in to change notification settings - Fork 28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Easier environment definition #143
Conversation
fix __repr__ of modules
…mplemented, and log_reward simply takes the log of reward() and clips by default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great PR ! Thanks. Could you also run pre-commit run --all
at the end ? Some files will get modified
src/gfn/states.py
Outdated
@@ -371,3 +384,55 @@ def _extend(masks, first_dim): | |||
|
|||
self.forward_masks = _extend(self.forward_masks, required_first_dim) | |||
self.backward_masks = _extend(self.backward_masks, required_first_dim) | |||
|
|||
# The helper methods are convienience functions for common mask operations. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo
src/gfn/states.py
Outdated
def set_nonexit_masks(self, cond, allow_exit: bool = False): | ||
"""Sets the allowable actions according to cond, appending the exit mask. | ||
|
||
A convienience function for common mask operations. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i just apparently can't spell this one :)
src/gfn/states.py
Outdated
A convienience function for common mask operations. | ||
|
||
Args: | ||
cond: a boolean of shape (batch_shape,) + (state_shape,), which |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isn't this true only when state_shape = n_actions - 1 ?
I think you meant n_actions - 1 rather than state_shape
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you're right -- good catch
src/gfn/env.py
Outdated
@@ -184,12 +190,12 @@ def backward_step( | |||
return new_states | |||
|
|||
def reward(self, final_states: States) -> TT["batch_shape", torch.float]: | |||
"""Either this or log_reward needs to be implemented.""" | |||
return torch.exp(self.log_reward(final_states)) | |||
"""This (and potentially log_reward) needs to be implemented.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why and ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed the docs
src/gfn/gym/hypergrid.py
Outdated
self.backward_masks, | ||
) | ||
|
||
self.set_default_typing() | ||
self.forward_masks[..., :-1] = self.tensor != env.height - 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think we should use set_nonexit_masks for this line ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense, good catch!
Also, it seems like there are GitHub actions now. Do you happen to know why the checks fail ? |
no it's some kind of environment thing -- I'm going to fix that as part of
this PR once and for all :)
Joseph Viviano
@josephdviviano <https://twitter.com/josephdviviano>
viviano.ca
…On Mon, Oct 23, 2023 at 3:50 PM saleml ***@***.***> wrote:
Also, it seems like there are GitHub actions now. Do you happen to know
why the checks fail ?
—
Reply to this email directly, view it on GitHub
<#143 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AA7TL2RV5FTAX3WJSCBUV7LYA3C6VAVCNFSM6AAAAAA6JLNRN2VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTONZVHEYTKMRWHA>
.
You are receiving this because you were assigned.Message ID:
***@***.***>
|
… longer (I am not sure what introduced this bug)
…chgfn into easier_environment_definition
All comments fixed -- just waiting to see if the checks pass :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -23,16 +23,21 @@ def __init__( | |||
sf: Optional[TT["state_shape", torch.float]] = None, | |||
device_str: Optional[str] = None, | |||
preprocessor: Optional[Preprocessor] = None, | |||
log_reward_clip: Optional[float] = -100.0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good idea !
self.is_discrete = True | ||
self.log_reward_clip = log_reward_clip |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unnecessary
@@ -303,6 +303,19 @@ def __init__( | |||
self.forward_masks = cast(torch.Tensor, forward_masks) | |||
self.backward_masks = cast(torch.Tensor, backward_masks) | |||
|
|||
self.set_default_typing() | |||
|
|||
def set_default_typing(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great idea!
super().__init__(**kwargs) | ||
self.logZ_value = nn.Parameter(logZ_value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only place BoxStateFlowModule
is used is in train_box.py
:
logZ = torch.tensor(0.0, device=env.device, requires_grad=True)
# We need a LogStateFlowEstimator
module = BoxStateFlowModule(
input_dim=env.preprocessor.output_dim,
output_dim=1,
hidden_dim=args.hidden_dim,
n_hidden_layers=args.n_hidden,
torso=None, # We do not tie the parameters of the flow function to PF
logZ_value=logZ,
)
Naive pytorch question: why do we need nn.Parameter
?
def true_reward( | ||
self, final_states: DiscreteStates | ||
) -> TT["batch_shape", torch.float]: | ||
def reward(self, final_states: DiscreteStates) -> TT["batch_shape", torch.float]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right ! true_reward
was useless
src/gfn/states.py
Outdated
@@ -371,3 +384,55 @@ def _extend(masks, first_dim): | |||
|
|||
self.forward_masks = _extend(self.forward_masks, required_first_dim) | |||
self.backward_masks = _extend(self.backward_masks, required_first_dim) | |||
|
|||
# The helper methods are convenience functions for common mask operations. | |||
def set_nonexit_masks(self, cond, allow_exit: bool = False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are there other places than hypergrid.py
where this is used ?
dim=-1, | ||
).bool() | ||
|
||
def init_forward_masks(self, set_ones: bool = True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perfect
Tests pass! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a very important PR. Thank you @josephdviviano
Should this be merged to master or to stable ? |
DiscreteEnv
to make mask definition easier.log_reward
andreward
methods.