Skip to content

Commit

Permalink
added helper methods and type checking for logF
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Sep 20, 2024
1 parent 0b62a2a commit 83f276a
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/gfn/gflownet/sub_trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,25 @@ def __init__(
forward_looking: bool = False,
):
super().__init__(pf, pb)
assert isinstance(logF, ScalarEstimator), "logF must be a ScalarEstimator"
self.logF = logF
self.weighting = weighting
self.lamda = lamda
self.log_reward_clip_min = log_reward_clip_min
self.forward_looking = forward_looking

def logF_named_parameters(self):
try:
return {k: v for k, v in self.named_parameters() if "logF" in k}
except KeyError as e:
print("logF not found in self.named_parameters. Are the weights tied with PF? {}".format(e))

def logF_parameters(self):
try:
return [v for k, v in self.named_parameters() if "logF" in k]
except KeyError as e:
print("logF not found in self.named_parameters. Are the weights tied with PF? {}".format(e))

def cumulative_logprobs(
self,
trajectories: Trajectories,
Expand Down

0 comments on commit 83f276a

Please sign in to comment.