|
| 1 | +from typing import Any, Dict, Optional, Tuple, Union |
| 2 | + |
| 3 | +from ConfigSpace.configuration_space import ConfigurationSpace |
| 4 | +from ConfigSpace.hyperparameters import CategoricalHyperparameter |
| 5 | + |
| 6 | +import numpy as np |
| 7 | + |
| 8 | +from torch import nn |
| 9 | + |
| 10 | +from autoPyTorch.pipeline.components.setup.network_head.base_network_head import NetworkHeadComponent |
| 11 | +from autoPyTorch.pipeline.components.setup.network_head.utils import _activations |
| 12 | +from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter |
| 13 | + |
| 14 | + |
| 15 | +class NoHead(NetworkHeadComponent): |
| 16 | + """ |
| 17 | + Head which only adds a fully connected layer which takes the |
| 18 | + output of the backbone as input and outputs the predictions. |
| 19 | + Flattens any input in a array of shape [B, prod(input_shape)]. |
| 20 | + """ |
| 21 | + |
| 22 | + def build_head(self, input_shape: Tuple[int, ...], output_shape: Tuple[int, ...]) -> nn.Module: |
| 23 | + layers = [nn.Flatten()] |
| 24 | + in_features = np.prod(input_shape).item() |
| 25 | + out_features = np.prod(output_shape).item() |
| 26 | + layers.append(_activations[self.config["activation"]]()) |
| 27 | + layers.append(nn.Linear(in_features=in_features, |
| 28 | + out_features=out_features)) |
| 29 | + return nn.Sequential(*layers) |
| 30 | + |
| 31 | + @staticmethod |
| 32 | + def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, Union[str, bool]]: |
| 33 | + return { |
| 34 | + 'shortname': 'NoHead', |
| 35 | + 'name': 'NoHead', |
| 36 | + 'handles_tabular': True, |
| 37 | + 'handles_image': True, |
| 38 | + 'handles_time_series': True, |
| 39 | + } |
| 40 | + |
| 41 | + @staticmethod |
| 42 | + def get_hyperparameter_search_space( |
| 43 | + dataset_properties: Optional[Dict[str, str]] = None, |
| 44 | + activation: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter="activation", |
| 45 | + value_range=tuple(_activations.keys()), |
| 46 | + default_value=list(_activations.keys())[0]), |
| 47 | + ) -> ConfigurationSpace: |
| 48 | + cs = ConfigurationSpace() |
| 49 | + |
| 50 | + add_hyperparameter(cs, activation, CategoricalHyperparameter) |
| 51 | + |
| 52 | + return cs |
0 commit comments