Skip to content

Commit

Permalink
Merge pull request #491 from datamol-io/batch_ensemble
Browse files Browse the repository at this point in the history
EnsembleMLP and EnsembleFeedForwardNN to enable parallelization of ensemble models
  • Loading branch information
DomInvivo committed Dec 19, 2023
2 parents 639b3f8 + 62ab1ad commit 8cbf2d0
Show file tree
Hide file tree
Showing 6 changed files with 1,410 additions and 50 deletions.
1 change: 1 addition & 0 deletions graphium/nn/architectures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .global_architectures import TaskHeads
from .global_architectures import GraphOutputNN
from .pyg_architectures import FeedForwardPyg
from .global_architectures import EnsembleFeedForwardNN
356 changes: 352 additions & 4 deletions graphium/nn/architectures/global_architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,16 +166,18 @@ def __init__(
self.last_layer_is_readout = last_layer_is_readout
self._readout_cache = None

self.full_dims = [self.in_dim] + self.hidden_dims + [self.out_dim]
self._parse_layers(layer_type=layer_type, residual_type=residual_type)
self._create_layers()
self._check_bad_arguments()

def _parse_layers(self, layer_type, residual_type):
# Parse the layer and residuals
from graphium.utils.spaces import LAYERS_DICT, RESIDUALS_DICT

self.layer_class, self.layer_name = self._parse_class_from_dict(layer_type, LAYERS_DICT)
self.residual_class, self.residual_name = self._parse_class_from_dict(residual_type, RESIDUALS_DICT)

self.full_dims = [self.in_dim] + self.hidden_dims + [self.out_dim]
self._create_layers()
self._check_bad_arguments()

def _check_bad_arguments(self):
r"""
Raise comprehensive errors if the arguments seem wrong
Expand Down Expand Up @@ -403,6 +405,352 @@ def __repr__(self):
return class_str + layer_str


class EnsembleFeedForwardNN(FeedForwardNN):
def __init__(
self,
in_dim: int,
out_dim: int,
hidden_dims: Union[List[int], int],
num_ensemble: int,
reduction: Union[str, Callable],
subset_in_dim: Union[float, int] = 1.0,
depth: Optional[int] = None,
activation: Union[str, Callable] = "relu",
last_activation: Union[str, Callable] = "none",
dropout: float = 0.0,
last_dropout: float = 0.0,
normalization: Union[str, Callable] = "none",
first_normalization: Union[str, Callable] = "none",
last_normalization: Union[str, Callable] = "none",
residual_type: str = "none",
residual_skip_steps: int = 1,
name: str = "LNN",
layer_type: Union[str, nn.Module] = "ens-fc",
layer_kwargs: Optional[Dict] = None,
last_layer_is_readout: bool = False,
):
r"""
An ensemble of flexible neural network architecture, with variable hidden dimensions,
support for multiple layer types, and support for different residual
connections.
Parameters:
in_dim:
Input feature dimensions of the layer
out_dim:
Output feature dimensions of the layer
hidden_dims:
Either an integer specifying all the hidden dimensions,
or a list of dimensions in the hidden layers.
Be careful, the "simple" residual type only supports
hidden dimensions of the same value.
num_ensemble:
Number of MLPs that run in parallel.
reduction:
Reduction to use at the end of the MLP. Choices:
- "none" or `None`: No reduction
- "mean": Mean reduction
- "sum": Sum reduction
- "max": Max reduction
- "min": Min reduction
- "median": Median reduction
- `Callable`: Any callable function. Must take `dim` as a keyword argument.
subset_in_dim:
If float, ratio of the subset of the ensemble to use. Must be between 0 and 1.
If int, number of elements to subset from in_dim.
If `None`, the subset_in_dim is set to `1.0`.
A different subset is used for each ensemble.
Only valid if the input shape is `[B, Din]`.
depth:
If `hidden_dims` is an integer, `depth` is 1 + the number of
hidden layers to use.
If `hidden_dims` is a list, then
`depth` must be `None` or equal to `len(hidden_dims) + 1`
activation:
activation function to use in the hidden layers.
last_activation:
activation function to use in the last layer.
dropout:
The ratio of units to dropout. Must be between 0 and 1
last_dropout:
The ratio of units to dropout for the last_layer. Must be between 0 and 1
normalization:
Normalization to use. Choices:
- "none" or `None`: No normalization
- "batch_norm": Batch normalization
- "layer_norm": Layer normalization
- `Callable`: Any callable function
first_normalization:
Whether to use batch normalization **before** the first layer
last_normalization:
Whether to use batch normalization in the last layer
residual_type:
- "none": No residual connection
- "simple": Residual connection similar to the ResNet architecture.
See class `ResidualConnectionSimple`
- "weighted": Residual connection similar to the Resnet architecture,
but with weights applied before the summation. See class `ResidualConnectionWeighted`
- "concat": Residual connection where the residual is concatenated instead
of being added.
- "densenet": Residual connection where the residual of all previous layers
are concatenated. This leads to a strong increase in the number of parameters
if there are multiple hidden layers.
residual_skip_steps:
The number of steps to skip between each residual connection.
If `1`, all the layers are connected. If `2`, half of the
layers are connected.
name:
Name attributed to the current network, for display and printing
purposes.
layer_type:
The type of layers to use in the network.
Either "ens-fc" as the `EnsembleFCLayer`, or a class representing the `nn.Module`
to use.
layer_kwargs:
The arguments to be used in the initialization of the layer provided by `layer_type`
last_layer_is_readout: Whether the last layer should be treated as a readout layer.
Allows to use the `mup.MuReadout` from the muTransfer method https://github.com/microsoft/mup
"""

# Parse the ensemble arguments
if layer_kwargs is None:
layer_kwargs = {}
layer_kwargs["num_ensemble"] = self._parse_num_ensemble(num_ensemble, layer_kwargs)

# Parse the sample input dimension
self.subset_in_dim, self.subset_idx = self._parse_subset_in_dim(in_dim, subset_in_dim, num_ensemble)

super().__init__(
in_dim=in_dim,
out_dim=out_dim,
hidden_dims=hidden_dims,
depth=depth,
activation=activation,
last_activation=last_activation,
dropout=dropout,
last_dropout=last_dropout,
normalization=normalization,
first_normalization=first_normalization,
last_normalization=last_normalization,
residual_type=residual_type,
residual_skip_steps=residual_skip_steps,
name=name,
layer_type=layer_type,
layer_kwargs=layer_kwargs,
last_layer_is_readout=last_layer_is_readout,
)

# Parse the reduction
self.reduction = reduction
self.reduction_fn = self._parse_reduction(reduction)

def _create_layers(self):
self.full_dims[0] = self.subset_in_dim
super()._create_layers()

def _parse_num_ensemble(self, num_ensemble: int, layer_kwargs) -> int:
r"""
Parse the num_ensemble argument.
"""
num_ensemble_out = num_ensemble

# Get the num_ensemble from the layer_kwargs if it exists
num_ensemble_2 = None
if layer_kwargs is None:
layer_kwargs = {}
else:
num_ensemble_2 = layer_kwargs.get("num_ensemble", None)

if num_ensemble is None:
num_ensemble_out = num_ensemble_2

# Check that the num_ensemble is consistent
if num_ensemble_2 is not None:
assert (
num_ensemble_2 == num_ensemble
), f"num_ensemble={num_ensemble} != num_ensemble_2={num_ensemble_2}"

# Check that `num_ensemble_out` is not None
assert (
num_ensemble_out is not None
), f"num_ensemble={num_ensemble} and num_ensemble_2={num_ensemble_2}"

return num_ensemble_out

def _parse_reduction(self, reduction: Optional[Union[str, Callable]]) -> Optional[Callable]:
r"""
Parse the reduction argument.
"""

if isinstance(reduction, str):
reduction = reduction.lower()
if reduction is None or reduction == "none":
return None
elif reduction == "mean":
return torch.mean
elif reduction == "sum":
return torch.sum
elif reduction == "max":

def max_vals(x, dim):
return torch.max(x, dim=dim).values

return max_vals
elif reduction == "min":

def min_vals(x, dim):
return torch.min(x, dim=dim).values

return min_vals
elif reduction == "median":

def median_vals(x, dim):
return torch.median(x, dim=dim).values

return median_vals
elif callable(reduction):
return reduction
else:
raise ValueError(f"Unknown reduction {reduction}")

def _parse_subset_in_dim(
self, in_dim: int, subset_in_dim: Union[float, int], num_ensemble: int
) -> Tuple[float, int]:
r"""
Parse the subset_in_dim argument and the subset_in_dim.
The subset_in_dim is the ratio of the hidden features to use by each MLP of the ensemble.
The subset_in_dim is the number of input features to use by each MLP of the ensemble.
Parameters:
in_dim: The number of input features, before subsampling
subset_in_dim:
Ratio of the subset of features to use by each MLP of the ensemble.
Must be between 0 and 1. A different subset is used for each ensemble.
Only valid if the input shape is `[B, Din]`.
If None, the subset_in_dim is set to 1.0.
num_ensemble:
Number of MLPs that run in parallel.
Returns:
subset_in_dim: The ratio of the subset of features to use by each MLP of the ensemble.
subset_idx: The indices of the features to use by each MLP of the ensemble.
"""

# Parse the subset_in_dim, make sure value is between 0 and 1
subset_idx = None
if subset_in_dim is None:
return 1.0, None
if isinstance(subset_in_dim, int):
assert (
subset_in_dim > 0 and subset_in_dim <= in_dim
), f"subset_in_dim={subset_in_dim}, in_dim={in_dim}"
elif isinstance(subset_in_dim, float):
assert subset_in_dim > 0.0 and subset_in_dim <= 1.0, f"subset_in_dim={subset_in_dim}"

# Convert to integer value
subset_in_dim = int(in_dim * subset_in_dim)
if subset_in_dim == 0:
subset_in_dim = 1

# Create the subset_idx, which is a list of indices to use for each ensemble
if subset_in_dim != in_dim:
subset_idx = torch.stack([torch.randperm(in_dim)[:subset_in_dim] for _ in range(num_ensemble)])

return subset_in_dim, subset_idx

def _parse_layers(self, layer_type, residual_type):
# Parse the layer and residuals
from graphium.utils.spaces import ENSEMBLE_LAYERS_DICT, RESIDUALS_DICT

self.layer_class, self.layer_name = self._parse_class_from_dict(layer_type, ENSEMBLE_LAYERS_DICT)
self.residual_class, self.residual_name = self._parse_class_from_dict(residual_type, RESIDUALS_DICT)

def forward(self, h: torch.Tensor) -> torch.Tensor:
r"""
Subset the hidden dimension for each MLP,
forward the ensemble MLP on the input features,
then reduce the output if specified.
Parameters:
h: `torch.Tensor[B, Din]` or `torch.Tensor[..., 1, B, Din]` or `torch.Tensor[..., L, B, Din]`:
Input feature tensor, before the MLP.
`Din` is the number of input features, `B` is the batch size, and `L` is the number of ensembles.
Returns:
`torch.Tensor[..., L, B, Dout]` or `torch.Tensor[..., B, Dout]`:
Output feature tensor, after the MLP.
`Dout` is the number of output features, `B` is the batch size, and `L` is the number of ensembles.
`L` is removed if a reduction is specified.
"""
# Subset the input features for each MLP in the ensemble
if self.subset_idx is not None:
if len(h.shape) != 2:
assert (
h.shape[-3] == 1
), f"Expected shape to be [B, Din] or [..., 1, B, Din] when using `subset_in_dim`, got {h.shape}."
h = h[..., self.subset_idx].transpose(-2, -3)

# Run the standard forward pass
h = super().forward(h)

# Reduce the output if specified
if self.reduction_fn is not None:
h = self.reduction_fn(h, dim=-3)

return h

def get_init_kwargs(self) -> Dict[str, Any]:
"""
Get a dictionary that can be used to instanciate a new object with identical parameters.
"""
kw = super().get_init_kwargs()
kw["num_ensemble"] = self.num_ensemble
kw["reduction"] = self.reduction
return kw

def __repr__(self):
r"""
Controls how the class is printed
"""
class_str = f"{self.name}(depth={self.depth}, {self.residual_layer})\n , num_ensemble={self.num_ensemble}, reduction={self.reduction}\n "
layer_str = f"[{self.layer_class.__name__}[{' -> '.join(map(str, self.full_dims))}]"

return class_str + layer_str


class FeedForwardGraph(FeedForwardNN):
def __init__(
self,
Expand Down
Loading

0 comments on commit 8cbf2d0

Please sign in to comment.