-
Notifications
You must be signed in to change notification settings - Fork 0
/
EIM
72 lines (61 loc) · 2.29 KB
/
EIM
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.nn.modules.conv import _ConvNd
from torch.nn.modules.utils import _pair
import numpy as np
import math
from models.layers.config import cfg
def Norm2d(in_channels):
"""
Custom Norm Function to allow flexible switching
"""
layer = getattr(cfg.MODEL,'BNFUNC')
normalizationLayer = layer(in_channels)
return normalizationLayer
class GatedSpatialConv2d(_ConvNd):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
padding=0, dilation=1, groups=1, bias=False):
"""
:param in_channels:
:param out_channels:
:param kernel_size:
:param stride:
:param padding:
:param dilation:
:param groups:
:param bias:
"""
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
super(GatedSpatialConv2d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _pair(0), groups, bias, 'zeros')
self._gate_conv = nn.Sequential(
Norm2d(in_channels + 1),
nn.Conv2d(in_channels + 1, in_channels + 1, 1),
nn.ReLU(),
nn.Conv2d(in_channels + 1, 1, 1),
Norm2d(1),
nn.Sigmoid()
)
def forward(self, input_features, gating_features):
"""
:param input_features: [NxCxHxW] featuers comming from the shape branch (canny branch).
:param gating_features: [Nx1xHxW] features comming from the texture branch (resnet). Only one channel feature map.
:return:
"""
alphas = self._gate_conv(torch.cat([input_features, gating_features], dim=1))
input_features = (input_features * (alphas + 1))
return F.conv2d(input_features, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
def reset_parameters(self):
nn.init.xavier_normal_(self.weight)
if self.bias is not None:
nn.init.zeros_(self.bias)