From 9ae95a5911f72e004dace3d338a5e1b28a0781d8 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Tue, 2 Apr 2024 10:24:20 -0400 Subject: [PATCH] black --- src/gfn/gflownet/base.py | 5 ++++- src/gfn/gflownet/trajectory_balance.py | 10 ++++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index 2711e09..032639a 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -168,7 +168,10 @@ def get_pfs_and_pbs( if has_log_probs(trajectories) and not recalculate_all_logprobs: log_pf_trajectories = trajectories.log_probs else: - if trajectories.estimator_outputs is not None and not recalculate_all_logprobs: + if ( + trajectories.estimator_outputs is not None + and not recalculate_all_logprobs + ): estimator_outputs = trajectories.estimator_outputs[ ~trajectories.actions.is_dummy ] diff --git a/src/gfn/gflownet/trajectory_balance.py b/src/gfn/gflownet/trajectory_balance.py index b4abf3a..1f8799d 100644 --- a/src/gfn/gflownet/trajectory_balance.py +++ b/src/gfn/gflownet/trajectory_balance.py @@ -42,7 +42,10 @@ def __init__( self.log_reward_clip_min = log_reward_clip_min def loss( - self, env: Env, trajectories: Trajectories, recalculate_all_logprobs: bool = False + self, + env: Env, + trajectories: Trajectories, + recalculate_all_logprobs: bool = False, ) -> TT[0, float]: """Trajectory balance loss. @@ -83,7 +86,10 @@ def __init__( self.log_reward_clip_min = log_reward_clip_min def loss( - self, env: Env, trajectories: Trajectories, recalculate_all_logprobs: bool = False + self, + env: Env, + trajectories: Trajectories, + recalculate_all_logprobs: bool = False, ) -> TT[0, float]: """Log Partition Variance loss.