Skip to content

Commit

Permalink
updated config and other changes
Browse files Browse the repository at this point in the history
  • Loading branch information
iKrishneel committed May 5, 2020
1 parent 6fc0814 commit 3d8aa08
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 35 deletions.
40 changes: 35 additions & 5 deletions efficient_net/config.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,46 @@
#!/usr/bin/env python

from dataclasses import dataclass
from efficient_net.activation import Swish


@dataclass
class MBConfig(object):

IN_CHANNELS = None
IN_CHANNELS: int = 1

OUT_CHANNELS: int = 1

OUT_CHANNELS = None
STRIDES: int = 1

KERNEL_SIZE: int = 1

EXPANSION_RATIO = 1
EXPANSION_FACTOR: int = 1

HAS_BIAS: bool = False

ID_SKIP: bool = True

class ENConfig(object):
BATCH_NORM_MOMENTUM: float = 0.9

# MBConv config
BATCH_NORM_EPS: float = 1E-5

HAS_SE: bool = True

DROPOUT_PROB: float = 0.5

ACTIVATION = Swish

TRAINING: bool = True

@property
def identity_skip(self):
return self.ID_SKIP and \
self.IN_CHANNELS == self.OUT_CHANNELS


@dataclass
class ENConfig(MBConfig):

# MBConv config
pass
76 changes: 46 additions & 30 deletions efficient_net/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import torch.nn as nn
from torch.nn import functional as F

from activation import Swish
from efficient_net.activation import Swish
from efficient_net.config import MBConfig


class ConvBNR(nn.Module):
Expand Down Expand Up @@ -58,52 +59,63 @@ def forward(self, inp):


class MBConvX(nn.Module):

def __init__(self,
in_channels: int,
out_channels: int,
dropout_prob: float=0.0):

def __init__(self, config: MBConfig):
super(MBConvX, self).__init__()

self._use_se = True

inner_channels = in_channels * 6
self.config = config
inner_channels = config.IN_CHANNELS * config.EXPANSION_FACTOR
bn_momentum = config.BATCH_NORM_MOMENTUM
bn_eps = config.BATCH_NORM_EPS

ex_attrs = dict(
in_channels=in_channels, out_channels=inner_channels,
kernel_size=1, stride=1, groups=1, bias=False,
bn_momentum=0.9, bn_eps=1e-5)
in_channels=config.IN_CHANNELS,
out_channels=inner_channels,
kernel_size=1, stride=1,
padding=self.padding(1, 1),
groups=1, bias=False,
bn_momentum=bn_momentum,
bn_eps=bn_eps)
dw_attrs = dict(
in_channels=inner_channels, out_channels=inner_channels,
kernel_size=3, stride=1, groups=inner_channels, bias=False,
padding=1, bn_momentum=0.9, bn_eps=1e-5)
in_channels=inner_channels,
out_channels=inner_channels,
kernel_size=3, stride=1,
groups=inner_channels, bias=False,
padding=self.padding(3, 1),
bn_momentum=bn_momentum,
bn_eps=bn_eps)
op_attrs = dict(
in_channels=inner_channels, out_channels=out_channels,
kernel_size=1, stride=1, groups=1, bias=False,
bn_momentum=0.9, bn_eps=1e-5)
in_channels=inner_channels,
out_channels=config.OUT_CHANNELS,
kernel_size=1, stride=1,
groups=1, bias=False,
padding=self.padding(1, 1),
bn_momentum=bn_momentum,
bn_eps=bn_eps)

self.conv_ip = ConvBNR(activation=Swish, **ex_attrs)
self.conv_dw = ConvBNR(activation=Swish, **dw_attrs)
self.conv_op = ConvBNR(**op_attrs)

if self._use_se:
if config.HAS_SE:
self._sqex = SqueezeExcitation(
num_channels=inner_channels)

# todo: dropout



def forward(self, inputs):
x = inputs
x = self.conv_ip(x)
x = self.conv_dw(x)
x = self._sqex(x) if self._use_se else x
x = self._sqex(x) if self.config.HAS_SE else x
x = self.conv_op(x)
# print(inputs.shape)
x = x + inputs
print(x.shape)

if self.config.identity_skip:
x = F.dropout(x, p=self.config.DROPOUT_PROB,
training=self.config.TRAINING)
x = x + inputs
return x

@staticmethod
def padding(kernel_size: int, stride: int):
return max(kernel_size - stride, 0) // 2


class EfficientNetBase(nn.Module):
Expand All @@ -117,10 +129,14 @@ def __init__(self, ):
import numpy as np
x = np.random.random((1, 3, 32, 32)).astype(np.float32)
y = torch.from_numpy(x)

attrs = dict(in_channels=3, out_channels=32, kernel_size=1,
stride=1, bias=False, bn_momentum=0.9, bn_eps=0.001)

config = MBConfig(IN_CHANNELS=3, OUT_CHANNELS=3, KERNEL_SIZE=1,
STRIDES=1, EXPANSION_FACTOR=6)
print(config)
# m = ConvBNR(True, torch.nn.ReLU, **attrs)

m = MBConvX(3, 3)
m = MBConvX(config)
m(y)

0 comments on commit 3d8aa08

Please sign in to comment.