Skip to content

Commit

Permalink
Merge pull request #157 from GFNOrg/padding_util_fn
Browse files Browse the repository at this point in the history
removed utility function (DRY) addressing #154
  • Loading branch information
josephdviviano authored Feb 19, 2024
2 parents 3276492 + ebfd4c8 commit 8fed836
Showing 1 changed file with 35 additions and 39 deletions.
74 changes: 35 additions & 39 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,7 @@ def __init__(
self.env = env
self.is_backward = is_backward
self.states = (
states
if states is not None
else env.states_from_batch_shape((0, 0))
states if states is not None else env.states_from_batch_shape((0, 0))
)
assert len(self.states.batch_shape) == 2
self.actions = (
Expand Down Expand Up @@ -260,59 +258,43 @@ def extend(self, other: Trajectories) -> None:
):
batch_shape = self.actions.batch_shape
n_bs = len(batch_shape)

# Cast other to match self.
output_dtype = self.estimator_outputs.dtype
other.estimator_outputs = other.estimator_outputs.to(dtype=output_dtype)

if n_bs == 1:
# Concatenate along the only batch dimension.
self.estimator_outputs = torch.cat(
(self.estimator_outputs, other.estimator_outputs),
dim=0,
)

elif n_bs == 2:
if self.estimator_outputs.shape[0] != other.estimator_outputs.shape[0]:
# First we need to pad the first dimension on either self or other.
self_shape = np.array(self.estimator_outputs.shape)
other_shape = np.array(other.estimator_outputs.shape)
required_first_dim = max(self_shape[0], other_shape[0])

# TODO: This should be a single reused function (#154)
# The size of self needs to grow to match other along dim=0.
if self_shape[0] < other_shape[0]:
pad_dim = required_first_dim - self_shape[0]
pad_dim_full = (pad_dim,) + tuple(self_shape[1:])
output_padding = torch.full(
pad_dim_full,
fill_value=-float("inf"),
dtype=self.estimator_outputs.dtype, # TODO: This isn't working! Hence the cast below...
device=self.estimator_outputs.device,
)
self.estimator_outputs = torch.cat(
(self.estimator_outputs, output_padding),
dim=0,
# Concatenate along the first dimension, padding where required.
self_dim0 = self.estimator_outputs.shape[0]
other_dim0 = other.estimator_outputs.shape[0]
if self_dim0 != other_dim0:
# We need to pad the first dimension on either self or other.
required_first_dim = max(self_dim0, other_dim0)

if self_dim0 < other_dim0:
self.estimator_outputs = pad_dim0_to_target(
self.estimator_outputs,
required_first_dim,
)

# The size of other needs to grow to match self along dim=0.
if other_shape[0] < self_shape[0]:
pad_dim = required_first_dim - other_shape[0]
pad_dim_full = (pad_dim,) + tuple(other_shape[1:])
output_padding = torch.full(
pad_dim_full,
fill_value=-float("inf"),
dtype=other.estimator_outputs.dtype, # TODO: This isn't working! Hence the cast below...
device=other.estimator_outputs.device,
)
other.estimator_outputs = torch.cat(
(other.estimator_outputs, output_padding),
dim=0,
elif self_dim0 > other_dim0:
other.estimator_outputs = pad_dim0_to_target(
other.estimator_outputs,
required_first_dim,
)

# Concatenate the tensors along the second dimension.
self.estimator_outputs = torch.cat(
(self.estimator_outputs, other.estimator_outputs),
dim=1,
).to(
dtype=output_dtype
) # Cast to prevent single precision becoming double precision... weird.
)

# Sanity check. TODO: Remove?
assert self.estimator_outputs.shape[:n_bs] == batch_shape
Expand Down Expand Up @@ -376,3 +358,17 @@ def to_non_initial_intermediary_and_terminating_states(
terminating_states = self.last_states
terminating_states.log_rewards = self.log_rewards
return intermediary_states, terminating_states


def pad_dim0_to_target(a: torch.Tensor, target_dim0: int) -> torch.Tensor:
"""Pads tensor a to match the dimention of b."""
assert a.shape[0] < target_dim0, "a is already larger than target_dim0!"
pad_dim = target_dim0 - a.shape[0]
pad_dim_full = (pad_dim,) + tuple(a.shape[1:])
output_padding = torch.full(
pad_dim_full,
fill_value=-float("inf"),
dtype=a.dtype, # TODO: This isn't working! Hence the cast below...
device=a.device,
)
return torch.cat((a, output_padding), dim=0)

0 comments on commit 8fed836

Please sign in to comment.