Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rethinking sampling #147

Merged
merged 70 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
95f4e01
improved documentation, and also formatting
josephdviviano Nov 16, 2023
7b0688b
notes, and some changes RE: passing of policy_kwargs (renamed)
josephdviviano Nov 16, 2023
0b812e5
renaming policy kwargs and changing usage
josephdviviano Nov 16, 2023
6eeb0af
added policy_kwargs
josephdviviano Nov 16, 2023
a391b28
renamed args
josephdviviano Nov 16, 2023
d9fa884
changed logic and added TODOs for improved efficiency
josephdviviano Nov 16, 2023
2a16b1b
sample trajectories bugfix
josephdviviano Nov 16, 2023
740e16f
sample trajectories bugfix
josephdviviano Nov 16, 2023
9d1ffd0
sync for debug
josephdviviano Nov 18, 2023
963693c
adding line environment
josephdviviano Nov 18, 2023
76ab487
TODO
josephdviviano Nov 21, 2023
8999dd3
added logic for sampling off policy and some helper functions
josephdviviano Nov 21, 2023
450ebf0
estimator_outputs can be passed around
josephdviviano Nov 21, 2023
a8b637e
estimator outputs can be saved
josephdviviano Nov 21, 2023
1acfcce
tweaks to demo
josephdviviano Nov 21, 2023
e052c82
added back in default recomputing behaviour for pf in off policy mode.
josephdviviano Nov 21, 2023
f897aab
bugfix
josephdviviano Nov 21, 2023
67ea36e
simplified logprobs calc
josephdviviano Nov 21, 2023
e88e1bb
v1 of the line tutorial
josephdviviano Nov 22, 2023
b28dc95
estimator outputs can be passed around to avoid recalculation
josephdviviano Nov 22, 2023
a72afe9
documentation & removal of reward clamping.
josephdviviano Nov 22, 2023
46693af
added clone (just a test)
josephdviviano Nov 23, 2023
119559d
black formatting, debugging code left in (commented), and log_reward_…
josephdviviano Nov 23, 2023
e8ab999
log_reward_clip_min is now default off
josephdviviano Nov 23, 2023
5e87e3f
log_reward_clip_min is now optional
josephdviviano Nov 23, 2023
732fb0f
black
josephdviviano Nov 23, 2023
5048f3c
added log reward clipping
josephdviviano Nov 23, 2023
3b4e597
formatting
josephdviviano Nov 23, 2023
5ef0c22
reorg of training loop (nothing is functionally different
josephdviviano Nov 23, 2023
d69d258
isort
josephdviviano Nov 23, 2023
e29a278
variable naming and a note
josephdviviano Nov 23, 2023
ff45949
no longer using deepcopy. removed all log_reward clipping, which shou…
josephdviviano Nov 23, 2023
2ca4ced
note RE typecasting
josephdviviano Nov 23, 2023
a4c1786
improved efficiency of the init, and also added a clone method for st…
josephdviviano Nov 23, 2023
f6edd53
note to self RE typecasting in the identity preprocessor -- not sure …
josephdviviano Nov 23, 2023
6aab6e0
black / isort
josephdviviano Nov 23, 2023
716ee7a
log reward clipping removed
josephdviviano Nov 23, 2023
dfb929d
debugging sync
josephdviviano Nov 23, 2023
5d62bee
Independent distributions
josephdviviano Nov 23, 2023
1ceb53d
synced (debug still included)
josephdviviano Nov 23, 2023
b67a6d2
Merge branch 'easier_environment_definition' of github.com:saleml/tor…
josephdviviano Nov 24, 2023
c419dd3
removed debugging notes (confirmed that the issue is with my personal…
josephdviviano Nov 24, 2023
a0f43c6
turned clipping on
josephdviviano Nov 24, 2023
d6ad17f
estimator_outputs now live inside trajectories
josephdviviano Nov 24, 2023
50b74d2
estimator outputs now live inside trajectories. if they aren't comput…
josephdviviano Nov 24, 2023
80b4c29
full support for estimator_outputs (lots of padding logic added -- th…
josephdviviano Nov 27, 2023
b5fdd32
TODO - I think we found a function that is never called
josephdviviano Nov 27, 2023
ac42f22
estimator outputs now saved in a padded format. Also, some logic chan…
josephdviviano Nov 27, 2023
05732a8
all flags are now off_policy for consistency
josephdviviano Nov 27, 2023
25235f0
all flags are now off_policy for consistency
josephdviviano Nov 27, 2023
0814725
all tests pass
josephdviviano Nov 27, 2023
72ac58f
added off policy flag
josephdviviano Nov 27, 2023
b0432c9
isort / black
josephdviviano Nov 27, 2023
b987a39
updated scripts with new API and tweaked tests (with reproducibility)
josephdviviano Nov 27, 2023
12eab45
tests passing
josephdviviano Nov 27, 2023
cd35cb8
syncing notebook states
josephdviviano Nov 27, 2023
44050a9
removed one order of magnitude precision required
josephdviviano Nov 27, 2023
038b67b
merge issues resolved
josephdviviano Nov 27, 2023
8388362
fixed tests
josephdviviano Nov 27, 2023
93f2e5f
removed comments
josephdviviano Nov 27, 2023
a6601d7
further loosened test tolerances
josephdviviano Nov 27, 2023
71da6b5
changes requested for PR
josephdviviano Feb 13, 2024
bafa1ad
moved training specific imports here to avoid circular deps
josephdviviano Feb 13, 2024
0990d51
circular deps fix
josephdviviano Feb 13, 2024
aa3c656
removing addiditons (additions commented out)
josephdviviano Feb 14, 2024
e2ad9dd
formatting common
josephdviviano Feb 14, 2024
be122ed
indexing reverted to old strategy with copius documentation
josephdviviano Feb 14, 2024
cfc560c
formatting of tests
josephdviviano Feb 14, 2024
2bebde2
isort / black
josephdviviano Feb 14, 2024
e7c7453
isort
josephdviviano Feb 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 8 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,10 @@ from gfn.utils import NeuralNet # NeuralNet is a simple multi-layer perceptron

if __name__ == "__main__":

# 1 - We define the environment

env = HyperGrid(ndim=4, height=8, R0=0.01) # Grid of size 8x8x8x8

# 2 - We define the needed modules (neural networks)
# 1 - We define the environment.
env = HyperGrid(ndim=4, height=8, R0=0.01) # Grid of size 8x8x8x8

# 2 - We define the needed modules (neural networks).
# The environment has a preprocessor attribute, which is used to preprocess the state before feeding it to the policy estimator
module_PF = NeuralNet(
input_dim=env.preprocessor.output_dim,
Expand All @@ -88,17 +86,14 @@ if __name__ == "__main__":
torso=module_PF.torso # We share all the parameters of P_F and P_B, except for the last layer
)

# 3 - We define the estimators

# 3 - We define the estimators.
pf_estimator = DiscretePolicyEstimator(module_PF, env.n_actions, is_backward=False, preprocessor=env.preprocessor)
pb_estimator = DiscretePolicyEstimator(module_PB, env.n_actions, is_backward=True, preprocessor=env.preprocessor)

# 4 - We define the GFlowNet

# 4 - We define the GFlowNet.
gfn = TBGFlowNet(init_logZ=0., pf=pf_estimator, pb=pb_estimator) # We initialize logZ to 0

# 5 - We define the sampler and the optimizer

# 5 - We define the sampler and the optimizer.
sampler = Sampler(estimator=pf_estimator) # We use an on-policy sampler, based on the forward policy

# Policy parameters have their own LR.
Expand All @@ -110,7 +105,6 @@ if __name__ == "__main__":
optimizer.add_param_group({"params": logz_params, "lr": 1e-1})

# 6 - We train the GFlowNet for 1000 iterations, with 16 trajectories per iteration

for i in (pbar := tqdm(range(1000))):
trajectories = sampler.sample_trajectories(env=env, n_trajectories=16)
optimizer.zero_grad()
Expand Down Expand Up @@ -193,6 +187,8 @@ Training GFlowNets requires one or multiple estimators, called `GFNModule`s, whi

For non-discrete environments, the user needs to specify their own policies $P_F$ and $P_B$. The module, taking as input a batch of states (as a `States`) object, should return the batched parameters of a `torch.Distribution`. The distribution depends on the environment. The `to_probability_distribution` function handles the conversion of the parameter outputs to an actual batched `Distribution` object, that implements at least the `sample` and `log_prob` functions. An example is provided [here](https://github.com/saleml/torchgfn/tree/master/src/gfn/gym/helpers/box_utils.py), for a square environment in which the forward policy has support either on a quarter disk, or on an arc-circle, such that the angle, and the radius (for the quarter disk part) are scaled samples from a mixture of Beta distributions. The provided example shows an intricate scenario, and it is not expected that user defined environment need this much level of details.

In general, (and perhaps obviously) the `to_probability_distribution` method is used to calculate a probability distribution from a policy. Therefore, in order to go off-policy, one needs to modify the computations in this method during sampling. One accomplishes this using `policy_kwargs`, a `dict` of kwarg-value pairs which are used by the `Estimator` when calculating the new policy. In the discrete case, where common settings apply, one can see their use in `DiscretePolicyEstimator`'s `to_probability_distribution` method by passing a softmax `temperature`, `sf_bias` (a scalar to subtract from the exit action logit) or `epsilon` which allows for e-greedy style exploration. In the continuous case, it is not possible to forsee the methods used for off-policy exploration (as it depends on the details of the `to_probability_distribution` method, which is not generic for continuous GFNs), so this must be handled by the user, using custom `policy_kwargs`.

In all `GFNModule`s, note that the input of the `forward` function is a `States` object. Meaning that they first need to be transformed to tensors. However, `states.tensor` does not necessarily include the structure that a neural network can used to generalize. It is common in these scenarios to have a function that transforms these raw tensor states to ones where the structure is clearer, via a `Preprocessor` object, that is part of the environment. More on this [here](https://github.com/saleml/torchgfn/tree/master/tutorials/ENV.md). The default preprocessor of an environment is the identity preprocessor. The `forward` pass thus first calls the `preprocessor` attribute of the environment on `States`, before performing any transformation. The `preprocessor` is thus an attribute of the module. If it is not explicitly defined, it is set to the identity preprocessor.

For discrete environments, a `Tabular` module is provided, where a lookup table is used instead of a neural network. Additionally, a `UniformPB` module is provided, implementing a uniform backward policy. These modules are provided [here](https://github.com/saleml/torchgfn/tree/master/src/gfn/utils/modules.py).
Expand Down
111 changes: 104 additions & 7 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,20 @@
from gfn.env import Env
from gfn.states import States

import numpy as np
import torch
from torch import Tensor
from torchtyping import TensorType as TT

from gfn.containers.base import Container
from gfn.containers.transitions import Transitions


def is_tensor(t) -> bool:
"""Checks whether t is a torch.Tensor instance."""
return isinstance(t, Tensor)


# TODO: remove env from this class?
class Trajectories(Container):
"""Container for complete trajectories (starting in $s_0$ and ending in $s_f$).
Expand Down Expand Up @@ -47,16 +54,21 @@ def __init__(
is_backward: bool = False,
log_rewards: TT["n_trajectories", torch.float] | None = None,
log_probs: TT["max_length", "n_trajectories", torch.float] | None = None,
estimator_outputs: torch.Tensor | None = None,
) -> None:
"""
Args:
env: The environment in which the trajectories are defined.
states: The states of the trajectories. Defaults to None.
actions: The actions of the trajectories. Defaults to None.
when_is_done: The time step at which each trajectory ends. Defaults to None.
is_backward: Whether the trajectories are backward or forward. Defaults to False.
log_rewards: The log_rewards of the trajectories. Defaults to None.
log_probs: The log probabilities of the trajectories' actions. Defaults to None.
states: The states of the trajectories.
actions: The actions of the trajectories.
when_is_done: The time step at which each trajectory ends.
is_backward: Whether the trajectories are backward or forward.
log_rewards: The log_rewards of the trajectories.
log_probs: The log probabilities of the trajectories' actions.
estimator_outputs: When forward sampling off-policy for an n-step
trajectory, n forward passes will be made on some function approximator,
which may need to be re-used (for example, for evaluating PF). To avoid
duplicated effort, the outputs of the forward passes can be stored here.

If states is None, then the states are initialized to an empty States object,
that can be populated on the fly. If log_rewards is None, then `env.log_reward`
Expand Down Expand Up @@ -87,6 +99,7 @@ def __init__(
if log_probs is not None
else torch.full(size=(0, 0), fill_value=0, dtype=torch.float)
)
self.estimator_outputs = estimator_outputs

def __repr__(self) -> str:
states = self.states.tensor.transpose(0, 1)
Expand Down Expand Up @@ -154,6 +167,21 @@ def __getitem__(self, index: int | Sequence[int]) -> Trajectories:
log_rewards = (
self._log_rewards[index] if self._log_rewards is not None else None
)
if is_tensor(self.estimator_outputs):
# TODO: Is there a safer way to index self.estimator_outputs for
# for n-dimensional estimator outputs?
#
# First we index along the first dimension of the estimator outputs.
# This can be thought of as the instance dimension, and is
# compatible with all supported indexing approaches (dim=1).
# All dims > 1 are not explicitly indexed unless the dimensionality
# of `index` matches all dimensions of `estimator_outputs` aside
# from the first (trajectory) dimension.
estimator_outputs = self.estimator_outputs[:, index]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implicitly assumes that self.estimator_outputs is of shape max_length x n_trajectories (as is the case for example for self.log_probs). Would this always be the case?

I feel like things would easily break here unless we force some structure on estimator_outputs. Rather than torch.Tensor, it has to be some TensorType with a specific shape IMO.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think of simply:

        if is_tensor(self.estimator_outputs):
            estimator_outputs = self.estimator_outputs[..., index]
            estimator_outputs = estimator_outputs[:new_max_length]

?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that should work !

# Next we index along the trajectory length (dim=0)
estimator_outputs = estimator_outputs[:new_max_length]
else:
estimator_outputs = None

return Trajectories(
env=self.env,
Expand All @@ -163,6 +191,7 @@ def __getitem__(self, index: int | Sequence[int]) -> Trajectories:
is_backward=self.is_backward,
log_rewards=log_rewards,
log_probs=log_probs,
estimator_outputs=estimator_outputs,
)

@staticmethod
Expand Down Expand Up @@ -198,7 +227,10 @@ def extend(self, other: Trajectories) -> None:
Args:
other: an external set of Trajectories.
"""
if len(other) == 0:
return

# TODO: The replay buffer is storing `dones` - this wastes a lot of space.
self.actions.extend(other.actions)
self.states.extend(other.states)
self.when_is_done = torch.cat((self.when_is_done, other.when_is_done), dim=0)
Expand All @@ -213,11 +245,76 @@ def extend(self, other: Trajectories) -> None:

if self._log_rewards is not None and other._log_rewards is not None:
self._log_rewards = torch.cat(
(self._log_rewards, other._log_rewards), dim=0
(self._log_rewards, other._log_rewards),
dim=0,
)
else:
self._log_rewards = None

# Either set, or append, estimator outputs if they exist in the submitted
# trajectory.
if self.estimator_outputs is None and is_tensor(other.estimator_outputs):
self.estimator_outputs = other.estimator_outputs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but how would we match the indices of the trajectories to the indices of the estimator_outputs ?

This feels dangerous. I suggest just throwing an error when one is None and the other is not (either one).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the idea is to be able to extend an empty Trajectories instance, say with a stored buffer.

I agree it is dangerous but I think we should support this behaviour.

Admittedly it has been some time since I looked at this so I might be forgetting something.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough!

elif is_tensor(self.estimator_outputs) and is_tensor(other.estimator_outputs):
batch_shape = self.actions.batch_shape
n_bs = len(batch_shape)
output_dtype = self.estimator_outputs.dtype

if n_bs == 1:
# Concatenate along the only batch dimension.
self.estimator_outputs = torch.cat(
(self.estimator_outputs, other.estimator_outputs),
dim=0,
)
elif n_bs == 2:
if self.estimator_outputs.shape[0] != other.estimator_outputs.shape[0]:
# First we need to pad the first dimension on either self or other.
self_shape = np.array(self.estimator_outputs.shape)
other_shape = np.array(other.estimator_outputs.shape)
required_first_dim = max(self_shape[0], other_shape[0])

# TODO: This should be a single reused function (#154)
# The size of self needs to grow to match other along dim=0.
if self_shape[0] < other_shape[0]:
pad_dim = required_first_dim - self_shape[0]
pad_dim_full = (pad_dim,) + tuple(self_shape[1:])
output_padding = torch.full(
pad_dim_full,
fill_value=-float("inf"),
dtype=self.estimator_outputs.dtype, # TODO: This isn't working! Hence the cast below...
device=self.estimator_outputs.device,
)
self.estimator_outputs = torch.cat(
(self.estimator_outputs, output_padding),
dim=0,
)

# The size of other needs to grow to match self along dim=0.
if other_shape[0] < self_shape[0]:
pad_dim = required_first_dim - other_shape[0]
pad_dim_full = (pad_dim,) + tuple(other_shape[1:])
output_padding = torch.full(
pad_dim_full,
fill_value=-float("inf"),
dtype=other.estimator_outputs.dtype, # TODO: This isn't working! Hence the cast below...
device=other.estimator_outputs.device,
)
other.estimator_outputs = torch.cat(
(other.estimator_outputs, output_padding),
dim=0,
)

# Concatenate the tensors along the second dimension.
self.estimator_outputs = torch.cat(
(self.estimator_outputs, other.estimator_outputs),
dim=1,
).to(
dtype=output_dtype
) # Cast to prevent single precision becoming double precision... weird.

# Sanity check. TODO: Remove?
assert self.estimator_outputs.shape[:n_bs] == batch_shape

def to_transitions(self) -> Transitions:
"""Returns a `Transitions` object from the trajectories."""
states = self.states[:-1][~self.actions.is_dummy]
Expand Down
26 changes: 9 additions & 17 deletions src/gfn/env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Optional, Tuple, Union

import torch
Expand All @@ -8,6 +7,7 @@
from gfn.actions import Actions
from gfn.preprocessors import IdentityPreprocessor, Preprocessor
from gfn.states import DiscreteStates, States
from gfn.utils.common import set_seed

# Errors
NonValidActionsError = type("NonValidActionsError", (ValueError,), {})
Expand All @@ -23,7 +23,6 @@ def __init__(
sf: Optional[TT["state_shape", torch.float]] = None,
device_str: Optional[str] = None,
preprocessor: Optional[Preprocessor] = None,
log_reward_clip: Optional[float] = -100.0,
):
"""Initializes an environment.

Expand All @@ -37,7 +36,6 @@ def __init__(
preprocessor: a Preprocessor object that converts raw states to a tensor
that can be fed into a neural network. Defaults to None, in which case
the IdentityPreprocessor is used.
log_reward_clip: Used to clip small rewards (in particular, log(0) rewards).
"""
self.device = torch.device(device_str) if device_str is not None else s0.device

Expand All @@ -58,7 +56,6 @@ def __init__(

self.preprocessor = preprocessor
self.is_discrete = False
self.log_reward_clip = log_reward_clip

@abstractmethod
def make_States_class(self) -> type[States]:
Expand All @@ -83,7 +80,7 @@ def reset(
assert not (random and sink)

if random and seed is not None:
torch.manual_seed(seed)
set_seed(seed, performance_mode=True)

if batch_shape is None:
batch_shape = (1,)
Expand All @@ -94,15 +91,15 @@ def reset(
)

@abstractmethod
def maskless_step(
def maskless_step( # TODO: rename to step, other method becomes _step.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea !

self, states: States, actions: Actions
) -> TT["batch_shape", "state_shape", torch.float]:
"""Function that takes a batch of states and actions and returns a batch of next
states. Does not need to check whether the actions are valid or the states are sink states.
"""

@abstractmethod
def maskless_backward_step(
def maskless_backward_step( # TODO: rename to backward_step, other method becomes _backward_step.
self, states: States, actions: Actions
) -> TT["batch_shape", "state_shape", torch.float]:
"""Function that takes a batch of states and actions and returns a batch of previous
Expand Down Expand Up @@ -134,7 +131,7 @@ def step(
) -> States:
"""Function that takes a batch of states and actions and returns a batch of next
states and a boolean tensor indicating sink states in the new batch."""
new_states = deepcopy(states)
new_states = states.clone() # TODO: Ensure this is efficient!
valid_states_idx: TT["batch_shape", torch.bool] = ~states.is_sink_state
valid_actions = actions[valid_states_idx]
valid_states = states[valid_states_idx]
Expand All @@ -154,8 +151,6 @@ def step(
new_not_done_states_tensor = self.maskless_step(
not_done_states, not_done_actions
)
# if isinstance(new_states, DiscreteStates):
# new_not_done_states.masks = self.update_masks(not_done_states, not_done_actions)

new_states.tensor[~new_sink_states_idx] = new_not_done_states_tensor

Expand All @@ -168,7 +163,7 @@ def backward_step(
) -> States:
"""Function that takes a batch of states and actions and returns a batch of next
states and a boolean tensor indicating initial states in the new batch."""
new_states = deepcopy(states)
new_states = states.clone() # TODO: Ensure this is efficient!
valid_states_idx: TT["batch_shape", torch.bool] = ~new_states.is_initial_state
valid_actions = actions[valid_states_idx]
valid_states = states[valid_states_idx]
Expand Down Expand Up @@ -197,8 +192,8 @@ def reward(self, final_states: States) -> TT["batch_shape", torch.float]:
raise NotImplementedError("Reward function is not implemented.")

def log_reward(self, final_states: States) -> TT["batch_shape", torch.float]:
"""Calculates the log reward (clipping small rewards)."""
return torch.log(self.reward(final_states)).clip(self.log_reward_clip)
"""Calculates the log reward."""
return torch.log(self.reward(final_states))

@property
def log_partition(self) -> float:
Expand All @@ -224,7 +219,6 @@ def __init__(
sf: Optional[TT["state_shape", torch.float]] = None,
device_str: Optional[str] = None,
preprocessor: Optional[Preprocessor] = None,
log_reward_clip: Optional[float] = -100.0,
):
"""Initializes a discrete environment.

Expand All @@ -234,12 +228,10 @@ def __init__(
sf: The final state tensor (shared among all trajectories).
device_str: String representation of a torch.device.
preprocessor: An optional preprocessor for intermediate states.
log_reward_clip: Used to clip small rewards (in particular, log(0) rewards).
"""
self.n_actions = n_actions
super().__init__(s0, sf, device_str, preprocessor, log_reward_clip)
super().__init__(s0, sf, device_str, preprocessor)
self.is_discrete = True
self.log_reward_clip = log_reward_clip

def make_Actions_class(self) -> type[Actions]:
env = self
Expand Down
Loading
Loading