From d5cdf70d9ad20d6a047561e4752074baec944519 Mon Sep 17 00:00:00 2001 From: Sebastian Walter Date: Fri, 30 Oct 2020 13:29:51 +0100 Subject: [PATCH] backbones and heads are now handled by a BackboneHeadNet, leaving the chance to still create custom networks that are not split into backbone and head --- .idea/Auto-PyTorch.iml | 2 +- .idea/misc.xml | 2 +- .../setup/network/BackboneHeadNet.py | 112 +++++++ .../setup/network/InceptionTimeNet.py | 176 ---------- .../components/setup/network/ResNet.py | 5 - .../components/setup/network/TCNNet.py | 159 --------- .../setup/network/backbone/__init__.py | 29 ++ .../setup/network/backbone/base_backbone.py | 17 +- .../setup/network/backbone/tabular.py | 12 +- .../setup/network/backbone/time_series.py | 301 ++++++++++++++++++ .../components/setup/network/base_network.py | 63 ++-- .../setup/network/base_network_choice.py | 28 +- .../components/setup/network/head/__init__.py | 29 ++ .../setup/network/head/base_head.py | 16 +- 14 files changed, 536 insertions(+), 415 deletions(-) create mode 100644 autoPyTorch/pipeline/components/setup/network/BackboneHeadNet.py delete mode 100644 autoPyTorch/pipeline/components/setup/network/InceptionTimeNet.py delete mode 100644 autoPyTorch/pipeline/components/setup/network/TCNNet.py create mode 100644 autoPyTorch/pipeline/components/setup/network/backbone/time_series.py diff --git a/.idea/Auto-PyTorch.iml b/.idea/Auto-PyTorch.iml index 08e5fe3..8f42704 100644 --- a/.idea/Auto-PyTorch.iml +++ b/.idea/Auto-PyTorch.iml @@ -4,7 +4,7 @@ - + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml index f5106e0..ba07413 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -1,4 +1,4 @@ - + \ No newline at end of file diff --git a/autoPyTorch/pipeline/components/setup/network/BackboneHeadNet.py b/autoPyTorch/pipeline/components/setup/network/BackboneHeadNet.py new file mode 100644 index 0000000..b5e83e0 --- /dev/null +++ b/autoPyTorch/pipeline/components/setup/network/BackboneHeadNet.py @@ -0,0 +1,112 @@ +from typing import Any, Dict, Optional, Tuple + +import ConfigSpace as CS +import numpy as np +from ConfigSpace.configuration_space import ConfigurationSpace +from ConfigSpace.hyperparameters import ( + CategoricalHyperparameter +) +from torch import nn + +from autoPyTorch.pipeline.components.setup.network.backbone import get_available_backbones, BaseBackbone, MLPBackbone, \ + ShapedMLPBackbone +from autoPyTorch.pipeline.components.setup.network.base_network import BaseNetworkComponent +from autoPyTorch.pipeline.components.setup.network.head import get_available_heads, BaseHead, FullyConnectedHead + + +class BackboneHeadNet(BaseNetworkComponent): + """ + Implementation of a dynamic network, that consists of a backbone and a head + """ + + def __init__( + self, + random_state: Optional[np.random.RandomState] = None, + **kwargs: Any + ): + super().__init__( + random_state=random_state, + ) + self.config = kwargs + self._backbones = get_available_backbones() + self._heads = get_available_heads() + self._backbones = get_available_backbones() + self._heads = get_available_heads() + + @staticmethod + def get_properties(dataset_properties: Optional[Dict[str, str]] = None) -> Dict[str, Any]: + return { + "shortname": "BackboneHeadNet", + "name": "BackboneHeadNet", + } + + @staticmethod + def get_hyperparameter_search_space(dataset_properties: Optional[Dict[str, str]] = None, + **kwargs: Any) -> ConfigurationSpace: + cs = ConfigurationSpace() + backbones = get_available_backbones() + heads = get_available_heads() + + # filter backbones and heads for those who support the current task type + task = dataset_properties["task_type"] + backbones = {name: backbone for name, backbone in backbones.items() if task in backbone.supported_tasks} + heads = {name: head for name, head in heads.items() if task in head.supported_tasks} + + backbone_hp = CategoricalHyperparameter("backbone", choices=backbones.keys()) + head_hp = CategoricalHyperparameter("head", choices=heads.keys()) + cs.add_hyperparameters([backbone_hp, head_hp]) + + # for each backbone and head, add a conditional search space if this backbone or head is chosen + for backbone_name in backbones.keys(): + backbone_cs = backbones[backbone_name].get_hyperparameter_search_space(dataset_properties) + cs.add_configuration_space(backbone_name, + backbone_cs, + parent_hyperparameter={"parent": backbone_hp, "value": backbone_name}) + + for head_name in heads.keys(): + head_cs: ConfigurationSpace = heads[head_name].get_hyperparameter_search_space(dataset_properties) + cs.add_configuration_space(head_name, + head_cs, + parent_hyperparameter={"parent": head_hp, "value": head_name}) + return cs + + def build_network(self, input_shape: Tuple[int, ...], output_shape: Tuple[int, ...]) -> nn.Module: + """This method returns a pytorch network, that is dynamically built + + a self.config that is network specific, and contains the additional + configuration hyperparameters to build a domain specific network + """ + backbone_name = self.config["backbone"] + head_name = self.config["head"] + Backbone = self._backbones[backbone_name] + Head = self._heads[head_name] + + backbone = Backbone(**{k.replace(backbone_name, "").replace(":", ""): v + for k, v in self.config.items() if + k.startswith(backbone_name)}) + backbone_module = backbone.build_backbone(input_shape=input_shape) + backbone_output_shape = backbone.get_output_shape(input_shape=input_shape) + + head = Head(**{k.replace(head_name, "").replace(":", ""): v + for k, v in self.config.items() if + k.startswith(head_name)}) + head_module = head.build_head(input_shape=backbone_output_shape, output_shape=output_shape) + + return nn.Sequential(backbone_module, head_module) + + def __str__(self) -> str: + """ Allow a nice understanding of what components where used """ + info = vars(self) + # Remove unwanted info + info.pop('network', None) + info.pop('random_state', None) + return f"{self.config['backbone']} -> {self.config['head']} ({str(info)})" + + +if __name__ == "__main__": + cs = BackboneHeadNet.get_hyperparameter_search_space(dataset_properties={"task_type": "tabular_classification"}) + print(cs) + sample = cs.sample_configuration() + bnet = BackboneHeadNet(**sample) + print(bnet) + net = BackboneHeadNet(**sample).build_network(**{"input_shape": (10,), "output_shape": (10,)}) diff --git a/autoPyTorch/pipeline/components/setup/network/InceptionTimeNet.py b/autoPyTorch/pipeline/components/setup/network/InceptionTimeNet.py deleted file mode 100644 index 59734f8..0000000 --- a/autoPyTorch/pipeline/components/setup/network/InceptionTimeNet.py +++ /dev/null @@ -1,176 +0,0 @@ -# Code inspired by https://github.com/hfawaz/InceptionTime -# Paper: https://arxiv.org/pdf/1909.04939.pdf -from typing import Optional, Dict, Any - -import numpy as np -import torch -from torch import nn - -from ConfigSpace.configuration_space import ConfigurationSpace -from ConfigSpace.hyperparameters import UniformIntegerHyperparameter - -from autoPyTorch.pipeline.components.setup.network.base_network import BaseNetworkComponent - - -class _InceptionBlock(nn.Module): - def __init__(self, n_inputs, n_filters, kernel_size, bottleneck=None): - super(_InceptionBlock, self).__init__() - self.n_filters = n_filters - self.bottleneck = None \ - if bottleneck is None \ - else nn.Conv1d(n_inputs, bottleneck, kernel_size=1) - kernel_sizes = [kernel_size // (2 ** i) for i in range(3)] - n_inputs = n_inputs if bottleneck is None else bottleneck - # create 3 conv layers with different kernel sizes which are applied in parallel - self.pad1 = nn.ConstantPad1d( - padding=self.padding(kernel_sizes[0]), value=0) - self.conv1 = nn.Conv1d(n_inputs, n_filters, kernel_sizes[0]) - self.pad2 = nn.ConstantPad1d( - padding=self.padding(kernel_sizes[1]), value=0) - self.conv2 = nn.Conv1d(n_inputs, n_filters, kernel_sizes[1]) - self.pad3 = nn.ConstantPad1d( - padding=self.padding(kernel_sizes[2]), value=0) - self.conv3 = nn.Conv1d(n_inputs, n_filters, kernel_sizes[2]) - # create 1 maxpool and conv layer which are also applied in parallel - self.maxpool = nn.MaxPool1d(kernel_size=3, stride=1, padding=1) - self.convpool = nn.Conv1d(n_inputs, n_filters, 1) - - self.bn = nn.BatchNorm1d(4 * n_filters) - - def padding(self, kernel_size): - if kernel_size % 2 == 0: - return kernel_size // 2, kernel_size // 2 - 1 - else: - return kernel_size // 2, kernel_size // 2 - - def get_n_outputs(self): - return 4 * self.n_filters - - def forward(self, x): - if self.bottleneck is not None: - x = self.bottleneck(x) - x1 = self.conv1(self.pad1(x)) - x2 = self.conv2(self.pad2(x)) - x3 = self.conv3(self.pad3(x)) - x4 = self.convpool(self.maxpool(x)) - x = torch.cat([x1, x2, x3, x4], dim=1) - x = self.bn(x) - return torch.relu(x) - - -class _ResidualBlock(nn.Module): - def __init__(self, n_res_inputs, n_outputs): - super(_ResidualBlock, self).__init__() - self.shortcut = nn.Conv1d(n_res_inputs, n_outputs, 1, bias=False) - self.bn = nn.BatchNorm1d(n_outputs) - - def forward(self, x, res): - shortcut = self.shortcut(res) - shortcut = self.bn(shortcut) - x += shortcut - return torch.relu(x) - - -class _InceptionTime(nn.Module): - def __init__(self, - in_features: int, - out_features: int, - config: Dict[str, Any]) -> None: - super().__init__() - self.config = config - n_inputs = in_features - n_filters = self.config["num_filters"] - bottleneck_size = self.config["bottleneck_size"] - kernel_size = self.config["kernel_size"] - n_res_inputs = in_features - for i in range(self.config["num_blocks"]): - block = _InceptionBlock(n_inputs=n_inputs, - n_filters=n_filters, - bottleneck=bottleneck_size, - kernel_size=kernel_size) - self.__setattr__(f"inception_block_{i}", block) - - # add a residual block after every 3 inception blocks - if i % 3 == 2: - n_res_outputs = block.get_n_outputs() - self.__setattr__(f"residual_block_{i}", _ResidualBlock(n_res_inputs=n_res_inputs, - n_outputs=n_res_outputs)) - n_res_inputs = n_res_outputs - n_inputs = block.get_n_outputs() - - self.global_avg_pool = nn.AdaptiveAvgPool1d(1) - - fc_layers = [ - nn.Linear(in_features=block.get_n_outputs(), out_features=out_features)] - self.fc_layers = nn.Sequential(*fc_layers) - - def forward(self, x): - # swap sequence and feature dimensions for use with convolutional nets - x = x.transpose(1, 2).contiguous() - res = x - for i in range(self.config["num_blocks"]): - x = self.__getattr__(f"inception_block_{i}")(x) - if i % 3 == 2: - x = self.__getattr__(f"residual_block_{i}")(x, res) - res = x - x = self.global_avg_pool(x) - x = x.permute(0, 2, 1) - x = self.fc_layers(x).squeeze(dim=1) - return x - - -class InceptionTime(BaseNetworkComponent): - def __init__(self, - intermediate_activation: str, - final_activation: Optional[str], - random_state: Optional[np.random.RandomState] = None, - **kwargs) -> None: - super().__init__(intermediate_activation, final_activation, random_state) - self.config = kwargs - - @staticmethod - def get_properties(dataset_properties: Optional[Dict[str, str]] = None) -> Dict[str, Any]: - return { - "shortname": "InceptionTime", - "name": "InceptionTime", - } - - @staticmethod - def get_hyperparameter_search_space(dataset_properties: Optional[Dict[str, str]] = None, - min_num_blocks: int = 1, - max_num_blocks: int = 10, - min_num_filters: int = 16, - max_num_filters: int = 64, - min_kernel_size: int = 32, - max_kernel_size: int = 64, - min_bottleneck_size: int = 16, - max_bottleneck_size: int = 64, - ) -> ConfigurationSpace: - cs = ConfigurationSpace() - - num_blocks_hp = UniformIntegerHyperparameter("num_blocks", - lower=min_num_blocks, - upper=max_num_blocks) - cs.add_hyperparameter(num_blocks_hp) - - num_filters_hp = UniformIntegerHyperparameter("num_filters", - lower=min_num_filters, - upper=max_num_filters) - cs.add_hyperparameter(num_filters_hp) - - bottleneck_size_hp = UniformIntegerHyperparameter("bottleneck_size", - lower=min_bottleneck_size, - upper=max_bottleneck_size) - cs.add_hyperparameter(bottleneck_size_hp) - - kernel_size_hp = UniformIntegerHyperparameter("kernel_size", - lower=min_kernel_size, - upper=max_kernel_size) - cs.add_hyperparameter(kernel_size_hp) - return cs - - def build_network(self, in_feature: int, out_features: int) -> torch.nn.Module: - network = _InceptionTime(in_features=in_feature, - out_features=out_features, - config=self.config) - return network diff --git a/autoPyTorch/pipeline/components/setup/network/ResNet.py b/autoPyTorch/pipeline/components/setup/network/ResNet.py index c9b1070..cfe64af 100644 --- a/autoPyTorch/pipeline/components/setup/network/ResNet.py +++ b/autoPyTorch/pipeline/components/setup/network/ResNet.py @@ -39,15 +39,11 @@ class ResNet(BaseNetworkComponent): def __init__( self, - intermediate_activation: str, - final_activation: Optional[str] = None, random_state: Optional[np.random.RandomState] = None, **kwargs: Any ): super().__init__( - intermediate_activation=intermediate_activation, - final_activation=final_activation, random_state=random_state, ) self.config = kwargs @@ -130,7 +126,6 @@ def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, min_num_units: int = 10, max_num_units: int = 1024, ) -> ConfigurationSpace: - cs = ConfigurationSpace() # The number of groups that will compose the resnet. That is, diff --git a/autoPyTorch/pipeline/components/setup/network/TCNNet.py b/autoPyTorch/pipeline/components/setup/network/TCNNet.py deleted file mode 100644 index ce49867..0000000 --- a/autoPyTorch/pipeline/components/setup/network/TCNNet.py +++ /dev/null @@ -1,159 +0,0 @@ -# Chomp1d, TemporalBlock and TemporalConvNet copied from -# https://github.com/locuslab/TCN/blob/master/TCN/tcn.py, Carnegie Mellon University Locus Labs -# Paper: https://arxiv.org/pdf/1803.01271.pdf -from typing import Optional, Dict, Any - -import numpy as np -from torch import nn -from torch.nn.utils import weight_norm - -import ConfigSpace as CS -from ConfigSpace.configuration_space import ConfigurationSpace -from ConfigSpace.hyperparameters import ( - CategoricalHyperparameter, - UniformFloatHyperparameter, - UniformIntegerHyperparameter -) - -from autoPyTorch.pipeline.components.setup.network.base_network import BaseNetworkComponent - - -class _Chomp1d(nn.Module): - def __init__(self, chomp_size): - super(_Chomp1d, self).__init__() - self.chomp_size = chomp_size - - def forward(self, x): - return x[:, :, :-self.chomp_size].contiguous() - - -class _TemporalBlock(nn.Module): - def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2): - super(_TemporalBlock, self).__init__() - self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size, - stride=stride, padding=padding, dilation=dilation)) - self.chomp1 = _Chomp1d(padding) - self.relu1 = nn.ReLU() - self.dropout1 = nn.Dropout(dropout) - - self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size, - stride=stride, padding=padding, dilation=dilation)) - self.chomp2 = _Chomp1d(padding) - self.relu2 = nn.ReLU() - self.dropout2 = nn.Dropout(dropout) - - self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1, - self.conv2, self.chomp2, self.relu2, self.dropout2) - self.downsample = nn.Conv1d( - n_inputs, n_outputs, 1) if n_inputs != n_outputs else None - self.relu = nn.ReLU() - # self.init_weights() - - def init_weights(self): - self.conv1.weight.data.normal_(0, 0.01) - self.conv2.weight.data.normal_(0, 0.01) - if self.downsample is not None: - self.downsample.weight.data.normal_(0, 0.01) - - def forward(self, x): - out = self.net(x) - res = x if self.downsample is None else self.downsample(x) - return self.relu(out + res) - - -class _TemporalConvNet(nn.Module): - def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2): - super(_TemporalConvNet, self).__init__() - layers = [] - num_levels = len(num_channels) - for i in range(num_levels): - dilation_size = 2 ** i - in_channels = num_inputs if i == 0 else num_channels[i - 1] - out_channels = num_channels[i] - layers += [_TemporalBlock(in_channels, - out_channels, - kernel_size, - stride=1, - dilation=dilation_size, - padding=(kernel_size - 1) * dilation_size, - dropout=dropout)] - - self.network = nn.Sequential(*layers) - - def forward(self, x): - return self.network(x) - - -class TCNNet(BaseNetworkComponent): - def __init__(self, - intermediate_activation: str, - final_activation: Optional[str], - random_state: Optional[np.random.RandomState] = None, - **kwargs: Any) -> None: - super().__init__(intermediate_activation, final_activation, random_state) - self.config = kwargs - - @staticmethod - def get_properties(dataset_properties: Optional[Dict[str, str]] = None) -> Dict[str, Any]: - return { - "shortname": "TCN", - "name": "Temporal Convolutional Network", - } - - @staticmethod - def get_hyperparameter_search_space(dataset_properties: Optional[Dict[str, str]] = None, - min_num_blocks: int = 1, - max_num_blocks: int = 10, - min_num_filters: int = 4, - max_num_filters: int = 64, - min_kernel_size: int = 4, - max_kernel_size: int = 64, - min_dropout: float = 0.0, - max_dropout: float = 0.5 - ) -> ConfigurationSpace: - cs = ConfigurationSpace() - - num_blocks_hp = UniformIntegerHyperparameter("num_blocks", - lower=min_num_blocks, - upper=max_num_blocks) - cs.add_hyperparameter(num_blocks_hp) - - kernel_size_hp = UniformIntegerHyperparameter("kernel_size", - lower=min_kernel_size, - upper=max_kernel_size) - cs.add_hyperparameter(kernel_size_hp) - - use_dropout_hp = CategoricalHyperparameter("use_dropout", - choices=[True, False]) - cs.add_hyperparameter(use_dropout_hp) - - dropout_hp = UniformFloatHyperparameter("dropout", - lower=min_dropout, - upper=max_dropout) - cs.add_hyperparameter(dropout_hp) - cs.add_condition(CS.EqualsCondition(dropout_hp, use_dropout_hp, True)) - - for i in range(0, max_num_blocks): - num_filters_hp = UniformIntegerHyperparameter(f"num_filters_{i}", - lower=min_num_filters, - upper=max_num_filters) - cs.add_hyperparameter(num_filters_hp) - if i >= min_num_blocks: - cs.add_condition(CS.GreaterThanCondition( - num_filters_hp, num_blocks_hp, i)) - - return cs - - def build_network(self, in_feature: int, out_features: int) -> nn.Module: - num_channels = [self.config["num_filters_0"]] - for i in range(1, self.config["num_blocks"]): - num_channels.append(self.config[f"num_filters_{i}"]) - tcn = _TemporalConvNet(in_feature, - num_channels, - kernel_size=self.config["kernel_size"], - dropout=self.config["dropout"] if self.config["use_dropout"] else 0.0 - ) - fc_layers = [nn.Linear(in_features=num_channels[-1], - out_features=out_features)] - network = nn.Sequential(tcn, *fc_layers) - return network diff --git a/autoPyTorch/pipeline/components/setup/network/backbone/__init__.py b/autoPyTorch/pipeline/components/setup/network/backbone/__init__.py index e69de29..3c817e5 100644 --- a/autoPyTorch/pipeline/components/setup/network/backbone/__init__.py +++ b/autoPyTorch/pipeline/components/setup/network/backbone/__init__.py @@ -0,0 +1,29 @@ +import os +from typing import Dict +from collections import OrderedDict + +from autoPyTorch.pipeline.components.setup.network.backbone.base_backbone import BaseBackbone +from autoPyTorch.pipeline.components.setup.network.backbone.tabular import MLPBackbone, ShapedMLPBackbone +from autoPyTorch.pipeline.components.setup.network.backbone.time_series import InceptionTimeBackbone, TCNBackbone + +from autoPyTorch.pipeline.components.base_component import ( + ThirdPartyComponents, + find_components, +) + +_directory = os.path.split(__file__)[0] +_backbones = find_components(__package__, + _directory, + BaseBackbone) +_addons = ThirdPartyComponents(BaseBackbone) + + +def add_backbone(backbone: BaseBackbone): + _addons.add_component(backbone) + + +def get_available_backbones() -> Dict[str, BaseBackbone]: + backbones = OrderedDict() + backbones.update(_backbones) + backbones.update(_addons.components) + return backbones diff --git a/autoPyTorch/pipeline/components/setup/network/backbone/base_backbone.py b/autoPyTorch/pipeline/components/setup/network/backbone/base_backbone.py index c8cc1ac..b6a95b3 100644 --- a/autoPyTorch/pipeline/components/setup/network/backbone/base_backbone.py +++ b/autoPyTorch/pipeline/components/setup/network/backbone/base_backbone.py @@ -1,8 +1,13 @@ +from abc import abstractmethod from typing import Set, Any, Dict, Tuple + import torch from torch import nn -from autoPyTorch.pipeline.components.base_component import autoPyTorchComponent, BaseEstimator +from autoPyTorch.pipeline.components.base_component import BaseEstimator +from autoPyTorch.pipeline.components.base_component import ( + autoPyTorchComponent, +) class BaseBackbone(autoPyTorchComponent): @@ -15,12 +20,14 @@ def __init__(self, self.config = kwargs def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator: - input_shape = X["input_shape"] - self.backbone = self.build_backbone(input_shape=input_shape) + """ + Not used. Just for API compatibility. + """ return self + @abstractmethod def build_backbone(self, input_shape: Tuple[int, ...]) -> nn.Module: - raise NotImplementedError + raise NotImplementedError() def get_output_shape(self, input_shape: Tuple[int, ...]) -> Tuple[int, ...]: """ @@ -32,4 +39,4 @@ def get_output_shape(self, input_shape: Tuple[int, ...]) -> Tuple[int, ...]: placeholder = torch.randn((1, *input_shape), dtype=torch.float) with torch.no_grad(): output = self.backbone(placeholder) - return output.shape + return tuple(output.shape[1:]) diff --git a/autoPyTorch/pipeline/components/setup/network/backbone/tabular.py b/autoPyTorch/pipeline/components/setup/network/backbone/tabular.py index da71dee..83d8215 100644 --- a/autoPyTorch/pipeline/components/setup/network/backbone/tabular.py +++ b/autoPyTorch/pipeline/components/setup/network/backbone/tabular.py @@ -21,7 +21,7 @@ } -class MLP(BaseBackbone): +class MLPBackbone(BaseBackbone): """ This component automatically creates a Multi Layer Perceptron based on a given config. @@ -73,8 +73,8 @@ def _add_layer(self, layers: List[torch.nn.Module], in_features: int, out_featur @staticmethod def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, str]: return { - 'shortname': 'MLP', - 'name': 'Multi Layer Perceptron', + 'shortname': 'MLPBackbone', + 'name': 'MLPBackbone', } @staticmethod @@ -141,7 +141,7 @@ def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, return cs -class ShapedMLP(BaseBackbone): +class ShapedMLPBackbone(BaseBackbone): """ Implementation of a Shaped MLP -- an MLP with the number of units arranged so that a given shape is honored @@ -188,8 +188,8 @@ def _add_layer(self, layers: List[torch.nn.Module], @staticmethod def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, str]: return { - 'shortname': 'ShapedMLP', - 'name': 'Shaped Multi Layer Perceptron', + 'shortname': 'ShapedMLPBackbone', + 'name': 'ShapedMLPBackbone', } @staticmethod diff --git a/autoPyTorch/pipeline/components/setup/network/backbone/time_series.py b/autoPyTorch/pipeline/components/setup/network/backbone/time_series.py new file mode 100644 index 0000000..185d720 --- /dev/null +++ b/autoPyTorch/pipeline/components/setup/network/backbone/time_series.py @@ -0,0 +1,301 @@ +from typing import Any, Dict, List, Optional, Tuple + +import ConfigSpace as CS +from ConfigSpace.configuration_space import ConfigurationSpace +from ConfigSpace.hyperparameters import ( + CategoricalHyperparameter, + UniformFloatHyperparameter, + UniformIntegerHyperparameter +) + +import torch +from torch import nn +from torch.nn.utils import weight_norm + +from autoPyTorch.pipeline.components.setup.network.backbone.base_backbone import BaseBackbone + + +# Code inspired by https://github.com/hfawaz/InceptionTime +# Paper: https://arxiv.org/pdf/1909.04939.pdf +class _InceptionBlock(nn.Module): + def __init__(self, n_inputs, n_filters, kernel_size, bottleneck=None): + super(_InceptionBlock, self).__init__() + self.n_filters = n_filters + self.bottleneck = None \ + if bottleneck is None \ + else nn.Conv1d(n_inputs, bottleneck, kernel_size=1) + kernel_sizes = [kernel_size // (2 ** i) for i in range(3)] + n_inputs = n_inputs if bottleneck is None else bottleneck + # create 3 conv layers with different kernel sizes which are applied in parallel + self.pad1 = nn.ConstantPad1d( + padding=self.padding(kernel_sizes[0]), value=0) + self.conv1 = nn.Conv1d(n_inputs, n_filters, kernel_sizes[0]) + self.pad2 = nn.ConstantPad1d( + padding=self.padding(kernel_sizes[1]), value=0) + self.conv2 = nn.Conv1d(n_inputs, n_filters, kernel_sizes[1]) + self.pad3 = nn.ConstantPad1d( + padding=self.padding(kernel_sizes[2]), value=0) + self.conv3 = nn.Conv1d(n_inputs, n_filters, kernel_sizes[2]) + # create 1 maxpool and conv layer which are also applied in parallel + self.maxpool = nn.MaxPool1d(kernel_size=3, stride=1, padding=1) + self.convpool = nn.Conv1d(n_inputs, n_filters, 1) + + self.bn = nn.BatchNorm1d(4 * n_filters) + + def padding(self, kernel_size): + if kernel_size % 2 == 0: + return kernel_size // 2, kernel_size // 2 - 1 + else: + return kernel_size // 2, kernel_size // 2 + + def get_n_outputs(self): + return 4 * self.n_filters + + def forward(self, x): + if self.bottleneck is not None: + x = self.bottleneck(x) + x1 = self.conv1(self.pad1(x)) + x2 = self.conv2(self.pad2(x)) + x3 = self.conv3(self.pad3(x)) + x4 = self.convpool(self.maxpool(x)) + x = torch.cat([x1, x2, x3, x4], dim=1) + x = self.bn(x) + return torch.relu(x) + + +class _ResidualBlock(nn.Module): + def __init__(self, n_res_inputs, n_outputs): + super(_ResidualBlock, self).__init__() + self.shortcut = nn.Conv1d(n_res_inputs, n_outputs, 1, bias=False) + self.bn = nn.BatchNorm1d(n_outputs) + + def forward(self, x, res): + shortcut = self.shortcut(res) + shortcut = self.bn(shortcut) + x += shortcut + return torch.relu(x) + + +class _InceptionTime(nn.Module): + def __init__(self, + in_features: int, + config: Dict[str, Any]) -> None: + super().__init__() + self.config = config + n_inputs = in_features + n_filters = self.config["num_filters"] + bottleneck_size = self.config["bottleneck_size"] + kernel_size = self.config["kernel_size"] + n_res_inputs = in_features + for i in range(self.config["num_blocks"]): + block = _InceptionBlock(n_inputs=n_inputs, + n_filters=n_filters, + bottleneck=bottleneck_size, + kernel_size=kernel_size) + self.__setattr__(f"inception_block_{i}", block) + + # add a residual block after every 3 inception blocks + if i % 3 == 2: + n_res_outputs = block.get_n_outputs() + self.__setattr__(f"residual_block_{i}", _ResidualBlock(n_res_inputs=n_res_inputs, + n_outputs=n_res_outputs)) + n_res_inputs = n_res_outputs + n_inputs = block.get_n_outputs() + + def forward(self, x): + # swap sequence and feature dimensions for use with convolutional nets + x = x.transpose(1, 2).contiguous() + res = x + for i in range(self.config["num_blocks"]): + x = self.__getattr__(f"inception_block_{i}")(x) + if i % 3 == 2: + x = self.__getattr__(f"residual_block_{i}")(x, res) + res = x + x = x.transpose(1, 2).contiguous() + return x + + +class InceptionTimeBackbone(BaseBackbone): + supported_tasks = {"time_series_classification", "time_series_regression"} + + def build_backbone(self, input_shape: Tuple[int, ...]) -> nn.Module: + assert len(input_shape) == 2 + backbone = _InceptionTime(in_features=input_shape[-1], + config=self.config) + return backbone + + @staticmethod + def get_properties(dataset_properties: Optional[Dict[str, str]] = None) -> Dict[str, Any]: + return { + 'shortname': 'InceptionTimeBackbone', + 'name': 'InceptionTimeBackbone', + } + + @staticmethod + def get_hyperparameter_search_space(dataset_properties: Optional[Dict[str, str]] = None, + min_num_blocks: int = 1, + max_num_blocks: int = 10, + min_num_filters: int = 16, + max_num_filters: int = 64, + min_kernel_size: int = 32, + max_kernel_size: int = 64, + min_bottleneck_size: int = 16, + max_bottleneck_size: int = 64, + ) -> ConfigurationSpace: + cs = ConfigurationSpace() + + num_blocks_hp = UniformIntegerHyperparameter("num_blocks", + lower=min_num_blocks, + upper=max_num_blocks) + cs.add_hyperparameter(num_blocks_hp) + + num_filters_hp = UniformIntegerHyperparameter("num_filters", + lower=min_num_filters, + upper=max_num_filters) + cs.add_hyperparameter(num_filters_hp) + + bottleneck_size_hp = UniformIntegerHyperparameter("bottleneck_size", + lower=min_bottleneck_size, + upper=max_bottleneck_size) + cs.add_hyperparameter(bottleneck_size_hp) + + kernel_size_hp = UniformIntegerHyperparameter("kernel_size", + lower=min_kernel_size, + upper=max_kernel_size) + cs.add_hyperparameter(kernel_size_hp) + return cs + + +# Chomp1d, TemporalBlock and TemporalConvNet copied from +# https://github.com/locuslab/TCN/blob/master/TCN/tcn.py, Carnegie Mellon University Locus Labs +# Paper: https://arxiv.org/pdf/1803.01271.pdf +class _Chomp1d(nn.Module): + def __init__(self, chomp_size): + super(_Chomp1d, self).__init__() + self.chomp_size = chomp_size + + def forward(self, x): + return x[:, :, :-self.chomp_size].contiguous() + + +class _TemporalBlock(nn.Module): + def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2): + super(_TemporalBlock, self).__init__() + self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size, + stride=stride, padding=padding, dilation=dilation)) + self.chomp1 = _Chomp1d(padding) + self.relu1 = nn.ReLU() + self.dropout1 = nn.Dropout(dropout) + + self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size, + stride=stride, padding=padding, dilation=dilation)) + self.chomp2 = _Chomp1d(padding) + self.relu2 = nn.ReLU() + self.dropout2 = nn.Dropout(dropout) + + self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1, + self.conv2, self.chomp2, self.relu2, self.dropout2) + self.downsample = nn.Conv1d( + n_inputs, n_outputs, 1) if n_inputs != n_outputs else None + self.relu = nn.ReLU() + # self.init_weights() + + def init_weights(self): + self.conv1.weight.data.normal_(0, 0.01) + self.conv2.weight.data.normal_(0, 0.01) + if self.downsample is not None: + self.downsample.weight.data.normal_(0, 0.01) + + def forward(self, x): + out = self.net(x) + res = x if self.downsample is None else self.downsample(x) + return self.relu(out + res) + + +class _TemporalConvNet(nn.Module): + def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2): + super(_TemporalConvNet, self).__init__() + layers = [] + num_levels = len(num_channels) + for i in range(num_levels): + dilation_size = 2 ** i + in_channels = num_inputs if i == 0 else num_channels[i - 1] + out_channels = num_channels[i] + layers += [_TemporalBlock(in_channels, + out_channels, + kernel_size, + stride=1, + dilation=dilation_size, + padding=(kernel_size - 1) * dilation_size, + dropout=dropout)] + self.network = nn.Sequential(*layers) + + def forward(self, x): + return self.network(x) + + +class TCNBackbone(BaseBackbone): + supported_tasks = {"time_series_classification", "time_series_regression"} + + def build_backbone(self, input_shape: Tuple[int, ...]) -> nn.Module: + assert len(input_shape) == 2 + num_channels = [self.config["num_filters_0"]] + for i in range(1, self.config["num_blocks"]): + num_channels.append(self.config[f"num_filters_{i}"]) + backbone = _TemporalConvNet(input_shape[-1], + num_channels, + kernel_size=self.config["kernel_size"], + dropout=self.config["dropout"] if self.config["use_dropout"] else 0.0 + ) + return backbone + + @staticmethod + def get_properties(dataset_properties: Optional[Dict[str, str]] = None) -> Dict[str, Any]: + return { + "shortname": "TCNBackbone", + "name": "TCNBackbone", + } + + @staticmethod + def get_hyperparameter_search_space(dataset_properties: Optional[Dict[str, str]] = None, + min_num_blocks: int = 1, + max_num_blocks: int = 10, + min_num_filters: int = 4, + max_num_filters: int = 64, + min_kernel_size: int = 4, + max_kernel_size: int = 64, + min_dropout: float = 0.0, + max_dropout: float = 0.5 + ) -> ConfigurationSpace: + cs = ConfigurationSpace() + + num_blocks_hp = UniformIntegerHyperparameter("num_blocks", + lower=min_num_blocks, + upper=max_num_blocks) + cs.add_hyperparameter(num_blocks_hp) + + kernel_size_hp = UniformIntegerHyperparameter("kernel_size", + lower=min_kernel_size, + upper=max_kernel_size) + cs.add_hyperparameter(kernel_size_hp) + + use_dropout_hp = CategoricalHyperparameter("use_dropout", + choices=[True, False]) + cs.add_hyperparameter(use_dropout_hp) + + dropout_hp = UniformFloatHyperparameter("dropout", + lower=min_dropout, + upper=max_dropout) + cs.add_hyperparameter(dropout_hp) + cs.add_condition(CS.EqualsCondition(dropout_hp, use_dropout_hp, True)) + + for i in range(0, max_num_blocks): + num_filters_hp = UniformIntegerHyperparameter(f"num_filters_{i}", + lower=min_num_filters, + upper=max_num_filters) + cs.add_hyperparameter(num_filters_hp) + if i >= min_num_blocks: + cs.add_condition(CS.GreaterThanCondition( + num_filters_hp, num_blocks_hp, i)) + + return cs diff --git a/autoPyTorch/pipeline/components/setup/network/base_network.py b/autoPyTorch/pipeline/components/setup/network/base_network.py index e782302..15fbb7e 100644 --- a/autoPyTorch/pipeline/components/setup/network/base_network.py +++ b/autoPyTorch/pipeline/components/setup/network/base_network.py @@ -1,9 +1,7 @@ -import numbers from abc import abstractmethod -from typing import Any, Dict, Optional +from typing import Any, Dict, Tuple, Optional import numpy as np - import torch from autoPyTorch.pipeline.components.setup.base_setup import autoPyTorchSetupComponent @@ -14,12 +12,14 @@ class BaseNetworkComponent(autoPyTorchSetupComponent): in Auto-Pytorch""" def __init__( - self, - random_state: Optional[np.random.RandomState] = None, + self, + random_state: Optional[np.random.RandomState] = None, + device: Optional[torch.device] = None ) -> None: + super().__init__() self.network = None self.random_state = random_state - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device def fit(self, X: Dict[str, Any], y: Any = None) -> autoPyTorchSetupComponent: """ @@ -36,10 +36,11 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> autoPyTorchSetupComponent: # information to fit this stage self.check_requirements(X, y) - in_features = X['num_features'] - out_features = X['num_classes'] + input_shape = X['input_shape'] + output_shape = X['output_shape'] - self.network = self.build_network(in_features, out_features) + self.network = self.build_network(input_shape=input_shape, + output_shape=output_shape) # Properly set the network training device self.to(self.device) @@ -47,14 +48,8 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> autoPyTorchSetupComponent: return self @abstractmethod - def build_network(self, in_feature: int, out_features: int) -> torch.nn.Module: - """This method returns a pytorch network, that is dynamically built - using: - - common network arguments from the base class: - * intermediate_activation - * final_activation - + def build_network(self, input_shape: Tuple[int, ...], output_shape: Tuple[int, ...]) -> torch.nn.Module: + """This method returns a pytorch network, that is dynamically built using a self.config that is network specific, and contains the additional configuration hyperparameters to build a domain specific network """ @@ -92,30 +87,28 @@ def check_requirements(self, X: Dict[str, Any], y: Any = None) -> None: # For the Network, we need the number of input features, # to build the first network layer - if 'num_features' not in X.keys(): - raise ValueError("Could not parse the number of input features in the fit dictionary " - "To fit a network, the number of features is needed to define " + if 'input_shape' not in X.keys(): + raise ValueError("Could not find the input shape in the fit dictionary " + "To fit a network, the input shape is needed to define " "the hidden layers, yet the dict contains only: {}".format( - X.keys() - ) - ) + X.keys()) + ) - assert isinstance(X['num_features'], numbers.Integral), "num_features: {}".format( - type(X['num_features']) - ) + # assert isinstance(X['input_shape'], numbers.Integral), "num_features: {}".format( + # type(X['num_features']) + # ) # For the Network, we need the number of classes, # to build the last layer - if 'num_classes' not in X: - raise ValueError("Could not parse the number of classes in the fit dictionary " - "To fit a network, the number of classes is needed to define " + if 'output_shape' not in X: + raise ValueError("Could not the output shape in the fit dictionary " + "To fit a network, the output shape is needed to define " "the hidden layers, yet the dict contains only: {}".format( - X.keys() - ) - ) - assert isinstance(X['num_classes'], numbers.Integral), "num_classes: {}".format( - type(X['num_classes']) - ) + X.keys()) + ) + # assert isinstance(X['num_classes'], numbers.Integral), "num_classes: {}".format( + # type(X['num_classes']) + # ) def get_network_weights(self) -> torch.nn.parameter.Parameter: """Returns the weights of the network""" diff --git a/autoPyTorch/pipeline/components/setup/network/base_network_choice.py b/autoPyTorch/pipeline/components/setup/network/base_network_choice.py index 51c78b1..c3b4cf2 100644 --- a/autoPyTorch/pipeline/components/setup/network/base_network_choice.py +++ b/autoPyTorch/pipeline/components/setup/network/base_network_choice.py @@ -15,7 +15,6 @@ ) from autoPyTorch.pipeline.components.setup.network.base_network import BaseNetworkComponent - directory = os.path.split(__file__)[0] _networks = find_components(__package__, directory, @@ -31,10 +30,8 @@ class NetworkChoice(autoPyTorchChoice): def get_components(self) -> Dict[str, autoPyTorchComponent]: """Returns the available network components - Args: None - Returns: Dict[str, autoPyTorchComponent]: all baseNetwork components available as choices @@ -45,14 +42,13 @@ def get_components(self) -> Dict[str, autoPyTorchComponent]: return components def get_available_components( - self, - dataset_properties: Optional[Dict[str, str]] = None, - include: List[str] = None, - exclude: List[str] = None, + self, + dataset_properties: Optional[Dict[str, str]] = None, + include: List[str] = None, + exclude: List[str] = None, ) -> Dict[str, autoPyTorchComponent]: """Filters out components based on user provided include/exclude directives, as well as the dataset properties - Args: include (Optional[Dict[str, Any]]): what hyper-parameter configurations to honor when creating the configuration space @@ -60,7 +56,6 @@ def get_available_components( to remove from the configuration space dataset_properties (Optional[Dict[str, Union[str, int]]]): Caracteristics of the dataset to guide the pipeline choices of components - Returns: Dict[str, autoPyTorchComponent]: A filtered dict of Network components @@ -101,21 +96,19 @@ def get_available_components( return components_dict def get_hyperparameter_search_space( - self, - dataset_properties: Optional[Dict[str, str]] = None, - default: Optional[str] = None, - include: Optional[List[str]] = None, - exclude: Optional[List[str]] = None, + self, + dataset_properties: Optional[Dict[str, str]] = None, + default: Optional[str] = None, + include: Optional[List[str]] = None, + exclude: Optional[List[str]] = None, ) -> ConfigurationSpace: """Returns the configuration space of the current chosen components - Args: dataset_properties (Optional[Dict[str, str]]): Describes the dataset to work on default (Optional[str]): Default component to use include: Optional[Dict[str, Any]]: what components to include. It is an exhaustive list, and will exclusively use this components. exclude: Optional[Dict[str, Any]]: which components to skip - Returns: ConfigurationSpace: the configuration space of the hyper-parameters of the chosen component @@ -134,8 +127,7 @@ def get_hyperparameter_search_space( raise ValueError("No Network found") if default is None: - defaults = ['MLPNet', - ] + defaults = ['BackboneHeadNet'] for default_ in defaults: if default_ in available_networks: default = default_ diff --git a/autoPyTorch/pipeline/components/setup/network/head/__init__.py b/autoPyTorch/pipeline/components/setup/network/head/__init__.py index e69de29..26d27f8 100644 --- a/autoPyTorch/pipeline/components/setup/network/head/__init__.py +++ b/autoPyTorch/pipeline/components/setup/network/head/__init__.py @@ -0,0 +1,29 @@ +import os +from typing import Dict +from collections import OrderedDict + +from autoPyTorch.pipeline.components.setup.network.head.base_head import BaseHead +from autoPyTorch.pipeline.components.setup.network.head.fully_connected import FullyConnectedHead +from autoPyTorch.pipeline.components.setup.network.head.fully_convolutional import FullyConvolutionalHead + +from autoPyTorch.pipeline.components.base_component import ( + ThirdPartyComponents, + find_components, +) + +_directory = os.path.split(__file__)[0] +_heads = find_components(__package__, + _directory, + BaseHead) +_addons = ThirdPartyComponents(BaseHead) + + +def add_head(head: BaseHead): + _addons.add_component(head) + + +def get_available_heads() -> Dict[str, BaseHead]: + heads = OrderedDict() + heads.update(_heads) + heads.update(_addons.components) + return heads diff --git a/autoPyTorch/pipeline/components/setup/network/head/base_head.py b/autoPyTorch/pipeline/components/setup/network/head/base_head.py index ef0fce7..1e2c586 100644 --- a/autoPyTorch/pipeline/components/setup/network/head/base_head.py +++ b/autoPyTorch/pipeline/components/setup/network/head/base_head.py @@ -1,8 +1,9 @@ +from abc import abstractmethod from typing import Set, Any, Dict, Tuple + from torch import nn from autoPyTorch.pipeline.components.base_component import autoPyTorchComponent, BaseEstimator -from autoPyTorch.pipeline.components.setup.network.head import fully_convolutional class BaseHead(autoPyTorchComponent): @@ -15,14 +16,11 @@ def __init__(self, self.config = kwargs def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator: - input_shape = X["backbone_output_shape"] - output_shape = X["head_output_shape"] - self.head = self.build_head(input_shape=input_shape, output_shape=output_shape) + """ + Not used. Just for API compatibility. + """ return self + @abstractmethod def build_head(self, input_shape: Tuple[int, ...], output_shape: Tuple[int, ...]) -> nn.Module: - raise NotImplementedError - - -def get_available_heads() -> Set[BaseHead]: - return {fully_convolutional.ImageHead} + raise NotImplementedError()