From 54af465314e92dcd230b022cc3fe79654ec57e75 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Fri, 20 Sep 2024 18:27:32 -0400 Subject: [PATCH] added helper methods and type checking for logZ, including allowing the user to have a conditional logZ --- src/gfn/gflownet/trajectory_balance.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/gfn/gflownet/trajectory_balance.py b/src/gfn/gflownet/trajectory_balance.py index 1f8799d..3e9e88c 100644 --- a/src/gfn/gflownet/trajectory_balance.py +++ b/src/gfn/gflownet/trajectory_balance.py @@ -10,7 +10,7 @@ from gfn.containers import Trajectories from gfn.env import Env from gfn.gflownet.base import TrajectoryBasedGFlowNet -from gfn.modules import GFNModule +from gfn.modules import GFNModule, ScalarEstimator class TBGFlowNet(TrajectoryBasedGFlowNet): @@ -23,7 +23,7 @@ class TBGFlowNet(TrajectoryBasedGFlowNet): the DAG, or a singleton thereof, if self.logit_PB is a fixed DiscretePBEstimator. Attributes: - logZ: a LogZEstimator instance. + logZ: a ScalarEstimator (for conditional GFNs) instance, or float. log_reward_clip_min: If finite, clips log rewards to this value. """ @@ -31,14 +31,17 @@ def __init__( self, pf: GFNModule, pb: GFNModule, - init_logZ: float = 0.0, + logZ: float | ScalarEstimator = 0.0, log_reward_clip_min: float = -float("inf"), ): super().__init__(pf, pb) - self.logZ = nn.Parameter( - torch.tensor(init_logZ) - ) # TODO: Optionally, this should be a nn.Module to support conditional GFNs. + if isinstance(logZ, float): + self.logZ = nn.Parameter(torch.tensor(logZ)) + else: + assert isinstance(logZ, ScalarEstimator), "logZ must be either float or a ScalarEstimator" + self.logZ = logZ + self.log_reward_clip_min = log_reward_clip_min def loss(