Skip to content

Commit

Permalink
Removed run_in_graph_and_eager_mode from the pairwise_distance tests. (
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrieldemarmiesse authored Mar 27, 2020
1 parent 2ff35f1 commit a00b2de
Showing 1 changed file with 28 additions and 27 deletions.
55 changes: 28 additions & 27 deletions tensorflow_addons/losses/metric_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,44 +20,45 @@
import numpy as np
import tensorflow as tf
from tensorflow_addons.losses.metric_learning import pairwise_distance
from tensorflow_addons.utils import test_utils


@test_utils.run_all_in_graph_and_eager_modes
class PairWiseDistance(tf.test.TestCase):
def test_zero_distance(self):
"""Test that equal embeddings have a pairwise distance of 0."""
equal_embeddings = tf.constant([[1.0, 0.5], [1.0, 0.5]])
def test_zero_distance():
"""Test that equal embeddings have a pairwise distance of 0."""
equal_embeddings = tf.constant([[1.0, 0.5], [1.0, 0.5]])

distances = pairwise_distance(equal_embeddings, squared=False)
self.assertAllClose(tf.math.reduce_sum(distances), 0)
distances = pairwise_distance(equal_embeddings, squared=False)
np.testing.assert_allclose(tf.math.reduce_sum(distances), 0)

def test_positive_distances(self):
"""Test that the pairwise distances are always positive."""

# Create embeddings very close to each other in [1.0 - 2e-7, 1.0 + 2e-7]
# This will encourage errors in the computation
embeddings = 1.0 + 2e-7 * tf.random.uniform([64, 6], dtype=tf.float32)
distances = pairwise_distance(embeddings, squared=False)
self.assertAllGreaterEqual(distances, 0)
def test_positive_distances():
"""Test that the pairwise distances are always positive."""

def test_correct_distance(self):
"""Compare against numpy caluclation."""
tf_embeddings = tf.constant([[0.5, 0.5], [1.0, 1.0]])
# Create embeddings very close to each other in [1.0 - 2e-7, 1.0 + 2e-7]
# This will encourage errors in the computation
embeddings = 1.0 + 2e-7 * tf.random.uniform([64, 6], dtype=tf.float32)
distances = pairwise_distance(embeddings, squared=False)
assert np.all(distances >= 0)

expected_distance = np.array([[0, np.sqrt(2) / 2], [np.sqrt(2) / 2, 0]])

distances = pairwise_distance(tf_embeddings, squared=False)
self.assertAllClose(expected_distance, distances)
def test_correct_distance():
"""Compare against numpy caluclation."""
tf_embeddings = tf.constant([[0.5, 0.5], [1.0, 1.0]])

def test_correct_distance_squared(self):
"""Compare against numpy caluclation for squared distances."""
tf_embeddings = tf.constant([[0.5, 0.5], [1.0, 1.0]])
expected_distance = np.array([[0, np.sqrt(2) / 2], [np.sqrt(2) / 2, 0]])

expected_distance = np.array([[0, 0.5], [0.5, 0]])
distances = pairwise_distance(tf_embeddings, squared=False)
np.testing.assert_allclose(expected_distance, distances)

distances = pairwise_distance(tf_embeddings, squared=True)
self.assertAllClose(expected_distance, distances)

@pytest.mark.usefixtures("maybe_run_functions_eagerly")
def test_correct_distance_squared():
"""Compare against numpy caluclation for squared distances."""
tf_embeddings = tf.constant([[0.5, 0.5], [1.0, 1.0]])

expected_distance = np.array([[0, 0.5], [0.5, 0]])

distances = pairwise_distance(tf_embeddings, squared=True)
np.testing.assert_allclose(expected_distance, distances)


if __name__ == "__main__":
Expand Down

0 comments on commit a00b2de

Please sign in to comment.