Skip to content

Commit

Permalink
Code cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
DomInvivo committed Dec 15, 2023
1 parent 8ad1822 commit 62ab1ad
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions graphium/nn/architectures/global_architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,8 +666,9 @@ def _parse_subset_in_dim(
"""

# Parse the subset_in_dim, make sure value is between 0 and 1
subset_idx = None
if subset_in_dim is None:
subset_in_dim = 1.0
return 1.0, None
if isinstance(subset_in_dim, int):
assert (
subset_in_dim > 0 and subset_in_dim <= in_dim
Expand All @@ -681,9 +682,7 @@ def _parse_subset_in_dim(
subset_in_dim = 1

# Create the subset_idx, which is a list of indices to use for each ensemble
if subset_in_dim == in_dim:
subset_idx = None
else:
if subset_in_dim != in_dim:
subset_idx = torch.stack([torch.randperm(in_dim)[:subset_in_dim] for _ in range(num_ensemble)])

return subset_in_dim, subset_idx
Expand Down Expand Up @@ -719,7 +718,9 @@ def forward(self, h: torch.Tensor) -> torch.Tensor:
# Subset the input features for each MLP in the ensemble
if self.subset_idx is not None:
if len(h.shape) != 2:
assert h.shape[-3] == 1, f"Expected shape to be [B, Din] or [..., 1, B, Din], got {h.shape}."
assert (
h.shape[-3] == 1
), f"Expected shape to be [B, Din] or [..., 1, B, Din] when using `subset_in_dim`, got {h.shape}."
h = h[..., self.subset_idx].transpose(-2, -3)

# Run the standard forward pass
Expand Down

0 comments on commit 62ab1ad

Please sign in to comment.