-
Notifications
You must be signed in to change notification settings - Fork 28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rethinking sampling #147
Rethinking sampling #147
Conversation
…clip_min is now optional (only if it is finite)
Sorry in advance for the large PR - feel free to be critical ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, I think all of the code in this PR makes a lot of sense. I don't deeply understand all of your intent as well as you do (of course) but I think everything here is well put together.
In general, I have some broad bits of feedback:
- This PR is really large. While this is often the case necessarily with changes that aim to make large improvements across the whole codebase, it's always worth trying to make changes more focused when possible. I know you already know that though so no worries.
- We should probably come up with some strategy for dealing with the TODOs. I've worked in large projects before where each TODO had to be associated with a specific GitHub Issue for example. A lot of them relate to copy semantics which seems like a good, focused thing that we could pursue in isolation.
- I understand your intent with introducing a very generic policy_kwargs dictionary as its not possible to know what parameters might be needed by continuous off-policy exploration. I think we should keep an eye on how that winds up getting used in practice though. It may be possible to type those parameters more strongly in the future.
In general, I think this is worth merging! Not the least of which because I'm excited about #149 😄
Thanks for the feedback! I really like your idea of associating each TODO with an issue. That would also make it easier to go fix the thing (you just search for the issue number in the code). I can do this in a follow up PR! Sorry, I knew I was being naughty when I submitted this monster PR. It essentially was a grab bag of things I tried, while trying to get the library to play ball for the gflownet workshop, and it would have been really annoying to go split it out into various PRs post-hoc. I figured it wasn't too bad because I was the only one working on it but I agree this is horrible practice and not conducive to collaboration. I understand your desire for strong typing on the And just to clarify the intent of the PR: I was addressing multiple points of feedback:
|
I'll wait for @saleml (who is defending next week so is likely distracted at the minute) to merge. No rush! |
On it ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for my very late review.
This is a great PR, that would make using the library much simpler. Thanks a lot @josephdviviano.
I left a few comments, questions and suggestions. They are minor. Hopefully the tests would pass after the fixes
src/gfn/containers/trajectories.py
Outdated
@@ -65,7 +77,7 @@ def __init__( | |||
self.env = env | |||
self.is_backward = is_backward | |||
self.states = ( | |||
states | |||
states.clone() # TODO: Do we need this clone? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see why we would need that
@@ -155,6 +168,12 @@ def __getitem__(self, index: int | Sequence[int]) -> Trajectories: | |||
self._log_rewards[index] if self._log_rewards is not None else None | |||
) | |||
|
|||
if is_tensor(self.estimator_outputs): | |||
estimator_outputs = self.estimator_outputs[:, index] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This implicitly assumes that self.estimator_outputs
is of shape max_length x n_trajectories
(as is the case for example for self.log_probs
). Would this always be the case?
I feel like things would easily break here unless we force some structure on estimator_outputs
. Rather than torch.Tensor
, it has to be some TensorType
with a specific shape IMO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think of simply:
if is_tensor(self.estimator_outputs):
estimator_outputs = self.estimator_outputs[..., index]
estimator_outputs = estimator_outputs[:new_max_length]
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that should work !
# Either set, or append, estimator outputs if they exist in the submitted | ||
# trajectory. | ||
if self.estimator_outputs is None and is_tensor(other.estimator_outputs): | ||
self.estimator_outputs = other.estimator_outputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but how would we match the indices of the trajectories to the indices of the estimator_outputs ?
This feels dangerous. I suggest just throwing an error when one is None and the other is not (either one).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the idea is to be able to extend an empty Trajectories
instance, say with a stored buffer.
I agree it is dangerous but I think we should support this behaviour.
Admittedly it has been some time since I looked at this so I might be forgetting something.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair enough!
src/gfn/containers/trajectories.py
Outdated
other_shape = np.array(other.estimator_outputs.shape) | ||
required_first_dim = max(self_shape[0], other_shape[0]) | ||
|
||
# TODO: This should be a single reused function. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right ! There is a function elsewhere that does something similar. Maybe for a next PR.
src/gfn/env.py
Outdated
@@ -83,7 +79,7 @@ def reset( | |||
assert not (random and sink) | |||
|
|||
if random and seed is not None: | |||
torch.manual_seed(seed) | |||
torch.manual_seed(seed) # TODO: Improve seeding here? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you made a set_seed
function in common.py
src/gfn/gym/discrete_ebm.py
Outdated
@@ -119,6 +116,7 @@ def make_random_states_tensor( | |||
device=env.device, | |||
) | |||
|
|||
# TODO: Look into make masks - I don't think this is being called. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes. This function can safely be deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed!
|
||
def __setitem__( | ||
self, index: int | Sequence[int] | Sequence[bool], states: States | ||
) -> None: | ||
"""Set particular states of the batch.""" | ||
self.tensor[index] = states.tensor | ||
|
||
def clone(self) -> States: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what about batch_shape and log_reward attributes ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right -- I think the easiest solution here is to use deepcopy
- what do you think?
src/gfn/utils/modules.py
Outdated
arch.append(nn.Linear(hidden_dim, hidden_dim)) | ||
arch.append(activation()) | ||
self.torso = nn.Sequential(*arch) | ||
self.torso.hidden_dim = hidden_dim # TODO: what is this? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Storing the hidden_dim
attribute in self.torso
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome !
return states.tensor.float() | ||
return ( | ||
states.tensor.float() | ||
) # TODO: should we typecast here? not a true identity... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand the question
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It means that, the identity
preprocessor is typecasting data, which seems like unexpected behaviour possibly. I would expect this to return whatever tensor is already inside states
untouched.
Hey Salem -- the indexing changes I implemented to fix this have broken the tests -- I'm working on that now. I opened a can of worms here! There's likely an elegant solution. |
Finally I revered the change for that failing test. I'm not sure there's a better solution that isn't extremely complicated. I'd like to get this PR merged but we can happily revisit this issue in a future much smaller PR. |
@saleml would love you to check this before I merge :) |
This is some great work! Thank you Joseph for this very important PR. I have read your replies to my comments, and seen that the tests pass. I think this can be merged as is. Small nitpick: For this comment: #147 (comment), do you think we should add the argument in the abstract function? |
I added it! |
This PR is a hodgepodge of a few tweaks, bugfixes, and investigations related to the sampling logic, including a new simple continuous example.
estimator_outputs
are now re-used when sampling off policy.policy_kwargs
are passed around properly to do off policy sampling, hopefully in a way that remains generic.TODO:s
added around based on my observations.log_reward_min_clamp
now off by default and only defined in the gflownets, NOT the environment.off_policy
or not. This is for efficiency. If sampling happens off policy, we save estimator outputs (because we assume we will need them later, to evaluate log probabilities of actions under the policy). If it's doneon_policy
, we calculate the log-rewards during the forward pass.