Skip to content

Commit

Permalink
Merge branch 'development' of https://github.com/tensorflow/similarity
Browse files Browse the repository at this point in the history
…into development
  • Loading branch information
abeltheo committed Mar 31, 2023
1 parent d49ccd2 commit 0bec1e0
Show file tree
Hide file tree
Showing 8 changed files with 904 additions and 913 deletions.
28 changes: 15 additions & 13 deletions tests/samplers/test_file_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@

from tensorflow_similarity.samplers import MultiShotFileSampler

class FileSamplersTest(tf.test.TestCase):

class FileSamplersTest(tf.test.TestCase):
def _create_random_image(self, filename, size=(32, 32)):
filepath = os.path.join(tempfile.gettempdir(), filename)
image = np.random.random(size + (3,)).astype(np.float32)
tf.keras.utils.save_img(filepath, image)
return filepath


def test_multi_shot_file_sampler(self):
"""Test MultiShotFileSampler with various sizes.
Expand All @@ -29,7 +28,7 @@ def test_multi_shot_file_sampler(self):
filepaths = [self._create_random_image(filename) for filename in filenames]
examples_per_class = (2, 20)
images = [np.array(tf.keras.utils.load_img(path), dtype=np.float32) / 255 for path in filepaths]

for example_per_class in examples_per_class:
y = tf.constant([1, 2, 3, 1, 2, 3])
x = tf.constant(filepaths)
Expand All @@ -53,12 +52,20 @@ def test_multi_shot_file_sampler(self):

for x, y in zip(batch_x, batch_y):
if y == 1:
assert np.isclose(x.numpy(), images[0], atol=0.1).all() or np.isclose(x.numpy(), images[3], atol=0.1).all()
assert (
np.isclose(x.numpy(), images[0], atol=0.1).all()
or np.isclose(x.numpy(), images[3], atol=0.1).all()
)
elif y == 2:
assert np.isclose(x.numpy(), images[1], atol=0.1).all() or np.isclose(x.numpy(), images[4], atol=0.1).all()
assert (
np.isclose(x.numpy(), images[1], atol=0.1).all()
or np.isclose(x.numpy(), images[4], atol=0.1).all()
)
elif y == 3:
assert np.isclose(x.numpy(), images[2], atol=0.1).all() or np.isclose(x.numpy(), images[5], atol=0.1).all()

assert (
np.isclose(x.numpy(), images[2], atol=0.1).all()
or np.isclose(x.numpy(), images[5], atol=0.1).all()
)

def test_msfs_get_slice(self):
"""Test the multi shot file sampler get_slice method."""
Expand All @@ -84,7 +91,6 @@ def test_msfs_get_slice(self):
self.assertEqual(slice_y[0], 1)
self.assertEqual(slice_y[1], 2)


def test_msms_properties(self):
"""Test the multi shot file sampler num_examples and shape"""
filenames = ["1.jpg", "2.jpg", "3.jpg", "4.jpg"]
Expand All @@ -97,7 +103,6 @@ def test_msms_properties(self):
self.assertEqual(fs_sampler.num_examples, 4)
self.assertEqual(fs_sampler.example_shape, (128, 96, 3))


def test_small_class_size(self):
"""Test examples_per_class is > the number of class examples."""
filenames = ["1.jpg", "2.jpg", "3.jpg", "4.jpg"]
Expand All @@ -106,11 +111,8 @@ def test_small_class_size(self):
y = tf.constant([1, 1, 1, 2])
x = tf.constant(filepaths)


with self.captureWritesToStream(sys.stdout) as captured:
ms_sampler = MultiShotFileSampler(
x=x, y=y, classes_per_batch=2, examples_per_class_per_batch=3
)
ms_sampler = MultiShotFileSampler(x=x, y=y, classes_per_batch=2, examples_per_class_per_batch=3)
_, batch_y = ms_sampler.generate_batch(0)
y, _, class_counts = tf.unique_with_counts(batch_y)

Expand Down
281 changes: 139 additions & 142 deletions tests/samplers/test_memory_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,145 +7,142 @@


class MemorySamplersTest(tf.test.TestCase):

def test_valid_class_numbers(self):
"Check that sampler properly detect if num_class requests >> class avail"
y = tf.constant([1, 2, 3, 1, 2, 3, 1])
x = tf.constant([10, 20, 30, 10, 20, 30, 10])

class_per_batch = 42
with self.assertRaises(ValueError):
MultiShotMemorySampler(x=x, y=y, classes_per_batch=class_per_batch)

def test_select_examples(self):
"""Test select_examples with various sizes.
Users may sample with replacement when creating batches, so check that we
can handle when elements per class is either less than or greater than the
total count of elements in the class.
"""

examples_per_class = (2, 20)

for example_per_class in examples_per_class:
y = tf.constant([1, 2, 3, 1, 2, 3, 1])
x = tf.constant([10, 20, 30, 10, 20, 30, 10])
cls_list = [1, 3]
with self.subTest(example_per_class=example_per_class):
batch_x, batch_y = select_examples(x, y, cls_list, example_per_class)

self.assertLen(batch_y, len(cls_list) * example_per_class)
self.assertLen(batch_x, len(cls_list) * example_per_class)

for x, y in zip(batch_x, batch_y):
self.assertIn(y, cls_list)

if y == 1:
self.assertEqual(x, 10)
elif y == 3:
self.assertEqual(x, 30)

def test_multi_shot_memory_sampler(self):
"""Test MultiShotMemorySampler with various sizes.
Users may sample with replacement when creating batches, so check that we
can handle when elements per class is either less than or greater than the
total count of elements in the class.
"""

examples_per_class = (2, 20)

for example_per_class in examples_per_class:
y = tf.constant([1, 2, 3, 1, 2, 3, 1])
x = tf.constant([10, 20, 30, 10, 20, 30, 10])
class_per_batch = 2
batch_size = example_per_class * class_per_batch
with self.subTest(example_per_class=example_per_class):
ms_sampler = MultiShotMemorySampler(
x=x,
y=y,
classes_per_batch=class_per_batch,
examples_per_class_per_batch=example_per_class,
) # noqa

batch_x, batch_y = ms_sampler.generate_batch(batch_id=606)

self.assertLen(batch_y, batch_size)
self.assertLen(batch_x, batch_size)
num_classes, _ = tf.unique(batch_y)
self.assertLen(num_classes, class_per_batch)

for x, y in zip(batch_x, batch_y):
if y == 1:
self.assertEqual(x, 10)
elif y == 2:
self.assertEqual(x, 20)
elif y == 3:
self.assertEqual(x, 30)

def test_msms_get_slice(self):
"""Test the multi shot memory sampler get_slice method."""
y = tf.constant(range(4))
x = tf.constant([[0] * 10, [1] * 10, [2] * 10, [3] * 10])

ms_sampler = MultiShotMemorySampler(x=x, y=y)
# x and y are randomly shuffled so we fix the values here.
ms_sampler._x = x
ms_sampler._y = y
slice_x, slice_y = ms_sampler.get_slice(1, 2)

self.assertEqual(slice_x.shape, (2, 10))
self.assertEqual(slice_y.shape, (2,))

self.assertEqual(slice_x[0, 0], 1)
self.assertEqual(slice_x[1, 0], 2)

self.assertEqual(slice_y[0], 1)
self.assertEqual(slice_y[1], 2)

def test_msms_properties(self):
"""Test the multi shot memory sampler num_examples and shape"""
y = tf.constant(range(4))
x = tf.ones([4, 10, 20, 3])

ms_sampler = MultiShotMemorySampler(x=x, y=y)

self.assertEqual(ms_sampler.num_examples, 4)
self.assertEqual(ms_sampler.example_shape, (10, 20, 3))

def test_small_class_size(self):
"""Test examples_per_class is > the number of class examples."""
y = tf.constant([1, 1, 1, 2])
x = tf.ones([4, 10, 10, 3])

with self.captureWritesToStream(sys.stdout) as captured:
ms_sampler = MultiShotMemorySampler(
x=x, y=y, classes_per_batch=2, examples_per_class_per_batch=3
)
_, batch_y = ms_sampler.generate_batch(0)
y, _, class_counts = tf.unique_with_counts(batch_y)

self.assertAllEqual(tf.sort(y), tf.constant([1, 2]))
self.assertAllEqual(class_counts, tf.constant([3, 3]))

expected_msg = (
"WARNING: Class 2 only has 1 unique examples, but "
"examples_per_class is set to 3. The current batch will sample "
"from class examples with replacement, but you may want to "
"consider passing an Augmenter function or using the "
"SingleShotMemorySampler()."
)

match = re.search(expected_msg, captured.contents())
self.assertIsNotNone(match)

with self.captureWritesToStream(sys.stdout) as captured:
_, batch_y = ms_sampler.generate_batch(0)
y, _, class_counts = tf.unique_with_counts(batch_y)

self.assertAllEqual(tf.sort(y), tf.constant([1, 2]))
self.assertAllEqual(class_counts, tf.constant([3, 3]))

# Subsequent batch should produce the sampler warning.
match = re.search(expected_msg, captured.contents())
self.assertIsNone(match)
def test_valid_class_numbers(self):
"Check that sampler properly detect if num_class requests >> class avail"
y = tf.constant([1, 2, 3, 1, 2, 3, 1])
x = tf.constant([10, 20, 30, 10, 20, 30, 10])

class_per_batch = 42
with self.assertRaises(ValueError):
MultiShotMemorySampler(x=x, y=y, classes_per_batch=class_per_batch)

def test_select_examples(self):
"""Test select_examples with various sizes.
Users may sample with replacement when creating batches, so check that we
can handle when elements per class is either less than or greater than the
total count of elements in the class.
"""

examples_per_class = (2, 20)

for example_per_class in examples_per_class:
y = tf.constant([1, 2, 3, 1, 2, 3, 1])
x = tf.constant([10, 20, 30, 10, 20, 30, 10])
cls_list = [1, 3]
with self.subTest(example_per_class=example_per_class):
batch_x, batch_y = select_examples(x, y, cls_list, example_per_class)

self.assertLen(batch_y, len(cls_list) * example_per_class)
self.assertLen(batch_x, len(cls_list) * example_per_class)

for x, y in zip(batch_x, batch_y):
self.assertIn(y, cls_list)

if y == 1:
self.assertEqual(x, 10)
elif y == 3:
self.assertEqual(x, 30)

def test_multi_shot_memory_sampler(self):
"""Test MultiShotMemorySampler with various sizes.
Users may sample with replacement when creating batches, so check that we
can handle when elements per class is either less than or greater than the
total count of elements in the class.
"""

examples_per_class = (2, 20)

for example_per_class in examples_per_class:
y = tf.constant([1, 2, 3, 1, 2, 3, 1])
x = tf.constant([10, 20, 30, 10, 20, 30, 10])
class_per_batch = 2
batch_size = example_per_class * class_per_batch
with self.subTest(example_per_class=example_per_class):
ms_sampler = MultiShotMemorySampler(
x=x,
y=y,
classes_per_batch=class_per_batch,
examples_per_class_per_batch=example_per_class,
) # noqa

batch_x, batch_y = ms_sampler.generate_batch(batch_id=606)

self.assertLen(batch_y, batch_size)
self.assertLen(batch_x, batch_size)
num_classes, _ = tf.unique(batch_y)
self.assertLen(num_classes, class_per_batch)

for x, y in zip(batch_x, batch_y):
if y == 1:
self.assertEqual(x, 10)
elif y == 2:
self.assertEqual(x, 20)
elif y == 3:
self.assertEqual(x, 30)

def test_msms_get_slice(self):
"""Test the multi shot memory sampler get_slice method."""
y = tf.constant(range(4))
x = tf.constant([[0] * 10, [1] * 10, [2] * 10, [3] * 10])

ms_sampler = MultiShotMemorySampler(x=x, y=y)
# x and y are randomly shuffled so we fix the values here.
ms_sampler._x = x
ms_sampler._y = y
slice_x, slice_y = ms_sampler.get_slice(1, 2)

self.assertEqual(slice_x.shape, (2, 10))
self.assertEqual(slice_y.shape, (2,))

self.assertEqual(slice_x[0, 0], 1)
self.assertEqual(slice_x[1, 0], 2)

self.assertEqual(slice_y[0], 1)
self.assertEqual(slice_y[1], 2)

def test_msms_properties(self):
"""Test the multi shot memory sampler num_examples and shape"""
y = tf.constant(range(4))
x = tf.ones([4, 10, 20, 3])

ms_sampler = MultiShotMemorySampler(x=x, y=y)

self.assertEqual(ms_sampler.num_examples, 4)
self.assertEqual(ms_sampler.example_shape, (10, 20, 3))

def test_small_class_size(self):
"""Test examples_per_class is > the number of class examples."""
y = tf.constant([1, 1, 1, 2])
x = tf.ones([4, 10, 10, 3])

with self.captureWritesToStream(sys.stdout) as captured:
ms_sampler = MultiShotMemorySampler(x=x, y=y, classes_per_batch=2, examples_per_class_per_batch=3)
_, batch_y = ms_sampler.generate_batch(0)
y, _, class_counts = tf.unique_with_counts(batch_y)

self.assertAllEqual(tf.sort(y), tf.constant([1, 2]))
self.assertAllEqual(class_counts, tf.constant([3, 3]))

expected_msg = (
"WARNING: Class 2 only has 1 unique examples, but "
"examples_per_class is set to 3. The current batch will sample "
"from class examples with replacement, but you may want to "
"consider passing an Augmenter function or using the "
"SingleShotMemorySampler()."
)

match = re.search(expected_msg, captured.contents())
self.assertIsNotNone(match)

with self.captureWritesToStream(sys.stdout) as captured:
_, batch_y = ms_sampler.generate_batch(0)
y, _, class_counts = tf.unique_with_counts(batch_y)

self.assertAllEqual(tf.sort(y), tf.constant([1, 2]))
self.assertAllEqual(class_counts, tf.constant([3, 3]))

# Subsequent batch should produce the sampler warning.
match = re.search(expected_msg, captured.contents())
self.assertIsNone(match)
5 changes: 2 additions & 3 deletions tests/samplers/test_tfdataset_samplers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import tensorflow as tf
import tensorflow as tf

from tensorflow_similarity.samplers import TFDatasetMultiShotMemorySampler

Expand All @@ -8,8 +8,7 @@ def test_basic(self):
dataset_name = "mnist"
sampler = TFDatasetMultiShotMemorySampler(dataset_name=dataset_name, classes_per_batch=10)
batch = sampler.generate_batch(42)
self.assertEqual(batch[0].shape,(20, 28, 28, 1))

self.assertEqual(batch[0].shape, (20, 28, 28, 1))

def test_wrong_key(self):
dataset_name = "mnist"
Expand Down
Loading

0 comments on commit 0bec1e0

Please sign in to comment.