forked from tensorflow/similarity
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_memory_samplers.py
150 lines (111 loc) · 4.75 KB
/
test_memory_samplers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import re
import pytest
import tensorflow as tf
from tensorflow_similarity.samplers import select_examples
from tensorflow_similarity.samplers import MultiShotMemorySampler
def test_valid_class_numbers():
"Check that sampler properly detect if num_class requests >> class avail"
y = tf.constant([1, 2, 3, 1, 2, 3, 1])
x = tf.constant([10, 20, 30, 10, 20, 30, 10])
class_per_batch = 42
with pytest.raises(ValueError):
MultiShotMemorySampler(x=x, y=y, classes_per_batch=class_per_batch)
@pytest.mark.parametrize("example_per_class", [2, 20])
def test_select_examples(example_per_class):
"""Test select_examples with various sizes.
Users may sample with replacement when creating batches, so check that we
can handle when elements per class is either less than or greater than the
total count of elements in the class.
"""
y = tf.constant([1, 2, 3, 1, 2, 3, 1])
x = tf.constant([10, 20, 30, 10, 20, 30, 10])
cls_list = [1, 3]
batch_x, batch_y = select_examples(x, y, cls_list, example_per_class)
assert len(batch_y) == len(cls_list) * example_per_class
assert len(batch_x) == len(cls_list) * example_per_class
for x, y in zip(batch_x, batch_y):
assert y in cls_list
if y == 1:
assert x == 10
elif y == 3:
assert x == 30
@pytest.mark.parametrize("example_per_class", [2, 20])
def test_multi_shot_memory_sampler(example_per_class):
"""Test MultiShotMemorySampler with various sizes.
Users may sample with replacement when creating batches, so check that we
can handle when elements per class is either less than or greater than the
total count of elements in the class.
"""
y = tf.constant([1, 2, 3, 1, 2, 3, 1])
x = tf.constant([10, 20, 30, 10, 20, 30, 10])
class_per_batch = 2
batch_size = example_per_class * class_per_batch
ms_sampler = MultiShotMemorySampler(
x=x,
y=y,
classes_per_batch=class_per_batch,
examples_per_class_per_batch=example_per_class,
) # noqa
batch_x, batch_y = ms_sampler.generate_batch(batch_id=606)
assert len(batch_y) == batch_size
assert len(batch_x) == batch_size
num_classes, _ = tf.unique(batch_y)
assert len(num_classes) == class_per_batch
for x, y in zip(batch_x, batch_y):
if y == 1:
assert x == 10
elif y == 2:
assert x == 20
elif y == 3:
assert x == 30
def test_msms_get_slice():
"""Test the multi shot memory sampler get_slice method."""
y = tf.constant(range(4))
x = tf.constant([[0] * 10, [1] * 10, [2] * 10, [3] * 10])
ms_sampler = MultiShotMemorySampler(x=x, y=y)
# x and y are randomly shuffled so we fix the values here.
ms_sampler._x = x
ms_sampler._y = y
slice_x, slice_y = ms_sampler.get_slice(1, 2)
assert slice_x.shape == (2, 10)
assert slice_y.shape == (2,)
assert slice_x[0, 0] == 1
assert slice_x[1, 0] == 2
assert slice_y[0] == 1
assert slice_y[1] == 2
def test_msms_properties():
"""Test the multi shot memory sampler num_examples and shape"""
y = tf.constant(range(4))
x = tf.ones([4, 10, 20, 3])
ms_sampler = MultiShotMemorySampler(x=x, y=y)
assert ms_sampler.num_examples == 4
assert ms_sampler.example_shape == (10, 20, 3)
def test_small_class_size(capsys):
"""Test examples_per_class is > the number of class examples."""
y = tf.constant([1, 1, 1, 2])
x = tf.ones([4, 10, 10, 3])
ms_sampler = MultiShotMemorySampler(x=x,
y=y,
classes_per_batch=2,
examples_per_class_per_batch=3)
_, batch_y = ms_sampler.generate_batch(0)
y, _, class_counts = tf.unique_with_counts(batch_y)
assert tf.math.reduce_all(tf.sort(y) == tf.constant([1, 2]))
assert tf.math.reduce_all(class_counts == tf.constant([3, 3]))
captured = capsys.readouterr()
expected_msg = (
"WARNING: Class 2 only has 1 unique examples, but "
"examples_per_class is set to 3. The current batch will sample "
"from class examples with replacement, but you may want to "
"consider passing an Augmenter function or using the "
"SingleShotMemorySampler().")
match = re.search(expected_msg, captured.out)
assert bool(match)
_, batch_y = ms_sampler.generate_batch(0)
y, _, class_counts = tf.unique_with_counts(batch_y)
assert tf.math.reduce_all(tf.sort(y) == tf.constant([1, 2]))
assert tf.math.reduce_all(class_counts == tf.constant([3, 3]))
# Subsequent batch should produce the sampler warning.
captured = capsys.readouterr()
match = re.search(expected_msg, captured.out)
assert not bool(match)