Skip to content

Commit

Permalink
black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Sep 20, 2024
1 parent 54af465 commit e05ae2f
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 9 deletions.
7 changes: 4 additions & 3 deletions src/gfn/containers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,14 @@ class PrioritizedReplayBuffer(ReplayBuffer):
p_norm_distance: p-norm distance value to pass to torch.cdist, for the
determination of novel states.
"""

def __init__(
self,
env: Env,
objects_type: Literal["transitions", "trajectories", "states"],
capacity: int = 1000,
cutoff_distance: float = 0.,
p_norm_distance: float = 1.,
cutoff_distance: float = 0.0,
p_norm_distance: float = 1.0,
):
"""Instantiates a prioritized replay buffer.
Args:
Expand All @@ -137,7 +138,7 @@ def __init__(
norms are >= 0).
p_norm_distance: p-norm distance value to pass to torch.cdist, for the
determination of novel states.
"""
"""
super().__init__(env, objects_type, capacity)
self.cutoff_distance = cutoff_distance
self.p_norm_distance = p_norm_distance
Expand Down
12 changes: 10 additions & 2 deletions src/gfn/gflownet/detailed_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,21 @@ def logF_named_parameters(self):
try:
return {k: v for k, v in self.named_parameters() if "logF" in k}
except KeyError as e:
print("logF not found in self.named_parameters. Are the weights tied with PF? {}".format(e))
print(
"logF not found in self.named_parameters. Are the weights tied with PF? {}".format(
e
)
)

def logF_parameters(self):
try:
return [v for k, v in self.named_parameters() if "logF" in k]
except KeyError as e:
print("logF not found in self.named_parameters. Are the weights tied with PF? {}".format(e))
print(
"logF not found in self.named_parameters. Are the weights tied with PF? {}".format(
e
)
)

def get_scores(
self, env: Env, transitions: Transitions, recalculate_all_logprobs: bool = False
Expand Down
4 changes: 3 additions & 1 deletion src/gfn/gflownet/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ class FMGFlowNet(GFlowNet[Tuple[DiscreteStates, DiscreteStates]]):
def __init__(self, logF: DiscretePolicyEstimator, alpha: float = 1.0):
super().__init__()

assert isinstance(logF, DiscretePolicyEstimator), "logF must be a Discrete Policy Estimator"
assert isinstance(
logF, DiscretePolicyEstimator
), "logF must be a Discrete Policy Estimator"
self.logF = logF
self.alpha = alpha

Expand Down
12 changes: 10 additions & 2 deletions src/gfn/gflownet/sub_trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,21 @@ def logF_named_parameters(self):
try:
return {k: v for k, v in self.named_parameters() if "logF" in k}
except KeyError as e:
print("logF not found in self.named_parameters. Are the weights tied with PF? {}".format(e))
print(
"logF not found in self.named_parameters. Are the weights tied with PF? {}".format(
e
)
)

def logF_parameters(self):
try:
return [v for k, v in self.named_parameters() if "logF" in k]
except KeyError as e:
print("logF not found in self.named_parameters. Are the weights tied with PF? {}".format(e))
print(
"logF not found in self.named_parameters. Are the weights tied with PF? {}".format(
e
)
)

def cumulative_logprobs(
self,
Expand Down
4 changes: 3 additions & 1 deletion src/gfn/gflownet/trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def __init__(
if isinstance(logZ, float):
self.logZ = nn.Parameter(torch.tensor(logZ))
else:
assert isinstance(logZ, ScalarEstimator), "logZ must be either float or a ScalarEstimator"
assert isinstance(
logZ, ScalarEstimator
), "logZ must be either float or a ScalarEstimator"
self.logZ = logZ

self.log_reward_clip_min = log_reward_clip_min
Expand Down

0 comments on commit e05ae2f

Please sign in to comment.