Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding EfficientNetV2 architecture #5450

Merged
merged 18 commits into from
Mar 2, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 119 additions & 33 deletions torchvision/models/efficientnet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import copy
import math
import warnings
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Optional, List, Sequence

Expand Down Expand Up @@ -40,8 +42,23 @@
}


class MBConvConfig:
# Stores information listed at Table 1 of the EfficientNet paper
@dataclass
class _MBConvConfig:
expand_ratio: float
kernel: int
stride: int
input_channels: int
out_channels: int
num_layers: int
block: Callable[..., nn.Module]

@staticmethod
def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None) -> int:
return _make_divisible(channels * width_mult, 8, min_value)


class MBConvConfig(_MBConvConfig):
# Stores information listed at Table 1 of the EfficientNet paper & Table 4 of the EfficientNetV2 paper
def __init__(
self,
expand_ratio: float,
Expand All @@ -52,36 +69,37 @@ def __init__(
num_layers: int,
width_mult: float,
depth_mult: float,
block: Optional[Callable[..., nn.Module]] = None
) -> None:
self.expand_ratio = expand_ratio
self.kernel = kernel
self.stride = stride
self.input_channels = self.adjust_channels(input_channels, width_mult)
self.out_channels = self.adjust_channels(out_channels, width_mult)
self.num_layers = self.adjust_depth(num_layers, depth_mult)

def __repr__(self) -> str:
s = (
f"{self.__class__.__name__}("
f"expand_ratio={self.expand_ratio}"
f", kernel={self.kernel}"
f", stride={self.stride}"
f", input_channels={self.input_channels}"
f", out_channels={self.out_channels}"
f", num_layers={self.num_layers}"
f")"
)
return s

@staticmethod
def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None) -> int:
return _make_divisible(channels * width_mult, 8, min_value)
input_channels = self.adjust_channels(input_channels, width_mult)
out_channels = self.adjust_channels(out_channels, width_mult)
num_layers = self.adjust_depth(num_layers, depth_mult)
if block is None:
block = MBConv
super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block)

@staticmethod
def adjust_depth(num_layers: int, depth_mult: float):
return int(math.ceil(num_layers * depth_mult))


class FusedMBConvConfig(_MBConvConfig):
# Stores information listed at Table 4 of the EfficientNetV2 paper
def __init__(
self,
expand_ratio: float,
kernel: int,
stride: int,
input_channels: int,
out_channels: int,
num_layers: int,
block: Optional[Callable[..., nn.Module]] = None
) -> None:
if block is None:
block = FusedMBConv
super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block)


class MBConv(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -149,26 +167,86 @@ def forward(self, input: Tensor) -> Tensor:
return result


class FusedMBConv(nn.Module):
def __init__(
self,
cnf: FusedMBConvConfig,
stochastic_depth_prob: float,
norm_layer: Callable[..., nn.Module],
**kwargs: Any,
) -> None:
super().__init__()

if not (1 <= cnf.stride <= 2):
raise ValueError("illegal stride value")

self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels

layers: List[nn.Module] = []
activation_layer = nn.SiLU

expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio)
if expanded_channels != cnf.input_channels:
# fused expand
layers.append(
ConvNormActivation(
cnf.input_channels,
expanded_channels,
kernel_size=cnf.kernel,
stride=cnf.stride,
norm_layer=norm_layer,
activation_layer=activation_layer,
)
)

# project
layers.append(
ConvNormActivation(
expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
)
)
else:
layers.append(
ConvNormActivation(
cnf.input_channels,
cnf.out_channels,
kernel_size=cnf.kernel,
stride=cnf.stride,
norm_layer=norm_layer,
activation_layer=activation_layer,
)
)

self.block = nn.Sequential(*layers)
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
self.out_channels = cnf.out_channels

def forward(self, input: Tensor) -> Tensor:
result = self.block(input)
if self.use_res_connect:
result = self.stochastic_depth(result)
result += input
return result


class EfficientNet(nn.Module):
def __init__(
self,
inverted_residual_setting: List[MBConvConfig],
dropout: float,
stochastic_depth_prob: float = 0.2,
num_classes: int = 1000,
block: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
**kwargs: Any,
) -> None:
"""
EfficientNet main class
EfficientNet V1 and V2 main class

Args:
inverted_residual_setting (List[MBConvConfig]): Network structure
dropout (float): The droupout probability
stochastic_depth_prob (float): The stochastic depth probability
num_classes (int): Number of classes
block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for mobilenet
norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use
"""
super().__init__()
Expand All @@ -178,12 +256,19 @@ def __init__(
raise ValueError("The inverted_residual_setting should not be empty")
elif not (
isinstance(inverted_residual_setting, Sequence)
and all([isinstance(s, MBConvConfig) for s in inverted_residual_setting])
and all([isinstance(s, _MBConvConfig) for s in inverted_residual_setting])
):
raise TypeError("The inverted_residual_setting should be List[MBConvConfig]")

if block is None:
block = MBConv
if "block" in kwargs:
warnings.warn(
"The parameter 'block' is deprecated since 0.13 and will be removed 0.15. "
"Please pass this information on 'MBConvConfig.block' instead."
)
if kwargs["block"] is not None:
for s in inverted_residual_setting:
if isinstance(s, MBConvConfig):
datumbox marked this conversation as resolved.
Show resolved Hide resolved
s.block = kwargs["block"]

if norm_layer is None:
norm_layer = nn.BatchNorm2d
Expand Down Expand Up @@ -215,14 +300,15 @@ def __init__(
# adjust stochastic depth probability based on the depth of the stage block
sd_prob = stochastic_depth_prob * float(stage_block_id) / total_stage_blocks

stage.append(block(block_cnf, sd_prob, norm_layer))
stage.append(block_cnf.block(block_cnf, sd_prob, norm_layer))
stage_block_id += 1

layers.append(nn.Sequential(*stage))

# building last several layers
lastconv_input_channels = inverted_residual_setting[-1].out_channels
lastconv_output_channels = 4 * lastconv_input_channels
is_v2 = any([isinstance(s, FusedMBConvConfig) for s in inverted_residual_setting])
lastconv_output_channels = 1280 if is_v2 else 4 * lastconv_input_channels
layers.append(
ConvNormActivation(
lastconv_input_channels,
Expand Down