Skip to content

Commit d5cdf70

Browse files
committed
backbones and heads are now handled by a BackboneHeadNet, leaving the chance to still create custom networks that are not split into backbone and head
1 parent 5ee3da3 commit d5cdf70

14 files changed

+536
-415
lines changed

.idea/Auto-PyTorch.iml

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/misc.xml

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
from typing import Any, Dict, Optional, Tuple
2+
3+
import ConfigSpace as CS
4+
import numpy as np
5+
from ConfigSpace.configuration_space import ConfigurationSpace
6+
from ConfigSpace.hyperparameters import (
7+
CategoricalHyperparameter
8+
)
9+
from torch import nn
10+
11+
from autoPyTorch.pipeline.components.setup.network.backbone import get_available_backbones, BaseBackbone, MLPBackbone, \
12+
ShapedMLPBackbone
13+
from autoPyTorch.pipeline.components.setup.network.base_network import BaseNetworkComponent
14+
from autoPyTorch.pipeline.components.setup.network.head import get_available_heads, BaseHead, FullyConnectedHead
15+
16+
17+
class BackboneHeadNet(BaseNetworkComponent):
18+
"""
19+
Implementation of a dynamic network, that consists of a backbone and a head
20+
"""
21+
22+
def __init__(
23+
self,
24+
random_state: Optional[np.random.RandomState] = None,
25+
**kwargs: Any
26+
):
27+
super().__init__(
28+
random_state=random_state,
29+
)
30+
self.config = kwargs
31+
self._backbones = get_available_backbones()
32+
self._heads = get_available_heads()
33+
self._backbones = get_available_backbones()
34+
self._heads = get_available_heads()
35+
36+
@staticmethod
37+
def get_properties(dataset_properties: Optional[Dict[str, str]] = None) -> Dict[str, Any]:
38+
return {
39+
"shortname": "BackboneHeadNet",
40+
"name": "BackboneHeadNet",
41+
}
42+
43+
@staticmethod
44+
def get_hyperparameter_search_space(dataset_properties: Optional[Dict[str, str]] = None,
45+
**kwargs: Any) -> ConfigurationSpace:
46+
cs = ConfigurationSpace()
47+
backbones = get_available_backbones()
48+
heads = get_available_heads()
49+
50+
# filter backbones and heads for those who support the current task type
51+
task = dataset_properties["task_type"]
52+
backbones = {name: backbone for name, backbone in backbones.items() if task in backbone.supported_tasks}
53+
heads = {name: head for name, head in heads.items() if task in head.supported_tasks}
54+
55+
backbone_hp = CategoricalHyperparameter("backbone", choices=backbones.keys())
56+
head_hp = CategoricalHyperparameter("head", choices=heads.keys())
57+
cs.add_hyperparameters([backbone_hp, head_hp])
58+
59+
# for each backbone and head, add a conditional search space if this backbone or head is chosen
60+
for backbone_name in backbones.keys():
61+
backbone_cs = backbones[backbone_name].get_hyperparameter_search_space(dataset_properties)
62+
cs.add_configuration_space(backbone_name,
63+
backbone_cs,
64+
parent_hyperparameter={"parent": backbone_hp, "value": backbone_name})
65+
66+
for head_name in heads.keys():
67+
head_cs: ConfigurationSpace = heads[head_name].get_hyperparameter_search_space(dataset_properties)
68+
cs.add_configuration_space(head_name,
69+
head_cs,
70+
parent_hyperparameter={"parent": head_hp, "value": head_name})
71+
return cs
72+
73+
def build_network(self, input_shape: Tuple[int, ...], output_shape: Tuple[int, ...]) -> nn.Module:
74+
"""This method returns a pytorch network, that is dynamically built
75+
76+
a self.config that is network specific, and contains the additional
77+
configuration hyperparameters to build a domain specific network
78+
"""
79+
backbone_name = self.config["backbone"]
80+
head_name = self.config["head"]
81+
Backbone = self._backbones[backbone_name]
82+
Head = self._heads[head_name]
83+
84+
backbone = Backbone(**{k.replace(backbone_name, "").replace(":", ""): v
85+
for k, v in self.config.items() if
86+
k.startswith(backbone_name)})
87+
backbone_module = backbone.build_backbone(input_shape=input_shape)
88+
backbone_output_shape = backbone.get_output_shape(input_shape=input_shape)
89+
90+
head = Head(**{k.replace(head_name, "").replace(":", ""): v
91+
for k, v in self.config.items() if
92+
k.startswith(head_name)})
93+
head_module = head.build_head(input_shape=backbone_output_shape, output_shape=output_shape)
94+
95+
return nn.Sequential(backbone_module, head_module)
96+
97+
def __str__(self) -> str:
98+
""" Allow a nice understanding of what components where used """
99+
info = vars(self)
100+
# Remove unwanted info
101+
info.pop('network', None)
102+
info.pop('random_state', None)
103+
return f"{self.config['backbone']} -> {self.config['head']} ({str(info)})"
104+
105+
106+
if __name__ == "__main__":
107+
cs = BackboneHeadNet.get_hyperparameter_search_space(dataset_properties={"task_type": "tabular_classification"})
108+
print(cs)
109+
sample = cs.sample_configuration()
110+
bnet = BackboneHeadNet(**sample)
111+
print(bnet)
112+
net = BackboneHeadNet(**sample).build_network(**{"input_shape": (10,), "output_shape": (10,)})

autoPyTorch/pipeline/components/setup/network/InceptionTimeNet.py

-176
This file was deleted.

autoPyTorch/pipeline/components/setup/network/ResNet.py

-5
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,11 @@ class ResNet(BaseNetworkComponent):
3939

4040
def __init__(
4141
self,
42-
intermediate_activation: str,
43-
final_activation: Optional[str] = None,
4442
random_state: Optional[np.random.RandomState] = None,
4543
**kwargs: Any
4644
):
4745

4846
super().__init__(
49-
intermediate_activation=intermediate_activation,
50-
final_activation=final_activation,
5147
random_state=random_state,
5248
)
5349
self.config = kwargs
@@ -130,7 +126,6 @@ def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None,
130126
min_num_units: int = 10,
131127
max_num_units: int = 1024,
132128
) -> ConfigurationSpace:
133-
134129
cs = ConfigurationSpace()
135130

136131
# The number of groups that will compose the resnet. That is,

0 commit comments

Comments
 (0)