Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

conditional gfn #188

Merged
merged 33 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
6e8dc4d
example of conditional GFN computation with TB only (for now)
josephdviviano Sep 25, 2024
39fb5ee
should be no change
josephdviviano Sep 25, 2024
2bc2263
Trajectories objects now have an optional .conditonal field which opt…
josephdviviano Sep 25, 2024
99afaf3
small changes to logz paramater handling, optionally incorporate cond…
josephdviviano Sep 25, 2024
e6d25a0
logZ is optionally computed using a conditioning vector
josephdviviano Sep 25, 2024
2c72bf9
NeuralNets now have input/output dims
josephdviviano Sep 25, 2024
580c455
added a ConditionalDiscretePolicyEstimator, and the forward of GFNMod…
josephdviviano Sep 25, 2024
a74872f
added conditioning to sampler, which will save the tensor as an attri…
josephdviviano Sep 25, 2024
056d935
black
josephdviviano Sep 25, 2024
96b725c
API changes adapted
josephdviviano Oct 1, 2024
5cd32a7
added conditioning to all gflownets
josephdviviano Oct 1, 2024
877c4a0
both trajectories and transitions can now store a conditioning tensor
josephdviviano Oct 1, 2024
279a313
input_dim setting is now private
josephdviviano Oct 1, 2024
65135c1
added exception handling for all estimator calls potentially involvin…
josephdviviano Oct 1, 2024
b4c418c
API change -- n vs. n_trajectories
josephdviviano Oct 1, 2024
738b062
change test_box target value
josephdviviano Oct 1, 2024
4434e5f
API changes
josephdviviano Oct 1, 2024
851e03e
hacky fix for problematic test (added TODO)
josephdviviano Oct 1, 2024
5152295
working examples for all 4 major losses
josephdviviano Oct 4, 2024
1d64b55
added conditioning indexing for correct broadcasting
josephdviviano Oct 4, 2024
348ee82
added a ConditionalScalarEstimator which subclasses ConditionalDiscre…
josephdviviano Oct 4, 2024
9120afe
added modified DB example
josephdviviano Oct 4, 2024
f59f4de
conditioning added to modified db example
josephdviviano Oct 4, 2024
c5ef7ea
black
josephdviviano Oct 4, 2024
d67dfd5
reorganized keyword arguments and fixed some type errors (not all)
josephdviviano Oct 9, 2024
d56a798
reorganized keyword arguments and fixed some type errors (not all)
josephdviviano Oct 9, 2024
db8844c
added typing and a ConditionalScalarEstimator
josephdviviano Oct 9, 2024
e03c03a
added typing
josephdviviano Oct 9, 2024
6b47e06
typing
josephdviviano Oct 9, 2024
988faf0
typing
josephdviviano Oct 9, 2024
f2bbce3
added kwargs
josephdviviano Oct 9, 2024
eb13a2d
renamed torso to trunk
josephdviviano Oct 24, 2024
fd3d9dc
renamed torso to trunk
josephdviviano Oct 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
self,
env: Env,
states: States | None = None,
conditioning: torch.Tensor | None = None,
actions: Actions | None = None,
when_is_done: TT["n_trajectories", torch.long] | None = None,
is_backward: bool = False,
Expand All @@ -76,6 +77,7 @@ def __init__(
is used to compute the rewards, at each call of self.log_rewards
"""
self.env = env
self.conditioning = conditioning
self.is_backward = is_backward
self.states = (
states if states is not None else env.states_from_batch_shape((0, 0))
Expand Down
59 changes: 55 additions & 4 deletions src/gfn/gflownet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ def sample_terminating_states(self, env: Env, n_samples: int) -> States:
return trajectories.last_states

def logz_named_parameters(self):
return {"logZ": dict(self.named_parameters())["logZ"]}
return {k: v for k, v in dict(self.named_parameters()).items() if "logZ" in k}

def logz_parameters(self):
return [dict(self.named_parameters())["logZ"]]
return [v for k, v in dict(self.named_parameters()).items() if "logZ" in k]

@abstractmethod
def to_training_samples(self, trajectories: Trajectories) -> TrainingSampleType:
Expand Down Expand Up @@ -176,7 +176,34 @@ def get_pfs_and_pbs(
~trajectories.actions.is_dummy
]
else:
estimator_outputs = self.pf(valid_states)
if trajectories.conditioning is not None:
cond_dim = (-1,) * len(trajectories.conditioning.shape)
traj_len = trajectories.states.tensor.shape[0]
masked_cond = trajectories.conditioning.unsqueeze(0).expand(
(traj_len,) + cond_dim
)[~trajectories.states.is_sink_state]

# Here, we pass all valid states, i.e., non-sink states.
try:
estimator_outputs = self.pf(valid_states, masked_cond)
except TypeError as e:
print(
"conditioning was passed but `pf` is {}".format(
type(self.pf)
)
)
raise e
else:
# Here, we pass all valid states, i.e., non-sink states.
try:
estimator_outputs = self.pf(valid_states)
except TypeError as e:
print(
"conditioning was not passed but `pf` is {}".format(
type(self.pf)
)
)
raise e

# Calculates the log PF of the actions sampled off policy.
valid_log_pf_actions = self.pf.to_probability_distribution(
Expand All @@ -196,7 +223,31 @@ def get_pfs_and_pbs(

# Using all non-initial states, calculate the backward policy, and the logprobs
# of those actions.
estimator_outputs = self.pb(non_initial_valid_states)
if trajectories.conditioning is not None:

# We need to index the conditioning vector to broadcast over the states.
cond_dim = (-1,) * len(trajectories.conditioning.shape)
traj_len = trajectories.states.tensor.shape[0]
masked_cond = trajectories.conditioning.unsqueeze(0).expand(
(traj_len,) + cond_dim
)[~trajectories.states.is_sink_state][~valid_states.is_initial_state]

# Pass all valid states, i.e., non-sink states, except the initial state.
try:
estimator_outputs = self.pb(non_initial_valid_states, masked_cond)
except TypeError as e:
print("conditioning was passed but `pb` is {}".format(type(self.pb)))
raise e
else:
# Pass all valid states, i.e., non-sink states, except the initial state.
try:
estimator_outputs = self.pb(non_initial_valid_states)
except TypeError as e:
print(
"conditioning was not passed but `pb` is {}".format(type(self.pb))
)
raise e

valid_log_pb_actions = self.pb.to_probability_distribution(
non_initial_valid_states, estimator_outputs
).log_prob(non_exit_valid_actions.tensor)
Expand Down
10 changes: 9 additions & 1 deletion src/gfn/gflownet/trajectory_balance.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pleasantly surprised no change is needed for the LogPartitionVarianceLoss. Right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need the conditioning information here, and I agree it's nice that the code naturally reflected that. Please correct me if I misunderstand this loss.

Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,15 @@ def loss(
_, _, scores = self.get_trajectories_scores(
trajectories, recalculate_all_logprobs=recalculate_all_logprobs
)
loss = (scores + self.logZ).pow(2).mean()

# If the conditioning values exist, we pass them to self.logZ
# (should be a ScalarEstimator or equivilant).
if trajectories.conditioning is not None:
logZ = self.logZ(trajectories.conditioning)
else:
logZ = self.logZ

loss = (scores + logZ.squeeze()).pow(2).mean()
if torch.isnan(loss):
raise ValueError("loss is nan")

Expand Down
62 changes: 59 additions & 3 deletions src/gfn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,14 @@ def __init__(
self._output_dim_is_checked = False
self.is_backward = is_backward

def forward(self, states: States) -> TT["batch_shape", "output_dim", float]:
out = self.module(self.preprocessor(states))
def forward(
self, input: States | torch.Tensor
) -> TT["batch_shape", "output_dim", float]:
if isinstance(input, States):
input = self.preprocessor(input)

out = self.module(input)

if not self._output_dim_is_checked:
self.check_output_dim(out)
self._output_dim_is_checked = True
Expand Down Expand Up @@ -193,6 +199,56 @@ def to_probability_distribution(

return UnsqueezedCategorical(probs=probs)

# LogEdgeFlows are greedy, as are more P_B.
# LogEdgeFlows are greedy, as are most P_B.
else:
return UnsqueezedCategorical(logits=logits)


class ConditionalDiscretePolicyEstimator(DiscretePolicyEstimator):
r"""Container for forward and backward policy estimators for discrete environments.

$s \mapsto (P_F(s' \mid s, c))_{s' \in Children(s)}$.

or

$s \mapsto (P_B(s' \mid s, c))_{s' \in Parents(s)}$.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be worth mentioning that this is a s very specific conditioning use-case, where the condition is encoded separately, and embeddings are concatenated.

I don't think we can do a generic one, but this should be enough as an example !

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What other conditioning approaches would be worth including? Cross attention?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general I would think the conditioning should be embedded / encoded separately --- or would the conditioning just need to be concatenated to the state before input? I could add support for that.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there is an exhaustive list of ways we can process the condition. What you have is great as an example. I suggest you just add a comment or doc that the user might want to write their own module


Attributes:
temperature: scalar to divide the logits by before softmax.
sf_bias: scalar to subtract from the exit action logit before dividing by
temperature.
epsilon: with probability epsilon, a random action is chosen.
"""

def __init__(
self,
state_module: nn.Module,
conditioning_module: nn.Module,
final_module: nn.Module,
n_actions: int,
preprocessor: Preprocessor | None,
is_backward: bool = False,
):
"""Initializes a estimator for P_F for discrete environments.

Args:
n_actions: Total number of actions in the Discrete Environment.
is_backward: if False, then this is a forward policy, else backward policy.
"""
super().__init__(state_module, n_actions, preprocessor, is_backward)
self.n_actions = n_actions
self.conditioning_module = conditioning_module
self.final_module = final_module

def forward(
self, states: States, conditioning: torch.tensor
) -> TT["batch_shape", "output_dim", float]:
state_out = self.module(self.preprocessor(states))
conditioning_out = self.conditioning_module(conditioning)
out = self.final_module(torch.cat((state_out, conditioning_out), -1))

if not self._output_dim_is_checked:
self.check_output_dim(out)
self._output_dim_is_checked = True

return out
43 changes: 38 additions & 5 deletions src/gfn/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,14 @@ class Sampler:
estimator: the submitted PolicyEstimator.
"""

def __init__(
self,
estimator: GFNModule,
) -> None:
def __init__(self, estimator: GFNModule) -> None:
self.estimator = estimator

def sample_actions(
self,
env: Env,
states: States,
conditioning: torch.Tensor = None,
save_estimator_outputs: bool = False,
save_logprobs: bool = True,
**policy_kwargs: Optional[dict],
Expand All @@ -45,6 +43,7 @@ def sample_actions(
estimator: A GFNModule to pass to the probability distribution calculator.
env: The environment to sample actions from.
states: A batch of states.
conditioning: An optional tensor of conditioning information.
save_estimator_outputs: If True, the estimator outputs will be returned.
save_logprobs: If True, calculates and saves the log probabilities of sampled
actions.
Expand All @@ -68,7 +67,28 @@ def sample_actions(
the sampled actions under the probability distribution of the given
states.
"""
estimator_output = self.estimator(states)
# TODO: Should estimators instead ignore None for the conditioning vector?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't it be cleaner with fewer if else blocks ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes there's a bit of cruft with all the if-else blocks, but as it stands an estimator can either accept one or two arguments and I think it's good if it fails noisily... what do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok ! makes sense.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added these exception_handlers to reduce the cruft.

if conditioning is not None:
try:
estimator_output = self.estimator(states, conditioning)
except TypeError as e:
print(
"conditioning was passed but `estimator` is {}".format(
type(self.estimator)
)
)
raise e
else:
try:
estimator_output = self.estimator(states)
except TypeError as e:
print(
"conditioning was not passed but `estimator` is {}".format(
type(self.estimator)
)
)
raise e

dist = self.estimator.to_probability_distribution(
states, estimator_output, **policy_kwargs
)
Expand All @@ -94,6 +114,7 @@ def sample_trajectories(
self,
env: Env,
states: Optional[States] = None,
conditioning: Optional[torch.Tensor] = None,
n_trajectories: Optional[int] = None,
save_estimator_outputs: bool = False,
save_logprobs: bool = True,
Expand All @@ -105,6 +126,7 @@ def sample_trajectories(
env: The environment to sample trajectories from.
states: If given, trajectories would start from such states. Otherwise,
trajectories are sampled from $s_o$ and n_trajectories must be provided.
conditioning: An optional tensor of conditioning information.
n_trajectories: If given, a batch of n_trajectories will be sampled all
starting from the environment's s_0.
save_estimator_outputs: If True, the estimator outputs will be returned. This
Expand Down Expand Up @@ -136,6 +158,9 @@ def sample_trajectories(
), "States should be a linear batch of states"
n_trajectories = states.batch_shape[0]

if conditioning is not None:
assert states.batch_shape == conditioning.shape[: len(states.batch_shape)]

device = states.tensor.device

dones = (
Expand Down Expand Up @@ -166,9 +191,15 @@ def sample_trajectories(
# during sampling. This is useful if, for example, you want to evaluate off
# policy actions later without repeating calculations to obtain the env
# distribution parameters.
if conditioning is not None:
masked_conditioning = conditioning[~dones]
else:
masked_conditioning = None

valid_actions, actions_log_probs, estimator_outputs = self.sample_actions(
env,
states[~dones],
masked_conditioning,
save_estimator_outputs=True if save_estimator_outputs else False,
save_logprobs=save_logprobs,
**policy_kwargs,
Expand Down Expand Up @@ -201,6 +232,7 @@ def sample_trajectories(
# Increment the step, determine which trajectories are finisihed, and eval
# rewards.
step += 1

# new_dones means those trajectories that just finished. Because we
# pad the sink state to every short trajectory, we need to make sure
# to filter out the already done ones.
Expand Down Expand Up @@ -236,6 +268,7 @@ def sample_trajectories(
trajectories = Trajectories(
env=env,
states=trajectories_states,
conditioning=conditioning,
actions=trajectories_actions,
when_is_done=trajectories_dones,
is_backward=self.estimator.is_backward,
Expand Down
9 changes: 9 additions & 0 deletions src/gfn/utils/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
(i.e. all layers except last layer).
"""
super().__init__()
self._input_dim = input_dim
self._output_dim = output_dim

if torso is None:
Expand Down Expand Up @@ -69,6 +70,14 @@ def forward(
out = self.last_layer(out)
return out

@property
def input_dim(self):
return self._input_dim

@property
def output_dim(self):
return self._output_dim


class Tabular(nn.Module):
"""Implements a tabular policy.
Expand Down
2 changes: 1 addition & 1 deletion tutorials/examples/train_hypergrid_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,4 @@
loss = gflownet.loss(env, trajectories)
loss.backward()
optimizer.step()
pbar.set_postfix({"loss": loss.item()})
pbar.set_postfix({"loss": loss.item()})
Loading
Loading