-
Notifications
You must be signed in to change notification settings - Fork 28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
conditional gfn #188
conditional gfn #188
Conversation
…ionally contains a tensor of conditioning vectors (one per trajectory)
…itioning into PB and PF computation
…ule can now accept raw tensors
…bute of the trajectory
Don't worry about the tests - they should be easy to fix. I can make the chances for DB, Sub-TB, and FM pretty easily if we agree this is a good approach, before a proper review. |
|
||
or | ||
|
||
$s \mapsto (P_B(s' \mid s, c))_{s' \in Parents(s)}$. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might be worth mentioning that this is a s very specific conditioning use-case, where the condition is encoded separately, and embeddings are concatenated.
I don't think we can do a generic one, but this should be enough as an example !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What other conditioning approaches would be worth including? Cross attention?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general I would think the conditioning should be embedded / encoded separately --- or would the conditioning just need to be concatenated to the state before input? I could add support for that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think there is an exhaustive list of ways we can process the condition. What you have is great as an example. I suggest you just add a comment or doc that the user might want to write their own module
@@ -68,7 +67,28 @@ def sample_actions( | |||
the sampled actions under the probability distribution of the given | |||
states. | |||
""" | |||
estimator_output = self.estimator(states) | |||
# TODO: Should estimators instead ignore None for the conditioning vector? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wouldn't it be cleaner with fewer if else blocks ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes there's a bit of cruft with all the if-else blocks, but as it stands an estimator can either accept one or two arguments and I think it's good if it fails noisily... what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok ! makes sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added these exception_handlers
to reduce the cruft.
LGTM! Looking forward to test this feature |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm happy to see this being added to the library. Great work! Great code design, and thanks for factorizing a few other things, including the context managers / error handlers.
I left a few comments and a suggestion for the script.
@@ -32,41 +35,41 @@ class GFlowNet(ABC, nn.Module, Generic[TrainingSampleType]): | |||
def sample_trajectories( | |||
self, | |||
env: Env, | |||
n_samples: int, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it looks like you're handling the conditioning input to this function as a kwarg, whereas sampler
's sample_trajectories
have an explicit conditioning
input. I'm wondering if you have a particular reason for this choice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think maybe all functions should use an explicit conditioning
kwarg, what do you think? I can make those changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that it would be cleaner
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it should be done now, let me know if i missed something.
conditioning = torch.rand((batch_size, 1)) | ||
conditioning = (conditioning > 0.5).to(torch.float) # Randomly 1 and zero. | ||
|
||
trajectories = gflownet.sample_trajectories( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so I think I fixed this but moving to **kwargs: Any
but we have a multitude of other harder to handle pylance
issues that I'm not sure what to do about and warrants a discussion bigger than the scope of this PR, I think.
print("+ Training Conditional {}!".format(type(gflownet))) | ||
for i in (pbar := tqdm(range(n_iterations))): | ||
conditioning = torch.rand((batch_size, 1)) | ||
conditioning = (conditioning > 0.5).to(torch.float) # Randomly 1 and zero. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, the conditioning doesn't change anything in this example. While this file is a great way to show how one can code their conditional gflownet, what do you think of slightly altering the setting here, e.g. by making the environment conditional (e.g. hide one of the 4 modes if conditioning=1), and then, post-training, have some validation where we compare the resulting pair of distributions to the pair of target distributions ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, you're right. can we save this for a follow up PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I filed the issue here. if you agree I'd like to do this work separately.
gflownet = build_tb_gflownet(environment) | ||
train(environment, gflownet) | ||
|
||
gflownet = build_db_gflownet(environment) | ||
train(environment, gflownet) | ||
|
||
gflownet = build_db_mod_gflownet(environment) | ||
train(environment, gflownet) | ||
|
||
gflownet = build_subTB_gflownet(environment) | ||
train(environment, gflownet) | ||
|
||
gflownet = build_fm_gflownet(environment) | ||
train(environment, gflownet) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
argparse this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fini
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm pleasantly surprised no change is needed for the LogPartitionVarianceLoss. Right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need the conditioning information here, and I agree it's nice that the code naturally reflected that. Please correct me if I misunderstand this loss.
) -> tuple[DiscreteStates, DiscreteStates, torch.Tensor]: | ||
def to_training_samples(self, trajectories: Trajectories) -> Union[ | ||
Tuple[DiscreteStates, DiscreteStates, torch.Tensor, torch.Tensor], | ||
Tuple[DiscreteStates, DiscreteStates, None, None], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤯
@@ -240,13 +240,20 @@ def __init__( | |||
self.conditioning_module = conditioning_module | |||
self.final_module = final_module | |||
|
|||
def forward( | |||
self, states: States, conditioning: torch.tensor | |||
def _forward_trunk( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is what you call trunk
the same thing I called torso
before ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah -- let me unify the naming
LGTM! Thanks for the PR |
Supports conditioning on a tensor of
shape=[n_trajectories, n_cond_dims]
. This is passed by the user during a call to the sampler.Implemented for all GFlowNets. Note that the current version expects a particular kind of estimator. I can imagine this will lead to future changes - e.g., we should have some
Estimators
which expect huggingface models, so we can use them to produce conditioning vectors / to initialize the policy (this will obviously be a future PR).Note that the conditioning is useless in my example, we should have a better use-case envisioned for the demo. The demo currently is not complete for all GFlowNet types.