Skip to content

Commit f4c11b4

Browse files
authored
Merge pull request #20 from facaiy/ENH/PoincareNormalize
Implement PoincareNormalize
2 parents d4b3114 + 40c722b commit f4c11b4

File tree

3 files changed

+179
-1
lines changed

3 files changed

+179
-1
lines changed

tensorflow_addons/layers/BUILD

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ py_library(
77
srcs = ([
88
"__init__.py",
99
"python/__init__.py",
10+
"python/poincare.py",
1011
"python/wrappers.py",
1112
]),
1213
srcs_version = "PY2AND3",
@@ -22,4 +23,17 @@ py_test(
2223
":layers_py",
2324
],
2425
srcs_version = "PY2AND3",
25-
)
26+
)
27+
28+
py_test(
29+
name = "poincare_py_test",
30+
size = "small",
31+
srcs = [
32+
"python/poincare_test.py",
33+
],
34+
main = "python/poincare_test.py",
35+
deps = [
36+
":layers_py",
37+
],
38+
srcs_version = "PY2AND3",
39+
)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Implementing PoincareNormalize layer."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
from tensorflow.python.framework import ops
22+
from tensorflow.python.keras.utils import generic_utils
23+
from tensorflow.python.keras.engine.base_layer import Layer
24+
from tensorflow.python.ops import math_ops
25+
26+
27+
class PoincareNormalize(Layer):
28+
"""Project into the Poincare ball with norm <= 1.0 - epsilon.
29+
30+
https://en.wikipedia.org/wiki/Poincare_ball_model
31+
32+
Used in
33+
Poincare Embeddings for Learning Hierarchical Representations
34+
Maximilian Nickel, Douwe Kiela
35+
https://arxiv.org/pdf/1705.08039.pdf
36+
37+
For a 1-D tensor with `axis = 0`, computes
38+
39+
(x * (1 - epsilon)) / ||x|| if ||x|| > 1 - epsilon
40+
output =
41+
x otherwise
42+
43+
For `x` with more dimensions, independently normalizes each 1-D slice along
44+
dimension `axis`.
45+
46+
Arguments:
47+
axis: Axis along which to normalize. A scalar or a vector of
48+
integers.
49+
epsilon: A small deviation from the edge of the unit sphere for numerical
50+
stability.
51+
"""
52+
53+
def __init__(self, axis=1, epsilon=1e-5, **kwargs):
54+
super(PoincareNormalize, self).__init__(**kwargs)
55+
self.axis = axis
56+
self.epsilon = epsilon
57+
58+
def call(self, inputs):
59+
x = ops.convert_to_tensor(inputs)
60+
square_sum = math_ops.reduce_sum(
61+
math_ops.square(x), self.axis, keepdims=True)
62+
x_inv_norm = math_ops.rsqrt(square_sum)
63+
x_inv_norm = math_ops.minimum((1. - self.epsilon) * x_inv_norm, 1.)
64+
outputs = math_ops.multiply(x, x_inv_norm)
65+
return outputs
66+
67+
def compute_output_shape(self, input_shape):
68+
return input_shape
69+
70+
def get_config(self):
71+
config = {'axis': self.axis, 'epsilon': self.epsilon}
72+
base_config = super(PoincareNormalize, self).get_config()
73+
return dict(list(base_config.items()) + list(config.items()))
74+
75+
76+
generic_utils._GLOBAL_CUSTOM_OBJECTS['PoincareNormalize'] = PoincareNormalize
77+
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests for PoincareNormalize layer."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import numpy as np
22+
23+
from tensorflow.python.keras import testing_utils
24+
from tensorflow.python.platform import test
25+
from tensorflow_addons.layers.python.poincare import PoincareNormalize
26+
27+
28+
class PoincareNormalizeTest(test.TestCase):
29+
def _PoincareNormalize(self, x, dim, epsilon=1e-5):
30+
if isinstance(dim, list):
31+
norm = np.linalg.norm(x, axis=tuple(dim))
32+
for d in dim:
33+
norm = np.expand_dims(norm, d)
34+
norm_x = ((1. - epsilon) * x) / norm
35+
else:
36+
norm = np.expand_dims(
37+
np.apply_along_axis(np.linalg.norm, dim, x), dim)
38+
norm_x = ((1. - epsilon) * x) / norm
39+
return np.where(norm > 1.0 - epsilon, norm_x, x)
40+
41+
def testPoincareNormalize(self):
42+
x_shape = [20, 7, 3]
43+
epsilon = 1e-5
44+
tol = 1e-6
45+
np.random.seed(1)
46+
inputs = np.random.random_sample(x_shape).astype(np.float32)
47+
48+
for dim in range(len(x_shape)):
49+
outputs_expected = self._PoincareNormalize(inputs, dim, epsilon)
50+
51+
outputs = testing_utils.layer_test(
52+
PoincareNormalize,
53+
kwargs={
54+
'axis': dim,
55+
'epsilon': epsilon
56+
},
57+
input_data=inputs,
58+
expected_output=outputs_expected)
59+
for y in outputs_expected, outputs:
60+
norm = np.linalg.norm(y, axis=dim)
61+
self.assertLessEqual(norm.max(), 1. - epsilon + tol)
62+
63+
def testPoincareNormalizeDimArray(self):
64+
x_shape = [20, 7, 3]
65+
epsilon = 1e-5
66+
tol = 1e-6
67+
np.random.seed(1)
68+
inputs = np.random.random_sample(x_shape).astype(np.float32)
69+
dim = [1, 2]
70+
71+
outputs_expected = self._PoincareNormalize(inputs, dim, epsilon)
72+
73+
outputs = testing_utils.layer_test(
74+
PoincareNormalize,
75+
kwargs={
76+
'axis': dim,
77+
'epsilon': epsilon
78+
},
79+
input_data=inputs,
80+
expected_output=outputs_expected)
81+
for y in outputs_expected, outputs:
82+
norm = np.linalg.norm(y, axis=tuple(dim))
83+
self.assertLessEqual(norm.max(), 1. - epsilon + tol)
84+
85+
86+
if __name__ == '__main__':
87+
test.main()

0 commit comments

Comments
 (0)