Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Add support to attach heads to DenseNets #383

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,17 @@
"model": {
"name": "densenet",
"num_blocks": [6, 12, 24, 16],
"num_classes": 1000,
"small_input": false
"small_input": false,
"heads": [
{
"name": "fully_connected",
"unique_id": "default_head",
"num_classes": 1000,
"fork_block": "trunk_output",
"in_plane": 1024,
"zero_init_bias": true
}
]
},
"optimizer": {
"name": "sgd",
Expand Down
18 changes: 16 additions & 2 deletions classy_vision/heads/fully_connected_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@ class FullyConnectedHead(ClassyHead):
layer (:class:`torch.nn.Linear`).
"""

def __init__(self, unique_id: str, num_classes: int, in_plane: int):
def __init__(
self,
unique_id: str,
num_classes: int,
in_plane: int,
zero_init_bias: bool = False,
):
"""Constructor for FullyConnectedHead

Args:
Expand All @@ -37,6 +43,9 @@ def __init__(self, unique_id: str, num_classes: int, in_plane: int):
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = None if num_classes is None else nn.Linear(in_plane, num_classes)

if zero_init_bias:
self.fc.bias.data.zero_()

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "FullyConnectedHead":
"""Instantiates a FullyConnectedHead from a configuration.
Expand All @@ -50,7 +59,12 @@ def from_config(cls, config: Dict[str, Any]) -> "FullyConnectedHead":
"""
num_classes = config.get("num_classes", None)
in_plane = config["in_plane"]
return cls(config["unique_id"], num_classes, in_plane)
return cls(
config["unique_id"],
num_classes,
in_plane,
zero_init_bias=config.get("zero_init_bias", False),
)

def forward(self, x):
# perform average pooling:
Expand Down
105 changes: 59 additions & 46 deletions classy_vision/models/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,29 +62,6 @@ def forward(self, x):
return torch.cat([x, new_features], 1)


class _DenseBlock(nn.Sequential):
"""
Block of densely connected layers at same resolution.
"""

def __init__(self, num_layers, in_planes, growth_rate=32, expansion=4):

# assertions:
assert is_pos_int(in_planes)
assert is_pos_int(growth_rate)
assert is_pos_int(expansion)

# create block of dense layers at same resolution:
super(_DenseBlock, self).__init__()
for idx in range(num_layers):
layer = _DenseLayer(
in_planes + idx * growth_rate,
growth_rate=growth_rate,
expansion=expansion,
)
self.add_module("denselayer-%d" % (idx + 1), layer)


class _Transition(nn.Sequential):
"""
Transition layer to reduce spatial resolution.
Expand Down Expand Up @@ -130,6 +107,13 @@ def __init__(
Set `final_bn_relu` to `False` to exclude the final batchnorm and ReLU
layers. These settings are useful when
training Siamese networks.

Contains the following attachable blocks:
block{block_idx}-{idx}: This is the output of each dense block,
indexed by the block index and the index of the dense layer
transition-{idx}: This is the output of the transition layers
trunk_output: The final output of the `DenseNet`. This is
where a `fully_connected` head is normally attached.
"""
super().__init__()

Expand Down Expand Up @@ -165,31 +149,28 @@ def __init__(
)
# loop over spatial resolutions:
num_planes = init_planes
self.features = nn.Sequential()
blocks = []
for idx, num_layers in enumerate(num_blocks):

# add dense block:
block = _DenseBlock(
num_layers, num_planes, growth_rate=growth_rate, expansion=expansion
# add dense block
block = self._make_dense_block(
num_layers,
num_planes,
idx,
growth_rate=growth_rate,
expansion=expansion,
)
self.features.add_module("denseblock-%d" % (idx + 1), block)
blocks.append(block)
num_planes = num_planes + num_layers * growth_rate

# add transition layer:
if idx != len(num_blocks) - 1:
trans = _Transition(num_planes, num_planes // 2)
self.features.add_module("transition-%d" % (idx + 1), trans)
blocks.append(self.build_attachable_block(f"transition-{idx}", trans))
num_planes = num_planes // 2

# final batch normalization:
if final_bn_relu:
self.features.add_module("norm-final", nn.BatchNorm2d(num_planes))
self.features.add_module("relu-final", nn.ReLU(inplace=INPLACE))
blocks.append(self._make_trunk_output_block(num_planes, final_bn_relu))

# final classifier:
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = None if num_classes is None else nn.Linear(num_planes, num_classes)
self.num_planes = num_planes
self.features = nn.Sequential(*blocks)

# initialize weights of convolutional and batchnorm layers:
for m in self.modules():
Expand All @@ -202,6 +183,36 @@ def __init__(
elif isinstance(m, nn.Linear):
m.bias.data.zero_()

def _make_trunk_output_block(self, num_planes, final_bn_relu):
layers = nn.Sequential()
if final_bn_relu:
# final batch normalization:
layers.add_module("norm-final", nn.BatchNorm2d(num_planes))
layers.add_module("relu-final", nn.ReLU(inplace=INPLACE))
return self.build_attachable_block("trunk_output", layers)

def _make_dense_block(
self, num_layers, in_planes, block_idx, growth_rate=32, expansion=4
):
assert is_pos_int(in_planes)
assert is_pos_int(growth_rate)
assert is_pos_int(expansion)

# create a block of dense layers at same resolution:
layers = []
for idx in range(num_layers):
layers.append(
self.build_attachable_block(
f"block{block_idx}-{idx}",
_DenseLayer(
in_planes + idx * growth_rate,
growth_rate=growth_rate,
expansion=expansion,
),
)
)
return nn.Sequential(*layers)

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "DenseNet":
"""Instantiates a DenseNet from a configuration.
Expand Down Expand Up @@ -234,14 +245,16 @@ def forward(self, x):
# evaluate all dense blocks:
out = self.features(out)

# perform average pooling:
out = self.avgpool(out)

# final classifier:
out = out.view(out.size(0), -1)
if self.fc is not None:
out = self.fc(out)
return out
# By default the classification layer is implemented as one head on top
# of the last block. The head is automatically computed right after the
# last block.
head_outputs = self.execute_heads()
if len(head_outputs) == 0:
raise Exception("Expecting at least one head that generates output")
elif len(head_outputs) == 1:
return list(head_outputs.values())[0]
else:
return head_outputs

def get_optimizer_params(self):
# use weight decay on BatchNorm for DenseNets
Expand Down
13 changes: 11 additions & 2 deletions test/models_densenet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,21 @@
"small_densenet": {
"name": "densenet",
"num_blocks": [1, 1, 1, 1],
"num_classes": 1000,
"init_planes": 4,
"growth_rate": 32,
"expansion": 4,
"final_bn_relu": True,
"small_input": True,
"heads": [
{
"name": "fully_connected",
"unique_id": "default_head",
"num_classes": 1000,
"fork_block": "trunk_output",
"in_plane": 60,
"zero_init_bias": True,
}
],
}
}

Expand Down Expand Up @@ -49,5 +58,5 @@ def _test_model(self, model_config):

compare_model_state(self, state, new_state, check_heads=True)

def test_small_resnet(self):
def test_small_densenet(self):
self._test_model(MODELS["small_densenet"])