From e9a9f7c2b419a29900ab5c0cdaa0f5845c253d7c Mon Sep 17 00:00:00 2001 From: Andrew Zhao Date: Tue, 15 Dec 2020 10:19:13 +0800 Subject: [PATCH] added groupnorm --- networks/layers.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/networks/layers.py b/networks/layers.py index 41ba54f..d5d6553 100644 --- a/networks/layers.py +++ b/networks/layers.py @@ -6,7 +6,7 @@ class ConvNormAct(nn.Module): - def __init__(self, in_channels, out_channels, conv_type='basic', mode=None, activation='relu', normalization='bn', kernel_size=None): + def __init__(self, in_channels, out_channels, conv_type='basic', mode=None, activation='relu', normalization='bn', groups=1, kernel_size=None): super().__init__() # type of convolution if conv_type == 'basic' and mode is None or mode == 'down': @@ -24,26 +24,27 @@ def __init__(self, in_channels, out_channels, conv_type='basic', mode=None, acti if kernel_size is None: kernel_size = 4 conv = conv(in_channels=in_channels, out_channels=out_channels, - kernel_size=kernel_size, stride=2, padding=1, bias=False) + kernel_size=kernel_size, stride=2, padding=1, groups=groups, bias=False) elif mode == 'down': if kernel_size is None: kernel_size = 4 conv = conv(in_channels=in_channels, out_channels=out_channels, - kernel_size=kernel_size, stride=2, padding=1, bias=False) + kernel_size=kernel_size, stride=2, padding=1, groups=groups, bias=False) else: if kernel_size is None: kernel_size = 3 conv = conv(in_channels=in_channels, out_channels=out_channels, - kernel_size=kernel_size, stride=1, padding=1, bias=False) + kernel_size=kernel_size, stride=1, padding=1, groups=groups, bias=False) # normalization - # TODO GroupNorm if normalization == 'bn': norm = nn.BatchNorm2d(out_channels) elif normalization == 'ln': norm = nn.LayerNorm(out_channels) elif normalization == 'in': norm = nn.InstanceNorm2d(out_channels) + elif normalization == 'gn': + norm = nn.GroupNorm(groups, out_channels) else: raise NotImplementedError('Please only choose normalization [bn, ln, in]') @@ -66,7 +67,7 @@ def forward(self, x): class ResBlock(nn.Module): - def __init__(self, in_channels, activation, normalization): + def __init__(self, in_channels, activation, normalization, groups=1): super().__init__() if normalization == 'bn': norm = nn.BatchNorm2d(in_channels) @@ -74,6 +75,8 @@ def __init__(self, in_channels, activation, normalization): norm = nn.LayerNorm(in_channels) elif normalization == 'in': norm = nn.InstanceNorm2d(in_channels) + elif normalization == 'gn': + norm = nn.GroupNorm(groups, in_channels) else: raise NotImplementedError('Please only choose normalization [bn, ln, in]') @@ -91,6 +94,7 @@ def __init__(self, in_channels, activation, normalization): in_channels, kernel_size=3, padding=1, + groups=groups, ), norm, act, @@ -99,6 +103,7 @@ def __init__(self, in_channels, activation, normalization): in_channels, kernel_size=3, padding=1, + groups=groups, ), norm )