Skip to content

Commit

Permalink
Remove as_numpy_iterator when creating the list of grouped datasets.
Browse files Browse the repository at this point in the history
* Also move class_list filter to before the group_by function
* Apply the total_examples_per_class as a take() function on each
  grouped dataset
* Remove as much casting as possible from the dataset. Certain functions
  expect an int64 though and require casting.
  • Loading branch information
owenvallis committed May 1, 2023
1 parent 4dbf73e commit 473b2b2
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 21 deletions.
20 changes: 11 additions & 9 deletions tensorflow_similarity/samplers/tfdata_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,22 @@ def create_grouped_dataset(
A List of `tf.data.Dataset` objects grouped by class id.
"""
window_size = ds.cardinality().numpy()
if class_list is not None:
class_list = tf.constant(class_list)
ds = ds.filter(lambda x, y, *args: tf.reduce_any(tf.equal(y, class_list)))

# NOTE: We need to cast the key_func as the group_op expects an int64.
grouped_by_cid = ds.group_by_window(
key_func=lambda x, y, *args: y,
key_func=lambda x, y, *args: tf.cast(y, dtype=tf.int64),
reduce_func=lambda key, ds: ds.batch(window_size),
window_size=window_size,
)

cid_datasets = []
for elem in grouped_by_cid.as_numpy_iterator():
# This assumes that elem has at least (x, y) and that y is a tensor or array_like.
if class_list is not None and elem[1][0] not in class_list:
continue
if total_examples is not None:
elem = elem[:total_examples]
for elem in grouped_by_cid:
cid_ds = tf.data.Dataset.from_tensor_slices(elem)
if total_examples is not None:
cid_ds = cid_ds.take(total_examples)
if buffer_size is not None:
cid_ds = cid_ds.shuffle(buffer_size)
cid_datasets.append(cid_ds.repeat())
Expand Down Expand Up @@ -99,8 +101,8 @@ def apply_augmenter_ds(ds: tf.data.Dataset, augmenter: Callable, warmup: int | N
count_ds = tf.data.experimental.Counter()

ds = tf.data.Dataset.choose_from_datasets(
[ds, aug_ds],
count_ds.map(lambda x: tf.cast(0, dtype=tf.dtypes.int64) if x < warmup else tf.cast(1, dtype=tf.dtypes.int64)),
[ds.take(warmup), aug_ds],
count_ds.map(lambda x: tf.cast(0, dtype=tf.int64) if x < warmup else tf.cast(1, dtype=tf.int64)),
)

return ds
Expand Down
25 changes: 13 additions & 12 deletions tests/samplers/test_tfdata_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ class TestCreateGroupedDataset(tf.test.TestCase):
def setUp(self):
self.ds = tf.data.Dataset.from_tensor_slices(
(
tf.constant([1, 2, 3, 4, 5, 6, 7, 8], dtype=tf.float32),
tf.constant([1, 1, 2, 2, 3, 3, 4, 4], dtype=tf.int64),
tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]),
tf.constant([1, 1, 2, 2, 3, 3, 4, 4]),
)
)

Expand Down Expand Up @@ -38,6 +38,7 @@ def test_returns_correct_number_of_datasets_with_buffer_size(self):
self.assertEqual(len(cid_datasets), 4)

def test_datasets_repeat(self):
print(self.ds.element_spec)
cid_datasets = tfds.create_grouped_dataset(self.ds)
for cid_ds in cid_datasets:
self.assertTrue(cid_ds.element_spec, self.ds.element_spec)
Expand Down Expand Up @@ -121,12 +122,12 @@ def setUp(self):
self.ds = tf.data.Dataset.from_tensor_slices(
(
tf.random.uniform((6, 2), dtype=tf.float32),
tf.constant([1, 1, 1, 2, 2, 2], dtype=tf.int64),
tf.constant([1, 1, 1, 2, 2, 2], dtype=tf.int32),
)
)
self.expected_elementspec = (
tf.TensorSpec(shape=(None, 2), dtype=tf.float32, name=None),
tf.TensorSpec(shape=(None,), dtype=tf.int64, name=None),
self.expected_element_spec = (
tf.TensorSpec(shape=(None, 2), dtype=tf.float32),
tf.TensorSpec(shape=(None,), dtype=tf.int32),
)

def test_cardinality_is_finite(self):
Expand All @@ -144,32 +145,32 @@ def test_cardinality_is_known(self):
def test_output_batch_size(self):
# Test that the output batch size is correct
out_ds = tfds.TFDataSampler(self.ds)
self.assertEqual(out_ds.element_spec, self.expected_elementspec)
self.assertEqual(out_ds.element_spec, self.expected_element_spec)

def test_output_classes_per_batch(self):
# Test that the number of classes per batch is correct
out_ds = tfds.TFDataSampler(self.ds, classes_per_batch=1)
self.assertEqual(out_ds.element_spec, self.expected_elementspec)
self.assertEqual(out_ds.element_spec, self.expected_element_spec)

def test_output_examples_per_class_per_batch(self):
# Test that the number of examples per class per batch is correct
out_ds = tfds.TFDataSampler(self.ds, examples_per_class_per_batch=1)
self.assertEqual(out_ds.element_spec, self.expected_elementspec)
self.assertEqual(out_ds.element_spec, self.expected_element_spec)

def test_output_class_list(self):
# Test that the class list is correctly used
out_ds = tfds.TFDataSampler(self.ds, class_list=[1])
self.assertEqual(out_ds.element_spec, self.expected_elementspec)
self.assertEqual(out_ds.element_spec, self.expected_element_spec)

def test_output_total_examples_per_class(self):
# Test that the total number of examples per class is correctly used
out_ds = tfds.TFDataSampler(self.ds, total_examples_per_class=2)
self.assertEqual(out_ds.element_spec, self.expected_elementspec)
self.assertEqual(out_ds.element_spec, self.expected_element_spec)

def test_output_augmenter(self):
# Test that the augmenter is correctly applied
out_ds = tfds.TFDataSampler(self.ds, augmenter=lambda x, y: (x * 2, y))
self.assertEqual(out_ds.element_spec, self.expected_elementspec)
self.assertEqual(out_ds.element_spec, self.expected_element_spec)

def test_output_load_fn(self):
# TODO(ovallis): Test that the load_fn is correctly applied
Expand Down

0 comments on commit 473b2b2

Please sign in to comment.