Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Oct 4, 2024
1 parent f59f4de commit c5ef7ea
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 15 deletions.
18 changes: 13 additions & 5 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,10 @@ def to_states(self) -> States:

def to_non_initial_intermediary_and_terminating_states(
self,
) -> Union[Tuple[States, States, torch.Tensor, torch.Tensor], Tuple[States, States, None, None]]:
) -> Union[
Tuple[States, States, torch.Tensor, torch.Tensor],
Tuple[States, States, None, None],
]:
"""Returns all intermediate and terminating `States` from the trajectories.
This is useful for the flow matching loss, that requires its inputs to be distinguished.
Expand All @@ -390,9 +393,9 @@ def to_non_initial_intermediary_and_terminating_states(
if self.conditioning is not None:
traj_len = self.states.batch_shape[0]
expand_dims = (traj_len,) + tuple(self.conditioning.shape)
intermediary_conditioning = self.conditioning.unsqueeze(0).expand(expand_dims)[
~states.is_sink_state & ~states.is_initial_state
]
intermediary_conditioning = self.conditioning.unsqueeze(0).expand(
expand_dims
)[~states.is_sink_state & ~states.is_initial_state]
conditioning = self.conditioning # n_final_states == n_trajectories.
else:
intermediary_conditioning = None
Expand All @@ -401,7 +404,12 @@ def to_non_initial_intermediary_and_terminating_states(
intermediary_states = states[~states.is_sink_state & ~states.is_initial_state]
terminating_states = self.last_states
terminating_states.log_rewards = self.log_rewards
return (intermediary_states, terminating_states, intermediary_conditioning, conditioning)
return (
intermediary_states,
terminating_states,
intermediary_conditioning,
conditioning,
)


def pad_dim0_to_target(a: torch.Tensor, target_dim0: int) -> torch.Tensor:
Expand Down
22 changes: 16 additions & 6 deletions src/gfn/gflownet/detailed_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
def check_compatibility(states, actions, transitions):
if states.batch_shape != tuple(actions.batch_shape):
if type(transitions) is not Transitions:
raise TypeError("`transitions` is type={}, not Transitions".format(type(transitions)))
raise TypeError(
"`transitions` is type={}, not Transitions".format(type(transitions))
)
else:
raise ValueError(" wrong happening with log_pf evaluations")



class DBGFlowNet(PFBasedGFlowNet[Transitions]):
r"""The Detailed Balance GFlowNet.
Expand All @@ -50,7 +51,10 @@ def __init__(
log_reward_clip_min: float = -float("inf"),
):
super().__init__(pf, pb)
assert any(isinstance(logF, cls) for cls in [ScalarEstimator, ConditionalScalarEstimator]), "logF must be a ScalarEstimator or derived"
assert any(
isinstance(logF, cls)
for cls in [ScalarEstimator, ConditionalScalarEstimator]
), "logF must be a ScalarEstimator or derived"
self.logF = logF
self.forward_looking = forward_looking
self.log_reward_clip_min = log_reward_clip_min
Expand Down Expand Up @@ -152,7 +156,9 @@ def get_scores(
# Evaluate the log PB of the actions, with optional conditioning.
if transitions.conditioning is not None:
with has_conditioning_exception_handler("pb", self.pb):
module_output = self.pb(valid_next_states, transitions.conditioning[~transitions.is_done])
module_output = self.pb(
valid_next_states, transitions.conditioning[~transitions.is_done]
)
else:
with no_conditioning_exception_handler("pb", self.pb):
module_output = self.pb(valid_next_states)
Expand Down Expand Up @@ -263,7 +269,9 @@ def get_scores(
# next_states are also states, for which we already did a forward pass.
if transitions.conditioning is not None:
with has_conditioning_exception_handler("pf", self.pf):
module_output = self.pf(valid_next_states, transitions.conditioning[mask])
module_output = self.pf(
valid_next_states, transitions.conditioning[mask]
)
else:
with no_conditioning_exception_handler("pf", self.pf):
module_output = self.pf(valid_next_states)
Expand All @@ -276,7 +284,9 @@ def get_scores(

if transitions.conditioning is not None:
with has_conditioning_exception_handler("pb", self.pb):
module_output = self.pb(valid_next_states, transitions.conditioning[mask])
module_output = self.pb(
valid_next_states, transitions.conditioning[mask]
)
else:
with no_conditioning_exception_handler("pb", self.pb):
module_output = self.pb(valid_next_states)
Expand Down
15 changes: 12 additions & 3 deletions src/gfn/gflownet/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,18 @@ def loss(
tuple of states, the first one being the internal states of the trajectories
(i.e. non-terminal states), and the second one being the terminal states of the
trajectories."""
intermediary_states, terminating_states, intermediary_conditioning, terminating_conditioning = states_tuple
fm_loss = self.flow_matching_loss(env, intermediary_states, intermediary_conditioning)
rm_loss = self.reward_matching_loss(env, terminating_states, terminating_conditioning)
(
intermediary_states,
terminating_states,
intermediary_conditioning,
terminating_conditioning,
) = states_tuple
fm_loss = self.flow_matching_loss(
env, intermediary_states, intermediary_conditioning
)
rm_loss = self.reward_matching_loss(
env, terminating_states, terminating_conditioning
)
return fm_loss + self.alpha * rm_loss

def to_training_samples(
Expand Down
5 changes: 4 additions & 1 deletion src/gfn/gflownet/sub_trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ def __init__(
forward_looking: bool = False,
):
super().__init__(pf, pb)
assert any(isinstance(logF, cls) for cls in [ScalarEstimator, ConditionalScalarEstimator]), "logF must be a ScalarEstimator or derived"
assert any(
isinstance(logF, cls)
for cls in [ScalarEstimator, ConditionalScalarEstimator]
), "logF must be a ScalarEstimator or derived"
self.logF = logF
self.weighting = weighting
self.lamda = lamda
Expand Down

0 comments on commit c5ef7ea

Please sign in to comment.