From 76074ce4c2953046e663155de2c6afcf435c798d Mon Sep 17 00:00:00 2001 From: DomInvivo Date: Tue, 12 Dec 2023 17:34:00 -0500 Subject: [PATCH 01/18] First step towards parallelized ensemble MLPs --- .../nn/architectures/global_architectures.py | 243 +++++++++- graphium/nn/base_layers.py | 47 +- graphium/nn/ensemble_layers.py | 418 ++++++++++++++++++ graphium/utils/spaces.py | 5 + 4 files changed, 670 insertions(+), 43 deletions(-) create mode 100644 graphium/nn/ensemble_layers.py diff --git a/graphium/nn/architectures/global_architectures.py b/graphium/nn/architectures/global_architectures.py index 6570ca492..a40e6f044 100644 --- a/graphium/nn/architectures/global_architectures.py +++ b/graphium/nn/architectures/global_architectures.py @@ -166,16 +166,17 @@ def __init__( self.last_layer_is_readout = last_layer_is_readout self._readout_cache = None + self.full_dims = [self.in_dim] + self.hidden_dims + [self.out_dim] + self._parse_layers(layer_type=layer_type, residual_type=residual_type) + self._create_layers() + self._check_bad_arguments() + + def _parse_layers(self, layer_type, residual_type): # Parse the layer and residuals from graphium.utils.spaces import LAYERS_DICT, RESIDUALS_DICT - self.layer_class, self.layer_name = self._parse_class_from_dict(layer_type, LAYERS_DICT) self.residual_class, self.residual_name = self._parse_class_from_dict(residual_type, RESIDUALS_DICT) - self.full_dims = [self.in_dim] + self.hidden_dims + [self.out_dim] - self._create_layers() - self._check_bad_arguments() - def _check_bad_arguments(self): r""" Raise comprehensive errors if the arguments seem wrong @@ -403,6 +404,238 @@ def __repr__(self): return class_str + layer_str +class EnsembleFeedForwardNN(nn.Module, MupMixin): + def __init__( + self, + in_dim: int, + out_dim: int, + hidden_dims: Union[List[int], int], + num_ensemble: int, + reduction: Union[str, Callable], + depth: Optional[int] = None, + activation: Union[str, Callable] = "relu", + last_activation: Union[str, Callable] = "none", + dropout: float = 0.0, + last_dropout: float = 0.0, + normalization: Union[str, Callable] = "none", + first_normalization: Union[str, Callable] = "none", + last_normalization: Union[str, Callable] = "none", + residual_type: str = "none", + residual_skip_steps: int = 1, + name: str = "LNN", + layer_type: Union[str, nn.Module] = "ens-fc", + layer_kwargs: Optional[Dict] = None, + last_layer_is_readout: bool = False, + ): + r""" + An ensemble of flexible neural network architecture, with variable hidden dimensions, + support for multiple layer types, and support for different residual + connections. + + Parameters: + + in_dim: + Input feature dimensions of the layer + + out_dim: + Output feature dimensions of the layer + + hidden_dims: + Either an integer specifying all the hidden dimensions, + or a list of dimensions in the hidden layers. + Be careful, the "simple" residual type only supports + hidden dimensions of the same value. + + num_ensemble: + Number of MLPs that run in parallel. + + reduction: + Reduction to use at the end of the MLP. Choices: + + - "none" or `None`: No reduction + - "mean": Mean reduction + - "sum": Sum reduction + - "max": Max reduction + - "min": Min reduction + - "median": Median reduction + - `Callable`: Any callable function. Must take `dim` as a keyword argument. + + depth: + If `hidden_dims` is an integer, `depth` is 1 + the number of + hidden layers to use. + If `hidden_dims` is a list, then + `depth` must be `None` or equal to `len(hidden_dims) + 1` + + activation: + activation function to use in the hidden layers. + + last_activation: + activation function to use in the last layer. + + dropout: + The ratio of units to dropout. Must be between 0 and 1 + + last_dropout: + The ratio of units to dropout for the last_layer. Must be between 0 and 1 + + normalization: + Normalization to use. Choices: + + - "none" or `None`: No normalization + - "batch_norm": Batch normalization + - "layer_norm": Layer normalization + - `Callable`: Any callable function + + first_normalization: + Whether to use batch normalization **before** the first layer + + last_normalization: + Whether to use batch normalization in the last layer + + residual_type: + - "none": No residual connection + - "simple": Residual connection similar to the ResNet architecture. + See class `ResidualConnectionSimple` + - "weighted": Residual connection similar to the Resnet architecture, + but with weights applied before the summation. See class `ResidualConnectionWeighted` + - "concat": Residual connection where the residual is concatenated instead + of being added. + - "densenet": Residual connection where the residual of all previous layers + are concatenated. This leads to a strong increase in the number of parameters + if there are multiple hidden layers. + + residual_skip_steps: + The number of steps to skip between each residual connection. + If `1`, all the layers are connected. If `2`, half of the + layers are connected. + + name: + Name attributed to the current network, for display and printing + purposes. + + layer_type: + The type of layers to use in the network. + Either "ens-fc" as the `EnsembleFCLayer`, or a class representing the `nn.Module` + to use. + + layer_kwargs: + The arguments to be used in the initialization of the layer provided by `layer_type` + + last_layer_is_readout: Whether the last layer should be treated as a readout layer. + Allows to use the `mup.MuReadout` from the muTransfer method https://github.com/microsoft/mup + + """ + + # Parse the ensemble arguments + self.num_ensemble = num_ensemble + num_ensemble_2 = layer_kwargs.get("num_ensemble", None) + if num_ensemble_2 is None: + layer_kwargs["num_ensemble"] = num_ensemble + else: + assert num_ensemble_2 == num_ensemble, f"num_ensemble={num_ensemble} != num_ensemble_2={num_ensemble_2}" + + super().__init__( + in_dim=in_dim, + out_dim=out_dim, + hidden_dims=hidden_dims, + depth=depth, + activation=activation, + last_activation=last_activation, + dropout=dropout, + last_dropout=last_dropout, + normalization=normalization, + first_normalization=first_normalization, + last_normalization=last_normalization, + residual_type=residual_type, + residual_skip_steps=residual_skip_steps, + name=name, + layer_type=layer_type, + layer_kwargs=layer_kwargs, + last_layer_is_readout=last_layer_is_readout, + reduction=reduction, + ) + + # Parse the reduction + self.reduction = reduction + self.reduction_fn = self._parse_reduction(reduction) + + + def _parse_reduction(self, reduction: Optional[Union[str, Callable]]) -> Optional[Callable]: + r""" + Parse the reduction argument. + """ + + if isinstance(reduction, str): + reduction = reduction.lower() + if reduction is None or reduction == "none": + return None + elif reduction == "mean": + return torch.mean + elif reduction == "sum": + return torch.sum + elif reduction == "max": + return torch.max + elif reduction == "min": + return torch.min + elif reduction == "median": + return torch.median + elif callable(reduction): + return reduction + else: + raise ValueError(f"Unknown reduction {reduction}") + + def _parse_layers(self, layer_type, residual_type): + # Parse the layer and residuals + from graphium.utils.spaces import ENSEMBLE_LAYERS_DICT, RESIDUALS_DICT + self.layer_class, self.layer_name = self._parse_class_from_dict(layer_type, ENSEMBLE_LAYERS_DICT) + self.residual_class, self.residual_name = self._parse_class_from_dict(residual_type, RESIDUALS_DICT) + + + def forward(self, h: torch.Tensor) -> torch.Tensor: + r""" + Apply the ensemble MLP on the input features, then reduce the output if specified. + + Parameters: + + h: `torch.Tensor[B, Din]` or `torch.Tensor[..., 1, B, Din]` or `torch.Tensor[..., L, B, Din]`: + + Input feature tensor, before the MLP. + `Din` is the number of input features, `B` is the batch size, and `L` is the number of ensembles. + + Returns: + + `torch.Tensor[..., L, B, Dout]` or `torch.Tensor[..., B, Dout]`: + + Output feature tensor, after the MLP. + `Dout` is the number of output features, `B` is the batch size, and `L` is the number of ensembles. + `L` is removed if a reduction is specified. + """ + + h = super().forward(h) + if self.reduction is not None: + h = self.reduction(h, dim=-2) + + return h + + def get_init_kwargs(self) -> Dict[str, Any]: + """ + Get a dictionary that can be used to instanciate a new object with identical parameters. + """ + kw = super().get_init_kwargs() + kw["num_ensemble"] = self.num_ensemble + kw["reduction"] = self.reduction + return kw + + def __repr__(self): + r""" + Controls how the class is printed + """ + class_str = f"{self.name}(depth={self.depth}, {self.residual_layer})\n , num_ensemble={self.num_ensemble}, reduction={self.reduction}\n " + layer_str = f"[{self.layer_class.__name__}[{' -> '.join(map(str, self.full_dims))}]" + + return class_str + layer_str + + class FeedForwardGraph(FeedForwardNN): def __init__( self, diff --git a/graphium/nn/base_layers.py b/graphium/nn/base_layers.py index 8a8b29f1a..229fdcdc6 100644 --- a/graphium/nn/base_layers.py +++ b/graphium/nn/base_layers.py @@ -259,43 +259,6 @@ def width_mult(self): return self.absolute_width / self.base_width -class MuReadoutGraphium(MuReadout): - """ - PopTorch-compatible replacement for `mup.MuReadout` - - Not quite a drop-in replacement for `mup.MuReadout` - you need to specify - `base_width`. - - Set `base_width` to width of base model passed to `mup.set_base_shapes` - to get same results on IPU and CPU. Should still "work" with any other - value, but won't give the same results as CPU - """ - - def __init__(self, in_features, *args, **kwargs): - super().__init__(in_features, *args, **kwargs) - self.base_width = in_features - - @property - def absolute_width(self): - return float(self.in_features) - - @property - def base_width(self): - return self._base_width - - @base_width.setter - def base_width(self, val): - if val is None: - return - assert isinstance( - val, (int, torch.int, torch.long) - ), f"`base_width` must be None, int or long, provided {val} of type {type(val)}" - self._base_width = val - - def width_mult(self): - return self.absolute_width / self.base_width - - class FCLayer(nn.Module): def __init__( self, @@ -490,6 +453,8 @@ def __init__( last_layer_is_readout: bool = False, droppath_rate: float = 0.0, constant_droppath_rate: bool = True, + fc_layer: FCLayer = FCLayer, + fc_layer_kwargs: Optional[dict] = None, ): r""" Simple multi-layer perceptron, built of a series of FCLayers @@ -538,12 +503,17 @@ def __init__( If `True`, drop rates will remain constant accross layers. Otherwise, drop rates will vary stochastically. See `DropPath.get_stochastic_drop_rate` + fc_layer: + The fully connected layer to use. Must inherit from `FCLayer`. + fc_layer_kwargs: + Keyword arguments to pass to the fully connected layer. """ super().__init__() self.in_dim = in_dim self.out_dim = out_dim + self.fc_layer_kwargs = deepcopy(fc_layer_kwargs) or {} # Parse the hidden dimensions and depth if isinstance(hidden_dims, int): @@ -585,7 +555,7 @@ def __init__( # Add a fully-connected layer fully_connected.append( - FCLayer( + fc_layer( all_dims[ii], all_dims[ii + 1], activation=this_activation, @@ -593,6 +563,7 @@ def __init__( dropout=this_dropout, is_readout_layer=is_readout_layer, droppath_rate=this_drop_rate, + **self.fc_layer_kwargs, ) ) diff --git a/graphium/nn/ensemble_layers.py b/graphium/nn/ensemble_layers.py new file mode 100644 index 000000000..261ccd3b5 --- /dev/null +++ b/graphium/nn/ensemble_layers.py @@ -0,0 +1,418 @@ +from typing import Union, Callable, Optional, Type, Tuple, Iterable +from copy import deepcopy +from loguru import logger + + +import torch +import torch.nn as nn +import mup.init as mupi +from mup import set_base_shapes + +from graphium.nn.base_layers import FCLayer, MLP + +class EnsembleLinear(nn.Module): + def __init__(self, + in_dim: int, + out_dim: int, + num_ensemble: int, + bias: bool = True, + init_fn: Optional[Callable] = None, + ): + r""" + Multiple linear layers that are applied in parallel with batched matrix multiplication with `torch.matmul`. + + Parameters: + in_dim: + Input dimension of the linear layers + out_dim: + Output dimension of the linear layers. + num_ensemble: + Number of linear layers in the ensemble. + + + """ + super(EnsembleLinear, self).__init__() + + # Initialize weight and bias as learnable parameters + self.weight = nn.Parameter(torch.Tensor(num_ensemble, out_dim, in_dim)) + if bias: + self.bias = nn.Parameter(torch.Tensor(num_ensemble, 1, out_dim)) + else: + self.register_parameter('bias', None) + + # Initialize parameters + self.init_fn = init_fn if init_fn is not None else mupi.xavier_uniform_ + self.reset_parameters() + + def reset_parameters(self): + """ + Reset the parameters of the linear layer using the `init_fn`. + """ + # Initialize weight using the provided initialization function + self.init_fn(self.weight) + + # Initialize bias if present + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, h: torch.Tensor) -> torch.Tensor: + r""" + Apply the batched linear transformation on the input features. + + Parameters: + h: `torch.Tensor[B, Din]` or `torch.Tensor[..., 1, B, Din]` or `torch.Tensor[..., L, B, Din]`: + Input feature tensor, before the batched linear transformation. + `Din` is the number of input features, `B` is the batch size, and `L` is the number of linear layers. + + Returns: + `torch.Tensor[..., L, B, Dout]`: + Output feature tensor, after the batched linear transformation. + `Dout` is the number of output features, , `B` is the batch size, and `L` is the number of linear layers. + """ + + # Perform the linear transformation using torch.matmul + h = torch.matmul(self.weight, h.transpose(-1, -2)).transpose(-1, -2) + + # Add bias if present + if self.bias is not None: + h += self.bias + + return h + +class EnsembleFCLayer(FCLayer): + def __init__( + self, + in_dim: int, + out_dim: int, + num_ensemble: int, + activation: Union[str, Callable] = "relu", + dropout: float = 0.0, + normalization: Union[str, Callable] = "none", + bias: bool = True, + init_fn: Optional[Callable] = None, + is_readout_layer: bool = False, + droppath_rate: float = 0.0, + ): + r""" + Multiple fully connected layers running in parallel. + This layer is centered around a `torch.nn.Linear` module. + The order in which transformations are applied is: + + - Dense Layer + - Activation + - Dropout (if applicable) + - Batch Normalization (if applicable) + + Parameters: + in_dim: + Input dimension of the layer (the `torch.nn.Linear`) + out_dim: + Output dimension of the layer. + num_ensemble: + Number of linear layers in the ensemble. + dropout: + The ratio of units to dropout. No dropout by default. + activation: + Activation function to use. + normalization: + Normalization to use. Choices: + + - "none" or `None`: No normalization + - "batch_norm": Batch normalization + - "layer_norm": Layer normalization + - `Callable`: Any callable function + bias: + Whether to enable bias in for the linear layer. + init_fn: + Initialization function to use for the weight of the layer. Default is + $$\mathcal{U}(-\sqrt{k}, \sqrt{k})$$ with $$k=\frac{1}{ \text{in_dim}}$$ + is_readout_layer: Whether the layer should be treated as a readout layer by replacing of `torch.nn.Linear` + by `mup.MuReadout` from the muTransfer method https://github.com/microsoft/mup + + droppath_rate: + stochastic depth drop rate, between 0 and 1, see https://arxiv.org/abs/1603.09382 + Attributes: + dropout (int): + The ratio of units to dropout. + normalization (None or Callable): + Normalization layer + linear (`torch.nn.Linear`): + The linear layer + activation (`torch.nn.Module`): + The activation layer + init_fn (Callable): + Initialization function used for the weight of the layer + in_dim (int): + Input dimension of the linear layer + out_dim (int): + Output dimension of the linear layer + """ + + super().__init__( + in_dim=in_dim, + out_dim=out_dim, + activation=activation, + dropout=dropout, + normalization=normalization, + bias=bias, + init_fn=init_fn, + is_readout_layer=is_readout_layer, + droppath_rate=droppath_rate, + ) + + # Linear layer, or MuReadout layer + if not is_readout_layer: + self.linear = EnsembleLinear(in_dim, out_dim, num_ensemble=num_ensemble, bias=bias, init_fn=init_fn) + else: + self.linear = EnsembleMuReadoutGraphium(in_dim, out_dim, bias=bias) + + self.reset_parameters() + + def reset_parameters(self, init_fn=None): + """ + Reset the parameters of the linear layer using the `init_fn`. + """ + set_base_shapes(self, None, rescale_params=False) # Set the shapes of the tensors, useful for mup + self.linear.reset_parameters() + + def __repr__(self): + rep = super().__repr__() + rep = rep[:-1] + f", num_ensemble={self.linear.weight.shape[0]})" + return rep + +class EnsembleMuReadoutGraphium(EnsembleLinear): + """ + This layer implements an ensemble version of μP with a 1/width multiplier and a + constant variance initialization for both weights and biases. + """ + def __init__(self, + in_dim: int, + out_dim: int, + num_ensemble: int, + bias: bool = True, + init_fn: Optional[Callable] = None, + readout_zero_init=False, + output_mult=1.0 + ): + self.output_mult = output_mult + self.readout_zero_init = readout_zero_init + self.base_width = in_dim + super().__init__( + in_dim=in_dim, + out_dim=out_dim, + num_ensemble=num_ensemble, + bias=bias, + init_fn=init_fn, + ) + + def reset_parameters(self) -> None: + if self.readout_zero_init: + self.weight.data[:] = 0 + if self.bias is not None: + self.bias.data[:] = 0 + else: + super().reset_parameters() + + def width_mult(self): + assert hasattr(self.weight, 'infshape'), ( + 'Please call set_base_shapes(...). If using torch.nn.DataParallel, ' + 'switch to distributed training with ' + 'torch.nn.parallel.DistributedDataParallel instead' + ) + return self.weight.infshape.width_mult() + + def _rescale_parameters(self): + '''Rescale parameters to convert SP initialization to μP initialization. + + Warning: This method is NOT idempotent and should be called only once + unless you know what you are doing. + ''' + if hasattr(self, '_has_rescaled_params') and self._has_rescaled_params: + raise RuntimeError( + "`_rescale_parameters` has been called once before already. " + "Unless you know what you are doing, usually you should not be calling `_rescale_parameters` more than once.\n" + "If you called `set_base_shapes` on a model loaded from a checkpoint, " + "or just want to re-set the base shapes of an existing model, " + "make sure to set the flag `rescale_params=False`.\n" + "To bypass this error and *still rescale parameters*, set `self._has_rescaled_params=False` before this call.") + if self.bias is not None: + self.bias.data *= self.width_mult()**0.5 + self.weight.data *= self.width_mult()**0.5 + self._has_rescaled_params = True + + def forward(self, x): + return super().forward( + self.output_mult * x / self.width_mult()) + + @property + def absolute_width(self): + return float(self.in_features) + + @property + def base_width(self): + return self._base_width + + @base_width.setter + def base_width(self, val): + if val is None: + return + assert isinstance( + val, (int, torch.int, torch.long) + ), f"`base_width` must be None, int or long, provided {val} of type {type(val)}" + self._base_width = val + + def width_mult(self): + return self.absolute_width / self.base_width + + +class EnsembleMLP(MLP): + def __init__( + self, + in_dim: int, + hidden_dims: Union[Iterable[int], int], + out_dim: int, + depth: int, + num_ensemble: int, + reduction: Optional[Union[str, Callable]] = "none", + activation: Union[str, Callable] = "relu", + last_activation: Union[str, Callable] = "none", + dropout: float = 0.0, + last_dropout: float = 0.0, + normalization: Union[Type[None], str, Callable] = "none", + last_normalization: Union[Type[None], str, Callable] = "none", + first_normalization: Union[Type[None], str, Callable] = "none", + last_layer_is_readout: bool = False, + droppath_rate: float = 0.0, + constant_droppath_rate: bool = True, + ): + r""" + Simple multi-layer perceptron, built of a series of FCLayers + + Parameters: + in_dim: + Input dimension of the MLP + hidden_dims: + Either an integer specifying all the hidden dimensions, + or a list of dimensions in the hidden layers. + out_dim: + Output dimension of the MLP. + depth: + If `hidden_dims` is an integer, `depth` is 1 + the number of + hidden layers to use. + If `hidden_dims` is a list, then + `depth` must be `None` or equal to `len(hidden_dims) + 1` + num_ensemble: + Number of MLPs that run in parallel. + reduction: + Reduction to use at the end of the MLP. Choices: + + - "none" or `None`: No reduction + - "mean": Mean reduction + - "sum": Sum reduction + - "max": Max reduction + - "min": Min reduction + - `Callable`: Any callable function. Must take `dim` as a keyword argument. + activation: + Activation function to use in all the layers except the last. + if `layers==1`, this parameter is ignored + last_activation: + Activation function to use in the last layer. + dropout: + The ratio of units to dropout. Must be between 0 and 1 + normalization: + Normalization to use. Choices: + + - "none" or `None`: No normalization + - "batch_norm": Batch normalization + - "layer_norm": Layer normalization in the hidden layers. + - `Callable`: Any callable function + + if `layers==1`, this parameter is ignored + last_normalization: + Norrmalization to use **after the last layer**. Same options as `normalization`. + first_normalization: + Norrmalization to use in **before the first layer**. Same options as `normalization`. + last_dropout: + The ratio of units to dropout at the last layer. + last_layer_is_readout: Whether the last layer should be treated as a readout layer. + Allows to use the `mup.MuReadout` from the muTransfer method https://github.com/microsoft/mup + droppath_rate: + stochastic depth drop rate, between 0 and 1. + See https://arxiv.org/abs/1603.09382 + constant_droppath_rate: + If `True`, drop rates will remain constant accross layers. + Otherwise, drop rates will vary stochastically. + See `DropPath.get_stochastic_drop_rate` + """ + + super().__init__( + in_dim=in_dim, + hidden_dims=hidden_dims, + out_dim=out_dim, + depth=depth, + num_ensemble=num_ensemble, + activation=activation, + last_activation=last_activation, + dropout=dropout, + last_dropout=last_dropout, + normalization=normalization, + last_normalization=last_normalization, + first_normalization=first_normalization, + last_layer_is_readout=last_layer_is_readout, + droppath_rate=droppath_rate, + constant_droppath_rate=constant_droppath_rate, + + ) + + self.reduction = self._parse_reduction(reduction) + + def _parse_reduction(self, reduction: Optional[Union[str, Callable]]) -> Optional[Callable]: + r""" + Parse the reduction argument. + """ + + if isinstance(reduction, str): + reduction = reduction.lower() + if reduction is None or reduction == "none": + return None + elif reduction == "mean": + return torch.mean + elif reduction == "sum": + return torch.sum + + elif callable(reduction): + return reduction + else: + raise ValueError(f"Unknown reduction {reduction}") + + def forward(self, h: torch.Tensor) -> torch.Tensor: + r""" + Apply the ensemble MLP on the input features, then reduce the output if specified. + + Parameters: + + h: `torch.Tensor[B, Din]` or `torch.Tensor[..., 1, B, Din]` or `torch.Tensor[..., L, B, Din]`: + + Input feature tensor, before the MLP. + `Din` is the number of input features, `B` is the batch size, and `L` is the number of ensembles. + + Returns: + + `torch.Tensor[..., L, B, Dout]` or `torch.Tensor[..., B, Dout]`: + + Output feature tensor, after the MLP. + `Dout` is the number of output features, `B` is the batch size, and `L` is the number of ensembles. + `L` is removed if a reduction is specified. + """ + h = super().forward(h) + if self.reduction is not None: + h = self.reduction(h, dim=-2) + return h + + def __repr__(self): + r""" + Controls how the class is printed + """ + rep = super().__repr__() + rep = rep[:-1] + f", num_ensemble={self.layers[0].linear.weight.shape[0]})" + diff --git a/graphium/utils/spaces.py b/graphium/utils/spaces.py index d821223a4..67f1f4502 100644 --- a/graphium/utils/spaces.py +++ b/graphium/utils/spaces.py @@ -27,6 +27,10 @@ "fc": BaseLayers.FCLayer, } +ENSEMBLE_FC_LAYERS_DICT = { + "ens-fc": BaseLayers.EnsembleFCLayer, +} + PYG_LAYERS_DICT = { "pyg:gcn": PygLayers.GCNConvPyg, "pyg:gin": PygLayers.GINConvPyg, @@ -41,6 +45,7 @@ LAYERS_DICT = deepcopy(FC_LAYERS_DICT) LAYERS_DICT.update(deepcopy(PYG_LAYERS_DICT)) +ENSEMBLE_LAYERS_DICT = deepcopy(ENSEMBLE_FC_LAYERS_DICT) RESIDUALS_DICT = { "none": Residuals.ResidualConnectionNone, From 44f9eba94cb3708bfdfbf104852923ce02730908 Mon Sep 17 00:00:00 2001 From: DomInvivo Date: Tue, 12 Dec 2023 17:34:38 -0500 Subject: [PATCH 02/18] applied black linting --- .../nn/architectures/global_architectures.py | 8 ++- graphium/nn/ensemble_layers.py | 70 ++++++++++--------- 2 files changed, 43 insertions(+), 35 deletions(-) diff --git a/graphium/nn/architectures/global_architectures.py b/graphium/nn/architectures/global_architectures.py index a40e6f044..fe1a82445 100644 --- a/graphium/nn/architectures/global_architectures.py +++ b/graphium/nn/architectures/global_architectures.py @@ -174,6 +174,7 @@ def __init__( def _parse_layers(self, layer_type, residual_type): # Parse the layer and residuals from graphium.utils.spaces import LAYERS_DICT, RESIDUALS_DICT + self.layer_class, self.layer_name = self._parse_class_from_dict(layer_type, LAYERS_DICT) self.residual_class, self.residual_name = self._parse_class_from_dict(residual_type, RESIDUALS_DICT) @@ -532,7 +533,9 @@ def __init__( if num_ensemble_2 is None: layer_kwargs["num_ensemble"] = num_ensemble else: - assert num_ensemble_2 == num_ensemble, f"num_ensemble={num_ensemble} != num_ensemble_2={num_ensemble_2}" + assert ( + num_ensemble_2 == num_ensemble + ), f"num_ensemble={num_ensemble} != num_ensemble_2={num_ensemble_2}" super().__init__( in_dim=in_dim, @@ -559,7 +562,6 @@ def __init__( self.reduction = reduction self.reduction_fn = self._parse_reduction(reduction) - def _parse_reduction(self, reduction: Optional[Union[str, Callable]]) -> Optional[Callable]: r""" Parse the reduction argument. @@ -587,10 +589,10 @@ def _parse_reduction(self, reduction: Optional[Union[str, Callable]]) -> Optiona def _parse_layers(self, layer_type, residual_type): # Parse the layer and residuals from graphium.utils.spaces import ENSEMBLE_LAYERS_DICT, RESIDUALS_DICT + self.layer_class, self.layer_name = self._parse_class_from_dict(layer_type, ENSEMBLE_LAYERS_DICT) self.residual_class, self.residual_name = self._parse_class_from_dict(residual_type, RESIDUALS_DICT) - def forward(self, h: torch.Tensor) -> torch.Tensor: r""" Apply the ensemble MLP on the input features, then reduce the output if specified. diff --git a/graphium/nn/ensemble_layers.py b/graphium/nn/ensemble_layers.py index 261ccd3b5..ad8182d67 100644 --- a/graphium/nn/ensemble_layers.py +++ b/graphium/nn/ensemble_layers.py @@ -10,14 +10,16 @@ from graphium.nn.base_layers import FCLayer, MLP + class EnsembleLinear(nn.Module): - def __init__(self, - in_dim: int, - out_dim: int, - num_ensemble: int, - bias: bool = True, - init_fn: Optional[Callable] = None, - ): + def __init__( + self, + in_dim: int, + out_dim: int, + num_ensemble: int, + bias: bool = True, + init_fn: Optional[Callable] = None, + ): r""" Multiple linear layers that are applied in parallel with batched matrix multiplication with `torch.matmul`. @@ -38,7 +40,7 @@ def __init__(self, if bias: self.bias = nn.Parameter(torch.Tensor(num_ensemble, 1, out_dim)) else: - self.register_parameter('bias', None) + self.register_parameter("bias", None) # Initialize parameters self.init_fn = init_fn if init_fn is not None else mupi.xavier_uniform_ @@ -79,6 +81,7 @@ def forward(self, h: torch.Tensor) -> torch.Tensor: return h + class EnsembleFCLayer(FCLayer): def __init__( self, @@ -162,7 +165,9 @@ def __init__( # Linear layer, or MuReadout layer if not is_readout_layer: - self.linear = EnsembleLinear(in_dim, out_dim, num_ensemble=num_ensemble, bias=bias, init_fn=init_fn) + self.linear = EnsembleLinear( + in_dim, out_dim, num_ensemble=num_ensemble, bias=bias, init_fn=init_fn + ) else: self.linear = EnsembleMuReadoutGraphium(in_dim, out_dim, bias=bias) @@ -180,20 +185,23 @@ def __repr__(self): rep = rep[:-1] + f", num_ensemble={self.linear.weight.shape[0]})" return rep + class EnsembleMuReadoutGraphium(EnsembleLinear): """ This layer implements an ensemble version of μP with a 1/width multiplier and a constant variance initialization for both weights and biases. """ - def __init__(self, - in_dim: int, - out_dim: int, - num_ensemble: int, - bias: bool = True, - init_fn: Optional[Callable] = None, - readout_zero_init=False, - output_mult=1.0 - ): + + def __init__( + self, + in_dim: int, + out_dim: int, + num_ensemble: int, + bias: bool = True, + init_fn: Optional[Callable] = None, + readout_zero_init=False, + output_mult=1.0, + ): self.output_mult = output_mult self.readout_zero_init = readout_zero_init self.base_width = in_dim @@ -214,35 +222,35 @@ def reset_parameters(self) -> None: super().reset_parameters() def width_mult(self): - assert hasattr(self.weight, 'infshape'), ( - 'Please call set_base_shapes(...). If using torch.nn.DataParallel, ' - 'switch to distributed training with ' - 'torch.nn.parallel.DistributedDataParallel instead' + assert hasattr(self.weight, "infshape"), ( + "Please call set_base_shapes(...). If using torch.nn.DataParallel, " + "switch to distributed training with " + "torch.nn.parallel.DistributedDataParallel instead" ) return self.weight.infshape.width_mult() def _rescale_parameters(self): - '''Rescale parameters to convert SP initialization to μP initialization. + """Rescale parameters to convert SP initialization to μP initialization. Warning: This method is NOT idempotent and should be called only once unless you know what you are doing. - ''' - if hasattr(self, '_has_rescaled_params') and self._has_rescaled_params: + """ + if hasattr(self, "_has_rescaled_params") and self._has_rescaled_params: raise RuntimeError( "`_rescale_parameters` has been called once before already. " "Unless you know what you are doing, usually you should not be calling `_rescale_parameters` more than once.\n" "If you called `set_base_shapes` on a model loaded from a checkpoint, " "or just want to re-set the base shapes of an existing model, " "make sure to set the flag `rescale_params=False`.\n" - "To bypass this error and *still rescale parameters*, set `self._has_rescaled_params=False` before this call.") + "To bypass this error and *still rescale parameters*, set `self._has_rescaled_params=False` before this call." + ) if self.bias is not None: - self.bias.data *= self.width_mult()**0.5 - self.weight.data *= self.width_mult()**0.5 + self.bias.data *= self.width_mult() ** 0.5 + self.weight.data *= self.width_mult() ** 0.5 self._has_rescaled_params = True def forward(self, x): - return super().forward( - self.output_mult * x / self.width_mult()) + return super().forward(self.output_mult * x / self.width_mult()) @property def absolute_width(self): @@ -361,7 +369,6 @@ def __init__( last_layer_is_readout=last_layer_is_readout, droppath_rate=droppath_rate, constant_droppath_rate=constant_droppath_rate, - ) self.reduction = self._parse_reduction(reduction) @@ -415,4 +422,3 @@ def __repr__(self): """ rep = super().__repr__() rep = rep[:-1] + f", num_ensemble={self.layers[0].linear.weight.shape[0]})" - From 015373f5222409f6fff13174d5f801706151e6e0 Mon Sep 17 00:00:00 2001 From: DomInvivo Date: Wed, 13 Dec 2023 17:38:29 -0500 Subject: [PATCH 03/18] Tested the ensemble layers --- graphium/nn/base_layers.py | 12 +- graphium/nn/ensemble_layers.py | 17 ++- graphium/utils/spaces.py | 3 +- tests/test_ensemble_layers.py | 256 +++++++++++++++++++++++++++++++++ 4 files changed, 274 insertions(+), 14 deletions(-) create mode 100644 tests/test_ensemble_layers.py diff --git a/graphium/nn/base_layers.py b/graphium/nn/base_layers.py index 229fdcdc6..1d402b760 100644 --- a/graphium/nn/base_layers.py +++ b/graphium/nn/base_layers.py @@ -236,7 +236,7 @@ class MuReadoutGraphium(MuReadout): def __init__(self, in_features, *args, **kwargs): super().__init__(in_features, *args, **kwargs) - self.base_width = in_features + self._base_width = in_features @property def absolute_width(self): @@ -442,7 +442,7 @@ def __init__( in_dim: int, hidden_dims: Union[Iterable[int], int], out_dim: int, - depth: int, + depth: Optional[int] = None, activation: Union[str, Callable] = "relu", last_activation: Union[str, Callable] = "none", dropout: float = 0.0, @@ -530,12 +530,12 @@ def __init__( all_dims = [in_dim] + self.hidden_dims + [out_dim] fully_connected = [] - if depth == 0: + if self.depth == 0: self.fully_connected = None return else: - for ii in range(depth): - if ii < (depth - 1): + for ii in range(self.depth): + if ii < (self.depth - 1): # Define the parameters for all intermediate layers this_activation = activation this_normalization = normalization @@ -551,7 +551,7 @@ def __init__( if constant_droppath_rate: this_drop_rate = droppath_rate else: - this_drop_rate = DropPath.get_stochastic_drop_rate(droppath_rate, ii, depth) + this_drop_rate = DropPath.get_stochastic_drop_rate(droppath_rate, ii, self.depth) # Add a fully-connected layer fully_connected.append( diff --git a/graphium/nn/ensemble_layers.py b/graphium/nn/ensemble_layers.py index ad8182d67..800ed7c71 100644 --- a/graphium/nn/ensemble_layers.py +++ b/graphium/nn/ensemble_layers.py @@ -50,6 +50,7 @@ def reset_parameters(self): """ Reset the parameters of the linear layer using the `init_fn`. """ + set_base_shapes(self, None, rescale_params=False) # Set the shapes of the tensors, useful for mup # Initialize weight using the provided initialization function self.init_fn(self.weight) @@ -169,7 +170,7 @@ def __init__( in_dim, out_dim, num_ensemble=num_ensemble, bias=bias, init_fn=init_fn ) else: - self.linear = EnsembleMuReadoutGraphium(in_dim, out_dim, bias=bias) + self.linear = EnsembleMuReadoutGraphium(in_dim, out_dim, num_ensemble=num_ensemble, bias=bias) self.reset_parameters() @@ -202,9 +203,10 @@ def __init__( readout_zero_init=False, output_mult=1.0, ): + self.in_dim = in_dim self.output_mult = output_mult self.readout_zero_init = readout_zero_init - self.base_width = in_dim + self._base_width = in_dim super().__init__( in_dim=in_dim, out_dim=out_dim, @@ -254,7 +256,7 @@ def forward(self, x): @property def absolute_width(self): - return float(self.in_features) + return float(self.in_dim) @property def base_width(self): @@ -279,8 +281,8 @@ def __init__( in_dim: int, hidden_dims: Union[Iterable[int], int], out_dim: int, - depth: int, num_ensemble: int, + depth: Optional[int] = None, reduction: Optional[Union[str, Callable]] = "none", activation: Union[str, Callable] = "relu", last_activation: Union[str, Callable] = "none", @@ -304,13 +306,13 @@ def __init__( or a list of dimensions in the hidden layers. out_dim: Output dimension of the MLP. + num_ensemble: + Number of MLPs that run in parallel. depth: If `hidden_dims` is an integer, `depth` is 1 + the number of hidden layers to use. If `hidden_dims` is a list, then `depth` must be `None` or equal to `len(hidden_dims) + 1` - num_ensemble: - Number of MLPs that run in parallel. reduction: Reduction to use at the end of the MLP. Choices: @@ -358,7 +360,6 @@ def __init__( hidden_dims=hidden_dims, out_dim=out_dim, depth=depth, - num_ensemble=num_ensemble, activation=activation, last_activation=last_activation, dropout=dropout, @@ -369,6 +370,8 @@ def __init__( last_layer_is_readout=last_layer_is_readout, droppath_rate=droppath_rate, constant_droppath_rate=constant_droppath_rate, + fc_layer=EnsembleFCLayer, + fc_layer_kwargs={"num_ensemble": num_ensemble}, ) self.reduction = self._parse_reduction(reduction) diff --git a/graphium/utils/spaces.py b/graphium/utils/spaces.py index 67f1f4502..3a7f46109 100644 --- a/graphium/utils/spaces.py +++ b/graphium/utils/spaces.py @@ -4,6 +4,7 @@ import torchmetrics.functional as TorchMetrics import graphium.nn.base_layers as BaseLayers +import graphium.nn.ensemble_layers as EnsembleLayers from graphium.nn.architectures import FeedForwardNN, FeedForwardPyg, TaskHeads import graphium.utils.custom_lr as CustomLR import graphium.data.datamodule as Datamodules @@ -28,7 +29,7 @@ } ENSEMBLE_FC_LAYERS_DICT = { - "ens-fc": BaseLayers.EnsembleFCLayer, + "ens-fc": EnsembleLayers.EnsembleFCLayer, } PYG_LAYERS_DICT = { diff --git a/tests/test_ensemble_layers.py b/tests/test_ensemble_layers.py new file mode 100644 index 000000000..e516f7675 --- /dev/null +++ b/tests/test_ensemble_layers.py @@ -0,0 +1,256 @@ +""" +Unit tests for the different layers of graphium/nn/ensemble_layers +""" + +import numpy as np +import torch +from torch.nn import Linear +import unittest as ut + +from graphium.nn.base_layers import FCLayer, MLP +from graphium.nn.ensemble_layers import EnsembleLinear, EnsembleFCLayer, EnsembleMLP, EnsembleMuReadoutGraphium + + +class test_Ensemble_Layers(ut.TestCase): + + # for drop_rate=0.5, test if the output shape is correct + def check_ensemble_linear(self, in_dim: int, out_dim: int, num_ensemble: int, batch_size: int, more_batch_dim:int): + + msg = f"Testing EnsembleLinear with in_dim={in_dim}, out_dim={out_dim}, num_ensemble={num_ensemble}, batch_size={batch_size}, more_batch_dim={more_batch_dim}" + + # Create EnsembleLinear instance + ensemble_linear = EnsembleLinear(in_dim, out_dim, num_ensemble) + + # Create equivalent separate Linear layers with synchronized weights and biases + linear_layers = [Linear(in_dim, out_dim) for _ in range(num_ensemble)] + for i, linear_layer in enumerate(linear_layers): + linear_layer.weight.data = ensemble_linear.weight.data[i] + if ensemble_linear.bias is not None: + linear_layer.bias.data = ensemble_linear.bias.data[i].squeeze() + + # Test with a sample input + input_tensor = torch.randn(batch_size, in_dim) + ensemble_output = ensemble_linear(input_tensor) + + # Check for the output shape + self.assertEqual(ensemble_output.shape, (num_ensemble, batch_size, out_dim), msg=msg) + + # Make sure that the outputs of the individual layers are the same as the ensemble output + for i, linear_layer in enumerate(linear_layers): + + individual_output = linear_layer(input_tensor) + individual_output = individual_output.detach().numpy() + ensemble_output_i = ensemble_output[i].detach().numpy() + np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg) + + + # Test with a sample input with the extra `num_ensemble` and `more_batch_dim` dimension + if more_batch_dim: + out_shape = (more_batch_dim, num_ensemble, batch_size, out_dim) + input_tensor = torch.randn(more_batch_dim, num_ensemble, batch_size, in_dim) + else: + out_shape = (num_ensemble, batch_size, out_dim) + input_tensor = torch.randn(num_ensemble, batch_size, in_dim) + ensemble_output = ensemble_linear(input_tensor) + + # Check for the output shape + self.assertEqual(ensemble_output.shape, out_shape, msg=msg) + + # Make sure that the outputs of the individual layers are the same as the ensemble output + for i, linear_layer in enumerate(linear_layers): + + if more_batch_dim: + individual_output = linear_layer(input_tensor[:, i]) + ensemble_output_i = ensemble_output[:, i] + else: + individual_output = linear_layer(input_tensor[i]) + ensemble_output_i = ensemble_output[i] + individual_output = individual_output.detach().numpy() + ensemble_output_i = ensemble_output_i.detach().numpy() + np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg) + + + + def test_ensemble_linear(self): + # more_batch_dim=0 + self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0) + self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=0) + self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=0) + + # more_batch_dim=1 + self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1) + self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=1) + self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=1) + + # more_batch_dim=7 + self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7) + self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=7) + self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=7) + + + # for drop_rate=0.5, test if the output shape is correct + def check_ensemble_fclayer(self, in_dim: int, out_dim: int, num_ensemble: int, batch_size: int, more_batch_dim:int, is_readout_layer=False): + + msg = f"Testing EnsembleFCLayer with in_dim={in_dim}, out_dim={out_dim}, num_ensemble={num_ensemble}, batch_size={batch_size}, more_batch_dim={more_batch_dim}" + + # Create EnsembleFCLayer instance + ensemble_fclayer = EnsembleFCLayer(in_dim, out_dim, num_ensemble, is_readout_layer=is_readout_layer) + + # Create equivalent separate FCLayer layers with synchronized weights and biases + fc_layers = [FCLayer(in_dim, out_dim, is_readout_layer=is_readout_layer) for _ in range(num_ensemble)] + for i, fc_layer in enumerate(fc_layers): + fc_layer.linear.weight.data = ensemble_fclayer.linear.weight.data[i] + if ensemble_fclayer.bias is not None: + fc_layer.linear.bias.data = ensemble_fclayer.linear.bias.data[i].squeeze() + + # Test with a sample input + input_tensor = torch.randn(batch_size, in_dim) + ensemble_output = ensemble_fclayer(input_tensor) + + # Check for the output shape + self.assertEqual(ensemble_output.shape, (num_ensemble, batch_size, out_dim), msg=msg) + + # Make sure that the outputs of the individual layers are the same as the ensemble output + for i, fc_layer in enumerate(fc_layers): + + individual_output = fc_layer(input_tensor) + individual_output = individual_output.detach().numpy() + ensemble_output_i = ensemble_output[i].detach().numpy() + np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg) + + + # Test with a sample input with the extra `num_ensemble` and `more_batch_dim` dimension + if more_batch_dim: + out_shape = (more_batch_dim, num_ensemble, batch_size, out_dim) + input_tensor = torch.randn(more_batch_dim, num_ensemble, batch_size, in_dim) + else: + out_shape = (num_ensemble, batch_size, out_dim) + input_tensor = torch.randn(num_ensemble, batch_size, in_dim) + ensemble_output = ensemble_fclayer(input_tensor) + + # Check for the output shape + self.assertEqual(ensemble_output.shape, out_shape, msg=msg) + + # Make sure that the outputs of the individual layers are the same as the ensemble output + for i, fc_layer in enumerate(fc_layers): + + if more_batch_dim: + individual_output = fc_layer(input_tensor[:, i]) + ensemble_output_i = ensemble_output[:, i] + else: + individual_output = fc_layer(input_tensor[i]) + ensemble_output_i = ensemble_output[i] + individual_output = individual_output.detach().numpy() + ensemble_output_i = ensemble_output_i.detach().numpy() + np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg) + + + + def test_ensemble_fclayer(self): + # more_batch_dim=0 + self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0) + self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=0) + self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=0) + + # more_batch_dim=1 + self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1) + self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=1) + self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=1) + + # more_batch_dim=7 + self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7) + self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=7) + self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=7) + + # Test `is_readout_layer` + self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0, is_readout_layer=True) + self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1, is_readout_layer=True) + self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, is_readout_layer=True) + + + + + # for drop_rate=0.5, test if the output shape is correct + def check_ensemble_mlp(self, in_dim: int, out_dim: int, num_ensemble: int, batch_size: int, more_batch_dim:int, last_layer_is_readout=False): + + msg = f"Testing EnsembleMLP with in_dim={in_dim}, out_dim={out_dim}, num_ensemble={num_ensemble}, batch_size={batch_size}, more_batch_dim={more_batch_dim}" + + # Create EnsembleMLP instance + hidden_dims = [17, 17, 17] + ensemble_mlp = EnsembleMLP(in_dim, hidden_dims, out_dim, num_ensemble, last_layer_is_readout=last_layer_is_readout) + + # Create equivalent separate MLP layers with synchronized weights and biases + mlps = [MLP(in_dim, hidden_dims, out_dim, last_layer_is_readout=last_layer_is_readout) for _ in range(num_ensemble)] + for i, mlp in enumerate(mlps): + for j, layer in enumerate(mlp.fully_connected): + layer.linear.weight.data = ensemble_mlp.fully_connected[j].linear.weight.data[i] + if layer.bias is not None: + layer.linear.bias.data = ensemble_mlp.fully_connected[j].linear.bias.data[i].squeeze() + + # Test with a sample input + input_tensor = torch.randn(batch_size, in_dim) + ensemble_output = ensemble_mlp(input_tensor) + + # Check for the output shape + self.assertEqual(ensemble_output.shape, (num_ensemble, batch_size, out_dim), msg=msg) + + # Make sure that the outputs of the individual layers are the same as the ensemble output + for i, mlp in enumerate(mlps): + + individual_output = mlp(input_tensor) + individual_output = individual_output.detach().numpy() + ensemble_output_i = ensemble_output[i].detach().numpy() + np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg) + + + # Test with a sample input with the extra `num_ensemble` and `more_batch_dim` dimension + if more_batch_dim: + out_shape = (more_batch_dim, num_ensemble, batch_size, out_dim) + input_tensor = torch.randn(more_batch_dim, num_ensemble, batch_size, in_dim) + else: + out_shape = (num_ensemble, batch_size, out_dim) + input_tensor = torch.randn(num_ensemble, batch_size, in_dim) + ensemble_output = ensemble_mlp(input_tensor) + + # Check for the output shape + self.assertEqual(ensemble_output.shape, out_shape, msg=msg) + + # Make sure that the outputs of the individual layers are the same as the ensemble output + for i, mlp in enumerate(mlps): + + if more_batch_dim: + individual_output = mlp(input_tensor[:, i]) + ensemble_output_i = ensemble_output[:, i] + else: + individual_output = mlp(input_tensor[i]) + ensemble_output_i = ensemble_output[i] + individual_output = individual_output.detach().numpy() + ensemble_output_i = ensemble_output_i.detach().numpy() + np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg) + + + + def test_ensemble_mlp(self): + # more_batch_dim=0 + self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0) + self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=0) + self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=0) + + # more_batch_dim=1 + self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1) + self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=1) + self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=1) + + # more_batch_dim=7 + self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7) + self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=7) + self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=7) + + # Test `last_layer_is_readout` + self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0, last_layer_is_readout=True) + self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1, last_layer_is_readout=True) + self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, last_layer_is_readout=True) + + +if __name__ == '__main__': + ut.main() From e0f841a2fd546d8c9883b55d089388237f7730d5 Mon Sep 17 00:00:00 2001 From: DomInvivo Date: Wed, 13 Dec 2023 17:38:47 -0500 Subject: [PATCH 04/18] black linting --- tests/test_ensemble_layers.py | 93 +++++++++++++++++++++-------------- 1 file changed, 55 insertions(+), 38 deletions(-) diff --git a/tests/test_ensemble_layers.py b/tests/test_ensemble_layers.py index e516f7675..ff96e0fad 100644 --- a/tests/test_ensemble_layers.py +++ b/tests/test_ensemble_layers.py @@ -8,14 +8,19 @@ import unittest as ut from graphium.nn.base_layers import FCLayer, MLP -from graphium.nn.ensemble_layers import EnsembleLinear, EnsembleFCLayer, EnsembleMLP, EnsembleMuReadoutGraphium +from graphium.nn.ensemble_layers import ( + EnsembleLinear, + EnsembleFCLayer, + EnsembleMLP, + EnsembleMuReadoutGraphium, +) class test_Ensemble_Layers(ut.TestCase): - # for drop_rate=0.5, test if the output shape is correct - def check_ensemble_linear(self, in_dim: int, out_dim: int, num_ensemble: int, batch_size: int, more_batch_dim:int): - + def check_ensemble_linear( + self, in_dim: int, out_dim: int, num_ensemble: int, batch_size: int, more_batch_dim: int + ): msg = f"Testing EnsembleLinear with in_dim={in_dim}, out_dim={out_dim}, num_ensemble={num_ensemble}, batch_size={batch_size}, more_batch_dim={more_batch_dim}" # Create EnsembleLinear instance @@ -37,13 +42,11 @@ def check_ensemble_linear(self, in_dim: int, out_dim: int, num_ensemble: int, ba # Make sure that the outputs of the individual layers are the same as the ensemble output for i, linear_layer in enumerate(linear_layers): - individual_output = linear_layer(input_tensor) individual_output = individual_output.detach().numpy() ensemble_output_i = ensemble_output[i].detach().numpy() np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg) - # Test with a sample input with the extra `num_ensemble` and `more_batch_dim` dimension if more_batch_dim: out_shape = (more_batch_dim, num_ensemble, batch_size, out_dim) @@ -58,7 +61,6 @@ def check_ensemble_linear(self, in_dim: int, out_dim: int, num_ensemble: int, ba # Make sure that the outputs of the individual layers are the same as the ensemble output for i, linear_layer in enumerate(linear_layers): - if more_batch_dim: individual_output = linear_layer(input_tensor[:, i]) ensemble_output_i = ensemble_output[:, i] @@ -69,8 +71,6 @@ def check_ensemble_linear(self, in_dim: int, out_dim: int, num_ensemble: int, ba ensemble_output_i = ensemble_output_i.detach().numpy() np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg) - - def test_ensemble_linear(self): # more_batch_dim=0 self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0) @@ -87,10 +87,16 @@ def test_ensemble_linear(self): self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=7) self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=7) - # for drop_rate=0.5, test if the output shape is correct - def check_ensemble_fclayer(self, in_dim: int, out_dim: int, num_ensemble: int, batch_size: int, more_batch_dim:int, is_readout_layer=False): - + def check_ensemble_fclayer( + self, + in_dim: int, + out_dim: int, + num_ensemble: int, + batch_size: int, + more_batch_dim: int, + is_readout_layer=False, + ): msg = f"Testing EnsembleFCLayer with in_dim={in_dim}, out_dim={out_dim}, num_ensemble={num_ensemble}, batch_size={batch_size}, more_batch_dim={more_batch_dim}" # Create EnsembleFCLayer instance @@ -112,13 +118,11 @@ def check_ensemble_fclayer(self, in_dim: int, out_dim: int, num_ensemble: int, b # Make sure that the outputs of the individual layers are the same as the ensemble output for i, fc_layer in enumerate(fc_layers): - individual_output = fc_layer(input_tensor) individual_output = individual_output.detach().numpy() ensemble_output_i = ensemble_output[i].detach().numpy() np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg) - # Test with a sample input with the extra `num_ensemble` and `more_batch_dim` dimension if more_batch_dim: out_shape = (more_batch_dim, num_ensemble, batch_size, out_dim) @@ -133,7 +137,6 @@ def check_ensemble_fclayer(self, in_dim: int, out_dim: int, num_ensemble: int, b # Make sure that the outputs of the individual layers are the same as the ensemble output for i, fc_layer in enumerate(fc_layers): - if more_batch_dim: individual_output = fc_layer(input_tensor[:, i]) ensemble_output_i = ensemble_output[:, i] @@ -144,8 +147,6 @@ def check_ensemble_fclayer(self, in_dim: int, out_dim: int, num_ensemble: int, b ensemble_output_i = ensemble_output_i.detach().numpy() np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg) - - def test_ensemble_fclayer(self): # more_batch_dim=0 self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0) @@ -163,24 +164,39 @@ def test_ensemble_fclayer(self): self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=7) # Test `is_readout_layer` - self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0, is_readout_layer=True) - self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1, is_readout_layer=True) - self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, is_readout_layer=True) - - - + self.check_ensemble_fclayer( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0, is_readout_layer=True + ) + self.check_ensemble_fclayer( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1, is_readout_layer=True + ) + self.check_ensemble_fclayer( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, is_readout_layer=True + ) # for drop_rate=0.5, test if the output shape is correct - def check_ensemble_mlp(self, in_dim: int, out_dim: int, num_ensemble: int, batch_size: int, more_batch_dim:int, last_layer_is_readout=False): - + def check_ensemble_mlp( + self, + in_dim: int, + out_dim: int, + num_ensemble: int, + batch_size: int, + more_batch_dim: int, + last_layer_is_readout=False, + ): msg = f"Testing EnsembleMLP with in_dim={in_dim}, out_dim={out_dim}, num_ensemble={num_ensemble}, batch_size={batch_size}, more_batch_dim={more_batch_dim}" # Create EnsembleMLP instance hidden_dims = [17, 17, 17] - ensemble_mlp = EnsembleMLP(in_dim, hidden_dims, out_dim, num_ensemble, last_layer_is_readout=last_layer_is_readout) + ensemble_mlp = EnsembleMLP( + in_dim, hidden_dims, out_dim, num_ensemble, last_layer_is_readout=last_layer_is_readout + ) # Create equivalent separate MLP layers with synchronized weights and biases - mlps = [MLP(in_dim, hidden_dims, out_dim, last_layer_is_readout=last_layer_is_readout) for _ in range(num_ensemble)] + mlps = [ + MLP(in_dim, hidden_dims, out_dim, last_layer_is_readout=last_layer_is_readout) + for _ in range(num_ensemble) + ] for i, mlp in enumerate(mlps): for j, layer in enumerate(mlp.fully_connected): layer.linear.weight.data = ensemble_mlp.fully_connected[j].linear.weight.data[i] @@ -196,13 +212,11 @@ def check_ensemble_mlp(self, in_dim: int, out_dim: int, num_ensemble: int, batch # Make sure that the outputs of the individual layers are the same as the ensemble output for i, mlp in enumerate(mlps): - individual_output = mlp(input_tensor) individual_output = individual_output.detach().numpy() ensemble_output_i = ensemble_output[i].detach().numpy() np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg) - # Test with a sample input with the extra `num_ensemble` and `more_batch_dim` dimension if more_batch_dim: out_shape = (more_batch_dim, num_ensemble, batch_size, out_dim) @@ -217,7 +231,6 @@ def check_ensemble_mlp(self, in_dim: int, out_dim: int, num_ensemble: int, batch # Make sure that the outputs of the individual layers are the same as the ensemble output for i, mlp in enumerate(mlps): - if more_batch_dim: individual_output = mlp(input_tensor[:, i]) ensemble_output_i = ensemble_output[:, i] @@ -228,8 +241,6 @@ def check_ensemble_mlp(self, in_dim: int, out_dim: int, num_ensemble: int, batch ensemble_output_i = ensemble_output_i.detach().numpy() np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg) - - def test_ensemble_mlp(self): # more_batch_dim=0 self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0) @@ -247,10 +258,16 @@ def test_ensemble_mlp(self): self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=7) # Test `last_layer_is_readout` - self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0, last_layer_is_readout=True) - self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1, last_layer_is_readout=True) - self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, last_layer_is_readout=True) - - -if __name__ == '__main__': + self.check_ensemble_mlp( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0, last_layer_is_readout=True + ) + self.check_ensemble_mlp( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1, last_layer_is_readout=True + ) + self.check_ensemble_mlp( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, last_layer_is_readout=True + ) + + +if __name__ == "__main__": ut.main() From dd52ca5122746eccfb2ab2188f3c77e11d75642a Mon Sep 17 00:00:00 2001 From: DomInvivo Date: Wed, 13 Dec 2023 17:42:48 -0500 Subject: [PATCH 05/18] Added tests for mu_readout --- tests/test_ensemble_layers.py | 59 +++++++++++++++++++++++++++++++---- 1 file changed, 53 insertions(+), 6 deletions(-) diff --git a/tests/test_ensemble_layers.py b/tests/test_ensemble_layers.py index ff96e0fad..66de95e48 100644 --- a/tests/test_ensemble_layers.py +++ b/tests/test_ensemble_layers.py @@ -7,7 +7,7 @@ from torch.nn import Linear import unittest as ut -from graphium.nn.base_layers import FCLayer, MLP +from graphium.nn.base_layers import FCLayer, MLP, MuReadoutGraphium from graphium.nn.ensemble_layers import ( EnsembleLinear, EnsembleFCLayer, @@ -19,15 +19,27 @@ class test_Ensemble_Layers(ut.TestCase): # for drop_rate=0.5, test if the output shape is correct def check_ensemble_linear( - self, in_dim: int, out_dim: int, num_ensemble: int, batch_size: int, more_batch_dim: int + self, + in_dim: int, + out_dim: int, + num_ensemble: int, + batch_size: int, + more_batch_dim: int, + use_mureadout=False, ): msg = f"Testing EnsembleLinear with in_dim={in_dim}, out_dim={out_dim}, num_ensemble={num_ensemble}, batch_size={batch_size}, more_batch_dim={more_batch_dim}" - # Create EnsembleLinear instance - ensemble_linear = EnsembleLinear(in_dim, out_dim, num_ensemble) + if use_mureadout: + # Create EnsembleMuReadoutGraphium instance + ensemble_linear = EnsembleMuReadoutGraphium(in_dim, out_dim, num_ensemble) + # Create equivalent separate Linear layers with synchronized weights and biases + linear_layers = [MuReadoutGraphium(in_dim, out_dim) for _ in range(num_ensemble)] + else: + # Create EnsembleLinear instance + ensemble_linear = EnsembleLinear(in_dim, out_dim, num_ensemble) + # Create equivalent separate Linear layers with synchronized weights and biases + linear_layers = [Linear(in_dim, out_dim) for _ in range(num_ensemble)] - # Create equivalent separate Linear layers with synchronized weights and biases - linear_layers = [Linear(in_dim, out_dim) for _ in range(num_ensemble)] for i, linear_layer in enumerate(linear_layers): linear_layer.weight.data = ensemble_linear.weight.data[i] if ensemble_linear.bias is not None: @@ -87,6 +99,41 @@ def test_ensemble_linear(self): self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=7) self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=7) + def test_ensemble_mureadout_graphium(self): + # Test `use_mureadout` + # more_batch_dim=0 + self.check_ensemble_linear( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0, use_mureadout=True + ) + self.check_ensemble_linear( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=0, use_mureadout=True + ) + self.check_ensemble_linear( + in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=0, use_mureadout=True + ) + + # more_batch_dim=1 + self.check_ensemble_linear( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1, use_mureadout=True + ) + self.check_ensemble_linear( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=1, use_mureadout=True + ) + self.check_ensemble_linear( + in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=1, use_mureadout=True + ) + + # more_batch_dim=7 + self.check_ensemble_linear( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, use_mureadout=True + ) + self.check_ensemble_linear( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=7, use_mureadout=True + ) + self.check_ensemble_linear( + in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=7, use_mureadout=True + ) + # for drop_rate=0.5, test if the output shape is correct def check_ensemble_fclayer( self, From d926c28ee22030faf274f07e2a845b6ca8d02d31 Mon Sep 17 00:00:00 2001 From: DomInvivo Date: Wed, 13 Dec 2023 17:44:58 -0500 Subject: [PATCH 06/18] Added `NotImplementedError` to remember testing EnsembleFeedForwardNN --- graphium/nn/architectures/__init__.py | 1 + tests/test_ensemble_layers.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/graphium/nn/architectures/__init__.py b/graphium/nn/architectures/__init__.py index 1025d900a..a02a5cc8c 100644 --- a/graphium/nn/architectures/__init__.py +++ b/graphium/nn/architectures/__init__.py @@ -3,3 +3,4 @@ from .global_architectures import TaskHeads from .global_architectures import GraphOutputNN from .pyg_architectures import FeedForwardPyg +from .global_architectures import EnsembleFeedForwardNN diff --git a/tests/test_ensemble_layers.py b/tests/test_ensemble_layers.py index 66de95e48..74f029957 100644 --- a/tests/test_ensemble_layers.py +++ b/tests/test_ensemble_layers.py @@ -14,6 +14,7 @@ EnsembleMLP, EnsembleMuReadoutGraphium, ) +from graphium.nn.architectures import FeedForwardNN, EnsembleFeedForwardNN class test_Ensemble_Layers(ut.TestCase): @@ -315,6 +316,8 @@ def test_ensemble_mlp(self): in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, last_layer_is_readout=True ) + def test_ensemble_feed_forward_nn(self): + raise NotImplementedError if __name__ == "__main__": ut.main() From db25f107587aba3b2341c820554d88f6a2b8462a Mon Sep 17 00:00:00 2001 From: DomInvivo Date: Wed, 13 Dec 2023 17:45:14 -0500 Subject: [PATCH 07/18] Forgot to stage --- tests/test_ensemble_layers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_ensemble_layers.py b/tests/test_ensemble_layers.py index 74f029957..55843c460 100644 --- a/tests/test_ensemble_layers.py +++ b/tests/test_ensemble_layers.py @@ -319,5 +319,6 @@ def test_ensemble_mlp(self): def test_ensemble_feed_forward_nn(self): raise NotImplementedError + if __name__ == "__main__": ut.main() From 0dd49d25eaf696f88c15718d0dc44815d9f9d5dd Mon Sep 17 00:00:00 2001 From: DomInvivo Date: Wed, 13 Dec 2023 17:46:36 -0500 Subject: [PATCH 08/18] minor --- tests/test_ensemble_layers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_ensemble_layers.py b/tests/test_ensemble_layers.py index 55843c460..5c649f804 100644 --- a/tests/test_ensemble_layers.py +++ b/tests/test_ensemble_layers.py @@ -318,6 +318,7 @@ def test_ensemble_mlp(self): def test_ensemble_feed_forward_nn(self): raise NotImplementedError + # Don't forget to test the `reduce` argument if __name__ == "__main__": From e8c970c05ce004d39abf6d7a31256b820f49ca5b Mon Sep 17 00:00:00 2001 From: DomInvivo Date: Thu, 14 Dec 2023 00:27:35 -0500 Subject: [PATCH 09/18] Testing the MLP, and the `reduction` --- .../nn/architectures/global_architectures.py | 45 +++-- graphium/nn/ensemble_layers.py | 6 +- tests/test_ensemble_layers.py | 184 +++++++++++++++++- 3 files changed, 214 insertions(+), 21 deletions(-) diff --git a/graphium/nn/architectures/global_architectures.py b/graphium/nn/architectures/global_architectures.py index fe1a82445..cfe5dd692 100644 --- a/graphium/nn/architectures/global_architectures.py +++ b/graphium/nn/architectures/global_architectures.py @@ -405,7 +405,7 @@ def __repr__(self): return class_str + layer_str -class EnsembleFeedForwardNN(nn.Module, MupMixin): +class EnsembleFeedForwardNN(FeedForwardNN): def __init__( self, in_dim: int, @@ -528,14 +528,9 @@ def __init__( """ # Parse the ensemble arguments - self.num_ensemble = num_ensemble - num_ensemble_2 = layer_kwargs.get("num_ensemble", None) - if num_ensemble_2 is None: - layer_kwargs["num_ensemble"] = num_ensemble - else: - assert ( - num_ensemble_2 == num_ensemble - ), f"num_ensemble={num_ensemble} != num_ensemble_2={num_ensemble_2}" + if layer_kwargs is None: + layer_kwargs = {} + layer_kwargs["num_ensemble"] = self._parse_num_ensemble(num_ensemble, layer_kwargs) super().__init__( in_dim=in_dim, @@ -555,13 +550,39 @@ def __init__( layer_type=layer_type, layer_kwargs=layer_kwargs, last_layer_is_readout=last_layer_is_readout, - reduction=reduction, ) # Parse the reduction self.reduction = reduction self.reduction_fn = self._parse_reduction(reduction) + def _parse_num_ensemble(self, num_ensemble: int, layer_kwargs) -> int: + r""" + Parse the num_ensemble argument. + """ + num_ensemble_out = num_ensemble + + # Get the num_ensemble from the layer_kwargs if it exists + num_ensemble_2 = None + if layer_kwargs is None: + layer_kwargs = {} + else: + num_ensemble_2 = layer_kwargs.get("num_ensemble", None) + + if num_ensemble is None: + num_ensemble_out = num_ensemble_2 + + # Check that the num_ensemble is consistent + if num_ensemble_2 is not None: + assert ( + num_ensemble_2 == num_ensemble + ), f"num_ensemble={num_ensemble} != num_ensemble_2={num_ensemble_2}" + + # Check that `num_ensemble_out` is not None + assert num_ensemble_out is not None, f"num_ensemble={num_ensemble} and num_ensemble_2={num_ensemble_2}" + + return num_ensemble_out + def _parse_reduction(self, reduction: Optional[Union[str, Callable]]) -> Optional[Callable]: r""" Parse the reduction argument. @@ -614,8 +635,8 @@ def forward(self, h: torch.Tensor) -> torch.Tensor: """ h = super().forward(h) - if self.reduction is not None: - h = self.reduction(h, dim=-2) + if self.reduction_fn is not None: + h = self.reduction_fn(h, dim=-3) return h diff --git a/graphium/nn/ensemble_layers.py b/graphium/nn/ensemble_layers.py index 800ed7c71..7915fc9d2 100644 --- a/graphium/nn/ensemble_layers.py +++ b/graphium/nn/ensemble_layers.py @@ -374,7 +374,7 @@ def __init__( fc_layer_kwargs={"num_ensemble": num_ensemble}, ) - self.reduction = self._parse_reduction(reduction) + self.reduction_fn = self._parse_reduction(reduction) def _parse_reduction(self, reduction: Optional[Union[str, Callable]]) -> Optional[Callable]: r""" @@ -415,8 +415,8 @@ def forward(self, h: torch.Tensor) -> torch.Tensor: `L` is removed if a reduction is specified. """ h = super().forward(h) - if self.reduction is not None: - h = self.reduction(h, dim=-2) + if self.reduction_fn is not None: + h = self.reduction_fn(h, dim=-3) return h def __repr__(self): diff --git a/tests/test_ensemble_layers.py b/tests/test_ensemble_layers.py index 5c649f804..4a14582eb 100644 --- a/tests/test_ensemble_layers.py +++ b/tests/test_ensemble_layers.py @@ -18,7 +18,7 @@ class test_Ensemble_Layers(ut.TestCase): - # for drop_rate=0.5, test if the output shape is correct + def check_ensemble_linear( self, in_dim: int, @@ -135,7 +135,7 @@ def test_ensemble_mureadout_graphium(self): in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=7, use_mureadout=True ) - # for drop_rate=0.5, test if the output shape is correct + def check_ensemble_fclayer( self, in_dim: int, @@ -222,7 +222,7 @@ def test_ensemble_fclayer(self): in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, is_readout_layer=True ) - # for drop_rate=0.5, test if the output shape is correct + def check_ensemble_mlp( self, in_dim: int, @@ -316,10 +316,182 @@ def test_ensemble_mlp(self): in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, last_layer_is_readout=True ) - def test_ensemble_feed_forward_nn(self): - raise NotImplementedError - # Don't forget to test the `reduce` argument + def check_ensemble_feedforwardnn( + self, + in_dim: int, + out_dim: int, + num_ensemble: int, + batch_size: int, + more_batch_dim: int, + last_layer_is_readout=False, + ): + msg = f"Testing EnsembleFeedForwardNN with in_dim={in_dim}, out_dim={out_dim}, num_ensemble={num_ensemble}, batch_size={batch_size}, more_batch_dim={more_batch_dim}" + + # Create EnsembleFeedForwardNN instance + hidden_dims = [17, 17, 17] + ensemble_mlp = EnsembleFeedForwardNN( + in_dim, out_dim, hidden_dims, num_ensemble, reduction=None, last_layer_is_readout=last_layer_is_readout + ) + + # Create equivalent separate MLP layers with synchronized weights and biases + mlps = [ + FeedForwardNN(in_dim, out_dim, hidden_dims, last_layer_is_readout=last_layer_is_readout) + for _ in range(num_ensemble) + ] + for i, mlp in enumerate(mlps): + for j, layer in enumerate(mlp.layers): + layer.linear.weight.data = ensemble_mlp.layers[j].linear.weight.data[i] + if layer.bias is not None: + layer.linear.bias.data = ensemble_mlp.layers[j].linear.bias.data[i].squeeze() + + # Test with a sample input + input_tensor = torch.randn(batch_size, in_dim) + ensemble_output = ensemble_mlp(input_tensor) + + # Check for the output shape + self.assertEqual(ensemble_output.shape, (num_ensemble, batch_size, out_dim), msg=msg) + + # Make sure that the outputs of the individual layers are the same as the ensemble output + individual_outputs = [] + for i, mlp in enumerate(mlps): + individual_outputs.append(mlp(input_tensor)) + individual_outputs = torch.stack(individual_outputs).detach().numpy() + for i, mlp in enumerate(mlps): + ensemble_output_i = ensemble_output[i].detach().numpy() + np.testing.assert_allclose(ensemble_output_i, individual_outputs[..., i, :, :], atol=1e-5, err_msg=msg) + + # Test with a sample input with the extra `num_ensemble` and `more_batch_dim` dimension + if more_batch_dim: + out_shape = (more_batch_dim, num_ensemble, batch_size, out_dim) + input_tensor = torch.randn(more_batch_dim, num_ensemble, batch_size, in_dim) + else: + out_shape = (num_ensemble, batch_size, out_dim) + input_tensor = torch.randn(num_ensemble, batch_size, in_dim) + ensemble_output = ensemble_mlp(input_tensor) + + # Check for the output shape + self.assertEqual(ensemble_output.shape, out_shape, msg=msg) + + # Make sure that the outputs of the individual layers are the same as the ensemble output + for i, mlp in enumerate(mlps): + if more_batch_dim: + individual_output = mlp(input_tensor[:, i]) + ensemble_output_i = ensemble_output[:, i] + else: + individual_output = mlp(input_tensor[i]) + ensemble_output_i = ensemble_output[i] + individual_output = individual_output.detach().numpy() + ensemble_output_i = ensemble_output_i.detach().numpy() + np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg) + + + def check_ensemble_feedforwardnn_mean( + self, + in_dim: int, + out_dim: int, + num_ensemble: int, + batch_size: int, + more_batch_dim: int, + last_layer_is_readout=False, + ): + msg = f"Testing EnsembleFeedForwardNN with in_dim={in_dim}, out_dim={out_dim}, num_ensemble={num_ensemble}, batch_size={batch_size}, more_batch_dim={more_batch_dim}" + + # Create EnsembleFeedForwardNN instance + hidden_dims = [17, 17, 17] + ensemble_mlp = EnsembleFeedForwardNN( + in_dim, out_dim, hidden_dims, num_ensemble, reduction="mean", last_layer_is_readout=last_layer_is_readout + ) + + # Create equivalent separate MLP layers with synchronized weights and biases + mlps = [ + FeedForwardNN(in_dim, out_dim, hidden_dims, last_layer_is_readout=last_layer_is_readout) + for _ in range(num_ensemble) + ] + for i, mlp in enumerate(mlps): + for j, layer in enumerate(mlp.layers): + layer.linear.weight.data = ensemble_mlp.layers[j].linear.weight.data[i] + if layer.bias is not None: + layer.linear.bias.data = ensemble_mlp.layers[j].linear.bias.data[i].squeeze() + + # Test with a sample input + input_tensor = torch.randn(batch_size, in_dim) + ensemble_output = ensemble_mlp(input_tensor) + + # Check for the output shape + self.assertEqual(ensemble_output.shape, (batch_size, out_dim), msg=msg) + + + # Make sure that the outputs of the individual layers are the same as the ensemble output + individual_outputs = [] + for i, mlp in enumerate(mlps): + individual_outputs.append(mlp(input_tensor)) + individual_outputs = torch.stack(individual_outputs, dim=-3) + individual_outputs = individual_outputs.mean(dim=-3).detach().numpy() + np.testing.assert_allclose(ensemble_output.detach().numpy(), individual_outputs, atol=1e-5, err_msg=msg) + + # Test with a sample input with the extra `num_ensemble` and `more_batch_dim` dimension + if more_batch_dim: + out_shape = (more_batch_dim, batch_size, out_dim) + input_tensor = torch.randn(more_batch_dim, num_ensemble, batch_size, in_dim) + else: + out_shape = (batch_size, out_dim) + input_tensor = torch.randn(num_ensemble, batch_size, in_dim) + ensemble_output = ensemble_mlp(input_tensor).detach() + + # Check for the output shape + self.assertEqual(ensemble_output.shape, out_shape, msg=msg) + + # Make sure that the outputs of the individual layers are the same as the ensemble output + individual_outputs = [] + for i, mlp in enumerate(mlps): + if more_batch_dim: + individual_outputs.append(mlp(input_tensor[:, i])) + else: + individual_outputs.append(mlp(input_tensor[i])) + individual_output = torch.stack(individual_outputs, dim=-3).mean(dim=-3).detach().numpy() + np.testing.assert_allclose(ensemble_output, individual_output, atol=1e-5, err_msg=msg) + + + + + def test_ensemble_feedforwardnn(self): + # more_batch_dim=0 + self.check_ensemble_feedforwardnn(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0) + self.check_ensemble_feedforwardnn(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=0) + self.check_ensemble_feedforwardnn(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=0) + + # more_batch_dim=1 + self.check_ensemble_feedforwardnn(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1) + self.check_ensemble_feedforwardnn(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=1) + self.check_ensemble_feedforwardnn(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=1) + + # more_batch_dim=7 + self.check_ensemble_feedforwardnn(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7) + self.check_ensemble_feedforwardnn(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=7) + self.check_ensemble_feedforwardnn(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=7) + + # Test `last_layer_is_readout` + self.check_ensemble_feedforwardnn( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0, last_layer_is_readout=True + ) + self.check_ensemble_feedforwardnn( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1, last_layer_is_readout=True + ) + self.check_ensemble_feedforwardnn( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, last_layer_is_readout=True + ) + + # Test `reduction` + self.check_ensemble_feedforwardnn_mean( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0, last_layer_is_readout=True + ) + self.check_ensemble_feedforwardnn_mean( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1, last_layer_is_readout=True + ) + self.check_ensemble_feedforwardnn_mean( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, last_layer_is_readout=True + ) if __name__ == "__main__": ut.main() From be6d0188318cc3786ac441d734c60cd01e11bbe7 Mon Sep 17 00:00:00 2001 From: DomInvivo Date: Thu, 14 Dec 2023 00:27:59 -0500 Subject: [PATCH 10/18] black --- .../nn/architectures/global_architectures.py | 4 +- tests/test_ensemble_layers.py | 68 +++++++++++++------ 2 files changed, 49 insertions(+), 23 deletions(-) diff --git a/graphium/nn/architectures/global_architectures.py b/graphium/nn/architectures/global_architectures.py index cfe5dd692..6f8bd2141 100644 --- a/graphium/nn/architectures/global_architectures.py +++ b/graphium/nn/architectures/global_architectures.py @@ -579,7 +579,9 @@ def _parse_num_ensemble(self, num_ensemble: int, layer_kwargs) -> int: ), f"num_ensemble={num_ensemble} != num_ensemble_2={num_ensemble_2}" # Check that `num_ensemble_out` is not None - assert num_ensemble_out is not None, f"num_ensemble={num_ensemble} and num_ensemble_2={num_ensemble_2}" + assert ( + num_ensemble_out is not None + ), f"num_ensemble={num_ensemble} and num_ensemble_2={num_ensemble_2}" return num_ensemble_out diff --git a/tests/test_ensemble_layers.py b/tests/test_ensemble_layers.py index 4a14582eb..581adbbf3 100644 --- a/tests/test_ensemble_layers.py +++ b/tests/test_ensemble_layers.py @@ -18,7 +18,6 @@ class test_Ensemble_Layers(ut.TestCase): - def check_ensemble_linear( self, in_dim: int, @@ -135,7 +134,6 @@ def test_ensemble_mureadout_graphium(self): in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=7, use_mureadout=True ) - def check_ensemble_fclayer( self, in_dim: int, @@ -222,7 +220,6 @@ def test_ensemble_fclayer(self): in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, is_readout_layer=True ) - def check_ensemble_mlp( self, in_dim: int, @@ -316,7 +313,6 @@ def test_ensemble_mlp(self): in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, last_layer_is_readout=True ) - def check_ensemble_feedforwardnn( self, in_dim: int, @@ -331,7 +327,12 @@ def check_ensemble_feedforwardnn( # Create EnsembleFeedForwardNN instance hidden_dims = [17, 17, 17] ensemble_mlp = EnsembleFeedForwardNN( - in_dim, out_dim, hidden_dims, num_ensemble, reduction=None, last_layer_is_readout=last_layer_is_readout + in_dim, + out_dim, + hidden_dims, + num_ensemble, + reduction=None, + last_layer_is_readout=last_layer_is_readout, ) # Create equivalent separate MLP layers with synchronized weights and biases @@ -359,7 +360,9 @@ def check_ensemble_feedforwardnn( individual_outputs = torch.stack(individual_outputs).detach().numpy() for i, mlp in enumerate(mlps): ensemble_output_i = ensemble_output[i].detach().numpy() - np.testing.assert_allclose(ensemble_output_i, individual_outputs[..., i, :, :], atol=1e-5, err_msg=msg) + np.testing.assert_allclose( + ensemble_output_i, individual_outputs[..., i, :, :], atol=1e-5, err_msg=msg + ) # Test with a sample input with the extra `num_ensemble` and `more_batch_dim` dimension if more_batch_dim: @@ -385,7 +388,6 @@ def check_ensemble_feedforwardnn( ensemble_output_i = ensemble_output_i.detach().numpy() np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg) - def check_ensemble_feedforwardnn_mean( self, in_dim: int, @@ -400,7 +402,12 @@ def check_ensemble_feedforwardnn_mean( # Create EnsembleFeedForwardNN instance hidden_dims = [17, 17, 17] ensemble_mlp = EnsembleFeedForwardNN( - in_dim, out_dim, hidden_dims, num_ensemble, reduction="mean", last_layer_is_readout=last_layer_is_readout + in_dim, + out_dim, + hidden_dims, + num_ensemble, + reduction="mean", + last_layer_is_readout=last_layer_is_readout, ) # Create equivalent separate MLP layers with synchronized weights and biases @@ -421,14 +428,15 @@ def check_ensemble_feedforwardnn_mean( # Check for the output shape self.assertEqual(ensemble_output.shape, (batch_size, out_dim), msg=msg) - # Make sure that the outputs of the individual layers are the same as the ensemble output individual_outputs = [] for i, mlp in enumerate(mlps): individual_outputs.append(mlp(input_tensor)) individual_outputs = torch.stack(individual_outputs, dim=-3) individual_outputs = individual_outputs.mean(dim=-3).detach().numpy() - np.testing.assert_allclose(ensemble_output.detach().numpy(), individual_outputs, atol=1e-5, err_msg=msg) + np.testing.assert_allclose( + ensemble_output.detach().numpy(), individual_outputs, atol=1e-5, err_msg=msg + ) # Test with a sample input with the extra `num_ensemble` and `more_batch_dim` dimension if more_batch_dim: @@ -452,24 +460,39 @@ def check_ensemble_feedforwardnn_mean( individual_output = torch.stack(individual_outputs, dim=-3).mean(dim=-3).detach().numpy() np.testing.assert_allclose(ensemble_output, individual_output, atol=1e-5, err_msg=msg) - - - def test_ensemble_feedforwardnn(self): # more_batch_dim=0 - self.check_ensemble_feedforwardnn(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0) - self.check_ensemble_feedforwardnn(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=0) - self.check_ensemble_feedforwardnn(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=0) + self.check_ensemble_feedforwardnn( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0 + ) + self.check_ensemble_feedforwardnn( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=0 + ) + self.check_ensemble_feedforwardnn( + in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=0 + ) # more_batch_dim=1 - self.check_ensemble_feedforwardnn(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1) - self.check_ensemble_feedforwardnn(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=1) - self.check_ensemble_feedforwardnn(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=1) + self.check_ensemble_feedforwardnn( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1 + ) + self.check_ensemble_feedforwardnn( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=1 + ) + self.check_ensemble_feedforwardnn( + in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=1 + ) # more_batch_dim=7 - self.check_ensemble_feedforwardnn(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7) - self.check_ensemble_feedforwardnn(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=7) - self.check_ensemble_feedforwardnn(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=7) + self.check_ensemble_feedforwardnn( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7 + ) + self.check_ensemble_feedforwardnn( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=7 + ) + self.check_ensemble_feedforwardnn( + in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=7 + ) # Test `last_layer_is_readout` self.check_ensemble_feedforwardnn( @@ -493,5 +516,6 @@ def test_ensemble_feedforwardnn(self): in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, last_layer_is_readout=True ) + if __name__ == "__main__": ut.main() From 05365054230cf0dd193a33d438a7929119e0bf67 Mon Sep 17 00:00:00 2001 From: DomInvivo Date: Thu, 14 Dec 2023 13:40:40 -0500 Subject: [PATCH 11/18] Edited the `spaces.py` --- graphium/utils/spaces.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/graphium/utils/spaces.py b/graphium/utils/spaces.py index 3a7f46109..5d0b82a0a 100644 --- a/graphium/utils/spaces.py +++ b/graphium/utils/spaces.py @@ -5,7 +5,7 @@ import graphium.nn.base_layers as BaseLayers import graphium.nn.ensemble_layers as EnsembleLayers -from graphium.nn.architectures import FeedForwardNN, FeedForwardPyg, TaskHeads +import graphium.nn.architectures as Architectures import graphium.utils.custom_lr as CustomLR import graphium.data.datamodule as Datamodules import graphium.ipu.ipu_losses as IPULosses @@ -138,4 +138,9 @@ "dummy-pretrained-model": "tests/dummy-pretrained-model.ckpt", # dummy model used for testing purposes } -FINETUNING_HEADS_DICT = {"mlp": FeedForwardNN, "gnn": FeedForwardPyg, "task_head": TaskHeads} +FINETUNING_HEADS_DICT = { + "mlp": Architectures.FeedForwardNN, + "gnn": Architectures.FeedForwardPyg, + "task_head": Architectures.TaskHeads, + "ens-mlp": Architectures.EnsembleFeedForwardNN, + } From dd157c442bb6ff501395922dfe5a47138a7e17c9 Mon Sep 17 00:00:00 2001 From: DomInvivo Date: Thu, 14 Dec 2023 13:47:52 -0500 Subject: [PATCH 12/18] black --- graphium/utils/spaces.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphium/utils/spaces.py b/graphium/utils/spaces.py index 5d0b82a0a..e18cd2302 100644 --- a/graphium/utils/spaces.py +++ b/graphium/utils/spaces.py @@ -143,4 +143,4 @@ "gnn": Architectures.FeedForwardPyg, "task_head": Architectures.TaskHeads, "ens-mlp": Architectures.EnsembleFeedForwardNN, - } +} From cbd3758aec44d788cab5931d842617033e112880 Mon Sep 17 00:00:00 2001 From: DomInvivo Date: Thu, 14 Dec 2023 15:53:28 -0500 Subject: [PATCH 13/18] added more options to the `parse_reduction` in EnsembleMLP --- graphium/nn/ensemble_layers.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/graphium/nn/ensemble_layers.py b/graphium/nn/ensemble_layers.py index 7915fc9d2..3787b0b3b 100644 --- a/graphium/nn/ensemble_layers.py +++ b/graphium/nn/ensemble_layers.py @@ -321,8 +321,9 @@ def __init__( - "sum": Sum reduction - "max": Max reduction - "min": Min reduction + - "median": Median reduction - `Callable`: Any callable function. Must take `dim` as a keyword argument. - activation: + activation: Activation function to use in all the layers except the last. if `layers==1`, this parameter is ignored last_activation: @@ -389,7 +390,12 @@ def _parse_reduction(self, reduction: Optional[Union[str, Callable]]) -> Optiona return torch.mean elif reduction == "sum": return torch.sum - + elif reduction == "max": + return torch.max + elif reduction == "min": + return torch.min + elif reduction == "median": + return torch.median elif callable(reduction): return reduction else: From 2b1f9ab12e909372a30c966ee79fe8e05ebe7e79 Mon Sep 17 00:00:00 2001 From: DomInvivo Date: Thu, 14 Dec 2023 16:04:45 -0500 Subject: [PATCH 14/18] fixed the `parse_reduction` to take values from min, max, median --- graphium/nn/architectures/global_architectures.py | 13 ++++++++++--- graphium/nn/ensemble_layers.py | 12 +++++++++--- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/graphium/nn/architectures/global_architectures.py b/graphium/nn/architectures/global_architectures.py index 6f8bd2141..97bce7f0a 100644 --- a/graphium/nn/architectures/global_architectures.py +++ b/graphium/nn/architectures/global_architectures.py @@ -599,16 +599,23 @@ def _parse_reduction(self, reduction: Optional[Union[str, Callable]]) -> Optiona elif reduction == "sum": return torch.sum elif reduction == "max": - return torch.max + def max_vals(x, dim): + return torch.max(x, dim=dim).values + return max_vals elif reduction == "min": - return torch.min + def min_vals(x, dim): + return torch.min(x, dim=dim).values + return min_vals elif reduction == "median": - return torch.median + def median_vals(x, dim): + return torch.median(x, dim=dim).values + return median_vals elif callable(reduction): return reduction else: raise ValueError(f"Unknown reduction {reduction}") + def _parse_layers(self, layer_type, residual_type): # Parse the layer and residuals from graphium.utils.spaces import ENSEMBLE_LAYERS_DICT, RESIDUALS_DICT diff --git a/graphium/nn/ensemble_layers.py b/graphium/nn/ensemble_layers.py index 3787b0b3b..b1e51b5fd 100644 --- a/graphium/nn/ensemble_layers.py +++ b/graphium/nn/ensemble_layers.py @@ -391,11 +391,17 @@ def _parse_reduction(self, reduction: Optional[Union[str, Callable]]) -> Optiona elif reduction == "sum": return torch.sum elif reduction == "max": - return torch.max + def max_vals(x, dim): + return torch.max(x, dim=dim).values + return max_vals elif reduction == "min": - return torch.min + def min_vals(x, dim): + return torch.min(x, dim=dim).values + return min_vals elif reduction == "median": - return torch.median + def median_vals(x, dim): + return torch.median(x, dim=dim).values + return median_vals elif callable(reduction): return reduction else: From 785935175cff903be7803f1e37bd588005d0e7a5 Mon Sep 17 00:00:00 2001 From: DomInvivo Date: Fri, 15 Dec 2023 14:07:17 -0500 Subject: [PATCH 15/18] Added subset of the input features --- .../nn/architectures/global_architectures.py | 71 ++++++++++++++++++- 1 file changed, 70 insertions(+), 1 deletion(-) diff --git a/graphium/nn/architectures/global_architectures.py b/graphium/nn/architectures/global_architectures.py index 97bce7f0a..ba6f319b2 100644 --- a/graphium/nn/architectures/global_architectures.py +++ b/graphium/nn/architectures/global_architectures.py @@ -413,6 +413,7 @@ def __init__( hidden_dims: Union[List[int], int], num_ensemble: int, reduction: Union[str, Callable], + subset_sample_ratio: float = 1.0, depth: Optional[int] = None, activation: Union[str, Callable] = "relu", last_activation: Union[str, Callable] = "none", @@ -461,6 +462,11 @@ def __init__( - "median": Median reduction - `Callable`: Any callable function. Must take `dim` as a keyword argument. + subset_sample_ratio: + Ratio of the subset of the ensemble to use. + Must be between 0 and 1. A different subset is used for each ensemble. + Only valid if the input shape is `[B, Din]`. + depth: If `hidden_dims` is an integer, `depth` is 1 + the number of hidden layers to use. @@ -532,6 +538,9 @@ def __init__( layer_kwargs = {} layer_kwargs["num_ensemble"] = self._parse_num_ensemble(num_ensemble, layer_kwargs) + # Parse the sample ratio + self.subset_sample_ratio, self.subset_in_dim, self.subset_idx = self._parse_subset_sample(subset_sample_ratio, num_ensemble) + super().__init__( in_dim=in_dim, out_dim=out_dim, @@ -556,6 +565,10 @@ def __init__( self.reduction = reduction self.reduction_fn = self._parse_reduction(reduction) + def _create_layers(self): + self.full_dims[0] = self.subset_in_dim + super()._create_layers() + def _parse_num_ensemble(self, num_ensemble: int, layer_kwargs) -> int: r""" Parse the num_ensemble argument. @@ -615,6 +628,52 @@ def median_vals(x, dim): else: raise ValueError(f"Unknown reduction {reduction}") + def _parse_subset_sample(self, in_dim: int, subset_sample_ratio: float, num_ensemble: int) -> Tuple[float, int]: + r""" + Parse the subset_sample_ratio argument and the subset_in_dim. + + The subset_sample_ratio is the ratio of the hidden features to use by each MLP of the ensemble. + The subset_in_dim is the number of input features to use by each MLP of the ensemble. + + Parameters: + + in_dim: The number of input features, before subsampling + + subset_sample_ratio: + Ratio of the subset of features to use by each MLP of the ensemble. + Must be between 0 and 1. A different subset is used for each ensemble. + Only valid if the input shape is `[B, Din]`. + + If None, the subset_sample_ratio is set to 1.0. + + num_ensemble: + Number of MLPs that run in parallel. + + Returns: + + subset_sample_ratio: The ratio of the subset of features to use by each MLP of the ensemble. + subset_in_dim: The number of input features to use by each MLP of the ensemble. + subset_idx: The indices of the features to use by each MLP of the ensemble. + """ + + # Parse the subset_sample_ratio, make sure value is between 0 and 1 + if subset_sample_ratio is None: + subset_sample_ratio = 1.0 + assert subset_sample_ratio > 0.0 and subset_sample_ratio <= 1.0, f"subset_sample_ratio={subset_sample_ratio}" + + # Parse the subset_in_dim, make sure value is between 0 and in_dim + subset_in_dim = int(torch.ceil(in_dim * subset_sample_ratio).item()) + if subset_in_dim == 0: + subset_in_dim = 1 + + # Create the subset_idx, which is a list of indices to use for each ensemble + if subset_in_dim == in_dim: + subset_idx = None + else: + subset_idx = torch.stack([ + torch.randperm(in_dim)[:subset_in_dim] for _ in range(num_ensemble)]) + + return subset_sample_ratio, subset_in_dim, subset_idx def _parse_layers(self, layer_type, residual_type): # Parse the layer and residuals @@ -625,7 +684,9 @@ def _parse_layers(self, layer_type, residual_type): def forward(self, h: torch.Tensor) -> torch.Tensor: r""" - Apply the ensemble MLP on the input features, then reduce the output if specified. + Subset the hidden dimension for each MLP, + forward the ensemble MLP on the input features, + then reduce the output if specified. Parameters: @@ -642,8 +703,16 @@ def forward(self, h: torch.Tensor) -> torch.Tensor: `Dout` is the number of output features, `B` is the batch size, and `L` is the number of ensembles. `L` is removed if a reduction is specified. """ + # Subset the input features for each MLP in the ensemble + if self.subset_idx is not None: + if len(h.shape) != 2: + assert h.shape[-3] == 1, f"Expected shape to be [B, Din] or [..., 1, B, Din], got {h.shape}" + h = h[..., self.subset_idx] + # Run the standard forward pass h = super().forward(h) + + # Reduce the output if specified if self.reduction_fn is not None: h = self.reduction_fn(h, dim=-3) From f95872de739bdb8d559f0535e82e983317aeb576 Mon Sep 17 00:00:00 2001 From: DomInvivo Date: Fri, 15 Dec 2023 14:54:40 -0500 Subject: [PATCH 16/18] Added `subset_in_dim` option --- .../nn/architectures/global_architectures.py | 65 ++++++++++++------- graphium/nn/ensemble_layers.py | 6 ++ 2 files changed, 46 insertions(+), 25 deletions(-) diff --git a/graphium/nn/architectures/global_architectures.py b/graphium/nn/architectures/global_architectures.py index ba6f319b2..c0b2a4b24 100644 --- a/graphium/nn/architectures/global_architectures.py +++ b/graphium/nn/architectures/global_architectures.py @@ -413,7 +413,7 @@ def __init__( hidden_dims: Union[List[int], int], num_ensemble: int, reduction: Union[str, Callable], - subset_sample_ratio: float = 1.0, + subset_in_dim: Union[float, int] = 1.0, depth: Optional[int] = None, activation: Union[str, Callable] = "relu", last_activation: Union[str, Callable] = "none", @@ -462,9 +462,11 @@ def __init__( - "median": Median reduction - `Callable`: Any callable function. Must take `dim` as a keyword argument. - subset_sample_ratio: - Ratio of the subset of the ensemble to use. - Must be between 0 and 1. A different subset is used for each ensemble. + subset_in_dim: + If float, ratio of the subset of the ensemble to use. Must be between 0 and 1. + If int, number of elements to subset from in_dim. + If `None`, the subset_in_dim is set to `1.0`. + A different subset is used for each ensemble. Only valid if the input shape is `[B, Din]`. depth: @@ -538,8 +540,8 @@ def __init__( layer_kwargs = {} layer_kwargs["num_ensemble"] = self._parse_num_ensemble(num_ensemble, layer_kwargs) - # Parse the sample ratio - self.subset_sample_ratio, self.subset_in_dim, self.subset_idx = self._parse_subset_sample(subset_sample_ratio, num_ensemble) + # Parse the sample input dimension + self.subset_in_dim, self.subset_idx = self._parse_subset_in_dim(in_dim, subset_in_dim, num_ensemble) super().__init__( in_dim=in_dim, @@ -612,68 +614,81 @@ def _parse_reduction(self, reduction: Optional[Union[str, Callable]]) -> Optiona elif reduction == "sum": return torch.sum elif reduction == "max": + def max_vals(x, dim): return torch.max(x, dim=dim).values + return max_vals elif reduction == "min": + def min_vals(x, dim): return torch.min(x, dim=dim).values + return min_vals elif reduction == "median": + def median_vals(x, dim): return torch.median(x, dim=dim).values + return median_vals elif callable(reduction): return reduction else: raise ValueError(f"Unknown reduction {reduction}") - def _parse_subset_sample(self, in_dim: int, subset_sample_ratio: float, num_ensemble: int) -> Tuple[float, int]: + def _parse_subset_in_dim( + self, in_dim: int, subset_in_dim: Union[float, int], num_ensemble: int + ) -> Tuple[float, int]: r""" - Parse the subset_sample_ratio argument and the subset_in_dim. + Parse the subset_in_dim argument and the subset_in_dim. - The subset_sample_ratio is the ratio of the hidden features to use by each MLP of the ensemble. + The subset_in_dim is the ratio of the hidden features to use by each MLP of the ensemble. The subset_in_dim is the number of input features to use by each MLP of the ensemble. Parameters: in_dim: The number of input features, before subsampling - subset_sample_ratio: + subset_in_dim: Ratio of the subset of features to use by each MLP of the ensemble. Must be between 0 and 1. A different subset is used for each ensemble. Only valid if the input shape is `[B, Din]`. - If None, the subset_sample_ratio is set to 1.0. + If None, the subset_in_dim is set to 1.0. num_ensemble: Number of MLPs that run in parallel. Returns: - subset_sample_ratio: The ratio of the subset of features to use by each MLP of the ensemble. - subset_in_dim: The number of input features to use by each MLP of the ensemble. + subset_in_dim: The ratio of the subset of features to use by each MLP of the ensemble. subset_idx: The indices of the features to use by each MLP of the ensemble. """ - # Parse the subset_sample_ratio, make sure value is between 0 and 1 - if subset_sample_ratio is None: - subset_sample_ratio = 1.0 - assert subset_sample_ratio > 0.0 and subset_sample_ratio <= 1.0, f"subset_sample_ratio={subset_sample_ratio}" + # Parse the subset_in_dim, make sure value is between 0 and 1 + if subset_in_dim is None: + subset_in_dim = 1.0 + if isinstance(subset_in_dim, int): + assert ( + subset_in_dim > 0 and subset_in_dim <= in_dim + ), f"subset_in_dim={subset_in_dim}, in_dim={in_dim}" + elif isinstance(subset_in_dim, float): + assert subset_in_dim > 0.0 and subset_in_dim <= 1.0, f"subset_in_dim={subset_in_dim}" - # Parse the subset_in_dim, make sure value is between 0 and in_dim - subset_in_dim = int(torch.ceil(in_dim * subset_sample_ratio).item()) - if subset_in_dim == 0: - subset_in_dim = 1 + # Convert to integer value + subset_in_dim = int(in_dim * subset_in_dim) + if subset_in_dim == 0: + subset_in_dim = 1 # Create the subset_idx, which is a list of indices to use for each ensemble if subset_in_dim == in_dim: subset_idx = None else: - subset_idx = torch.stack([ - torch.randperm(in_dim)[:subset_in_dim] for _ in range(num_ensemble)]) + subset_idx = torch.stack( + [torch.randperm(in_dim)[:subset_in_dim] for _ in range(num_ensemble)] + ).unsqueeze(-2) - return subset_sample_ratio, subset_in_dim, subset_idx + return subset_in_dim, subset_idx def _parse_layers(self, layer_type, residual_type): # Parse the layer and residuals @@ -707,7 +722,7 @@ def forward(self, h: torch.Tensor) -> torch.Tensor: if self.subset_idx is not None: if len(h.shape) != 2: assert h.shape[-3] == 1, f"Expected shape to be [B, Din] or [..., 1, B, Din], got {h.shape}" - h = h[..., self.subset_idx] + h = h[..., self.subset_idx] # Run the standard forward pass h = super().forward(h) diff --git a/graphium/nn/ensemble_layers.py b/graphium/nn/ensemble_layers.py index b1e51b5fd..41c9a4018 100644 --- a/graphium/nn/ensemble_layers.py +++ b/graphium/nn/ensemble_layers.py @@ -391,16 +391,22 @@ def _parse_reduction(self, reduction: Optional[Union[str, Callable]]) -> Optiona elif reduction == "sum": return torch.sum elif reduction == "max": + def max_vals(x, dim): return torch.max(x, dim=dim).values + return max_vals elif reduction == "min": + def min_vals(x, dim): return torch.min(x, dim=dim).values + return min_vals elif reduction == "median": + def median_vals(x, dim): return torch.median(x, dim=dim).values + return median_vals elif callable(reduction): return reduction From 8ad1822f6625c2f5d0dfd470e6c6243b36b3c966 Mon Sep 17 00:00:00 2001 From: DomInvivo Date: Fri, 15 Dec 2023 15:52:18 -0500 Subject: [PATCH 17/18] Added tests for the subset_in_dim --- .../nn/architectures/global_architectures.py | 8 +-- tests/test_ensemble_layers.py | 63 +++++++++++++++++++ 2 files changed, 66 insertions(+), 5 deletions(-) diff --git a/graphium/nn/architectures/global_architectures.py b/graphium/nn/architectures/global_architectures.py index c0b2a4b24..2b2ecc1c2 100644 --- a/graphium/nn/architectures/global_architectures.py +++ b/graphium/nn/architectures/global_architectures.py @@ -684,9 +684,7 @@ def _parse_subset_in_dim( if subset_in_dim == in_dim: subset_idx = None else: - subset_idx = torch.stack( - [torch.randperm(in_dim)[:subset_in_dim] for _ in range(num_ensemble)] - ).unsqueeze(-2) + subset_idx = torch.stack([torch.randperm(in_dim)[:subset_in_dim] for _ in range(num_ensemble)]) return subset_in_dim, subset_idx @@ -721,8 +719,8 @@ def forward(self, h: torch.Tensor) -> torch.Tensor: # Subset the input features for each MLP in the ensemble if self.subset_idx is not None: if len(h.shape) != 2: - assert h.shape[-3] == 1, f"Expected shape to be [B, Din] or [..., 1, B, Din], got {h.shape}" - h = h[..., self.subset_idx] + assert h.shape[-3] == 1, f"Expected shape to be [B, Din] or [..., 1, B, Din], got {h.shape}." + h = h[..., self.subset_idx].transpose(-2, -3) # Run the standard forward pass h = super().forward(h) diff --git a/tests/test_ensemble_layers.py b/tests/test_ensemble_layers.py index 581adbbf3..e43b14ac1 100644 --- a/tests/test_ensemble_layers.py +++ b/tests/test_ensemble_layers.py @@ -460,6 +460,37 @@ def check_ensemble_feedforwardnn_mean( individual_output = torch.stack(individual_outputs, dim=-3).mean(dim=-3).detach().numpy() np.testing.assert_allclose(ensemble_output, individual_output, atol=1e-5, err_msg=msg) + def check_ensemble_feedforwardnn_simple( + self, + in_dim: int, + out_dim: int, + num_ensemble: int, + batch_size: int, + more_batch_dim: int, + last_layer_is_readout=False, + **kwargs, + ): + msg = f"Testing EnsembleFeedForwardNN with in_dim={in_dim}, out_dim={out_dim}, num_ensemble={num_ensemble}, batch_size={batch_size}, more_batch_dim={more_batch_dim}" + + # Create EnsembleFeedForwardNN instance + hidden_dims = [17, 17, 17] + ensemble_mlp = EnsembleFeedForwardNN( + in_dim, + out_dim, + hidden_dims, + num_ensemble, + reduction=None, + last_layer_is_readout=last_layer_is_readout, + **kwargs, + ) + + # Test with a sample input + input_tensor = torch.randn(batch_size, in_dim) + ensemble_output = ensemble_mlp(input_tensor) + + # Check for the output shape + self.assertEqual(ensemble_output.shape, (num_ensemble, batch_size, out_dim), msg=msg) + def test_ensemble_feedforwardnn(self): # more_batch_dim=0 self.check_ensemble_feedforwardnn( @@ -516,6 +547,38 @@ def test_ensemble_feedforwardnn(self): in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, last_layer_is_readout=True ) + # Test `subset_in_dim` + self.check_ensemble_feedforwardnn_simple( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0, subset_in_dim=0.5 + ) + self.check_ensemble_feedforwardnn_simple( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1, subset_in_dim=0.5 + ) + self.check_ensemble_feedforwardnn_simple( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, subset_in_dim=0.5 + ) + self.check_ensemble_feedforwardnn_simple( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0, subset_in_dim=7 + ) + self.check_ensemble_feedforwardnn_simple( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1, subset_in_dim=7 + ) + self.check_ensemble_feedforwardnn_simple( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, subset_in_dim=7 + ) + with self.assertRaises(AssertionError): + self.check_ensemble_feedforwardnn_simple( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0, subset_in_dim=1.5 + ) + with self.assertRaises(AssertionError): + self.check_ensemble_feedforwardnn_simple( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1, subset_in_dim=39 + ) + with self.assertRaises(AssertionError): + self.check_ensemble_feedforwardnn_simple( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, subset_in_dim=39 + ) + if __name__ == "__main__": ut.main() From 62ab1ad23c8a43b7f801f5397459e25a9a10c168 Mon Sep 17 00:00:00 2001 From: DomInvivo Date: Fri, 15 Dec 2023 15:58:47 -0500 Subject: [PATCH 18/18] Code cleaning --- graphium/nn/architectures/global_architectures.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/graphium/nn/architectures/global_architectures.py b/graphium/nn/architectures/global_architectures.py index 2b2ecc1c2..1b6f44dd9 100644 --- a/graphium/nn/architectures/global_architectures.py +++ b/graphium/nn/architectures/global_architectures.py @@ -666,8 +666,9 @@ def _parse_subset_in_dim( """ # Parse the subset_in_dim, make sure value is between 0 and 1 + subset_idx = None if subset_in_dim is None: - subset_in_dim = 1.0 + return 1.0, None if isinstance(subset_in_dim, int): assert ( subset_in_dim > 0 and subset_in_dim <= in_dim @@ -681,9 +682,7 @@ def _parse_subset_in_dim( subset_in_dim = 1 # Create the subset_idx, which is a list of indices to use for each ensemble - if subset_in_dim == in_dim: - subset_idx = None - else: + if subset_in_dim != in_dim: subset_idx = torch.stack([torch.randperm(in_dim)[:subset_in_dim] for _ in range(num_ensemble)]) return subset_in_dim, subset_idx @@ -719,7 +718,9 @@ def forward(self, h: torch.Tensor) -> torch.Tensor: # Subset the input features for each MLP in the ensemble if self.subset_idx is not None: if len(h.shape) != 2: - assert h.shape[-3] == 1, f"Expected shape to be [B, Din] or [..., 1, B, Din], got {h.shape}." + assert ( + h.shape[-3] == 1 + ), f"Expected shape to be [B, Din] or [..., 1, B, Din] when using `subset_in_dim`, got {h.shape}." h = h[..., self.subset_idx].transpose(-2, -3) # Run the standard forward pass