Skip to content

Commit d9b95df

Browse files
committed
added no head (#218)
1 parent b70de2b commit d9b95df

File tree

2 files changed

+53
-1
lines changed

2 files changed

+53
-1
lines changed
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from typing import Any, Dict, Optional, Tuple, Union
2+
3+
from ConfigSpace.configuration_space import ConfigurationSpace
4+
from ConfigSpace.hyperparameters import CategoricalHyperparameter
5+
6+
import numpy as np
7+
8+
from torch import nn
9+
10+
from autoPyTorch.pipeline.components.setup.network_head.base_network_head import NetworkHeadComponent
11+
from autoPyTorch.pipeline.components.setup.network_head.utils import _activations
12+
from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter
13+
14+
15+
class NoHead(NetworkHeadComponent):
16+
"""
17+
Head which only adds a fully connected layer which takes the
18+
output of the backbone as input and outputs the predictions.
19+
Flattens any input in a array of shape [B, prod(input_shape)].
20+
"""
21+
22+
def build_head(self, input_shape: Tuple[int, ...], output_shape: Tuple[int, ...]) -> nn.Module:
23+
layers = [nn.Flatten()]
24+
in_features = np.prod(input_shape).item()
25+
out_features = np.prod(output_shape).item()
26+
layers.append(_activations[self.config["activation"]]())
27+
layers.append(nn.Linear(in_features=in_features,
28+
out_features=out_features))
29+
return nn.Sequential(*layers)
30+
31+
@staticmethod
32+
def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, Union[str, bool]]:
33+
return {
34+
'shortname': 'NoHead',
35+
'name': 'NoHead',
36+
'handles_tabular': True,
37+
'handles_image': True,
38+
'handles_time_series': True,
39+
}
40+
41+
@staticmethod
42+
def get_hyperparameter_search_space(
43+
dataset_properties: Optional[Dict[str, str]] = None,
44+
activation: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter="activation",
45+
value_range=tuple(_activations.keys()),
46+
default_value=list(_activations.keys())[0]),
47+
) -> ConfigurationSpace:
48+
cs = ConfigurationSpace()
49+
50+
add_hyperparameter(cs, activation, CategoricalHyperparameter)
51+
52+
return cs

test/test_pipeline/components/setup/test_setup_networks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def backbone(request):
1414
return request.param
1515

1616

17-
@pytest.fixture(params=['fully_connected'])
17+
@pytest.fixture(params=['fully_connected', 'no_head'])
1818
def head(request):
1919
return request.param
2020

0 commit comments

Comments
 (0)