Skip to content

Commit

Permalink
Merge pull request #159 from GFNOrg/device_handling_fix
Browse files Browse the repository at this point in the history
device handling fix
  • Loading branch information
josephdviviano authored Feb 19, 2024
2 parents c776af4 + 35908de commit 617cc22
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 8 deletions.
4 changes: 2 additions & 2 deletions src/gfn/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,9 +410,9 @@ def set_nonexit_action_masks(self, cond, allow_exit: bool):
trajectory - if so, it should be set to True.
"""
if allow_exit:
exit_idx = torch.zeros(self.batch_shape + (1,))
exit_idx = torch.zeros(self.batch_shape + (1,)).to(cond.device)
else:
exit_idx = torch.ones(self.batch_shape + (1,))
exit_idx = torch.ones(self.batch_shape + (1,)).to(cond.device)
self.forward_masks[torch.cat([cond, exit_idx], dim=-1).bool()] = False

def set_exit_masks(self, batch_idx):
Expand Down
6 changes: 0 additions & 6 deletions src/gfn/utils/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def __init__(
else:
self.torso = torso
self.last_layer = nn.Linear(self.torso.hidden_dim, output_dim)
self.device = None

def forward(
self, preprocessed_states: TT["batch_shape", "input_dim", float]
Expand All @@ -66,11 +65,6 @@ def forward(
ingestion by the MLP.
Returns: out, a set of continuous variables.
"""
if self.device is None:
self.device = preprocessed_states.device
self.to(
self.device
) # TODO: This is maybe fine but could result in weird errors if the model keeps bouncing between devices.
out = self.torso(preprocessed_states)
out = self.last_layer(out)
return out
Expand Down
66 changes: 66 additions & 0 deletions tutorials/examples/train_hypergrid_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#!/usr/bin/env python
import torch
from tqdm import tqdm

from gfn.gflownet import TBGFlowNet
from gfn.gym import HyperGrid
from gfn.modules import DiscretePolicyEstimator
from gfn.samplers import Sampler
from gfn.utils import NeuralNet

torch.manual_seed(0)
exploration_rate = 0.5
learning_rate = 0.0005

# Setup the Environment.
env = HyperGrid(
ndim=5,
height=2,
device_str="cuda" if torch.cuda.is_available() else "cpu",
)

# Build the GFlowNet.
module_PF = NeuralNet(
input_dim=env.preprocessor.output_dim,
output_dim=env.n_actions,
)
module_PB = NeuralNet(
input_dim=env.preprocessor.output_dim,
output_dim=env.n_actions - 1,
torso=module_PF.torso,
)
pf_estimator = DiscretePolicyEstimator(
module_PF, env.n_actions, is_backward=False, preprocessor=env.preprocessor
)
pb_estimator = DiscretePolicyEstimator(
module_PB, env.n_actions, is_backward=True, preprocessor=env.preprocessor
)
gflownet = TBGFlowNet(init_logZ=0.0, pf=pf_estimator, pb=pb_estimator, off_policy=True)

# Feed pf to the sampler.
sampler = Sampler(estimator=pf_estimator)

# Move the gflownet to the GPU.
if torch.cuda.is_available():
gflownet = gflownet.to("cuda")

# Policy parameters have their own LR. Log Z gets dedicated learning rate
# (typically higher).
optimizer = torch.optim.Adam(gflownet.pf_pb_parameters(), lr=1e-3)
optimizer.add_param_group({"params": gflownet.logz_parameters(), "lr": 1e-1})

n_iterations = int(1e4)
batch_size = int(1e5)

for i in (pbar := tqdm(range(n_iterations))):
trajectories = sampler.sample_trajectories(
env,
n_trajectories=batch_size,
off_policy=True,
epsilon=exploration_rate,
)
optimizer.zero_grad()
loss = gflownet.loss(env, trajectories)
loss.backward()
optimizer.step()
pbar.set_postfix({"loss": loss.item()})

0 comments on commit 617cc22

Please sign in to comment.