Skip to content

Commit e65f79e

Browse files
authored
Add NetVLAD layer to TF add-ons. (#1237)
* Add NetVLAD layer to TF add-ons.
1 parent 8967c38 commit e65f79e

File tree

4 files changed

+213
-0
lines changed

4 files changed

+213
-0
lines changed

.github/CODEOWNERS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949

5050
/tensorflow_addons/layers/gelu*.py @aakashkumarnain
5151
/tensorflow_addons/layers/maxout*.py @failure-to-thrive
52+
/tensorflow_addons/layers/netvlad*.py @joel-shor
5253
/tensorflow_addons/layers/normalizations*.py @smokrow
5354
/tensorflow_addons/layers/optical_flow*.py @failure-to-thrive
5455
/tensorflow_addons/layers/poincare*.py @rahulunair

tensorflow_addons/layers/BUILD

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ py_library(
88
"__init__.py",
99
"gelu.py",
1010
"maxout.py",
11+
"netvlad.py",
1112
"normalizations.py",
1213
"optical_flow.py",
1314
"poincare.py",
@@ -134,3 +135,15 @@ py_test(
134135
":layers",
135136
],
136137
)
138+
139+
py_test(
140+
name = "netvlad_test",
141+
size = "small",
142+
srcs = [
143+
"netvlad_test.py",
144+
],
145+
main = "netvlad_test.py",
146+
deps = [
147+
":layers",
148+
],
149+
)

tensorflow_addons/layers/netvlad.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""NetVLAD keras layer."""
16+
17+
import math
18+
import tensorflow as tf
19+
from typeguard import typechecked
20+
21+
22+
@tf.keras.utils.register_keras_serializable(package="Addons")
23+
class NetVLAD(tf.keras.layers.Layer):
24+
"""Applies NetVLAD to the input.
25+
26+
This is a fully-differentiable version of "Vector of Locally Aggregated Descriptors" commonly used in image
27+
retrieval. It is also used in audio retrieval, and audio represenation learning (ex
28+
"Towards Learning a Universal Non-Semantic Representation of Speech", https://arxiv.org/abs/2002.12764).
29+
30+
"NetVLAD: CNN architecture for weakly supervised place recognition"
31+
Relja Arandjelovic, Petr Gronat, Akihiko Torii, Tomas Pajdla, Josef Sivic.
32+
https://arxiv.org/abs/1511.07247
33+
34+
Arguments:
35+
num_clusters: The number of clusters to use.
36+
Input shape:
37+
3D tensor with shape: `(batch_size, time, feature_dim)`.
38+
Output shape:
39+
2D tensor with shape: `(batch_size, feature_dim * num_clusters)`.
40+
"""
41+
42+
@typechecked
43+
def __init__(self, num_clusters: int, **kwargs):
44+
super().__init__(**kwargs)
45+
if num_clusters <= 0:
46+
raise ValueError("`num_clusters` must be greater than 1: %i" % num_clusters)
47+
self.num_clusters = num_clusters
48+
49+
def build(self, input_shape):
50+
"""Keras build method."""
51+
feature_dim = input_shape[-1]
52+
if not isinstance(feature_dim, int):
53+
feature_dim = feature_dim.value
54+
self.fc = tf.keras.layers.Dense(
55+
units=self.num_clusters,
56+
activation=tf.nn.softmax,
57+
kernel_regularizer=tf.keras.regularizers.l2(1e-5),
58+
)
59+
self.cluster_centers = self.add_weight(
60+
name="cluster_centers",
61+
shape=(1, feature_dim, self.num_clusters),
62+
initializer=tf.keras.initializers.TruncatedNormal(
63+
stddev=1.0 / math.sqrt(feature_dim)
64+
),
65+
trainable=True,
66+
)
67+
super(NetVLAD, self).build(input_shape)
68+
69+
def call(self, frames):
70+
"""Apply the NetVLAD module to the given frames.
71+
72+
Args:
73+
frames: A tensor with shape [batch_size, max_frames, feature_dim].
74+
75+
Returns:
76+
A tensor with shape [batch_size, feature_dim * num_clusters].
77+
78+
Raises:
79+
ValueError: If the `feature_dim` of input is not defined.
80+
"""
81+
frames.shape.assert_has_rank(3)
82+
feature_dim = frames.shape.as_list()[-1]
83+
if feature_dim is None:
84+
raise ValueError("Last dimension must be defined.")
85+
max_frames = tf.shape(frames)[-2]
86+
87+
# Compute soft-assignment from frames to clusters.
88+
# Essentially: softmax(w*x + b), although BN can be used instead of bias.
89+
frames = tf.reshape(frames, (-1, feature_dim))
90+
activation = self.fc(frames)
91+
activation = tf.reshape(activation, (-1, max_frames, self.num_clusters))
92+
93+
# Soft-count of number of frames assigned to each cluster.
94+
# Output shape: [batch_size, 1, num_clusters]
95+
a_sum = tf.math.reduce_sum(activation, axis=-2, keepdims=True)
96+
97+
# Compute sum_{i=1}^N softmax(w_k * x_i + b_k) * c_k(j),
98+
# for all clusters and dimensions.
99+
# Output shape: [batch_size, feature_dim, num_clusters]
100+
a = a_sum * self.cluster_centers
101+
102+
# Compute sum_{i=1}^N softmax(w_k * x_i + b_k) * x_i(j),
103+
# for all clusters and dimensions.
104+
# Output shape: (batch_size, feature_dim, num_clusters)
105+
frames = tf.reshape(frames, (-1, max_frames, feature_dim))
106+
b = tf.transpose(
107+
tf.matmul(tf.transpose(activation, perm=(0, 2, 1)), frames), perm=(0, 2, 1)
108+
)
109+
110+
# Output shape: (batch_size, feature_dim, num_clusters)
111+
vlad = b - a
112+
113+
# Normalize first across the feature dimensions.
114+
vlad = tf.nn.l2_normalize(vlad, 1)
115+
116+
# Output shape: [batch_size, feature_dim * num_clusters]
117+
vlad = tf.reshape(vlad, (-1, feature_dim * self.num_clusters))
118+
119+
# Renormalize across both the feature dimensions (already normalized) and
120+
# the cluster centers.
121+
vlad = tf.nn.l2_normalize(vlad, 1)
122+
123+
return vlad
124+
125+
def compute_output_shape(self, input_shape):
126+
input_shape = tf.TensorShape(input_shape).as_list()
127+
return tf.TensorShape([input_shape[0], input_shape[-1] * self.num_clusters])
128+
129+
def get_config(self):
130+
config = {"num_clusters": self.num_clusters}
131+
base_config = super().get_config()
132+
return dict(list(base_config.items()) + list(config.items()))
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests for NetVLAD layer."""
16+
17+
import sys
18+
19+
import pytest
20+
from absl.testing import parameterized
21+
import numpy as np
22+
import tensorflow as tf
23+
from tensorflow_addons.layers.netvlad import NetVLAD
24+
from tensorflow_addons.utils import test_utils
25+
26+
27+
@test_utils.run_all_in_graph_and_eager_modes
28+
class NetVLADTest(tf.test.TestCase, parameterized.TestCase):
29+
"""Tests for NetVLAD."""
30+
31+
@parameterized.parameters(
32+
{"num_clusters": 1}, {"num_clusters": 4},
33+
)
34+
def test_simple(self, num_clusters):
35+
test_utils.layer_test(
36+
NetVLAD,
37+
kwargs={"num_clusters": num_clusters},
38+
input_shape=(5, 4, 100),
39+
expected_output_shape=(None, num_clusters * 100),
40+
)
41+
42+
def test_unknown(self):
43+
inputs = np.random.random((5, 4, 100)).astype("float32")
44+
test_utils.layer_test(
45+
NetVLAD,
46+
kwargs={"num_clusters": 3},
47+
input_shape=(None, None, 100),
48+
input_data=inputs,
49+
expected_output_shape=(None, 3 * 100),
50+
)
51+
52+
def test_invalid_shape(self):
53+
with self.assertRaisesRegexp(
54+
ValueError, r"`num_clusters` must be greater than 1"
55+
):
56+
test_utils.layer_test(
57+
NetVLAD, kwargs={"num_clusters": 0}, input_shape=(5, 4, 20)
58+
)
59+
60+
with self.assertRaisesRegexp(ValueError, r"must have rank 3"):
61+
test_utils.layer_test(
62+
NetVLAD, kwargs={"num_clusters": 2}, input_shape=(5, 4, 4, 20)
63+
)
64+
65+
66+
if __name__ == "__main__":
67+
sys.exit(pytest.main([__file__]))

0 commit comments

Comments
 (0)