Skip to content

Commit

Permalink
added groupnorm
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrewzh112 committed Dec 15, 2020
1 parent 3fb8aaa commit e9a9f7c
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions networks/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


class ConvNormAct(nn.Module):
def __init__(self, in_channels, out_channels, conv_type='basic', mode=None, activation='relu', normalization='bn', kernel_size=None):
def __init__(self, in_channels, out_channels, conv_type='basic', mode=None, activation='relu', normalization='bn', groups=1, kernel_size=None):
super().__init__()
# type of convolution
if conv_type == 'basic' and mode is None or mode == 'down':
Expand All @@ -24,26 +24,27 @@ def __init__(self, in_channels, out_channels, conv_type='basic', mode=None, acti
if kernel_size is None:
kernel_size = 4
conv = conv(in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=2, padding=1, bias=False)
kernel_size=kernel_size, stride=2, padding=1, groups=groups, bias=False)
elif mode == 'down':
if kernel_size is None:
kernel_size = 4
conv = conv(in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=2, padding=1, bias=False)
kernel_size=kernel_size, stride=2, padding=1, groups=groups, bias=False)
else:
if kernel_size is None:
kernel_size = 3
conv = conv(in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=1, padding=1, bias=False)
kernel_size=kernel_size, stride=1, padding=1, groups=groups, bias=False)

# normalization
# TODO GroupNorm
if normalization == 'bn':
norm = nn.BatchNorm2d(out_channels)
elif normalization == 'ln':
norm = nn.LayerNorm(out_channels)
elif normalization == 'in':
norm = nn.InstanceNorm2d(out_channels)
elif normalization == 'gn':
norm = nn.GroupNorm(groups, out_channels)
else:
raise NotImplementedError('Please only choose normalization [bn, ln, in]')

Expand All @@ -66,14 +67,16 @@ def forward(self, x):


class ResBlock(nn.Module):
def __init__(self, in_channels, activation, normalization):
def __init__(self, in_channels, activation, normalization, groups=1):
super().__init__()
if normalization == 'bn':
norm = nn.BatchNorm2d(in_channels)
elif normalization == 'ln':
norm = nn.LayerNorm(in_channels)
elif normalization == 'in':
norm = nn.InstanceNorm2d(in_channels)
elif normalization == 'gn':
norm = nn.GroupNorm(groups, in_channels)
else:
raise NotImplementedError('Please only choose normalization [bn, ln, in]')

Expand All @@ -91,6 +94,7 @@ def __init__(self, in_channels, activation, normalization):
in_channels,
kernel_size=3,
padding=1,
groups=groups,
),
norm,
act,
Expand All @@ -99,6 +103,7 @@ def __init__(self, in_channels, activation, normalization):
in_channels,
kernel_size=3,
padding=1,
groups=groups,
),
norm
)
Expand Down

0 comments on commit e9a9f7c

Please sign in to comment.