From 3d8aa08aa2a1435e9a30d6866a89da5dfd23349d Mon Sep 17 00:00:00 2001 From: iKrishneel Date: Tue, 5 May 2020 16:04:27 +0900 Subject: [PATCH] updated config and other changes --- efficient_net/config.py | 40 ++++++++++++++++++--- efficient_net/network.py | 76 ++++++++++++++++++++++++---------------- 2 files changed, 81 insertions(+), 35 deletions(-) diff --git a/efficient_net/config.py b/efficient_net/config.py index 82ea483..4e1ff1e 100644 --- a/efficient_net/config.py +++ b/efficient_net/config.py @@ -1,16 +1,46 @@ #!/usr/bin/env python +from dataclasses import dataclass +from efficient_net.activation import Swish + +@dataclass class MBConfig(object): - IN_CHANNELS = None + IN_CHANNELS: int = 1 + + OUT_CHANNELS: int = 1 - OUT_CHANNELS = None + STRIDES: int = 1 + + KERNEL_SIZE: int = 1 - EXPANSION_RATIO = 1 + EXPANSION_FACTOR: int = 1 + HAS_BIAS: bool = False + + ID_SKIP: bool = True -class ENConfig(object): + BATCH_NORM_MOMENTUM: float = 0.9 - # MBConv config + BATCH_NORM_EPS: float = 1E-5 + + HAS_SE: bool = True + + DROPOUT_PROB: float = 0.5 + ACTIVATION = Swish + + TRAINING: bool = True + + @property + def identity_skip(self): + return self.ID_SKIP and \ + self.IN_CHANNELS == self.OUT_CHANNELS + + +@dataclass +class ENConfig(MBConfig): + + # MBConv config + pass diff --git a/efficient_net/network.py b/efficient_net/network.py index 82f94d0..d237cd9 100644 --- a/efficient_net/network.py +++ b/efficient_net/network.py @@ -4,7 +4,8 @@ import torch.nn as nn from torch.nn import functional as F -from activation import Swish +from efficient_net.activation import Swish +from efficient_net.config import MBConfig class ConvBNR(nn.Module): @@ -58,52 +59,63 @@ def forward(self, inp): class MBConvX(nn.Module): - - def __init__(self, - in_channels: int, - out_channels: int, - dropout_prob: float=0.0): + + def __init__(self, config: MBConfig): super(MBConvX, self).__init__() - self._use_se = True - - inner_channels = in_channels * 6 + self.config = config + inner_channels = config.IN_CHANNELS * config.EXPANSION_FACTOR + bn_momentum = config.BATCH_NORM_MOMENTUM + bn_eps = config.BATCH_NORM_EPS ex_attrs = dict( - in_channels=in_channels, out_channels=inner_channels, - kernel_size=1, stride=1, groups=1, bias=False, - bn_momentum=0.9, bn_eps=1e-5) + in_channels=config.IN_CHANNELS, + out_channels=inner_channels, + kernel_size=1, stride=1, + padding=self.padding(1, 1), + groups=1, bias=False, + bn_momentum=bn_momentum, + bn_eps=bn_eps) dw_attrs = dict( - in_channels=inner_channels, out_channels=inner_channels, - kernel_size=3, stride=1, groups=inner_channels, bias=False, - padding=1, bn_momentum=0.9, bn_eps=1e-5) + in_channels=inner_channels, + out_channels=inner_channels, + kernel_size=3, stride=1, + groups=inner_channels, bias=False, + padding=self.padding(3, 1), + bn_momentum=bn_momentum, + bn_eps=bn_eps) op_attrs = dict( - in_channels=inner_channels, out_channels=out_channels, - kernel_size=1, stride=1, groups=1, bias=False, - bn_momentum=0.9, bn_eps=1e-5) + in_channels=inner_channels, + out_channels=config.OUT_CHANNELS, + kernel_size=1, stride=1, + groups=1, bias=False, + padding=self.padding(1, 1), + bn_momentum=bn_momentum, + bn_eps=bn_eps) self.conv_ip = ConvBNR(activation=Swish, **ex_attrs) self.conv_dw = ConvBNR(activation=Swish, **dw_attrs) self.conv_op = ConvBNR(**op_attrs) - if self._use_se: + if config.HAS_SE: self._sqex = SqueezeExcitation( num_channels=inner_channels) - - # todo: dropout - - + def forward(self, inputs): x = inputs x = self.conv_ip(x) x = self.conv_dw(x) - x = self._sqex(x) if self._use_se else x + x = self._sqex(x) if self.config.HAS_SE else x x = self.conv_op(x) - # print(inputs.shape) - x = x + inputs - print(x.shape) - + if self.config.identity_skip: + x = F.dropout(x, p=self.config.DROPOUT_PROB, + training=self.config.TRAINING) + x = x + inputs return x + + @staticmethod + def padding(kernel_size: int, stride: int): + return max(kernel_size - stride, 0) // 2 class EfficientNetBase(nn.Module): @@ -117,10 +129,14 @@ def __init__(self, ): import numpy as np x = np.random.random((1, 3, 32, 32)).astype(np.float32) y = torch.from_numpy(x) - + attrs = dict(in_channels=3, out_channels=32, kernel_size=1, stride=1, bias=False, bn_momentum=0.9, bn_eps=0.001) + + config = MBConfig(IN_CHANNELS=3, OUT_CHANNELS=3, KERNEL_SIZE=1, + STRIDES=1, EXPANSION_FACTOR=6) + print(config) # m = ConvBNR(True, torch.nn.ReLU, **attrs) - m = MBConvX(3, 3) + m = MBConvX(config) m(y)