diff --git a/src/gfn/gflownet/trajectory_balance.py b/src/gfn/gflownet/trajectory_balance.py index 1f8799d..3e9e88c 100644 --- a/src/gfn/gflownet/trajectory_balance.py +++ b/src/gfn/gflownet/trajectory_balance.py @@ -10,7 +10,7 @@ from gfn.containers import Trajectories from gfn.env import Env from gfn.gflownet.base import TrajectoryBasedGFlowNet -from gfn.modules import GFNModule +from gfn.modules import GFNModule, ScalarEstimator class TBGFlowNet(TrajectoryBasedGFlowNet): @@ -23,7 +23,7 @@ class TBGFlowNet(TrajectoryBasedGFlowNet): the DAG, or a singleton thereof, if self.logit_PB is a fixed DiscretePBEstimator. Attributes: - logZ: a LogZEstimator instance. + logZ: a ScalarEstimator (for conditional GFNs) instance, or float. log_reward_clip_min: If finite, clips log rewards to this value. """ @@ -31,14 +31,17 @@ def __init__( self, pf: GFNModule, pb: GFNModule, - init_logZ: float = 0.0, + logZ: float | ScalarEstimator = 0.0, log_reward_clip_min: float = -float("inf"), ): super().__init__(pf, pb) - self.logZ = nn.Parameter( - torch.tensor(init_logZ) - ) # TODO: Optionally, this should be a nn.Module to support conditional GFNs. + if isinstance(logZ, float): + self.logZ = nn.Parameter(torch.tensor(logZ)) + else: + assert isinstance(logZ, ScalarEstimator), "logZ must be either float or a ScalarEstimator" + self.logZ = logZ + self.log_reward_clip_min = log_reward_clip_min def loss(