From c5ef7ea9f633a4b8ed2fbb2e14e2ad88f61a5d88 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Fri, 4 Oct 2024 16:10:19 -0400 Subject: [PATCH] black --- src/gfn/containers/trajectories.py | 18 +++++++++++++----- src/gfn/gflownet/detailed_balance.py | 22 ++++++++++++++++------ src/gfn/gflownet/flow_matching.py | 15 ++++++++++++--- src/gfn/gflownet/sub_trajectory_balance.py | 5 ++++- 4 files changed, 45 insertions(+), 15 deletions(-) diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py index 214c36b..d0545d9 100644 --- a/src/gfn/containers/trajectories.py +++ b/src/gfn/containers/trajectories.py @@ -376,7 +376,10 @@ def to_states(self) -> States: def to_non_initial_intermediary_and_terminating_states( self, - ) -> Union[Tuple[States, States, torch.Tensor, torch.Tensor], Tuple[States, States, None, None]]: + ) -> Union[ + Tuple[States, States, torch.Tensor, torch.Tensor], + Tuple[States, States, None, None], + ]: """Returns all intermediate and terminating `States` from the trajectories. This is useful for the flow matching loss, that requires its inputs to be distinguished. @@ -390,9 +393,9 @@ def to_non_initial_intermediary_and_terminating_states( if self.conditioning is not None: traj_len = self.states.batch_shape[0] expand_dims = (traj_len,) + tuple(self.conditioning.shape) - intermediary_conditioning = self.conditioning.unsqueeze(0).expand(expand_dims)[ - ~states.is_sink_state & ~states.is_initial_state - ] + intermediary_conditioning = self.conditioning.unsqueeze(0).expand( + expand_dims + )[~states.is_sink_state & ~states.is_initial_state] conditioning = self.conditioning # n_final_states == n_trajectories. else: intermediary_conditioning = None @@ -401,7 +404,12 @@ def to_non_initial_intermediary_and_terminating_states( intermediary_states = states[~states.is_sink_state & ~states.is_initial_state] terminating_states = self.last_states terminating_states.log_rewards = self.log_rewards - return (intermediary_states, terminating_states, intermediary_conditioning, conditioning) + return ( + intermediary_states, + terminating_states, + intermediary_conditioning, + conditioning, + ) def pad_dim0_to_target(a: torch.Tensor, target_dim0: int) -> torch.Tensor: diff --git a/src/gfn/gflownet/detailed_balance.py b/src/gfn/gflownet/detailed_balance.py index 43c5710..2060f7b 100644 --- a/src/gfn/gflownet/detailed_balance.py +++ b/src/gfn/gflownet/detailed_balance.py @@ -18,12 +18,13 @@ def check_compatibility(states, actions, transitions): if states.batch_shape != tuple(actions.batch_shape): if type(transitions) is not Transitions: - raise TypeError("`transitions` is type={}, not Transitions".format(type(transitions))) + raise TypeError( + "`transitions` is type={}, not Transitions".format(type(transitions)) + ) else: raise ValueError(" wrong happening with log_pf evaluations") - class DBGFlowNet(PFBasedGFlowNet[Transitions]): r"""The Detailed Balance GFlowNet. @@ -50,7 +51,10 @@ def __init__( log_reward_clip_min: float = -float("inf"), ): super().__init__(pf, pb) - assert any(isinstance(logF, cls) for cls in [ScalarEstimator, ConditionalScalarEstimator]), "logF must be a ScalarEstimator or derived" + assert any( + isinstance(logF, cls) + for cls in [ScalarEstimator, ConditionalScalarEstimator] + ), "logF must be a ScalarEstimator or derived" self.logF = logF self.forward_looking = forward_looking self.log_reward_clip_min = log_reward_clip_min @@ -152,7 +156,9 @@ def get_scores( # Evaluate the log PB of the actions, with optional conditioning. if transitions.conditioning is not None: with has_conditioning_exception_handler("pb", self.pb): - module_output = self.pb(valid_next_states, transitions.conditioning[~transitions.is_done]) + module_output = self.pb( + valid_next_states, transitions.conditioning[~transitions.is_done] + ) else: with no_conditioning_exception_handler("pb", self.pb): module_output = self.pb(valid_next_states) @@ -263,7 +269,9 @@ def get_scores( # next_states are also states, for which we already did a forward pass. if transitions.conditioning is not None: with has_conditioning_exception_handler("pf", self.pf): - module_output = self.pf(valid_next_states, transitions.conditioning[mask]) + module_output = self.pf( + valid_next_states, transitions.conditioning[mask] + ) else: with no_conditioning_exception_handler("pf", self.pf): module_output = self.pf(valid_next_states) @@ -276,7 +284,9 @@ def get_scores( if transitions.conditioning is not None: with has_conditioning_exception_handler("pb", self.pb): - module_output = self.pb(valid_next_states, transitions.conditioning[mask]) + module_output = self.pb( + valid_next_states, transitions.conditioning[mask] + ) else: with no_conditioning_exception_handler("pb", self.pb): module_output = self.pb(valid_next_states) diff --git a/src/gfn/gflownet/flow_matching.py b/src/gfn/gflownet/flow_matching.py index a968569..5bf9b4b 100644 --- a/src/gfn/gflownet/flow_matching.py +++ b/src/gfn/gflownet/flow_matching.py @@ -184,9 +184,18 @@ def loss( tuple of states, the first one being the internal states of the trajectories (i.e. non-terminal states), and the second one being the terminal states of the trajectories.""" - intermediary_states, terminating_states, intermediary_conditioning, terminating_conditioning = states_tuple - fm_loss = self.flow_matching_loss(env, intermediary_states, intermediary_conditioning) - rm_loss = self.reward_matching_loss(env, terminating_states, terminating_conditioning) + ( + intermediary_states, + terminating_states, + intermediary_conditioning, + terminating_conditioning, + ) = states_tuple + fm_loss = self.flow_matching_loss( + env, intermediary_states, intermediary_conditioning + ) + rm_loss = self.reward_matching_loss( + env, terminating_states, terminating_conditioning + ) return fm_loss + self.alpha * rm_loss def to_training_samples( diff --git a/src/gfn/gflownet/sub_trajectory_balance.py b/src/gfn/gflownet/sub_trajectory_balance.py index 2e930e1..5cbb8b5 100644 --- a/src/gfn/gflownet/sub_trajectory_balance.py +++ b/src/gfn/gflownet/sub_trajectory_balance.py @@ -75,7 +75,10 @@ def __init__( forward_looking: bool = False, ): super().__init__(pf, pb) - assert any(isinstance(logF, cls) for cls in [ScalarEstimator, ConditionalScalarEstimator]), "logF must be a ScalarEstimator or derived" + assert any( + isinstance(logF, cls) + for cls in [ScalarEstimator, ConditionalScalarEstimator] + ), "logF must be a ScalarEstimator or derived" self.logF = logF self.weighting = weighting self.lamda = lamda