-
Notifications
You must be signed in to change notification settings - Fork 2
/
BAM.py
55 lines (41 loc) · 2.47 KB
/
BAM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# -*- coding: utf-8 -*-
"""
Created on Wed Jan 16 11:25:53 2019
@author: huyz
Reference: [BMVC2018] BAM: Bottleneck Attention Module
"""
import tensorflow as tf
import tensorflow.contrib.slim as slim
X = tf.placeholder(tf.float32, shape=[128, 32, 32, 256])
batch_norm_params = {
# Decay for moving averages
'decay': 0.995,
# epsilon to prevent 0 in variance
'epsilon': 0.001,
# force in-place updates of mean and variances estimates
'updates_collections': None,
# moving averages ends up in the trainable variables collection
'variables_collections': [tf.GraphKeys.TRAINABLE_VARIABLES]}
def BAM(inputs, batch_norm_params, reduction_ratio=16, dilation_value=4, reuse=None, scope='BAM'):
with tf.variable_scope(scope, reuse=reuse):
with slim.arg_scope([slim.conv2d, slim.fully_connected],
weights_initializer=slim.xavier_initializer(),
weights_regularizer=slim.l2_regularizer(0.0005)):
with slim.arg_scope([slim.conv2d], activation_fn=None):
input_channel = inputs.get_shape().as_list()[-1]
num_squeeze = input_channel // reduction_ratio
# Channel attention
gap = tf.reduce_mean(inputs, axis=[1, 2], keepdims=True)
channel = slim.fully_connected(gap, num_squeeze, activation_fn=None, scope='fc1')
channel = slim.fully_connected(channel, input_channel, activation_fn=None,
normalizer_fn=slim.batch_norm, normalizer_params=batch_norm_params, scope='fc2')
# Spatial attention
spatial = slim.conv2d(inputs, num_squeeze, 1, padding='SAME', scope='conv1')
spatial = slim.repeat(spatial, 2, slim.conv2d, num_squeeze, 3, padding='SAME', rate=dilation_value, scope='conv2')
spatial = slim.conv2d(spatial, 1, 1, padding='SAME', scope='conv3',
normalizer_fn=slim.batch_norm, normalizer_params=batch_norm_params)
# combined two attention branch
combined = tf.nn.sigmoid(channel + spatial)
output = inputs + inputs * combined
return output
print(BAM(X, batch_norm_params).shape)