Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to handling of logF and logZ #186

Merged
merged 7 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 102 additions & 46 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,62 +57,118 @@ Example scripts and notebooks for the three environments are provided [here](htt

### Standalone example

This example, which shows how to use the library for a simple discrete environment, requires [`tqdm`](https://github.com/tqdm/tqdm) package to run. Use `pip install tqdm` or install all extra requirements with `pip install .[scripts]` or `pip install torchgfn[scripts]`.
This example, which shows how to use the library for a simple discrete environment, requires [`tqdm`](https://github.com/tqdm/tqdm) package to run. Use `pip install tqdm` or install all extra requirements with `pip install .[scripts]` or `pip install torchgfn[scripts]`. In the first example, we will train a Tarjectory Balance GFlowNet:

```python
import torch
from tqdm import tqdm

from gfn.gflownet import TBGFlowNet # We use a GFlowNet with the Trajectory Balance (TB) loss
from gfn.gflownet import TBGFlowNet
from gfn.gym import HyperGrid # We use the hyper grid environment
from gfn.modules import DiscretePolicyEstimator
from gfn.samplers import Sampler
from gfn.utils import NeuralNet # NeuralNet is a simple multi-layer perceptron (MLP)

if __name__ == "__main__":

# 1 - We define the environment.
env = HyperGrid(ndim=4, height=8, R0=0.01) # Grid of size 8x8x8x8

# 2 - We define the needed modules (neural networks).
# The environment has a preprocessor attribute, which is used to preprocess the state before feeding it to the policy estimator
module_PF = NeuralNet(
input_dim=env.preprocessor.output_dim,
output_dim=env.n_actions
) # Neural network for the forward policy, with as many outputs as there are actions
module_PB = NeuralNet(
input_dim=env.preprocessor.output_dim,
output_dim=env.n_actions - 1,
torso=module_PF.torso # We share all the parameters of P_F and P_B, except for the last layer
)

# 3 - We define the estimators.
pf_estimator = DiscretePolicyEstimator(module_PF, env.n_actions, is_backward=False, preprocessor=env.preprocessor)
pb_estimator = DiscretePolicyEstimator(module_PB, env.n_actions, is_backward=True, preprocessor=env.preprocessor)

# 4 - We define the GFlowNet.
gfn = TBGFlowNet(init_logZ=0., pf=pf_estimator, pb=pb_estimator) # We initialize logZ to 0

# 5 - We define the sampler and the optimizer.
sampler = Sampler(estimator=pf_estimator) # We use an on-policy sampler, based on the forward policy

# Policy parameters have their own LR.
non_logz_params = [v for k, v in dict(gfn.named_parameters()).items() if k != "logZ"]
optimizer = torch.optim.Adam(non_logz_params, lr=1e-3)

# Log Z gets dedicated learning rate (typically higher).
logz_params = [dict(gfn.named_parameters())["logZ"]]
optimizer.add_param_group({"params": logz_params, "lr": 1e-1})

# 6 - We train the GFlowNet for 1000 iterations, with 16 trajectories per iteration
for i in (pbar := tqdm(range(1000))):
trajectories = sampler.sample_trajectories(env=env, n_trajectories=16)
optimizer.zero_grad()
loss = gfn.loss(env, trajectories)
loss.backward()
optimizer.step()
if i % 25 == 0:
pbar.set_postfix({"loss": loss.item()})
# 1 - We define the environment.
env = HyperGrid(ndim=4, height=8, R0=0.01) # Grid of size 8x8x8x8

# 2 - We define the needed modules (neural networks).
# The environment has a preprocessor attribute, which is used to preprocess the state before feeding it to the policy estimator
module_PF = NeuralNet(
input_dim=env.preprocessor.output_dim,
output_dim=env.n_actions
) # Neural network for the forward policy, with as many outputs as there are actions

module_PB = NeuralNet(
input_dim=env.preprocessor.output_dim,
output_dim=env.n_actions - 1,
torso=module_PF.torso # We share all the parameters of P_F and P_B, except for the last layer
)

# 3 - We define the estimators.
pf_estimator = DiscretePolicyEstimator(module_PF, env.n_actions, is_backward=False, preprocessor=env.preprocessor)
pb_estimator = DiscretePolicyEstimator(module_PB, env.n_actions, is_backward=True, preprocessor=env.preprocessor)

# 4 - We define the GFlowNet.
gfn = TBGFlowNet(logZ=0., pf=pf_estimator, pb=pb_estimator) # We initialize logZ to 0

# 5 - We define the sampler and the optimizer.
sampler = Sampler(estimator=pf_estimator) # We use an on-policy sampler, based on the forward policy

# Different policy parameters can have their own LR.
# Log Z gets dedicated learning rate (typically higher).
optimizer = torch.optim.Adam(gfn.pf_pb_parameters(), lr=1e-3)
optimizer.add_param_group({"params": gfn.logz_parameters(), "lr": 1e-1})

# 6 - We train the GFlowNet for 1000 iterations, with 16 trajectories per iteration
for i in (pbar := tqdm(range(1000))):
trajectories = sampler.sample_trajectories(env=env, n_trajectories=16)
optimizer.zero_grad()
loss = gfn.loss(env, trajectories)
loss.backward()
optimizer.step()
if i % 25 == 0:
pbar.set_postfix({"loss": loss.item()})
```

and in this example, we instead train using Sub Trajectory Balance. You can see we simply assemble our GFlowNet from slightly different building blocks:

```python
import torch
from tqdm import tqdm

from gfn.gflownet import SubTBGFlowNet
from gfn.gym import HyperGrid # We use the hyper grid environment
from gfn.modules import DiscretePolicyEstimator, ScalarEstimator
from gfn.samplers import Sampler
from gfn.utils import NeuralNet # NeuralNet is a simple multi-layer perceptron (MLP)

# 1 - We define the environment.
env = HyperGrid(ndim=4, height=8, R0=0.01) # Grid of size 8x8x8x8

# 2 - We define the needed modules (neural networks).
# The environment has a preprocessor attribute, which is used to preprocess the state before feeding it to the policy estimator
module_PF = NeuralNet(
input_dim=env.preprocessor.output_dim,
output_dim=env.n_actions
) # Neural network for the forward policy, with as many outputs as there are actions

module_PB = NeuralNet(
input_dim=env.preprocessor.output_dim,
output_dim=env.n_actions - 1,
torso=module_PF.torso # We share all the parameters of P_F and P_B, except for the last layer
)
module_logF = NeuralNet(
input_dim=env.preprocessor.output_dim,
output_dim=1, # Important for ScalarEstimators!
)

# 3 - We define the estimators.
pf_estimator = DiscretePolicyEstimator(module_PF, env.n_actions, is_backward=False, preprocessor=env.preprocessor)
pb_estimator = DiscretePolicyEstimator(module_PB, env.n_actions, is_backward=True, preprocessor=env.preprocessor)
logF_estimator = ScalarEstimator(module=module_logF, preprocessor=env.preprocessor)

# 4 - We define the GFlowNet.
gfn = SubTBGFlowNet(pf=pf_estimator, pb=pb_estimator, logF=logF, lamda=0.9)

# 5 - We define the sampler and the optimizer.
sampler = Sampler(estimator=pf_estimator) # We use an on-policy sampler, based on the forward policy

# Different policy parameters can have their own LR.
# Log F gets dedicated learning rate (typically higher).
optimizer = torch.optim.Adam(gfn.pf_pb_parameters(), lr=1e-3)
optimizer.add_param_group({"params": gfn.logF_parameters(), "lr": 1e-2})

# 6 - We train the GFlowNet for 1000 iterations, with 16 trajectories per iteration
for i in (pbar := tqdm(range(1000))):
trajectories = sampler.sample_trajectories(env=env, n_trajectories=16)
optimizer.zero_grad()
loss = gfn.loss(env, trajectories)
loss.backward()
optimizer.step()
if i % 25 == 0:
pbar.set_postfix({"loss": loss.item()})

```

## Contributing
Expand Down
13 changes: 13 additions & 0 deletions src/gfn/gflownet/detailed_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,23 @@ def __init__(
log_reward_clip_min: float = -float("inf"),
):
super().__init__(pf, pb)
assert isinstance(logF, ScalarEstimator), "logF must be a ScalarEstimator"
self.logF = logF
self.forward_looking = forward_looking
self.log_reward_clip_min = log_reward_clip_min

def logF_named_parameters(self):
try:
return {k: v for k, v in self.named_parameters() if "logF" in k}
except KeyError as e:
print("logF not found in self.named_parameters. Are the weights tied with PF? {}".format(e))

def logF_parameters(self):
try:
return [v for k, v in self.named_parameters() if "logF" in k]
except KeyError as e:
print("logF not found in self.named_parameters. Are the weights tied with PF? {}".format(e))

def get_scores(
self, env: Env, transitions: Transitions, recalculate_all_logprobs: bool = False
) -> Tuple[
Expand Down
3 changes: 2 additions & 1 deletion src/gfn/gflownet/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@ class FMGFlowNet(GFlowNet[Tuple[DiscreteStates, DiscreteStates]]):
3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266).

Attributes:
logF: LogEdgeFlowEstimator
logF: an estimator of log edge flows.
alpha: weight for the reward matching loss.
"""

def __init__(self, logF: DiscretePolicyEstimator, alpha: float = 1.0):
super().__init__()

assert isinstance(logF, DiscretePolicyEstimator), "logF must be a Discrete Policy Estimator"
self.logF = logF
self.alpha = alpha

Expand Down
13 changes: 13 additions & 0 deletions src/gfn/gflownet/sub_trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,25 @@ def __init__(
forward_looking: bool = False,
):
super().__init__(pf, pb)
assert isinstance(logF, ScalarEstimator), "logF must be a ScalarEstimator"
self.logF = logF
self.weighting = weighting
self.lamda = lamda
self.log_reward_clip_min = log_reward_clip_min
self.forward_looking = forward_looking

def logF_named_parameters(self):
try:
return {k: v for k, v in self.named_parameters() if "logF" in k}
except KeyError as e:
print("logF not found in self.named_parameters. Are the weights tied with PF? {}".format(e))

def logF_parameters(self):
try:
return [v for k, v in self.named_parameters() if "logF" in k]
except KeyError as e:
print("logF not found in self.named_parameters. Are the weights tied with PF? {}".format(e))

def cumulative_logprobs(
self,
trajectories: Trajectories,
Expand Down
15 changes: 9 additions & 6 deletions src/gfn/gflownet/trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from gfn.containers import Trajectories
from gfn.env import Env
from gfn.gflownet.base import TrajectoryBasedGFlowNet
from gfn.modules import GFNModule
from gfn.modules import GFNModule, ScalarEstimator


class TBGFlowNet(TrajectoryBasedGFlowNet):
Expand All @@ -23,22 +23,25 @@ class TBGFlowNet(TrajectoryBasedGFlowNet):
the DAG, or a singleton thereof, if self.logit_PB is a fixed DiscretePBEstimator.

Attributes:
logZ: a LogZEstimator instance.
logZ: a ScalarEstimator (for conditional GFNs) instance, or float.
log_reward_clip_min: If finite, clips log rewards to this value.
"""

def __init__(
self,
pf: GFNModule,
pb: GFNModule,
init_logZ: float = 0.0,
logZ: float | ScalarEstimator = 0.0,
log_reward_clip_min: float = -float("inf"),
):
super().__init__(pf, pb)

self.logZ = nn.Parameter(
torch.tensor(init_logZ)
) # TODO: Optionally, this should be a nn.Module to support conditional GFNs.
if isinstance(logZ, float):
self.logZ = nn.Parameter(torch.tensor(logZ))
else:
assert isinstance(logZ, ScalarEstimator), "logZ must be either float or a ScalarEstimator"
self.logZ = logZ

self.log_reward_clip_min = log_reward_clip_min

def loss(
Expand Down
Loading