diff --git a/src/gfn/containers/replay_buffer.py b/src/gfn/containers/replay_buffer.py index 2cf9fcc6..a5205249 100644 --- a/src/gfn/containers/replay_buffer.py +++ b/src/gfn/containers/replay_buffer.py @@ -117,13 +117,14 @@ class PrioritizedReplayBuffer(ReplayBuffer): p_norm_distance: p-norm distance value to pass to torch.cdist, for the determination of novel states. """ + def __init__( self, env: Env, objects_type: Literal["transitions", "trajectories", "states"], capacity: int = 1000, - cutoff_distance: float = 0., - p_norm_distance: float = 1., + cutoff_distance: float = 0.0, + p_norm_distance: float = 1.0, ): """Instantiates a prioritized replay buffer. Args: @@ -137,7 +138,7 @@ def __init__( norms are >= 0). p_norm_distance: p-norm distance value to pass to torch.cdist, for the determination of novel states. - """ + """ super().__init__(env, objects_type, capacity) self.cutoff_distance = cutoff_distance self.p_norm_distance = p_norm_distance diff --git a/src/gfn/gflownet/detailed_balance.py b/src/gfn/gflownet/detailed_balance.py index c08966f2..63c975f6 100644 --- a/src/gfn/gflownet/detailed_balance.py +++ b/src/gfn/gflownet/detailed_balance.py @@ -46,13 +46,21 @@ def logF_named_parameters(self): try: return {k: v for k, v in self.named_parameters() if "logF" in k} except KeyError as e: - print("logF not found in self.named_parameters. Are the weights tied with PF? {}".format(e)) + print( + "logF not found in self.named_parameters. Are the weights tied with PF? {}".format( + e + ) + ) def logF_parameters(self): try: return [v for k, v in self.named_parameters() if "logF" in k] except KeyError as e: - print("logF not found in self.named_parameters. Are the weights tied with PF? {}".format(e)) + print( + "logF not found in self.named_parameters. Are the weights tied with PF? {}".format( + e + ) + ) def get_scores( self, env: Env, transitions: Transitions, recalculate_all_logprobs: bool = False diff --git a/src/gfn/gflownet/flow_matching.py b/src/gfn/gflownet/flow_matching.py index 193f7c07..f363663d 100644 --- a/src/gfn/gflownet/flow_matching.py +++ b/src/gfn/gflownet/flow_matching.py @@ -30,7 +30,9 @@ class FMGFlowNet(GFlowNet[Tuple[DiscreteStates, DiscreteStates]]): def __init__(self, logF: DiscretePolicyEstimator, alpha: float = 1.0): super().__init__() - assert isinstance(logF, DiscretePolicyEstimator), "logF must be a Discrete Policy Estimator" + assert isinstance( + logF, DiscretePolicyEstimator + ), "logF must be a Discrete Policy Estimator" self.logF = logF self.alpha = alpha diff --git a/src/gfn/gflownet/sub_trajectory_balance.py b/src/gfn/gflownet/sub_trajectory_balance.py index 0c14c497..2184bacc 100644 --- a/src/gfn/gflownet/sub_trajectory_balance.py +++ b/src/gfn/gflownet/sub_trajectory_balance.py @@ -81,13 +81,21 @@ def logF_named_parameters(self): try: return {k: v for k, v in self.named_parameters() if "logF" in k} except KeyError as e: - print("logF not found in self.named_parameters. Are the weights tied with PF? {}".format(e)) + print( + "logF not found in self.named_parameters. Are the weights tied with PF? {}".format( + e + ) + ) def logF_parameters(self): try: return [v for k, v in self.named_parameters() if "logF" in k] except KeyError as e: - print("logF not found in self.named_parameters. Are the weights tied with PF? {}".format(e)) + print( + "logF not found in self.named_parameters. Are the weights tied with PF? {}".format( + e + ) + ) def cumulative_logprobs( self, diff --git a/src/gfn/gflownet/trajectory_balance.py b/src/gfn/gflownet/trajectory_balance.py index 3e9e88c7..691d7388 100644 --- a/src/gfn/gflownet/trajectory_balance.py +++ b/src/gfn/gflownet/trajectory_balance.py @@ -39,7 +39,9 @@ def __init__( if isinstance(logZ, float): self.logZ = nn.Parameter(torch.tensor(logZ)) else: - assert isinstance(logZ, ScalarEstimator), "logZ must be either float or a ScalarEstimator" + assert isinstance( + logZ, ScalarEstimator + ), "logZ must be either float or a ScalarEstimator" self.logZ = logZ self.log_reward_clip_min = log_reward_clip_min