Skip to content

Commit 9728631

Browse files
committed
fix multi-gpu bug in lib/non_local.py
1 parent d6dfaf2 commit 9728631

File tree

3 files changed

+43
-19
lines changed

3 files changed

+43
-19
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
## Statement
55
- Only do the experiments on MNIST dataset so far.
66
- You can find the non-local block in **lib/**.
7+
- The code can support **multi-gpu** now.
78
- If there is something wrong in my code, please contact me, thanks!
89

910
There are two version **non-local.py** and **non-local-simple-version.py**.
@@ -20,3 +21,5 @@ There are two version **non-local.py** and **non-local-simple-version.py**.
2021
- Experiments on Charades dataset.
2122
- Experiments on COCO dataset.
2223
- [x] Make sure how to do the Implementation of concatenation.
24+
- [x] Support multi-gpu.
25+
- [x] Fix the bug in **lib/non_local.py** when using multi-gpu.

demo_MNIST.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def calc_acc(x, y):
3232

3333
net = Network()
3434
if torch.cuda.is_available():
35+
net = nn.DataParallel(net)
3536
net.cuda()
3637

3738
opt = torch.optim.Adam(net.parameters(), lr=cfg.LR, weight_decay=cfg.weight_decay)

lib/non_local.py

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -59,24 +59,36 @@ def __init__(self, in_channels, inter_channels=None, dimension=3, mode='embedded
5959
self.phi = None
6060
self.concat_project = None
6161

62-
if mode in ['embedded_gaussian', 'dot_product', 'concatenation']:
63-
self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
64-
kernel_size=1, stride=1, padding=0)
65-
self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
66-
kernel_size=1, stride=1, padding=0)
67-
68-
if mode == 'embedded_gaussian':
69-
self.operation_function = self._embedded_gaussian
70-
elif mode == 'dot_product':
71-
self.operation_function = self._dot_product
72-
elif mode == 'concatenation':
73-
self.operation_function = self._concatenation
74-
self.concat_project = nn.Sequential(
75-
nn.Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False),
76-
nn.ReLU()
77-
)
78-
elif mode == 'gaussian':
79-
self.operation_function = self._gaussian
62+
# if mode in ['embedded_gaussian', 'dot_product', 'concatenation']:
63+
self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
64+
kernel_size=1, stride=1, padding=0)
65+
66+
self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
67+
kernel_size=1, stride=1, padding=0)
68+
# elif mode == 'concatenation':
69+
self.concat_project = nn.Sequential(
70+
nn.Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False),
71+
nn.ReLU()
72+
)
73+
74+
# if mode in ['embedded_gaussian', 'dot_product', 'concatenation']:
75+
# self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
76+
# kernel_size=1, stride=1, padding=0)
77+
# self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
78+
# kernel_size=1, stride=1, padding=0)
79+
#
80+
# if mode == 'embedded_gaussian':
81+
# self.operation_function = self._embedded_gaussian
82+
# elif mode == 'dot_product':
83+
# self.operation_function = self._dot_product
84+
# elif mode == 'concatenation':
85+
# self.operation_function = self._concatenation
86+
# self.concat_project = nn.Sequential(
87+
# nn.Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False),
88+
# nn.ReLU()
89+
# )
90+
# elif mode == 'gaussian':
91+
# self.operation_function = self._gaussian
8092

8193
if sub_sample:
8294
self.g = nn.Sequential(self.g, max_pool(kernel_size=2))
@@ -91,7 +103,15 @@ def forward(self, x):
91103
:return:
92104
'''
93105

94-
output = self.operation_function(x)
106+
if self.mode == 'embedded_gaussian':
107+
output = self._embedded_gaussian(x)
108+
elif mode == 'dot_product':
109+
output = self._dot_product(x)
110+
elif mode == 'concatenation':
111+
output = self._concatenation(x)
112+
elif mode == 'gaussian':
113+
output = self._gaussian(x)
114+
95115
return output
96116

97117
def _embedded_gaussian(self, x):

0 commit comments

Comments
 (0)